diff --git a/.gitignore b/.gitignore index 068cb87484a0..b9357018a64c 100644 --- a/.gitignore +++ b/.gitignore @@ -2,10 +2,11 @@ __pycache__/ *.py[cod] *$py.class - +*.S # C extensions *.so - +*.ll +.npm # Distribution / packaging .Python env/ @@ -224,7 +225,7 @@ Pipfile.lock # conda package artifacts conda/Dockerfile.cuda* conda/pkg - +.node_repl_history # nix files .envrc *.nix diff --git a/3rdparty/bfloat16/bfloat16.cc b/3rdparty/bfloat16/bfloat16.cc index 56d05efb03a5..674feb4f29c0 100644 --- a/3rdparty/bfloat16/bfloat16.cc +++ b/3rdparty/bfloat16/bfloat16.cc @@ -17,6 +17,7 @@ ==============================================================================*/ #include + #include #include @@ -50,8 +51,7 @@ void BFloat16ToFloat(const uint16_t* src, float* dst, size_t size) { #endif } -void BFloat16Add(const uint16_t* a, const uint16_t* b, uint16_t* dst, - size_t size) { +void BFloat16Add(const uint16_t* a, const uint16_t* b, uint16_t* dst, size_t size) { float a_f, b_f; BFloat16ToFloat(a, &a_f, 1); BFloat16ToFloat(b, &b_f, 1); diff --git a/3rdparty/cma/cma.h b/3rdparty/cma/cma.h index f005b3065c3a..2cd550122614 100644 --- a/3rdparty/cma/cma.h +++ b/3rdparty/cma/cma.h @@ -27,20 +27,17 @@ #ifndef VTA_DE10_NANO_KERNEL_MODULE_CMA_H_ #define VTA_DE10_NANO_KERNEL_MODULE_CMA_H_ - /* Should be defined in settings.mk file */ #ifndef CMA_IOCTL_MAGIC -#define CMA_IOCTL_MAGIC 0xf2 +#define CMA_IOCTL_MAGIC 0xf2 #endif +#define CMA_ALLOC_CACHED _IOC(_IOC_WRITE | _IOC_READ, CMA_IOCTL_MAGIC, 1, 4) +#define CMA_ALLOC_NONCACHED _IOC(_IOC_WRITE | _IOC_READ, CMA_IOCTL_MAGIC, 2, 4) +#define CMA_FREE _IOC(_IOC_WRITE, CMA_IOCTL_MAGIC, 3, 4) +#define CMA_GET_PHY_ADDR _IOC(_IOC_WRITE | _IOC_READ, CMA_IOCTL_MAGIC, 4, 4) +#define CMA_GET_SIZE _IOC(_IOC_WRITE | _IOC_READ, CMA_IOCTL_MAGIC, 5, 4) -#define CMA_ALLOC_CACHED _IOC(_IOC_WRITE|_IOC_READ, CMA_IOCTL_MAGIC, 1, 4) -#define CMA_ALLOC_NONCACHED _IOC(_IOC_WRITE|_IOC_READ, CMA_IOCTL_MAGIC, 2, 4) -#define CMA_FREE _IOC(_IOC_WRITE, CMA_IOCTL_MAGIC, 3, 4) -#define CMA_GET_PHY_ADDR _IOC(_IOC_WRITE|_IOC_READ, CMA_IOCTL_MAGIC, 4, 4) -#define CMA_GET_SIZE _IOC(_IOC_WRITE|_IOC_READ, CMA_IOCTL_MAGIC, 5, 4) - -#define CMA_IOCTL_MAXNR 5 - +#define CMA_IOCTL_MAXNR 5 #endif // VTA_DE10_NANO_KERNEL_MODULE_CMA_H_ diff --git a/3rdparty/cma/cma_api_impl.h b/3rdparty/cma/cma_api_impl.h index 12c0e3b27efc..317be5c9af1a 100644 --- a/3rdparty/cma/cma_api_impl.h +++ b/3rdparty/cma/cma_api_impl.h @@ -30,48 +30,47 @@ * \brief Application layer implementation for contigous memory allocation. */ +#include +#include #include #include -#include -#include -#include #include -#include #include #include +#include +#include #include "cma_api.h" #ifndef CMA_IOCTL_MAGIC -#define CMA_IOCTL_MAGIC 0xf2 +#define CMA_IOCTL_MAGIC 0xf2 #endif -#define CMA_ALLOC_CACHED _IOC(_IOC_WRITE|_IOC_READ, CMA_IOCTL_MAGIC, 1, 4) -#define CMA_ALLOC_NONCACHED _IOC(_IOC_WRITE|_IOC_READ, CMA_IOCTL_MAGIC, 2, 4) -#define CMA_FREE _IOC(_IOC_WRITE, CMA_IOCTL_MAGIC, 3, 4) -#define CMA_GET_PHY_ADDR _IOC(_IOC_WRITE|_IOC_READ, CMA_IOCTL_MAGIC, 4, 4) -#define CMA_GET_SIZE _IOC(_IOC_WRITE|_IOC_READ, CMA_IOCTL_MAGIC, 5, 4) +#define CMA_ALLOC_CACHED _IOC(_IOC_WRITE | _IOC_READ, CMA_IOCTL_MAGIC, 1, 4) +#define CMA_ALLOC_NONCACHED _IOC(_IOC_WRITE | _IOC_READ, CMA_IOCTL_MAGIC, 2, 4) +#define CMA_FREE _IOC(_IOC_WRITE, CMA_IOCTL_MAGIC, 3, 4) +#define CMA_GET_PHY_ADDR _IOC(_IOC_WRITE | _IOC_READ, CMA_IOCTL_MAGIC, 4, 4) +#define CMA_GET_SIZE _IOC(_IOC_WRITE | _IOC_READ, CMA_IOCTL_MAGIC, 5, 4) -#define CMA_IOCTL_MAXNR 5 +#define CMA_IOCTL_MAXNR 5 #ifndef CMA_DEBUG - #define CMA_DEBUG 0 +#define CMA_DEBUG 0 #endif #ifndef DRIVER_NODE_NAME - #define DRIVER_NODE_NAME "cma" +#define DRIVER_NODE_NAME "cma" #endif #if CMA_DEBUG == 1 - #define __DEBUG(fmt, args...) printf("CMA_API_DEBUG: " fmt, ##args) +#define __DEBUG(fmt, args...) printf("CMA_API_DEBUG: " fmt, ##args) #else - #define __DEBUG(fmt, args...) +#define __DEBUG(fmt, args...) #endif -#define ROUND_UP(N, S) ((((N) + (S) - 1) / (S)) * (S)) - +#define ROUND_UP(N, S) ((((N) + (S)-1) / (S)) * (S)) /* Private functions */ -void *cma_alloc(size_t size, unsigned ioctl_cmd); +void* cma_alloc(size_t size, unsigned ioctl_cmd); /* Global file descriptor */ int cma_fd = 0; @@ -99,23 +98,19 @@ int cma_release(void) { return 0; } -void *cma_alloc_cached(size_t size) { - return cma_alloc(size, CMA_ALLOC_CACHED); -} +void* cma_alloc_cached(size_t size) { return cma_alloc(size, CMA_ALLOC_CACHED); } -void *cma_alloc_noncached(size_t size) { - return cma_alloc(size, CMA_ALLOC_NONCACHED); -} +void* cma_alloc_noncached(size_t size) { return cma_alloc(size, CMA_ALLOC_NONCACHED); } -int cma_free(void *mem) { +int cma_free(void* mem) { __DEBUG("Releasing contigous memory from 0x%x\n", (unsigned)mem); unsigned data, v_addr; /* save user space pointer value */ - data = (unsigned)mem; + data = (unsigned)mem; v_addr = (unsigned)mem; - if ( ioctl(cma_fd, CMA_GET_SIZE, &data) == -1 ) { + if (ioctl(cma_fd, CMA_GET_SIZE, &data) == -1) { __DEBUG("cma_free - ioctl command unsuccsessful - 0\n"); return -1; } @@ -125,7 +120,7 @@ int cma_free(void *mem) { munmap(mem, data); /* free cma entry */ - if ( ioctl(cma_fd, CMA_FREE, &v_addr) == -1 ) { + if (ioctl(cma_fd, CMA_FREE, &v_addr) == -1) { __DEBUG("cma_free - ioctl command unsuccsessful - 1\n"); return -1; } @@ -133,7 +128,7 @@ int cma_free(void *mem) { return 0; } -unsigned cma_get_phy_addr(void *mem) { +unsigned cma_get_phy_addr(void* mem) { unsigned data; __DEBUG("Getting physical address from 0x%x\n", (unsigned)mem); @@ -141,7 +136,7 @@ unsigned cma_get_phy_addr(void *mem) { data = (unsigned)mem; /* get physical address */ - if ( ioctl(cma_fd, CMA_GET_PHY_ADDR, &data) == -1 ) { + if (ioctl(cma_fd, CMA_GET_PHY_ADDR, &data) == -1) { __DEBUG("cma_free - ioctl command unsuccsessful\n"); return 0; } @@ -150,10 +145,9 @@ unsigned cma_get_phy_addr(void *mem) { return data; } - -void *cma_alloc(size_t size, unsigned ioctl_cmd) { +void* cma_alloc(size_t size, unsigned ioctl_cmd) { unsigned data; - void *mem; + void* mem; __DEBUG("Allocating 0x%x bytes of contigous memory\n", size); /* Page align size */ @@ -161,7 +155,7 @@ void *cma_alloc(size_t size, unsigned ioctl_cmd) { /* ioctl cmd to allocate contigous memory */ data = (unsigned)size; - if ( ioctl(cma_fd, ioctl_cmd, &data) == -1 ) { + if (ioctl(cma_fd, ioctl_cmd, &data) == -1) { __DEBUG("cma_alloc - ioctl command unsuccsessful\n"); return NULL; } diff --git a/3rdparty/compiler-rt/builtin_fp16.h b/3rdparty/compiler-rt/builtin_fp16.h index fa8efddcd4ca..804898081996 100644 --- a/3rdparty/compiler-rt/builtin_fp16.h +++ b/3rdparty/compiler-rt/builtin_fp16.h @@ -29,16 +29,33 @@ static inline uint32_t __clz(uint32_t x) { int n = 32; uint32_t y; - y = x >>16; if (y) { n = n -16; x = y; } - y = x >> 8; if (y) { n = n - 8; x = y; } - y = x >> 4; if (y) { n = n - 4; x = y; } - y = x >> 2; if (y) { n = n - 2; x = y; } - y = x >> 1; if (y) return n - 2; + y = x >> 16; + if (y) { + n = n - 16; + x = y; + } + y = x >> 8; + if (y) { + n = n - 8; + x = y; + } + y = x >> 4; + if (y) { + n = n - 4; + x = y; + } + y = x >> 2; + if (y) { + n = n - 2; + x = y; + } + y = x >> 1; + if (y) return n - 2; return n - x; } -template +template static inline DST_T __truncXfYf2__(SRC_T a) { // Various constants whose values follow from the type parameters. // Any reasonable optimizer will fold and propagate all of these. @@ -71,7 +88,10 @@ static inline DST_T __truncXfYf2__(SRC_T a) { const DST_REP_T dstNaNCode = dstQNaN - 1; // Break a into a sign and representation of the absolute value - union SrcExchangeType { SRC_T f; SRC_REP_T i; }; + union SrcExchangeType { + SRC_T f; + SRC_REP_T i; + }; SrcExchangeType src_rep; src_rep.f = a; const SRC_REP_T aRep = src_rep.i; @@ -88,25 +108,21 @@ static inline DST_T __truncXfYf2__(SRC_T a) { const SRC_REP_T roundBits = aAbs & roundMask; // Round to nearest - if (roundBits > halfway) - absResult++; - // Ties to even + if (roundBits > halfway) absResult++; + // Ties to even else if (roundBits == halfway) absResult += absResult & 1; - } - else if (aAbs > srcInfinity) { + } else if (aAbs > srcInfinity) { // a is NaN. // Conjure the result by beginning with infinity, setting the qNaN // bit and inserting the (truncated) trailing NaN field. absResult = (DST_REP_T)dstInfExp << DST_SIG_BITS; absResult |= dstQNaN; absResult |= ((aAbs & srcNaNCode) >> (SRC_SIG_BITS - DST_SIG_BITS)) & dstNaNCode; - } - else if (aAbs >= overflow) { + } else if (aAbs >= overflow) { // a overflows to infinity. absResult = (DST_REP_T)dstInfExp << DST_SIG_BITS; - } - else { + } else { // a underflows on conversion to the destination type or is an exact // zero. The result may be a denormal or zero. Extract the exponent // to get the shift amount for the denormalization. @@ -124,9 +140,8 @@ static inline DST_T __truncXfYf2__(SRC_T a) { absResult = denormalizedSignificand >> (SRC_SIG_BITS - DST_SIG_BITS); const SRC_REP_T roundBits = denormalizedSignificand & roundMask; // Round to nearest - if (roundBits > halfway) - absResult++; - // Ties to even + if (roundBits > halfway) absResult++; + // Ties to even else if (roundBits == halfway) absResult += absResult & 1; } @@ -134,14 +149,17 @@ static inline DST_T __truncXfYf2__(SRC_T a) { // Apply the signbit to (DST_T)abs(a). const DST_REP_T result = absResult | sign >> (srcBits - dstBits); - union DstExchangeType { DST_T f; DST_REP_T i; }; + union DstExchangeType { + DST_T f; + DST_REP_T i; + }; DstExchangeType dst_rep; dst_rep.i = result; return dst_rep.f; } -template +template static inline DST_T __extendXfYf2__(SRC_T a) { // Various constants whose values follow from the type parameters. // Any reasonable optimizer will fold and propagate all of these. @@ -157,7 +175,7 @@ static inline DST_T __extendXfYf2__(SRC_T a) { const SRC_REP_T srcQNaN = SRC_REP_T(1) << (SRC_SIG_BITS - 1); const SRC_REP_T srcNaNCode = srcQNaN - 1; - const int dstBits = sizeof(DST_T)*8; + const int dstBits = sizeof(DST_T) * 8; const int dstExpBits = dstBits - DST_SIG_BITS - 1; const int dstInfExp = (1 << dstExpBits) - 1; const int dstExpBias = dstInfExp >> 1; @@ -165,7 +183,10 @@ static inline DST_T __extendXfYf2__(SRC_T a) { const DST_REP_T dstMinNormal = DST_REP_T(1) << DST_SIG_BITS; // Break a into a sign and representation of the absolute value - union SrcExchangeType { SRC_T f; SRC_REP_T i; }; + union SrcExchangeType { + SRC_T f; + SRC_REP_T i; + }; SrcExchangeType src_rep; src_rep.f = a; const SRC_REP_T aRep = src_rep.i; @@ -191,8 +212,7 @@ static inline DST_T __extendXfYf2__(SRC_T a) { absResult = (DST_REP_T)dstInfExp << DST_SIG_BITS; absResult |= (DST_REP_T)(aAbs & srcQNaN) << (DST_SIG_BITS - SRC_SIG_BITS); absResult |= (DST_REP_T)(aAbs & srcNaNCode) << (DST_SIG_BITS - SRC_SIG_BITS); - } - else if (aAbs) { + } else if (aAbs) { // a is denormal. // renormalize the significand and clear the leading bit, then insert // the correct adjusted exponent in the destination type. @@ -201,15 +221,17 @@ static inline DST_T __extendXfYf2__(SRC_T a) { absResult ^= dstMinNormal; const int resultExponent = dstExpBias - srcExpBias - scale + 1; absResult |= (DST_REP_T)resultExponent << DST_SIG_BITS; - } - else { + } else { // a is zero. absResult = 0; } // Apply the signbit to (DST_T)abs(a). const DST_REP_T result = absResult | (DST_REP_T)sign << (dstBits - srcBits); - union DstExchangeType { DST_T f; DST_REP_T i; }; + union DstExchangeType { + DST_T f; + DST_REP_T i; + }; DstExchangeType dst_rep; dst_rep.i = result; return dst_rep.f; diff --git a/CMakeLists.txt b/CMakeLists.txt index 714dfa6f9aab..0c0cfb8cc507 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -24,7 +24,6 @@ endif() tvm_option(USE_CUDA "Build with CUDA" OFF) tvm_option(USE_OPENCL "Build with OpenCL" OFF) tvm_option(USE_VULKAN "Build with Vulkan" OFF) -tvm_option(USE_OPENGL "Build with OpenGL" OFF) tvm_option(USE_METAL "Build with Metal" OFF) tvm_option(USE_ROCM "Build with ROCM" OFF) tvm_option(ROCM_PATH "The path to rocm" /opt/rocm) @@ -71,6 +70,7 @@ tvm_option(USE_ANTLR "Build with ANTLR for Relay parsing" OFF) tvm_option(USE_CPP_RPC "Build CPP RPC" OFF) tvm_option(USE_TFLITE "Build with tflite support" OFF) tvm_option(USE_TENSORFLOW_PATH "TensorFlow root path when use TFLite" none) +tvm_option(USE_COREML "Build with coreml support" OFF) if(USE_CPP_RPC AND UNIX) message(FATAL_ERROR "USE_CPP_RPC is only supported with WIN32. Use the Makefile for non-Windows.") @@ -122,7 +122,7 @@ else(MSVC) endif(USE_TF_COMPILE_FLAGS) if ("${CMAKE_BUILD_TYPE}" STREQUAL "Debug") - message("Build in Debug mode") + message(STATUS "Build in Debug mode") set(CMAKE_C_FLAGS "-O0 -g -Wall -fPIC ${CMAKE_C_FLAGS}") set(CMAKE_CXX_FLAGS "-O0 -g -Wall -fPIC ${CMAKE_CXX_FLAGS}") set(CMAKE_CUDA_FLAGS "-O0 -g -Xcompiler=-Wall -Xcompiler=-fPIC ${CMAKE_CUDA_FLAGS}") @@ -160,11 +160,35 @@ else(MSVC) if(BUILD_FOR_HEXAGON) message(STATUS "Building for Hexagon") endif() + + # Detect if we're compiling for Android. + set(TEST_FOR_ANDROID_CXX + "#ifndef __ANDROID__" + "#error" + "#endif" + "int main() {}") + set(TEST_FOR_ANDROID_DIR + "${CMAKE_BINARY_DIR}${CMAKE_FILES_DIRECTORY}/CMakeTmp") + set(TEST_FOR_ANDROID_FILE "${TEST_FOR_ANDROID_DIR}/test_for_android.cc") + string(REPLACE ";" "\n" TEST_FOR_ANDROID_CXX_TEXT "${TEST_FOR_ANDROID_CXX}") + file(WRITE "${TEST_FOR_ANDROID_FILE}" "${TEST_FOR_ANDROID_CXX_TEXT}") + try_compile(BUILD_FOR_ANDROID "${CMAKE_BINARY_DIR}${CMAKE_FILES_DIRECTORY}" + "${TEST_FOR_ANDROID_FILE}") + file(REMOVE "${TEST_FOR_ANDROID_FILE}") + if(BUILD_FOR_ANDROID) + message(STATUS "Building for Android") + endif() endif(MSVC) # Hexagon has dlopen built into QuRT (no need for static library). if(NOT BUILD_FOR_HEXAGON) - string(APPEND TVM_RUNTIME_LINKER_LIBS ${CMAKE_DL_LIBS}) + list(APPEND TVM_RUNTIME_LINKER_LIBS ${CMAKE_DL_LIBS}) +endif() + +if(BUILD_FOR_ANDROID) + # EmuTLS on Android is in libgcc. Without it linked in, libtvm_runtime.so + # won't load on Android due to missing __emutls_XXX symbols. + list(APPEND TVM_RUNTIME_LINKER_LIBS "gcc") endif() # add source group @@ -224,13 +248,6 @@ if(USE_VM_PROFILER) list(APPEND COMPILER_SRCS ${BACKEND_VM_PROFILER_SRCS}) endif(USE_VM_PROFILER) -if(BUILD_FOR_HEXAGON) - # Add file implementing posix_memalign. - list(APPEND RUNTIME_SRCS src/runtime/hexagon/hexagon_posix.cc) - - add_definitions(-D_MACH_I32=int) -endif() - file(GLOB DATATYPE_SRCS src/target/datatype/*.cc) list(APPEND COMPILER_SRCS ${DATATYPE_SRCS}) @@ -243,6 +260,13 @@ file(GLOB RUNTIME_SRCS src/runtime/vm/*.cc ) +if(BUILD_FOR_HEXAGON) + # Add file implementing posix_memalign. + list(APPEND RUNTIME_SRCS src/runtime/hexagon/hexagon_posix.cc) + + add_definitions(-D_MACH_I32=int) +endif() + # Package runtime rules if(NOT USE_RTTI) add_definitions(-DDMLC_ENABLE_RTTI=0) @@ -297,7 +321,6 @@ include(cmake/modules/VTA.cmake) include(cmake/modules/CUDA.cmake) include(cmake/modules/Hexagon.cmake) include(cmake/modules/OpenCL.cmake) -include(cmake/modules/OpenGL.cmake) include(cmake/modules/OpenMP.cmake) include(cmake/modules/Vulkan.cmake) include(cmake/modules/Metal.cmake) @@ -316,13 +339,17 @@ include(cmake/modules/contrib/TensorRT.cmake) include(cmake/modules/contrib/HybridDump.cmake) include(cmake/modules/contrib/TFLite.cmake) include(cmake/modules/contrib/TF_TVMDSOOP.cmake) +include(cmake/modules/contrib/CoreML.cmake) +include(CheckCXXCompilerFlag) if(NOT MSVC) - include(CheckCXXCompilerFlag) check_cxx_compiler_flag("-std=c++14" SUPPORT_CXX14) - message(STATUS "Build with c++14") set(CMAKE_CXX_FLAGS "-std=c++14 ${CMAKE_CXX_FLAGS}") set(CMAKE_CUDA_STANDARD 14) +else() + check_cxx_compiler_flag("/std:c++14" SUPPORT_CXX14) + set(CMAKE_CXX_FLAGS "/std:c++14 ${CMAKE_CXX_FLAGS}") + set(CMAKE_CUDA_STANDARD 14) endif() add_library(tvm SHARED ${COMPILER_SRCS} ${RUNTIME_SRCS}) diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 10b247ae8eb1..8945adbacd3a 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -41,8 +41,9 @@ We do encourage everyone to work anything they are interested in. - [Aditya Atluri](https://github.com/adityaatluri): @adityaatluri - rocm - [Tianqi Chen](https://github.com/tqchen) (PPMC): @tqchen - topi, compiler, relay, docs +- [Liangfu Chen](https://github.com/liangfu): @liangfu - vta, chisel, intel FPGA, c runtime - [Wei Chen](https://github.com/wweic): @wweic - runtime, relay, vm -- [Zhi Chen](https://github.com/zhiics): @zhiics - relay, quantization, pass manager +- [Zhi Chen](https://github.com/zhiics) (PPMC): @zhiics - relay, quantization, pass manager - [Yuwei Hu](https://github.com/Huyuwei): @Huyuwei - topi, frontends - [Nick Hynes](https://github.com/nhynes): @nhynes: - sgx, rust - [Animesh Jain](https://github.com/anijain2305): @anijain2305 - quantization, relay @@ -51,7 +52,7 @@ We do encourage everyone to work anything they are interested in. - [Wuwei Lin](https://github.com/vinx13): @vinx13 - relay, topi - [Yizhi Liu](https://github.com/yzhliu) (PPMC): @yzhliu - jvm, topi, relay - [Hao Lu](https://github.com/hlu1): @hlu1 - nnpack, frontends -- [Masahiro Masuda](https://github.com/masahi): @masahi - topi, relay +- [Masahiro Masuda](https://github.com/masahi) (PPMC): @masahi - topi, relay - [Thierry Moreau](https://github.com/tmoreau89) (PPMC): @tmoreau89 - vta - [Kazutaka Morita](https://github.com/kazum): @kazum - frontends, opencl - [Jared Roesch](https://github.com/jroesch) (PPMC): @jroesch - relay @@ -98,6 +99,8 @@ We do encourage everyone to work anything they are interested in. - [Thierry Moreau](https://github.com/tmoreau89): @tmoreau89 - [Kazutaka Morita](https://github.com/kazum): @kazum - [Tatsuya Nishiyama](https://github.com/nishi-t): @nishi-t +- [Wei Pan](https://github.com/wpan11nv): @wpan11nv +- [Krzysztof Parzyszek](https://github.com/kparzysz-quic): @kparzysz-quic - [Pariksheet Pinjari](https://github.com/PariksheetPinjari909): @PariksheetPinjari909 - [Josh Pollock](https://github.com/joshpoll): @joshpoll - [Jared Roesch](https://github.com/jroesch): @jroesch diff --git a/Jenkinsfile b/Jenkinsfile index 3dbee4d7fa41..8a13998a6a64 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -43,9 +43,10 @@ // // -ci_lint = "tvmai/ci-lint:v0.60" -ci_gpu = "tvmai/ci-gpu:v0.61" -ci_cpu = "tvmai/ci-cpu:v0.61" +ci_lint = "tvmai/ci-lint:v0.61" +ci_gpu = "tvmai/ci-gpu:v0.64" +ci_cpu = "tvmai/ci-cpu:v0.62" +ci_wasm = "tvmai/ci-wasm:v0.60" ci_i386 = "tvmai/ci-i386:v0.52" // tvm libraries @@ -158,7 +159,7 @@ stage('Build') { init_git() sh "${docker_run} ${ci_cpu} ./tests/scripts/task_config_build_cpu.sh" make(ci_cpu, 'build', '-j2') - pack_lib('cpu', tvm_lib) + pack_lib('cpu', tvm_multilib) timeout(time: max_time, unit: 'MINUTES') { sh "${docker_run} ${ci_cpu} ./tests/scripts/task_python_unittest.sh" sh "${docker_run} ${ci_cpu} ./tests/scripts/task_python_integration.sh" @@ -169,6 +170,18 @@ stage('Build') { } } } + }, + 'BUILD: WASM': { + node('CPU') { + ws(per_exec_ws("tvm/build-wasm")) { + init_git() + sh "${docker_run} ${ci_wasm} ./tests/scripts/task_config_build_wasm.sh" + make(ci_wasm, 'build', '-j2') + timeout(time: max_time, unit: 'MINUTES') { + sh "${docker_run} ${ci_wasm} ./tests/scripts/task_web_wasm.sh" + } + } + } }, 'BUILD : i386': { node('CPU') { @@ -189,8 +202,7 @@ stage('Unit Test') { init_git() unpack_lib('gpu', tvm_multilib) timeout(time: max_time, unit: 'MINUTES') { - // TODO(trevmorr): neo-ai/tvm disable sphinx due to missing PRs from upstream - // sh "${docker_run} ${ci_gpu} ./tests/scripts/task_sphinx_precheck.sh" + sh "${docker_run} ${ci_gpu} ./tests/scripts/task_sphinx_precheck.sh" sh "${docker_run} ${ci_gpu} ./tests/scripts/task_python_unittest.sh" sh "${docker_run} ${ci_gpu} ./tests/scripts/task_python_integration.sh" } @@ -245,10 +257,21 @@ stage('Integration Test') { } } } + }, + 'frontend: CPU': { + node('CPU') { + ws(per_exec_ws("tvm/frontend-python-cpu")) { + init_git() + unpack_lib('cpu', tvm_multilib) + timeout(time: max_time, unit: 'MINUTES') { + sh "${docker_run} ${ci_cpu} ./tests/scripts/task_python_frontend_cpu.sh" + } + } + } } // TODO: Fix the doc // 'docs: GPU': { - // node('GPU') { + // node('TensorCore') { // ws(per_exec_ws("tvm/docs-python-gpu")) { // init_git() // unpack_lib('gpu', tvm_multilib) diff --git a/Makefile b/Makefile index 757b3300f7d5..e54b9a93b230 100644 --- a/Makefile +++ b/Makefile @@ -73,7 +73,8 @@ build/libtvm_web_runtime.js: build/libtvm_web_runtime.bc cpplint: python3 3rdparty/dmlc-core/scripts/lint.py vta cpp vta/include vta/src python3 3rdparty/dmlc-core/scripts/lint.py topi cpp topi/include; - python3 3rdparty/dmlc-core/scripts/lint.py tvm cpp include src \ + python3 3rdparty/dmlc-core/scripts/lint.py tvm cpp \ + include src \ examples/extension/src examples/graph_executor/src pylint: diff --git a/apps/android_camera/app/src/main/jni/tvm_runtime.h b/apps/android_camera/app/src/main/jni/tvm_runtime.h index a58252e780fe..bc10bdaa508c 100644 --- a/apps/android_camera/app/src/main/jni/tvm_runtime.h +++ b/apps/android_camera/app/src/main/jni/tvm_runtime.h @@ -22,6 +22,7 @@ * \brief Pack all tvm runtime source files */ #include + #include /* Enable custom logging - this will cause TVM to pass every log message @@ -38,23 +39,23 @@ #include "../src/runtime/c_runtime_api.cc" #include "../src/runtime/cpu_device_api.cc" -#include "../src/runtime/workspace_pool.cc" +#include "../src/runtime/dso_library.cc" +#include "../src/runtime/file_util.cc" +#include "../src/runtime/graph/graph_runtime.cc" #include "../src/runtime/library_module.cc" -#include "../src/runtime/system_library.cc" #include "../src/runtime/module.cc" +#include "../src/runtime/ndarray.cc" +#include "../src/runtime/object.cc" #include "../src/runtime/registry.cc" -#include "../src/runtime/file_util.cc" -#include "../src/runtime/dso_library.cc" -#include "../src/runtime/rpc/rpc_session.cc" #include "../src/runtime/rpc/rpc_event_impl.cc" -#include "../src/runtime/rpc/rpc_server_env.cc" #include "../src/runtime/rpc/rpc_module.cc" +#include "../src/runtime/rpc/rpc_server_env.cc" +#include "../src/runtime/rpc/rpc_session.cc" #include "../src/runtime/rpc/rpc_socket_impl.cc" +#include "../src/runtime/system_library.cc" #include "../src/runtime/thread_pool.cc" #include "../src/runtime/threading_backend.cc" -#include "../src/runtime/graph/graph_runtime.cc" -#include "../src/runtime/ndarray.cc" -#include "../src/runtime/object.cc" +#include "../src/runtime/workspace_pool.cc" #ifdef TVM_OPENCL_RUNTIME #include "../src/runtime/opencl/opencl_device_api.cc" @@ -69,7 +70,6 @@ #include "../src/runtime/contrib/sort/sort.cc" #endif - #include void dmlc::CustomLogMessage::Log(const std::string& msg) { diff --git a/apps/android_camera/models/prepare_model.py b/apps/android_camera/models/prepare_model.py index 36674d273bd1..703a4656c479 100644 --- a/apps/android_camera/models/prepare_model.py +++ b/apps/android_camera/models/prepare_model.py @@ -87,7 +87,7 @@ def main(model_str, output_path): except FileExistsError: pass print("building...") - with relay.build_config(opt_level=3): + with tvm.transform.PassContext(opt_level=3): graph, lib, params = relay.build(net, target, target_host=target_host, params=params) print("dumping lib...") lib.export_library(output_path_str + '/' + 'deploy_lib_cpu.so', ndk.create_shared) diff --git a/apps/android_deploy/app/src/main/jni/tvm_runtime.h b/apps/android_deploy/app/src/main/jni/tvm_runtime.h index 0d038fb1060c..f1a47a674281 100644 --- a/apps/android_deploy/app/src/main/jni/tvm_runtime.h +++ b/apps/android_deploy/app/src/main/jni/tvm_runtime.h @@ -22,23 +22,23 @@ * \brief Pack all tvm runtime source files */ #include + #include #include "../src/runtime/c_runtime_api.cc" #include "../src/runtime/cpu_device_api.cc" -#include "../src/runtime/workspace_pool.cc" +#include "../src/runtime/dso_library.cc" +#include "../src/runtime/file_util.cc" +#include "../src/runtime/graph/graph_runtime.cc" #include "../src/runtime/library_module.cc" -#include "../src/runtime/system_library.cc" #include "../src/runtime/module.cc" +#include "../src/runtime/ndarray.cc" +#include "../src/runtime/object.cc" #include "../src/runtime/registry.cc" -#include "../src/runtime/file_util.cc" -#include "../src/runtime/dso_library.cc" +#include "../src/runtime/system_library.cc" #include "../src/runtime/thread_pool.cc" -#include "../src/runtime/object.cc" #include "../src/runtime/threading_backend.cc" -#include "../src/runtime/ndarray.cc" - -#include "../src/runtime/graph/graph_runtime.cc" +#include "../src/runtime/workspace_pool.cc" #ifdef TVM_OPENCL_RUNTIME #include "../src/runtime/opencl/opencl_device_api.cc" diff --git a/apps/android_rpc/app/src/main/jni/tvm_runtime.h b/apps/android_rpc/app/src/main/jni/tvm_runtime.h index 5d2bca2e216d..0b713b88ba9e 100644 --- a/apps/android_rpc/app/src/main/jni/tvm_runtime.h +++ b/apps/android_rpc/app/src/main/jni/tvm_runtime.h @@ -22,6 +22,7 @@ * \brief Pack all tvm runtime source files */ #include + #include /* Enable custom logging - this will cause TVM to pass every log message @@ -38,23 +39,23 @@ #include "../src/runtime/c_runtime_api.cc" #include "../src/runtime/cpu_device_api.cc" -#include "../src/runtime/workspace_pool.cc" +#include "../src/runtime/dso_library.cc" +#include "../src/runtime/file_util.cc" +#include "../src/runtime/graph/graph_runtime.cc" #include "../src/runtime/library_module.cc" -#include "../src/runtime/system_library.cc" #include "../src/runtime/module.cc" +#include "../src/runtime/ndarray.cc" +#include "../src/runtime/object.cc" #include "../src/runtime/registry.cc" -#include "../src/runtime/file_util.cc" -#include "../src/runtime/dso_library.cc" -#include "../src/runtime/rpc/rpc_session.cc" #include "../src/runtime/rpc/rpc_event_impl.cc" -#include "../src/runtime/rpc/rpc_server_env.cc" #include "../src/runtime/rpc/rpc_module.cc" +#include "../src/runtime/rpc/rpc_server_env.cc" +#include "../src/runtime/rpc/rpc_session.cc" #include "../src/runtime/rpc/rpc_socket_impl.cc" +#include "../src/runtime/system_library.cc" #include "../src/runtime/thread_pool.cc" #include "../src/runtime/threading_backend.cc" -#include "../src/runtime/graph/graph_runtime.cc" -#include "../src/runtime/ndarray.cc" -#include "../src/runtime/object.cc" +#include "../src/runtime/workspace_pool.cc" #ifdef TVM_OPENCL_RUNTIME #include "../src/runtime/opencl/opencl_device_api.cc" @@ -69,7 +70,6 @@ #include "../src/runtime/contrib/sort/sort.cc" #endif - #include void dmlc::CustomLogMessage::Log(const std::string& msg) { diff --git a/apps/benchmark/arm_cpu_imagenet_bench.py b/apps/benchmark/arm_cpu_imagenet_bench.py index 53b616868bdd..f319d5a53042 100644 --- a/apps/benchmark/arm_cpu_imagenet_bench.py +++ b/apps/benchmark/arm_cpu_imagenet_bench.py @@ -39,7 +39,7 @@ def evaluate_network(network, target, target_host, repeat): net, params, input_shape, output_shape = get_network(network, batch_size=1) print_progress("%-20s building..." % network) - with relay.build_config(opt_level=3): + with tvm.transform.PassContext(opt_level=3): graph, lib, params = relay.build( net, target=target, target_host=target_host, params=params) diff --git a/apps/benchmark/gpu_imagenet_bench.py b/apps/benchmark/gpu_imagenet_bench.py index dfb0445bf214..a3df2c46a24b 100644 --- a/apps/benchmark/gpu_imagenet_bench.py +++ b/apps/benchmark/gpu_imagenet_bench.py @@ -33,7 +33,7 @@ def benchmark(network, target): net, params, input_shape, output_shape = get_network(network, batch_size=1) - with relay.build_config(opt_level=3): + with tvm.transform.PassContext(opt_level=3): graph, lib, params = relay.build(net, target=target, params=params) # create runtime @@ -56,13 +56,17 @@ def benchmark(network, target): 'vgg-16', 'vgg-19', 'densenet-121', 'inception_v3', 'mobilenet', 'squeezenet_v1.0', 'squeezenet_v1.1'], help='The name of neural network') + parser.add_argument("--device", type=str, + choices=['amd_apu'], default='amd_apu', + help="The name of the test device. If your device is not listed in " + "the choices list, pick the most similar one as argument.") parser.add_argument("--model", type=str, - choices=['1080ti', 'titanx', 'tx2', 'gfx900'], default='1080ti', + choices=['1080ti', 'titanx', 'tx2', 'gfx900', 'v1000'], default='1080ti', help="The model of the test device. If your device is not listed in " "the choices list, pick the most similar one as argument.") parser.add_argument("--repeat", type=int, default=600) parser.add_argument("--target", type=str, - choices=['cuda', 'opencl', 'rocm', 'nvptx', 'metal'], default='cuda', + choices=['cuda', 'opencl', 'rocm', 'nvptx', 'metal', 'vulkan'], default='cuda', help="The tvm compilation target") parser.add_argument("--thread", type=int, default=1, help="The number of threads to be run.") args = parser.parse_args() @@ -74,7 +78,7 @@ def benchmark(network, target): else: networks = [args.network] - target = tvm.target.create('%s -model=%s' % (args.target, args.model)) + target = tvm.target.create('%s -device=%s -model=%s' % (args.target, args.device, args.model)) print("--------------------------------------------------") print("%-20s %-20s" % ("Network Name", "Mean Inference Time (std dev)")) diff --git a/apps/benchmark/mobile_gpu_imagenet_bench.py b/apps/benchmark/mobile_gpu_imagenet_bench.py index 4f93a0d5e383..83127ff5af72 100644 --- a/apps/benchmark/mobile_gpu_imagenet_bench.py +++ b/apps/benchmark/mobile_gpu_imagenet_bench.py @@ -38,7 +38,7 @@ def evaluate_network(network, target, target_host, dtype, repeat): net, params, input_shape, output_shape = get_network(network, batch_size=1, dtype=dtype) print_progress("%-20s building..." % network) - with relay.build_config(opt_level=3): + with tvm.transform.PassContext(opt_level=3): graph, lib, params = relay.build( net, target=target, target_host=target_host, params=params) diff --git a/apps/bundle_deploy/build_model.py b/apps/bundle_deploy/build_model.py index 63d658e6d428..1d415cd40ef4 100644 --- a/apps/bundle_deploy/build_model.py +++ b/apps/bundle_deploy/build_model.py @@ -33,7 +33,7 @@ def build_module(opts): func = mod["main"] func = relay.Function(func.params, relay.nn.softmax(func.body), None, func.type_params, func.attrs) - with relay.build_config(opt_level=3): + with tvm.transform.PassContext(opt_level=3): graph, lib, params = relay.build( func, 'llvm --system-lib', params=params) diff --git a/apps/bundle_deploy/bundle.c b/apps/bundle_deploy/bundle.c index dd24bcbdc049..4def96eb12b9 100644 --- a/apps/bundle_deploy/bundle.c +++ b/apps/bundle_deploy/bundle.c @@ -17,23 +17,22 @@ * under the License. */ -#include #include #include +#include /*! \brief macro to do C API call */ -#define TVM_CCALL(func) \ - do { \ - int ret = (func); \ - if (ret != 0) { \ +#define TVM_CCALL(func) \ + do { \ + int ret = (func); \ + if (ret != 0) { \ fprintf(stderr, "%s: %d: error: %s\n", __FILE__, __LINE__, TVMGetLastError()); \ - exit(ret); \ - } \ + exit(ret); \ + } \ } while (0) -TVM_DLL void * tvm_runtime_create(const char * json_data, - const char * params_data, - const uint64_t params_size) { +TVM_DLL void* tvm_runtime_create(const char* json_data, const char* params_data, + const uint64_t params_size) { int64_t device_type = kDLCPU; int64_t device_id = 0; @@ -47,43 +46,47 @@ TVM_DLL void * tvm_runtime_create(const char * json_data, // declare pointers TVMModuleHandle (*SystemLibraryCreate)(); - TVMModuleHandle (*TVMGraphRuntimeCreate)(const char *, const TVMModuleHandle, const TVMContext *); - int (*TVMGraphRuntime_LoadParams)(TVMModuleHandle, const char *, const uint32_t); + TVMModuleHandle (*TVMGraphRuntimeCreate)(const char*, const TVMModuleHandle, const TVMContext*); + int (*TVMGraphRuntime_LoadParams)(TVMModuleHandle, const char*, const uint32_t); // get pointers TVM_CCALL(TVMFuncGetGlobal("runtime.SystemLib", (TVMFunctionHandle*)&SystemLibraryCreate)); - TVM_CCALL(TVMFuncGetGlobal("tvm.graph_runtime.create", (TVMFunctionHandle*)&TVMGraphRuntimeCreate)); + TVM_CCALL( + TVMFuncGetGlobal("tvm.graph_runtime.create", (TVMFunctionHandle*)&TVMGraphRuntimeCreate)); // run modules TVMModuleHandle mod_syslib = SystemLibraryCreate(); TVMModuleHandle mod = TVMGraphRuntimeCreate(json_data, mod_syslib, &ctx); - TVM_CCALL(TVMModGetFunction(mod, "load_params", 0, (TVMFunctionHandle*)&TVMGraphRuntime_LoadParams)); + TVM_CCALL( + TVMModGetFunction(mod, "load_params", 0, (TVMFunctionHandle*)&TVMGraphRuntime_LoadParams)); TVMGraphRuntime_LoadParams(mod, params.data, params.size); - + return mod; } -TVM_DLL void tvm_runtime_destroy(void * runtime) { - void (*TVMGraphRuntimeRelease)(TVMModuleHandle *); - TVM_CCALL(TVMFuncGetGlobal("tvm.graph_runtime.release", (TVMFunctionHandle*)&TVMGraphRuntimeRelease)); +TVM_DLL void tvm_runtime_destroy(void* runtime) { + void (*TVMGraphRuntimeRelease)(TVMModuleHandle*); + TVM_CCALL( + TVMFuncGetGlobal("tvm.graph_runtime.release", (TVMFunctionHandle*)&TVMGraphRuntimeRelease)); TVMGraphRuntimeRelease(&runtime); } -TVM_DLL void tvm_runtime_set_input(void * runtime, const char * name, DLTensor * tensor) { - void (*TVMGraphRuntime_SetInput)(TVMModuleHandle, const char *, DLTensor*); - TVM_CCALL(TVMFuncGetGlobal("tvm.graph_runtime.set_input", (TVMFunctionHandle*)&TVMGraphRuntime_SetInput)); +TVM_DLL void tvm_runtime_set_input(void* runtime, const char* name, DLTensor* tensor) { + void (*TVMGraphRuntime_SetInput)(TVMModuleHandle, const char*, DLTensor*); + TVM_CCALL(TVMFuncGetGlobal("tvm.graph_runtime.set_input", + (TVMFunctionHandle*)&TVMGraphRuntime_SetInput)); TVMGraphRuntime_SetInput(runtime, name, tensor); } -TVM_DLL void tvm_runtime_run(void * runtime) { +TVM_DLL void tvm_runtime_run(void* runtime) { void (*TVMGraphRuntime_Run)(TVMModuleHandle runtime); TVM_CCALL(TVMFuncGetGlobal("tvm.graph_runtime.run", (TVMFunctionHandle*)&TVMGraphRuntime_Run)); TVMGraphRuntime_Run(runtime); } -TVM_DLL void tvm_runtime_get_output(void * runtime, int32_t index, DLTensor * tensor) { - int (*TVMGraphRuntime_GetOutput)(TVMModuleHandle, const int32_t, DLTensor *); - TVM_CCALL(TVMFuncGetGlobal("tvm.graph_runtime.get_output", (TVMFunctionHandle*)&TVMGraphRuntime_GetOutput)); +TVM_DLL void tvm_runtime_get_output(void* runtime, int32_t index, DLTensor* tensor) { + int (*TVMGraphRuntime_GetOutput)(TVMModuleHandle, const int32_t, DLTensor*); + TVM_CCALL(TVMFuncGetGlobal("tvm.graph_runtime.get_output", + (TVMFunctionHandle*)&TVMGraphRuntime_GetOutput)); TVMGraphRuntime_GetOutput(runtime, index, tensor); } - diff --git a/apps/bundle_deploy/bundle.cc b/apps/bundle_deploy/bundle.cc index 3e5080927db4..d8ff683decc3 100644 --- a/apps/bundle_deploy/bundle.cc +++ b/apps/bundle_deploy/bundle.cc @@ -17,51 +17,47 @@ * under the License. */ -#include #include #include +#include + #define TVM_BUNDLE_FUNCTION __attribute__((visibility("default"))) extern "C" { -TVM_BUNDLE_FUNCTION void *tvm_runtime_create(const char * build_graph_json, - const char * build_params_bin, +TVM_BUNDLE_FUNCTION void* tvm_runtime_create(const char* build_graph_json, + const char* build_params_bin, const uint64_t build_params_bin_len) { const int build_graph_json_len = strlen(build_graph_json); - const std::string json_data(&build_graph_json[0], - &build_graph_json[0] + build_graph_json_len); - tvm::runtime::Module mod_syslib = - (*tvm::runtime::Registry::Get("runtime.SystemLib"))(); + const std::string json_data(&build_graph_json[0], &build_graph_json[0] + build_graph_json_len); + tvm::runtime::Module mod_syslib = (*tvm::runtime::Registry::Get("runtime.SystemLib"))(); int device_type = kDLCPU; int device_id = 0; - tvm::runtime::Module mod = - (*tvm::runtime::Registry::Get("tvm.graph_runtime.create"))( - json_data, mod_syslib, device_type, device_id); + tvm::runtime::Module mod = (*tvm::runtime::Registry::Get("tvm.graph_runtime.create"))( + json_data, mod_syslib, device_type, device_id); TVMByteArray params; - params.data = reinterpret_cast(&build_params_bin[0]); + params.data = reinterpret_cast(&build_params_bin[0]); params.size = build_params_bin_len; mod.GetFunction("load_params")(params); return new tvm::runtime::Module(mod); } -TVM_BUNDLE_FUNCTION void tvm_runtime_destroy(void *handle) { - delete reinterpret_cast(handle); +TVM_BUNDLE_FUNCTION void tvm_runtime_destroy(void* handle) { + delete reinterpret_cast(handle); } -TVM_BUNDLE_FUNCTION void tvm_runtime_set_input(void *handle, const char *name, - void *tensor) { - reinterpret_cast(handle)->GetFunction("set_input")( - name, reinterpret_cast(tensor)); +TVM_BUNDLE_FUNCTION void tvm_runtime_set_input(void* handle, const char* name, void* tensor) { + reinterpret_cast(handle)->GetFunction("set_input")( + name, reinterpret_cast(tensor)); } -TVM_BUNDLE_FUNCTION void tvm_runtime_run(void *handle) { - reinterpret_cast(handle)->GetFunction("run")(); +TVM_BUNDLE_FUNCTION void tvm_runtime_run(void* handle) { + reinterpret_cast(handle)->GetFunction("run")(); } -TVM_BUNDLE_FUNCTION void tvm_runtime_get_output(void *handle, int index, - void *tensor) { - reinterpret_cast(handle)->GetFunction("get_output")( - index, reinterpret_cast(tensor)); +TVM_BUNDLE_FUNCTION void tvm_runtime_get_output(void* handle, int index, void* tensor) { + reinterpret_cast(handle)->GetFunction("get_output")( + index, reinterpret_cast(tensor)); } } diff --git a/apps/bundle_deploy/bundle.h b/apps/bundle_deploy/bundle.h index aa57faa38666..80238e1e231a 100644 --- a/apps/bundle_deploy/bundle.h +++ b/apps/bundle_deploy/bundle.h @@ -22,20 +22,15 @@ #include -TVM_DLL void * tvm_runtime_create(const char * json_data, - const char * params_data, - const uint64_t params_size); +TVM_DLL void* tvm_runtime_create(const char* json_data, const char* params_data, + const uint64_t params_size); -TVM_DLL void tvm_runtime_destroy(void * runtime); +TVM_DLL void tvm_runtime_destroy(void* runtime); -TVM_DLL void tvm_runtime_set_input(void * runtime, - const char * name, - DLTensor * tensor); +TVM_DLL void tvm_runtime_set_input(void* runtime, const char* name, DLTensor* tensor); -TVM_DLL void tvm_runtime_run(void * runtime); +TVM_DLL void tvm_runtime_run(void* runtime); -TVM_DLL void tvm_runtime_get_output(void * runtime, - int32_t index, - DLTensor * tensor); +TVM_DLL void tvm_runtime_get_output(void* runtime, int32_t index, DLTensor* tensor); #endif /* TVM_APPS_BUNDLE_DEPLOY_BUNDLE_H_ */ diff --git a/apps/bundle_deploy/bundle_static.c b/apps/bundle_deploy/bundle_static.c index c7eb9352652b..5ecc5e58eea7 100644 --- a/apps/bundle_deploy/bundle_static.c +++ b/apps/bundle_deploy/bundle_static.c @@ -23,9 +23,8 @@ #include "bundle.h" #include "runtime.c" -TVM_DLL void * tvm_runtime_create(const char * json_data, - const char * params_data, - const uint64_t params_size) { +TVM_DLL void* tvm_runtime_create(const char* json_data, const char* params_data, + const uint64_t params_size) { int64_t device_type = kDLCPU; int64_t device_id = 0; @@ -38,9 +37,9 @@ TVM_DLL void * tvm_runtime_create(const char * json_data, ctx.device_id = device_id; // declare pointers - void * (*SystemLibraryCreate)(); - TVMGraphRuntime * (*TVMGraphRuntimeCreate)(const char *, const TVMModuleHandle, const TVMContext *); - int (*TVMGraphRuntime_LoadParams)(TVMModuleHandle, const char *, const uint32_t); + void* (*SystemLibraryCreate)(); + TVMGraphRuntime* (*TVMGraphRuntimeCreate)(const char*, const TVMModuleHandle, const TVMContext*); + int (*TVMGraphRuntime_LoadParams)(TVMModuleHandle, const char*, const uint32_t); // get pointers TVMFuncGetGlobal("runtime.SystemLib", (TVMFunctionHandle*)&SystemLibraryCreate); @@ -51,30 +50,30 @@ TVM_DLL void * tvm_runtime_create(const char * json_data, TVMModuleHandle mod = TVMGraphRuntimeCreate(json_data, mod_syslib, &ctx); TVMModGetFunction(mod, "load_params", 0, (TVMFunctionHandle*)&TVMGraphRuntime_LoadParams); TVMGraphRuntime_LoadParams(mod, params.data, params.size); - + return mod; } -TVM_DLL void tvm_runtime_destroy(void * runtime) { - void (*TVMGraphRuntimeRelease)(TVMModuleHandle *); +TVM_DLL void tvm_runtime_destroy(void* runtime) { + void (*TVMGraphRuntimeRelease)(TVMModuleHandle*); TVMFuncGetGlobal("tvm.graph_runtime.release", (TVMFunctionHandle*)&TVMGraphRuntimeRelease); TVMGraphRuntimeRelease(&runtime); } -TVM_DLL void tvm_runtime_set_input(void * runtime, const char * name, DLTensor * tensor) { - void (*TVMGraphRuntime_SetInput)(TVMModuleHandle, const char *, DLTensor*); +TVM_DLL void tvm_runtime_set_input(void* runtime, const char* name, DLTensor* tensor) { + void (*TVMGraphRuntime_SetInput)(TVMModuleHandle, const char*, DLTensor*); TVMFuncGetGlobal("tvm.graph_runtime.set_input", (TVMFunctionHandle*)&TVMGraphRuntime_SetInput); TVMGraphRuntime_SetInput(runtime, name, tensor); } -TVM_DLL void tvm_runtime_run(void * runtime) { +TVM_DLL void tvm_runtime_run(void* runtime) { void (*TVMGraphRuntime_Run)(TVMModuleHandle runtime); TVMFuncGetGlobal("tvm.graph_runtime.run", (TVMFunctionHandle*)&TVMGraphRuntime_Run); TVMGraphRuntime_Run(runtime); } -TVM_DLL void tvm_runtime_get_output(void * runtime, int32_t index, DLTensor * tensor) { - int (*TVMGraphRuntime_GetOutput)(TVMModuleHandle, const int32_t, DLTensor *); +TVM_DLL void tvm_runtime_get_output(void* runtime, int32_t index, DLTensor* tensor) { + int (*TVMGraphRuntime_GetOutput)(TVMModuleHandle, const int32_t, DLTensor*); TVMFuncGetGlobal("tvm.graph_runtime.get_output", (TVMFunctionHandle*)&TVMGraphRuntime_GetOutput); TVMGraphRuntime_GetOutput(runtime, index, tensor); } \ No newline at end of file diff --git a/apps/bundle_deploy/demo.cc b/apps/bundle_deploy/demo.cc index 0de10d7177eb..5c210a2cab88 100644 --- a/apps/bundle_deploy/demo.cc +++ b/apps/bundle_deploy/demo.cc @@ -17,44 +17,44 @@ * under the License. */ +#include +#include //dlopen +#include #include -#include -#include //dlopen #include #include #include -#include #include "build/graph.json.c" #include "build/params.bin.c" -template auto getFunc(void *bundle, const char *name) { +template +auto getFunc(void* bundle, const char* name) { dlerror(); - auto *f = - reinterpret_cast::type>(dlsym(bundle, name)); + auto* f = reinterpret_cast::type>(dlsym(bundle, name)); assert(!dlerror()); return f; } -int main(int argc, char **argv) { +int main(int argc, char** argv) { assert(argc == 3 && "Usage: demo "); - auto *bundle = dlopen(argv[1], RTLD_LAZY | RTLD_LOCAL); + auto* bundle = dlopen(argv[1], RTLD_LAZY | RTLD_LOCAL); assert(bundle); - char * json_data = reinterpret_cast(build_graph_json); - char * params_data = reinterpret_cast(build_params_bin); + char* json_data = reinterpret_cast(build_graph_json); + char* params_data = reinterpret_cast(build_params_bin); uint64_t params_size = build_params_bin_len; struct timeval t0, t1, t2, t3, t4, t5; gettimeofday(&t0, 0); - auto *handle = getFunc(bundle, "tvm_runtime_create")( + auto* handle = getFunc(bundle, "tvm_runtime_create")( json_data, params_data, params_size); gettimeofday(&t1, 0); float input_storage[1 * 3 * 224 * 224]; - FILE * fp = fopen(argv[2], "rb"); + FILE* fp = fopen(argv[2], "rb"); fread(input_storage, 3 * 224 * 224, 4, fp); fclose(fp); @@ -68,12 +68,10 @@ int main(int argc, char **argv) { input.strides = nullptr; input.byte_offset = 0; - getFunc(bundle, "tvm_runtime_set_input")( - handle, "data", &input); + getFunc(bundle, "tvm_runtime_set_input")(handle, "data", &input); gettimeofday(&t2, 0); - auto *ftvm_runtime_run = - (auto (*)(void *)->void)dlsym(bundle, "tvm_runtime_run"); + auto* ftvm_runtime_run = (auto (*)(void*)->void)dlsym(bundle, "tvm_runtime_run"); assert(!dlerror()); ftvm_runtime_run(handle); gettimeofday(&t3, 0); @@ -89,8 +87,7 @@ int main(int argc, char **argv) { output.strides = nullptr; output.byte_offset = 0; - getFunc(bundle, "tvm_runtime_get_output")( - handle, 0, &output); + getFunc(bundle, "tvm_runtime_get_output")(handle, 0, &output); gettimeofday(&t4, 0); float max_iter = -std::numeric_limits::max(); @@ -102,19 +99,19 @@ int main(int argc, char **argv) { } } - getFunc(bundle, "tvm_runtime_destroy")(handle); + getFunc(bundle, "tvm_runtime_destroy")(handle); gettimeofday(&t5, 0); - printf("The maximum position in output vector is: %d, with max-value %f.\n", - max_index, max_iter); - printf("timing: %.2f ms (create), %.2f ms (set_input), %.2f ms (run), " - "%.2f ms (get_output), %.2f ms (destroy)\n", - (t1.tv_sec-t0.tv_sec)*1000.0f + (t1.tv_usec-t0.tv_usec)/1000.f, - (t2.tv_sec-t1.tv_sec)*1000.0f + (t2.tv_usec-t1.tv_usec)/1000.f, - (t3.tv_sec-t2.tv_sec)*1000.0f + (t3.tv_usec-t2.tv_usec)/1000.f, - (t4.tv_sec-t3.tv_sec)*1000.0f + (t4.tv_usec-t3.tv_usec)/1000.f, - (t5.tv_sec-t4.tv_sec)*1000.0f + (t5.tv_usec-t4.tv_usec)/1000.f); + printf("The maximum position in output vector is: %d, with max-value %f.\n", max_index, max_iter); + printf( + "timing: %.2f ms (create), %.2f ms (set_input), %.2f ms (run), " + "%.2f ms (get_output), %.2f ms (destroy)\n", + (t1.tv_sec - t0.tv_sec) * 1000.0f + (t1.tv_usec - t0.tv_usec) / 1000.f, + (t2.tv_sec - t1.tv_sec) * 1000.0f + (t2.tv_usec - t1.tv_usec) / 1000.f, + (t3.tv_sec - t2.tv_sec) * 1000.0f + (t3.tv_usec - t2.tv_usec) / 1000.f, + (t4.tv_sec - t3.tv_sec) * 1000.0f + (t4.tv_usec - t3.tv_usec) / 1000.f, + (t5.tv_sec - t4.tv_sec) * 1000.0f + (t5.tv_usec - t4.tv_usec) / 1000.f); dlclose(bundle); - + return 0; } diff --git a/apps/bundle_deploy/demo_static.c b/apps/bundle_deploy/demo_static.c index ed003738b0e6..24aafbaf658b 100644 --- a/apps/bundle_deploy/demo_static.c +++ b/apps/bundle_deploy/demo_static.c @@ -17,36 +17,35 @@ * under the License. */ -#include - #include +#include #include -#include #include -#include +#include +#include -#include "bundle.h" #include "build/graph.json.c" #include "build/params.bin.c" +#include "bundle.h" -#define OUTPUT_LEN 1000 +#define OUTPUT_LEN 1000 -int main(int argc, char **argv) { +int main(int argc, char** argv) { assert(argc == 2 && "Usage: demo_static "); - char * json_data = (char *)(build_graph_json); - char * params_data = (char *)(build_params_bin); + char* json_data = (char*)(build_graph_json); + char* params_data = (char*)(build_params_bin); uint64_t params_size = build_params_bin_len; struct timeval t0, t1, t2, t3, t4, t5; gettimeofday(&t0, 0); - auto *handle = tvm_runtime_create(json_data, params_data, params_size); + void* handle = tvm_runtime_create(json_data, params_data, params_size); gettimeofday(&t1, 0); float input_storage[1 * 3 * 224 * 224]; - FILE * fp = fopen(argv[1], "rb"); - fread(input_storage, 3 * 224 * 224, 4, fp); + FILE* fp = fopen(argv[1], "rb"); + (void)fread(input_storage, 3 * 224 * 224, 4, fp); fclose(fp); DLTensor input; @@ -56,7 +55,7 @@ int main(int argc, char **argv) { input.ndim = 4; DLDataType dtype = {kDLFloat, 32, 1}; input.dtype = dtype; - int64_t shape [4] = {1, 3, 224, 224}; + int64_t shape[4] = {1, 3, 224, 224}; input.shape = shape; input.strides = NULL; input.byte_offset = 0; @@ -85,7 +84,7 @@ int main(int argc, char **argv) { float max_iter = -FLT_MAX; int32_t max_index = -1; - for (auto i = 0; i < OUTPUT_LEN; ++i) { + for (int i = 0; i < OUTPUT_LEN; ++i) { if (output_storage[i] > max_iter) { max_iter = output_storage[i]; max_index = i; @@ -95,15 +94,15 @@ int main(int argc, char **argv) { tvm_runtime_destroy(handle); gettimeofday(&t5, 0); - printf("The maximum position in output vector is: %d, with max-value %f.\n", - max_index, max_iter); - printf("timing: %.2f ms (create), %.2f ms (set_input), %.2f ms (run), " - "%.2f ms (get_output), %.2f ms (destroy)\n", - (t1.tv_sec-t0.tv_sec)*1000 + (t1.tv_usec-t0.tv_usec)/1000.f, - (t2.tv_sec-t1.tv_sec)*1000 + (t2.tv_usec-t1.tv_usec)/1000.f, - (t3.tv_sec-t2.tv_sec)*1000 + (t3.tv_usec-t2.tv_usec)/1000.f, - (t4.tv_sec-t3.tv_sec)*1000 + (t4.tv_usec-t3.tv_usec)/1000.f, - (t5.tv_sec-t4.tv_sec)*1000 + (t5.tv_usec-t4.tv_usec)/1000.f); + printf("The maximum position in output vector is: %d, with max-value %f.\n", max_index, max_iter); + printf( + "timing: %.2f ms (create), %.2f ms (set_input), %.2f ms (run), " + "%.2f ms (get_output), %.2f ms (destroy)\n", + (t1.tv_sec - t0.tv_sec) * 1000 + (t1.tv_usec - t0.tv_usec) / 1000.f, + (t2.tv_sec - t1.tv_sec) * 1000 + (t2.tv_usec - t1.tv_usec) / 1000.f, + (t3.tv_sec - t2.tv_sec) * 1000 + (t3.tv_usec - t2.tv_usec) / 1000.f, + (t4.tv_sec - t3.tv_sec) * 1000 + (t4.tv_usec - t3.tv_usec) / 1000.f, + (t5.tv_sec - t4.tv_sec) * 1000 + (t5.tv_usec - t4.tv_usec) / 1000.f); return 0; } diff --git a/apps/bundle_deploy/runtime.c b/apps/bundle_deploy/runtime.c index a7ffea9bbf91..248a295f97b8 100644 --- a/apps/bundle_deploy/runtime.c +++ b/apps/bundle_deploy/runtime.c @@ -58,9 +58,9 @@ /*! \brief Page size for virtual memory allocation */ #define TVM_CRT_PAGE_BYTES 4096 -#include "../../src/runtime/crt/crt_runtime_api.c" #include "../../src/runtime/crt/crt_backend_api.c" +#include "../../src/runtime/crt/crt_runtime_api.c" #include "../../src/runtime/crt/graph_runtime.c" #include "../../src/runtime/crt/load_json.c" -#include "../../src/runtime/crt/ndarray.c" #include "../../src/runtime/crt/memory.c" +#include "../../src/runtime/crt/ndarray.c" diff --git a/apps/bundle_deploy/runtime.cc b/apps/bundle_deploy/runtime.cc index 7a116e89fa88..8e294a05775d 100644 --- a/apps/bundle_deploy/runtime.cc +++ b/apps/bundle_deploy/runtime.cc @@ -19,19 +19,19 @@ #include #include -#include #include +#include #include "../../src/runtime/c_runtime_api.cc" #include "../../src/runtime/cpu_device_api.cc" -#include "../../src/runtime/workspace_pool.cc" +#include "../../src/runtime/file_util.cc" +#include "../../src/runtime/graph/graph_runtime.cc" #include "../../src/runtime/library_module.cc" #include "../../src/runtime/module.cc" -#include "../../src/runtime/registry.cc" -#include "../../src/runtime/file_util.cc" -#include "../../src/runtime/threading_backend.cc" -#include "../../src/runtime/thread_pool.cc" #include "../../src/runtime/ndarray.cc" #include "../../src/runtime/object.cc" +#include "../../src/runtime/registry.cc" #include "../../src/runtime/system_library.cc" -#include "../../src/runtime/graph/graph_runtime.cc" +#include "../../src/runtime/thread_pool.cc" +#include "../../src/runtime/threading_backend.cc" +#include "../../src/runtime/workspace_pool.cc" diff --git a/apps/bundle_deploy/test.cc b/apps/bundle_deploy/test.cc index c92400d29516..882e04be8ef9 100644 --- a/apps/bundle_deploy/test.cc +++ b/apps/bundle_deploy/test.cc @@ -17,35 +17,35 @@ * under the License. */ +#include +#include //dlopen +#include +#include #include -#include -#include //dlopen #include #include #include -#include -#include -template auto getFunc(void *bundle, const char *name) { +template +auto getFunc(void* bundle, const char* name) { dlerror(); - auto *f = - reinterpret_cast::type>(dlsym(bundle, name)); + auto* f = reinterpret_cast::type>(dlsym(bundle, name)); assert(!dlerror()); return f; } -int main(int argc, char **argv) { +int main(int argc, char** argv) { assert(argc == 6 && "Usage: test "); - auto *bundle = dlopen(argv[1], RTLD_LAZY | RTLD_LOCAL); + auto* bundle = dlopen(argv[1], RTLD_LAZY | RTLD_LOCAL); assert(bundle); struct stat st; - char * json_data; - char * params_data; + char* json_data; + char* params_data; uint64_t params_size; - FILE * fp = fopen(argv[4], "rb"); + FILE* fp = fopen(argv[4], "rb"); stat(argv[4], &st); json_data = (char*)malloc(st.st_size); fread(json_data, st.st_size, 1, fp); @@ -61,7 +61,7 @@ int main(int argc, char **argv) { struct timeval t0, t1, t2, t3, t4, t5; gettimeofday(&t0, 0); - auto *handle = getFunc(bundle, "tvm_runtime_create")( + auto* handle = getFunc(bundle, "tvm_runtime_create")( json_data, params_data, params_size); gettimeofday(&t1, 0); @@ -85,12 +85,10 @@ int main(int argc, char **argv) { input.strides = nullptr; input.byte_offset = 0; - getFunc(bundle, "tvm_runtime_set_input")( - handle, "x", &input); + getFunc(bundle, "tvm_runtime_set_input")(handle, "x", &input); gettimeofday(&t2, 0); - auto *ftvm_runtime_run = - (auto (*)(void *)->void)dlsym(bundle, "tvm_runtime_run"); + auto* ftvm_runtime_run = (auto (*)(void*)->void)dlsym(bundle, "tvm_runtime_run"); assert(!dlerror()); ftvm_runtime_run(handle); gettimeofday(&t3, 0); @@ -106,8 +104,7 @@ int main(int argc, char **argv) { output.strides = nullptr; output.byte_offset = 0; - getFunc(bundle, "tvm_runtime_get_output")( - handle, 0, &output); + getFunc(bundle, "tvm_runtime_get_output")(handle, 0, &output); gettimeofday(&t4, 0); for (auto i = 0; i < 10 * 5; ++i) { @@ -117,20 +114,21 @@ int main(int argc, char **argv) { } } - getFunc(bundle, "tvm_runtime_destroy")(handle); + getFunc(bundle, "tvm_runtime_destroy")(handle); gettimeofday(&t5, 0); - printf("timing: %.2f ms (create), %.2f ms (set_input), %.2f ms (run), " - "%.2f ms (get_output), %.2f ms (destroy)\n", - (t1.tv_sec-t0.tv_sec)*1000.0f + (t1.tv_usec-t0.tv_usec)/1000.f, - (t2.tv_sec-t1.tv_sec)*1000.0f + (t2.tv_usec-t1.tv_usec)/1000.f, - (t3.tv_sec-t2.tv_sec)*1000.0f + (t3.tv_usec-t2.tv_usec)/1000.f, - (t4.tv_sec-t3.tv_sec)*1000.0f + (t4.tv_usec-t3.tv_usec)/1000.f, - (t5.tv_sec-t4.tv_sec)*1000.0f + (t5.tv_usec-t4.tv_usec)/1000.f); + printf( + "timing: %.2f ms (create), %.2f ms (set_input), %.2f ms (run), " + "%.2f ms (get_output), %.2f ms (destroy)\n", + (t1.tv_sec - t0.tv_sec) * 1000.0f + (t1.tv_usec - t0.tv_usec) / 1000.f, + (t2.tv_sec - t1.tv_sec) * 1000.0f + (t2.tv_usec - t1.tv_usec) / 1000.f, + (t3.tv_sec - t2.tv_sec) * 1000.0f + (t3.tv_usec - t2.tv_usec) / 1000.f, + (t4.tv_sec - t3.tv_sec) * 1000.0f + (t4.tv_usec - t3.tv_usec) / 1000.f, + (t5.tv_sec - t4.tv_sec) * 1000.0f + (t5.tv_usec - t4.tv_usec) / 1000.f); free(json_data); free(params_data); dlclose(bundle); - + return 0; } diff --git a/apps/bundle_deploy/test_static.c b/apps/bundle_deploy/test_static.c index 05928744ba81..fca08d18da74 100644 --- a/apps/bundle_deploy/test_static.c +++ b/apps/bundle_deploy/test_static.c @@ -17,27 +17,25 @@ * under the License. */ -#include - #include +#include #include #include -#include -#include #include +#include +#include #include "bundle.h" - -int main(int argc, char **argv) { +int main(int argc, char** argv) { assert(argc == 5 && "Usage: test_static "); struct stat st; - char * json_data; - char * params_data; + char* json_data; + char* params_data; uint64_t params_size; - FILE * fp = fopen(argv[3], "rb"); + FILE* fp = fopen(argv[3], "rb"); stat(argv[3], &st); json_data = (char*)malloc(st.st_size); fread(json_data, st.st_size, 1, fp); @@ -53,7 +51,7 @@ int main(int argc, char **argv) { struct timeval t0, t1, t2, t3, t4, t5; gettimeofday(&t0, 0); - auto *handle = tvm_runtime_create(json_data, params_data, params_size); + auto* handle = tvm_runtime_create(json_data, params_data, params_size); gettimeofday(&t1, 0); float input_storage[10 * 5]; @@ -110,13 +108,14 @@ int main(int argc, char **argv) { tvm_runtime_destroy(handle); gettimeofday(&t5, 0); - printf("timing: %.2f ms (create), %.2f ms (set_input), %.2f ms (run), " - "%.2f ms (get_output), %.2f ms (destroy)\n", - (t1.tv_sec-t0.tv_sec)*1000 + (t1.tv_usec-t0.tv_usec)/1000.f, - (t2.tv_sec-t1.tv_sec)*1000 + (t2.tv_usec-t1.tv_usec)/1000.f, - (t3.tv_sec-t2.tv_sec)*1000 + (t3.tv_usec-t2.tv_usec)/1000.f, - (t4.tv_sec-t3.tv_sec)*1000 + (t4.tv_usec-t3.tv_usec)/1000.f, - (t5.tv_sec-t4.tv_sec)*1000 + (t5.tv_usec-t4.tv_usec)/1000.f); + printf( + "timing: %.2f ms (create), %.2f ms (set_input), %.2f ms (run), " + "%.2f ms (get_output), %.2f ms (destroy)\n", + (t1.tv_sec - t0.tv_sec) * 1000 + (t1.tv_usec - t0.tv_usec) / 1000.f, + (t2.tv_sec - t1.tv_sec) * 1000 + (t2.tv_usec - t1.tv_usec) / 1000.f, + (t3.tv_sec - t2.tv_sec) * 1000 + (t3.tv_usec - t2.tv_usec) / 1000.f, + (t4.tv_sec - t3.tv_sec) * 1000 + (t4.tv_usec - t3.tv_usec) / 1000.f, + (t5.tv_sec - t4.tv_sec) * 1000 + (t5.tv_usec - t4.tv_usec) / 1000.f); free(json_data); free(params_data); diff --git a/apps/cpp_rpc/Makefile b/apps/cpp_rpc/Makefile index 927331ad00ea..5cd87e929223 100644 --- a/apps/cpp_rpc/Makefile +++ b/apps/cpp_rpc/Makefile @@ -47,7 +47,7 @@ all: tvm_rpc # Build rule for all in one TVM package library tvm_rpc: *.cc @mkdir -p $(@D) - $(CXX) $(PKG_CFLAGS) -o $@ $(filter %.cc %.o %.a, $^) $(PKG_LDFLAGS) + $(CXX) $(PKG_CFLAGS) -o $@ $(filter-out win32_process.cc, $(filter %.cc %.o %.a, $^)) $(PKG_LDFLAGS) clean: -rm -f tvm_rpc \ No newline at end of file diff --git a/apps/cpp_rpc/README.md b/apps/cpp_rpc/README.md index c826dae80c22..6e50002cf4ca 100644 --- a/apps/cpp_rpc/README.md +++ b/apps/cpp_rpc/README.md @@ -39,7 +39,7 @@ This folder contains a simple recipe to make RPC server in c++. - Build tvm with the argument -DUSE_CPP_RPC - Install [LLVM pre-build binaries](https://releases.llvm.org/download.html), making sure to select the option to add it to the PATH. - Verify Python 3.6 or newer is installed and in the PATH. -- Use `\tvm_rpc.exe` to start the RPC server +- Use `\tvm_rpc.exe` to start the RPC server ## How it works - The tvm runtime dll is linked along with this executable and when the RPC server starts it will load the tvm runtime library. @@ -59,4 +59,4 @@ Command line usage ``` ## Note -Currently support is only there for Linux / Android / Windows environment and proxy mode doesn't be supported currently. \ No newline at end of file +Currently support is only there for Linux / Android / Windows environment and proxy mode doesn't be supported currently. diff --git a/apps/cpp_rpc/main.cc b/apps/cpp_rpc/main.cc index 5168da31d696..ae2636da7555 100644 --- a/apps/cpp_rpc/main.cc +++ b/apps/cpp_rpc/main.cc @@ -21,20 +21,21 @@ * \file rpc_server.cc * \brief RPC Server for TVM. */ -#include #include #include +#include #if defined(__linux__) || defined(__ANDROID__) #include #endif #include -#include + #include -#include +#include #include +#include -#include "../../src/support/util.h" #include "../../src/support/socket.h" +#include "../../src/support/util.h" #include "rpc_server.h" #if defined(_WIN32) @@ -45,21 +46,21 @@ using namespace std; using namespace tvm::runtime; using namespace tvm::support; -static const string kUsage = \ -"Command line usage\n" \ -" server - Start the server\n" \ -"--host - The hostname of the server, Default=0.0.0.0\n" \ -"--port - The port of the RPC, Default=9090\n" \ -"--port-end - The end search port of the RPC, Default=9199\n" \ -"--tracker - The RPC tracker address in host:port format e.g. 10.1.1.2:9190 Default=\"\"\n" \ -"--key - The key used to identify the device type in tracker. Default=\"\"\n" \ -"--custom-addr - Custom IP Address to Report to RPC Tracker. Default=\"\"\n" \ -"--silent - Whether to run in silent mode. Default=False\n" \ -"\n" \ -" Example\n" \ -" ./tvm_rpc server --host=0.0.0.0 --port=9000 --port-end=9090 " -" --tracker=127.0.0.1:9190 --key=rasp" \ -"\n"; +static const string kUsage = + "Command line usage\n" + " server - Start the server\n" + "--host - The hostname of the server, Default=0.0.0.0\n" + "--port - The port of the RPC, Default=9090\n" + "--port-end - The end search port of the RPC, Default=9199\n" + "--tracker - The RPC tracker address in host:port format e.g. 10.1.1.2:9190 Default=\"\"\n" + "--key - The key used to identify the device type in tracker. Default=\"\"\n" + "--custom-addr - Custom IP Address to Report to RPC Tracker. Default=\"\"\n" + "--silent - Whether to run in silent mode. Default=False\n" + "\n" + " Example\n" + " ./tvm_rpc server --host=0.0.0.0 --port=9000 --port-end=9090 " + " --tracker=127.0.0.1:9190 --key=rasp" + "\n"; /*! * \brief RpcServerArgs. @@ -95,7 +96,7 @@ void PrintArgs(const RpcServerArgs& args) { LOG(INFO) << "tracker = " << args.tracker; LOG(INFO) << "key = " << args.key; LOG(INFO) << "custom_addr = " << args.custom_addr; - LOG(INFO) << "silent = " << ((args.silent) ? ("True"): ("False")); + LOG(INFO) << "silent = " << ((args.silent) ? ("True") : ("False")); } #if defined(__linux__) || defined(__ANDROID__) @@ -151,7 +152,7 @@ string GetCmdOption(int argc, char* argv[], string option, bool key = false) { * \param tracker The tracker input. * \return result of operation. */ -bool ValidateTracker(string &tracker) { +bool ValidateTracker(string& tracker) { vector list = Split(tracker, ':'); if ((list.size() != 2) || (!ValidateIP(list[0])) || (!IsNumber(list[1]))) { return false; @@ -168,7 +169,7 @@ bool ValidateTracker(string &tracker) { * \param argv arg values * \param args the output structure which holds the parsed values */ -void ParseCmdArgs(int argc, char * argv[], struct RpcServerArgs &args) { +void ParseCmdArgs(int argc, char* argv[], struct RpcServerArgs& args) { const string silent = GetCmdOption(argc, argv, "--silent", true); if (!silent.empty()) { args.silent = true; @@ -232,12 +233,11 @@ void ParseCmdArgs(int argc, char * argv[], struct RpcServerArgs &args) { } #if defined(WIN32) const string mmap_path = GetCmdOption(argc, argv, "--child_proc="); - if(!mmap_path.empty()) { + if (!mmap_path.empty()) { args.mmap_path = mmap_path; dmlc::InitLogging("--minloglevel=0"); } #endif - } /*! @@ -246,7 +246,7 @@ void ParseCmdArgs(int argc, char * argv[], struct RpcServerArgs &args) { * \param argv arg values * \return result of operation. */ -int RpcServer(int argc, char * argv[]) { +int RpcServer(int argc, char* argv[]) { RpcServerArgs args; /* parse the command line args */ @@ -260,21 +260,21 @@ int RpcServer(int argc, char * argv[]) { #endif #if defined(WIN32) - if(!args.mmap_path.empty()) { + if (!args.mmap_path.empty()) { int ret = 0; try { - ChildProcSocketHandler(args.mmap_path); + ChildProcSocketHandler(args.mmap_path); } catch (const std::exception&) { - ret = -1; + ret = -1; } return ret; } #endif - RPCServerCreate(args.host, args.port, args.port_end, args.tracker, - args.key, args.custom_addr, args.silent); + RPCServerCreate(args.host, args.port, args.port_end, args.tracker, args.key, args.custom_addr, + args.silent); return 0; } @@ -284,7 +284,7 @@ int RpcServer(int argc, char * argv[]) { * \param argv arg values * \return result of operation. */ -int main(int argc, char * argv[]) { +int main(int argc, char* argv[]) { if (argc <= 1) { LOG(INFO) << kUsage; return 0; diff --git a/apps/cpp_rpc/rpc_env.cc b/apps/cpp_rpc/rpc_env.cc index b5dc51b9e7ef..a4286086888a 100644 --- a/apps/cpp_rpc/rpc_env.cc +++ b/apps/cpp_rpc/rpc_env.cc @@ -20,8 +20,9 @@ * \file rpc_env.cc * \brief Server environment of the RPC. */ -#include #include + +#include #ifndef _WIN32 #include #include @@ -30,44 +31,53 @@ #include #include namespace { - int mkdir(const char* path, int /* ignored */) { return _mkdir(path); } -} +int mkdir(const char* path, int /* ignored */) { return _mkdir(path); } +} // namespace #endif #include #include #include #include #include -#include -#include "../../src/support/util.h" #include "../../src/runtime/file_util.h" +#include "../../src/support/util.h" #include "rpc_env.h" namespace { - std::string GenerateUntarCommand(const std::string& tar_file, const std::string& output_dir) { - std::string untar_cmd; - untar_cmd.reserve(512); +std::string GenerateUntarCommand(const std::string& tar_file, const std::string& output_dir) { + std::string untar_cmd; + untar_cmd.reserve(512); #if defined(__linux__) || defined(__ANDROID__) - untar_cmd += "tar -C "; - untar_cmd += output_dir; - untar_cmd += " -zxf "; - untar_cmd += tar_file; + untar_cmd += "tar -C "; + untar_cmd += output_dir; + untar_cmd += " -zxf "; + untar_cmd += tar_file; #elif defined(_WIN32) - untar_cmd += "python -m tarfile -e "; - untar_cmd += tar_file; - untar_cmd += " "; - untar_cmd += output_dir; + untar_cmd += "python -m tarfile -e "; + untar_cmd += tar_file; + untar_cmd += " "; + untar_cmd += output_dir; #endif - return untar_cmd; - } + return untar_cmd; +} -}// Anonymous namespace +} // Anonymous namespace namespace tvm { namespace runtime { RPCEnv::RPCEnv() { +#ifndef _WIN32 + char cwd[PATH_MAX]; + if (getcwd(cwd, sizeof(cwd))) { + base_ = std::string(cwd) + "/rpc"; + } else { + base_ = "./rpc"; + } +#else base_ = "./rpc"; +#endif + mkdir(base_.c_str(), 0777); TVM_REGISTER_GLOBAL("tvm.rpc.server.workpath").set_body([](TVMArgs args, TVMRetValue* rv) { static RPCEnv env; @@ -162,22 +172,20 @@ std::vector ListDir(const std::string& dirname) { * \param options The compiler options * \param cc The compiler */ -void LinuxShared(const std::string output, - const std::vector &files, - std::string options = "", - std::string cc = "g++") { - std::string cmd = cc; - cmd += " -shared -fPIC "; - cmd += " -o " + output; - for (auto f = files.begin(); f != files.end(); ++f) { - cmd += " " + *f; - } - cmd += " " + options; - std::string err_msg; - auto executed_status = support::Execute(cmd, &err_msg); - if (executed_status) { - LOG(FATAL) << err_msg; - } +void LinuxShared(const std::string output, const std::vector& files, + std::string options = "", std::string cc = "g++") { + std::string cmd = cc; + cmd += " -shared -fPIC "; + cmd += " -o " + output; + for (auto f = files.begin(); f != files.end(); ++f) { + cmd += " " + *f; + } + cmd += " " + options; + std::string err_msg; + auto executed_status = support::Execute(cmd, &err_msg); + if (executed_status) { + LOG(FATAL) << err_msg; + } } #endif @@ -189,10 +197,8 @@ void LinuxShared(const std::string output, * \param options The compiler options * \param cc The compiler */ -void WindowsShared(const std::string& output, - const std::vector& files, - const std::string& options = "", - const std::string& cc = "clang") { +void WindowsShared(const std::string& output, const std::vector& files, + const std::string& options = "", const std::string& cc = "clang") { std::string cmd = cc; cmd += " -O2 -flto=full -fuse-ld=lld-link -Wl,/EXPORT:__tvm_main__ -shared "; cmd += " -o " + output; @@ -233,7 +239,7 @@ void CreateShared(const std::string& output, const std::vector& fil * \param fmt The format of file * \return Module The loaded module */ -Module Load(std::string *fileIn, const std::string& fmt) { +Module Load(std::string* fileIn, const std::string& fmt) { const std::string& file = *fileIn; if (support::EndsWith(file, ".so") || support::EndsWith(file, ".dll")) { return Module::LoadFromFile(file, fmt); diff --git a/apps/cpp_rpc/rpc_env.h b/apps/cpp_rpc/rpc_env.h index d046f6ecb480..464b10a2714c 100644 --- a/apps/cpp_rpc/rpc_env.h +++ b/apps/cpp_rpc/rpc_env.h @@ -25,6 +25,7 @@ #define TVM_APPS_CPP_RPC_ENV_H_ #include + #include namespace tvm { @@ -40,13 +41,13 @@ namespace runtime { * \param file The format of file * \return Module The loaded module */ -Module Load(std::string *path, const std::string& fmt = ""); +Module Load(std::string* path, const std::string& fmt = ""); /*! * \brief CleanDir Removes the files from the directory * \param dirname THe name of the directory */ -void CleanDir(const std::string &dirname); +void CleanDir(const std::string& dirname); /*! * \brief RPCEnv The RPC Environment parameters for c++ rpc server diff --git a/apps/cpp_rpc/rpc_server.cc b/apps/cpp_rpc/rpc_server.cc index ea4ab00c113b..2628ff77a5f7 100644 --- a/apps/cpp_rpc/rpc_server.cc +++ b/apps/cpp_rpc/rpc_server.cc @@ -32,9 +32,9 @@ #include #include -#include "../../src/support/socket.h" -#include "../../src/runtime/rpc/rpc_session.h" +#include "../../src/runtime/rpc/rpc_endpoint.h" #include "../../src/runtime/rpc/rpc_socket_impl.h" +#include "../../src/support/socket.h" #include "rpc_env.h" #include "rpc_server.h" #include "rpc_tracker_client.h" @@ -66,6 +66,22 @@ static pid_t waitPidEintr(int* status) { } #endif +#ifdef __ANDROID__ +static std::string getNextString(std::stringstream* iss) { + std::string str = iss->str(); + size_t start = iss->tellg(); + size_t len = str.size(); + // Skip leading spaces. + while (start < len && isspace(str[start])) start++; + + size_t end = start; + while (end < len && !isspace(str[end])) end++; + + iss->seekg(end); + return str.substr(start, end - start); +} +#endif + /*! * \brief RPCServer RPC Server class. * \param host The hostname of the server, Default=0.0.0.0 @@ -80,14 +96,15 @@ class RPCServer { /*! * \brief Constructor. */ - RPCServer(std::string host, int port, int port_end, std::string tracker_addr, - std::string key, std::string custom_addr) : - host_(std::move(host)), port_(port), my_port_(0), port_end_(port_end), - tracker_addr_(std::move(tracker_addr)), key_(std::move(key)), - custom_addr_(std::move(custom_addr)) - { - - } + RPCServer(std::string host, int port, int port_end, std::string tracker_addr, std::string key, + std::string custom_addr) + : host_(std::move(host)), + port_(port), + my_port_(0), + port_end_(port_end), + tracker_addr_(std::move(tracker_addr)), + key_(std::move(key)), + custom_addr_(std::move(custom_addr)) {} /*! * \brief Destructor. @@ -97,8 +114,7 @@ class RPCServer { // Free the resources tracker_sock_.Close(); listen_sock_.Close(); - } catch(...) { - + } catch (...) { } } @@ -144,7 +160,7 @@ class RPCServer { } int timeout = GetTimeOutFromOpts(opts); -#if defined(__linux__) || defined(__ANDROID__) +#if defined(__linux__) || defined(__ANDROID__) // step 3: serving if (timeout != 0) { const pid_t timer_pid = fork(); @@ -164,9 +180,9 @@ class RPCServer { int status = 0; const pid_t finished_first = waitPidEintr(&status); if (finished_first == timer_pid) { - kill(worker_pid, SIGKILL); + kill(worker_pid, SIGTERM); } else if (finished_first == worker_pid) { - kill(timer_pid, SIGKILL); + kill(timer_pid, SIGTERM); } else { LOG(INFO) << "Child pid=" << finished_first << " unexpected, but still continue."; } @@ -197,7 +213,6 @@ class RPCServer { try { SpawnRPCChild(conn.sockfd, seconds(timeout)); } catch (const std::exception&) { - } auto dur = high_resolution_clock::now() - start_time; @@ -217,11 +232,8 @@ class RPCServer { * \param opts Parsed options for socket * \param ping_period Timeout for select call waiting */ - void AcceptConnection(TrackerClient* tracker, - support::TCPSocket* conn_sock, - support::SockAddr* addr, - std::string* opts, - int ping_period = 2) { + void AcceptConnection(TrackerClient* tracker, support::TCPSocket* conn_sock, + support::SockAddr* addr, std::string* opts, int ping_period = 2) { std::set old_keyset; std::string matchkey; @@ -233,7 +245,7 @@ class RPCServer { support::TCPSocket conn = listen_sock_.Accept(addr); int code = kRPCMagic; - CHECK_EQ(conn.RecvAll(&code, sizeof(code)), sizeof(code)); + CHECK_EQ(conn.RecvAll(&code, sizeof(code)), sizeof(code)); if (code != kRPCMagic) { conn.Close(); LOG(FATAL) << "Client connected is not TVM RPC server"; @@ -260,7 +272,12 @@ class RPCServer { std::stringstream ssin(remote_key); std::string arg0; +#ifndef __ANDROID__ ssin >> arg0; +#else + arg0 = getNextString(&ssin); +#endif + if (arg0 != expect_header) { code = kRPCMismatch; CHECK_EQ(conn.SendAll(&code, sizeof(code)), sizeof(code)); @@ -274,7 +291,11 @@ class RPCServer { CHECK_EQ(conn.SendAll(&keylen, sizeof(keylen)), sizeof(keylen)); CHECK_EQ(conn.SendAll(server_key.c_str(), keylen), keylen); LOG(INFO) << "Connection success " << addr->AsString(); +#ifndef __ANDROID__ ssin >> *opts; +#else + *opts = getNextString(&ssin); +#endif *conn_sock = conn; return; } @@ -301,8 +322,9 @@ class RPCServer { int GetTimeOutFromOpts(const std::string& opts) const { const std::string option = "-timeout="; - if (opts.find(option) == 0) { - const std::string cmd = opts.substr(opts.find_last_of(option) + 1); + size_t pos = opts.rfind(option); + if (pos != std::string::npos) { + const std::string cmd = opts.substr(pos + option.size()); CHECK(support::IsNumber(cmd)) << "Timeout is not valid"; return std::stoi(cmd); } @@ -322,15 +344,15 @@ class RPCServer { #if defined(WIN32) /*! -* \brief ServerLoopFromChild The Server loop process. -* \param socket The socket information -*/ + * \brief ServerLoopFromChild The Server loop process. + * \param socket The socket information + */ void ServerLoopFromChild(SOCKET socket) { // Server loop tvm::support::TCPSocket sock(socket); const auto env = RPCEnv(); RPCServerLoop(int(sock.sockfd)); - + sock.Close(); env.CleanUp(); } @@ -341,10 +363,10 @@ void ServerLoopFromChild(SOCKET socket) { * \param host The hostname of the server, Default=0.0.0.0 * \param port The port of the RPC, Default=9090 * \param port_end The end search port of the RPC, Default=9199 - * \param tracker_addr The address of RPC tracker in host:port format e.g. 10.77.1.234:9190 Default="" - * \param key The key used to identify the device type in tracker. Default="" - * \param custom_addr Custom IP Address to Report to RPC Tracker. Default="" - * \param silent Whether run in silent mode. Default=True + * \param tracker_addr The address of RPC tracker in host:port format e.g. 10.77.1.234:9190 + * Default="" \param key The key used to identify the device type in tracker. Default="" \param + * custom_addr Custom IP Address to Report to RPC Tracker. Default="" \param silent Whether run in + * silent mode. Default=True */ void RPCServerCreate(std::string host, int port, int port_end, std::string tracker_addr, std::string key, std::string custom_addr, bool silent) { @@ -353,13 +375,13 @@ void RPCServerCreate(std::string host, int port, int port_end, std::string track dmlc::InitLogging("--minloglevel=2"); } // Start the rpc server - RPCServer rpc(std::move(host), port, port_end, std::move(tracker_addr), std::move(key), std::move(custom_addr)); + RPCServer rpc(std::move(host), port, port_end, std::move(tracker_addr), std::move(key), + std::move(custom_addr)); rpc.Start(); } -TVM_REGISTER_GLOBAL("rpc._ServerCreate") -.set_body([](TVMArgs args, TVMRetValue* rv) { - RPCServerCreate(args[0], args[1], args[2], args[3], args[4], args[5], args[6]); - }); +TVM_REGISTER_GLOBAL("rpc.ServerCreate").set_body([](TVMArgs args, TVMRetValue* rv) { + RPCServerCreate(args[0], args[1], args[2], args[3], args[4], args[5], args[6]); +}); } // namespace runtime } // namespace tvm diff --git a/apps/cpp_rpc/rpc_server.h b/apps/cpp_rpc/rpc_server.h index db7c89d823dd..0936c51bb2ce 100644 --- a/apps/cpp_rpc/rpc_server.h +++ b/apps/cpp_rpc/rpc_server.h @@ -25,6 +25,7 @@ #define TVM_APPS_CPP_RPC_SERVER_H_ #include + #include "tvm/runtime/c_runtime_api.h" namespace tvm { @@ -49,13 +50,9 @@ void ServerLoopFromChild(SOCKET socket); * \param custom_addr Custom IP Address to Report to RPC Tracker. Default="" * \param silent Whether run in silent mode. Default=True */ -void RPCServerCreate(std::string host = "", - int port = 9090, - int port_end = 9099, - std::string tracker_addr = "", - std::string key = "", - std::string custom_addr = "", - bool silent = true); +void RPCServerCreate(std::string host = "", int port = 9090, int port_end = 9099, + std::string tracker_addr = "", std::string key = "", + std::string custom_addr = "", bool silent = true); } // namespace runtime } // namespace tvm #endif // TVM_APPS_CPP_RPC_SERVER_H_ diff --git a/apps/cpp_rpc/rpc_tracker_client.h b/apps/cpp_rpc/rpc_tracker_client.h index dfd576f4c195..cdfb64780ba6 100644 --- a/apps/cpp_rpc/rpc_tracker_client.h +++ b/apps/cpp_rpc/rpc_tracker_client.h @@ -24,14 +24,14 @@ #ifndef TVM_APPS_CPP_RPC_TRACKER_CLIENT_H_ #define TVM_APPS_CPP_RPC_TRACKER_CLIENT_H_ -#include -#include #include +#include #include -#include +#include #include +#include -#include "../../src/runtime/rpc/rpc_session.h" +#include "../../src/runtime/rpc/rpc_endpoint.h" #include "../../src/support/socket.h" namespace tvm { @@ -47,29 +47,28 @@ class TrackerClient { public: /*! * \brief Constructor. - */ - TrackerClient(const std::string& tracker_addr, - const std::string& key, + */ + TrackerClient(const std::string& tracker_addr, const std::string& key, const std::string& custom_addr) - : tracker_addr_(tracker_addr), key_(key), custom_addr_(custom_addr), - gen_(std::random_device{}()), dis_(0.0, 1.0) { - } + : tracker_addr_(tracker_addr), + key_(key), + custom_addr_(custom_addr), + gen_(std::random_device{}()), + dis_(0.0, 1.0) {} /*! * \brief Destructor. - */ + */ ~TrackerClient() { // Free the resources Close(); } /*! * \brief IsValid Check tracker is valid. - */ - bool IsValid() { - return (!tracker_addr_.empty() && !tracker_sock_.IsClosed()); - } + */ + bool IsValid() { return (!tracker_addr_.empty() && !tracker_sock_.IsClosed()); } /*! * \brief TryConnect Connect to tracker if the tracker address is valid. - */ + */ void TryConnect() { if (!tracker_addr_.empty() && (tracker_sock_.IsClosed())) { tracker_sock_ = ConnectWithRetry(); @@ -80,8 +79,8 @@ class TrackerClient { CHECK_EQ(code, kRPCTrackerMagic) << tracker_addr_.c_str() << " is not RPC Tracker"; std::ostringstream ss; - ss << "[" << static_cast(TrackerCode::kUpdateInfo) - << ", {\"key\": \"server:"<< key_ << "\"}]"; + ss << "[" << static_cast(TrackerCode::kUpdateInfo) << ", {\"key\": \"server:" << key_ + << "\"}]"; tracker_sock_.SendBytes(ss.str()); // Receive status and validate @@ -91,20 +90,19 @@ class TrackerClient { } /*! * \brief Close Clean up tracker resources. - */ + */ void Close() { // close tracker resource if (!tracker_sock_.IsClosed()) { tracker_sock_.Close(); } } - /*! - * \brief ReportResourceAndGetKey Report resource to tracker. - * \param port listening port. - * \param matchkey Random match key output. - */ - void ReportResourceAndGetKey(int port, - std::string *matchkey) { + /*! + * \brief ReportResourceAndGetKey Report resource to tracker. + * \param port listening port. + * \param matchkey Random match key output. + */ + void ReportResourceAndGetKey(int port, std::string* matchkey) { if (!tracker_sock_.IsClosed()) { *matchkey = RandomKey(key_ + ":", old_keyset_); if (custom_addr_.empty()) { @@ -112,8 +110,8 @@ class TrackerClient { } std::ostringstream ss; - ss << "[" << static_cast(TrackerCode::kPut) << ", \"" << key_ << "\", [" - << port << ", \"" << *matchkey << "\"], " << custom_addr_ << "]"; + ss << "[" << static_cast(TrackerCode::kPut) << ", \"" << key_ << "\", [" << port + << ", \"" << *matchkey << "\"], " << custom_addr_ << "]"; tracker_sock_.SendBytes(ss.str()); @@ -121,7 +119,7 @@ class TrackerClient { std::string remote_status = tracker_sock_.RecvBytes(); CHECK_EQ(std::stoi(remote_status), static_cast(TrackerCode::kSuccess)); } else { - *matchkey = key_; + *matchkey = key_; } } @@ -131,11 +129,9 @@ class TrackerClient { * \param port listening port. * \param ping_period Select wait time. * \param matchkey Random match key output. - */ - void WaitConnectionAndUpdateKey(support::TCPSocket listen_sock, - int port, - int ping_period, - std::string *matchkey) { + */ + void WaitConnectionAndUpdateKey(support::TCPSocket listen_sock, int port, int ping_period, + std::string* matchkey) { int unmatch_period_count = 0; int unmatch_timeout = 4; while (true) { @@ -155,9 +151,9 @@ class TrackerClient { // if match key not in pending key set // it means the key is acquired by a client but not used. if (pending_keys.find(*matchkey) == std::string::npos) { - unmatch_period_count += 1; + unmatch_period_count += 1; } else { - unmatch_period_count = 0; + unmatch_period_count = 0; } // regenerate match key if key is acquired but not used for a while if (unmatch_period_count * ping_period > unmatch_timeout + ping_period) { @@ -166,8 +162,8 @@ class TrackerClient { *matchkey = RandomKey(key_ + ":", old_keyset_); std::ostringstream ss; - ss << "[" << static_cast(TrackerCode::kPut) << ", \"" << key_ << "\", [" - << port << ", \"" << *matchkey << "\"], " << custom_addr_ << "]"; + ss << "[" << static_cast(TrackerCode::kPut) << ", \"" << key_ << "\", [" << port + << ", \"" << *matchkey << "\"], " << custom_addr_ << "]"; tracker_sock_.SendBytes(ss.str()); std::string remote_status = tracker_sock_.RecvBytes(); @@ -201,26 +197,25 @@ class TrackerClient { } auto period = (std::chrono::duration_cast( - std::chrono::system_clock::now() - tbegin)).count(); + std::chrono::system_clock::now() - tbegin)) + .count(); CHECK(period < timeout) << "Failed to connect to server" << addr.AsString(); - LOG(WARNING) << "Cannot connect to tracker " << addr.AsString() - << " retry in " << retry_period << " seconds."; + LOG(WARNING) << "Cannot connect to tracker " << addr.AsString() << " retry in " + << retry_period << " seconds."; std::this_thread::sleep_for(std::chrono::seconds(retry_period)); } } /*! - * \brief Random Generate a random number between 0 and 1. - * \return random float value. - */ - float Random() { - return dis_(gen_); - } + * \brief Random Generate a random number between 0 and 1. + * \return random float value. + */ + float Random() { return dis_(gen_); } /*! * \brief Generate a random key. * \param prefix The string prefix. * \return cmap The conflict map set. */ - std::string RandomKey(const std::string& prefix, const std::set &cmap) { + std::string RandomKey(const std::string& prefix, const std::set& cmap) { if (!cmap.empty()) { while (true) { std::string key = prefix + std::to_string(Random()); @@ -236,10 +231,9 @@ class TrackerClient { std::string key_; std::string custom_addr_; support::TCPSocket tracker_sock_; - std::set old_keyset_; + std::set old_keyset_; std::mt19937 gen_; std::uniform_real_distribution dis_; - }; } // namespace runtime } // namespace tvm diff --git a/apps/cpp_rpc/win32_process.cc b/apps/cpp_rpc/win32_process.cc index c6c72d79ab81..bbf8367903bb 100644 --- a/apps/cpp_rpc/win32_process.cc +++ b/apps/cpp_rpc/win32_process.cc @@ -20,15 +20,18 @@ #ifndef WIN32_LEAN_AND_MEAN #define WIN32_LEAN_AND_MEAN #endif +#include "win32_process.h" + +#include +#include #include #include + #include #include -#include -#include #include -#include -#include "win32_process.h" +#include + #include "rpc_server.h" using namespace std::chrono; @@ -82,36 +85,36 @@ UniqueHandle MakeUniqueHandle(HANDLE handle) { */ SOCKET GetSocket(const std::string& mmap_path) { WSAPROTOCOL_INFO protocol_info; - + const std::string parent_event_name = mmap_path + kParent; const std::string child_event_name = mmap_path + kChild; // Open the events UniqueHandle parent_file_mapping_event; - if ((parent_file_mapping_event = MakeUniqueHandle(OpenEventA(SYNCHRONIZE, false, parent_event_name.c_str()))) == nullptr) { + if ((parent_file_mapping_event = MakeUniqueHandle( + OpenEventA(SYNCHRONIZE, false, parent_event_name.c_str()))) == nullptr) { LOG(FATAL) << "OpenEvent() failed: " << GetLastError(); } UniqueHandle child_file_mapping_event; - if ((child_file_mapping_event = MakeUniqueHandle(OpenEventA(EVENT_MODIFY_STATE, false, child_event_name.c_str()))) == nullptr) { + if ((child_file_mapping_event = MakeUniqueHandle( + OpenEventA(EVENT_MODIFY_STATE, false, child_event_name.c_str()))) == nullptr) { LOG(FATAL) << "OpenEvent() failed: " << GetLastError(); } - + // Wait for the parent to set the event, notifying WSAPROTOCOL_INFO is ready to be read - if (WaitForSingleObject(parent_file_mapping_event.get(), uint32_t(kEventTimeout.count())) != WAIT_OBJECT_0) { - LOG(FATAL) << "WaitForSingleObject() failed: " << GetLastError(); + if (WaitForSingleObject(parent_file_mapping_event.get(), uint32_t(kEventTimeout.count())) != + WAIT_OBJECT_0) { + LOG(FATAL) << "WaitForSingleObject() failed: " << GetLastError(); } - const UniqueHandle file_map = MakeUniqueHandle(OpenFileMappingA(FILE_MAP_READ | FILE_MAP_WRITE, - false, - mmap_path.c_str())); + const UniqueHandle file_map = + MakeUniqueHandle(OpenFileMappingA(FILE_MAP_READ | FILE_MAP_WRITE, false, mmap_path.c_str())); if (!file_map) { - LOG(INFO) << "CreateFileMapping() failed: " << GetLastError(); + LOG(INFO) << "CreateFileMapping() failed: " << GetLastError(); } - void* map_view = MapViewOfFile(file_map.get(), - FILE_MAP_READ | FILE_MAP_WRITE, - 0, 0, 0); + void* map_view = MapViewOfFile(file_map.get(), FILE_MAP_READ | FILE_MAP_WRITE, 0, 0, 0); SOCKET sock_duplicated = INVALID_SOCKET; @@ -120,12 +123,8 @@ SOCKET GetSocket(const std::string& mmap_path) { UnmapViewOfFile(map_view); // Creates the duplicate socket, that was created in the parent - sock_duplicated = WSASocket(FROM_PROTOCOL_INFO, - FROM_PROTOCOL_INFO, - FROM_PROTOCOL_INFO, - &protocol_info, - 0, - 0); + sock_duplicated = + WSASocket(FROM_PROTOCOL_INFO, FROM_PROTOCOL_INFO, FROM_PROTOCOL_INFO, &protocol_info, 0, 0); // Let the parent know we are finished dupicating the socket SetEvent(child_file_mapping_event.get()); @@ -135,7 +134,7 @@ SOCKET GetSocket(const std::string& mmap_path) { return sock_duplicated; } -}// Anonymous namespace +} // Anonymous namespace namespace tvm { namespace runtime { @@ -146,7 +145,7 @@ namespace runtime { */ void SpawnRPCChild(SOCKET fd, seconds timeout) { STARTUPINFOA startup_info; - + memset(&startup_info, 0, sizeof(startup_info)); startup_info.cb = sizeof(startup_info); @@ -157,13 +156,15 @@ void SpawnRPCChild(SOCKET fd, seconds timeout) { // Create an event to let the child know the socket info was set to the mmap file UniqueHandle parent_file_mapping_event; - if ((parent_file_mapping_event = MakeUniqueHandle(CreateEventA(nullptr, true, false, parent_event_name.c_str()))) == nullptr) { + if ((parent_file_mapping_event = MakeUniqueHandle( + CreateEventA(nullptr, true, false, parent_event_name.c_str()))) == nullptr) { LOG(FATAL) << "CreateEvent for parent file mapping failed"; } UniqueHandle child_file_mapping_event; // An event to let the parent know the socket info was read from the mmap file - if ((child_file_mapping_event = MakeUniqueHandle(CreateEventA(nullptr, true, false, child_event_name.c_str()))) == nullptr) { + if ((child_file_mapping_event = MakeUniqueHandle( + CreateEventA(nullptr, true, false, child_event_name.c_str()))) == nullptr) { LOG(FATAL) << "CreateEvent for child file mapping failed"; } @@ -181,35 +182,22 @@ void SpawnRPCChild(SOCKET fd, seconds timeout) { strcpy(command_line_ptr.get(), child_command_line.c_str()); PROCESS_INFORMATION child_process_info; - if (CreateProcessA(nullptr, - command_line_ptr.get(), - nullptr, - nullptr, - false, - CREATE_NO_WINDOW, - nullptr, - nullptr, - &startup_info, - &child_process_info)) { + if (CreateProcessA(nullptr, command_line_ptr.get(), nullptr, nullptr, false, CREATE_NO_WINDOW, + nullptr, nullptr, &startup_info, &child_process_info)) { // Child process and thread handles must be closed, so wrapped in RAII auto child_process_handle = MakeUniqueHandle(child_process_info.hProcess); auto child_process_thread_handle = MakeUniqueHandle(child_process_info.hThread); WSAPROTOCOL_INFO protocol_info; // Get info needed to duplicate the socket - if (WSADuplicateSocket(fd, - child_process_info.dwProcessId, - &protocol_info) == SOCKET_ERROR) { + if (WSADuplicateSocket(fd, child_process_info.dwProcessId, &protocol_info) == SOCKET_ERROR) { LOG(FATAL) << "WSADuplicateSocket(): failed. Error =" << WSAGetLastError(); } // Create a mmap file to store the info needed for duplicating the SOCKET in the child proc - UniqueHandle file_map = MakeUniqueHandle(CreateFileMappingA(INVALID_HANDLE_VALUE, - nullptr, - PAGE_READWRITE, - 0, - sizeof(WSAPROTOCOL_INFO), - file_map_path.c_str())); + UniqueHandle file_map = + MakeUniqueHandle(CreateFileMappingA(INVALID_HANDLE_VALUE, nullptr, PAGE_READWRITE, 0, + sizeof(WSAPROTOCOL_INFO), file_map_path.c_str())); if (!file_map) { LOG(INFO) << "CreateFileMapping() failed: " << GetLastError(); } @@ -225,11 +213,13 @@ void SpawnRPCChild(SOCKET fd, seconds timeout) { // Let child proc know the mmap file is ready to be read SetEvent(parent_file_mapping_event.get()); - + // Wait for the child to finish reading mmap file - if (WaitForSingleObject(child_file_mapping_event.get(), uint32_t(kEventTimeout.count())) != WAIT_OBJECT_0) { + if (WaitForSingleObject(child_file_mapping_event.get(), uint32_t(kEventTimeout.count())) != + WAIT_OBJECT_0) { TerminateProcess(child_process_handle.get(), 0); - LOG(FATAL) << "WaitForSingleObject for child file mapping timed out. Terminating child process."; + LOG(FATAL) << "WaitForSingleObject for child file mapping timed out. Terminating child " + "process."; } } else { TerminateProcess(child_process_handle.get(), 0); @@ -237,9 +227,8 @@ void SpawnRPCChild(SOCKET fd, seconds timeout) { } } - const DWORD process_timeout = timeout.count() - ? uint32_t(duration_cast(timeout).count()) - : INFINITE; + const DWORD process_timeout = + timeout.count() ? uint32_t(duration_cast(timeout).count()) : INFINITE; // Wait for child process to exit, or hit configured timeout if (WaitForSingleObject(child_process_handle.get(), process_timeout) != WAIT_OBJECT_0) { @@ -251,8 +240,9 @@ void SpawnRPCChild(SOCKET fd, seconds timeout) { } } /*! - * \brief ChildProcSocketHandler Ran from the child process and runs server to handle the client socket - * \param mmap_path The memory mapped file path that will contain the information to duplicate the client socket from the parent + * \brief ChildProcSocketHandler Ran from the child process and runs server to handle the client + * socket \param mmap_path The memory mapped file path that will contain the information to + * duplicate the client socket from the parent */ void ChildProcSocketHandler(const std::string& mmap_path) { SOCKET socket; @@ -260,14 +250,12 @@ void ChildProcSocketHandler(const std::string& mmap_path) { // Set high thread priority to avoid the thread scheduler from // interfering with any measurements in the RPC server. SetThreadPriority(GetCurrentThread(), THREAD_PRIORITY_TIME_CRITICAL); - + if ((socket = GetSocket(mmap_path)) != INVALID_SOCKET) { tvm::runtime::ServerLoopFromChild(socket); - } - else { + } else { LOG(FATAL) << "GetSocket() failed"; } - } } // namespace runtime } // namespace tvm \ No newline at end of file diff --git a/apps/cpp_rpc/win32_process.h b/apps/cpp_rpc/win32_process.h index 7d1a27680ed3..621444e18764 100644 --- a/apps/cpp_rpc/win32_process.h +++ b/apps/cpp_rpc/win32_process.h @@ -17,10 +17,10 @@ * under the License. */ - /*! - * \file win32_process.h - * \brief Win32 process code to mimic a POSIX fork() - */ +/*! + * \file win32_process.h + * \brief Win32 process code to mimic a POSIX fork() + */ #ifndef TVM_APPS_CPP_RPC_WIN32_PROCESS_H_ #define TVM_APPS_CPP_RPC_WIN32_PROCESS_H_ #include @@ -34,8 +34,9 @@ namespace runtime { */ void SpawnRPCChild(SOCKET fd, std::chrono::seconds timeout); /*! - * \brief ChildProcSocketHandler Ran from the child process and runs server to handle the client socket - * \param mmap_path The memory mapped file path that will contain the information to duplicate the client socket from the parent + * \brief ChildProcSocketHandler Ran from the child process and runs server to handle the client + * socket \param mmap_path The memory mapped file path that will contain the information to + * duplicate the client socket from the parent */ void ChildProcSocketHandler(const std::string& mmap_path); } // namespace runtime diff --git a/apps/dso_plugin_module/plugin_module.cc b/apps/dso_plugin_module/plugin_module.cc index 7c3c5accf1ec..eed11f855693 100644 --- a/apps/dso_plugin_module/plugin_module.cc +++ b/apps/dso_plugin_module/plugin_module.cc @@ -20,10 +20,10 @@ * \brief Example code that can be compiled and loaded by TVM runtime. * \file plugin_module.cc */ -#include #include -#include #include +#include +#include namespace tvm_dso_plugin { @@ -31,24 +31,16 @@ using namespace tvm::runtime; class MyModuleNode : public ModuleNode { public: - explicit MyModuleNode(int value) - : value_(value) {} + explicit MyModuleNode(int value) : value_(value) {} - virtual const char* type_key() const final { - return "MyModule"; - } + virtual const char* type_key() const final { return "MyModule"; } - virtual PackedFunc GetFunction( - const std::string& name, - const ObjectPtr& sptr_to_self) final { + virtual PackedFunc GetFunction(const std::string& name, + const ObjectPtr& sptr_to_self) final { if (name == "add") { - return TypedPackedFunc([sptr_to_self, this](int value) { - return value_ + value; - }); + return TypedPackedFunc([sptr_to_self, this](int value) { return value_ + value; }); } else if (name == "mul") { - return TypedPackedFunc([sptr_to_self, this](int value) { - return value_ * value; - }); + return TypedPackedFunc([sptr_to_self, this](int value) { return value_ * value; }); } else { LOG(FATAL) << "unknown function " << name; return PackedFunc(); @@ -64,18 +56,14 @@ void CreateMyModule_(TVMArgs args, TVMRetValue* rv) { *rv = Module(make_object(value)); } -int SubOne_(int x) { - return x - 1; -} +int SubOne_(int x) { return x - 1; } // USE TVM_DLL_EXPORT_TYPED_PACKED_FUNC to export a // typed function as packed function. TVM_DLL_EXPORT_TYPED_FUNC(SubOne, SubOne_); // TVM_DLL_EXPORT_TYPED_PACKED_FUNC also works for lambda. -TVM_DLL_EXPORT_TYPED_FUNC(AddOne, [](int x) -> int { - return x + 1; -}); +TVM_DLL_EXPORT_TYPED_FUNC(AddOne, [](int x) -> int { return x + 1; }); // Use TVM_EXPORT_PACKED_FUNC to export a function with TVM_DLL_EXPORT_PACKED_FUNC(CreateMyModule, tvm_dso_plugin::CreateMyModule_); diff --git a/apps/extension/python/tvm_ext/__init__.py b/apps/extension/python/tvm_ext/__init__.py index 377db7c1c6ea..1df304a67e2a 100644 --- a/apps/extension/python/tvm_ext/__init__.py +++ b/apps/extension/python/tvm_ext/__init__.py @@ -58,7 +58,7 @@ def __getitem__(self, idx): class NDSubClass(tvm.nd.NDArrayBase): """Example for subclassing TVM's NDArray infrastructure. - By inheriting TMV's NDArray, external libraries could + By inheriting TVM's NDArray, external libraries could leverage TVM's FFI without any modification. """ diff --git a/apps/extension/src/tvm_ext.cc b/apps/extension/src/tvm_ext.cc index a92d55fc4acd..87cb69b4f4ce 100644 --- a/apps/extension/src/tvm_ext.cc +++ b/apps/extension/src/tvm_ext.cc @@ -17,16 +17,15 @@ * under the License. */ - /*! * \brief Example package that uses TVM. * \file tvm_ext.cc */ -#include +#include #include -#include #include -#include +#include +#include #include using namespace tvm; @@ -50,8 +49,7 @@ class NDSubClass : public tvm::runtime::NDArray { public: class SubContainer : public NDArray::Container { public: - SubContainer(int additional_info) : - additional_info_(additional_info) { + SubContainer(int additional_info) : additional_info_(additional_info) { type_index_ = SubContainer::RuntimeTypeIndex(); } int additional_info_{0}; @@ -74,14 +72,14 @@ class NDSubClass : public tvm::runtime::NDArray { data_ = GetObjectPtr(ptr); } - NDSubClass AddWith(const NDSubClass &other) const { - SubContainer *a = static_cast(get_mutable()); - SubContainer *b = static_cast(other.get_mutable()); + NDSubClass AddWith(const NDSubClass& other) const { + SubContainer* a = static_cast(get_mutable()); + SubContainer* b = static_cast(other.get_mutable()); CHECK(a != nullptr && b != nullptr); return NDSubClass(a->additional_info_ + b->additional_info_); } int get_additional_info() const { - SubContainer *self = static_cast(get_mutable()); + SubContainer* self = static_cast(get_mutable()); CHECK(self != nullptr); return self->additional_info_; } @@ -116,60 +114,48 @@ TVM_REGISTER_OBJECT_TYPE(IntVectorObj); namespace tvm_ext { -TVM_REGISTER_GLOBAL("tvm_ext.ivec_create") -.set_body([](TVMArgs args, TVMRetValue *rv) { - auto n = tvm::runtime::make_object(); - for (int i = 0; i < args.size(); ++i) { - n->vec.push_back(args[i].operator int()); - } - *rv = IntVector(n); - }); - -TVM_REGISTER_GLOBAL("tvm_ext.ivec_get") -.set_body([](TVMArgs args, TVMRetValue *rv) { - IntVector p = args[0]; - *rv = p->vec[args[1].operator int()]; - }); - - -TVM_REGISTER_GLOBAL("tvm_ext.bind_add") -.set_body([](TVMArgs args_, TVMRetValue *rv_) { - PackedFunc pf = args_[0]; - int b = args_[1]; - *rv_ = PackedFunc([pf, b](TVMArgs args, TVMRetValue *rv) { - *rv = pf(b, args[0]); - }); - }); - -TVM_REGISTER_GLOBAL("tvm_ext.sym_add") -.set_body([](TVMArgs args, TVMRetValue *rv) { - Var a = args[0]; - Var b = args[1]; - *rv = a + b; - }); - -TVM_REGISTER_GLOBAL("device_api.ext_dev") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = (*tvm::runtime::Registry::Get("device_api.cpu"))(); - }); - -TVM_REGISTER_GLOBAL("tvm_ext.nd_create") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("tvm_ext.ivec_create").set_body([](TVMArgs args, TVMRetValue* rv) { + auto n = tvm::runtime::make_object(); + for (int i = 0; i < args.size(); ++i) { + n->vec.push_back(args[i].operator int()); + } + *rv = IntVector(n); +}); + +TVM_REGISTER_GLOBAL("tvm_ext.ivec_get").set_body([](TVMArgs args, TVMRetValue* rv) { + IntVector p = args[0]; + *rv = p->vec[args[1].operator int()]; +}); + +TVM_REGISTER_GLOBAL("tvm_ext.bind_add").set_body([](TVMArgs args_, TVMRetValue* rv_) { + PackedFunc pf = args_[0]; + int b = args_[1]; + *rv_ = PackedFunc([pf, b](TVMArgs args, TVMRetValue* rv) { *rv = pf(b, args[0]); }); +}); + +TVM_REGISTER_GLOBAL("tvm_ext.sym_add").set_body([](TVMArgs args, TVMRetValue* rv) { + Var a = args[0]; + Var b = args[1]; + *rv = a + b; +}); + +TVM_REGISTER_GLOBAL("device_api.ext_dev").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = (*tvm::runtime::Registry::Get("device_api.cpu"))(); +}); + +TVM_REGISTER_GLOBAL("tvm_ext.nd_create").set_body([](TVMArgs args, TVMRetValue* rv) { int additional_info = args[0]; *rv = NDSubClass(additional_info); CHECK_EQ(rv->type_code(), kTVMNDArrayHandle); - }); -TVM_REGISTER_GLOBAL("tvm_ext.nd_add_two") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("tvm_ext.nd_add_two").set_body([](TVMArgs args, TVMRetValue* rv) { NDSubClass a = args[0]; NDSubClass b = args[1]; *rv = a.AddWith(b); }); -TVM_REGISTER_GLOBAL("tvm_ext.nd_get_additional_info") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("tvm_ext.nd_get_additional_info").set_body([](TVMArgs args, TVMRetValue* rv) { NDSubClass a = args[0]; *rv = a.get_additional_info(); }); @@ -177,17 +163,14 @@ TVM_REGISTER_GLOBAL("tvm_ext.nd_get_additional_info") } // namespace tvm_ext // External function exposed to runtime. -extern "C" float TVMTestAddOne(float y) { - return y + 1; -} +extern "C" float TVMTestAddOne(float y) { return y + 1; } // This callback approach allows extension allows tvm to extract // This way can be helpful when we want to use a header only // minimum version of TVM Runtime. extern "C" int TVMExtDeclare(TVMFunctionHandle pregister) { - const PackedFunc& fregister = - *static_cast(pregister); - auto mul = [](TVMArgs args, TVMRetValue *rv) { + const PackedFunc& fregister = *static_cast(pregister); + auto mul = [](TVMArgs args, TVMRetValue* rv) { int x = args[0]; int y = args[1]; *rv = x * y; diff --git a/apps/howto_deploy/cpp_deploy.cc b/apps/howto_deploy/cpp_deploy.cc index a386dffa0b30..b7a60f49d917 100644 --- a/apps/howto_deploy/cpp_deploy.cc +++ b/apps/howto_deploy/cpp_deploy.cc @@ -21,11 +21,12 @@ * \brief Example code on load and run TVM module.s * \file cpp_deploy.cc */ -#include #include #include -#include #include +#include + +#include void Verify(tvm::runtime::Module mod, std::string fname) { // Get the function from the module. @@ -52,10 +53,8 @@ void Verify(tvm::runtime::Module mod, std::string fname) { int device_type = kDLCPU; int device_id = 0; int64_t shape[1] = {10}; - TVMArrayAlloc(shape, ndim, dtype_code, dtype_bits, dtype_lanes, - device_type, device_id, &x); - TVMArrayAlloc(shape, ndim, dtype_code, dtype_bits, dtype_lanes, - device_type, device_id, &y); + TVMArrayAlloc(shape, ndim, dtype_code, dtype_bits, dtype_lanes, device_type, device_id, &x); + TVMArrayAlloc(shape, ndim, dtype_code, dtype_bits, dtype_lanes, device_type, device_id, &y); for (int i = 0; i < shape[0]; ++i) { static_cast(x->data)[i] = i; } @@ -72,8 +71,7 @@ void Verify(tvm::runtime::Module mod, std::string fname) { int main(void) { // Normally we can directly - tvm::runtime::Module mod_dylib = - tvm::runtime::Module::LoadFromFile("lib/test_addone_dll.so"); + tvm::runtime::Module mod_dylib = tvm::runtime::Module::LoadFromFile("lib/test_addone_dll.so"); LOG(INFO) << "Verify dynamic loading from test_addone_dll.so"; Verify(mod_dylib, "addone"); // For libraries that are directly packed as system lib and linked together with the app diff --git a/apps/howto_deploy/tvm_runtime_pack.cc b/apps/howto_deploy/tvm_runtime_pack.cc index 81bab497bebb..37e3968ca312 100644 --- a/apps/howto_deploy/tvm_runtime_pack.cc +++ b/apps/howto_deploy/tvm_runtime_pack.cc @@ -39,15 +39,15 @@ */ #include "../../src/runtime/c_runtime_api.cc" #include "../../src/runtime/cpu_device_api.cc" -#include "../../src/runtime/workspace_pool.cc" +#include "../../src/runtime/file_util.cc" #include "../../src/runtime/library_module.cc" #include "../../src/runtime/module.cc" -#include "../../src/runtime/registry.cc" -#include "../../src/runtime/file_util.cc" -#include "../../src/runtime/threading_backend.cc" -#include "../../src/runtime/thread_pool.cc" #include "../../src/runtime/ndarray.cc" #include "../../src/runtime/object.cc" +#include "../../src/runtime/registry.cc" +#include "../../src/runtime/thread_pool.cc" +#include "../../src/runtime/threading_backend.cc" +#include "../../src/runtime/workspace_pool.cc" // NOTE: all the files after this are optional modules // that you can include remove, depending on how much feature you use. diff --git a/apps/ios_rpc/tests/ios_rpc_mobilenet.py b/apps/ios_rpc/tests/ios_rpc_mobilenet.py new file mode 100644 index 000000000000..e8f81ffddcec --- /dev/null +++ b/apps/ios_rpc/tests/ios_rpc_mobilenet.py @@ -0,0 +1,171 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +from tvm import rpc, relay +from tvm.contrib.download import download_testdata +from tvm.relay.expr_functor import ExprMutator +from tvm.relay import transform +from tvm.relay.op.annotation import compiler_begin, compiler_end +from tvm.relay.quantize.quantize import prerequisite_optimize +from tvm.contrib import util, xcode, graph_runtime, coreml_runtime +from tvm.contrib.target import coreml as _coreml + +import os +import re +import sys +import numpy as np +from mxnet import gluon +from PIL import Image +import coremltools + +# Set to be address of tvm proxy. +proxy_host = os.environ["TVM_IOS_RPC_PROXY_HOST"] +# Set your desination via env variable. +# Should in format "platform=iOS,id=" +destination = os.environ["TVM_IOS_RPC_DESTINATION"] + +if not re.match(r"^platform=.*,id=.*$", destination): + print("Bad format: {}".format(destination)) + print("Example of expected string: platform=iOS,id=1234567890abcabcabcabc1234567890abcabcab") + sys.exit(1) + +proxy_port = 9090 +key = "iphone" + +# Change target configuration, this is setting for iphone6s +#arch = "x86_64" +#sdk = "iphonesimulator" +arch = "arm64" +sdk = "iphoneos" +target_host = "llvm -target=%s-apple-darwin" % arch + +# override metal compiler to compile to iphone +@tvm.register_func("tvm_callback_metal_compile") +def compile_metal(src): + return xcode.compile_metal(src, sdk=sdk) + +def prepare_input(): + img_url = 'https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true' + img_name = 'cat.png' + synset_url = ''.join(['https://gist.githubusercontent.com/zhreshold/', + '4d0b62f3d01426887599d4f7ede23ee5/raw/', + '596b27d23537e5a1b5751d2b0481ef172f58b539/', + 'imagenet1000_clsid_to_human.txt']) + synset_name = 'imagenet1000_clsid_to_human.txt' + img_path = download_testdata(img_url, 'cat.png', module='data') + synset_path = download_testdata(synset_url, synset_name, module='data') + with open(synset_path) as f: + synset = eval(f.read()) + image = Image.open(img_path).resize((224, 224)) + + image = np.array(image) - np.array([123., 117., 104.]) + image /= np.array([58.395, 57.12, 57.375]) + image = image.transpose((2, 0, 1)) + image = image[np.newaxis, :] + return image.astype('float32'), synset + + +def get_model(model_name, data_shape): + gluon_model = gluon.model_zoo.vision.get_model(model_name, pretrained=True) + mod, params = relay.frontend.from_mxnet(gluon_model, {"data": data_shape}) + # we want a probability so add a softmax operator + func = mod["main"] + func = relay.Function(func.params, relay.nn.softmax(func.body), None, func.type_params, func.attrs) + + return func, params + + +def test_mobilenet(): + temp = util.tempdir() + image, synset = prepare_input() + model, params = get_model('mobilenetv2_1.0', image.shape) + + def run(mod, target): + with relay.build_config(opt_level=3): + graph, lib, _params = relay.build(mod, target=target, + target_host=target_host, params=params) + path_dso = temp.relpath("deploy.dylib") + lib.export_library(path_dso, xcode.create_dylib, arch=arch, sdk=sdk) + xcode.codesign(path_dso) + + # Start RPC test server that contains the compiled library. + xcode.popen_test_rpc(proxy_host, proxy_port, key, + destination=destination, libs=[path_dso]) + + # connect to the proxy + remote = rpc.connect(proxy_host, proxy_port, key=key) + + if target == "metal": + ctx = remote.metal(0) + else: + ctx = remote.cpu(0) + lib = remote.load_module("deploy.dylib") + m = graph_runtime.create(graph, lib, ctx) + + m.set_input('data', tvm.nd.array(image, ctx)) + m.set_input(**_params) + m.run() + tvm_output = m.get_output(0) + top1 = np.argmax(tvm_output.asnumpy()[0]) + print('TVM prediction top-1:', top1, synset[top1]) + + # evaluate + ftimer = m.module.time_evaluator("run", ctx, number=3, repeat=10) + prof_res = np.array(ftimer().results) * 1000 + print("%-19s (%s)" % ("%.2f ms" % np.mean(prof_res), "%.2f ms" % np.std(prof_res))) + + def annotate(func, compiler): + """ + An annotator for Core ML. + """ + # Bind free variables to the constant values. + bind_dict = {} + for arg in func.params: + name = arg.name_hint + if name in params: + bind_dict[arg] = relay.const(params[name]) + + func = relay.bind(func, bind_dict) + + # Annotate the entire graph for Core ML + mod = tvm.IRModule() + mod["main"] = func + + seq = tvm.transform.Sequential([ + transform.SimplifyInference(), + transform.FoldConstant(), + transform.FoldScaleAxis(), + transform.AnnotateTarget(compiler), + transform.MergeCompilerRegions(), + transform.PartitionGraph() + ]) + + with relay.build_config(opt_level=3): + mod = seq(mod) + + return mod + + # CPU + run(model, target_host) + # Metal + run(model, "metal") + # CoreML + run(annotate(model, "coremlcompiler"), target_host) + +if __name__ == "__main__": + test_mobilenet() diff --git a/apps/ios_rpc/tvmrpc.xcodeproj/project.pbxproj b/apps/ios_rpc/tvmrpc.xcodeproj/project.pbxproj index f635d2c5cf19..b33c892cf002 100644 --- a/apps/ios_rpc/tvmrpc.xcodeproj/project.pbxproj +++ b/apps/ios_rpc/tvmrpc.xcodeproj/project.pbxproj @@ -34,6 +34,7 @@ C02637661F1C2690007247A9 /* TVMRuntime.mm in Sources */ = {isa = PBXBuildFile; fileRef = C02637651F1C2690007247A9 /* TVMRuntime.mm */; }; C02637691F1C26AF007247A9 /* ViewController.mm in Sources */ = {isa = PBXBuildFile; fileRef = C02637681F1C26AF007247A9 /* ViewController.mm */; }; C05A2C891F1DCE0900D4798B /* tvmrpcLauncher.mm in Sources */ = {isa = PBXBuildFile; fileRef = C05A2C881F1DCE0900D4798B /* tvmrpcLauncher.mm */; }; + D7685AD324390EAE00D1469C /* CoreML.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = D7685AD224390EAD00D1469C /* CoreML.framework */; }; /* End PBXBuildFile section */ /* Begin PBXContainerItemProxy section */ @@ -62,6 +63,7 @@ C05A2C861F1DCE0900D4798B /* tvmrpcLauncher.xctest */ = {isa = PBXFileReference; explicitFileType = wrapper.cfbundle; includeInIndex = 0; path = tvmrpcLauncher.xctest; sourceTree = BUILT_PRODUCTS_DIR; }; C05A2C881F1DCE0900D4798B /* tvmrpcLauncher.mm */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.objcpp; path = tvmrpcLauncher.mm; sourceTree = ""; }; C05A2C8A1F1DCE0900D4798B /* Info.plist */ = {isa = PBXFileReference; lastKnownFileType = text.plist.xml; path = Info.plist; sourceTree = ""; }; + D7685AD224390EAD00D1469C /* CoreML.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = CoreML.framework; path = System/Library/Frameworks/CoreML.framework; sourceTree = SDKROOT; }; /* End PBXFileReference section */ /* Begin PBXFrameworksBuildPhase section */ @@ -69,6 +71,7 @@ isa = PBXFrameworksBuildPhase; buildActionMask = 2147483647; files = ( + D7685AD324390EAE00D1469C /* CoreML.framework in Frameworks */, ); runOnlyForDeploymentPostprocessing = 0; }; @@ -88,6 +91,7 @@ C026374D1F1C25E8007247A9 /* tvmrpc */, C05A2C871F1DCE0900D4798B /* tvmrpcLauncher */, C026374C1F1C25E8007247A9 /* Products */, + D7685AD124390EAD00D1469C /* Frameworks */, ); indentWidth = 2; sourceTree = ""; @@ -137,6 +141,14 @@ path = tvmrpcLauncher; sourceTree = ""; }; + D7685AD124390EAD00D1469C /* Frameworks */ = { + isa = PBXGroup; + children = ( + D7685AD224390EAD00D1469C /* CoreML.framework */, + ); + name = Frameworks; + sourceTree = ""; + }; /* End PBXGroup section */ /* Begin PBXNativeTarget section */ @@ -249,7 +261,7 @@ ); runOnlyForDeploymentPostprocessing = 0; shellPath = /bin/sh; - shellScript = "libpath=${CONFIGURATION_BUILD_DIR}/${CONTENTS_FOLDER_PATH}/Frameworks/tvm\nmkdir -p ${libpath}\nrm -rf ${libpath}/*\n \nif [ -f ${SRCROOT}/rpc_config.txt ]; then\n head -n 1 ${SRCROOT}/rpc_config.txt > ${libpath}/rpc_config.txt\n tail -n +2 ${SRCROOT}/rpc_config.txt | xargs -J % cp % ${libpath}\nfi\n\n"; + shellScript = "libpath=${CONFIGURATION_BUILD_DIR}/${CONTENTS_FOLDER_PATH}/Frameworks/tvm\nmkdir -p ${libpath}\nrm -rf ${libpath}/*\n \nif [ -f ${SRCROOT}/rpc_config.txt ]; then\n head -n 1 ${SRCROOT}/rpc_config.txt > ${libpath}/rpc_config.txt\n tail -n +2 ${SRCROOT}/rpc_config.txt | xargs -J % cp -r % ${libpath}\nfi\n\n"; }; /* End PBXShellScriptBuildPhase section */ @@ -309,7 +321,7 @@ ALWAYS_SEARCH_USER_PATHS = NO; CLANG_ANALYZER_NONNULL = YES; CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; - CLANG_CXX_LANGUAGE_STANDARD = "gnu++0x"; + CLANG_CXX_LANGUAGE_STANDARD = "gnu++14"; CLANG_CXX_LIBRARY = "libc++"; CLANG_ENABLE_MODULES = YES; CLANG_ENABLE_OBJC_ARC = YES; @@ -358,7 +370,7 @@ ALWAYS_SEARCH_USER_PATHS = NO; CLANG_ANALYZER_NONNULL = YES; CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; - CLANG_CXX_LANGUAGE_STANDARD = "gnu++0x"; + CLANG_CXX_LANGUAGE_STANDARD = "gnu++14"; CLANG_CXX_LIBRARY = "libc++"; CLANG_ENABLE_MODULES = YES; CLANG_ENABLE_OBJC_ARC = YES; diff --git a/apps/ios_rpc/tvmrpc/AppDelegate.h b/apps/ios_rpc/tvmrpc/AppDelegate.h index 0c54a47e7a2d..a810aeafa47f 100644 --- a/apps/ios_rpc/tvmrpc/AppDelegate.h +++ b/apps/ios_rpc/tvmrpc/AppDelegate.h @@ -25,7 +25,6 @@ @interface AppDelegate : UIResponder -@property (strong, nonatomic) UIWindow *window; - +@property(strong, nonatomic) UIWindow* window; @end diff --git a/apps/ios_rpc/tvmrpc/TVMRuntime.h b/apps/ios_rpc/tvmrpc/TVMRuntime.h index 96a5c1bfa318..f6a6dc64c53a 100644 --- a/apps/ios_rpc/tvmrpc/TVMRuntime.h +++ b/apps/ios_rpc/tvmrpc/TVMRuntime.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -25,8 +25,8 @@ #define DMLC_LOG_CUSTOMIZE 1 #define TVM_METAL_RUNTIME 1 -#include #include +#include #include namespace tvm { @@ -52,8 +52,7 @@ using FEventHandler = std::function(data) - maxLength:size]; + ssize_t nbytes = [stream_ write:reinterpret_cast(data) maxLength:size]; if (nbytes < 0) { - NSLog(@"%@",[stream_ streamError].localizedDescription); + NSLog(@"%@", [stream_ streamError].localizedDescription); throw dmlc::Error("Stream error"); } return nbytes; @@ -81,12 +82,12 @@ size_t Recv(void* data, size_t size) final { NSOutputStream* stream_; }; -FEventHandler CreateServerEventHandler( - NSOutputStream *outputStream, std::string name, std::string remote_key) { +FEventHandler CreateServerEventHandler(NSOutputStream* outputStream, std::string name, + std::string remote_key) { std::unique_ptr ch(new NSStreamChannel(outputStream)); - std::shared_ptr sess = RPCSession::Create(std::move(ch), name, remote_key); + std::shared_ptr sess = RPCEndpoint::Create(std::move(ch), name, remote_key); return [sess](const std::string& in_bytes, int flag) { - return sess->ServerEventHandler(in_bytes, flag); + return sess->ServerAsyncIOEventHandler(in_bytes, flag); }; } @@ -101,9 +102,7 @@ FEventHandler CreateServerEventHandler( } } // Get Path. - std::string GetPath(const std::string& file_name) { - return base_ + file_name; - } + std::string GetPath(const std::string& file_name) { return base_ + file_name; } private: std::string base_; @@ -113,49 +112,44 @@ void LaunchSyncServer() { // only load dylib from frameworks. NSBundle* bundle = [NSBundle mainBundle]; NSString* base = [bundle privateFrameworksPath]; - NSString* path = [base stringByAppendingPathComponent: @"tvm/rpc_config.txt"]; + NSString* path = [base stringByAppendingPathComponent:@"tvm/rpc_config.txt"]; std::string name = [path UTF8String]; std::ifstream fs(name, std::ios::in); std::string url, key; int port; - CHECK(fs >> url >> port >> key) - << "Invalid RPC config file " << name; - RPCConnect(url, port, "server:" + key) - ->ServerLoop(); + CHECK(fs >> url >> port >> key) << "Invalid RPC config file " << name; + RPCConnect(url, port, "server:" + key, TVMArgs(nullptr, nullptr, 0))->ServerLoop(); } -TVM_REGISTER_GLOBAL("tvm.rpc.server.workpath") -.set_body([](TVMArgs args, TVMRetValue* rv) { - static RPCEnv env; - *rv = env.GetPath(args[0]); - }); - -TVM_REGISTER_GLOBAL("tvm.rpc.server.load_module") -.set_body([](TVMArgs args, TVMRetValue *rv) { - std::string name = args[0]; - std::string fmt = GetFileFormat(name, ""); - NSString* base; - if (fmt == "dylib") { - // only load dylib from frameworks. - NSBundle* bundle = [NSBundle mainBundle]; - base = [[bundle privateFrameworksPath] - stringByAppendingPathComponent: @"tvm"]; - } else { - // Load other modules in tempdir. - base = NSTemporaryDirectory(); - } - NSString* path = [base stringByAppendingPathComponent: - [NSString stringWithUTF8String:name.c_str()]]; - name = [path UTF8String]; - *rv = Module::LoadFromFile(name, fmt); - LOG(INFO) << "Load module from " << name << " ..."; - }); +TVM_REGISTER_GLOBAL("tvm.rpc.server.workpath").set_body([](TVMArgs args, TVMRetValue* rv) { + static RPCEnv env; + *rv = env.GetPath(args[0]); +}); + +TVM_REGISTER_GLOBAL("tvm.rpc.server.load_module").set_body([](TVMArgs args, TVMRetValue* rv) { + std::string name = args[0]; + std::string fmt = GetFileFormat(name, ""); + NSString* base; + if (fmt == "dylib") { + // only load dylib from frameworks. + NSBundle* bundle = [NSBundle mainBundle]; + base = [[bundle privateFrameworksPath] stringByAppendingPathComponent:@"tvm"]; + } else { + // Load other modules in tempdir. + base = NSTemporaryDirectory(); + } + NSString* path = + [base stringByAppendingPathComponent:[NSString stringWithUTF8String:name.c_str()]]; + name = [path UTF8String]; + *rv = Module::LoadFromFile(name, fmt); + LOG(INFO) << "Load module from " << name << " ..."; +}); } // namespace runtime } // namespace tvm @implementation TVMRuntime -+(void) launchSyncServer { ++ (void)launchSyncServer { tvm::runtime::LaunchSyncServer(); } diff --git a/apps/ios_rpc/tvmrpc/ViewController.h b/apps/ios_rpc/tvmrpc/ViewController.h index 3a3c928f8112..b188a87b20d3 100644 --- a/apps/ios_rpc/tvmrpc/ViewController.h +++ b/apps/ios_rpc/tvmrpc/ViewController.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -24,12 +24,11 @@ #import #include "TVMRuntime.h" -@interface ViewController : UIViewController -{ +@interface ViewController : UIViewController { // input socket stream - NSInputStream *inputStream_; + NSInputStream* inputStream_; // output socket stream - NSOutputStream *outputStream_; + NSOutputStream* outputStream_; // temporal receive buffer. std::string recvBuffer_; // Whether connection is initialized. @@ -46,11 +45,11 @@ tvm::runtime::FEventHandler handler_; } -@property (weak, nonatomic) IBOutlet UITextField *proxyURL; -@property (weak, nonatomic) IBOutlet UITextField *proxyPort; -@property (weak, nonatomic) IBOutlet UITextField *proxyKey; -@property (weak, nonatomic) IBOutlet UILabel *statusLabel; -@property (weak, nonatomic) IBOutlet UITextView *infoText; +@property(weak, nonatomic) IBOutlet UITextField* proxyURL; +@property(weak, nonatomic) IBOutlet UITextField* proxyPort; +@property(weak, nonatomic) IBOutlet UITextField* proxyKey; +@property(weak, nonatomic) IBOutlet UILabel* statusLabel; +@property(weak, nonatomic) IBOutlet UITextView* infoText; - (IBAction)connect:(id)sender; - (IBAction)disconnect:(id)sender; diff --git a/apps/ios_rpc/tvmrpc/ViewController.mm b/apps/ios_rpc/tvmrpc/ViewController.mm index 0f7611002042..6c618c48096f 100644 --- a/apps/ios_rpc/tvmrpc/ViewController.mm +++ b/apps/ios_rpc/tvmrpc/ViewController.mm @@ -21,12 +21,12 @@ * \file ViewController.mm */ -#include #import "ViewController.h" +#include @implementation ViewController -- (void)stream:(NSStream *)strm handleEvent:(NSStreamEvent)event { +- (void)stream:(NSStream*)strm handleEvent:(NSStreamEvent)event { std::string buffer; switch (event) { case NSStreamEventOpenCompleted: { @@ -45,7 +45,7 @@ - (void)stream:(NSStream *)strm handleEvent:(NSStreamEvent)event { break; } case NSStreamEventErrorOccurred: { - NSLog(@"%@",[strm streamError].localizedDescription); + NSLog(@"%@", [strm streamError].localizedDescription); break; } case NSStreamEventEndEncountered: { @@ -64,8 +64,7 @@ - (void)onReadAvailable { constexpr int kRPCMagic = 0xff271; if (!initialized_) { int code; - size_t nbytes = [inputStream_ read:reinterpret_cast(&code) - maxLength:sizeof(code)]; + size_t nbytes = [inputStream_ read:reinterpret_cast(&code) maxLength:sizeof(code)]; if (nbytes != sizeof(code)) { self.infoText.text = @"Fail to receive remote confirmation code."; [self close]; @@ -115,7 +114,7 @@ - (void)onShutdownReceived { - (void)onWriteAvailable { if (initSendPtr_ < initBytes_.length()) { initSendPtr_ += [outputStream_ write:reinterpret_cast(&initBytes_[initSendPtr_]) - maxLength:(initBytes_.length() - initSendPtr_)]; + maxLength:(initBytes_.length() - initSendPtr_)]; } if (initialized_) { try { @@ -148,13 +147,10 @@ - (void)open { // Initialize the network. CFReadStreamRef readStream; CFWriteStreamRef writeStream; - CFStreamCreatePairWithSocketToHost( - NULL, - (__bridge CFStringRef) self.proxyURL.text, - [self.proxyPort.text intValue], - &readStream, &writeStream); - inputStream_ = (__bridge_transfer NSInputStream *)readStream; - outputStream_ = (__bridge_transfer NSOutputStream *)writeStream; + CFStreamCreatePairWithSocketToHost(NULL, (__bridge CFStringRef)self.proxyURL.text, + [self.proxyPort.text intValue], &readStream, &writeStream); + inputStream_ = (__bridge_transfer NSInputStream*)readStream; + outputStream_ = (__bridge_transfer NSOutputStream*)writeStream; [inputStream_ setDelegate:self]; [outputStream_ setDelegate:self]; [inputStream_ scheduleInRunLoop:[NSRunLoop currentRunLoop] forMode:NSDefaultRunLoopMode]; diff --git a/apps/ios_rpc/tvmrpcLauncher/tvmrpcLauncher.mm b/apps/ios_rpc/tvmrpcLauncher/tvmrpcLauncher.mm index c4a6f8bd240f..eb538f07bf49 100644 --- a/apps/ios_rpc/tvmrpcLauncher/tvmrpcLauncher.mm +++ b/apps/ios_rpc/tvmrpcLauncher/tvmrpcLauncher.mm @@ -32,16 +32,15 @@ @interface tvmrpcLauncher : XCTestCase @implementation tvmrpcLauncher - (void)setUp { - [super setUp]; + [super setUp]; } - (void)tearDown { - [super tearDown]; + [super tearDown]; } - (void)testRPC { [TVMRuntime launchSyncServer]; } - @end diff --git a/apps/lldb/tvm.py b/apps/lldb/tvm.py index 135aeff5258a..fb5c4de1c4cc 100644 --- a/apps/lldb/tvm.py +++ b/apps/lldb/tvm.py @@ -36,7 +36,6 @@ def __lldb_init_module(debugger, _): "tvm::Attrs", "tvm::BijectiveLayout", "tvm::Buffer", - "tvm::BuildConfig", "tvm::Channel", "tvm::EnvFunc", "tvm::Expr", diff --git a/apps/rocm_rpc/rocm_runtime_pack.cc b/apps/rocm_rpc/rocm_runtime_pack.cc index a137a9b28f8a..de5c50452340 100644 --- a/apps/rocm_rpc/rocm_runtime_pack.cc +++ b/apps/rocm_rpc/rocm_runtime_pack.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -28,7 +28,7 @@ #define TVM_USE_MIOPEN 1 #define __HIP_PLATFORM_HCC__ 1 -#include "../../src/runtime/rocm/rocm_device_api.cc" -#include "../../src/runtime/rocm/rocm_module.cc" #include "../../src/contrib/miopen/conv_forward.cc" #include "../../src/contrib/miopen/miopen_utils.cc" +#include "../../src/runtime/rocm/rocm_device_api.cc" +#include "../../src/runtime/rocm/rocm_module.cc" diff --git a/apps/sgx/src/build_model.py b/apps/sgx/src/build_model.py index 6e0933efd381..b988574fc558 100755 --- a/apps/sgx/src/build_model.py +++ b/apps/sgx/src/build_model.py @@ -37,7 +37,7 @@ def main(): net, params = relay.testing.resnet.get_workload( layers=18, batch_size=dshape[0], image_shape=dshape[1:]) - with relay.build_config(opt_level=3): + with tvm.transform.PassContext(opt_level=3): graph, lib, params = relay.build( net, 'llvm --system-lib', params=params) diff --git a/cmake/config.cmake b/cmake/config.cmake index 8df86f495002..651fa7ce0445 100644 --- a/cmake/config.cmake +++ b/cmake/config.cmake @@ -154,6 +154,10 @@ set(USE_TFLITE OFF) # /path/to/tensorflow: tensorflow root path when use tflite library set(USE_TENSORFLOW_PATH none) +# Required for full builds with TFLite. Not needed for runtime with TFLite. +# /path/to/flatbuffers: flatbuffers root path when using tflite library +set(USE_FLATBUFFERS_PATH none) + # Possible values: # - OFF: disable tflite support for edgetpu # - /path/to/edgetpu: use specific path to edgetpu library diff --git a/cmake/modules/Hexagon.cmake b/cmake/modules/Hexagon.cmake index 5b56982a42c5..30b4ccbc5618 100644 --- a/cmake/modules/Hexagon.cmake +++ b/cmake/modules/Hexagon.cmake @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. +include(ExternalProject) + set(PICK_SIM "sim") set(PICK_HW "target") set(PICK_NONE "OFF") @@ -77,6 +79,13 @@ if(USE_HEXAGON_DEVICE STREQUAL "${PICK_SIM}") include_directories("${HEXAGON_TOOLCHAIN}/include/iss") link_directories("${HEXAGON_TOOLCHAIN}/lib/iss") list(APPEND TVM_RUNTIME_LINKER_LIBS "-lwrapper") + ExternalProject_Add(sim_dev + SOURCE_DIR "${CMAKE_SOURCE_DIR}/src/runtime/hexagon/sim/driver" + CMAKE_ARGS + "-DCMAKE_C_COMPILER=${HEXAGON_TOOLCHAIN}/bin/hexagon-clang" + "-DCMAKE_CXX_COMPILER=${HEXAGON_TOOLCHAIN}/bin/hexagon-clang++" + INSTALL_COMMAND "true" + ) elseif(USE_HEXAGON_DEVICE STREQUAL "${PICK_HW}") find_hexagon_sdk_root() find_hexagon_toolchain() @@ -87,7 +96,11 @@ elseif(USE_HEXAGON_DEVICE STREQUAL "${PICK_HW}") include_directories( "${HEXAGON_SDK_ROOT}/libs/common/remote/ship/android_Release_aarch64") include_directories("${HEXAGON_TOOLCHAIN}/include/iss") - list(APPEND TVM_RUNTIME_LINKER_LIBS "-ldl") + list(APPEND TVM_RUNTIME_LINKER_LIBS "dl") + if(BUILD_FOR_ANDROID) + # Hexagon runtime uses __android_log_print, which is in liblog. + list(APPEND TVM_RUNTIME_LINKER_LIBS "log") + endif() endif() file(GLOB RUNTIME_HEXAGON_SRCS src/runtime/hexagon/*.cc) diff --git a/cmake/modules/VTA.cmake b/cmake/modules/VTA.cmake index 4af39e088b23..d9508470c0a2 100644 --- a/cmake/modules/VTA.cmake +++ b/cmake/modules/VTA.cmake @@ -89,7 +89,8 @@ elseif(PYTHON) # VTA FPGA driver sources if(USE_VTA_FPGA) - file(GLOB FPGA_RUNTIME_SRCS ${VTA_HW_PATH}/src/*.cc) + file(GLOB FSIM_RUNTIME_SRCS ${VTA_HW_PATH}/src/*.cc) + file(GLOB FPGA_RUNTIME_SRCS vta/runtime/*.cc) # Rules for Zynq-class FPGAs with pynq OS support (see pynq.io) if(${VTA_TARGET} STREQUAL "pynq" OR ${VTA_TARGET} STREQUAL "ultra96") @@ -101,13 +102,14 @@ elseif(PYTHON) endif() # Target lib: vta add_library(vta SHARED ${FPGA_RUNTIME_SRCS}) - target_include_directories(vta PUBLIC vta/include) + target_include_directories(vta PUBLIC vta/runtime) foreach(__def ${VTA_DEFINITIONS}) string(SUBSTRING ${__def} 3 -1 __strip_def) target_compile_definitions(vta PUBLIC ${__strip_def}) endforeach() if(${VTA_TARGET} STREQUAL "pynq" OR ${VTA_TARGET} STREQUAL "ultra96") + target_include_directories(vta PUBLIC ${VTA_HW_PATH}/include) target_link_libraries(vta ${__cma_lib}) elseif(${VTA_TARGET} STREQUAL "de10nano") # DE10-Nano rules #target_compile_definitions(vta PUBLIC VTA_MAX_XFER=2097152) # (1<<21) diff --git a/cmake/modules/contrib/CoreML.cmake b/cmake/modules/contrib/CoreML.cmake new file mode 100644 index 000000000000..a61e9f6eef5f --- /dev/null +++ b/cmake/modules/contrib/CoreML.cmake @@ -0,0 +1,25 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +if(USE_COREML) + message(STATUS "Build with contrib.coreml") + find_library(FOUNDATION_LIB Foundation) + find_library(COREML_LIB Coreml) + file(GLOB COREML_CONTRIB_SRC src/runtime/contrib/coreml/*.mm) + list(APPEND TVM_RUNTIME_LINKER_LIBS ${FOUNDATION_LIB} ${COREML_LIB}) + list(APPEND RUNTIME_SRCS ${COREML_CONTRIB_SRC}) +endif(USE_COREML) diff --git a/cmake/modules/contrib/TFLite.cmake b/cmake/modules/contrib/TFLite.cmake index ec03c960dfa7..c16a76d4b37e 100644 --- a/cmake/modules/contrib/TFLite.cmake +++ b/cmake/modules/contrib/TFLite.cmake @@ -17,7 +17,7 @@ if(NOT USE_TFLITE STREQUAL "OFF") message(STATUS "Build with contrib.tflite") - if (USE_TENSORFLOW_PATH STREQUAL "none") + if (USE_TENSORFLOW_PATH STREQUAL "none") set(USE_TENSORFLOW_PATH ${CMAKE_CURRENT_SOURCE_DIR}/tensorflow) endif() @@ -40,5 +40,8 @@ if(NOT USE_TFLITE STREQUAL "OFF") find_library(TFLITE_CONTRIB_LIB libtensorflow-lite.a ${USE_TFLITE}) list(APPEND TVM_RUNTIME_LINKER_LIBS ${TFLITE_CONTRIB_LIB}) - list(APPEND TVM_RUNTIME_LINKER_LIBS rt dl flatbuffers) + + if (NOT USE_FLATBUFFERS_PATH STREQUAL "none") + include_directories(${USE_FLATBUFFERS_PATH}/include) + endif() endif() diff --git a/cmake/modules/contrib/TF_TVMDSOOP.cmake b/cmake/modules/contrib/TF_TVMDSOOP.cmake index e92822a397ae..1509e83a9be3 100644 --- a/cmake/modules/contrib/TF_TVMDSOOP.cmake +++ b/cmake/modules/contrib/TF_TVMDSOOP.cmake @@ -44,7 +44,6 @@ if(NOT USE_TF_TVMDSOOP STREQUAL "OFF") set(OP_LIBRARY_NAME tvm_dso_op) file(GLOB_RECURSE TFTVM_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/src/contrib/tf_op/*.cc) add_library(${OP_LIBRARY_NAME} SHARED ${TFTVM_SRCS}) - set_target_properties(${OP_LIBRARY_NAME} PROPERTIES PREFIX "") set(TFTVM_LINK_FLAGS -ltvm -L${CMAKE_CURRENT_BINARY_DIR}) if (NOT BUILD_TVMDSOOP_ONLY STREQUAL "ON") diff --git a/cmake/util/FindVulkan.cmake b/cmake/util/FindVulkan.cmake index 850b99ceb299..9d9a973c6558 100644 --- a/cmake/util/FindVulkan.cmake +++ b/cmake/util/FindVulkan.cmake @@ -62,7 +62,7 @@ macro(find_vulkan use_vulkan) if(Vulkan_FOUND) get_filename_component(VULKAN_LIBRARY_PATH ${Vulkan_LIBRARY} DIRECTORY) find_library(Vulkan_SPIRV_TOOLS_LIBRARY SPIRV-Tools - HINTS ${VULKAN_LIBRARY_PATH} ${VULKAN_LIBRARY_PATH}/spirv-tools) + HINTS ${VULKAN_LIBRARY_PATH} ${VULKAN_LIBRARY_PATH}/spirv-tools ${VULKAN_SDK}/lib) find_path(_libspirv libspirv.h HINTS ${Vulkan_INCLUDE_DIRS} PATH_SUFFIXES vulkan spirv-tools) find_path(_spirv spirv.hpp HINTS ${Vulkan_INCLUDE_DIRS} PATH_SUFFIXES vulkan SPIRV spirv/unified1) diff --git a/dmlc_tvm_commit_id.txt b/dmlc_tvm_commit_id.txt index 4a7c7b3313aa..fab59b842257 100644 --- a/dmlc_tvm_commit_id.txt +++ b/dmlc_tvm_commit_id.txt @@ -1 +1 @@ -c9cddddf1213b99485ad9dd5b4262a98b325abda +520ca0a8b39aeb4765369f169477265230ea7c6c \ No newline at end of file diff --git a/docker/Dockerfile.ci_cpu b/docker/Dockerfile.ci_cpu index 10c8c62d970b..f1a928a5a0d6 100644 --- a/docker/Dockerfile.ci_cpu +++ b/docker/Dockerfile.ci_cpu @@ -16,22 +16,22 @@ # under the License. # CI docker CPU env -# tag: v0.55 -FROM ubuntu:16.04 +# tag: v0.62 +FROM ubuntu:18.04 RUN apt-get update --fix-missing COPY install/ubuntu_install_core.sh /install/ubuntu_install_core.sh RUN bash /install/ubuntu_install_core.sh -COPY install/ubuntu_install_python.sh /install/ubuntu_install_python.sh -RUN bash /install/ubuntu_install_python.sh +COPY install/ubuntu1804_install_python.sh /install/ubuntu1804_install_python.sh +RUN bash /install/ubuntu1804_install_python.sh COPY install/ubuntu_install_python_package.sh /install/ubuntu_install_python_package.sh RUN bash /install/ubuntu_install_python_package.sh -COPY install/ubuntu_install_llvm.sh /install/ubuntu_install_llvm.sh -RUN bash /install/ubuntu_install_llvm.sh +COPY install/ubuntu1804_install_llvm.sh /install/ubuntu1804_install_llvm.sh +RUN bash /install/ubuntu1804_install_llvm.sh # Rust env (build early; takes a while) COPY install/ubuntu_install_rust.sh /install/ubuntu_install_rust.sh @@ -63,3 +63,11 @@ RUN bash /install/ubuntu_install_antlr.sh # Chisel deps for TSIM COPY install/ubuntu_install_chisel.sh /install/ubuntu_install_chisel.sh RUN bash /install/ubuntu_install_chisel.sh + +# TFLite deps +COPY install/ubuntu_install_tflite.sh /install/ubuntu_install_tflite.sh +RUN bash /install/ubuntu_install_tflite.sh + +# TensorFlow deps +COPY install/ubuntu_install_tensorflow.sh /install/ubuntu_install_tensorflow.sh +RUN bash /install/ubuntu_install_tensorflow.sh diff --git a/docker/Dockerfile.ci_gpu b/docker/Dockerfile.ci_gpu index 162549650994..6a1023a321ff 100644 --- a/docker/Dockerfile.ci_gpu +++ b/docker/Dockerfile.ci_gpu @@ -56,9 +56,6 @@ RUN bash /install/ubuntu_install_nodejs.sh COPY install/ubuntu_install_rocm.sh /install/ubuntu_install_rocm.sh RUN bash /install/ubuntu_install_rocm.sh -COPY install/ubuntu_install_opengl.sh /install/ubuntu_install_opengl.sh -RUN bash /install/ubuntu_install_opengl.sh - # DL Frameworks COPY install/ubuntu_install_mxnet.sh /install/ubuntu_install_mxnet.sh RUN bash /install/ubuntu_install_mxnet.sh @@ -111,7 +108,4 @@ ENV LD_LIBRARY_PATH=/usr/local/cuda/lib64:/usr/local/cuda/compact:${LD_LIBRARY_P ENV LD_LIBRARY_PATH=/opt/rocm/lib:${LD_LIBRARY_PATH} ENV PATH=/node_modules/.bin:${PATH} -ENV VULKAN_SDK=/usr/local/VulkanSDK/1.0.65.0/x86_64 -ENV PATH=${PATH}:${VULKAN_SDK}/bin -ENV LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:${VULKAN_SDK}/lib -ENV VK_LAYER_PATH=${VULKAN_SDK}/etc/explicit_layer.d +ENV VULKAN_SDK=/usr diff --git a/docker/Dockerfile.ci_lint b/docker/Dockerfile.ci_lint index 1c72ee70b63e..aeed9cad416a 100644 --- a/docker/Dockerfile.ci_lint +++ b/docker/Dockerfile.ci_lint @@ -18,13 +18,17 @@ # For lint test # CI docker lint env # tag: v0.60 -FROM ubuntu:16.04 +FROM ubuntu:18.04 -RUN apt-get update && apt-get install -y sudo wget -COPY install/ubuntu_install_python.sh /install/ubuntu_install_python.sh -RUN bash /install/ubuntu_install_python.sh +RUN apt-get update --fix-missing + +RUN apt-get update && apt-get install -y wget git sudo make + +COPY install/ubuntu1804_install_python.sh /install/ubuntu1804_install_python.sh +RUN bash /install/ubuntu1804_install_python.sh + +RUN apt-get update && apt-get install -y doxygen graphviz -RUN apt-get install -y doxygen graphviz git RUN pip3 install cpplint pylint==2.4.4 mypy # java deps for rat @@ -33,3 +37,9 @@ RUN bash /install/ubuntu_install_java.sh COPY install/ubuntu_install_rat.sh /install/ubuntu_install_rat.sh RUN bash /install/ubuntu_install_rat.sh + +COPY install/ubuntu1804_install_clang_format.sh /install/ubuntu1804_install_clang_format.sh +RUN bash /install/ubuntu1804_install_clang_format.sh + +COPY install/ubuntu_install_nodejs.sh /install/ubuntu_install_nodejs.sh +RUN bash /install/ubuntu_install_nodejs.sh diff --git a/docker/Dockerfile.ci_emscripten b/docker/Dockerfile.ci_wasm similarity index 60% rename from docker/Dockerfile.ci_emscripten rename to docker/Dockerfile.ci_wasm index e2a9a7b76dd9..965bc01d22d8 100644 --- a/docker/Dockerfile.ci_emscripten +++ b/docker/Dockerfile.ci_wasm @@ -14,26 +14,36 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - -# For CPU -FROM ubuntu:16.04 +FROM ubuntu:18.04 RUN apt-get update --fix-missing COPY install/ubuntu_install_core.sh /install/ubuntu_install_core.sh RUN bash /install/ubuntu_install_core.sh -COPY install/ubuntu_install_python.sh /install/ubuntu_install_python.sh -RUN bash /install/ubuntu_install_python.sh - -COPY install/ubuntu_install_emscripten.sh /install/ubuntu_install_emscripten.sh -RUN bash /install/ubuntu_install_emscripten.sh +COPY install/ubuntu1804_install_python.sh /install/ubuntu1804_install_python.sh +RUN bash /install/ubuntu1804_install_python.sh COPY install/ubuntu_install_python_package.sh /install/ubuntu_install_python_package.sh RUN bash /install/ubuntu_install_python_package.sh -RUN chmod a+rwx -R /emsdk-portable -RUN cp -r /emsdk-portable /emsdk-portable-backup -RUN mv /emsdk-portable /emsdk-portable-x -RUN mv /emsdk-portable-backup /emsdk-portable -RUN cp /root/.emscripten /emsdk-portable/ +COPY install/ubuntu1804_install_llvm.sh /install/ubuntu1804_install_llvm.sh +RUN bash /install/ubuntu1804_install_llvm.sh + +COPY install/ubuntu_install_java.sh /install/ubuntu_install_java.sh +RUN bash /install/ubuntu_install_java.sh + +COPY install/ubuntu_install_antlr.sh /install/ubuntu_install_antlr.sh +RUN bash /install/ubuntu_install_antlr.sh + +COPY install/ubuntu_install_nodejs.sh /install/ubuntu_install_nodejs.sh +RUN bash /install/ubuntu_install_nodejs.sh + +COPY install/ubuntu_install_emscripten.sh /install/ubuntu_install_emscripten.sh +RUN bash /install/ubuntu_install_emscripten.sh + +ENV EMSDK=/emsdk +ENV PATH=${PATH}:${EMSDK}:${EMSDK}/upstream/emscripten +ENV EMSCRIPTEN=${EMSDK}/upstream/emscripten +ENV BINARYEN=${EMSDK}/upstream +ENV LLVM=${EMSDK}/upstream/bin diff --git a/docker/bash.sh b/docker/bash.sh index 61823f9b6700..d46ee18b34a9 100755 --- a/docker/bash.sh +++ b/docker/bash.sh @@ -83,12 +83,14 @@ ${DOCKER_BINARY} run --rm --pid=host\ -v ${WORKSPACE}:/workspace \ -v ${SCRIPT_DIR}:/docker \ -w /workspace \ + --ulimit stack=16777216:16777216 \ -e "CI_BUILD_HOME=/workspace" \ -e "CI_BUILD_USER=$(id -u -n)" \ -e "CI_BUILD_UID=$(id -u)" \ -e "CI_BUILD_GROUP=$(id -g -n)" \ -e "CI_BUILD_GID=$(id -g)" \ - -e "PYTHONPATH=python:topi/python"\ + -e "PYTHONPATH=/workspace/python:/workspace/topi/python"\ + -e "CI_PYTEST_ADD_OPTIONS=$CI_PYTEST_ADD_OPTIONS" \ ${CUDA_ENV}\ ${CI_DOCKER_EXTRA_PARAMS[@]} \ ${DOCKER_IMAGE_NAME}\ diff --git a/docker/build.sh b/docker/build.sh index defa28245544..43f0a08700a4 100755 --- a/docker/build.sh +++ b/docker/build.sh @@ -67,6 +67,7 @@ fi if [[ "$1" == "--cache-from" ]]; then shift 1 cached_image="$1" + CI_DOCKER_BUILD_EXTRA_PARAMS+=("--cache-from tvm.$CONTAINER_TYPE") CI_DOCKER_BUILD_EXTRA_PARAMS+=("--cache-from $cached_image") shift 1 fi @@ -162,6 +163,7 @@ ${DOCKER_BINARY} run --rm --pid=host \ -e "CI_BUILD_UID=$(id -u)" \ -e "CI_BUILD_GROUP=$(id -g -n)" \ -e "CI_BUILD_GID=$(id -g)" \ + -e "CI_PYTEST_ADD_OPTIONS=$CI_PYTEST_ADD_OPTIONS" \ ${CUDA_ENV}\ ${CI_DOCKER_EXTRA_PARAMS[@]} \ ${DOCKER_IMG_NAME} \ diff --git a/docker/install/ubuntu1804_install_clang_format.sh b/docker/install/ubuntu1804_install_clang_format.sh new file mode 100755 index 000000000000..e830433bb039 --- /dev/null +++ b/docker/install/ubuntu1804_install_clang_format.sh @@ -0,0 +1,29 @@ +#!/bin/bash +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +set -e +set -u +set -o pipefail + +echo deb http://apt.llvm.org/bionic/ llvm-toolchain-bionic-10 main\ + >> /etc/apt/sources.list.d/llvm.list +echo deb-src http://apt.llvm.org/bionic/ llvm-toolchain-bionic-10 main\ + >> /etc/apt/sources.list.d/llvm.list + +wget -q -O - http://apt.llvm.org/llvm-snapshot.gpg.key | apt-key add - +apt-get update && apt-get install -y clang-format-10 diff --git a/docker/install/ubuntu1804_install_llvm.sh b/docker/install/ubuntu1804_install_llvm.sh new file mode 100755 index 000000000000..4f7c9df7ac8e --- /dev/null +++ b/docker/install/ubuntu1804_install_llvm.sh @@ -0,0 +1,40 @@ +#!/bin/bash +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +set -e +set -u +set -o pipefail + +echo deb http://apt.llvm.org/bionic/ llvm-toolchain-bionic-9 main\ + >> /etc/apt/sources.list.d/llvm.list +echo deb-src http://apt.llvm.org/bionic/ llvm-toolchain-bionic-9 main\ + >> /etc/apt/sources.list.d/llvm.list + + +echo deb http://apt.llvm.org/bionic/ llvm-toolchain-bionic-10 main\ + >> /etc/apt/sources.list.d/llvm.list +echo deb-src http://apt.llvm.org/bionic/ llvm-toolchain-bionic-10 main\ + >> /etc/apt/sources.list.d/llvm.list + +echo deb http://apt.llvm.org/bionic/ llvm-toolchain-bionic main\ + >> /etc/apt/sources.list.d/llvm.list +echo deb-src http://apt.llvm.org/bionic/ llvm-toolchain-bionic main\ + >> /etc/apt/sources.list.d/llvm.list + +wget -q -O - http://apt.llvm.org/llvm-snapshot.gpg.key|sudo apt-key add - +apt-get update && apt-get install -y llvm-9 llvm-10 llvm-11 clang-9 clang-10 clang-11 diff --git a/docker/install/ubuntu1804_install_python.sh b/docker/install/ubuntu1804_install_python.sh new file mode 100755 index 000000000000..6b4d6fb4f727 --- /dev/null +++ b/docker/install/ubuntu1804_install_python.sh @@ -0,0 +1,32 @@ +#!/bin/bash +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +set -e +set -u +set -o pipefail + +# install python and pip, don't modify this, modify install_python_package.sh +apt-get update +apt-get install -y software-properties-common +apt-get install -y python3-dev python3-setuptools + +# Install pip +cd /tmp && wget -q https://bootstrap.pypa.io/get-pip.py && python3 get-pip.py + +# Pin pip version +pip3 install pip==19.3.1 diff --git a/docker/install/ubuntu_install_caffe2.sh b/docker/install/ubuntu_install_caffe2.sh index a9c1d19eca8a..fa091f950497 100755 --- a/docker/install/ubuntu_install_caffe2.sh +++ b/docker/install/ubuntu_install_caffe2.sh @@ -20,6 +20,20 @@ set -e set -u set -o pipefail -python3 -m caffe2.python.models.download -i -f squeezenet -python3 -m caffe2.python.models.download -i -f resnet50 -python3 -m caffe2.python.models.download -i -f vgg19 +# caffe2.python.module.download generates a progress bar. in non +# interactive use this results in huge progress debris in the log +# files. There is no option to disable the progress bar so work +# around it by stripping the progress bar output + +filter_progress_bar() +{ + # Progress bars are the 'goto start of line' escape sequence + # ESC[1000D[ repeated, the end of the progress bar is the end of + # line. We can selectively remove progress bars by dropping lines + # that beging with the escape sequence. + sed "/^\x1b\[1000D/d" +} + +python3 -m caffe2.python.models.download -i -f squeezenet | filter_progress_bar +python3 -m caffe2.python.models.download -i -f resnet50 | filter_progress_bar +python3 -m caffe2.python.models.download -i -f vgg19 | filter_progress_bar diff --git a/docker/install/ubuntu_install_darknet.sh b/docker/install/ubuntu_install_darknet.sh index 0238b920c844..c48724c6065b 100755 --- a/docker/install/ubuntu_install_darknet.sh +++ b/docker/install/ubuntu_install_darknet.sh @@ -22,5 +22,8 @@ set -o pipefail #install the necessary dependancies, cffi, opencv wget -q 'https://github.com/siju-samuel/darknet/blob/master/lib/libdarknet.so?raw=true' -O libdarknet.so -pip2 install opencv-python cffi +debian_version=`cat /etc/debian_version` +if [ "$debian_version" == "stretch/sid" ]; then + pip2 install opencv-python cffi +fi pip3 install opencv-python cffi diff --git a/docker/install/ubuntu_install_emscripten.sh b/docker/install/ubuntu_install_emscripten.sh index 0012cd08cf28..2e48cccbe2a6 100755 --- a/docker/install/ubuntu_install_emscripten.sh +++ b/docker/install/ubuntu_install_emscripten.sh @@ -6,9 +6,9 @@ # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -20,20 +20,8 @@ set -e set -u set -o pipefail -alias make="make -j4" - -# Get latest cmake -wget -q https://cmake.org/files/v3.8/cmake-3.8.2-Linux-x86_64.tar.gz -tar xf cmake-3.8.2-Linux-x86_64.tar.gz -export PATH=/cmake-3.8.2-Linux-x86_64/bin/:${PATH} - -wget -q https://s3.amazonaws.com/mozilla-games/emscripten/releases/emsdk-portable.tar.gz -tar xf emsdk-portable.tar.gz -cd emsdk-portable -./emsdk update +cd / +git clone https://github.com/emscripten-core/emsdk.git +cd emsdk ./emsdk install latest ./emsdk activate latest -# Clone and pull latest sdk -./emsdk install clang-incoming-64bit -./emsdk activate clang-incoming-64bit -cd .. diff --git a/docker/install/ubuntu_install_mxnet.sh b/docker/install/ubuntu_install_mxnet.sh index d587843d4dec..aa04d4c19177 100755 --- a/docker/install/ubuntu_install_mxnet.sh +++ b/docker/install/ubuntu_install_mxnet.sh @@ -20,4 +20,4 @@ set -e set -u set -o pipefail -pip3 install mxnet-mkl==1.6.0 +pip3 install mxnet==1.6.0 diff --git a/docker/install/ubuntu_install_nnpack.sh b/docker/install/ubuntu_install_nnpack.sh index 6eb94ae6f5c2..744f76a162bb 100755 --- a/docker/install/ubuntu_install_nnpack.sh +++ b/docker/install/ubuntu_install_nnpack.sh @@ -20,7 +20,7 @@ set -e set -u set -o pipefail -apt-get update && apt-get install -y --no-install-recommends git cmake +apt-get update && apt-get install -y --no-install-recommends git cmake python-setuptools git clone https://github.com/Maratyszcza/NNPACK NNPACK git clone https://github.com/Maratyszcza/pthreadpool NNPACK/pthreadpool diff --git a/docker/install/ubuntu_install_nodejs.sh b/docker/install/ubuntu_install_nodejs.sh index 8da9e2485797..b36da6295ec0 100755 --- a/docker/install/ubuntu_install_nodejs.sh +++ b/docker/install/ubuntu_install_nodejs.sh @@ -25,8 +25,6 @@ apt-get install -y curl # The node install script fetched and executed here will update the # apt source list, hence the second apt-get update is necessary. -curl -s -S -L https://deb.nodesource.com/setup_8.x | bash - +curl -s -S -L https://deb.nodesource.com/setup_14.x | bash - apt-get update apt-get install -y nodejs - -npm install eslint jsdoc ws diff --git a/docker/install/ubuntu_install_redis.sh b/docker/install/ubuntu_install_redis.sh index 9679fddf1894..939b36679b53 100755 --- a/docker/install/ubuntu_install_redis.sh +++ b/docker/install/ubuntu_install_redis.sh @@ -6,9 +6,9 @@ # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -21,5 +21,4 @@ set -u set -o pipefail apt-get update && apt-get install -y redis-server -pip2 install xgboost psutil pip3 install xgboost psutil diff --git a/docker/install/ubuntu_install_rust.sh b/docker/install/ubuntu_install_rust.sh index ff22ea31cdd9..310e6507e3f3 100755 --- a/docker/install/ubuntu_install_rust.sh +++ b/docker/install/ubuntu_install_rust.sh @@ -29,5 +29,11 @@ curl -s -S -L https://sh.rustup.rs -sSf | sh -s -- -y --no-modify-path --default . $CARGO_HOME/env rustup component add rustfmt +# install wasmtime +export WASMTIME_HOME=/opt/wasmtime +curl https://wasmtime.dev/install.sh -sSf | bash +export PATH="${WASMTIME_HOME}/bin:${PATH}" +rustup target add wasm32-wasi + # make rust usable by all users chmod -R a+w /opt/rust diff --git a/docker/install/ubuntu_install_tensorflow.sh b/docker/install/ubuntu_install_tensorflow.sh index 8a51b63b5652..e187695c024d 100755 --- a/docker/install/ubuntu_install_tensorflow.sh +++ b/docker/install/ubuntu_install_tensorflow.sh @@ -20,4 +20,4 @@ set -e set -u set -o pipefail -pip3 install tensorflow==1.13.1 keras h5py +pip3 install tensorflow==2.1.0 keras h5py diff --git a/docker/install/ubuntu_install_tflite.sh b/docker/install/ubuntu_install_tflite.sh index df65753aace4..123ff520d725 100755 --- a/docker/install/ubuntu_install_tflite.sh +++ b/docker/install/ubuntu_install_tflite.sh @@ -6,9 +6,9 @@ # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -21,21 +21,26 @@ set -u set -o pipefail # Download, build and install flatbuffers -git clone --branch=v1.10.0 --depth=1 --recursive https://github.com/google/flatbuffers.git +git clone --branch=v1.12.0 --depth=1 --recursive https://github.com/google/flatbuffers.git cd flatbuffers cmake -G "Unix Makefiles" -DCMAKE_BUILD_TYPE=Release make install -j8 cd .. -rm -rf flatbuffers # Install flatbuffers python packages. pip3 install flatbuffers -pip2 install flatbuffers + +# Build the TFLite static library, necessary for building with TFLite ON. +# The library is built at: +# tensorflow/tensorflow/lite/tools/make/gen/*/lib/libtensorflow-lite.a. +git clone https://github.com/tensorflow/tensorflow --branch=r2.1 +./tensorflow/tensorflow/lite/tools/make/download_dependencies.sh +./tensorflow/tensorflow/lite/tools/make/build_lib.sh # Setup tflite from schema mkdir tflite cd tflite -wget -q https://raw.githubusercontent.com/tensorflow/tensorflow/r1.13/tensorflow/lite/schema/schema.fbs +wget -q https://raw.githubusercontent.com/tensorflow/tensorflow/r2.1/tensorflow/lite/schema/schema.fbs flatc --python schema.fbs cat <setup.py @@ -43,7 +48,7 @@ import setuptools setuptools.setup( name="tflite", - version="1.13.1", + version="2.1.0", author="google", author_email="google@google.com", description="TFLite", @@ -63,9 +68,8 @@ cat <__init__.py name = "tflite" EOM -# Install tflite over python2 and python3 +# Install tflite over python3 python3 setup.py install -python2 setup.py install cd .. rm -rf tflite diff --git a/docker/install/ubuntu_install_vulkan.sh b/docker/install/ubuntu_install_vulkan.sh index 5fb40829e0bc..b7d2d4672b0c 100755 --- a/docker/install/ubuntu_install_vulkan.sh +++ b/docker/install/ubuntu_install_vulkan.sh @@ -20,10 +20,7 @@ set -e set -u set -o pipefail -wget -q https://sdk.lunarg.com/sdk/download/1.0.65.0/linux/vulkansdk-linux-x86_64-1.0.65.0.run - -bash vulkansdk-linux-x86_64-1.0.65.0.run -mv VulkanSDK /usr/local/VulkanSDK -cd /usr/local/VulkanSDK/1.0.65.0 -./build_tools.sh -./build_samples.sh +wget -qO - http://packages.lunarg.com/lunarg-signing-key-pub.asc | apt-key add - +wget -qO /etc/apt/sources.list.d/lunarg-vulkan-1.2.135-xenial.list http://packages.lunarg.com/vulkan/1.2.135/lunarg-vulkan-1.2.135-xenial.list +apt update +apt install -y vulkan-sdk diff --git a/docker/with_the_same_user b/docker/with_the_same_user index 1288afd006c0..2338f6351e82 100644 --- a/docker/with_the_same_user +++ b/docker/with_the_same_user @@ -41,6 +41,7 @@ getent passwd "${CI_BUILD_UID}" || adduser --gid "${CI_BUILD_GID}" --uid "${CI_B --gecos "${CI_BUILD_USER} (generated by with_the_same_user script)" \ --disabled-password --home "${CI_BUILD_HOME}" --quiet "${CI_BUILD_USER}" usermod -a -G sudo "${CI_BUILD_USER}" +# This is a grotesque hack to get PYTEST_ADD_OPTS available to all task scripts. echo "${CI_BUILD_USER} ALL=(ALL) NOPASSWD:ALL" > /etc/sudoers.d/90-nopasswd-sudo if [[ ! -z $CUDA_VISIBLE_DEVICES ]]; then diff --git a/docs/api_links.rst b/docs/api/links.rst similarity index 92% rename from docs/api_links.rst rename to docs/api/links.rst index b2a66a5ccd85..8c22cf8ab0f8 100644 --- a/docs/api_links.rst +++ b/docs/api/links.rst @@ -16,10 +16,10 @@ under the License. Links to API References -================================== +======================= This page contains links to API references that are build with different doc build system. * `C++ doyxgen API `_ -* `Javascript jsdoc API `_ +* `Typescript typedoc API `_ * `Java Javadoc API `_ diff --git a/docs/api/python/contrib.rst b/docs/api/python/contrib.rst index b482d30515d4..8ac4e1ff7d3a 100644 --- a/docs/api/python/contrib.rst +++ b/docs/api/python/contrib.rst @@ -48,9 +48,9 @@ tvm.contrib.dlpack .. automodule:: tvm.contrib.dlpack :members: -tvm.contrib.emscripten -~~~~~~~~~~~~~~~~~~~~~~ -.. automodule:: tvm.contrib.emscripten +tvm.contrib.emcc +~~~~~~~~~~~~~~~~ +.. automodule:: tvm.contrib.emcc :members: tvm.contrib.miopen diff --git a/docs/api/python/index.rst b/docs/api/python/index.rst index c279dc2d1d9d..bee6e56a8cab 100644 --- a/docs/api/python/index.rst +++ b/docs/api/python/index.rst @@ -37,9 +37,11 @@ Python API relay/transform relay/analysis relay/backend + relay/dataflow_pattern relay/testing autotvm rpc + micro contrib graph_runtime vta/index diff --git a/docs/api/python/micro.rst b/docs/api/python/micro.rst new file mode 100644 index 000000000000..1a93f74834c7 --- /dev/null +++ b/docs/api/python/micro.rst @@ -0,0 +1,23 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +tvm.micro +--------- +.. automodule:: tvm.micro + :members: + :imported-members: + :autosummary: diff --git a/docs/api/python/relay/dataflow_pattern.rst b/docs/api/python/relay/dataflow_pattern.rst new file mode 100644 index 000000000000..fe1d4e95e507 --- /dev/null +++ b/docs/api/python/relay/dataflow_pattern.rst @@ -0,0 +1,25 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +tvm.relay.dataflow_pattern +-------------------------- + +.. automodule:: tvm.relay.dataflow_pattern + :members: + :imported-members: + :exclude-members: Object, Node + :autosummary: diff --git a/docs/api/python/runtime.rst b/docs/api/python/runtime.rst index 30d1b98650a3..c51a2d452065 100644 --- a/docs/api/python/runtime.rst +++ b/docs/api/python/runtime.rst @@ -23,28 +23,3 @@ tvm.runtime :imported-members: :exclude-members: NDArray :autosummary: - - -.. autoclass:: tvm.runtime.PackedFunc - :members: - :inherited-members: - -.. autofunction:: tvm.register_func - -.. autofunction:: tvm.get_global_func - - -.. autoclass:: tvm.runtime.Module - :members: - -.. autofunction:: tvm.runtime.load_module - -.. autofunction:: tvm.runtime.system_lib - -.. autofunction:: tvm.runtime.enabled - - -.. autoclass:: tvm.runtime.Object - :members: - -.. autofunction:: tvm.register_object diff --git a/docs/api/python/tir.rst b/docs/api/python/tir.rst index 8ef247aff2f7..9f2581b8c0a8 100644 --- a/docs/api/python/tir.rst +++ b/docs/api/python/tir.rst @@ -38,3 +38,10 @@ tvm.tir.analysis :members: :imported-members: :autosummary: + + +tvm.tir.stmt_functor +-------------------- +.. automodule:: tvm.tir.stmt_functor + :members: + :autosummary: diff --git a/docs/api/python/topi.rst b/docs/api/python/topi.rst index cef2999bef52..65f2375341c1 100644 --- a/docs/api/python/topi.rst +++ b/docs/api/python/topi.rst @@ -50,10 +50,12 @@ List of operators topi.expand_dims topi.reshape topi.unravel_index + topi.sparse_to_dense topi.squeeze topi.concatenate topi.split topi.take + topi.gather topi.gather_nd topi.full topi.full_like @@ -154,10 +156,12 @@ topi .. autofunction:: topi.expand_dims .. autofunction:: topi.reshape .. autofunction:: topi.unravel_index +.. autofunction:: topi.sparse_to_dense .. autofunction:: topi.squeeze .. autofunction:: topi.concatenate .. autofunction:: topi.split .. autofunction:: topi.take +.. autofunction:: topi.gather .. autofunction:: topi.gather_nd .. autofunction:: topi.full .. autofunction:: topi.full_like @@ -212,6 +216,8 @@ topi.nn .. autofunction:: topi.nn.conv2d_hwcn .. autofunction:: topi.nn.depthwise_conv2d_nchw .. autofunction:: topi.nn.depthwise_conv2d_nhwc +.. autofunction:: topi.nn.conv3d_ncdhw +.. autofunction:: topi.nn.conv3d_transpose_ncdhw .. autofunction:: topi.nn.fifo_buffer topi.image @@ -231,6 +237,8 @@ topi.generic .. autofunction:: topi.generic.schedule_conv2d_nchw .. autofunction:: topi.generic.schedule_depthwise_conv2d_nchw +.. autofunction:: topi.generic.schedule_conv3d_ncdhw +.. autofunction:: topi.generic.schedule_conv3d_transpose_ncdhw .. autofunction:: topi.generic.schedule_reduce .. autofunction:: topi.generic.schedule_broadcast .. autofunction:: topi.generic.schedule_injective diff --git a/docs/conf.py b/docs/conf.py index 6ef86ca5a39e..7ece63bd7aa8 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -183,7 +183,7 @@ intersphinx_mapping = { 'python': ('https://docs.python.org/{.major}'.format(sys.version_info), None), - 'numpy': ('https://docs.scipy.org/doc/numpy/', None), + 'numpy': ('https://numpy.org/doc/stable', None), 'scipy': ('https://docs.scipy.org/doc/scipy/reference', None), 'matplotlib': ('https://matplotlib.org/', None), } @@ -211,7 +211,7 @@ 'reference_url': { 'tvm': None, 'matplotlib': 'https://matplotlib.org/', - 'numpy': 'https://docs.scipy.org/doc/numpy/' + 'numpy': 'https://numpy.org/doc/stable' }, 'examples_dirs': examples_dirs, 'gallery_dirs': gallery_dirs, @@ -283,7 +283,7 @@ def process_docstring(app, what, name, obj, options, lines): def setup(app): app.connect('autodoc-process-docstring', process_docstring) - app.add_stylesheet('css/tvm_theme.css') + app.add_css_file('css/tvm_theme.css') app.add_config_value('recommonmark_config', { 'url_resolver': lambda url: github_doc_root + url, 'auto_doc_ref': True diff --git a/docs/contribute/code_guide.rst b/docs/contribute/code_guide.rst index 3d135eb8c8d5..c932e93a11f1 100644 --- a/docs/contribute/code_guide.rst +++ b/docs/contribute/code_guide.rst @@ -34,6 +34,47 @@ C++ Code Styles pass by value is better than pass by const reference in such cases. - Favor ``const`` member function when possible. +We use `clang-format` to enforce the code style. Because different version +of clang-format might change by its version, it is recommended to use the same +version of the clang-format as the master. +You can also use the following command via docker. + +.. code:: bash + + docker/bash.sh tvmai/ci-lint clang-format-10 [path-to-file] + + +clang-format is also not perfect, when necessary, you can use disble clang-format on certain code regions. + +.. code :: c + + // clang-format off + void Test() { + // clang-format will be disabled in this region. + } + // clang-format on + + +Because clang-format may not recognize macros, it is recommended to use macro like normal function styles. + + +.. code :: c + + #define MACRO_IMPL { custom impl; } + #define MACRO_FUNC(x) + + // not preferred, because clang-format might recognize it as types. + virtual void Func1() MACRO_IMPL + + // preferred + virtual void Func2() MACRO_IMPL; + + void Func3() { + // preferred + MACRO_FUNC(xyz); + } + + Python Code Styles ------------------ - The functions and classes are documented in `numpydoc `_ format. diff --git a/docs/contribute/pull_request.rst b/docs/contribute/pull_request.rst index 51626a16eb1a..7e0ba372b183 100644 --- a/docs/contribute/pull_request.rst +++ b/docs/contribute/pull_request.rst @@ -29,7 +29,14 @@ This is a quick guide to submit a pull request, please also refer to the detaile git rebase upstream/master - Make sure code style check pass by typing the following command, and all the existing test-cases pass. -- ``docker/bash.sh tvmai/ci-lint ./tests/scripts/task_lint.sh``. (Note: You must install docker beforehand so you can run a docker image.) + + .. code:: bash + + # Reproduce the lint procedure in the CI. + docker/bash.sh tvmai/ci-lint ./tests/scripts/task_lint.sh + # Run clang-format check for all the files that changed since upstream/master + docker/bash.sh tvmai/ci-lint ./tests/lint/git-clang-format.sh upstream/master + - Add test-cases to cover the new features or bugfix the patch introduces. - Document the code you wrote, see more at :ref:`doc_guide` - Send the pull request and fix the problems reported by automatic checks. @@ -44,11 +51,11 @@ This is a quick guide to submit a pull request, please also refer to the detaile - The patch can be merged after the reviewers approve the pull request. + CI Environment -------------- We use docker container to create stable CI environments that can be deployed to multiple machines. -You can find the prebuilt images in ``_ . Because we want a relatively stable CI environment and make use of pre-cached image, all of the CI images are built and maintained by committers. @@ -118,3 +125,6 @@ If you want to run a single test: rm -rf python/tvm/*.pyc python/tvm/*/*.pyc python/tvm/*/*/*.pyc TVM_FFI=ctypes python -m pytest -v tests/python/unittest/test_pass_storage_rewrite.py + + # Additionally if you want to run a single test, for example test_all_elemwise inside a file. + TVM_FFI=ctypes python -m pytest -v -k "test_all_elemwise" tests/python/frontend/tflite/test_forward.py diff --git a/docs/deploy/android.md b/docs/deploy/android.md deleted file mode 100644 index 788ab412db62..000000000000 --- a/docs/deploy/android.md +++ /dev/null @@ -1,39 +0,0 @@ - - - - - - - - - - - - - - - - - -# Deploy to Android - - -## Build model for Android Target - -Relay compilation of model for android target could follow same approach like android_rpc. -The code below will save the compilation output which is required on android target. - -``` -lib.export_library("deploy_lib.so", ndk.create_shared) -with open("deploy_graph.json", "w") as fo: - fo.write(graph.json()) -with open("deploy_param.params", "wb") as fo: - fo.write(relay.save_param_dict(params)) -``` - -deploy_lib.so, deploy_graph.json, deploy_param.params will go to android target. - -## TVM Runtime for Android Target - -Refer [here](https://github.com/apache/incubator-tvm/blob/master/apps/android_deploy/README.md#build-and-installation) to build CPU/OpenCL version flavor TVM runtime for android target. -From android java TVM API to load model & execute can be referred at this [java](https://github.com/apache/incubator-tvm/blob/master/apps/android_deploy/app/src/main/java/org/apache/tvm/android/demo/MainActivity.java) sample source. diff --git a/docs/deploy/android.rst b/docs/deploy/android.rst new file mode 100644 index 000000000000..c724eab8d996 --- /dev/null +++ b/docs/deploy/android.rst @@ -0,0 +1,42 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +Deploy to Android +================= + +Build model for Android Target +------------------------------ + +Relay compilation of model for android target could follow same approach like android_rpc. +The code below will save the compilation output which is required on android target. + + +.. code:: python + + lib.export_library("deploy_lib.so", ndk.create_shared) + with open("deploy_graph.json", "w") as fo: + fo.write(graph.json()) + with open("deploy_param.params", "wb") as fo: + fo.write(relay.save_param_dict(params)) + +deploy_lib.so, deploy_graph.json, deploy_param.params will go to android target. + +TVM Runtime for Android Target +------------------------------ + +Refer `here `_ to build CPU/OpenCL version flavor TVM runtime for android target. +From android java TVM API to load model & execute can be referred at this `java `_ sample source. diff --git a/docs/deploy/aocl_fpga.md b/docs/deploy/aocl_fpga.md deleted file mode 100644 index 24f8b65d2e99..000000000000 --- a/docs/deploy/aocl_fpga.md +++ /dev/null @@ -1,109 +0,0 @@ - - - - - - - - - - - - - - - - - -AOCL Backend Example -==================== - -TVM supports Intel FPGA SDK for OpenCL also known as AOCL. Here is a tutorial for how to use TVM with AOCL. - -***Note***: This feature is still experimental. We cannot use AOCL to deploy an end to end neural networks for now. In addition, we only tested compilation for emulation mode of AOCL. - -We use two python scripts for this tutorial. - -- build.py - a script to synthesize FPGA bitstream. -``` -import tvm -from tvm import te -tgt_host="llvm" -tgt="aocl_sw_emu" - -n = te.var("n") -A = te.placeholder((n,), name='A') -B = te.placeholder((n,), name='B') -C = te.compute(A.shape, lambda i: A[i] + B[i], name="C") - -s = te.create_schedule(C.op) -px, x = s[C].split(C.op.axis[0], nparts=1) - -s[C].bind(px, tvm.thread_axis("pipeline")) - -fadd = tvm.build(s, [A, B, C], tgt, target_host=tgt_host, name="myadd") - -fadd.save("myadd.o") -fadd.imported_modules[0].save("myadd.aocx") - -tvm.contrib.cc.create_shared("myadd.so", ["myadd.o"]) -``` - -- run.py - a script to use FPGA as an accelerator. -``` -import tvm -import numpy as np -import os - -tgt="aocl_sw_emu" - -fadd = tvm.runtime.load("myadd.so") -fadd_dev = tvm.runtime.load("myadd.aocx") -fadd.import_module(fadd_dev) - -ctx = tvm.context(tgt, 0) - -n = 1024 -a = tvm.nd.array(np.random.uniform(size=n).astype("float32"), ctx) -b = tvm.nd.array(np.random.uniform(size=n).astype("float32"), ctx) -c = tvm.nd.array(np.zeros(n, dtype="float32"), ctx) - -fadd(a, b, c) -tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy()) -``` - -Setup ------ - -- Install AOCL 17.1 on Ubuntu 16.04.4 LTS. -- Install BSP for your FPGA device. -- Install FPGA device driver. -- Create an ICD file at /etc/OpenCL/vendors/Altera.icd so that the OpenCL platform can be found. -``` -/opt/intelFPGA/17.1/hld/linux64/lib/libalteracl.so -``` -- Create an FCD file for example at /opt/Intel/OpenCL/Boards/s5_ref.fcd so that your FPGA device can be found. -``` -/opt/intelFPGA/17.1/hld/board/s5_ref/linux64/lib/libaltera_s5_ref_mmd.so -``` -- Setup TVM with AOCL and OpenCL enabled. - -Emulation ---------- - -- Run software emulation -``` -export CL_CONTEXT_EMULATOR_DEVICE_INTELFPGA=1 - -python build.py -python run.py -``` - -- Run on FPGA devices (not tested) - - Change tgt value to "aocl -device=s5_ref" on build.py and run.py -``` -unset CL_CONTEXT_EMULATOR_DEVICE_INTELFPGA - -python build.py -python run.py -``` diff --git a/docs/deploy/aws_fpga.md b/docs/deploy/aws_fpga.md deleted file mode 100644 index 894585f14b8a..000000000000 --- a/docs/deploy/aws_fpga.md +++ /dev/null @@ -1,170 +0,0 @@ - - - - - - - - - - - - - - - - - -HLS Backend Example -=================== - -TVM supports Xilinx FPGA board with SDAccel. Here is a tutorial for how to deploy TVM to AWS F1 FPGA instance. - -***Note***: This feature is still experimental. We cannot use SDAccel to deploy an end to end neural networks for now. - -We use two python scripts for this tutorial. - -- build.py - a script to synthesize FPGA bitstream. -```python -import tvm -from tvm import te - -tgt_host="llvm" -tgt="sdaccel" - -n = te.var("n") -A = te.placeholder((n,), name='A') -B = te.placeholder((n,), name='B') -C = te.compute(A.shape, lambda i: A[i] + B[i], name="C") - -s = te.create_schedule(C.op) -px, x = s[C].split(C.op.axis[0], nparts=1) - -s[C].bind(px, tvm.thread_axis("pipeline")) - -fadd = tvm.build(s, [A, B, C], tgt, target_host=tgt_host, name="myadd") - -fadd.save("myadd.o") -fadd.imported_modules[0].save("myadd.xclbin") - -tvm.contrib.cc.create_shared("myadd.so", ["myadd.o"]) -``` - -- run.py - a script to use FPGA as an accelerator. -```python -import tvm -import numpy as np -import os - -tgt="sdaccel" - -fadd = tvm.runtime.load("myadd.so") -if os.environ.get("XCL_EMULATION_MODE"): - fadd_dev = tvm.runtime.load("myadd.xclbin") -else: - fadd_dev = tvm.runtime.load("myadd.awsxclbin") -fadd.import_module(fadd_dev) - -ctx = tvm.context(tgt, 0) - -n = 1024 -a = tvm.nd.array(np.random.uniform(size=n).astype("float32"), ctx) -b = tvm.nd.array(np.random.uniform(size=n).astype("float32"), ctx) -c = tvm.nd.array(np.zeros(n, dtype="float32"), ctx) - -fadd(a, b, c) -tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy()) -``` - -Setup ------ - -- Launch an instance using the FPGA Developer AMI. We don't need an F1 instance for emulation and synthesis, so it is recommended to use a lower cost instance for them. - -- Setup AWS FPGA development kit. -```bash -git clone https://github.com/aws/aws-fpga.git -cd aws-fpga -source sdaccel_setup.sh -source ${XILINX_SDX}/settings64.sh -``` - -- Setup TVM with OpenCL enabled. - -Emulation ---------- - -- Create emconfig.json for emulation. -```bash -emconfigutil --platform ${AWS_PLATFORM} --nd 1 -``` - -- Copy emconfig.json to the python binary directory. It is because the current Xilinx toolkit assumes that both host binary and the emconfig.json file are in the same path. -```bash -cp emconfig.json $(dirname $(which python)) -``` - -- Run software emulation -```bash -export XCL_EMULATION_MODE=1 -export XCL_TARGET=sw_emu - -python build.py -python run.py -``` - -- Run hardware emulation -```bash -export XCL_EMULATION_MODE=1 -export XCL_TARGET=hw_emu - -python build.py -python run.py -``` - - -Synthesis ---------- - -- Run synthesis with the following script. - -```bash -unset XCL_EMULATION_MODE -export XCL_TARGET=hw - -python build.py -``` - -- Create AWS FPGA image and upload it to AWS S3. -``` -${SDACCEL_DIR}/tools/create_sdaccel_afi.sh -xclbin=myadd.xclbin -o=myadd \ - -s3_bucket= -s3_dcp_key= -s3_logs_key= -``` -This also generates an awsxclbin file, which is necessary to use the AWS FPGA image on F1 instances. - -Run ---- - -- Launch Amazon EC2 F1 instance. - -- Copy `myadd.so`, `myadd.awsxclbin`, and `run.py` to the F1 instance. - -- Setup AWS FPGA development kit. -```bash -git clone https://github.com/aws/aws-fpga.git -cd aws-fpga -source sdaccel_setup.sh -``` - -- Setup TVM with OpenCL enabled. - -- Become root and setup environment variables. -```bash -sudo sh -source ${INSTALL_ROOT}/setup.sh -``` - -- Run -```bash -python run.py -``` diff --git a/docs/deploy/cpp_deploy.md b/docs/deploy/cpp_deploy.md deleted file mode 100644 index 3a99846c0820..000000000000 --- a/docs/deploy/cpp_deploy.md +++ /dev/null @@ -1,52 +0,0 @@ - - - - - - - - - - - - - - - - - -Deploy TVM Module using C++ API -=============================== - -We provide an example on how to deploy TVM modules in [apps/howto_deploy](https://github.com/apache/incubator-tvm/tree/master/apps/howto_deploy) - -To run the example, you can use the following command - -```bash -cd apps/howto_deploy -./run_example.sh -``` - -Get TVM Runtime Library ------------------------ - -The only thing we need is to link to a TVM runtime in your target platform. -TVM provides a minimum runtime, which costs around 300K to 600K depending on how much modules we use. -In most cases, we can use ```libtvm_runtime.so``` that comes with the build. - -If somehow you find it is hard to build ```libtvm_runtime```, checkout [tvm_runtime_pack.cc](https://github.com/apache/incubator-tvm/tree/master/apps/howto_deploy/tvm_runtime_pack.cc). -It is an example all in one file that gives you TVM runtime. -You can compile this file using your build system and include this into your project. - -You can also checkout [apps](https://github.com/apache/incubator-tvm/tree/master/apps/) for example applications build with TVM on iOS, Android and others. - -Dynamic Library vs. System Module ---------------------------------- -TVM provides two ways to use the compiled library. -You can checkout [prepare_test_libs.py](https://github.com/apache/incubator-tvm/tree/master/apps/howto_deploy/prepare_test_libs.py) -on how to generate the library and [cpp_deploy.cc](https://github.com/apache/incubator-tvm/tree/master/apps/howto_deploy/cpp_deploy.cc) on how to use them. - -- Store library as a shared library and dynamically load the library into your project. -- Bundle the compiled library into your project in system module mode. - -Dynamic loading is more flexible and can load new modules on the fly. System module is a more ```static``` approach. We can use system module in places where dynamic library loading is banned. diff --git a/docs/deploy/cpp_deploy.rst b/docs/deploy/cpp_deploy.rst new file mode 100644 index 000000000000..a298f958bc78 --- /dev/null +++ b/docs/deploy/cpp_deploy.rst @@ -0,0 +1,56 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +Deploy TVM Module using C++ API +=============================== + +We provide an example on how to deploy TVM modules in `apps/howto_deploy `_ + +To run the example, you can use the following command + + +.. code:: bash + + cd apps/howto_deploy + ./run_example.sh + + +Get TVM Runtime Library +----------------------- + +The only thing we need is to link to a TVM runtime in your target platform. +TVM provides a minimum runtime, which costs around 300K to 600K depending on how much modules we use. +In most cases, we can use ``libtvm_runtime.so`` that comes with the build. + +If somehow you find it is hard to build ``libtvm_runtime``, checkout +`tvm_runtime_pack.cc `_. +It is an example all in one file that gives you TVM runtime. +You can compile this file using your build system and include this into your project. + +You can also checkout `apps `_ for example applications build with TVM on iOS, Android and others. + +Dynamic Library vs. System Module +--------------------------------- +TVM provides two ways to use the compiled library. +You can checkout `prepare_test_libs.py `_ +on how to generate the library and `cpp_deploy.cc `_ on how to use them. + +- Store library as a shared library and dynamically load the library into your project. +- Bundle the compiled library into your project in system module mode. + +Dynamic loading is more flexible and can load new modules on the fly. System module is a more ``static`` approach. We can use system module in places where dynamic library loading is banned. diff --git a/docs/deploy/hls.rst b/docs/deploy/hls.rst new file mode 100644 index 000000000000..64717ed1e678 --- /dev/null +++ b/docs/deploy/hls.rst @@ -0,0 +1,183 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +HLS Backend Example +=================== + +TVM supports Xilinx FPGA board with SDAccel. Here is a tutorial for how to deploy TVM to AWS F1 FPGA instance. + +.. note:: + + This feature is still experimental. We cannot use SDAccel to deploy an end to end neural networks for now. + +We use two python scripts for this tutorial. + +- build.py - a script to synthesize FPGA bitstream. + + .. code:: python + + import tvm + from tvm import te + + tgt_host="llvm" + tgt="sdaccel" + + n = te.var("n") + A = te.placeholder((n,), name='A') + B = te.placeholder((n,), name='B') + C = te.compute(A.shape, lambda i: A[i] + B[i], name="C") + + s = te.create_schedule(C.op) + px, x = s[C].split(C.op.axis[0], nparts=1) + + s[C].bind(px, tvm.thread_axis("pipeline")) + + fadd = tvm.build(s, [A, B, C], tgt, target_host=tgt_host, name="myadd") + + fadd.save("myadd.o") + fadd.imported_modules[0].save("myadd.xclbin") + + tvm.contrib.cc.create_shared("myadd.so", ["myadd.o"]) + +- run.py - a script to use FPGA as an accelerator. + + .. code:: python + + import tvm + import numpy as np + import os + + tgt="sdaccel" + + fadd = tvm.runtime.load("myadd.so") + if os.environ.get("XCL_EMULATION_MODE"): + fadd_dev = tvm.runtime.load("myadd.xclbin") + else: + fadd_dev = tvm.runtime.load("myadd.awsxclbin") + fadd.import_module(fadd_dev) + + ctx = tvm.context(tgt, 0) + + n = 1024 + a = tvm.nd.array(np.random.uniform(size=n).astype("float32"), ctx) + b = tvm.nd.array(np.random.uniform(size=n).astype("float32"), ctx) + c = tvm.nd.array(np.zeros(n, dtype="float32"), ctx) + + fadd(a, b, c) + tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy()) + + +Setup +----- + +- Launch an instance using the FPGA Developer AMI. We don't need an F1 instance for emulation and synthesis, so it is recommended to use a lower cost instance for them. +- Setup AWS FPGA development kit. + + .. code:: bash + + git clone https://github.com/aws/aws-fpga.git + cd aws-fpga + source sdaccel_setup.sh + source ${XILINX_SDX}/settings64.sh + +- Setup TVM with OpenCL enabled. + +Emulation +--------- + +- Create emconfig.json for emulation. + + .. code:: bash + + emconfigutil --platform ${AWS_PLATFORM} --nd 1 + +- Copy emconfig.json to the python binary directory. It is because the current Xilinx toolkit assumes that both host binary and the emconfig.json file are in the same path. + + .. code:: bash + + cp emconfig.json $(dirname $(which python)) + +- Run software emulation + + .. code:: bash + + export XCL_EMULATION_MODE=1 + export XCL_TARGET=sw_emu + + python build.py + python run.py + +- Run hardware emulation + + .. code:: bash + + export XCL_EMULATION_MODE=1 + export XCL_TARGET=hw_emu + + python build.py + python run.py + +Synthesis +--------- + +- Run synthesis with the following script. + + .. code:: bash + + unset XCL_EMULATION_MODE + export XCL_TARGET=hw + + python build.py + +- Create AWS FPGA image and upload it to AWS S3. + + .. code:: bash + + ${SDACCEL_DIR}/tools/create_sdaccel_afi.sh \ + -xclbin=myadd.xclbin -o=myadd \ + -s3_bucket= -s3_dcp_key= \ + -s3_logs_key= + + This also generates an awsxclbin file, which is necessary to use the AWS FPGA image on F1 instances. + +Run +--- + +- Launch Amazon EC2 F1 instance. +- Copy ``myadd.so``, ``myadd.awsxclbin``, and ``run.py`` to the F1 instance. +- Setup AWS FPGA development kit. + + .. code:: bash + + git clone https://github.com/aws/aws-fpga.git + cd aws-fpga + source sdaccel_setup.sh + +- Setup TVM with OpenCL enabled. +- Become root and setup environment variables. + + .. code:: bash + + sudo sh + source ${INSTALL_ROOT}/setup.sh + +- Run + + .. code:: bash + + python run.py diff --git a/docs/deploy/index.rst b/docs/deploy/index.rst index db2938635b82..53455ed50881 100644 --- a/docs/deploy/index.rst +++ b/docs/deploy/index.rst @@ -67,5 +67,4 @@ target device without relying on RPC. see the following resources on how to do s cpp_deploy android integrate - aocl_fpga - aws_fpga + hls diff --git a/docs/deploy/integrate.md b/docs/deploy/integrate.md deleted file mode 100644 index 42896149d283..000000000000 --- a/docs/deploy/integrate.md +++ /dev/null @@ -1,67 +0,0 @@ - - - - - - - - - - - - - - - - - -Integrate TVM into Your Project -=============================== - -TVM's runtime is designed to be lightweight and portable. -There are several ways you can integrate TVM into your project. - -This article introduces possible ways to integrate TVM -as a JIT compiler to generate functions on your system. - - -## DLPack Support - -TVM's generated function follows the PackedFunc convention. -It is a function that can take positional arguments including -standard types such as float, integer, string. -The PackedFunc takes DLTensor pointer in [dlpack](https://github.com/dmlc/dlpack) convention. -So the only thing you need to solve is to create a corresponding DLTensor object. - - - -## Integrate User Defined C++ Array - -The only thing we have to do in C++ is to convert your array to DLTensor and pass in its address as -```DLTensor*``` to the generated function. - - -## Integrate User Defined Python Array - -Assume you have a python object ```MyArray```. There are three things that you need to do - -- Add ```_tvm_tcode``` field to your array which returns ```tvm.TypeCode.ARRAY_HANDLE``` -- Support ```_tvm_handle``` property in your object, which returns the address of DLTensor in python integer -- Register this class by ```tvm.register_extension``` - -```python -# Example code -import tvm - -class MyArray(object): - _tvm_tcode = tvm.TypeCode.ARRAY_HANDLE - - @property - def _tvm_handle(self): - dltensor_addr = self.get_dltensor_addr() - return dltensor_addr - -# You can put registration step in a separate file mypkg.tvm.py -# and only optionally import that if you only want optional dependency. -tvm.register_extension(MyArray) -``` diff --git a/docs/deploy/integrate.rst b/docs/deploy/integrate.rst new file mode 100644 index 000000000000..99c968f14045 --- /dev/null +++ b/docs/deploy/integrate.rst @@ -0,0 +1,69 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +Integrate TVM into Your Project +=============================== + +TVM's runtime is designed to be lightweight and portable. +There are several ways you can integrate TVM into your project. + +This article introduces possible ways to integrate TVM +as a JIT compiler to generate functions on your system. + + +DLPack Support +-------------- + +TVM's generated function follows the PackedFunc convention. +It is a function that can take positional arguments including +standard types such as float, integer, string. +The PackedFunc takes DLTensor pointer in `DLPack `_ convention. +So the only thing you need to solve is to create a corresponding DLTensor object. + + + +Integrate User Defined C++ Array +-------------------------------- + +The only thing we have to do in C++ is to convert your array to DLTensor and pass in its address as +``DLTensor*`` to the generated function. + + +## Integrate User Defined Python Array + +Assume you have a python object ``MyArray``. There are three things that you need to do + +- Add ``_tvm_tcode`` field to your array which returns ``tvm.TypeCode.ARRAY_HANDLE`` +- Support ``_tvm_handle`` property in your object, which returns the address of DLTensor in python integer +- Register this class by ``tvm.register_extension`` + +.. code:: python + + # Example code + import tvm + + class MyArray(object): + _tvm_tcode = tvm.TypeCode.ARRAY_HANDLE + + @property + def _tvm_handle(self): + dltensor_addr = self.get_dltensor_addr() + return dltensor_addr + + # You can put registration step in a separate file mypkg.tvm.py + # and only optionally import that if you only want optional dependency. + tvm.register_extension(MyArray) diff --git a/docs/dev/codebase_walkthrough.rst b/docs/dev/codebase_walkthrough.rst index a66328fef7c9..8674c8e2c07e 100644 --- a/docs/dev/codebase_walkthrough.rst +++ b/docs/dev/codebase_walkthrough.rst @@ -84,7 +84,7 @@ This function is mapped to the C++ function in ``include/tvm/schedule.h``. :: inline Schedule create_schedule(Array ops) { - return ScheduleNode::make(ops); + return Schedule(ops); } ``Schedule`` consists of collections of ``Stage`` and output ``Operation``. diff --git a/docs/dev/convert_layout.rst b/docs/dev/convert_layout.rst index 7345c15b6702..07ebc2048dd3 100644 --- a/docs/dev/convert_layout.rst +++ b/docs/dev/convert_layout.rst @@ -92,7 +92,7 @@ These steps happen for each operator in sequence, where ConvertLayout pass keeps .. code-block:: python @reg.register_convert_op_layout("nn.conv2d") - def convert_conv2d(attrs, inputs, tinfos, desired_layout): + def convert_conv2d(attrs, inputs, tinfos, desired_layouts): """Convert Layout pass registration for conv2d op. Parameters @@ -103,8 +103,9 @@ These steps happen for each operator in sequence, where ConvertLayout pass keeps The args of the Relay expr to be legalized tinfos : list of types List of input and output types - desired_layout : str - The desired layout + desired_layouts : list of layout strings + List of layouts defining our desired + layout for the data and kernel inputs respectively. Returns ------- @@ -113,19 +114,30 @@ These steps happen for each operator in sequence, where ConvertLayout pass keeps """ from tvm import relay - data_layout = attrs['data_layout'] - kernel_layout = attrs['kernel_layout'] data, weight = inputs - assert desired_layout == 'NCHW', \ - "Currently only transformation to NCHW layout is supported." - if desired_layout == 'NCHW': - new_attrs = dict(attrs) - new_attrs['data_layout'] = desired_layout - new_attrs['kernel_layout'] = 'OIHW' + new_attrs = dict(attrs) + + # We expect 2 desired layouts to be specified, one for the data and one for the kernel. + assert len(desired_layouts) == 2, "A desired layout is expected for both of nn.conv2d's inputs" + + # Use the first entry in desired layouts which specifies the data layout. + # The expected ordering of layouts for this operator is defined by this function. + desired_data_layout, desired_kernel_layout = map(str, desired_layouts) + + assert desired_data_layout != "default", "Data layout cannot be default" + + new_attrs['data_layout'] = desired_data_layout + + if desired_data_layout == 'NCHW': + if desired_kernel_layout != 'default': + new_attrs['kernel_layout'] = desired_kernel_layout + else: + new_attrs['kernel_layout'] = 'OIHW' # Actual insertion of layout transforms is taken care internally # by ConvertLayout pass. return relay.nn.conv2d(data, weight, **new_attrs) - return None + + raise ValueError('Layout %s is not yet supported' % desired_data_layout) **FInferCorrectLayout - Layout inference** - Currently, this attribute is exposed only in C++. This function takes original input layouts and the new input layouts (passed from the previous operator or from the python callback for layout alteration), and infers the final data layouts. Layout inference is called for each operator. The usage might vary for different operator categories. For layout agnostic operators, we just want to return the new data layouts in this function. For lightly-layout and heavily-layout sensitive operators, we can change the operator attributes (like axis for concatenate, pad_width for pad) so that we can adapt to the new data layout, preventing insertion of layout transforms. Let's look at a couple of examples to understand this better. @@ -218,6 +230,8 @@ Second example is for a lightly-layout sensitive operator - batch normalization. ConvertLayout pass is extremely easy to use. The pass is not a part of default relay.build pipeline. The intended usage is to call it between the framework-to-relay parser and relay.build module call. +In order to specify the layouts to convert to, we create a mapping of heavily-layout sensitive operators to a list of the desired layouts for that operator. The first example below specifies data layout, we allow the kernel layout to be automatically converted to one that is supported by TVM (for that particular data layout and operator). This is specified by the use of the "default" keyword. The second example shows how we could have also converted to a specific kernel layout of our choosing. It's worth noting that the following examples will convert to the same layouts i.e. `{'nn.conv2d': ['NCHW', 'default']} == {'nn.conv2d': ['NCHW', 'OIHW']}` + .. code-block:: python # TFlite framework to Relay parser - Default layout is NHWC @@ -225,17 +239,29 @@ ConvertLayout pass is extremely easy to use. The pass is not a part of default r shape_dict=shape_dict, dtype_dict=dtype_dict) + # We assume our model's heavily-layout sensitive operators only consist of nn.conv2d + desired_layouts = {'nn.conv2d': ['NCHW', 'default']} + # Convert the layout to NCHW # RemoveUnunsedFunctions is used to clean up the graph. seq = tvm.transform.Sequential([relay.transform.RemoveUnusedFunctions(), - relay.transform.ConvertLayout('NCHW')]) - with relay.transform.PassContext(opt_level=3): + relay.transform.ConvertLayout(desired_layouts)]) + with tvm.transform.PassContext(opt_level=3): mod = seq(mod) # Call relay compilation with relay.build_config(opt_level=3): graph, lib, params = relay.build(mod, target, params=params) + +.. code-block:: python + + desired_layouts = {'nn.conv2d': ['NCHW', 'OIHW']} + pass = relay.transform.ConvertLayout(desired_layouts) + + +The ordering of the layouts is defined by the implementation of `register_convert_op_layout("OPNAME")`, you can refer to the docstring which should explicitly state the expected layout. In the examples above it's [data_layout, kernel_layout]. + Current implementation has support for almost all the operators commonly used in image classification models. However, if one encounters too many data layout transforms in the graph, it is highly likely that there is an operator whose layouts need special handling as described in Section 3. Some pull requests that can help in such a situation are - Layout inference for `Batch Norm `_ - Batch normalization falls into the category of lightly-sensitive operator. The PR shows how to handle the layout inference for batch norm. diff --git a/docs/dev/relay_add_op.rst b/docs/dev/relay_add_op.rst index f494cc618850..7dca251dd532 100644 --- a/docs/dev/relay_add_op.rst +++ b/docs/dev/relay_add_op.rst @@ -99,7 +99,7 @@ the arguments to the call node, as below. TVM_REGISTER_GLOBAL("relay.op._make.add") .set_body_typed([](Expr lhs, Expr rhs) { static const Op& op = Op::Get("add"); - return CallNode::make(op, {lhs, rhs}, Attrs(), {}); + return Call(op, {lhs, rhs}, Attrs(), {}); }); Including a Python API Hook diff --git a/docs/dev/relay_add_pass.rst b/docs/dev/relay_add_pass.rst index 8a6f8be0aea8..a82ae4ff717a 100644 --- a/docs/dev/relay_add_pass.rst +++ b/docs/dev/relay_add_pass.rst @@ -138,7 +138,7 @@ is shown below. if (g->tuple == t) { return GetRef(g); } else { - return TupleGetItemNode::make(t, g->index); + return TupleGetItem(t, g->index); } } @@ -261,7 +261,7 @@ the pass. body.same_as(op->body)) { return GetRef(op); } else { - return LetNode::make(var, value, body); + return Let(var, value, body); } } } @@ -292,7 +292,7 @@ pointed to by ``op->index``. The reason we need to check is because .. code:: c Expr VisitExpr_(const CallNode* call) final { - static auto op_stateful = Op::GetAttr("TOpIsStateful"); + static auto op_stateful = Op::GetAttrMap("TOpIsStateful"); Expr res = ExprMutator::VisitExpr_(call); call = res.as(); // We don't constant fold function with zero arguments. diff --git a/docs/dev/relay_pass_infra.rst b/docs/dev/relay_pass_infra.rst index b40b06e21d0a..446a91bceff7 100644 --- a/docs/dev/relay_pass_infra.rst +++ b/docs/dev/relay_pass_infra.rst @@ -344,13 +344,13 @@ registration. .. code:: c++ // Create a simple Relay program. - auto tensor_type = relay::TensorTypeNode::make({}, tvm::Bool()); - auto x = relay::VarNode::make("x", relay::Type()); - auto f = relay::FunctionNode::make(tvm::Array{ x }, x, relay::Type(), {}); + auto tensor_type = relay::TensorType({}, tvm::Bool()); + auto x = relay::Var("x", relay::Type()); + auto f = relay::Function(tvm::Array{ x }, x, relay::Type(), {}); - auto y = relay::VarNode::make("y", tensor_type); - auto call = relay::CallNode::make(f, tvm::Array{ y }); - auto fx = relay::FunctionNode::make(tvm::Array{ y }, call, relay::Type(), {}); + auto y = relay::Var("y", tensor_type); + auto call = relay::Call(f, tvm::Array{ y }); + auto fx = relay::Function(tvm::Array{ y }, call, relay::Type(), {}); // Create a module for optimization. auto mod = IRModule::FromExpr(fx); diff --git a/docs/index.rst b/docs/index.rst index 258547a34acd..5e3fa45a5c66 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -37,7 +37,7 @@ API Reference langref/index api/python/index - api_links + api/links Developer Guide --------------- @@ -47,7 +47,7 @@ Developer Guide dev/index Frontends ----------------- +--------- .. toctree:: :maxdepth: 1 diff --git a/docs/install/nnpack.md b/docs/install/nnpack.md deleted file mode 100644 index e1bcb701d63e..000000000000 --- a/docs/install/nnpack.md +++ /dev/null @@ -1,100 +0,0 @@ - - - - - - - - - - - - - - - - - -# NNPACK Contrib Installation - -[NNPACK](https://github.com/Maratyszcza/NNPACK) is an acceleration package -for neural network computations, which can run on x86-64, ARMv7, or ARM64 architecture CPUs. -Using NNPACK, higher-level libraries like _MXNet_ can speed up -the execution on multi-core CPU computers, including laptops and mobile devices. - -***Note***: AS TVM already has natively tuned schedules, NNPACK is here mainly for reference and comparison purpose. -For regular use prefer native tuned TVM implementation. - -_TVM_ supports NNPACK for forward propagation (inference only) in convolution, max-pooling, and fully-connected layers. -In this document, we give a high level overview of how to use NNPACK with _TVM_. - -## Conditions -The underlying implementation of NNPACK utilizes several acceleration methods, -including [fft](https://arxiv.org/abs/1312.5851) and [winograd](https://arxiv.org/abs/1509.09308). -These algorithms work better on some special `batch size`, `kernel size`, and `stride` settings than on other, -so depending on the context, not all convolution, max-pooling, or fully-connected layers can be powered by NNPACK. -When favorable conditions for running NNPACKS are not met, - -NNPACK only supports Linux and OS X systems. Windows is not supported at present. - -## Build/Install NNPACK - -If the trained model meets some conditions of using NNPACK, -you can build TVM with NNPACK support. -Follow these simple steps: -* Build NNPACK shared library with the following commands. _TVM_ will link NNPACK dynamically. - -Note: The following NNPACK installation instructions have been tested on Ubuntu 16.04. - -### Build [Ninja](https://ninja-build.org/) - -NNPACK need a recent version of Ninja. So we need to install ninja from source. -```bash -git clone git://github.com/ninja-build/ninja.git -cd ninja -./configure.py --bootstrap -``` - -Set the environment variable PATH to tell bash where to find the ninja executable. For example, assume we cloned ninja on the home directory ~. then we can added the following line in ~/.bashrc. -```bash -export PATH="${PATH}:~/ninja" -``` - -### Build [NNPACK](https://github.com/Maratyszcza/NNPACK) - -The new CMAKE version of NNPACK download [Peach](https://github.com/Maratyszcza/PeachPy) and other dependencies alone - -Note: at least on OS X, running `ninja install` below will overwrite googletest libraries installed in `/usr/local/lib`. If you build googletest again to replace the nnpack copy, be sure to pass `-DBUILD_SHARED_LIBS=ON` to `cmake`. - -```bash -git clone --recursive https://github.com/Maratyszcza/NNPACK.git -cd NNPACK -# Add PIC option in CFLAG and CXXFLAG to build NNPACK shared library -sed -i "s|gnu99|gnu99 -fPIC|g" CMakeLists.txt -sed -i "s|gnu++11|gnu++11 -fPIC|g" CMakeLists.txt -mkdir build -cd build -# Generate ninja build rule and add shared library in configuration -cmake -G Ninja -D BUILD_SHARED_LIBS=ON .. -ninja -sudo ninja install - -# Add NNPACK lib folder in your ldconfig -echo "/usr/local/lib" > /etc/ld.so.conf.d/nnpack.conf -sudo ldconfig -``` - -## Build TVM with NNPACK support - -```bash -git clone --recursive https://github.com/apache/incubator-tvm tvm -``` - -* Set `set(USE_NNPACK ON)` in config.cmake. -* Set `NNPACK_PATH` to the $(YOUR_NNPACK_INSTALL_PATH) - -after configuration use `make` to build TVM - -```bash -make -``` diff --git a/docs/install/nnpack.rst b/docs/install/nnpack.rst new file mode 100644 index 000000000000..10497ba05654 --- /dev/null +++ b/docs/install/nnpack.rst @@ -0,0 +1,118 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +NNPACK Contrib Installation +=========================== + +`NNPACK `_ is an acceleration package +for neural network computations, which can run on x86-64, ARMv7, or ARM64 architecture CPUs. +Using NNPACK, higher-level libraries like _MXNet_ can speed up +the execution on multi-core CPU computers, including laptops and mobile devices. + +.. note:: + + AS TVM already has natively tuned schedules, NNPACK is here mainly for reference and comparison purpose. + For regular use prefer native tuned TVM implementation. + +TVM supports NNPACK for forward propagation (inference only) in convolution, max-pooling, and fully-connected layers. +In this document, we give a high level overview of how to use NNPACK with TVM. + +Conditions +---------- + +The underlying implementation of NNPACK utilizes several acceleration methods, +including fft and winograd. +These algorithms work better on some special `batch size`, `kernel size`, and `stride` settings than on other, +so depending on the context, not all convolution, max-pooling, or fully-connected layers can be powered by NNPACK. +When favorable conditions for running NNPACKS are not met, + +NNPACK only supports Linux and OS X systems. Windows is not supported at present. + +Build/Install NNPACK +-------------------- + +If the trained model meets some conditions of using NNPACK, +you can build TVM with NNPACK support. +Follow these simple steps: + +uild NNPACK shared library with the following commands. TVM will link NNPACK dynamically. + +Note: The following NNPACK installation instructions have been tested on Ubuntu 16.04. + +Build Ninja +~~~~~~~~~~~ + +NNPACK need a recent version of Ninja. So we need to install ninja from source. + +.. code:: bash + + git clone git://github.com/ninja-build/ninja.git + cd ninja + ./configure.py --bootstrap + + +Set the environment variable PATH to tell bash where to find the ninja executable. For example, assume we cloned ninja on the home directory ~. then we can added the following line in ~/.bashrc. + + +.. code:: bash + + export PATH="${PATH}:~/ninja" + + +Build NNPACK +~~~~~~~~~~~~ + +The new CMAKE version of NNPACK download `Peach `_ and other dependencies alone + +Note: at least on OS X, running `ninja install` below will overwrite googletest libraries installed in `/usr/local/lib`. If you build googletest again to replace the nnpack copy, be sure to pass `-DBUILD_SHARED_LIBS=ON` to `cmake`. + +.. code:: bash + + git clone --recursive https://github.com/Maratyszcza/NNPACK.git + cd NNPACK + # Add PIC option in CFLAG and CXXFLAG to build NNPACK shared library + sed -i "s|gnu99|gnu99 -fPIC|g" CMakeLists.txt + sed -i "s|gnu++11|gnu++11 -fPIC|g" CMakeLists.txt + mkdir build + cd build + # Generate ninja build rule and add shared library in configuration + cmake -G Ninja -D BUILD_SHARED_LIBS=ON .. + ninja + sudo ninja install + + # Add NNPACK lib folder in your ldconfig + echo "/usr/local/lib" > /etc/ld.so.conf.d/nnpack.conf + sudo ldconfig + + +Build TVM with NNPACK support +----------------------------- + +.. code:: bash + + git clone --recursive https://github.com/apache/incubator-tvm tvm + +- Set `set(USE_NNPACK ON)` in config.cmake. +- Set `NNPACK_PATH` to the $(YOUR_NNPACK_INSTALL_PATH) + +after configuration use `make` to build TVM + + +.. code:: bash + + make diff --git a/docs/langref/index.rst b/docs/langref/index.rst index 0d296118da26..dcea9fa50c3d 100644 --- a/docs/langref/index.rst +++ b/docs/langref/index.rst @@ -46,6 +46,7 @@ algebraic data types, and operators in Relay, respectively. relay_type relay_adt relay_op + relay_pattern Hybrid Script ------------- diff --git a/docs/langref/relay_op.rst b/docs/langref/relay_op.rst index 798d440f7425..cef96ef65931 100644 --- a/docs/langref/relay_op.rst +++ b/docs/langref/relay_op.rst @@ -69,6 +69,8 @@ This level enables typical convnet models. tvm.relay.nn.conv2d tvm.relay.nn.conv2d_transpose + tvm.relay.nn.conv3d + tvm.relay.nn.conv3d_transpose tvm.relay.nn.dense tvm.relay.nn.max_pool2d tvm.relay.nn.max_pool3d @@ -118,6 +120,7 @@ This level enables additional math and transform operators. tvm.relay.zeros_like tvm.relay.ones tvm.relay.ones_like + tvm.relay.gather tvm.relay.gather_nd tvm.relay.full tvm.relay.full_like @@ -130,6 +133,7 @@ This level enables additional math and transform operators. tvm.relay.tile tvm.relay.reverse tvm.relay.unravel_index + tvm.relay.sparse_to_dense **Level 4: Broadcast and Reductions** @@ -224,4 +228,4 @@ This level supports dialect operators. :nosignatures: tvm.relay.qnn.op.requantize - tvm.relay.qnn.op.conv2d \ No newline at end of file + tvm.relay.qnn.op.conv2d diff --git a/docs/langref/relay_pattern.rst b/docs/langref/relay_pattern.rst new file mode 100644 index 000000000000..962dcc677bcb --- /dev/null +++ b/docs/langref/relay_pattern.rst @@ -0,0 +1,409 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +========================= +Pattern Matching in Relay +========================= + +There are many places in TVM where we identify pure data-flow sub-graphs of the Relay program and attempt to transform them in some way example passes include fusion, quantization, external code generation, and device specific optimizations such as bitpacking, and layer slicing used by VTA. + +Many of these passes today require a lots of boring boilerplate code in order to implement as well as requiring users to think in terms of visitors and AST matching. Many of these transformations can easily be described in terms of graph rewrites. In order to build a rewriter or other advanced machinery we first need a language of patterns to describe what we can match. + +Such a language is not just useful for building a rewriter but also providing extension points for existing passes. For example the fusion pass could be parameterized by a set of fusion patterns which describes the capability of your hardware, and the quantization pass could take a set of patterns which describe which operators can be quantized on a given platform. + +In the backend world, we could use the same machinery to build a higher level API using bring your own code generation. This API takes set of patterns describing your hardware capabilities and an external compiler, providing a relatively smooth heterogeneous experience out of the box. + +Pattern Examples +================ + +There are quite a few properties of operators that are worth matching. Below we examine how to match tree properties, and expand on some use cases that are not fully explored in the prototype. This section +demonstrates how to write patterns. It is recommended to check `tests/python/relay/test_dataflow_pattern.py`_ +for more use cases. + +.. _tests/python/relay/test_dataflow_pattern.py: https://github.com/apache/incubator-tvm/blob/master/tests/python/relay/test_dataflow_pattern.py + +.. note:: + + If you cannot find the corresponding pattern node to match the Relay node you want, + you are welcome to raise an issue or submit a PR to add it. + +Matching One of Two Ops +*********************** + +The first example is a simple case where we want to match one operator with a single input OR +another operator with a single input: + +.. code-block:: python + + def test_match_op_or(): + is_add_or_sub = is_op('add') | is_op('subtract') + assert is_add_or_sub.match(relay.op.op.get("add")) + assert is_add_or_sub.match(relay.op.op.get("subtract")) + + +Matching an Op with Attributes +****************************** + +The next example is a dense operation with any operator that is marked element-wise: + +.. code-block:: python + + def test_no_match_attr(): + op = is_op('nn.dense').has_attr({"TOpPattern": K_ELEMWISE}) + op_pat = op(wildcard(), wildcard()) + x = relay.var('x') + y = relay.var('y') + assert not op_pat.match(relay.op.nn.dense(x, y)) + +Here is another example to match an op with a specific attribute: + +.. code-block:: python + + def test_match_data_layout(): + is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()).has_attr({"data_layout": "NHWC"}) + x = relay.var('x') + y = relay.var('y') + assert not is_conv2d.match(relay.op.nn.conv2d(x, y)) + +Matching an Optional Op +*********************** + +The next example is matching a pattern with one optional operator. In this pattern, +we can match the graph of conv2d+bias_add+relu or the graph of conv2d+bias_add. + +.. code-block:: python + + def test_match_optional(): + conv_node = is_op('nn.conv2d')(wildcard(), wildcard()) + bias_node = is_op('nn.bias_add')(conv_node, wildcard()) + pat = bias_node.optional(lambda x: is_op('nn.relu')(x)) + + x = relay.var('x') + y = relay.var('y') + z = relay.var('z') + conv2d = relay.op.nn.conv2d(x, y) + bias = relay.op.nn.bias_add(conv2d, z) + assert pat.match(bias) + relu = relay.op.nn.relu(bias) + assert pat.match(relu) + +Matching Non-Call Nodes +*********************** + +Sometimes we may also want to match a pattern that includes Tuple or TupleGetItem nodes. +Since there are not call nodes, we need to use specific pattern nodes to match them: + +.. code-block:: python + + def test_match_tuple(): + x = relay.var('x') + y = relay.var('y') + z = relay.var('z') + tuple_pattern = is_tuple((wildcard(), wildcard(), wildcard())) + assert tuple_pattern.match(relay.expr.Tuple((x,y,z))) + +The next example is matching a pattern of batch_norm -> get(0) -> relu: + +.. code-block:: python + + def test_match_tuple_get_item(): + bn_node = is_op('nn.batch_norm')(wildcard(), wildcard(), wildcard(), wildcard(), wildcard()) + tuple_get_item_node = is_tuple_get_item(bn_node, 0) + pat = is_op('nn.relu')(tuple_get_item_node) + + x = relay.var('x', shape=(1, 8)) + gamma = relay.var("gamma", shape=(8,)) + beta = relay.var("beta", shape=(8,)) + moving_mean = relay.var("moving_mean", shape=(8,)) + moving_var = relay.var("moving_var", shape=(8,)) + bn_node = relay.nn.batch_norm(x, gamma, beta, moving_mean, moving_var) + tuple_get_item_node = bn_node[0] + out = relay.nn.relu(tuple_get_item_node) + pat.match(out) + +The next example is matching a constant node regarding its values. This is useful to check +if a specific parameter in a subgraph has been bound or not. + +.. code-block:: python + + def test_match_constant(): + conv2d = is_op('nn.conv2d')(wildcard(), is_constant()) + pattern = is_op('nn.bias_add')(conv2d, wildcard()) + + x = relay.var('x', shape=(1, 3, 224, 224)) + w = relay.var('w', shape=(3, 3, 3, 3)) + b = relay.var('b', shape=(3, )) + conv2d = relay.op.nn.conv2d(x, w) + out = relay.op.nn.bias_add(conv2d, b) + func = relay.Function([x, w, b], out) + mod = tvm.IRModule.from_expr(func) + + # Two inputs of the conv2d in the graph are VarNode by default, so no match. + assert not pattern.match(mod['main'].body) + + # The second input (weight) has been bind with constant values so it is now a constant node. + mod["main"] = bind_params_by_name(mod["main"], + {'w': tvm.nd.array(np.ones(shape=(3, 3, 3, 3)))}) + assert pattern.match(mod['main'].body) + +On the other hand, if you need to match the constant with a specific value, you can directly +use ``is_expr``. This could be useful for algebraic simplify. + +.. code-block:: python + + def test_match_plus_zero(): + zero = (is_expr(relay.const(0)) | is_expr(relay.const(0.0))) + pattern = wildcard() + zero + + x = relay.Var('x') + y = x + relay.const(0) + assert pattern.match(y) + +The next example is matching function nodes with a specific attribute: + +.. code-block:: python + + def test_match_function(): + pattern = wildcard().has_attr({"Composite": "add"}) + + x = relay.var('x') + y = relay.var('y') + f = relay.Function([x, y], x + y).with_attr("Composite", "add") + assert pattern.match(f) + +Matching Diamonds and Post-Dominator Graphs +******************************************* + +The next example is matching a diamond with two inputs at the top of the diamond:: + + def test_match_diamond(): + # Pattern + is_conv2d = is_op('nn.conv2d')(is_var(), is_var()) + path1 = is_op('nn.relu')(is_conv2d) + path2 = is_op('nn.leaky_relu')(is_conv2d) + diamond = is_op('add')(path1, path2) + + # Expr + inp = relay.var('input') + weight = relay.var('weight') + conv2d = relay.op.nn.conv2d(inp, weight) + relu = relay.op.nn.relu(conv2d) + leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0) + out = relu + leaky_relu + + # Check + assert diamond.match(out) + +The final example is matching diamonds with a post-dominator relationship. We embed dominator analysis as type of matching in the pattern language in order to allow for pattern matching with unknown topology. This is important because we want to be able to use the language to describe fuse patterns, like elementwise operations followed by a conv2d:: + + def test_match_dom_diamond(): + # Pattern + is_conv2d = is_op('nn.conv2d')(is_var(), is_var()) + reduction = is_op('add')(wildcard(), wildcard()) + diamond = dominates(is_conv2d, is_elemwise, reduction) + + # Expr + inp = relay.var('input') + weight = relay.var('weight') + conv2d = relay.op.nn.conv2d(inp, weight) + relu = relay.op.nn.relu(conv2d) + leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0) + out = relu + leaky_relu + + # Check + assert diamond.match(out) + +Pattern Language Design +======================= + +The pattern language proposed is designed to be a mirror of Relay's IR with additional support for common scenarios. The goal of the pattern language is to provide a regular-expression like capability for matching data-flow graphs and doing rewriting. + +The high level design is to introduce a language of patterns for now we propose the language as:: + + Pattern ::= expr + | * + | pattern(pattern1, ... patternN) + | has_type(pattern, type) + | has_attr(pattern, attrs) + | is_var(name) + | is_constant() + | is_expr(expr) + | is_op(op_name) + | is_tuple() + | is_tuple_get_item() + | pattern1 `|` pattern2 + | dominates(parent_pattern, path_pattern, child_pattern) + +The above language then provides a matching interface with both can select sub-graphs as well as verify that the graph does match the pattern. + +Expression Pattern +****************** + +Match a literal expression. + +Wildcard +******** + +Match any expression. + +Type Pattern +************ + +Check that the expression matched by the nested pattern has a particular type. + +Attribute Pattern +***************** + +Check that the operator matched by the pattern has an attribute with a particular value. + +Variable Pattern +**************** + +Check that the expression is a relay Variable, and optional provide a name to match to the Variable name. + + +Alternate +********* + +Either match the first pattern or the second pattern. + +Domination +********** + +Match child pattern, find a match for the parent pattern, insuring that the child ultimately dominates the parrent (i.e., no nodes outside the pattern use outputs of the parent), and that ever node betwen the child and the pattern matches the path pattern. + +Applications +============ + +The pattern language provides not only the pattern matching but also pattern processing. +Here we introduce two pattern processing approaches and provide some examples. + +Pattern Rewriting +***************** + +If you would like to replace the matched pattern with another subgraph, you can leverage +the ``rewrite`` transformation. Here is an example of rewriting a series of arithmetic operators +with a single batch_norm op: + +.. code-block:: python + + class BatchnormCallback(DFPatternCallback): + # A callback class to rewrite the matched pattern to a batch_norm op. + def __init__(self): + self.x = wildcard() + self.var = wildcard() + self.mean = wildcard() + self.beta = wildcard() + self.gamma = wildcard() + self.eps = wildcard() + + self.pattern = self.gamma * (self.x - self.mean)/is_op("sqrt")(self.var + self.eps) + self.beta + + def callback(self, pre, post, node_map): + x = node_map[self.x][0] + var = node_map[self.var][0] + mean = node_map[self.mean][0] + beta = node_map[self.beta][0] + gamma = node_map[self.gamma][0] + eps = node_map[self.eps][0] + return relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon = eps.data.asnumpy().item())[0] + + # A graph of arithmetic operators that are functional equivalent to batch_norm. + x = relay.var('x') + var = relay.var('var') + mean = relay.var('mean') + beta = relay.var('beta') + gamma = relay.var('gamma') + BN = gamma * (x - mean)/relay.op.sqrt(var + relay.const(1e-5)) + beta + + from tvm.relay.dataflow_pattern import rewrite + out = rewrite(BatchnormCallback(), BN) + assert tvm.ir.structural_equal(out, relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon = 1e-5)[0]) + +The function ``def callback(self, pre, post, node_map)`` will be invoked when the rewriter matches +``self.pattern``. ``node_map`` is a dictionary mapping from pattern nodes to matched nodes in the graph. + +Pattern Partitioning +******************** + +If you would like to perform a more complex processing for matched subgraphs and you are not +satisfied with ``rewrite``, you may consider partitioning the matched subgraphs to a separate +Relay function and perform other processes to the function. Here we use ``pattern.partition`` +to create a new Relay function for each matched subgraph. The functionality is similar to +the op fusion pass in TVM: + +.. code-block:: python + + # A pattern matching conv2d+relu. + pattern = is_op("nn.relu")(is_op("nn.conv2d")(wildcard(), wildcard())) + + # A graph. + x = relay.var('input') + w = relay.var('weight') + conv2d = relay.op.nn.conv2d(x, w) + relu = relay.op.nn.relu(conv2d) + print('relu') + # free_var %x: Tensor[(1, 3, 224, 224), float32] + # free_var %w: Tensor[(3, 3, 3, 3), float32] + # %0 = nn.conv2d(%x, %w, padding=[0, 0, 0, 0]) /* ty=Tensor[(1, 3, 222, 222), float32] */; + # free_var %b: Tensor[(3), float32] + # nn.bias_add(%0, %b) /* ty=Tensor[(1, 3, 222, 222), float32] */ + + # After partition. + print(pattern.partition(relu)) + # free_var %x: Tensor[(1, 3, 224, 224), float32] + # free_var %w: Tensor[(3, 3, 3, 3), float32] + # free_var %b: Tensor[(3), float32] + # %1 = fn (%FunctionVar_0_0, %FunctionVar_0_1, + # %FunctionVar_0_2, PartitionedFromPattern="nn.conv2d_nn.bias_add_") { + # %0 = nn.conv2d(%FunctionVar_0_0, %FunctionVar_0_1, padding=[0, 0, 0, 0]); + # nn.bias_add(%0, %FunctionVar_0_2) + # }; + # %1(%x, %w, %b) + +Note that you can also specify the attributes for the created functions: + +.. code-block:: python + + print(pattern.partition(relu, {'Composite': 'one_layer'})) + # free_var %x: Tensor[(1, 3, 224, 224), float32] + # free_var %w: Tensor[(3, 3, 3, 3), float32] + # free_var %b: Tensor[(3), float32] + # %1 = fn (%FunctionVar_0_0, %FunctionVar_0_1, + # %FunctionVar_0_2, Composite="one_layer", + # PartitionedFromPattern="nn.conv2d_nn.bias_add_") { + # %0 = nn.conv2d(%FunctionVar_0_0, %FunctionVar_0_1, padding=[0, 0, 0, 0]); + # nn.bias_add(%0, %FunctionVar_0_2) + # }; + # %1(%x, %w, %b) + +If you need a customized checking function that cannot be specified using pattern language, +you can specify ``check`` function when partitioning. The following example demonstrates a +case that checks input data layout of a subgraph: + +.. code-block:: python + + def check(pre): + conv = pre.args[0] + return (conv.attrs.data_layout == "NCHW") and bool(conv.checked_type.shape[0] == 1) + + pattern.partition(relu, check=check) + +In this example, we check if the first argument of the matched subgraph (i.e., ``pre.args[0]``) +has data layout "NCHW" and if its batch size is 1. This feature is useful if the conditions +of matching a pattern cannot be verified by analyzing the pattern itself. diff --git a/docs/vta/install.md b/docs/vta/install.md deleted file mode 100644 index a938a67218ff..000000000000 --- a/docs/vta/install.md +++ /dev/null @@ -1,419 +0,0 @@ - - - - - - - - - - - - - - - - - -VTA Installation Guide -====================== - -We present three installation guides, each extending on the previous one: -1. [Simulator installation](#vta-simulator-installation) -2. [PYNQ-based test setup](#vta-pynq-based-test-setup) -3. [Custom test setup for Intel FPGA](#vta-custom-test-setup-for-intel-fpga) -4. [FPGA toolchain installation](#vta-fpga-toolchain-installation) - -## VTA Simulator Installation - -You need [TVM installed](https://tvm.apache.org/docs/install/index.html) on your machine. -For a quick and easy start, checkout the [Docker Guide](https://tvm.apache.org/docs/install/docker.html). - -You'll need to set the following paths to use VTA: -```bash -export TVM_PATH= -export VTA_HW_PATH=$TVM_PATH/3rdparty/vta-hw -``` - -The VTA functional simulation library needs to be enabled when building TVM. -```bash -cd -mkdir build -cp cmake/config.cmake build/. -echo 'set(USE_VTA_FSIM ON)' >> build/config.cmake -cd build && cmake .. && make -j4 -``` - -Add the VTA python library to your python path to run the VTA examples. - -```bash -export PYTHONPATH=/path/to/vta/python:${PYTHONPATH} -``` - -### Testing your VTA Simulation Setup - -To ensure that you've properly installed the VTA python package, run the following 2D convolution testbench. - -```bash -python /vta/tests/python/integration/test_benchmark_topi_conv2d.py -``` - -> Note: You'll notice that for every convolution layer, the throughput gets reported in GOPS. These numbers are actually the computational throughput that the simulator achieves, by evaluating the convolutions in software. - -You are invited to try out our [VTA programming tutorials](https://tvm.apache.org/docs/vta/tutorials/index.html). - - -### Advanced Configuration (optional) - -VTA is a generic configurable deep learning accelerator. -The configuration is specified by `vta_config.json` under `3rdparty/vta-hw/config`. -This file provides an architectural specification of the VTA accelerator to parameterize the TVM compiler stack and the VTA hardware stack. - -The VTA configuration file also specifies the TVM compiler target. -When `TARGET` is set to `sim`, all TVM workloads execute on the VTA simulator. -You can modify the content of the configuration file to rebuild VTA to a different parameterization. -To do so, - -```bash -cd -vim 3rdparty/vta-hw/config/vta_config.json -# edit vta_config.json -make -``` - -## VTA Pynq-Based Test Setup - -This second guide extends the *VTA Simulator Installation* guide above to run FPGA hardware tests of the complete TVM and VTA software-hardware stack. -In terms of hardware components you'll need: -* The [Pynq](http://www.pynq.io/) FPGA development board which can be acquired for $200, or $150 for academics from [Digilent](https://store.digilentinc.com/pynq-z1-python-productivity-for-zynq/). -* An Ethernet-to-USB adapter to connect the Pynq board to your development machine. -* An 8+GB micro SD card. -* An AC to DC 12V 3A power adapter. - -This guide covers the following themes: -1. Pynq board setup instructions. -2. Pynq-side RPC server build and deployment. -3. Revisiting the test examples from the *VTA Simulator Installation* guide, this time executing on the Pynq board. - -### Pynq Board Setup - -Setup your Pynq board based on the [Pynq board getting started tutorial](http://pynq.readthedocs.io/en/latest/getting_started.html). -You should follow the instructions up to and including the *Turning On the PYNQ-Z1* step (no need to pursue the tutorial beyond this point). -* Make sure that you've downloaded the latest Pynq image, [PYNQ-Z1 v2.4](http://www.pynq.io/board.html)(released February 22rd 2019), and have imaged your SD card with it (we recommend the free [Etcher](https://etcher.io/) program). -* For this test setup, follow the ["Connect to a Computer"](http://pynq.readthedocs.io/en/latest/getting_started.html#connect-to-a-computer) Ethernet setup instructions. To be able to talk to the board, make sure to [assign your computer a static IP address](http://pynq.readthedocs.io/en/latest/appendix.html#assign-your-computer-a-static-ip) - -Once the board is powered on and connected to your development machine, try connecting to it to make sure you've properly set up your Pynq board: -```bash -# To connect to the Pynq board use the [username, password] combo: [xilinx, xilinx] -ssh xilinx@192.168.2.99 -``` - -### Pynq-Side RPC Server Build & Deployment - -Because the direct board-to-computer connection prevents the board from directly accessing the internet, we'll need to mount the Pynq's file system to your development machine's file system with [sshfs](https://www.digitalocean.com/community/tutorials/how-to-use-sshfs-to-mount-remote-file-systems-over-ssh). Next we directly clone the TVM repository into the sshfs mountpoint on your development machine. - -```bash -# On the Host-side -mkdir -sshfs xilinx@192.168.2.99:/home/xilinx -cd -git clone --recursive https://github.com/apache/incubator-tvm tvm -# When finished, you can leave the moutpoint and unmount the directory -cd ~ -sudo umount -``` - -Now that we've cloned the VTA repository in the Pynq's file system, we can ssh into it and launch the build of the TVM-based RPC server. -The build process should take roughly 5 minutes. - -```bash -ssh xilinx@192.168.2.99 -# Build TVM runtime library (takes 5 mins) -cd /home/xilinx/tvm -mkdir build -cp cmake/config.cmake build/. -echo 'set(USE_VTA_FPGA ON)' >> build/config.cmake -# Copy pynq specific configuration -cp 3rdparty/vta-hw/config/pynq_sample.json 3rdparty/vta-hw/config/vta_config.json -cd build -cmake .. -make runtime vta -j2 -# Build VTA RPC server (takes 1 min) -cd .. -sudo ./apps/vta_rpc/start_rpc_server.sh # pw is 'xilinx' -``` - -You should see the following being displayed when starting the RPC server. In order to run the next examples, you'll need to leave the RPC server running in an `ssh` session. -``` -INFO:root:RPCServer: bind to 0.0.0.0:9091 -``` - -Tips regarding the Pynq RPC Server: -* The RPC server should be listening on port `9091`. If not, an earlier process might have terminated unexpectedly and it's recommended in this case to just reboot the Pynq, and re-run the RPC server. -* To kill the RPC server, just send the `Ctrl + c` command. You can re-run it with `sudo ./apps/pynq_rpc/start_rpc_server.sh`. -* If unresponsive, the board can be rebooted by power-cycling it with the physical power switch. - -### Testing your Pynq-based Hardware Setup - -Before running the examples on your development machine, you'll need to configure your host environment as follows: -```bash -# On the Host-side -export VTA_RPC_HOST=192.168.2.99 -export VTA_RPC_PORT=9091 -``` - -In addition, you'll need to edit the `vta_config.json` file on the host to indicate that we are targeting the Pynq platform, by setting the `TARGET` field to `"pynq"`. -> Note: in contrast to our simulation setup, there are no libraries to compile on the host side since the host offloads all of the computation to the Pynq board. - -```bash -# On the Host-side -cd -cp 3rdparty/vta-hw/config/pynq_sample.json 3rdparty/vta-hw/config/vta_config.json -``` - -This time again, we will run the 2D convolution testbench. -Beforehand, we need to program the Pynq board FPGA with a VTA bitstream, and build the VTA runtime via RPC. -The following `test_program_rpc.py` script will perform two operations: -* FPGA programming, by downloading a pre-compiled bitstream from a [VTA bitstream repository](https://github.com/uwsaml/vta-distro) that matches the default `vta_config.json` configuration set by the host, and sending it over to the Pynq via RPC to program the Pynq's FPGA. -* Runtime building on the Pynq, which needs to be run every time the `vta_config.json` configuration is modified. This ensures that the VTA software runtime that generates the accelerator's executable via just-in-time (JIT) compilation matches the specifications of the VTA design that is programmed on the FPGA. The build process takes about 30 seconds to complete so be patient! - -```bash -# On the Host-side -python /vta/tests/python/pynq/test_program_rpc.py -``` - -> Tip: You can track progress of the FPGA programming and the runtime rebuilding steps by looking at the RPC server's logging messages in your Pynq `ssh` session. - -We are now ready to run the 2D convolution testbench in hardware. - -```bash -# On the Host-side -python /vta/tests/python/integration/test_benchmark_topi_conv2d.py -``` - -The performance metrics measured on the Pynq board will be reported for each convolutional layer. - -You can also try out our [VTA programming tutorials](https://tvm.apache.org/docs/vta/tutorials/index.html). - -## VTA Custom Test Setup for Intel FPGA - -Similar to the PYNQ side setup steps, this third guide bring us the details on how can we setup up the Linux environment for Intel FPGA boards like DE10-Nano. - -In terms of hardware components, you would need the [DE10-Nano Development Kit](https://www.terasic.com.tw/cgi-bin/page/archive.pl?Language=English&No=1046), which can be acquired for $130, or $100 for academics from [Terasic](https://www.terasic.com.tw/). A microSD card would be delivered the kit. Power cables and USB cables would be included as well. However, an additional Ethernet cable would be needed to connect the board to LAN. - -The rest part of this guide would provide the steps to - -* Flash the microSD card with latest Angstrom Linux image -* Cross compilation setup -* Device-side RPC server setup and deployment - -### DE10-Nano Board Setup - -Before powering up the device, we need to flash the microSD card image with latest Angstrom Linux image. - -#### Flash SD Card and Boot Angstrom Linux - -To flash SD card and boot Linux on DE10-Nano, it is recommended to navigate to the [Resource](https://www.terasic.com.tw/cgi-bin/page/archive.pl?Language=English&CategoryNo=167&No=1046&PartNo=4) tab of the DE10-Nano product page from Terasic Inc. -After registration and login on the webpage, the prebuilt Angstrom Linux image would be available for downloading and flashing. -Specifically, to flash the downloaded Linux SD card image into your physical SD card: - -First, extract the gzipped archive file. - -``` bash -tar xf de10-nano-image-Angstrom-v2016.12.socfpga-sdimg.2017.03.31.tgz -``` - -This would produce a single SD card image named `de10-nano-image-Angstrom-v2016.12.socfpga-sdimg` (approx. 2.4 GB), it contains all the file systems to boot Angstrom Linux. - -Second, plugin a SD card that is ready to flash in your PC, and identify the device id for the disk with `fdisk -l`, or `gparted` if you feel better to use GUI. The typical device id for your disk would likely to be `/dev/sdb`. - -Then, flash the disk image into your physical SD card with the following command: - -``` bash -# NOTE: root privilege is typically required to run the following command. -dd if=de10-nano-image-Angstrom-v2016.12.socfpga-sdimg of=/dev/sdb status=progress -``` -This would take a few minutes for your PC to write the whole file systems into the SD card. -After this process completes, you are ready to unmount the SD card and insert it into your DE10-Nano board. -Now you can connect the power cable and serial port to boot the Angstrom Linux. - -> Note: When boot up from the microSD card, you might notice the incompatibility of the linux kernel `zImage` in the microSD card. -> In this case, you might need to build the `zImage` file of your own from [socfpga-4.9.78-ltsi](https://github.com/altera-opensource/linux-socfpga/tree/socfpga-4.9.78-ltsi) branch of the [linux-socfpga](https://github.com/altera-opensource/linux-socfpga) repository. -> For a quick fix, you can also download a prebuilt version of the `zImage` file [here](https://raw.githubusercontent.com/liangfu/de10-nano-supplement/master/zImage). - -After connecting the usb cables to the DE10-Nano board, power on the board by connecting the power cable. You may then connect to the serial port of the device by using `minicom` on your host PC: - -``` bash -# NOTE: root privilege is typically required to run the following command. -minicom -D /dev/ttyUSB0 -``` - -The default user name for the device would be `root`, and the password is empty for the default user. - -You may now start to install supporting Python3 packages (TVM has dropped the support for Python2), specifically, they are `numpy`, `attrs` and `decorator`. - -> Note: You might fail to install `numpy` by using `pip3` on the DE10-Nano device. -> In that case, you have the option to either build your own filesystem image for the board from [meta-de10-nano](https://github.com/intel/meta-de10-nano) repository; -> an alternative option is to download prebuilt packages from existing Linux distributions, e.g. Debian. -> For a quick fix, we have concatenated the supplementary binary files [here](https://raw.githubusercontent.com/liangfu/de10-nano-supplement/master/rootfs_supplement.tgz), and you can extract the files into the root filesystem. - -#### Install Required Python Packages - -After accessing bash terminal from the serial port, we need to install required Python packages before building and installing TVM and VTA programs. - -#### Build Additional Components to Use VTA Bitstream - -To use the above built bitstream on DE10-Nano hardware, several additional components need to be compiled for the system. -Specifically, to compile application executables for the system, you need to download and install [SoCEDS](http://fpgasoftware.intel.com/soceds/18.1/?edition=standard&download_manager=dlm3&platform=linux) (recommended), or alternatively install the `g++-arm-linux-gnueabihf` package on your host machine. You would also need a `cma` kernel module to allocate contigous memory, and a driver for communicating with the VTA subsystem. - -## VTA FPGA Toolchain Installation - -This last guide allows users to generate custom VTA bitstreams using free-to-use Xilinx or Intel compilation toolchains. - -### Xilinx Toolchain Installation - -We recommend using `Vivado 2018.3` since our scripts have been tested to work on this version of the Xilinx toolchains. -Our guide is written for Linux (Ubuntu) installation. - -You’ll need to install Xilinx’ FPGA compilation toolchain, [Vivado HL WebPACK 2018.3](https://www.xilinx.com/products/design-tools/vivado.html), which a license-free version of the Vivado HLx toolchain. - -#### Obtaining and Launching the Vivado GUI Installer - -1. Go to the [download webpage](https://www.xilinx.com/support/download/index.html/content/xilinx/en/downloadNav/vivado-design-tools/2018-3.html), and download the Linux Self Extracting Web Installer for Vivado HLx 2018.3: WebPACK and Editions. -2. You’ll have to sign in with a Xilinx account. This requires a Xilinx account creation that will take 2 minutes. -3. Complete the Name and Address Verification by clicking “Next”, and you will get the opportunity to download a binary file, called `Xilinx_Vivado_SDK_Web_2018.3_1207_2324_Lin64.bin`. -4. Now that the file is downloaded, go to your `Downloads` directory, and change the file permissions so it can be executed: -```bash -chmod u+x Xilinx_Vivado_SDK_Web_2018.3_1207_2324_Lin64.bin -``` -5. Now you can execute the binary: -```bash -./Xilinx_Vivado_SDK_Web_2018.3_1207_2324_Lin64.bin -``` - -#### Xilinx Vivado GUI Installer Steps - -At this point you've launched the Vivado 2018.3 Installer GUI program. - -1. Click “Next” on the *Welcome* screen. -2. On the *Select Install Type* screen, enter your Xilinx user credentials under the “User Authentication” box and select the “Download and Install Now” option before clicking “Next” . -3. On the *Accept License Agreements* screen, accept all terms before clicking “Next”. -4. On the *Select Edition to Install* screen, select the “Vivado HL WebPACK” before clicking “Next” . -5. Under the *Vivado HL WebPACK* screen, before hitting “Next", check the following options (the rest should be unchecked): - * Design Tools -> Vivado Design Suite -> Vivado - * Devices -> Production Devices -> SoCs -> Zynq-7000 (if you are targeting the Pynq board) - * Devices -> Production Devices -> SoCs -> UltraScale+ MPSoC (if you are targeting the Ultra-96 board) -6. Your total download size should be about 5GB and the amount of Disk Space Required 23GB. -7. On the *Select Destination Directory* screen, set the installation directory before clicking “Next”. It might highlight some paths as red - that’s because the installer doesn’t have the permission to write to the directory. In that case select a path that doesn’t require special write permissions (e.g. your home directory). -8. On the *Installation Summary* screen, hit “Install”. -9. An *Installation Progress* window will pop-up to track progress of the download and the installation. -10. This process will take about 20-30 minutes depending on your connection speed. -11. A pop-up window will inform you that the installation completed successfully. Click "OK". -12. Finally the *Vivado License Manager* will launch. Select "Get Free ISE WebPACK, ISE/Vivado IP or PetaLinux License" and click "Connect Now" to complete the license registration process. - -#### Environment Setup - -The last step is to update your `~/.bashrc` with the following lines. This will include all of the Xilinx binary paths so you can launch compilation scripts from the command line. -```bash -# Xilinx Vivado 2018.3 environment -export XILINX_VIVADO=${XILINX_PATH}/Vivado/2018.3 -export PATH=${XILINX_VIVADO}/bin:${PATH} -``` - -### Intel Toolchain Installation - -It is recommended to use `Intel Quartus Prime 18.1`, since the test scripts contained in this document have been tested on this version. - -You would need to install Intel's FPGA compilation toolchain, [Quartus Prime Lite](http://fpgasoftware.intel.com/?edition=lite), which is a license-free version of the Intel Quartus Prime software. - -#### Obtaining and Launching the Quartus GUI Installer - -1. Go to the [download center](http://fpgasoftware.intel.com/?edition=lite), and download the linux version of `Quartus Prime (include Nios II EDS)` and `Cyclone V device support` files in the `Separate file` tab. This avoid downloading unused device support files. -2. Sign in the form if you have an account, or register on the right side of the web page to create an account. -3. After signed in, you are able to download the installer and the device support files. -4. Now that the files are downloaded, go to your `Downloads` directory, and change the file permissions: -```bash -chmod u+x QuartusLiteSetup-18.1.0.625-linux.run -``` -5. Now ensure both the installer and device support files are in the same directory, and you can run the install with: -```bash -./QuartusLiteSetup-18.1.0.625-linux.run -``` -6. Follow the instructions on the pop-up GUI form, and install all the content in the `/usr/local` directory. After installation, `/usr/local/intelFPGA_lite/18.1` would be created and the Quartus program along with other programs would be available in the folder. - -#### Environment Setup - -Similar to what should be done for Xilinx toolchain, the following line should be added to your `~/.bashrc`. -```bash -# Intel Quartus 18.1 environment -export QUARTUS_ROOTDIR="/usr/local/intelFPGA_lite/18.1/quartus" -export PATH=${QUARTUS_ROOTDIR}/bin:${PATH} -export PATH=${QUARTUS_ROOTDIR}/sopc_builder/bin:${PATH} -``` -This would add quartus binary path into your `PATH` environment variable, so you can launch compilation scripts from the command line. - -### HLS-based Custom VTA Bitstream Compilation for PYNQ - -High-level hardware parameters are listed in the VTA configuration file and can be customized by the user. -For this custom VTA bitstream compilation exercise, we'll change the frequency of our design, so it can be clocked a little faster. -* Set the `HW_FREQ` field to `142`. The Pynq board supports 100, 142, 167 and 200MHz clocks. Note that the higher the frequency, the harder it will be to close timing. Increasing the frequency can lead to timing violation and thus faulty hardware execution. -* Set the `HW_CLK_TARGET` to `6`. This parameters refers to the target clock period in nano seconds for HLS - a lower clock period leads to more aggressive pipelining to achieve timing closure at higher frequencies. Technically a 142MHz clock would require a 7ns target, but we intentionally lower the clock target to 6ns to more aggressively pipeline our design. - -Bitstream generation is driven by a top-level `Makefile` under `/3rdparty/vta-hw/hardware/xilinx/`. - -If you just want to simulate the VTA design in software emulation to make sure that it is functional, enter: -```bash -cd /3rdparty/vta-hw/hardware/xilinx -make ip MODE=sim -``` - -If you just want to generate the HLS-based VTA IP cores without launching the entire design place and route, enter: -```bash -make ip -``` -You'll be able to view the HLS synthesis reports under `/3rdparty/vta-hw/build/hardware/xilinx/hls/` `//solution0/syn/report/_csynth.rpt` -> Note: The `` name is a string that summarizes the VTA configuration parameters listed in the `vta_config.json`. The `` name refers to the specific module (or HLS function) that compose the high-level VTA pipeline. - -Finally to run the full hardware compilation and generate the VTA bitstream, run: - -```bash -make -``` - -This process is lengthy, and can take around up to an hour to complete depending on your machine's specs. -We recommend setting the `VTA_HW_COMP_THREADS` variable in the Makefile to take full advantage of all the cores on your development machine. - -Once the compilation completes, the generated bitstream can be found under `/3rdparty/vta-hw/build/hardware/xilinx/vivado//export/vta.bit`. - -### Chisel-based Custom VTA Bitstream Compilation for DE10-Nano - -Similar to the HLS-based design, high-level hardware parameters in Chisel-based design are listed in the VTA configuration file [Configs.scala](https://github.com/apache/incubator-tvm/blob/master/3rdparty/vta-hw/hardware/chisel/src/main/scala/core/Configs.scala), and they can be customized by the user. - -For Intel FPGA, bitstream generation is driven by a top-level `Makefile` under `/3rdparty/vta-hw/hardware/intel`. - -If you just want to generate the Chisel-based VTA IP core for the DE10-Nano board without compiling the design for the FPGA hardware, enter: -```bash -cd /3rdparty/vta-hw/hardware/intel -make ip -``` -Then you'll be able to locate the generated verilog file at `/3rdparty/vta-hw/build/hardware/intel/chisel//VTA.DefaultDe10Config.v`. - -If you would like to run the full hardware compilation for the `de10nano` board: -```bash -make -``` - -This process might be a bit lengthy, and might take up to half an hour to complete depending on the performance of your PC. The Quartus Prime software would automatically detect the number of cores available on your PC and try to utilize all of them to perform such process. - -Once the compilation completes, the generated bistream can be found under `/3rdparty/vta-hw/build/hardware/intel/quartus//export/vta.rbf`. You can also open the Quartus project file (.qpf) available at `/3rdparty/vta-hw/build/hardware/intel/quartus//de10_nano_top.qpf` to look around the generated reports. - -### Use the Custom Bitstream - -We can program the new VTA FPGA bitstream by setting the bitstream path of the `vta.program_fpga()` function in the tutorial examples, or in the `test_program_rpc.py` script. - -```python -vta.program_fpga(remote, bitstream="/3rdparty/vta-hw/build/hardware/xilinx/vivado//export/vta.bit") -``` - -Instead of downloading a pre-built bitstream from the VTA bitstream repository, TVM will instead use the new bitstream you just generated, which is a VTA design clocked at a higher frequency. -Do you observe a noticeable performance increase on the ImageNet classification example? diff --git a/docs/vta/install.rst b/docs/vta/install.rst new file mode 100644 index 000000000000..b68fab7da2d1 --- /dev/null +++ b/docs/vta/install.rst @@ -0,0 +1,488 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +VTA Installation Guide +====================== + +We present three installation guides, each extending on the previous one: + +1. `Simulator Installation`_ +2. `Xilinx Pynq FPGA Setup`_ +3. `Intel DE10 FPGA Setup`_ +4. `Bitstream Generation with Xilinx Toolchains`_ +5. `Bitstream Generation with Intel Toolchains`_ + + +Simulator Installation +---------------------- + +You need `TVM installed `_ on your machine. +For a quick and easy start, checkout the `Docker Guide `_. + +You'll need to set the following paths to use VTA: + +.. code:: bash + + export TVM_PATH= + export VTA_HW_PATH=$TVM_PATH/3rdparty/vta-hw + +The VTA functional simulation library needs to be enabled when building TVM. + +.. code:: bash + + cd + mkdir build + cp cmake/config.cmake build/. + echo 'set(USE_VTA_FSIM ON)' >> build/config.cmake + cd build && cmake .. && make -j4 + +Add the VTA python library to your python path to run the VTA examples. + +.. code:: bash + + export PYTHONPATH=/path/to/vta/python:${PYTHONPATH} + +Testing your VTA Simulation Setup +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +To ensure that you've properly installed the VTA python package, run the following 2D convolution testbench. + +.. code:: bash + + python /vta/tests/python/integration/test_benchmark_topi_conv2d.py + +You are invited to try out our `VTA programming tutorials `_. + + **Note**: You'll notice that for every convolution layer, the throughput gets reported in GOPS. These numbers are actually the computational throughput that the simulator achieves, by evaluating the convolutions in software. + +Advanced Configuration (optional) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +VTA is a generic configurable deep learning accelerator. +The configuration is specified by ``vta_config.json`` under ``3rdparty/vta-hw/config``. +This file provides an architectural specification of the VTA accelerator to parameterize the TVM compiler stack and the VTA hardware stack. + +The VTA configuration file also specifies the TVM compiler target. +When ``TARGET`` is set to ``sim``, all TVM workloads execute on the VTA simulator. +You can modify the content of the configuration file to rebuild VTA to a different parameterization. +To do so, + +.. code:: bash + + cd + vim 3rdparty/vta-hw/config/vta_config.json + # edit vta_config.json + make + + + +Xilinx Pynq FPGA Setup +---------------------- + +This second guide extends the *VTA Simulator Installation* guide above to run FPGA hardware tests of the complete TVM and VTA software-hardware stack. +In terms of hardware components you'll need: + +* The `Pynq `_ FPGA development board which can be acquired for $200, or $150 for academics from `Digilent `_. +* An Ethernet-to-USB adapter to connect the Pynq board to your development machine. +* An 8+GB micro SD card. +* An AC to DC 12V 3A power adapter. + +This guide covers the following themes: + +1. Pynq board setup instructions. +2. Pynq-side RPC server build and deployment. +3. Revisiting the test examples from the *VTA Simulator Installation* guide, this time executing on the Pynq board. + +Pynq Board Setup +^^^^^^^^^^^^^^^^ + +Setup your Pynq board based on the `Pynq board getting started tutorial `_. + +You should follow the instructions up to and including the *Turning On the PYNQ-Z1* step (no need to pursue the tutorial beyond this point). + +* Make sure that you've downloaded the latest Pynq image, `PYNQ-Z1 v2.4 `_ (released February 22rd 2019), and have imaged your SD card with it (we recommend the free `Etcher `_ program). +* For this test setup, follow the `"Connect to a Computer" `_ Ethernet setup instructions. To be able to talk to the board, make sure to `assign your computer a static IP address `_ + +Once the board is powered on and connected to your development machine, try connecting to it to make sure you've properly set up your Pynq board: + +.. code:: bash + + # To connect to the Pynq board use the combo: + ssh xilinx@192.168.2.99 + +Pynq-Side RPC Server Build & Deployment +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Because the direct board-to-computer connection prevents the board from directly accessing the internet, we'll need to mount the Pynq's file system to your development machine's file system with `sshfs `_. Next we directly clone the TVM repository into the sshfs mountpoint on your development machine. + +.. code:: bash + + # On the Host-side + mkdir + sshfs xilinx@192.168.2.99:/home/xilinx + cd + git clone --recursive https://github.com/apache/incubator-tvm tvm + # When finished, you can leave the moutpoint and unmount the directory + cd ~ + sudo umount + +Now that we've cloned the VTA repository in the Pynq's file system, we can ssh into it and launch the build of the TVM-based RPC server. +The build process should take roughly 5 minutes. + +.. code:: bash + + ssh xilinx@192.168.2.99 + # Build TVM runtime library (takes 5 mins) + cd /home/xilinx/tvm + mkdir build + cp cmake/config.cmake build/. + echo 'set(USE_VTA_FPGA ON)' >> build/config.cmake + # Copy pynq specific configuration + cp 3rdparty/vta-hw/config/pynq_sample.json 3rdparty/vta-hw/config/vta_config.json + cd build + cmake .. + make runtime vta -j2 + # Build VTA RPC server (takes 1 min) + cd .. + sudo ./apps/vta_rpc/start_rpc_server.sh # pw is 'xilinx' + + +You should see the following being displayed when starting the RPC server. In order to run the next examples, you'll need to leave the RPC server running in an ``ssh`` session. + +.. code:: bash + + INFO:root:RPCServer: bind to 0.0.0.0:9091 + + +Tips regarding the Pynq RPC Server: + +* The RPC server should be listening on port ``9091``. If not, an earlier process might have terminated unexpectedly and it's recommended in this case to just reboot the Pynq, and re-run the RPC server. +* To kill the RPC server, just send the ``Ctrl + c`` command. You can re-run it with ``sudo ./apps/pynq_rpc/start_rpc_server.sh``. +* If unresponsive, the board can be rebooted by power-cycling it with the physical power switch. + +Testing your Pynq-based Hardware Setup +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Before running the examples on your development machine, you'll need to configure your host environment as follows: + +.. code:: bash + + # On the Host-side + export VTA_RPC_HOST=192.168.2.99 + export VTA_RPC_PORT=9091 + + +In addition, you'll need to edit the ``vta_config.json`` file on the host to indicate that we are targeting the Pynq platform, by setting the ``TARGET`` field to ``"pynq"``. +> Note: in contrast to our simulation setup, there are no libraries to compile on the host side since the host offloads all of the computation to the Pynq board. + +.. code:: bash + + # On the Host-side + cd + cp 3rdparty/vta-hw/config/pynq_sample.json 3rdparty/vta-hw/config/vta_config.json + + +This time again, we will run the 2D convolution testbench. +Beforehand, we need to program the Pynq board FPGA with a VTA bitstream, and build the VTA runtime via RPC. +The following ``test_program_rpc.py`` script will perform two operations: +* FPGA programming, by downloading a pre-compiled bitstream from a `VTA bitstream repository `_ that matches the default ``vta_config.json`` configuration set by the host, and sending it over to the Pynq via RPC to program the Pynq's FPGA. +* Runtime building on the Pynq, which needs to be run every time the ``vta_config.json`` configuration is modified. This ensures that the VTA software runtime that generates the accelerator's executable via just-in-time (JIT) compilation matches the specifications of the VTA design that is programmed on the FPGA. The build process takes about 30 seconds to complete so be patient! + +.. code:: bash + + # On the Host-side + python /vta/tests/python/pynq/test_program_rpc.py + + +We are now ready to run the 2D convolution testbench in hardware. + +.. code:: bash + + # On the Host-side + python /vta/tests/python/integration/test_benchmark_topi_conv2d.py + +The performance metrics measured on the Pynq board will be reported for each convolutional layer. + +**Tip**: You can track progress of the FPGA programming and the runtime rebuilding steps by looking at the RPC server's logging messages in your Pynq ``ssh`` session. + +You can also try out our `VTA programming tutorials `_. + + + +Intel DE10 FPGA Setup +--------------------- + +Similar to the PYNQ side setup steps, this third guide bring us the details on how can we setup up the Linux environment for Intel FPGA boards like DE10-Nano. + +In terms of hardware components, you would need the `DE10-Nano Development Kit `_, which can be acquired for $130, or $100 for academics from `Terasic `_. A microSD card would be delivered the kit. Power cables and USB cables would be included as well. However, an additional Ethernet cable would be needed to connect the board to LAN. + +The rest part of this guide would provide the steps to + +* Flash the microSD card with latest Angstrom Linux image +* Cross compilation setup +* Device-side RPC server setup and deployment + +DE10-Nano Board Setup +^^^^^^^^^^^^^^^^^^^^^ + +Before powering up the device, we need to flash the microSD card image with latest Angstrom Linux image. + +Flash SD Card and Boot Angstrom Linux +""""""""""""""""""""""""""""""""""""" + +To flash SD card and boot Linux on DE10-Nano, it is recommended to navigate to the `Resource `_ tab of the DE10-Nano product page from Terasic Inc. +After registration and login on the webpage, the prebuilt Angstrom Linux image would be available for downloading and flashing. +Specifically, to flash the downloaded Linux SD card image into your physical SD card: + +First, extract the gzipped archive file. + +.. code:: bash + + tar xf de10-nano-image-Angstrom-v2016.12.socfpga-sdimg.2017.03.31.tgz + +This would produce a single SD card image named ``de10-nano-image-Angstrom-v2016.12.socfpga-sdimg`` (approx. 2.4 GB), it contains all the file systems to boot Angstrom Linux. + +Second, plugin a SD card that is ready to flash in your PC, and identify the device id for the disk with ``fdisk -l``, or ``gparted`` if you feel better to use GUI. The typical device id for your disk would likely to be ``/dev/sdb``. + +Then, flash the disk image into your physical SD card with the following command: + +.. code:: bash + + # NOTE: root privilege is typically required to run the following command. + dd if=de10-nano-image-Angstrom-v2016.12.socfpga-sdimg of=/dev/sdb status=progress + +This would take a few minutes for your PC to write the whole file systems into the SD card. +After this process completes, you are ready to unmount the SD card and insert it into your DE10-Nano board. +Now you can connect the power cable and serial port to boot the Angstrom Linux. + + **Note**: When boot up from the microSD card, you might notice the incompatibility of the linux kernel ``zImage`` in the microSD card. + In this case, you might need to build the ``zImage`` file of your own from `socfpga-4.9.78-ltsi `_ branch of the `linux-socfpga `_ repository. + For a quick fix, you can also download a prebuilt version of the ``zImage`` file `from this link `_. + +After connecting the usb cables to the DE10-Nano board, power on the board by connecting the power cable. You may then connect to the serial port of the device by using ``minicom`` on your host PC: + +.. code:: bash + + # NOTE: root privilege is typically required to run the following command. + minicom -D /dev/ttyUSB0 + +The default user name for the device would be ``root``, and the password is empty for the default user. + +You may now start to install supporting Python3 packages (TVM has dropped the support for Python2), specifically, they are ``numpy``, ``attrs`` and ``decorator``. + + **Note**: You might fail to install ``numpy`` by using ``pip3`` on the DE10-Nano device. + In that case, you have the option to either build your own filesystem image for the board from `meta-de10-nano `_ repository; + an alternative option is to download prebuilt packages from existing Linux distributions, e.g. Debian. + For a quick fix, we have concatenated the supplementary binary files `here `_, and you can extract the files into the root filesystem. + +Install Required Python Packages +"""""""""""""""""""""""""""""""" + +After accessing bash terminal from the serial port, we need to install required Python packages before building and installing TVM and VTA programs. + +Build Additional Components to Use VTA Bitstream +"""""""""""""""""""""""""""""""""""""""""""""""" + +To use the above built bitstream on DE10-Nano hardware, several additional components need to be compiled for the system. +Specifically, to compile application executables for the system, you need to download and install `SoCEDS `_ (recommended), or alternatively install the ``g++-arm-linux-gnueabihf`` package on your host machine. You would also need a ``cma`` kernel module to allocate contigous memory, and a driver for communicating with the VTA subsystem. + + +Bitstream Generation with Xilinx Toolchains +------------------------------------------- + +If you're interested in generating the Xilinx FPGA bitstream on your own instead of using the pre-built VTA bistreams, follow the instructions below. + +Xilinx Toolchain Installation +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +We recommend using Vivado 2018.3 since our scripts have been tested to work on this version of the Xilinx toolchains. +Our guide is written for Linux (Ubuntu) installation. + +You’ll need to install Xilinx’ FPGA compilation toolchain, `Vivado HL WebPACK 2018.3 `_, which a license-free version of the Vivado HLx toolchain. + +Obtaining and Launching the Vivado GUI Installer +"""""""""""""""""""""""""""""""""""""""""""""""" + +1. Go to the `download webpage `_, and download the Linux Self Extracting Web Installer for Vivado HLx 2018.3: WebPACK and Editions. +2. You’ll have to sign in with a Xilinx account. This requires a Xilinx account creation that will take 2 minutes. +3. Complete the Name and Address Verification by clicking “Next”, and you will get the opportunity to download a binary file, called ``Xilinx_Vivado_SDK_Web_2018.3_1207_2324_Lin64.bin``. +4. Now that the file is downloaded, go to your ``Downloads`` directory, and change the file permissions so it can be executed: + +.. code:: bash + + chmod u+x Xilinx_Vivado_SDK_Web_2018.3_1207_2324_Lin64.bin + +5. Now you can execute the binary: + +.. code:: bash + + ./Xilinx_Vivado_SDK_Web_2018.3_1207_2324_Lin64.bin + +Xilinx Vivado GUI Installer Steps +""""""""""""""""""""""""""""""""" + +At this point you've launched the Vivado 2018.3 Installer GUI program. + +1. Click “Next” on the "Welcome" screen. +2. On the "Select Install Type" screen, enter your Xilinx user credentials under the “User Authentication” box and select the “Download and Install Now” option before clicking “Next”. +3. On the "Accept License Agreements" screen, accept all terms before clicking “Next”. +4. On the "Select Edition to Install" screen, select the “Vivado HL WebPACK” before clicking “Next”. +5. Under the "Vivado HL WebPACK" screen, before hitting “Next", check the following options (the rest should be unchecked): + * Design Tools -> Vivado Design Suite -> Vivado + * Devices -> Production Devices -> SoCs -> Zynq-7000 (if you are targeting the Pynq board) + * Devices -> Production Devices -> SoCs -> UltraScale+ MPSoC (if you are targeting the Ultra-96 board) +6. Your total download size should be about 5GB and the amount of Disk Space Required 23GB. +7. On the "Select Destination Directory" screen, set the installation directory before clicking “Next”. It might highlight some paths as red - that’s because the installer doesn’t have the permission to write to the directory. In that case select a path that doesn’t require special write permissions (e.g. your home directory). +8. On the "Installation Summary" screen, hit “Install”. +9. An "Installation Progress" window will pop-up to track progress of the download and the installation. +10. This process will take about 20-30 minutes depending on your connection speed. +11. A pop-up window will inform you that the installation completed successfully. Click "OK". +12. Finally the "Vivado License Manager" will launch. Select "Get Free ISE WebPACK, ISE/Vivado IP or PetaLinux License" and click "Connect Now" to complete the license registration process. + +Environment Setup +""""""""""""""""" + +The last step is to update your ``~/.bashrc`` with the following lines. This will include all of the Xilinx binary paths so you can launch compilation scripts from the command line. + +.. code:: bash + + # Xilinx Vivado 2018.3 environment + export XILINX_VIVADO=${XILINX_PATH}/Vivado/2018.3 + export PATH=${XILINX_VIVADO}/bin:${PATH} + +HLS-based Custom VTA Bitstream Compilation for PYNQ +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +High-level hardware parameters are listed in the VTA configuration file and can be customized by the user. +For this custom VTA bitstream compilation exercise, we'll change the frequency of our design, so it can be clocked a little faster. + +* Set the ``HW_FREQ`` field to ``142``. The Pynq board supports 100, 142, 167 and 200MHz clocks. Note that the higher the frequency, the harder it will be to close timing. Increasing the frequency can lead to timing violation and thus faulty hardware execution. +* Set the ``HW_CLK_TARGET`` to ``6``. This parameters refers to the target clock period in nano seconds for HLS - a lower clock period leads to more aggressive pipelining to achieve timing closure at higher frequencies. Technically a 142MHz clock would require a 7ns target, but we intentionally lower the clock target to 6ns to more aggressively pipeline our design. + +Bitstream generation is driven by a top-level ``Makefile`` under ``/3rdparty/vta-hw/hardware/xilinx/``. + +If you just want to simulate the VTA design in software emulation to make sure that it is functional, enter: + +.. code:: bash + + cd /3rdparty/vta-hw/hardware/xilinx + make ip MODE=sim + + +If you just want to generate the HLS-based VTA IP cores without launching the entire design place and route, enter: + +.. code:: bash + + make ip + +You'll be able to view the HLS synthesis reports under ``/3rdparty/vta-hw/build/hardware/xilinx/hls///solution0/syn/report/_csynth.rpt`` + + **Note**: The ```` name is a string that summarizes the VTA configuration parameters listed in the ``vta_config.json``. The ```` name refers to the specific module (or HLS function) that compose the high-level VTA pipeline. + +Finally to run the full hardware compilation and generate the VTA bitstream, run ``make``. + +This process is lengthy, and can take around up to an hour to complete depending on your machine's specs. +We recommend setting the ``VTA_HW_COMP_THREADS`` variable in the Makefile to take full advantage of all the cores on your development machine. + +Once the compilation completes, the generated bitstream can be found under ``/3rdparty/vta-hw/build/hardware/xilinx/vivado//export/vta.bit``. + +Using A Custom Bitstream +^^^^^^^^^^^^^^^^^^^^^^^^ + +We can program the new VTA FPGA bitstream by setting the bitstream path of the ``vta.program_fpga()`` function in the tutorial examples, or in the ``test_program_rpc.py`` script. + +.. code:: python + + vta.program_fpga(remote, bitstream="/3rdparty/vta-hw/build/hardware/xilinx/vivado//export/vta.bit") + +Instead of downloading a pre-built bitstream from the VTA bitstream repository, TVM will instead use the new bitstream you just generated, which is a VTA design clocked at a higher frequency. +Do you observe a noticeable performance increase on the ImageNet classification example? + + + +Bitstream Generation with Intel Toolchains +------------------------------------------- + +If you're interested in generating the Xilinx FPGA bitstream on your own instead of using the pre-built VTA bistreams, follow the instructions below. + +Intel Toolchain Installation +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +It is recommended to use ``Intel Quartus Prime 18.1``, since the test scripts contained in this document have been tested on this version. + +You would need to install Intel's FPGA compilation toolchain, `Quartus Prime Lite `_, which is a license-free version of the Intel Quartus Prime software. + +Obtaining and Launching the Quartus GUI Installer +""""""""""""""""""""""""""""""""""""""""""""""""" + +1. Go to the `download center `_, and download the linux version of "Quartus Prime (include Nios II EDS)" and "Cyclone V device support" files in the "Separate file" tab. This avoid downloading unused device support files. +2. Sign in the form if you have an account, or register on the right side of the web page to create an account. +3. After signed in, you are able to download the installer and the device support files. +4. Now that the files are downloaded, go to your ``Downloads`` directory, and change the file permissions: + +.. code:: bash + + chmod u+x QuartusLiteSetup-18.1.0.625-linux.run + +5. Now ensure both the installer and device support files are in the same directory, and you can run the install with: + +.. code:: bash + + ./QuartusLiteSetup-18.1.0.625-linux.run + +6. Follow the instructions on the pop-up GUI form, and install all the content in the ``/usr/local`` directory. After installation, ``/usr/local/intelFPGA_lite/18.1`` would be created and the Quartus program along with other programs would be available in the folder. + +Environment Setup +""""""""""""""""" + +Similar to what should be done for Xilinx toolchain, the following line should be added to your ``~/.bashrc``. + +.. code:: bash + + # Intel Quartus 18.1 environment + export QUARTUS_ROOTDIR="/usr/local/intelFPGA_lite/18.1/quartus" + export PATH=${QUARTUS_ROOTDIR}/bin:${PATH} + export PATH=${QUARTUS_ROOTDIR}/sopc_builder/bin:${PATH} + +This would add quartus binary path into your ``PATH`` environment variable, so you can launch compilation scripts from the command line. + +Chisel-based Custom VTA Bitstream Compilation for DE10-Nano +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Similar to the HLS-based design, high-level hardware parameters in Chisel-based design are listed in the VTA configuration file `Configs.scala `_, and they can be customized by the user. + +For Intel FPGA, bitstream generation is driven by a top-level ``Makefile`` under ``/3rdparty/vta-hw/hardware/intel``. + +If you just want to generate the Chisel-based VTA IP core for the DE10-Nano board without compiling the design for the FPGA hardware, enter: + +.. code:: bash + + cd /3rdparty/vta-hw/hardware/intel + make ip + +Then you'll be able to locate the generated verilog file at ``/3rdparty/vta-hw/build/hardware/intel/chisel//VTA.DefaultDe10Config.v``. + +If you would like to run the full hardware compilation for the ``de10nano`` board: + +.. code:: bash + + make + +This process might be a bit lengthy, and might take up to half an hour to complete depending on the performance of your PC. The Quartus Prime software would automatically detect the number of cores available on your PC and try to utilize all of them to perform such process. + +Once the compilation completes, the generated bistream can be found under ``/3rdparty/vta-hw/build/hardware/intel/quartus//export/vta.rbf``. You can also open the Quartus project file (.qpf) available at ``/3rdparty/vta-hw/build/hardware/intel/quartus//de10_nano_top.qpf`` to look around the generated reports. + diff --git a/golang/sample/gen_mobilenet_lib.py b/golang/sample/gen_mobilenet_lib.py index 4f6a615d14c9..d4dcf2136f81 100644 --- a/golang/sample/gen_mobilenet_lib.py +++ b/golang/sample/gen_mobilenet_lib.py @@ -16,9 +16,8 @@ # under the License. import os -from tvm import relay +from tvm import relay, transform from tvm.contrib.download import download_testdata -import tflite.Model ################################################ @@ -49,7 +48,12 @@ def extract(path): # get TFLite model from buffer tflite_model_buf = open(model_file, "rb").read() -tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0) +try: + import tflite + tflite_model = tflite.Model.GetRootAsModel(tflite_model_buf, 0) +except AttributeError: + import tflite.Model + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0) ############################## @@ -73,7 +77,7 @@ def extract(path): target = 'llvm' # Build with Relay -with relay.build_config(opt_level=3): +with transform.PassContext(opt_level=3): graph, lib, params = relay.build_module.build( mod, target, params=params) diff --git a/golang/src/gotvm.cc b/golang/src/gotvm.cc index af6e4303a85a..f599c405d5d5 100644 --- a/golang/src/gotvm.cc +++ b/golang/src/gotvm.cc @@ -24,14 +24,17 @@ // Standard includes #include +#include #include #include #include #include -#include // golang string compatible definition -typedef struct { char *p; int n; } _gostring_; +typedef struct { + char* p; + int n; +} _gostring_; #include #ifdef __cplusplus @@ -39,8 +42,8 @@ extern "C" { #endif // TVM runtime C interface -#include #include +#include /*! * \brief Convert native char array to _gostring_ structure. @@ -53,7 +56,7 @@ extern "C" { * \return _gostring_ object corresponding to native char array. * Caller is responsible to free the memory block allocated here. */ -static _gostring_ _native_to_gostring(const char *p, size_t l) { +static _gostring_ _native_to_gostring(const char* p, size_t l) { _gostring_ ret; ret.p = reinterpret_cast(malloc(l)); if (NULL == ret.p) { @@ -72,10 +75,10 @@ static _gostring_ _native_to_gostring(const char *p, size_t l) { * \param off is the offset in the string object. * \param v is the uint64_t value which need to embed into given string. */ -static void putuint64(std::string *s, size_t off, uint64_t v) { - for (int i = 0; i < 8; i++) { - (*s)[off + i] = (v >> (i * 8)) & 0xff; - } +static void putuint64(std::string* s, size_t off, uint64_t v) { + for (int i = 0; i < 8; i++) { + (*s)[off + i] = (v >> (i * 8)) & 0xff; + } } // TVM runtime C interface wrappers @@ -86,7 +89,7 @@ static void putuint64(std::string *s, size_t off, uint64_t v) { * \return char pointer to TVM-VERSION */ const char* _TVM_VERSION(void) { - const char *version = TVM_VERSION; + const char* version = TVM_VERSION; return version; } @@ -101,16 +104,16 @@ const char* _TVM_VERSION(void) { */ int _TVMFuncListGlobalNames(_gostring_* names) { int names_size; - char **names_array; + char** names_array; int result; - result = TVMFuncListGlobalNames(&names_size, (char const ***)&names_array); + result = TVMFuncListGlobalNames(&names_size, (char const***)&names_array); if (result) { return result; } size_t tot = 8; - for (int ii = 0; ii < names_size ; ++ii) { + for (int ii = 0; ii < names_size; ++ii) { tot += 8 + strlen(names_array[ii]); } @@ -118,7 +121,7 @@ int _TVMFuncListGlobalNames(_gostring_* names) { str.resize(tot); putuint64(&str, 0, names_size); size_t off = 8; - for (int64_t ii = 0; ii < names_size ; ++ii) { + for (int64_t ii = 0; ii < names_size; ++ii) { putuint64(&str, off, strlen(names_array[ii])); off += 8; str.replace(off, strlen(names_array[ii]), names_array[ii]); @@ -143,9 +146,9 @@ int _TVMFuncListGlobalNames(_gostring_* names) { * \param array index in native array. */ void _TVMValueNativeSet(void* to_ptr, void* from_ptr, int ind) { - TVMValue *from_p = reinterpret_cast(from_ptr); - TVMValue *to_p = reinterpret_cast(to_ptr); - memcpy(to_p+ind, from_p, sizeof(TVMValue)); + TVMValue* from_p = reinterpret_cast(from_ptr); + TVMValue* to_p = reinterpret_cast(to_ptr); + memcpy(to_p + ind, from_p, sizeof(TVMValue)); } /*! @@ -157,9 +160,9 @@ void _TVMValueNativeSet(void* to_ptr, void* from_ptr, int ind) { * \param array index in native array. */ void _TVMValueNativeGet(void* to_ptr, void* from_ptr, int ind) { - TVMValue *from_p = reinterpret_cast(from_ptr); - TVMValue *to_p = reinterpret_cast(to_ptr); - memcpy(to_p, from_p+ind, sizeof(TVMValue)); + TVMValue* from_p = reinterpret_cast(from_ptr); + TVMValue* to_p = reinterpret_cast(to_ptr); + memcpy(to_p, from_p + ind, sizeof(TVMValue)); } extern int goTVMCallback(void*, void*, int, void*, void*); @@ -175,21 +178,16 @@ extern int goTVMCallback(void*, void*, int, void*, void*); * * \returns the error status as TVM_DLL */ -int _TVMCallback(TVMValue* args, - int* type_codes, - int num_args, - TVMRetValueHandle ret, +int _TVMCallback(TVMValue* args, int* type_codes, int num_args, TVMRetValueHandle ret, void* resource_handle) { - return goTVMCallback(args, type_codes, num_args, ret, resource_handle); + return goTVMCallback(args, type_codes, num_args, ret, resource_handle); } /*! * _TVMPackedCFuncFinalizer is finalizer for packed function system. * */ -void _TVMPackedCFuncFinalizer(void* resource_handle) { - return; -} +void _TVMPackedCFuncFinalizer(void* resource_handle) { return; } /*! * /brief _ConvertFunction creates a packed function for with given resource handle. @@ -199,11 +197,8 @@ void _TVMPackedCFuncFinalizer(void* resource_handle) { * * /return is an int indicating the return status. */ -int _ConvertFunction(void* fptr, TVMFunctionHandle *fhandle) { - int ret = TVMFuncCreateFromCFunc(_TVMCallback, - fptr, - _TVMPackedCFuncFinalizer, - fhandle); +int _ConvertFunction(void* fptr, TVMFunctionHandle* fhandle) { + int ret = TVMFuncCreateFromCFunc(_TVMCallback, fptr, _TVMPackedCFuncFinalizer, fhandle); return ret; } diff --git a/golang/src/gotvm.h b/golang/src/gotvm.h index 12b594b8c9a9..a053e39bd79a 100644 --- a/golang/src/gotvm.h +++ b/golang/src/gotvm.h @@ -32,11 +32,11 @@ extern "C" { #endif +#include #include #include #include #include -#include // Some type definitions for golang "C" typedef void* native_voidp; diff --git a/golang/src/tvm_runtime_pack.cc b/golang/src/tvm_runtime_pack.cc index 416067dcdca1..644249fa75c9 100644 --- a/golang/src/tvm_runtime_pack.cc +++ b/golang/src/tvm_runtime_pack.cc @@ -23,15 +23,15 @@ */ #include "src/runtime/c_runtime_api.cc" #include "src/runtime/cpu_device_api.cc" -#include "src/runtime/workspace_pool.cc" +#include "src/runtime/file_util.cc" #include "src/runtime/library_module.cc" #include "src/runtime/module.cc" -#include "src/runtime/registry.cc" -#include "src/runtime/file_util.cc" -#include "src/runtime/threading_backend.cc" -#include "src/runtime/thread_pool.cc" #include "src/runtime/ndarray.cc" #include "src/runtime/object.cc" +#include "src/runtime/registry.cc" +#include "src/runtime/thread_pool.cc" +#include "src/runtime/threading_backend.cc" +#include "src/runtime/workspace_pool.cc" // NOTE: all the files after this are optional modules // that you can include remove, depending on how much feature you use. diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h index 3a71e5eb5fbf..8033294a0f99 100644 --- a/include/tvm/arith/analyzer.h +++ b/include/tvm/arith/analyzer.h @@ -24,14 +24,14 @@ #ifndef TVM_ARITH_ANALYZER_H_ #define TVM_ARITH_ANALYZER_H_ -#include -#include #include +#include +#include -#include -#include -#include #include +#include +#include +#include namespace tvm { /*! \brief namespace of arithmetic analysis. */ @@ -107,12 +107,13 @@ class ConstIntBound : public ObjectRef { */ class ConstIntBoundAnalyzer { public: + using BoundMapType = std::unordered_map; /*! * \brief analyze the expr * \param expr The expression of interest. * \return the result of the analysis. */ - ConstIntBound operator()(const PrimExpr& expr); + TVM_DLL ConstIntBound operator()(const PrimExpr& expr); /*! * \brief analyze the expr with the intermediate memorized to avoid redundant computation @@ -120,8 +121,7 @@ class ConstIntBoundAnalyzer { * \param bound The lookup table to store the intermediate results * \return the result of the analysis. */ - ConstIntBound operator()(const PrimExpr& expr, - std::unordered_map* bound); + TVM_DLL ConstIntBound operator()(const PrimExpr& expr, BoundMapType* bound); /*! * \brief Update constant int bound information of var. @@ -130,22 +130,21 @@ class ConstIntBoundAnalyzer { * \param info The bound information. * \param override Whether do we allow override of existing information. */ - void Update(const Var& var, - const ConstIntBound& info, - bool override = false); + TVM_DLL void Update(const Var& var, const ConstIntBound& info, bool override = false); /*! * \brief Bind variable to a range. * * \param var The variable. * \param range The range we bind to. + * \param override Whether we allow overriding an existing var's range. */ - void Bind(const Var& var, const Range& range); + TVM_DLL void Bind(const Var& var, const Range& range, bool override = false); private: friend class Analyzer; friend class ConstraintContext; explicit ConstIntBoundAnalyzer(Analyzer* parent); - ~ConstIntBoundAnalyzer(); + TVM_DLL ~ConstIntBoundAnalyzer(); /*! * \brief Update the internal state to enter constraint. * \param constraint A constraint expression. @@ -212,7 +211,7 @@ class ModularSetAnalyzer { * \param expr The expression of interest. * \return the result of the analysis. */ - ModularSet operator()(const PrimExpr& expr); + TVM_DLL ModularSet operator()(const PrimExpr& expr); /*! * \brief Update constant int bound information of var. * @@ -220,15 +219,13 @@ class ModularSetAnalyzer { * \param info The bound information. * \param override Whether do we allow override of existing information. */ - void Update(const Var& var, - const ModularSet& info, - bool override = false); + TVM_DLL void Update(const Var& var, const ModularSet& info, bool override = false); private: friend class Analyzer; friend class ConstraintContext; explicit ModularSetAnalyzer(Analyzer* parent); - ~ModularSetAnalyzer(); + TVM_DLL ~ModularSetAnalyzer(); /*! * \brief Update the internal state to enter constraint. * \param constraint A constraint expression. @@ -252,7 +249,7 @@ class RewriteSimplifier { * \param expr The expression of interest. * \return the result of the analysis. */ - PrimExpr operator()(const PrimExpr& expr); + TVM_DLL PrimExpr operator()(const PrimExpr& expr); /*! * \brief Update binding of var to a new expression. @@ -261,9 +258,7 @@ class RewriteSimplifier { * \param new_expr * \param override Whether do we allow override of existing information. */ - void Update(const Var& var, - const PrimExpr& new_expr, - bool override = false); + TVM_DLL void Update(const Var& var, const PrimExpr& new_expr, bool override = false); std::function EnterConstraint(const PrimExpr& constraint); @@ -272,7 +267,7 @@ class RewriteSimplifier { friend class ConstraintContext; friend class CanonicalSimplifier; explicit RewriteSimplifier(Analyzer* parent); - ~RewriteSimplifier(); + TVM_DLL ~RewriteSimplifier(); class Impl; /*! \brief Internal impl */ Impl* impl_; @@ -288,7 +283,7 @@ class CanonicalSimplifier { * \param expr The expression of interest. * \return the result of the analysis. */ - PrimExpr operator()(const PrimExpr& expr); + TVM_DLL PrimExpr operator()(const PrimExpr& expr); /*! * \brief Update binding of var to a new expression. @@ -297,15 +292,13 @@ class CanonicalSimplifier { * \param new_expr * \param override Whether do we allow override of existing information. */ - void Update(const Var& var, - const PrimExpr& new_expr, - bool override = false); + TVM_DLL void Update(const Var& var, const PrimExpr& new_expr, bool override = false); private: friend class Analyzer; friend class ConstraintContext; explicit CanonicalSimplifier(Analyzer* parent); - ~CanonicalSimplifier(); + TVM_DLL ~CanonicalSimplifier(); class Impl; /*! \brief Internal impl */ Impl* impl_; @@ -363,12 +356,12 @@ class IntSetAnalyzer { * \param dom_map The domain map to indicate which variable to relax. * \return the result of the analysis. */ - IntSet operator()(const PrimExpr& expr, const Map& dom_map); + TVM_DLL IntSet operator()(const PrimExpr& expr, const Map& dom_map); private: friend class Analyzer; explicit IntSetAnalyzer(Analyzer* parent); - ~IntSetAnalyzer(); + TVM_DLL ~IntSetAnalyzer(); class Impl; /*! \brief Internal impl */ Impl* impl_; @@ -384,7 +377,7 @@ class IntSetAnalyzer { * If the analyzer uses memoization, we need to clear the internal * cache when information about a Var has been overridden. */ -class Analyzer { +class TVM_DLL Analyzer { public: /* * Disable copy constructor. @@ -411,8 +404,9 @@ class Analyzer { * * \param var The variable. * \param expr The expression we bind to. + * \param override Whether we allow overriding an existing var's expression. */ - void Bind(const Var& var, const PrimExpr& expr); + void Bind(const Var& var, const PrimExpr& expr, bool override = false); /*! * \brief Notify all the sub-analyzers that var * is created and binded to a range. @@ -421,14 +415,16 @@ class Analyzer { * * \param var The variable. * \param range The range we bind to. + * \param override Whether we allow overriding an existing var's expression. */ - void Bind(const Var& var, const Range& range); + void Bind(const Var& var, const Range& range, bool override = false); /*! * \brief Bind all the vars in the Map * * \param variables The {variable -> range} map. + * \param override Whether we allow overriding an existing var's expression. */ - void Bind(const Map& variables); + void Bind(const Map& variables, bool override = false); /*! * \brief Whether can we prove expr >= val. @@ -442,6 +438,19 @@ class Analyzer { * \note Analyzer will call into sub-analyzers to get the result. */ bool CanProveGreaterEqual(const PrimExpr& expr, int64_t lower_bound); + /*! + * \brief Whether can we prove expr < val. + + * Non-negative proof is very useful in integer analysis + * to lower divisions and mods given difference in trunc and ceil mode. + * + * \param expr The expression. + * \param upper_bound The upper bound. + * \return Whether we can prove it. + * + * \note Analyzer will call into sub-analyzers to get the result. + */ + bool CanProveLess(const PrimExpr& expr, int64_t upper_bound); /*! * \brief Whether can we prove condition. * diff --git a/include/tvm/arith/bound.h b/include/tvm/arith/bound.h index 6165a2ab546f..12b91cc033e5 100644 --- a/include/tvm/arith/bound.h +++ b/include/tvm/arith/bound.h @@ -23,25 +23,21 @@ #ifndef TVM_ARITH_BOUND_H_ #define TVM_ARITH_BOUND_H_ -#include -#include #include +#include +#include #include #include #include namespace tvm { -// forward delcare Tensor -namespace te { -class Tensor; -} namespace arith { +using tir::Region; +using tir::Stmt; using tir::Var; using tir::VarNode; -using tir::Domain; -using tir::Stmt; /*! * \brief Deduce the bound of the target variable in a expression, @@ -58,8 +54,7 @@ using tir::Stmt; * The deduce bound must implies e for all value in relax_map * \return An integer set that always satisfies the condition. */ -IntSet DeduceBound(PrimExpr v, PrimExpr cond, - const Map& hint_map, +IntSet DeduceBound(PrimExpr v, PrimExpr cond, const Map& hint_map, const Map& relax_map); /*! * \brief Same as DeduceBound with unordered_map signature. @@ -78,15 +73,13 @@ IntSet DeduceBound(PrimExpr v, PrimExpr cond, /*! * \brief Infer a regular domain that covers all the calls or provides within the given statement. * \param body The given statement. - * \param tensor The name of the calls or provides. - * \param consider_calls If calls (read) are considered. - * \param consider_provides If provides (write) are considered. + * \param buffer The buffer to check the access info. + * \param consider_loads If loads are considered. + * \param consider_stores If stores are considered. * \return The domain that covers all the calls or provides within the given statement. */ -Domain DomainTouched(Stmt body, - const te::Tensor &tensor, - bool consider_calls, - bool consider_provides); +Region DomainTouched(const Stmt& body, const tir::Buffer& buffer, bool consider_loads, + bool consider_stores); } // namespace arith } // namespace tvm diff --git a/include/tvm/arith/int_set.h b/include/tvm/arith/int_set.h index 86ef906fef0a..ae90bdea5310 100644 --- a/include/tvm/arith/int_set.h +++ b/include/tvm/arith/int_set.h @@ -26,14 +26,15 @@ #include #include + #include namespace tvm { namespace arith { +using tir::IterVar; using tir::Var; using tir::VarNode; -using tir::IterVar; //----------------------------------------------- // Integer set data structure. @@ -44,12 +45,7 @@ using tir::IterVar; /*! * \brief Sign type of an integer expression. */ -enum SignType { - kPositive, - kNegative, - kZero, - kUnknown -}; +enum SignType { kPositive, kNegative, kZero, kUnknown }; /*! * \brief Base class of all Integer set containers. @@ -77,9 +73,7 @@ class IntSet : public ObjectRef { * \brief access the internal node container * \return the pointer to the internal node container */ - const IntSetNode* operator->() const { - return static_cast(get()); - } + const IntSetNode* operator->() const { return static_cast(get()); } /*! * \brief Find a range that covers the region. * \param max_range The range to be covered. @@ -152,6 +146,13 @@ class IntSet : public ObjectRef { //----------------------------------------------- // Integer set legacy API. //------------------------------------------------ +/*! + * \brief Convert std::unordered_map to Map + * + * \param dom_map The domain map to convert. + * \return The converted map. + */ +Map ConvertDomMap(const std::unordered_map& dom_map); /*! * \brief Find an symbolic integer set that contains all possible values of * e given the domain of each iteration variables. @@ -160,8 +161,7 @@ class IntSet : public ObjectRef { * \param dom_map The domain of each variable. * \return An integer set that can cover all the possible values of e. */ -IntSet EvalSet(PrimExpr e, - const Map& dom_map); +IntSet EvalSet(PrimExpr e, const Map& dom_map); /*! * \brief Same as EvalSet, but takes unordered_map * @@ -169,9 +169,7 @@ IntSet EvalSet(PrimExpr e, * \param dom_map The domain of each variable. * \return An integer set that can cover all the possible values of e. */ -IntSet EvalSet(PrimExpr e, - const std::unordered_map& dom_map); - +IntSet EvalSet(PrimExpr e, const std::unordered_map& dom_map); /*! * \brief Find an symbolic integer set that contains is union over * all the possible conditional values in dom_map. @@ -180,8 +178,7 @@ IntSet EvalSet(PrimExpr e, * \param dom_map The domain of each variable. * \return An integer set that can cover all the possible values. */ -IntSet EvalSet(Range r, - const Map& dom_map); +IntSet EvalSet(Range r, const Map& dom_map); /*! * \brief Find an symbolic integer set that contains is union over @@ -191,8 +188,7 @@ IntSet EvalSet(Range r, * \param dom_map The domain of each variable. * \return An integer set that can cover all the possible values. */ -IntSet EvalSet(IntSet s, - const std::unordered_map& dom_map); +IntSet EvalSet(IntSet s, const std::unordered_map& dom_map); /*! * \brief Same as EvalSet, but takes unordered_map * @@ -200,11 +196,9 @@ IntSet EvalSet(IntSet s, * \param dom_map The domain of each variable. * \return An integer set that can cover all the possible values of e. */ -IntSet EvalSet(Range r, - const std::unordered_map& dom_map); - +IntSet EvalSet(Range r, const std::unordered_map& dom_map); /*! \brief Map from Expr to IntSet */ -using ExprIntSetMap = std::unordered_map; +using ExprIntSetMap = std::unordered_map; /*! * \brief Find the integer set of every sub-expression, given the * domain of each iteration variables. @@ -213,9 +207,8 @@ using ExprIntSetMap = std::unordered_map& dom_map); +ExprIntSetMap EvalSetForEachSubExpr(PrimExpr e, + const std::unordered_map& dom_map); /*! * \brief Create an union set of all sets diff --git a/include/tvm/arith/int_solver.h b/include/tvm/arith/int_solver.h index 57f3af4bb67b..ae18cab0a9fa 100644 --- a/include/tvm/arith/int_solver.h +++ b/include/tvm/arith/int_solver.h @@ -26,15 +26,16 @@ #include #include + #include #include namespace tvm { namespace arith { +using tir::IterVar; using tir::Var; using tir::VarNode; -using tir::IterVar; /*! * \brief Represent integer constrains including (integer) variables, their ranges and @@ -60,10 +61,8 @@ class IntConstraintsNode : public Object { } bool SEqualReduce(const IntConstraintsNode* other, SEqualReducer equal) const { - return - equal(variables, other->variables) && - equal(ranges, other->ranges) && - equal(relations, other->relations); + return equal(variables, other->variables) && equal(ranges, other->ranges) && + equal(relations, other->relations); } void SHashReduce(SHashReducer hash_reduce) const { @@ -90,9 +89,7 @@ class IntConstraints : public ObjectRef { * \param relations The linear relations between the variables * (either equations or inequalities) */ - TVM_DLL IntConstraints(Array variables, - Map ranges, - Array relations); + TVM_DLL IntConstraints(Array variables, Map ranges, Array relations); TVM_DEFINE_OBJECT_REF_METHODS(IntConstraints, ObjectRef, IntConstraintsNode); }; @@ -126,11 +123,8 @@ class IntConstraintsTransformNode : public Object { } bool SEqualReduce(const IntConstraintsTransformNode* other, SEqualReducer equal) const { - return - equal(src, other->src) && - equal(dst, other->dst) && - equal(src_to_dst, other->src_to_dst) && - equal(dst_to_src, other->dst_to_src); + return equal(src, other->src) && equal(dst, other->dst) && + equal(src_to_dst, other->src_to_dst) && equal(dst_to_src, other->dst_to_src); } void SHashReduce(SHashReducer hash_reduce) const { @@ -161,10 +155,8 @@ class IntConstraintsTransform : public ObjectRef { * \param dst_to_src mapping from variables in the \p dst to the variables in the \p src, * e.g., {m -> a, n -> -b} */ - TVM_DLL IntConstraintsTransform(IntConstraints src, - IntConstraints dst, - Map src_to_dst, - Map dst_to_src); + TVM_DLL IntConstraintsTransform(IntConstraints src, IntConstraints dst, + Map src_to_dst, Map dst_to_src); TVM_DEFINE_OBJECT_REF_METHODS(IntConstraintsTransform, ObjectRef, IntConstraintsTransformNode); }; @@ -176,20 +168,16 @@ class IntConstraintsTransform : public ObjectRef { * NOTE: Although in standard Smith Normal Form the diagonal elements satisfy * s_i | s_{i+1} (| means divides), the implement here does not guarantee it. * TODO(yzhliu): From sergei-grechanik: - * computing the proper Smith normal form may improve stability of automatic differentiation - * (generating the same gradient code for slightly different but equivalent input code - * U_{mxm} and V_{nxn} are invertible matrices. - * This function modifies \p S to be S_{mxn}, \p V to be V_{nxn}, - * \p y to be U_{mxm} y_{mx1} and \p x to be V^{-1} x. - * \param S the original A_{mxn}, it will be modified to S_{mxn} - * \param V an identity matrix, it will be modified to V_{nxn} - * \param x the x in A x = y. it will be modified to V^{-1}_{nxn} x_{nx1} - * \param y the y in A x = y. it will be modified to U_{mxm} y_{mx1} + * computing the proper Smith normal form may improve stability of automatic + * differentiation (generating the same gradient code for slightly different but equivalent input + * code U_{mxm} and V_{nxn} are invertible matrices. This function modifies \p S to be S_{mxn}, \p V + * to be V_{nxn}, \p y to be U_{mxm} y_{mx1} and \p x to be V^{-1} x. \param S the original + * A_{mxn}, it will be modified to S_{mxn} \param V an identity matrix, it will be modified to + * V_{nxn} \param x the x in A x = y. it will be modified to V^{-1}_{nxn} x_{nx1} \param y the y + * in A x = y. it will be modified to U_{mxm} y_{mx1} */ -void SmithNormalFormDiag(std::vector> *S, - std::vector> *V, - std::vector* x, - std::vector *y); +void SmithNormalFormDiag(std::vector>* S, std::vector>* V, + std::vector* x, std::vector* y); /*! * \brief Solve linear equations. @@ -201,7 +189,7 @@ void SmithNormalFormDiag(std::vector> *S, * as well as inequalities inferred from the \p system_to_solve. * You can get the mapping from the original variables to the solution via ret->src_to_dst. */ -IntConstraintsTransform SolveLinearEquations(const IntConstraints &system_to_solve); +IntConstraintsTransform SolveLinearEquations(const IntConstraints& system_to_solve); } // namespace arith } // namespace tvm diff --git a/include/tvm/arith/pattern.h b/include/tvm/arith/pattern.h index d3ba3e980430..301d95636ca4 100644 --- a/include/tvm/arith/pattern.h +++ b/include/tvm/arith/pattern.h @@ -24,8 +24,8 @@ #ifndef TVM_ARITH_PATTERN_H_ #define TVM_ARITH_PATTERN_H_ -#include #include +#include #include namespace tvm { @@ -38,8 +38,7 @@ namespace arith { * \param vars List of variables to be used in detection. * \return [coeff[i]] if it is possible, empty array if it is not. */ -Array DetectLinearEquation(const PrimExpr& e, - const Array& vars); +Array DetectLinearEquation(const PrimExpr& e, const Array& vars); /*! * \brief Detect if expression corresponds to clip bound of the vars @@ -49,8 +48,7 @@ Array DetectLinearEquation(const PrimExpr& e, * \return concat([min_value[i], max_value[i]]), None is returned if there is no min or max value * return empty if the e does not match the pattern. */ -Array DetectClipBound(const PrimExpr& e, - const Array& vars); +Array DetectClipBound(const PrimExpr& e, const Array& vars); } // namespace arith } // namespace tvm diff --git a/include/tvm/driver/driver_api.h b/include/tvm/driver/driver_api.h index e6d442754446..71a69a000944 100644 --- a/include/tvm/driver/driver_api.h +++ b/include/tvm/driver/driver_api.h @@ -29,47 +29,39 @@ #ifndef TVM_DRIVER_DRIVER_API_H_ #define TVM_DRIVER_DRIVER_API_H_ +#include #include -#include #include -#include +#include #include #include -#include -#include #include #include +#include +#include namespace tvm { /*! -* \brief Build an IRModule given a schedule, args and binds -* \param sch The schedule to lower. -* \param args The arguments to the function. -* \param name The name of the lowered function. -* \param binds Buffer assignments. -* \param config The build configuration. -* \return The result module. -*/ -TVM_DLL IRModule lower( - te::Schedule sch, - const Array& args, - const std::string& name, - const std::unordered_map& binds, - const BuildConfig& config); + * \brief Build an IRModule given a schedule, args and binds + * \param sch The schedule to lower. + * \param args The arguments to the function. + * \param name The name of the lowered function. + * \param binds Buffer assignments. + * \return The result module. + */ +TVM_DLL IRModule lower(te::Schedule sch, const Array& args, const std::string& name, + const std::unordered_map& binds); /*! -* \brief Build a device and host module for a specific target from an IRModule. -* \param funcs The functions to be built. -* \param target The target device to build for. -* \param target_host The target for building host code. To use the default, pass Target() -* \param config The build configuration. -* \return The built module. -*/ -TVM_DLL runtime::Module build(const IRModule& funcs, - const Target& target, - const Target& target_host, - const BuildConfig& config); + * \brief Build a device and host module for a specific target from an IRModule. + * \param funcs The functions to be built. + * \param target The target device to build for. + * \param target_host The target for building host code. To use the default, pass Target() + * \return The built module. + */ +TVM_DLL runtime::Module build(const IRModule& funcs, const Target& target, + const Target& target_host); /*! * \brief Build a device and host module for a specific target from a map @@ -78,12 +70,9 @@ TVM_DLL runtime::Module build(const IRModule& funcs, * \param input The map contains target to an IRModule. * \param target_host The target for building host code. To use the default, * pass Target(). - * \param config The build configuration. * \return The built module that contains code for different processors. */ -TVM_DLL runtime::Module build(const Map& input, - const Target& target_host, - const BuildConfig& config); +TVM_DLL runtime::Module build(const Map& input, const Target& target_host); /*! * \brief Build a device and host module for a specific target from a map @@ -92,12 +81,9 @@ TVM_DLL runtime::Module build(const Map& input, * \param input The map contains target string to an IRModule. * \param target_host The target for building host code. To use the default, * pass Target(). - * \param config The build configuration. * \return The built module that contains code for different processors. */ -TVM_DLL runtime::Module build(const Map& input, - const Target& target_host, - const BuildConfig& config); +TVM_DLL runtime::Module build(const Map& input, const Target& target_host); } // namespace tvm #endif // TVM_DRIVER_DRIVER_API_H_ diff --git a/include/tvm/ir/adt.h b/include/tvm/ir/adt.h index f9cb62225584..466a4f00fd5f 100644 --- a/include/tvm/ir/adt.h +++ b/include/tvm/ir/adt.h @@ -27,11 +27,12 @@ #ifndef TVM_IR_ADT_H_ #define TVM_IR_ADT_H_ -#include -#include -#include #include #include +#include +#include +#include + #include namespace tvm { @@ -44,7 +45,7 @@ namespace tvm { class ConstructorNode : public RelayExprNode { public: /*! \brief The name (only a hint) */ - std::string name_hint; + String name_hint; /*! \brief Input to the constructor. */ Array inputs; /*! \brief The datatype the constructor will construct. */ @@ -66,9 +67,7 @@ class ConstructorNode : public RelayExprNode { bool SEqualReduce(const ConstructorNode* other, SEqualReducer equal) const { // Use namehint for now to be consistent with the legacy relay impl // TODO(tvm-team) revisit, need to check the type var. - return - equal(name_hint, other->name_hint) && - equal(inputs, other->inputs); + return equal(name_hint, other->name_hint) && equal(inputs, other->inputs); } void SHashReduce(SHashReducer hash_reduce) const { @@ -92,9 +91,7 @@ class Constructor : public RelayExpr { * \param inputs The input types. * \param belong_to The data type var the constructor will construct. */ - TVM_DLL Constructor(std::string name_hint, - Array inputs, - GlobalTypeVar belong_to); + TVM_DLL Constructor(String name_hint, Array inputs, GlobalTypeVar belong_to); TVM_DEFINE_OBJECT_REF_METHODS(Constructor, RelayExpr, ConstructorNode); }; @@ -122,10 +119,8 @@ class TypeDataNode : public TypeNode { } bool SEqualReduce(const TypeDataNode* other, SEqualReducer equal) const { - return - equal.DefEqual(header, other->header) && - equal.DefEqual(type_vars, other->type_vars) && - equal(constructors, other->constructors); + return equal.DefEqual(header, other->header) && equal.DefEqual(type_vars, other->type_vars) && + equal(constructors, other->constructors); } void SHashReduce(SHashReducer hash_reduce) const { @@ -157,9 +152,7 @@ class TypeData : public Type { * \param type_vars type variables. * \param constructors constructors field. */ - TVM_DLL TypeData(GlobalTypeVar header, - Array type_vars, - Array constructors); + TVM_DLL TypeData(GlobalTypeVar header, Array type_vars, Array constructors); TVM_DEFINE_OBJECT_REF_METHODS(TypeData, Type, TypeDataNode); }; diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h index d12f1b85114c..4cdf8c5cbe94 100644 --- a/include/tvm/ir/attrs.h +++ b/include/tvm/ir/attrs.h @@ -27,7 +27,7 @@ * struct MyAttrs : public tvm::AttrsNode { * float learning_rate; * int num_hidden; - * std::string name; + * String name; * // declare attribute fields in header file * TVM_DECLARE_ATTRS(MyAttrs, "attrs.MyAttrs") { * TVM_ATTR_FIELD(num_hidden).set_lower_bound(1); @@ -50,12 +50,12 @@ #include #include -#include -#include #include -#include #include +#include +#include #include +#include namespace tvm { /*! @@ -63,34 +63,30 @@ namespace tvm { * \param ClassName The name of the class. * \param TypeKey The type key to be used by the TVM node system. */ -#define TVM_DECLARE_ATTRS(ClassName, TypeKey) \ - static constexpr const char* _type_key = TypeKey; \ - TVM_DECLARE_FINAL_OBJECT_INFO(ClassName, ::tvm::BaseAttrsNode) \ - template \ +#define TVM_DECLARE_ATTRS(ClassName, TypeKey) \ + static constexpr const char* _type_key = TypeKey; \ + TVM_DECLARE_FINAL_OBJECT_INFO(ClassName, ::tvm::BaseAttrsNode) \ + template \ void __VisitAttrs__(FVisit& __fvisit__) // NOLINT(*) - /*! * \brief Declare an attribute field. * \param FieldName The field name. */ -#define TVM_ATTR_FIELD(FieldName) \ - __fvisit__(#FieldName, &FieldName) - +#define TVM_ATTR_FIELD(FieldName) __fvisit__(#FieldName, &FieldName) /*! * \brief Create a NodeRef type that represents null. * \tparam TNodeRef the type to be created. * \return A instance that will represent None. */ -template +template inline TObjectRef NullValue() { - static_assert(TObjectRef::_type_is_nullable, - "Can only get NullValue for nullable types"); + static_assert(TObjectRef::_type_is_nullable, "Can only get NullValue for nullable types"); return TObjectRef(ObjectPtr(nullptr)); } -template<> +template <> inline DataType NullValue() { return DataType(DataType::kHandle, 0, 0); } @@ -101,8 +97,7 @@ struct AttrError : public dmlc::Error { * \brief constructor * \param msg error message */ - explicit AttrError(const std::string &msg) - : dmlc::Error(msg) {} + explicit AttrError(std::string msg) : dmlc::Error("AttributeError:" + msg) {} }; /*! @@ -111,11 +106,11 @@ struct AttrError : public dmlc::Error { class AttrFieldInfoNode : public Object { public: /*! \brief name of the field */ - std::string name; + String name; /*! \brief type docstring information in str. */ - std::string type_info; + String type_info; /*! \brief detailed description of the type */ - std::string description; + String description; void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); @@ -154,13 +149,13 @@ class BaseAttrsNode : public Object { * \param args The postional arguments in the form * [key0, value0, key1, value1, ..., key_n, value_n] */ - template - inline void InitBySeq(Args&& ...args); + template + inline void InitBySeq(Args&&... args); /*! * \brief Print readible docstring to ostream, add newline. * \param os the stream to print the docstring to. */ - inline void PrintDocString(std::ostream &os) const; // NOLINT(*) + inline void PrintDocString(std::ostream& os) const; // NOLINT(*) /*! * \brief Visit attributes that do not equal the default value. * @@ -206,15 +201,13 @@ class Attrs : public ObjectRef { class DictAttrsNode : public BaseAttrsNode { public: /*! \brief internal attrs map */ - Map dict; + Map dict; bool SEqualReduce(const DictAttrsNode* other, SEqualReducer equal) const { return equal(dict, other->dict); } - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(dict); - } + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(dict); } // implementations void VisitAttrs(AttrVisitor* v) final; @@ -237,13 +230,25 @@ class DictAttrs : public Attrs { * \param dict The attributes. * \return The dict attributes. */ - TVM_DLL explicit DictAttrs(Map dict); - + TVM_DLL explicit DictAttrs(Map dict); TVM_DEFINE_OBJECT_REF_METHODS(DictAttrs, Attrs, DictAttrsNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(DictAttrsNode); }; +/*! + * \brief Create an Attr object with all default values. + * \tparam TAttrNode the type to be created. + * \return A instance that will represent None. + */ +template +inline TAttrs AttrsWithDefaultValues() { + static_assert(std::is_base_of::value, "Can only take attr nodes"); + auto n = make_object(); + n->InitByPackedArgs(runtime::TVMArgs(nullptr, nullptr, 0), false); + return TAttrs(n); +} + // Namespace containing detail implementations namespace detail { using runtime::TVMArgValue; @@ -252,18 +257,16 @@ using runtime::TVMArgValue; struct AttrNopEntry { using TSelf = AttrNopEntry; - TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) { - return *this; - } - template + TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) { return *this; } + template TSelf& set_default(DMLC_ATTRIBUTE_UNUSED const T& value) { return *this; } - template + template TSelf& set_lower_bound(DMLC_ATTRIBUTE_UNUSED const T& begin) { return *this; } - template + template TSelf& set_upper_bound(DMLC_ATTRIBUTE_UNUSED const T& end) { return *this; } @@ -272,10 +275,8 @@ struct AttrNopEntry { // Wrapper for normal visitor. class AttrNormalVisitor { public: - explicit AttrNormalVisitor(AttrVisitor* visitor) - : visitor_(visitor) { - } - template + explicit AttrNormalVisitor(AttrVisitor* visitor) : visitor_(visitor) {} + template AttrNopEntry operator()(const char* key, T* value) { visitor_->Visit(key, value); return AttrNopEntry(); @@ -290,16 +291,13 @@ class AttrsSEqualVisitor { bool result_{true}; // constructor AttrsSEqualVisitor(const Object* lhs, const Object* rhs, const SEqualReducer& equal) - : lhs_(lhs), rhs_(rhs), equal_(equal) { - } - template + : lhs_(lhs), rhs_(rhs), equal_(equal) {} + template AttrNopEntry operator()(const char* key, T* lhs_value) { if (!result_) return AttrNopEntry(); - const T* rhs_value = - reinterpret_cast( - reinterpret_cast(rhs_) + - (reinterpret_cast(lhs_value) - - reinterpret_cast(lhs_))); + const T* rhs_value = reinterpret_cast( + reinterpret_cast(rhs_) + + (reinterpret_cast(lhs_value) - reinterpret_cast(lhs_))); if (!equal_(*lhs_value, *rhs_value)) { result_ = false; } @@ -314,10 +312,9 @@ class AttrsSEqualVisitor { class AttrsSHashVisitor { public: - explicit AttrsSHashVisitor(const SHashReducer& hash_reducer) - : hash_reducer_(hash_reducer) {} + explicit AttrsSHashVisitor(const SHashReducer& hash_reducer) : hash_reducer_(hash_reducer) {} - template + template AttrNopEntry operator()(const char* key, T* value) { hash_reducer_(*value); return AttrNopEntry(); @@ -328,7 +325,7 @@ class AttrsSHashVisitor { }; // helper entry that does initialization, set default. -template +template struct AttrInitEntry { // The attributes using TSelf = AttrInitEntry; @@ -344,34 +341,31 @@ struct AttrInitEntry { ~AttrInitEntry() DMLC_THROW_EXCEPTION { if (value_missing_) { std::ostringstream os; - os << type_key_ << ": Cannot find required field \'" << key_ - << "\' during initialization"; + os << type_key_ << ": Cannot find required field \'" << key_ << "\' during initialization"; throw AttrError(os.str()); } } // override fields. // This function sets the lower bound of the attribute TSelf& set_lower_bound(DMLC_ATTRIBUTE_UNUSED const T& begin) { - if (this->value_missing_) return *this; + if (this->value_missing_) return *this; const T& val = *value_; if (begin > val) { std::ostringstream os; os << type_key_ << "." << key_ << ": " - << "value " << val - << " is smaller than the lower bound " << begin; + << "value " << val << " is smaller than the lower bound " << begin; throw AttrError(os.str()); } return *this; } // This function sets the upper bound of the attribute TSelf& set_upper_bound(DMLC_ATTRIBUTE_UNUSED const T& end) { - if (this->value_missing_) return *this; + if (this->value_missing_) return *this; const T& val = *value_; if (val > end) { std::ostringstream os; os << type_key_ << "." << key_ << ": " - << "value " << val - << " is bigger than the upper bound " << end; + << "value " << val << " is bigger than the upper bound " << end; throw AttrError(os.str()); } return *this; @@ -383,19 +377,17 @@ struct AttrInitEntry { value_missing_ = false; return *this; } - TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) { - return *this; - } + TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) { return *this; } }; // Template function to allow smart conversion // from Expr types into the constants. -template +template inline void SetValue(T* ptr, const TVMArgValue& val) { *ptr = val.operator T(); } -template +template inline void SetIntValue(T* ptr, const TVMArgValue& val) { if (val.type_code() == kDLInt) { *ptr = static_cast(val.value().v_int64); @@ -405,7 +397,7 @@ inline void SetIntValue(T* ptr, const TVMArgValue& val) { } } -template<> +template <> inline void SetValue(std::string* ptr, const TVMArgValue& val) { if (val.type_code() == kTVMStr) { *ptr = val.operator std::string(); @@ -414,7 +406,7 @@ inline void SetValue(std::string* ptr, const TVMArgValue& val) { } } -template<> +template <> inline void SetValue(double* ptr, const TVMArgValue& val) { if (val.type_code() == kDLFloat || val.type_code() == kDLInt) { *ptr = val.operator double(); @@ -430,36 +422,34 @@ inline void SetValue(double* ptr, const TVMArgValue& val) { } } } -template<> +template <> inline void SetValue(int* ptr, const TVMArgValue& val) { SetIntValue(ptr, val); } -template<> +template <> inline void SetValue(int64_t* ptr, const TVMArgValue& val) { SetIntValue(ptr, val); } -template<> +template <> inline void SetValue(uint64_t* ptr, const TVMArgValue& val) { SetIntValue(ptr, val); } -template<> +template <> inline void SetValue(bool* ptr, const TVMArgValue& val) { SetIntValue(ptr, val); } // Visitor for value initialization -template +template class AttrInitVisitor { public: // Counter of number of matched attributes during visit. // This is used to decide if there is additional unmatched attributes. size_t hit_count_{0}; // constructor - AttrInitVisitor(const char* type_key, FFind ffind) - : type_key_(type_key), ffind_(ffind) { - } + AttrInitVisitor(const char* type_key, FFind ffind) : type_key_(type_key), ffind_(ffind) {} - template + template AttrInitEntry operator()(const char* key, T* value) { TVMArgValue val; AttrInitEntry opt; @@ -482,10 +472,8 @@ class AttrInitVisitor { FFind ffind_; }; -template -inline AttrInitVisitor CreateInitVisitor( - const char* type_key, - FFind ffind) { +template +inline AttrInitVisitor CreateInitVisitor(const char* type_key, FFind ffind) { return AttrInitVisitor(type_key, ffind); } @@ -493,47 +481,47 @@ inline AttrInitVisitor CreateInitVisitor( * \brief Helper struct to get the type name known to tvm. * \tparam T the type we are interested in. */ -template +template struct TypeName { static constexpr const char* value = T::ContainerType::_type_key; }; -template<> +template <> struct TypeName { static constexpr const char* value = "int"; }; -template<> +template <> struct TypeName { static constexpr const char* value = "int64"; }; -template<> +template <> struct TypeName { static constexpr const char* value = "uint64_t"; }; -template<> +template <> struct TypeName { static constexpr const char* value = "DataType"; }; -template<> +template <> struct TypeName { static constexpr const char* value = "str"; }; -template<> +template <> struct TypeName { static constexpr const char* value = "bool"; }; -template<> +template <> struct TypeName { static constexpr const char* value = "handle"; }; -template<> +template <> struct TypeName { static constexpr const char* value = "double"; }; @@ -542,25 +530,23 @@ class AttrDocEntry { public: using TSelf = AttrDocEntry; - explicit AttrDocEntry(ObjectPtr info) - : info_(info) { - } + explicit AttrDocEntry(ObjectPtr info) : info_(info) {} TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) { info_->description = str; return *this; } - template + template TSelf& set_default(DMLC_ATTRIBUTE_UNUSED const T& value) { std::ostringstream os; os << info_->type_info << ", default=" << value; info_->type_info = os.str(); return *this; } - template + template TSelf& set_lower_bound(DMLC_ATTRIBUTE_UNUSED T begin) { return *this; } - template + template TSelf& set_upper_bound(DMLC_ATTRIBUTE_UNUSED T end) { return *this; } @@ -571,10 +557,9 @@ class AttrDocEntry { class AttrDocVisitor { public: - template + template AttrDocEntry operator()(const char* key, T* v) { - ObjectPtr info - = make_object(); + ObjectPtr info = make_object(); info->name = key; info->type_info = TypeName::value; fields_.push_back(AttrFieldInfo(info)); @@ -589,7 +574,7 @@ class AttrExistVisitor { std::string key_; bool exist_{false}; - template + template AttrNopEntry operator()(const char* key, T* v) { if (exist_) return AttrNopEntry(); if (key == key_) exist_ = true; @@ -597,12 +582,11 @@ class AttrExistVisitor { } }; -template +template struct AttrTriggerNonDefaultEntry { using TSelf = AttrTriggerNonDefaultEntry; // constructor - AttrTriggerNonDefaultEntry( - AttrVisitor* visitor, const char* key, T* data) + AttrTriggerNonDefaultEntry(AttrVisitor* visitor, const char* key, T* data) : visitor_(visitor), key_(key), data_(data) {} ~AttrTriggerNonDefaultEntry() DMLC_THROW_EXCEPTION { @@ -610,37 +594,28 @@ struct AttrTriggerNonDefaultEntry { visitor_->Visit(key_, data_); } } - TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) { - return *this; - } + TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) { return *this; } TSelf& set_default(const T& value) { if (tvm::StructuralEqual()(value, *data_)) { trigger_ = false; } return *this; } - TSelf& set_lower_bound(DMLC_ATTRIBUTE_UNUSED const T& begin) { - return *this; - } - TSelf& set_upper_bound(DMLC_ATTRIBUTE_UNUSED const T& end) { - return *this; - } + TSelf& set_lower_bound(DMLC_ATTRIBUTE_UNUSED const T& begin) { return *this; } + TSelf& set_upper_bound(DMLC_ATTRIBUTE_UNUSED const T& end) { return *this; } private: AttrVisitor* visitor_; - const char * key_; - T *data_; + const char* key_; + T* data_; bool trigger_{true}; }; class AttrNonDefaultVisitor { public: - explicit AttrNonDefaultVisitor(AttrVisitor* visitor) - : visitor_(visitor) { - } - template - AttrTriggerNonDefaultEntry - operator()(const char* key, T* value) { + explicit AttrNonDefaultVisitor(AttrVisitor* visitor) : visitor_(visitor) {} + template + AttrTriggerNonDefaultEntry operator()(const char* key, T* value) { return AttrTriggerNonDefaultEntry(visitor_, key, value); } @@ -655,7 +630,7 @@ class AttrNonDefaultVisitor { * * \tparam DerivedType The final attribute type. */ -template +template class AttrsNode : public BaseAttrsNode { public: void VisitAttrs(AttrVisitor* v) { @@ -695,7 +670,7 @@ class AttrsNode : public BaseAttrsNode { CHECK_EQ(args.type_codes[i], kTVMStr); kwargs[args[i].operator std::string()] = args[i + 1]; } - auto ffind = [&kwargs](const char *key, runtime::TVMArgValue* val) { + auto ffind = [&kwargs](const char* key, runtime::TVMArgValue* val) { auto it = kwargs.find(key); if (it != kwargs.end()) { *val = it->second; @@ -715,8 +690,7 @@ class AttrsNode : public BaseAttrsNode { self()->__VisitAttrs__(visitor); if (!visitor.exist_) { std::ostringstream os; - os << DerivedType::_type_key - << ": does not have field \'" << visitor.key_ + os << DerivedType::_type_key << ": does not have field \'" << visitor.key_ << "\', Possible fields:\n"; os << "----------------\n"; this->PrintDocString(os); @@ -746,21 +720,18 @@ class AttrsNode : public BaseAttrsNode { private: DerivedType* self() const { - return const_cast( - static_cast(this)); + return const_cast(static_cast(this)); } }; - -template -inline void BaseAttrsNode::InitBySeq(Args&& ...args) { - runtime::PackedFunc pf([this](const TVMArgs& args, TVMRetValue *rv) { - this->InitByPackedArgs(args); - }); +template +inline void BaseAttrsNode::InitBySeq(Args&&... args) { + runtime::PackedFunc pf( + [this](const TVMArgs& args, TVMRetValue* rv) { this->InitByPackedArgs(args); }); pf(std::forward(args)...); } -inline void BaseAttrsNode::PrintDocString(std::ostream &os) const { // NOLINT(*) +inline void BaseAttrsNode::PrintDocString(std::ostream& os) const { // NOLINT(*) Array entry = this->ListFieldInfo(); for (AttrFieldInfo info : entry) { os << info->name << " : " << info->type_info << '\n'; diff --git a/include/tvm/ir/env_func.h b/include/tvm/ir/env_func.h index 67492ab24ba6..65653b75562d 100644 --- a/include/tvm/ir/env_func.h +++ b/include/tvm/ir/env_func.h @@ -41,15 +41,13 @@ namespace tvm { class EnvFuncNode : public Object { public: /*! \brief Unique name of the global function */ - std::string name; + String name; /*! \brief The internal packed function */ runtime::PackedFunc func; /*! \brief constructor */ EnvFuncNode() {} - void VisitAttrs(AttrVisitor* v) { - v->Visit("name", &name); - } + void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); } bool SEqualReduce(const EnvFuncNode* other, SEqualReducer equal) const { // name uniquely identifies the env function. @@ -76,15 +74,13 @@ class EnvFunc : public ObjectRef { EnvFunc() {} explicit EnvFunc(ObjectPtr n) : ObjectRef(n) {} /*! \return The internal global function pointer */ - const EnvFuncNode* operator->() const { - return static_cast(get()); - } + const EnvFuncNode* operator->() const { return static_cast(get()); } /*! * \brief Invoke the function. * \param args The arguments * \returns The return value. */ - template + template runtime::TVMRetValue operator()(Args&&... args) const { const EnvFuncNode* n = operator->(); CHECK(n != nullptr); @@ -96,7 +92,7 @@ class EnvFunc : public ObjectRef { * \return The created global function. * \note The function can be unique */ - TVM_DLL static EnvFunc Get(const std::string& name); + TVM_DLL static EnvFunc Get(const String& name); /*! \brief specify container node */ using ContainerType = EnvFuncNode; }; @@ -104,7 +100,7 @@ class EnvFunc : public ObjectRef { /*! * \brief Please refer to \ref TypedEnvFuncAnchor "TypedEnvFunc" */ -template +template class TypedEnvFunc; /*! @@ -116,7 +112,7 @@ class TypedEnvFunc; * \tparam Args The argument signature of the function. * \sa EnvFunc */ -template +template class TypedEnvFunc : public ObjectRef { public: /*! \brief short hand for this function type */ @@ -133,9 +129,7 @@ class TypedEnvFunc : public ObjectRef { return *this; } /*! \return The internal global function pointer */ - const EnvFuncNode* operator->() const { - return static_cast(get()); - } + const EnvFuncNode* operator->() const { return static_cast(get()); } /*! * \brief Invoke the function. * \param args The arguments @@ -144,8 +138,8 @@ class TypedEnvFunc : public ObjectRef { R operator()(Args... args) const { const EnvFuncNode* n = operator->(); CHECK(n != nullptr); - return runtime::detail::typed_packed_call_dispatcher - ::run(n->func, std::forward(args)...); + return runtime::detail::typed_packed_call_dispatcher::run(n->func, + std::forward(args)...); } /*! \brief specify container node */ using ContainerType = EnvFuncNode; diff --git a/include/tvm/ir/error.h b/include/tvm/ir/error.h index 94064ae8c8fa..ac7b96a3bd59 100644 --- a/include/tvm/ir/error.h +++ b/include/tvm/ir/error.h @@ -24,13 +24,13 @@ #ifndef TVM_IR_ERROR_H_ #define TVM_IR_ERROR_H_ -#include #include +#include -#include -#include #include +#include #include +#include namespace tvm { /*! @@ -51,7 +51,7 @@ namespace tvm { */ struct ErrorBuilder { public: - template + template ErrorBuilder& operator<<(const T& val) { // NOLINT(*) stream_ << val; return *this; @@ -78,12 +78,12 @@ class Error : public dmlc::Error { * \brief construct error from error builder. * \param err The error builder */ - Error(const ErrorBuilder& err) : dmlc::Error(err.stream_.str()), span(nullptr) {} // NOLINT(*) + Error(const ErrorBuilder& err) : dmlc::Error(err.stream_.str()), span(nullptr) {} // NOLINT(*) /*! * \brief copy constructor. * \param other The other ereor. */ - Error(const Error& other) : dmlc::Error(other.what()), span(other.span) {} // NOLINT(*) + Error(const Error& other) : dmlc::Error(other.what()), span(other.span) {} // NOLINT(*) /*! * \brief default constructor. */ Error() : dmlc::Error(""), span(nullptr) {} @@ -173,14 +173,12 @@ class ErrorReporter { */ void RenderErrors(const IRModule& module, bool use_color = true); - inline bool AnyErrors() { - return errors_.size() != 0; - } + inline bool AnyErrors() { return errors_.size() != 0; } private: std::vector errors_; - std::unordered_map, ObjectHash, ObjectEqual> node_to_error_; - std::unordered_map node_to_gv_; + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> node_to_error_; + std::unordered_map node_to_gv_; }; } // namespace tvm diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index 6630bf3ded20..b2ce50d91f58 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -24,27 +24,31 @@ #ifndef TVM_IR_EXPR_H_ #define TVM_IR_EXPR_H_ -#include -#include -#include #include #include -#include +#include +#include +#include + #include #include +#include #include namespace tvm { +using tvm::runtime::String; + /*! * \brief Base type of all the expressions. * \sa Expr */ class BaseExprNode : public Object { public: - static constexpr const char* _type_key = "Expr"; + static constexpr const char* _type_key = "BaseExpr"; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true; + static constexpr const uint32_t _type_child_slots = 58; TVM_DECLARE_BASE_OBJECT_INFO(BaseExprNode, Object); }; @@ -88,6 +92,7 @@ class PrimExprNode : public BaseExprNode { DataType dtype; static constexpr const char* _type_key = "PrimExpr"; + static constexpr const uint32_t _type_child_slots = 34; TVM_DECLARE_BASE_OBJECT_INFO(PrimExprNode, BaseExprNode); }; @@ -109,9 +114,7 @@ class PrimExpr : public BaseExpr { TVM_DLL PrimExpr(float value); // NOLINT(*) /*! \return the data type of this expression. */ - DataType dtype() const { - return static_cast(get())->dtype; - } + DataType dtype() const { return static_cast(get())->dtype; } TVM_DEFINE_OBJECT_REF_METHODS(PrimExpr, BaseExpr, PrimExprNode); @@ -158,10 +161,11 @@ class RelayExprNode : public BaseExprNode { * \return The corresponding TTypeNode pointer. * \tparam The specific TypeNode we look for. */ - template + template inline const TTypeNode* type_as() const; - static constexpr const char* _type_key = "relay.Expr"; + static constexpr const char* _type_key = "RelayExpr"; + static constexpr const uint32_t _type_child_slots = 22; TVM_DECLARE_BASE_OBJECT_INFO(RelayExprNode, BaseExprNode); }; @@ -186,7 +190,7 @@ class GlobalVar; class GlobalVarNode : public RelayExprNode { public: /*! \brief The name of the variable, this only acts as a hint. */ - std::string name_hint; + String name_hint; void VisitAttrs(AttrVisitor* v) { v->Visit("name_hint", &name_hint); @@ -196,9 +200,7 @@ class GlobalVarNode : public RelayExprNode { bool SEqualReduce(const GlobalVarNode* other, SEqualReducer equal) const { // name matters for global var. - return - equal(name_hint, other->name_hint) && - equal.FreeVarEqualImpl(this, other); + return equal(name_hint, other->name_hint) && equal.FreeVarEqualImpl(this, other); } void SHashReduce(SHashReducer hash_reduce) const { @@ -216,7 +218,7 @@ class GlobalVarNode : public RelayExprNode { */ class GlobalVar : public RelayExpr { public: - TVM_DLL explicit GlobalVar(std::string name_hint); + TVM_DLL explicit GlobalVar(String name_hint); TVM_DEFINE_OBJECT_REF_METHODS(GlobalVar, RelayExpr, GlobalVarNode); }; @@ -319,35 +321,21 @@ class FloatImm : public PrimExpr { */ class Bool : public IntImm { public: - explicit Bool(bool value) - : IntImm(DataType::Bool(), value) { - } - Bool operator!() const { - return Bool((*this)->value == 0); - } - operator bool() const { - return (*this)->value != 0; - } + explicit Bool(bool value) : IntImm(DataType::Bool(), value) {} + Bool operator!() const { return Bool((*this)->value == 0); } + operator bool() const { return (*this)->value != 0; } TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Bool, IntImm, IntImmNode); }; // Overload operators to make sure we have the most fine grained types. -inline Bool operator||(const Bool& a, bool b) { - return Bool(a.operator bool() || b); -} -inline Bool operator||(bool a, const Bool& b) { - return Bool(a || b.operator bool()); -} +inline Bool operator||(const Bool& a, bool b) { return Bool(a.operator bool() || b); } +inline Bool operator||(bool a, const Bool& b) { return Bool(a || b.operator bool()); } inline Bool operator||(const Bool& a, const Bool& b) { return Bool(a.operator bool() || b.operator bool()); } -inline Bool operator&&(const Bool& a, bool b) { - return Bool(a.operator bool() && b); -} -inline Bool operator&&(bool a, const Bool& b) { - return Bool(a && b.operator bool()); -} +inline Bool operator&&(const Bool& a, bool b) { return Bool(a.operator bool() && b); } +inline Bool operator&&(bool a, const Bool& b) { return Bool(a && b.operator bool()); } inline Bool operator&&(const Bool& a, const Bool& b) { return Bool(a.operator bool() && b.operator bool()); } @@ -381,8 +369,7 @@ class Integer : public IntImm { * \tparam Enum The enum type. * \param value The enum value. */ - template::value>::type> + template ::value>::type> explicit Integer(Enum value) : Integer(static_cast(value)) { static_assert(std::is_same::type>::value, "declare enum to be enum int to use visitor"); @@ -399,8 +386,7 @@ class Integer : public IntImm { * \brief convert to int64_t */ operator int64_t() const { - CHECK(data_ != nullptr) - << " Trying to reference a null Integer"; + CHECK(data_ != nullptr) << " Trying to reference a null Integer"; return (*this)->value; } // comparators @@ -408,16 +394,12 @@ class Integer : public IntImm { if (data_ == nullptr) return Bool(false); return Bool((*this)->value == other); } - Bool operator!=(int other) const { - return !(*this == other); - } - template::value>::type> + Bool operator!=(int other) const { return !(*this == other); } + template ::value>::type> Bool operator==(Enum other) const { return *this == static_cast(other); } - template::value>::type> + template ::value>::type> Bool operator!=(Enum other) const { return *this != static_cast(other); } @@ -479,24 +461,21 @@ class Range : public ObjectRef { // implementataions inline const Type& RelayExprNode::checked_type() const { - CHECK(checked_type_.defined()) - << "internal error: the type checker has " - << "not populated the checked_type " - << "field for " - << GetRef(this); + CHECK(checked_type_.defined()) << "internal error: the type checker has " + << "not populated the checked_type " + << "field for " << GetRef(this); return this->checked_type_; } -template +template inline const TTypeNode* RelayExprNode::type_as() const { static_assert(std::is_base_of::value, "TType must be a special case of type"); CHECK(checked_type_.defined()) << "Type inference for this Expr has not completed. Try to call infer_type pass."; const TTypeNode* node = checked_type_.as(); - CHECK(node != nullptr) - << "Expected type to be " << TTypeNode::_type_key - << ", but get " << checked_type_->GetTypeKey(); + CHECK(node != nullptr) << "Expected type to be " << TTypeNode::_type_key << ", but get " + << checked_type_->GetTypeKey(); return node; } @@ -504,7 +483,7 @@ inline const TTypeNode* RelayExprNode::type_as() const { namespace tvm { namespace runtime { -template<> +template <> struct PackedFuncValueConverter { // common rule for both RetValue and ArgValue. static PrimExpr From(const TVMPODValue_& val) { diff --git a/include/tvm/ir/function.h b/include/tvm/ir/function.h index d55656f34b00..5b9e0714e202 100644 --- a/include/tvm/ir/function.h +++ b/include/tvm/ir/function.h @@ -24,12 +24,12 @@ #ifndef TVM_IR_FUNCTION_H_ #define TVM_IR_FUNCTION_H_ -#include #include +#include #include -#include -#include +#include +#include namespace tvm { @@ -96,7 +96,7 @@ class BaseFuncNode : public RelayExprNode { * * \endcode */ - template + template Optional GetAttr( const std::string& attr_key, Optional default_value = Optional(nullptr)) const { @@ -111,9 +111,8 @@ class BaseFuncNode : public RelayExprNode { } } // variant that uses TObjectRef to enable implicit conversion to default value. - template - Optional GetAttr( - const std::string& attr_key, TObjectRef default_value) const { + template + Optional GetAttr(const std::string& attr_key, TObjectRef default_value) const { return GetAttr(attr_key, Optional(default_value)); } /*! @@ -140,6 +139,7 @@ class BaseFuncNode : public RelayExprNode { } static constexpr const char* _type_key = "BaseFunc"; + static constexpr const uint32_t _type_child_slots = 2; TVM_DECLARE_BASE_OBJECT_INFO(BaseFuncNode, RelayExprNode); }; @@ -179,19 +179,16 @@ class BaseFunc : public RelayExpr { * * \endcode */ -template::value>::type> -inline TFunc WithAttr(TFunc func, - const std::string& attr_key, - ObjectRef attr_value) { +template ::value>::type> +inline TFunc WithAttr(TFunc func, const std::string& attr_key, ObjectRef attr_value) { using TNode = typename TFunc::ContainerType; static_assert(TNode::_type_final, "Can only operate on the leaf nodes"); TNode* node = func.CopyOnWrite(); if (node->attrs.defined()) { node->attrs.CopyOnWrite()->dict.Set(attr_key, attr_value); } else { - Map dict = {{attr_key, attr_value}}; + Map dict = {{attr_key, attr_value}}; node->attrs = DictAttrs(dict); } return func; diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h index b0776dee661f..7af84b687f5f 100644 --- a/include/tvm/ir/module.h +++ b/include/tvm/ir/module.h @@ -24,15 +24,16 @@ #ifndef TVM_IR_MODULE_H_ #define TVM_IR_MODULE_H_ -#include +#include #include #include -#include +#include +#include #include -#include #include #include +#include namespace tvm { class IRModule; @@ -102,8 +103,7 @@ class IRModuleNode : public Object { * * It does not do type checking as AddTypeDef does. */ - TVM_DLL void AddTypeDefUnchecked(const GlobalTypeVar& var, - const TypeData& type, + TVM_DLL void AddTypeDefUnchecked(const GlobalTypeVar& var, const TypeData& type, bool update = false); /*! @@ -131,21 +131,21 @@ class IRModuleNode : public Object { * \param name The variable name. * \returns true if contains, otherise false. */ - TVM_DLL bool ContainGlobalVar(const std::string& name) const; + TVM_DLL bool ContainGlobalVar(const String& name) const; /*! * \brief Check if the global_type_var_map_ contains a global type variable. * \param name The variable name. * \returns true if contains, otherise false. */ - TVM_DLL bool ContainGlobalTypeVar(const std::string& name) const; + TVM_DLL bool ContainGlobalTypeVar(const String& name) const; /*! * \brief Lookup a global function by its variable. * \param str The unique string specifying the global variable. * \returns The global variable. */ - TVM_DLL GlobalVar GetGlobalVar(const std::string& str) const; + TVM_DLL GlobalVar GetGlobalVar(const String& str) const; /*! * \brief Collect all global vars defined in this module. @@ -158,7 +158,7 @@ class IRModuleNode : public Object { * \param str The unique string specifying the global variable. * \returns The global variable. */ - TVM_DLL GlobalTypeVar GetGlobalTypeVar(const std::string& str) const; + TVM_DLL GlobalTypeVar GetGlobalTypeVar(const String& str) const; /*! * \brief Collect all global type vars defined in this module. @@ -172,7 +172,7 @@ class IRModuleNode : public Object { * \param cons name of the constructor * \returns Constructor of ADT, error if not found */ - TVM_DLL Constructor GetConstructor(const std::string& adt, const std::string& cons) const; + TVM_DLL Constructor GetConstructor(const String& adt, const String& cons) const; /*! * \brief Look up a global function by its variable. @@ -186,7 +186,7 @@ class IRModuleNode : public Object { * \param name The name of the function. * \returns The function named by the argument. */ - TVM_DLL BaseFunc Lookup(const std::string& name) const; + TVM_DLL BaseFunc Lookup(const String& name) const; /*! * \brief Look up a global type definition by its variable. @@ -200,7 +200,7 @@ class IRModuleNode : public Object { * \param var The name of the global type definition. * \return The type definition. */ - TVM_DLL TypeData LookupTypeDef(const std::string& var) const; + TVM_DLL TypeData LookupTypeDef(const String& var) const; /*! * \brief Look up a constructor by its tag. @@ -225,18 +225,18 @@ class IRModuleNode : public Object { * relative it will be resovled against the current * working directory. */ - TVM_DLL void Import(const std::string& path); + TVM_DLL void Import(const String& path); /*! * \brief Import Relay code from the file at path, relative to the standard library. * \param path The path of the Relay code to import. */ - TVM_DLL void ImportFromStd(const std::string& path); + TVM_DLL void ImportFromStd(const String& path); /*! * \brief The set of imported files. */ - TVM_DLL std::unordered_set Imports() const; + TVM_DLL std::unordered_set Imports() const; static constexpr const char* _type_key = "IRModule"; static constexpr const bool _type_has_method_sequal_reduce = true; @@ -250,12 +250,12 @@ class IRModuleNode : public Object { /*! \brief A map from string names to global variables that * ensures global uniqueness. */ - Map global_var_map_; + Map global_var_map_; /*! \brief A map from string names to global type variables (ADT names) * that ensures global uniqueness. */ - Map global_type_var_map_; + Map global_type_var_map_; /*! \brief A map from constructor tags to constructor objects * for convenient access @@ -265,7 +265,7 @@ class IRModuleNode : public Object { /*! \brief The files previously imported, required to ensure importing is idempotent for each module. */ - std::unordered_set import_set_; + std::unordered_set import_set_; friend class IRModule; }; @@ -283,9 +283,9 @@ class IRModule : public ObjectRef { */ TVM_DLL explicit IRModule(Map functions, Map type_definitions = {}, - std::unordered_set import_set = {}); + std::unordered_set import_set = {}); /*! \brief default constructor */ - IRModule() {} + IRModule() : IRModule(Map()) {} /*! * \brief constructor * \param n The object pointer. @@ -298,14 +298,6 @@ class IRModule : public ObjectRef { return static_cast(ptr); } - /*! - * \brief Construct an empty module. - * - * \returns The constructed module - */ - static IRModule Empty() { - return IRModule(Map()); - } /*! * \brief Construct a module from a standalone expression. * @@ -318,10 +310,9 @@ class IRModule : public ObjectRef { * * \returns A module with expr set as the main function. */ - TVM_DLL static IRModule FromExpr( - const RelayExpr& expr, - const Map& global_funcs = {}, - const Map& type_definitions = {}); + TVM_DLL static IRModule FromExpr(const RelayExpr& expr, + const Map& global_funcs = {}, + const Map& type_definitions = {}); /*! * \brief Parse text format source file into an IRModule. @@ -329,10 +320,14 @@ class IRModule : public ObjectRef { * \param source_path The path to the source file. * \return A Relay module. */ - TVM_DLL static IRModule FromText(const std::string& text, const std::string& source_path); + TVM_DLL static IRModule FromText(const String& text, const String& source_path); /*! \brief Declare the container type. */ using ContainerType = IRModuleNode; + + /*! \brief Declare whether Ref is nullable. */ + static constexpr bool _type_is_nullable = false; + // allow copy on write. TVM_DEFINE_OBJECT_REF_COW_METHOD(IRModuleNode); }; @@ -346,7 +341,7 @@ class IRModule : public ObjectRef { * Use AsText if you want to store the text. * \sa AsText. */ -TVM_DLL std::string PrettyPrint(const ObjectRef& node); +TVM_DLL String PrettyPrint(const ObjectRef& node); /*! * \brief Render the node as a string in the text format. @@ -362,8 +357,7 @@ TVM_DLL std::string PrettyPrint(const ObjectRef& node); * \sa PrettyPrint. * \return The text representation. */ -TVM_DLL std::string AsText(const ObjectRef& node, - bool show_meta_data = true, - runtime::TypedPackedFunc annotate = nullptr); +TVM_DLL String AsText(const ObjectRef& node, bool show_meta_data = true, + runtime::TypedPackedFunc annotate = nullptr); } // namespace tvm #endif // TVM_IR_MODULE_H_ diff --git a/include/tvm/ir/op.h b/include/tvm/ir/op.h index 48cf61d187d5..2bc2c90c7854 100644 --- a/include/tvm/ir/op.h +++ b/include/tvm/ir/op.h @@ -27,10 +27,11 @@ #include #include -#include #include #include #include +#include +#include #include #include @@ -39,10 +40,8 @@ namespace tvm { // forward declare name. -template -class OpMap; -class GenericOpMap; -class OpRegistry; +template +class OpAttrMap; // TODO(tvm-team): migrate low-level intrinsics to use Op /*! @@ -59,21 +58,21 @@ class OpRegistry; class OpNode : public RelayExprNode { public: /*! \brief name of the operator */ - std::string name; + String name; /*! \brief the type of the operator */ mutable FuncType op_type; /*! * \brief detailed description of the operator * This can be used to generate docstring automatically for the operator. */ - std::string description; + String description; /* \brief Information of input arguments to the operator */ Array arguments; /*! * \brief The type key of the attribute field * This can be empty, in which case it defaults to anything. */ - std::string attrs_type_key; + String attrs_type_key; /*! * \brief attribute type index, * this field varies in each run and is not exposed to frontend. @@ -122,13 +121,22 @@ class OpNode : public RelayExprNode { return is_primitive_ != 0; } - static constexpr const char* _type_key = "relay.Op"; + static constexpr const char* _type_key = "Op"; TVM_DECLARE_FINAL_OBJECT_INFO(OpNode, RelayExprNode); private: + /*! \return the internal attr registry index. */ + uint32_t AttrRegistryIndex() const { return index_; } + /*! \brief repr to be printed in registry*/ + std::string AttrRegistryName() const { return name; } + // friend class - friend class GenericOpMap; - friend class OpRegistry; + template + friend class AttrRegistryMapContainerMap; + template + friend class AttrRegistry; + friend class OpRegEntry; + friend bool IsPrimitiveOp(const RelayExpr&); // Program internal unique index of operator. // Used to help index the program. @@ -166,26 +174,26 @@ class Op : public RelayExpr { inline const OpNode* operator->() const; /*! * \brief Get additional registered attribute about operators. - * If nothing has been registered, an empty OpMap will be returned. + * If nothing has been registered, an empty OpAttrMap will be returned. * \param attr_name The name of the attribute. - * \return An OpMap of specified attr_name. + * \return An OpAttrMap of specified attr_name. * \tparam ValueType The type of the attribute. */ template - inline static OpMap GetAttr(const std::string& attr_name); + inline static OpAttrMap GetAttrMap(const String& attr_name); /*! - * \brief Checks if an attr is present in the registry. + * \brief Checks if an attr map is present in the registry. * \param attr_name The name of the attribute. * \return bool True if the attr is present. */ - inline static bool HasAttr(const std::string& attr_name); + TVM_DLL static bool HasAttrMap(const String& attr_name); /*! * \brief Get an Op for a given operator name. * Will raise an error if the op has not been registered. * \param op_name Name of the operator. * \return Pointer to a Op, valid throughout program lifetime. */ - TVM_DLL static const Op& Get(const std::string& op_name); + TVM_DLL static const Op& Get(const String& op_name); /*! \brief specify container node */ using ContainerType = OpNode; @@ -194,22 +202,16 @@ class Op : public RelayExpr { /*! * \brief Get generic attrmap given attr name * \param key The attribute key - * \return reference to GenericOpMap + * \return The attr map. */ - TVM_DLL static const GenericOpMap& GetGenericAttr(const std::string& key); - /*! - * \brief Checks if the key is present in the registry - * \param key The attribute key - * \return bool True if the key is present - */ - TVM_DLL static bool HasGenericAttr(const std::string& key); + TVM_DLL static const AttrRegistryMapContainerMap& GetAttrMapContainer(const String& key); }; /*! * \brief Helper structure to register operators * \sa TVM_REGISTER_OP */ -class OpRegistry { +class OpRegEntry { public: /*! \return the operator */ const Op& op() const { return op_; } @@ -219,7 +221,7 @@ class OpRegistry { * \param descr the description string. * \return reference to self. */ - inline OpRegistry& describe(const std::string& descr); // NOLINT(*) + inline OpRegEntry& describe(const std::string& descr); // NOLINT(*) /*! * \brief Add argument information to the function. * \param name Name of the argument. @@ -227,8 +229,7 @@ class OpRegistry { * \param description Description of the argument. * \return reference to self. */ - inline OpRegistry& add_argument(const std::string& name, - const std::string& type, + inline OpRegEntry& add_argument(const std::string& name, const std::string& type, const std::string& description); /*! * \brief Attach the type function corresponding to the return type. @@ -237,31 +238,29 @@ class OpRegistry { * relation on variables. * \return reference to self. */ - inline OpRegistry& add_type_rel( + inline OpRegEntry& add_type_rel( const std::string& rel_name, - runtime::TypedPackedFunc&, - int, - const Attrs&, - const TypeReporter&)> type_rel_func); + runtime::TypedPackedFunc&, int, const Attrs&, const TypeReporter&)> + type_rel_func); /*! * \brief Set the the attrs type key and index to be AttrsType. * \tparam AttrsType the attribute type to b set. * \return reference to self. */ - template - inline OpRegistry& set_attrs_type(); + template + inline OpRegEntry& set_attrs_type(); /*! * \brief Set the num_inputs * \param n The number of inputs to be set. * \return reference to self. */ - inline OpRegistry& set_num_inputs(int32_t n); // NOLINT(*) + inline OpRegEntry& set_num_inputs(int32_t n); // NOLINT(*) /*! * \brief Set the support level of op. * \param level The support level. * \return reference to self. */ - inline OpRegistry& set_support_level(int32_t level); // NOLINT(*) + inline OpRegEntry& set_support_level(int32_t level); // NOLINT(*) /*! * \brief Register additional attributes to operator. * \param attr_name The name of the attribute. @@ -276,7 +275,7 @@ class OpRegistry { * \tparam ValueType The type of the value to be set. */ template - inline OpRegistry& set_attr(const std::string& attr_name, // NOLINT(*) + inline OpRegEntry& set_attr(const std::string& attr_name, // NOLINT(*) const ValueType& value, int plevel = 10); /*! @@ -286,76 +285,32 @@ class OpRegistry { inline void reset_attr(const std::string& attr_name); // set the name of the op to be the same as registry - inline OpRegistry& set_name() { // NOLINT(*) + inline OpRegEntry& set_name() { // NOLINT(*) if (get()->name.length() == 0) { get()->name = name; } return *this; } - /*! \return The global single registry */ - TVM_DLL static ::dmlc::Registry* Registry(); + /*! + * \brief Register or get a new entry. + * \param name The name of the operator. + * \return the corresponding entry. + */ + TVM_DLL static OpRegEntry& RegisterOrGet(const String& name); private: - friend class ::dmlc::Registry; + template + friend class AttrRegistry; // the name std::string name; /*! \brief The operator */ Op op_; // private constructor - TVM_DLL OpRegistry(); + TVM_DLL OpRegEntry(uint32_t reg_index); // return internal pointer to op. inline OpNode* get(); - // update the attribute OpMap - TVM_DLL void UpdateAttr(const std::string& key, - runtime::TVMRetValue value, - int plevel); -}; - -/*! - * \brief Generic map to store additional information of Op. - */ -class GenericOpMap { - public: - /*! - * \brief Check if the map has op as key. - * \param op The key to the map - * \return 1 if op is contained in map, 0 otherwise. - */ - inline int count(const Op& op) const; - /*! - * \brief get the corresponding value element at op - * \param op The key to the map - * \return the const reference to the content value. - */ - inline const runtime::TVMRetValue& operator[](const Op& op) const; - /*! - * \brief get the corresponding value element at op with default value. - * \param op The key to the map - * \param def_value The default value when the key does not exist. - * \return the const reference to the content value. - * \tparam ValueType The content value type. - */ - template - inline ValueType get(const Op& op, ValueType def_value) const; - /*! - * \brief get the corresponding value element at op with default value. - * \param expr The key to the map - * \param def_value The default value when the key does not exist - * or if expr is not an Op. - * \return the const reference to the content value. - * \tparam ValueType The content value type. - */ - template - inline ValueType get(const RelayExpr& expr, ValueType def_value) const; - - private: - friend class OpRegistry; - // the attribute field. - std::string attr_name_; - // internal data - std::vector > data_; - // The value - GenericOpMap() = default; + // update the attribute OpAttrMap + TVM_DLL void UpdateAttr(const String& key, runtime::TVMRetValue value, int plevel); }; /*! @@ -363,27 +318,8 @@ class GenericOpMap { * \tparam ValueType The type of the value stored in map. */ template -class OpMap { +class OpAttrMap : public AttrRegistryMap { public: - /*! - * \brief Check if the map has op as key. - * \param op The key to the map - * \return 1 if op is contained in map, 0 otherwise. - */ - inline int count(const Op& op) const; - /*! - * \brief get the corresponding value element at op - * \param op The key to the map - * \return the const reference to the content value. - */ - inline ValueType operator[](const Op& op) const; - /*! - * \brief get the corresponding value element at op with default value. - * \param op The key to the map - * \param def_value The default value when the key does not exist. - * \return the const reference to the content value. - */ - inline ValueType get(const Op& op, ValueType def_value) const; /*! * \brief get the corresponding value element at op with default value. * \param expr The key to the map @@ -393,12 +329,15 @@ class OpMap { */ inline ValueType get(const RelayExpr& expr, ValueType def_value) const; + using TParent = AttrRegistryMap; + using TParent::count; + using TParent::get; + using TParent::operator[]; + private: friend class Op; // constructor - explicit OpMap(const GenericOpMap& map) : map_(map) {} - /*! \brief The internal map field */ - const GenericOpMap& map_; + explicit OpAttrMap(const AttrRegistryMapContainerMap& map) : TParent(map) {} }; #define TVM_STRINGIZE_DETAIL(x) #x @@ -410,8 +349,7 @@ class OpMap { #define TVM_ADD_FILELINE "\n\nDefined in " __FILE__ ":L" TVM_STRINGIZE(__LINE__) // internal macros to make -#define TVM_OP_REGISTER_VAR_DEF \ - static DMLC_ATTRIBUTE_UNUSED ::tvm::OpRegistry& __make_##Op +#define TVM_OP_REGISTER_VAR_DEF static DMLC_ATTRIBUTE_UNUSED ::tvm::OpRegEntry& __make_##Op /*! * \def TVM_REGISTER_OP @@ -428,38 +366,26 @@ class OpMap { * * \endcode */ -#define TVM_REGISTER_OP(OpName) \ - TVM_STR_CONCAT(TVM_OP_REGISTER_VAR_DEF, __COUNTER__) = \ - ::tvm::OpRegistry::Registry() \ - ->__REGISTER_OR_GET__(OpName) \ - .set_name() +#define TVM_REGISTER_OP(OpName) \ + TVM_STR_CONCAT(TVM_OP_REGISTER_VAR_DEF, __COUNTER__) = \ + ::tvm::OpRegEntry::RegisterOrGet(OpName).set_name() // implementations -inline const OpNode* Op::operator->() const { - return static_cast(get()); -} +inline const OpNode* Op::operator->() const { return static_cast(get()); } template -inline OpMap Op::GetAttr(const std::string& key) { - return OpMap(Op::GetGenericAttr(key)); +inline OpAttrMap Op::GetAttrMap(const String& key) { + return OpAttrMap(Op::GetAttrMapContainer(key)); } -inline bool Op::HasAttr(const std::string& key) { - return Op::HasGenericAttr(key); -} +inline OpNode* OpRegEntry::get() { return const_cast(op_.operator->()); } -inline OpNode* OpRegistry::get() { - return const_cast(op_.operator->()); -} - -inline OpRegistry& OpRegistry::describe( - const std::string& descr) { // NOLINT(*) +inline OpRegEntry& OpRegEntry::describe(const std::string& descr) { // NOLINT(*) get()->description = descr; return *this; } -inline OpRegistry& OpRegistry::add_argument(const std::string& name, - const std::string& type, +inline OpRegEntry& OpRegEntry::add_argument(const std::string& name, const std::string& type, const std::string& description) { auto n = make_object(); n->name = name; @@ -469,12 +395,10 @@ inline OpRegistry& OpRegistry::add_argument(const std::string& name, return *this; } -inline OpRegistry& OpRegistry::add_type_rel( +inline OpRegEntry& OpRegEntry::add_type_rel( const std::string& rel_name, - runtime::TypedPackedFunc&, - int, - const Attrs&, - const TypeReporter&)> type_rel_func) { + runtime::TypedPackedFunc&, int, const Attrs&, const TypeReporter&)> + type_rel_func) { auto func_name = std::string("tvm.relay.type_relation.") + rel_name; TypeRelationFn env_type_rel_func; @@ -482,8 +406,7 @@ inline OpRegistry& OpRegistry::add_type_rel( auto env_func = EnvFunc::Get(func_name); env_type_rel_func = env_func; } else { - runtime::Registry::Register(func_name) - .set_body(type_rel_func.packed()); + runtime::Registry::Register(func_name).set_body(type_rel_func.packed()); auto env_func = EnvFunc::Get(func_name); env_type_rel_func = env_func; } @@ -517,38 +440,34 @@ inline OpRegistry& OpRegistry::add_type_rel( // A common example is sum(x, axis), where the choice of axis // can affect the type of the function. TypeConstraint type_rel = - TypeRelation(env_type_rel_func, - ty_call_args, - arg_types.size(), - Attrs()); + TypeRelation(env_type_rel_func, ty_call_args, arg_types.size(), Attrs()); - auto func_type = - FuncType(arg_types, out_param, type_params, {type_rel}); + auto func_type = FuncType(arg_types, out_param, type_params, {type_rel}); get()->op_type = func_type; return *this; } -inline OpRegistry& OpRegistry::set_num_inputs(int32_t n) { // NOLINT(*) +inline OpRegEntry& OpRegEntry::set_num_inputs(int32_t n) { // NOLINT(*) get()->num_inputs = n; return *this; } -template -inline OpRegistry& OpRegistry::set_attrs_type() { // NOLINT(*) +template +inline OpRegEntry& OpRegEntry::set_attrs_type() { // NOLINT(*) get()->attrs_type_key = AttrsType::_type_key; get()->attrs_type_index = AttrsType::RuntimeTypeIndex(); return *this; } -inline OpRegistry& OpRegistry::set_support_level(int32_t n) { // NOLINT(*) +inline OpRegEntry& OpRegEntry::set_support_level(int32_t n) { // NOLINT(*) get()->support_level = n; return *this; } template -inline OpRegistry& OpRegistry::set_attr( // NOLINT(*) +inline OpRegEntry& OpRegEntry::set_attr( // NOLINT(*) const std::string& attr_name, const ValueType& value, int plevel) { CHECK_GT(plevel, 0) << "plevel in set_attr must be greater than 0"; runtime::TVMRetValue rv; @@ -557,74 +476,18 @@ inline OpRegistry& OpRegistry::set_attr( // NOLINT(*) return *this; } -// member functions of OpMap -inline int GenericOpMap::count(const Op& op) const { - if (op.defined()) { - const uint32_t idx = op->index_; - return idx < data_.size() ? (data_[idx].second != 0) : 0; - } else { - return 0; - } -} - -inline const runtime::TVMRetValue& -GenericOpMap::operator[](const Op& op) const { - CHECK(op.defined()); - const uint32_t idx = op->index_; - CHECK(idx < data_.size() && data_[idx].second != 0) - << "Attribute " << attr_name_ << " has not been registered for Operator " - << op->name; - return data_[idx].first; -} +// member functions of OpAttrMap template -inline ValueType GenericOpMap::get(const Op& op, ValueType value) const { - CHECK(op.defined()); - const uint32_t idx = op->index_; - if (idx < data_.size() && data_[idx].second != 0) { - return data_[idx].first; - } else { - return value; - } -} - -template -inline ValueType GenericOpMap::get(const RelayExpr& expr, ValueType value) const { +inline ValueType OpAttrMap::get(const RelayExpr& expr, ValueType def_value) const { CHECK(expr.defined()); if (const OpNode* op = expr.as()) { - const uint32_t idx = op->index_; - if (idx < data_.size() && data_[idx].second != 0) { - return data_[idx].first; - } else { - return value; - } + return this->map_.get(GetRef(op), def_value); } else { - return value; + return def_value; } } -template -inline int OpMap::count(const Op& op) const { - return map_.count(op); -} - -template -inline ValueType OpMap::operator[](const Op& op) const { - return map_[op]; -} - -template -inline ValueType OpMap::get(const Op& op, - ValueType def_value) const { - return map_.get(op, def_value); -} - -template -inline ValueType OpMap::get(const RelayExpr& expr, - ValueType def_value) const { - return map_.get(expr, def_value); -} - /*! * \brief Check that an expression is a "primitive operator". * diff --git a/include/tvm/ir/span.h b/include/tvm/ir/span.h index 7194e903549c..84d6a7b0f877 100644 --- a/include/tvm/ir/span.h +++ b/include/tvm/ir/span.h @@ -24,8 +24,9 @@ #ifndef TVM_IR_SPAN_H_ #define TVM_IR_SPAN_H_ -#include #include +#include + #include namespace tvm { @@ -40,7 +41,7 @@ class SourceName; class SourceNameNode : public Object { public: /*! \brief The source name. */ - std::string name; + String name; // override attr visitor void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); } @@ -64,7 +65,7 @@ class SourceName : public ObjectRef { * \param name Name of the operator. * \return SourceName valid throughout program lifetime. */ - TVM_DLL static SourceName Get(const std::string& name); + TVM_DLL static SourceName Get(const String& name); TVM_DEFINE_OBJECT_REF_METHODS(SourceName, ObjectRef, SourceNameNode); }; @@ -92,21 +93,18 @@ class SpanNode : public Object { } bool SEqualReduce(const SpanNode* other, SEqualReducer equal) const { - return - equal(source, other->source) && - equal(lineno, other->lineno) && - equal(col_offset, other->col_offset); + return equal(source, other->source) && equal(lineno, other->lineno) && + equal(col_offset, other->col_offset); } - TVM_DLL static Span make(SourceName source, int lineno, int col_offset); - static constexpr const char* _type_key = "Span"; TVM_DECLARE_FINAL_OBJECT_INFO(SpanNode, Object); }; - class Span : public ObjectRef { public: + TVM_DLL Span(SourceName source, int lineno, int col_offset); + TVM_DEFINE_OBJECT_REF_METHODS(Span, ObjectRef, SpanNode); }; diff --git a/include/tvm/ir/tensor_type.h b/include/tvm/ir/tensor_type.h index e993cd9afacc..7a700258f23c 100644 --- a/include/tvm/ir/tensor_type.h +++ b/include/tvm/ir/tensor_type.h @@ -24,8 +24,8 @@ #ifndef TVM_IR_TENSOR_TYPE_H_ #define TVM_IR_TENSOR_TYPE_H_ -#include #include +#include namespace tvm { /*! @@ -36,6 +36,7 @@ namespace tvm { class BaseTensorTypeNode : public TypeNode { public: static constexpr const char* _type_key = "relay.BaseTensorType"; + static constexpr const uint32_t _type_child_slots = 1; TVM_DECLARE_BASE_OBJECT_INFO(BaseTensorTypeNode, TypeNode); }; @@ -74,9 +75,7 @@ class TensorTypeNode : public BaseTensorTypeNode { } bool SEqualReduce(const TensorTypeNode* other, SEqualReducer equal) const { - return - equal(shape, other->shape) && - equal(dtype, other->dtype); + return equal(shape, other->shape) && equal(dtype, other->dtype); } void SHashReduce(SHashReducer hash_reduce) const { diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h index 3680f6db9afe..5bfb51adb0ac 100644 --- a/include/tvm/ir/transform.h +++ b/include/tvm/ir/transform.h @@ -56,11 +56,12 @@ #ifndef TVM_IR_TRANSFORM_H_ #define TVM_IR_TRANSFORM_H_ -#include -#include -#include #include #include +#include +#include +#include + #include #include @@ -70,13 +71,11 @@ namespace transform { // Forward declare for TraceFunc. class PassInfo; -/*! \brief A callback for tracing passes, useful for debugging and logging. - * +/*! + * \brief A callback for tracing passes, useful for debugging and logging. */ using TraceFunc = - runtime::TypedPackedFunc; + runtime::TypedPackedFunc; /*! * \brief PassContextNode contains the information that a pass can rely on, @@ -93,23 +92,53 @@ class PassContextNode : public Object { /*! \brief The default optimization level. */ int opt_level{2}; - /*! \brief CPU is the default fallback device for heterogeneous execution. */ - int fallback_device{static_cast(kDLCPU)}; - /*! \brief The list of required passes. */ - Array required_pass; + Array required_pass; /*! \brief The list of disabled passes. */ - Array disabled_pass; - + Array disabled_pass; + /*! \brief Trace function to be invoked before and after each pass. */ TraceFunc trace_func; + /*! \brief Pass specific configurations. */ + Map config; + PassContextNode() = default; + /*! + * \brief Get a config value from the pass context. + * + * \param key The config key. + * \param default_value The default value if the key does not exist, defaults to nullptr. + * + * \return The result + * + * \tparam TOBjectRef the expected object type. + * \throw Error if the key exists but the value does not match TObjectRef. + */ + template + Optional GetConfig(const std::string& key, Optional default_value = + Optional(nullptr)) const { + static_assert(std::is_base_of::value, + "Can only call GetAttr with ObjectRef types."); + if (!config.defined()) return default_value; + auto it = config.find(key); + if (it != config.end()) { + return Downcast>((*it).second); + } else { + return default_value; + } + } + // variant that uses TObjectRef to enable implicit conversion to default value. + template + Optional GetConfig(const std::string& key, TObjectRef default_value) const { + return GetConfig(key, Optional(default_value)); + } + void VisitAttrs(AttrVisitor* v) { v->Visit("opt_level", &opt_level); - v->Visit("fallback_device", &fallback_device); v->Visit("required_pass", &required_pass); v->Visit("disabled_pass", &disabled_pass); + v->Visit("config", &config); } static constexpr const char* _type_key = "transform.PassContext"; @@ -117,7 +146,6 @@ class PassContextNode : public Object { TVM_DECLARE_FINAL_OBJECT_INFO(PassContextNode, Object); }; - /*! * \brief PassContext that is used to configure the pass behavior. * @@ -125,7 +153,6 @@ class PassContextNode : public Object { * * auto new_ctx = PassContext::Create(); * ctx->opt_level = 2; - * ctx->fallback_device = kDLCPU; * With scope(ctx); * // pass context in effect. * @@ -152,6 +179,7 @@ class PassContext : public ObjectRef { CHECK(get() != nullptr); return static_cast(get_mutable()); } + /*! * \brief Construct a PassContext containing the default configurations. * \return The new PassContext. @@ -171,6 +199,21 @@ class PassContext : public ObjectRef { */ TVM_DLL void Trace(const IRModule& module, const PassInfo& info, bool is_before) const; + /*! + * \brief Register a valid configuration option and its ValueType for validation. + * + * \param key The configuration key. + * \tparam ValueType The value type to be registered + */ + template + static uint32_t RegisterConfigOption(const char* key) { + using ValueNodeType = typename ValueType::ContainerType; + // NOTE: we could further update the function later. + uint32_t tindex = ValueNodeType::_GetOrAllocRuntimeTypeIndex(); + RegisterConfigOption(key, tindex); + return tindex; + } + // accessor. using ContainerType = PassContextNode; class Internal; @@ -180,12 +223,26 @@ class PassContext : public ObjectRef { TVM_DLL void EnterWithScope(); // The exit of a pass context scope. TVM_DLL void ExitWithScope(); + // Register configuration key value type. + TVM_DLL static void RegisterConfigOption(const char* key, uint32_t value_type_index); // Classes to get the Python `with` like syntax. friend class Internal; friend class With; }; +#define TVM_PASS_CTX_CONFIG_VAR_DEF static TVM_ATTRIBUTE_UNUSED uint32_t __make_PassContext_tid + +/*! + * \brief Helper macro to register the object type to runtime. + * Makes sure that the runtime type table is correctly populated. + * + * Use this macro in the cc file for each terminal class. + */ +#define TVM_REGISTER_PASS_CONFIG_OPTION(Key, ValueType) \ + TVM_STR_CONCAT(TVM_PASS_CTX_CONFIG_VAR_DEF, __COUNTER__) = \ + ::tvm::transform::PassContext::RegisterConfigOption(Key) + /*! * \brief Meta data that will be used to help optimization and analysis. * \sa PassInfo @@ -196,10 +253,10 @@ class PassInfoNode : public Object { int opt_level; /*! \brief The name of an optimization/analysis pass. */ - std::string name; + String name; /*! \brief The passes that are required to perform the current pass. */ - Array required; + Array required; PassInfoNode() = default; @@ -226,9 +283,7 @@ class PassInfo : public ObjectRef { * \param name Name of the pass. * \param required The passes that are required to perform the current pass. */ - TVM_DLL PassInfo(int opt_level, - std::string name, - Array required); + TVM_DLL PassInfo(int opt_level, String name, Array required); TVM_DEFINE_OBJECT_REF_METHODS(PassInfo, ObjectRef, PassInfoNode); }; @@ -264,8 +319,7 @@ class PassNode : public Object { * * \return The transformed module. */ - virtual IRModule operator()(IRModule mod, - const PassContext& pass_ctx) const = 0; + virtual IRModule operator()(IRModule mod, const PassContext& pass_ctx) const = 0; void VisitAttrs(AttrVisitor* v) {} @@ -303,8 +357,7 @@ class Pass : public ObjectRef { * * \return The transformed module. */ - IRModule operator()(IRModule mod, - const PassContext& pass_ctx) const { + IRModule operator()(IRModule mod, const PassContext& pass_ctx) const { const PassNode* node = operator->(); CHECK(node != nullptr); return node->operator()(std::move(mod), pass_ctx); @@ -333,7 +386,7 @@ class Sequential : public Pass { * This allows users to only provide a list of passes and execute them * under a given context. */ - TVM_DLL Sequential(Array passes, std::string name = "sequential"); + TVM_DLL Sequential(Array passes, String name = "sequential"); Sequential() = default; explicit Sequential(ObjectPtr n) : Pass(n) {} @@ -352,12 +405,9 @@ class Sequential : public Pass { * * \return The created module pass. */ -TVM_DLL Pass CreateModulePass( - const runtime::TypedPackedFunc& pass_func, - int opt_level, - const std::string& name, - const Array& required); - +TVM_DLL Pass +CreateModulePass(const runtime::TypedPackedFunc& pass_func, + int opt_level, String name, Array required); /*! * \brief A special trace pass that prints the header and IR to LOG(INFO). @@ -365,7 +415,7 @@ TVM_DLL Pass CreateModulePass( * \param show_meta_data Whether should we show meta data. * \return The pass. */ -TVM_DLL Pass PrintIR(std::string header = "", bool show_meta_data = false); +TVM_DLL Pass PrintIR(String header = "", bool show_meta_data = false); } // namespace transform } // namespace tvm diff --git a/include/tvm/ir/type.h b/include/tvm/ir/type.h index 0e65758a2e1c..65b454f08b52 100644 --- a/include/tvm/ir/type.h +++ b/include/tvm/ir/type.h @@ -49,11 +49,12 @@ #ifndef TVM_IR_TYPE_H_ #define TVM_IR_TYPE_H_ -#include -#include -#include -#include #include +#include +#include +#include +#include + #include namespace tvm { @@ -81,6 +82,7 @@ class TypeNode : public Object { static constexpr const char* _type_key = "Type"; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true; + static constexpr const uint32_t _type_child_slots = 14; TVM_DECLARE_BASE_OBJECT_INFO(TypeNode, Object); }; @@ -108,23 +110,18 @@ class PrimTypeNode : public TypeNode { */ runtime::DataType dtype; - void VisitAttrs(AttrVisitor* v) { - v->Visit("dtype", &dtype); - } + void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &dtype); } bool SEqualReduce(const PrimTypeNode* other, SEqualReducer equal) const { return equal(dtype, other->dtype); } - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(dtype); - } + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(dtype); } static constexpr const char* _type_key = "PrimType"; TVM_DECLARE_FINAL_OBJECT_INFO(PrimTypeNode, TypeNode); }; - /* * \brief Managed reference to PrimTypeNode. * \sa PrimTypeNode @@ -140,7 +137,6 @@ class PrimType : public Type { TVM_DEFINE_OBJECT_REF_METHODS(PrimType, Type, PrimTypeNode); }; - /*! * \brief Low-level raw pointer type. * @@ -158,17 +154,13 @@ class PointerTypeNode : public TypeNode { */ Type element_type; - void VisitAttrs(AttrVisitor* v) { - v->Visit("element_type", &element_type); - } + void VisitAttrs(AttrVisitor* v) { v->Visit("element_type", &element_type); } bool SEqualReduce(const PointerTypeNode* other, SEqualReducer equal) const { return equal(element_type, other->element_type); } - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(element_type); - } + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(element_type); } static constexpr const char* _type_key = "PointerType"; TVM_DECLARE_FINAL_OBJECT_INFO(PointerTypeNode, TypeNode); @@ -189,7 +181,6 @@ class PointerType : public Type { TVM_DEFINE_OBJECT_REF_METHODS(PointerType, Type, PointerTypeNode); }; - /*! \brief Possible kinds of TypeVars. */ enum TypeKind : int { kType = 0, @@ -226,7 +217,7 @@ class TypeVarNode : public TypeNode { * this only acts as a hint to the user, * and is not used for equality. */ - std::string name_hint; + String name_hint; /*! \brief The kind of type parameter */ TypeKind kind; @@ -237,9 +228,7 @@ class TypeVarNode : public TypeNode { } bool SEqualReduce(const TypeVarNode* other, SEqualReducer equal) const { - return - equal(kind, other->kind) && - equal.FreeVarEqualImpl(this, other); + return equal(kind, other->kind) && equal.FreeVarEqualImpl(this, other); } void SHashReduce(SHashReducer hash_reduce) const { @@ -262,7 +251,7 @@ class TypeVar : public Type { * \param name_hint The name of the type var. * \param kind The kind of the type var. */ - TVM_DLL TypeVar(std::string name_hint, TypeKind kind); + TVM_DLL TypeVar(String name_hint, TypeKind kind); TVM_DEFINE_OBJECT_REF_METHODS(TypeVar, Type, TypeVarNode); }; @@ -278,7 +267,7 @@ class GlobalTypeVarNode : public TypeNode { * this only acts as a hint to the user, * and is not used for equality. */ - std::string name_hint; + String name_hint; /*! \brief The kind of type parameter */ TypeKind kind; @@ -289,9 +278,7 @@ class GlobalTypeVarNode : public TypeNode { bool SEqualReduce(const GlobalTypeVarNode* other, SEqualReducer equal) const { // name matters for now in global type var. - return - equal(name_hint, other->name_hint) && - equal.FreeVarEqualImpl(this, other); + return equal(name_hint, other->name_hint) && equal.FreeVarEqualImpl(this, other); } void SHashReduce(SHashReducer hash_reduce) const { @@ -314,7 +301,7 @@ class GlobalTypeVar : public Type { * \param name_hint The name of the type var. * \param kind The kind of the type var. */ - TVM_DLL GlobalTypeVar(std::string name_hint, TypeKind kind); + TVM_DLL GlobalTypeVar(String name_hint, TypeKind kind); TVM_DEFINE_OBJECT_REF_METHODS(GlobalTypeVar, Type, GlobalTypeVarNode); }; @@ -339,9 +326,7 @@ class TupleTypeNode : public TypeNode { return equal(fields, other->fields); } - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(fields); - } + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(fields); } static constexpr const char* _type_key = "TupleType"; TVM_DECLARE_FINAL_OBJECT_INFO(TupleTypeNode, TypeNode); @@ -371,9 +356,7 @@ class TupleType : public Type { /*! * \return a type that represents void. */ -inline Type VoidType() { - return TupleType::Empty(); -} +inline Type VoidType() { return TupleType::Empty(); } /*! * \brief Check whether the tyep represents void. @@ -391,6 +374,7 @@ inline bool IsVoidType(const Type& type) { class TypeConstraintNode : public TypeNode { public: static constexpr const char* _type_key = "TypeConstraint"; + static constexpr const uint32_t _type_child_slots = 1; TVM_DECLARE_BASE_OBJECT_INFO(TypeConstraintNode, TypeNode); }; @@ -437,11 +421,8 @@ class FuncTypeNode : public TypeNode { bool SEqualReduce(const FuncTypeNode* other, SEqualReducer equal) const { // type params first as they defines type vars. - return - equal.DefEqual(type_params, other->type_params) && - equal(arg_types, other->arg_types) && - equal(ret_type, other->ret_type) && - equal(type_constraints, other->type_constraints); + return equal.DefEqual(type_params, other->type_params) && equal(arg_types, other->arg_types) && + equal(ret_type, other->ret_type) && equal(type_constraints, other->type_constraints); } void SHashReduce(SHashReducer hash_reduce) const { @@ -469,9 +450,7 @@ class FuncType : public Type { * \param type_constraints The type constraints. * \sa FuncTypeNode for more docs about these fields. */ - TVM_DLL FuncType(Array arg_types, - Type ret_type, - Array type_params, + TVM_DLL FuncType(Array arg_types, Type ret_type, Array type_params, Array type_constraints); TVM_DEFINE_OBJECT_REF_METHODS(FuncType, Type, FuncTypeNode); @@ -498,14 +477,10 @@ class IncompleteTypeNode : public TypeNode { } bool SEqualReduce(const IncompleteTypeNode* other, SEqualReducer equal) const { - return - equal(kind, other->kind) && - equal.FreeVarEqualImpl(this, other); + return equal(kind, other->kind) && equal.FreeVarEqualImpl(this, other); } - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(kind); - } + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(kind); } static constexpr const char* _type_key = "IncompleteType"; TVM_DECLARE_FINAL_OBJECT_INFO(IncompleteTypeNode, TypeNode); @@ -526,7 +501,6 @@ class IncompleteType : public Type { TVM_DEFINE_OBJECT_REF_METHODS(IncompleteType, Type, IncompleteTypeNode); }; - /*! * \brief Reference Type High-level Relay IR. * @@ -548,9 +522,7 @@ class RelayRefTypeNode : public TypeNode { return equal(value, other->value); } - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(value); - } + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(value); } // Keep the relay prefix in the type as this type is specific // to the relay itself. diff --git a/include/tvm/ir/type_functor.h b/include/tvm/ir/type_functor.h index 55071911fb80..2a6314cf7644 100644 --- a/include/tvm/ir/type_functor.h +++ b/include/tvm/ir/type_functor.h @@ -25,11 +25,12 @@ #define TVM_IR_TYPE_FUNCTOR_H_ #include -#include #include +#include + #include -#include #include +#include namespace tvm { @@ -37,16 +38,13 @@ template class TypeFunctor; // functions to be overriden. -#define TYPE_FUNCTOR_DEFAULT \ +#define TYPE_FUNCTOR_DEFAULT \ { return VisitTypeDefault_(op, std::forward(args)...); } - -#define TVM_TYPE_FUNCTOR_DISPATCH(OP) \ - vtable.template set_dispatch( \ - [](const ObjectRef& n, TSelf* self, Args... args) { \ - return self->VisitType_(static_cast(n.get()), \ - std::forward(args)...); \ - }); +#define TVM_TYPE_FUNCTOR_DISPATCH(OP) \ + vtable.template set_dispatch([](const ObjectRef& n, TSelf* self, Args... args) { \ + return self->VisitType_(static_cast(n.get()), std::forward(args)...); \ + }); template class TypeFunctor { @@ -65,9 +63,7 @@ class TypeFunctor { * \param args Additional arguments. * \return The result of the call */ - R operator()(const Type& n, Args... args) { - return VisitType(n, std::forward(args)...); - } + R operator()(const Type& n, Args... args) { return VisitType(n, std::forward(args)...); } /*! * \brief The functor call. * \param n The expression node. @@ -80,8 +76,7 @@ class TypeFunctor { return vtable(n, this, std::forward(args)...); } // Functions that can be overriden by subclass - virtual R VisitType_(const TensorTypeNode* op, - Args... args) TYPE_FUNCTOR_DEFAULT; + virtual R VisitType_(const TensorTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const TypeVarNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const TypeConstraintNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const FuncTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; @@ -126,8 +121,7 @@ class TypeFunctor { /*! * \brief A type visitor that recursively visit types. */ -class TVM_DLL TypeVisitor : - public TypeFunctor { +class TVM_DLL TypeVisitor : public TypeFunctor { public: void VisitType_(const TypeVarNode* op) override; void VisitType_(const IncompleteTypeNode* op) override; @@ -146,8 +140,7 @@ class TVM_DLL TypeVisitor : /*! * \brief TypeMutator that mutates expressions. */ -class TVM_DLL TypeMutator : - public TypeFunctor { +class TVM_DLL TypeMutator : public TypeFunctor { public: Type VisitType(const Type& t) override; Type VisitType_(const TypeVarNode* op) override; diff --git a/include/tvm/ir/type_relation.h b/include/tvm/ir/type_relation.h index 06bcb7207c74..dbd241afa458 100644 --- a/include/tvm/ir/type_relation.h +++ b/include/tvm/ir/type_relation.h @@ -24,10 +24,10 @@ #ifndef TVM_IR_TYPE_RELATION_H_ #define TVM_IR_TYPE_RELATION_H_ -#include -#include -#include #include +#include +#include +#include namespace tvm { @@ -51,9 +51,7 @@ class TypeCallNode : public TypeNode { } bool SEqualReduce(const TypeCallNode* other, SEqualReducer equal) const { - return - equal(func, other->func) && - equal(args, other->args); + return equal(func, other->func) && equal(args, other->args); } void SHashReduce(SHashReducer hash_reduce) const { @@ -105,7 +103,7 @@ class TypeReporterNode : public Object { * \return false if assertation can be proven to have failed * true if solver can still proceed. */ - TVM_DLL virtual bool Assert(const PrimExpr& cond)= 0; + TVM_DLL virtual bool Assert(const PrimExpr& cond) = 0; /*! * \brief assert shape expression equals each other. * \param lhs The left operand. @@ -141,11 +139,9 @@ class TypeReporterNode : public Object { class TypeReporter : public ObjectRef { public: TypeReporter() {} - explicit TypeReporter(ObjectPtr n) : ObjectRef(n) { - } + explicit TypeReporter(ObjectPtr n) : ObjectRef(n) {} TypeReporterNode* operator->() const { - return const_cast( - static_cast(get())); + return const_cast(static_cast(get())); } using ContainerType = TypeReporterNode; }; @@ -169,11 +165,8 @@ class TypeReporter : public ObjectRef { * \return false if This relation cannot be resolved. * true if this relation has been resolved. */ -using TypeRelationFn = - TypedEnvFunc& args, - int num_inputs, - const Attrs& attrs, - const TypeReporter& reporter)>; +using TypeRelationFn = TypedEnvFunc& args, int num_inputs, + const Attrs& attrs, const TypeReporter& reporter)>; /*! * \brief User defined type relation, it is an input-output relation on types. @@ -207,11 +200,8 @@ class TypeRelationNode : public TypeConstraintNode { } bool SEqualReduce(const TypeRelationNode* other, SEqualReducer equal) const { - return - equal(func, other->func) && - equal(args, other->args) && - equal(num_inputs, other->num_inputs) && - equal(attrs, other->attrs); + return equal(func, other->func) && equal(args, other->args) && + equal(num_inputs, other->num_inputs) && equal(attrs, other->attrs); } void SHashReduce(SHashReducer hash_reduce) const { @@ -239,10 +229,7 @@ class TypeRelation : public TypeConstraint { * \param attrs Attributes to the relation function. * \sa TypeRelationNode for more docs about these fields. */ - TVM_DLL TypeRelation(TypeRelationFn func, - Array args, - int num_inputs, - Attrs attrs); + TVM_DLL TypeRelation(TypeRelationFn func, Array args, int num_inputs, Attrs attrs); TVM_DEFINE_OBJECT_REF_METHODS(TypeRelation, TypeConstraint, TypeRelationNode); }; diff --git a/include/tvm/node/attr_registry_map.h b/include/tvm/node/attr_registry_map.h new file mode 100644 index 000000000000..748b3a80969c --- /dev/null +++ b/include/tvm/node/attr_registry_map.h @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/node/attr_registry_map.h + * \brief Attribute map used in registry. + */ +#ifndef TVM_NODE_ATTR_REGISTRY_MAP_H_ +#define TVM_NODE_ATTR_REGISTRY_MAP_H_ + +#include +#include + +namespace tvm { + +/*! + * \brief Generic attribute map. + * \tparam KeyType the type of the key. + */ +template +class AttrRegistryMapContainerMap { + public: + /*! + * \brief Check if the map has key. + * \param key The key to the map + * \return 1 if key is contained in map, 0 otherwise. + */ + int count(const KeyType& key) const { + if (key.defined()) { + const uint32_t idx = key->AttrRegistryIndex(); + return idx < data_.size() ? (data_[idx].second != 0) : 0; + } else { + return 0; + } + } + /*! + * \brief get the corresponding value element at key. + * \param key The key to the map + * \return the const reference to the content value. + */ + const runtime::TVMRetValue& operator[](const KeyType& key) const { + CHECK(key.defined()); + const uint32_t idx = key->AttrRegistryIndex(); + CHECK(idx < data_.size() && data_[idx].second != 0) + << "Attribute " << attr_name_ << " has not been registered for " << key->name; + return data_[idx].first; + } + /*! + * \brief get the corresponding value element at key with default value. + * \param key The key to the map + * \param def_value The default value when the key does not exist. + * \return the const reference to the content value. + * \tparam ValueType The content value type. + */ + template + ValueType get(const KeyType& key, ValueType def_value) const { + CHECK(key.defined()); + const uint32_t idx = key->AttrRegistryIndex(); + if (idx < data_.size() && data_[idx].second != 0) { + return data_[idx].first; + } else { + return def_value; + } + } + + private: + /*! \brief The name of the attr field */ + String attr_name_; + /*! \brief The internal data. */ + std::vector> data_; + /*! \brief The constructor */ + AttrRegistryMapContainerMap() = default; + template + friend class AttrRegistry; + friend class OpRegEntry; +}; + +/*! + * \brief Map used to store meta-data. + * \tparam KeyType The type of the key + * \tparam ValueType The type of the value stored in map. + */ +template +class AttrRegistryMap { + public: + /*! + * \brief constructor + * \param map The internal map. + */ + explicit AttrRegistryMap(const AttrRegistryMapContainerMap& map) : map_(map) {} + /*! + * \brief Check if the map has op as key. + * \param key The key to the map + * \return 1 if op is contained in map, 0 otherwise. + */ + int count(const KeyType& key) const { return map_.count(key); } + /*! + * \brief get the corresponding value element at key. + * \param key The key to the map + * \return the const reference to the content value. + */ + ValueType operator[](const KeyType& key) const { return map_[key]; } + /*! + * \brief get the corresponding value element at key with default value. + * \param key The key to the map + * \param def_value The default value when the key does not exist. + * \return the const reference to the content value. + */ + ValueType get(const KeyType& key, ValueType def_value) const { return map_.get(key, def_value); } + + protected: + /*! \brief The internal map field */ + const AttrRegistryMapContainerMap& map_; +}; + +} // namespace tvm +#endif // TVM_NODE_ATTR_REGISTRY_MAP_H_ diff --git a/include/tvm/node/container.h b/include/tvm/node/container.h index ba1edf84383e..a3cfdaf267ac 100644 --- a/include/tvm/node/container.h +++ b/include/tvm/node/container.h @@ -23,47 +23,63 @@ #ifndef TVM_NODE_CONTAINER_H_ #define TVM_NODE_CONTAINER_H_ -#include +#include #include +#include #include -#include -#include -#include #include +#include +#include #include #include -#include +#include namespace tvm { -using runtime::String; -using runtime::StringObj; +using runtime::Array; +using runtime::ArrayNode; +using runtime::Downcast; +using runtime::IterAdapter; +using runtime::make_object; using runtime::Object; using runtime::ObjectPtr; +using runtime::ObjectPtrEqual; +using runtime::ObjectPtrHash; using runtime::ObjectRef; -using runtime::make_object; -using runtime::ObjectHash; -using runtime::ObjectEqual; +using runtime::String; +using runtime::StringObj; -/*! \brief array node content in array */ -class ArrayNode : public Object { - public: - /*! \brief the data content */ - std::vector data; +/*! \brief String-aware ObjectRef hash functor */ +struct ObjectHash { + size_t operator()(const ObjectRef& a) const { + if (const auto* str = a.as()) { + return String::HashBytes(str->data, str->size); + } + return ObjectPtrHash()(a); + } +}; - static constexpr const char* _type_key = "Array"; - TVM_DECLARE_FINAL_OBJECT_INFO(ArrayNode, Object); +/*! \brief String-aware ObjectRef equal functor */ +struct ObjectEqual { + bool operator()(const ObjectRef& a, const ObjectRef& b) const { + if (a.same_as(b)) { + return true; + } + if (const auto* str_a = a.as()) { + if (const auto* str_b = b.as()) { + return String::memncmp(str_a->data, str_b->data, str_a->size, str_b->size) == 0; + } + } + return false; + } }; /*! \brief map node content */ class MapNode : public Object { public: /*! \brief The corresponding conatiner type */ - using ContainerType = std::unordered_map< - ObjectRef, - ObjectRef, - ObjectHash, ObjectEqual>; + using ContainerType = std::unordered_map; /*! \brief the data content */ ContainerType data; @@ -72,305 +88,6 @@ class MapNode : public Object { TVM_DECLARE_FINAL_OBJECT_INFO(MapNode, Object); }; - -/*! \brief specialized map node with string as key */ -class StrMapNode : public Object { - public: - /*! \brief The corresponding conatiner type */ - using ContainerType = std::unordered_map; - - /*! \brief the data content */ - ContainerType data; - - static constexpr const char* _type_key = "StrMap"; - TVM_DECLARE_FINAL_OBJECT_INFO(StrMapNode, Object); -}; - -/*! - * \brief iterator adapter that adapts TIter to return another type. - * \tparam Converter a struct that contains converting function - * \tparam TIter the content iterator type. - */ -template -class IterAdapter { - public: - using difference_type = typename std::iterator_traits::difference_type; - using value_type = typename Converter::ResultType; - using pointer = typename Converter::ResultType*; - using reference = typename Converter::ResultType&; // NOLINT(*) - using iterator_category = typename std::iterator_traits::iterator_category; - - explicit IterAdapter(TIter iter) : iter_(iter) {} - inline IterAdapter& operator++() { - ++iter_; - return *this; - } - inline IterAdapter operator+(difference_type offset) const { - return IterAdapter(iter_ + offset); - } - - template - typename std::enable_if::value, - typename T::difference_type>::type - inline operator-(const IterAdapter& rhs) const { - return iter_ - rhs.iter_; - } - - inline bool operator==(IterAdapter other) const { - return iter_ == other.iter_; - } - inline bool operator!=(IterAdapter other) const { - return !(*this == other); - } - inline const value_type operator*() const { - return Converter::convert(*iter_); - } - - private: - TIter iter_; -}; - -/*! - * \brief Array container of NodeRef in DSL graph. - * Array implements copy on write semantics, which means array is mutable - * but copy will happen when array is referenced in more than two places. - * - * operator[] only provide const acces, use Set to mutate the content. - * \tparam T The content NodeRef type. - */ -template::value>::type > -class Array : public ObjectRef { - public: - /*! - * \brief default constructor - */ - Array() { - data_ = make_object(); - } - /*! - * \brief move constructor - * \param other source - */ - Array(Array && other) : ObjectRef() { // NOLINT(*) - data_ = std::move(other.data_); - } - /*! - * \brief copy constructor - * \param other source - */ - Array(const Array &other) : ObjectRef() { // NOLINT(*) - data_ = std::move(other.data_); - } - /*! - * \brief constructor from pointer - * \param n the container pointer - */ - explicit Array(ObjectPtr n) : ObjectRef(n) {} - /*! - * \brief constructor from iterator - * \param begin begin of iterator - * \param end end of iterator - * \tparam IterType The type of iterator - */ - template - Array(IterType begin, IterType end) { - assign(begin, end); - } - /*! - * \brief constructor from initializer list - * \param init The initalizer list - */ - Array(std::initializer_list init) { // NOLINT(*) - assign(init.begin(), init.end()); - } - /*! - * \brief constructor from vector - * \param init The vector - */ - Array(const std::vector& init) { // NOLINT(*) - assign(init.begin(), init.end()); - } - /*! - * \brief Constructs a container with n elements. Each element is a copy of val - * \param n The size of the container - * \param val The init value - */ - explicit Array(size_t n, const T& val) { - auto tmp_node = make_object(); - for (size_t i = 0; i < n; ++i) { - tmp_node->data.push_back(val); - } - data_ = std::move(tmp_node); - } - /*! - * \brief move assign operator - * \param other The source of assignment - * \return reference to self. - */ - Array& operator=(Array && other) { - data_ = std::move(other.data_); - return *this; - } - /*! - * \brief copy assign operator - * \param other The source of assignment - * \return reference to self. - */ - Array& operator=(const Array & other) { - data_ = other.data_; - return *this; - } - /*! - * \brief reset the array to content from iterator. - * \param begin begin of iterator - * \param end end of iterator - * \tparam IterType The type of iterator - */ - template - void assign(IterType begin, IterType end) { - auto n = make_object(); - for (IterType it = begin; it != end; ++it) { - n->data.push_back(T(*it)); - } - data_ = std::move(n); - } - /*! - * \brief Read i-th element from array. - * \param i The index - * \return the i-th element. - */ - inline const T operator[](size_t i) const { - return DowncastNoCheck( - static_cast(data_.get())->data[i]); - } - /*! \return The size of the array */ - inline size_t size() const { - if (data_.get() == nullptr) return 0; - return static_cast(data_.get())->data.size(); - } - /*! - * \brief copy on write semantics - * Do nothing if current handle is the unique copy of the array. - * Otherwise make a new copy of the array to ensure the current handle - * hold a unique copy. - * - * \return Handle to the internal node container(which ganrantees to be unique) - */ - inline ArrayNode* CopyOnWrite() { - if (data_.get() == nullptr || !data_.unique()) { - ObjectPtr n = make_object(); - n->data = static_cast(data_.get())->data; - ObjectPtr(std::move(n)).swap(data_); - } - return static_cast(data_.get()); - } - /*! - * \brief push a new item to the back of the list - * \param item The item to be pushed. - */ - inline void push_back(const T& item) { - ArrayNode* n = this->CopyOnWrite(); - n->data.push_back(item); - } - /*! - * \brief Resize the array. - * \param size The new size. - */ - inline void resize(size_t size) { - ArrayNode* n = this->CopyOnWrite(); - n->data.resize(size); - } - /*! - * \brief set i-th element of the array. - * \param i The index - * \param value The value to be setted. - */ - inline void Set(size_t i, const T& value) { - ArrayNode* n = this->CopyOnWrite(); - n->data[i] = value; - } - /*! \return whether array is empty */ - inline bool empty() const { - return size() == 0; - } - /*! - * \brief Helper function to apply fmutate to mutate an array. - * \param fmutate The transformation function T -> T. - * \tparam F the type of the mutation function. - * \note This function performs copy on write optimization. - */ - template - inline void MutateByApply(F fmutate) { - ArrayNode* ptr = static_cast(data_.get()); - if (ptr == nullptr) return; - if (data_.unique()) { - // Copy on write optimization. - // Perform inplace update because this is an unique copy. - for (size_t i = 0; i < ptr->data.size(); ++i) { - // It is important to use move here - // to make prevent the element's ref count from increasing - // so fmutate itself can perform copy-on-write optimization - T old_elem = DowncastNoCheck(std::move(ptr->data[i])); - T new_elem = fmutate(std::move(old_elem)); - ptr->data[i] = std::move(new_elem); - } - } else { - // lazily trigger copy if there is element change. - ObjectPtr copy; - for (size_t i = 0; i < ptr->data.size(); ++i) { - T old_elem = DowncastNoCheck(ptr->data[i]); - T new_elem = fmutate(old_elem); - if (!new_elem.same_as(ptr->data[i])) { - // copy the old array - if (copy == nullptr) { - copy = runtime::make_object(*ptr); - } - copy->data[i] = std::move(new_elem); - } - } - // replace the data with the new copy. - if (copy != nullptr) { - data_ = std::move(copy); - } - } - } - - /*! \brief specify container node */ - using ContainerType = ArrayNode; - - struct ValueConverter { - using ResultType = T; - static inline T convert(const ObjectRef& n) { - return DowncastNoCheck(n); - } - }; - using iterator = IterAdapter::const_iterator>; - - using reverse_iterator = IterAdapter< - ValueConverter, - std::vector::const_reverse_iterator>; - - /*! \return begin iterator */ - inline iterator begin() const { - return iterator(static_cast(data_.get())->data.begin()); - } - /*! \return end iterator */ - inline iterator end() const { - return iterator(static_cast(data_.get())->data.end()); - } - /*! \return rbegin iterator */ - inline reverse_iterator rbegin() const { - return reverse_iterator(static_cast(data_.get())->data.rbegin()); - } - /*! \return rend iterator */ - inline reverse_iterator rend() const { - return reverse_iterator(static_cast(data_.get())->data.rend()); - } -}; - /*! * \brief Map container of NodeRef->NodeRef in DSL graph. * Map implements copy on write semantics, which means map is mutable @@ -380,32 +97,27 @@ class Array : public ObjectRef { * \tparam K The key NodeRef type. * \tparam V The value NodeRef type. */ -template::value || - std::is_base_of::value >::type, - typename = typename std::enable_if::value>::type> +template ::value>::type, + typename = typename std::enable_if::value>::type> class Map : public ObjectRef { public: /*! * \brief default constructor */ - Map() { - data_ = make_object(); - } + Map() { data_ = make_object(); } /*! * \brief move constructor * \param other source */ - Map(Map && other) { // NOLINT(*) + Map(Map&& other) { // NOLINT(*) data_ = std::move(other.data_); } /*! * \brief copy constructor * \param other source */ - Map(const Map &other) : ObjectRef(other.data_) { // NOLINT(*) + Map(const Map& other) : ObjectRef(other.data_) { // NOLINT(*) } /*! * \brief constructor from pointer @@ -418,7 +130,7 @@ class Map : public ObjectRef { * \param end end of iterator * \tparam IterType The type of iterator */ - template + template Map(IterType begin, IterType end) { assign(begin, end); } @@ -426,15 +138,15 @@ class Map : public ObjectRef { * \brief constructor from initializer list * \param init The initalizer list */ - Map(std::initializer_list > init) { // NOLINT(*) + Map(std::initializer_list > init) { // NOLINT(*) assign(init.begin(), init.end()); } /*! - * \brief constructor from vector - * \param init The vector + * \brief constructor from unordered_map + * \param init The unordered_map */ - template - Map(const std::unordered_map& init) { // NOLINT(*) + template + Map(const std::unordered_map& init) { // NOLINT(*) assign(init.begin(), init.end()); } /*! @@ -442,7 +154,7 @@ class Map : public ObjectRef { * \param other The source of assignment * \return reference to self. */ - Map& operator=(Map && other) { + Map& operator=(Map&& other) { data_ = std::move(other.data_); return *this; } @@ -451,7 +163,7 @@ class Map : public ObjectRef { * \param other The source of assignment * \return reference to self. */ - Map& operator=(const Map & other) { + Map& operator=(const Map& other) { data_ = other.data_; return *this; } @@ -461,7 +173,7 @@ class Map : public ObjectRef { * \param end end of iterator * \tparam IterType The type of iterator */ - template + template void assign(IterType begin, IterType end) { ObjectPtr n = make_object(); for (IterType i = begin; i != end; ++i) { @@ -475,8 +187,7 @@ class Map : public ObjectRef { * \return the corresonding element. */ inline const V operator[](const K& key) const { - return DowncastNoCheck( - static_cast(data_.get())->data.at(key)); + return DowncastNoCheck(static_cast(data_.get())->data.at(key)); } /*! * \brief Read element from map. @@ -484,8 +195,7 @@ class Map : public ObjectRef { * \return the corresonding element. */ inline const V at(const K& key) const { - return DowncastNoCheck( - static_cast(data_.get())->data.at(key)); + return DowncastNoCheck(static_cast(data_.get())->data.at(key)); } /*! \return The size of the array */ inline size_t size() const { @@ -506,7 +216,7 @@ class Map : public ObjectRef { * \return Handle to the internal node container(which ganrantees to be unique) */ inline MapNode* CopyOnWrite() { - if (data_.get() == nullptr || !data_.unique()) { + if (data_.get() == nullptr || !data_.unique()) { ObjectPtr n = make_object(); n->data = static_cast(data_.get())->data; ObjectPtr(std::move(n)).swap(data_); @@ -524,24 +234,18 @@ class Map : public ObjectRef { } /*! \return whether array is empty */ - inline bool empty() const { - return size() == 0; - } + inline bool empty() const { return size() == 0; } /*! \brief specify container node */ using ContainerType = MapNode; struct ValueConverter { using ResultType = std::pair; - static inline ResultType convert(const std::pair< - ObjectRef, - ObjectRef>& n) { - return std::make_pair(DowncastNoCheck(n.first), - DowncastNoCheck(n.second)); + static inline ResultType convert(const std::pair& n) { + return std::make_pair(DowncastNoCheck(n.first), DowncastNoCheck(n.second)); } }; - using iterator = IterAdapter< - ValueConverter, MapNode::ContainerType::const_iterator>; + using iterator = IterAdapter; /*! \return begin iterator */ inline iterator begin() const { @@ -553,152 +257,32 @@ class Map : public ObjectRef { } /*! \return begin iterator */ inline iterator find(const K& key) const { - return iterator( - static_cast(data_.get())->data.find(key)); + return iterator(static_cast(data_.get())->data.find(key)); } }; -// specialize of string map -template -class Map : public ObjectRef { - public: - // for code reuse - Map() { - data_ = make_object(); - } - Map(Map && other) { // NOLINT(*) - data_ = std::move(other.data_); - } - Map(const Map &other) : ObjectRef(other.data_) { // NOLINT(*) - } - explicit Map(ObjectPtr n) : ObjectRef(n) {} - template - Map(IterType begin, IterType end) { - assign(begin, end); - } - Map(std::initializer_list > init) { // NOLINT(*) - assign(init.begin(), init.end()); - } - - template - Map(const std::unordered_map& init) { // NOLINT(*) - assign(init.begin(), init.end()); - } - Map& operator=(Map && other) { - data_ = std::move(other.data_); - return *this; - } - Map& operator=(const Map & other) { - data_ = other.data_; - return *this; - } - template - void assign(IterType begin, IterType end) { - auto n = make_object(); - for (IterType i = begin; i != end; ++i) { - n->data.emplace(std::make_pair(i->first, i->second)); - } - data_ = std::move(n); - } - inline const V operator[](const std::string& key) const { - return DowncastNoCheck( - static_cast(data_.get())->data.at(key)); - } - inline const V at(const std::string& key) const { - return DowncastNoCheck( - static_cast(data_.get())->data.at(key)); - } - inline size_t size() const { - if (data_.get() == nullptr) return 0; - return static_cast(data_.get())->data.size(); - } - inline size_t count(const std::string& key) const { - if (data_.get() == nullptr) return 0; - return static_cast(data_.get())->data.count(key); - } - inline StrMapNode* CopyOnWrite() { - if (data_.get() == nullptr || !data_.unique()) { - ObjectPtr n = make_object(); - n->data = static_cast(data_.get())->data; - ObjectPtr(std::move(n)).swap(data_); - } - return static_cast(data_.get()); - } - inline void Set(const std::string& key, const V& value) { - StrMapNode* n = this->CopyOnWrite(); - n->data[key] = value; - } - inline bool empty() const { - return size() == 0; - } - using ContainerType = StrMapNode; - - struct ValueConverter { - using ResultType = std::pair; - static inline ResultType convert(const std::pair< - std::string, - ObjectRef>& n) { - return std::make_pair(n.first, DowncastNoCheck(n.second)); - } - }; - - using iterator = IterAdapter< - ValueConverter, StrMapNode::ContainerType::const_iterator>; - - /*! \return begin iterator */ - inline iterator begin() const { - return iterator(static_cast(data_.get())->data.begin()); - } - /*! \return end iterator */ - inline iterator end() const { - return iterator(static_cast(data_.get())->data.end()); - } - /*! \return begin iterator */ - inline iterator find(const std::string& key) const { - return iterator(static_cast(data_.get())->data.find(key)); - } -}; } // namespace tvm namespace tvm { namespace runtime { // Additional overloads for PackedFunc checking. -template +template struct ObjectTypeChecker > { static bool Check(const Object* ptr) { if (ptr == nullptr) return true; if (!ptr->IsInstance()) return false; const ArrayNode* n = static_cast(ptr); - for (const auto& p : n->data) { + for (const ObjectRef& p : *n) { if (!ObjectTypeChecker::Check(p.get())) { return false; } } return true; } - static std::string TypeName() { - return "List[" + ObjectTypeChecker::TypeName() + "]"; - } -}; - -template -struct ObjectTypeChecker > { - static bool Check(const Object* ptr) { - if (ptr == nullptr) return true; - if (!ptr->IsInstance()) return false; - const StrMapNode* n = static_cast(ptr); - for (const auto& kv : n->data) { - if (!ObjectTypeChecker::Check(kv.second.get())) return false; - } - return true; - } - static std::string TypeName() { - return "Map[str, " + - ObjectTypeChecker::TypeName()+ ']'; - } + static std::string TypeName() { return "List[" + ObjectTypeChecker::TypeName() + "]"; } }; -template +template struct ObjectTypeChecker > { static bool Check(const Object* ptr) { if (ptr == nullptr) return true; @@ -711,10 +295,8 @@ struct ObjectTypeChecker > { return true; } static std::string TypeName() { - return "Map[" + - ObjectTypeChecker::TypeName() + - ", " + - ObjectTypeChecker::TypeName()+ ']'; + return "Map[" + ObjectTypeChecker::TypeName() + ", " + ObjectTypeChecker::TypeName() + + ']'; } }; } // namespace runtime diff --git a/include/tvm/node/functor.h b/include/tvm/node/functor.h index e11fda892c30..0837f35bd715 100644 --- a/include/tvm/node/functor.h +++ b/include/tvm/node/functor.h @@ -26,9 +26,9 @@ #include #include -#include #include #include +#include namespace tvm { @@ -60,16 +60,16 @@ using runtime::ObjectRef; * \tparam FType function signiture * This type if only defined for FType with function signature */ -template +template class NodeFunctor; -template +template class NodeFunctor { private: /*! \brief internal function pointer type */ - typedef R (*FPointer)(const ObjectRef&n, Args...); + typedef R (*FPointer)(const ObjectRef& n, Args...); /*! \brief refer to itself. */ - using TSelf = NodeFunctor; + using TSelf = NodeFunctor; /*! \brief internal function table */ std::vector func_; @@ -92,9 +92,8 @@ class NodeFunctor { * \return The result. */ R operator()(const ObjectRef& n, Args... args) const { - CHECK(can_dispatch(n)) - << "NodeFunctor calls un-registered function on type " - << n->GetTypeKey(); + CHECK(can_dispatch(n)) << "NodeFunctor calls un-registered function on type " + << n->GetTypeKey(); return (*func_[n->type_index()])(n, std::forward(args)...); } /*! @@ -103,37 +102,32 @@ class NodeFunctor { * \tparam TNode the type of Node to be dispatched. * \return reference to self. */ - template + template TSelf& set_dispatch(FPointer f) { // NOLINT(*) uint32_t tindex = TNode::RuntimeTypeIndex(); if (func_.size() <= tindex) { func_.resize(tindex + 1, nullptr); } - CHECK(func_[tindex] == nullptr) - << "Dispatch for " << TNode::_type_key - << " is already set"; + CHECK(func_[tindex] == nullptr) << "Dispatch for " << TNode::_type_key << " is already set"; func_[tindex] = f; return *this; } /*! - * \brief unset the dispacher for type TNode - * - * \tparam TNode the type of Node to be dispatched. - * \return reference to self. - */ - template + * \brief unset the dispacher for type TNode + * + * \tparam TNode the type of Node to be dispatched. + * \return reference to self. + */ + template TSelf& clear_dispatch() { // NOLINT(*) uint32_t tindex = TNode::RuntimeTypeIndex(); - CHECK_LT(tindex, func_.size()) - << "clear_dispatch: index out of range"; + CHECK_LT(tindex, func_.size()) << "clear_dispatch: index out of range"; func_[tindex] = nullptr; return *this; } }; - -#define TVM_REG_FUNC_VAR_DEF(ClsName) \ - static TVM_ATTRIBUTE_UNUSED auto & __make_functor ## _ ## ClsName +#define TVM_REG_FUNC_VAR_DEF(ClsName) static TVM_ATTRIBUTE_UNUSED auto& __make_functor##_##ClsName /*! * \brief Useful macro to set NodeFunctor dispatch in a global static field. @@ -176,8 +170,7 @@ class NodeFunctor { * \param ClsName The name of the class * \param FField The static function that returns a singleton of NodeFunctor. */ -#define TVM_STATIC_IR_FUNCTOR(ClsName, FField) \ - TVM_STR_CONCAT(TVM_REG_FUNC_VAR_DEF(ClsName), __COUNTER__) = \ - ClsName::FField() +#define TVM_STATIC_IR_FUNCTOR(ClsName, FField) \ + TVM_STR_CONCAT(TVM_REG_FUNC_VAR_DEF(ClsName), __COUNTER__) = ClsName::FField() } // namespace tvm #endif // TVM_NODE_FUNCTOR_H_ diff --git a/include/tvm/node/node.h b/include/tvm/node/node.h index 471a0de361b7..59295c2ce427 100644 --- a/include/tvm/node/node.h +++ b/include/tvm/node/node.h @@ -34,35 +34,35 @@ #ifndef TVM_NODE_NODE_H_ #define TVM_NODE_NODE_H_ -#include -#include -#include -#include +#include #include #include -#include #include #include +#include +#include +#include +#include #include -#include -#include #include +#include +#include namespace tvm { -using runtime::TypeIndex; +using runtime::Downcast; +using runtime::GetRef; +using runtime::make_object; using runtime::Object; using runtime::ObjectPtr; +using runtime::ObjectPtrEqual; +using runtime::ObjectPtrHash; using runtime::ObjectRef; -using runtime::GetRef; -using runtime::Downcast; -using runtime::ObjectHash; -using runtime::ObjectEqual; -using runtime::make_object; using runtime::PackedFunc; using runtime::TVMArgs; using runtime::TVMRetValue; +using runtime::TypeIndex; } // namespace tvm #endif // TVM_NODE_NODE_H_ diff --git a/include/tvm/node/reflection.h b/include/tvm/node/reflection.h index 9ed87df46618..e8ff26be42b3 100644 --- a/include/tvm/node/reflection.h +++ b/include/tvm/node/reflection.h @@ -23,18 +23,18 @@ #ifndef TVM_NODE_REFLECTION_H_ #define TVM_NODE_REFLECTION_H_ +#include +#include #include -#include +#include #include -#include #include -#include -#include -#include +#include +#include -#include #include #include +#include namespace tvm { @@ -51,7 +51,7 @@ using runtime::ObjectRef; */ class AttrVisitor { public: -//! \cond Doxygen_Suppress + //! \cond Doxygen_Suppress TVM_DLL virtual ~AttrVisitor() = default; TVM_DLL virtual void Visit(const char* key, double* value) = 0; TVM_DLL virtual void Visit(const char* key, int64_t* value) = 0; @@ -63,14 +63,13 @@ class AttrVisitor { TVM_DLL virtual void Visit(const char* key, DataType* value) = 0; TVM_DLL virtual void Visit(const char* key, runtime::NDArray* value) = 0; TVM_DLL virtual void Visit(const char* key, runtime::ObjectRef* value) = 0; - template::value>::type> + template ::value>::type> void Visit(const char* key, ENum* ptr) { static_assert(std::is_same::type>::value, "declare enum to be enum int to use visitor"); this->Visit(key, reinterpret_cast(ptr)); } -//! \endcond + //! \endcond }; /*! @@ -147,6 +146,22 @@ class ReflectionVTable { */ TVM_DLL ObjectPtr CreateInitObject(const std::string& type_key, const std::string& repr_bytes = "") const; + /*! + * \brief Create an object by giving kwargs about its fields. + * + * \param type_key The type key. + * \param kwargs the arguments in format key1, value1, ..., key_n, value_n. + * \return The created object. + */ + TVM_DLL ObjectRef CreateObject(const std::string& type_key, const runtime::TVMArgs& kwargs); + /*! + * \brief Create an object by giving kwargs about its fields. + * + * \param type_key The type key. + * \param kwargs The field arguments. + * \return The created object. + */ + TVM_DLL ObjectRef CreateObject(const std::string& type_key, const Map& kwargs); /*! * \brief Get an field object by the attr name. * \param self The pointer to the object. @@ -154,7 +169,7 @@ class ReflectionVTable { * \return The corresponding attribute value. * \note This function will throw an exception if the object does not contain the field. */ - TVM_DLL runtime::TVMRetValue GetAttr(Object* self, const std::string& attr_name) const; + TVM_DLL runtime::TVMRetValue GetAttr(Object* self, const String& attr_name) const; /*! * \brief List all the fields in the object. @@ -166,7 +181,7 @@ class ReflectionVTable { TVM_DLL static ReflectionVTable* Global(); class Registry; - template + template inline Registry Register(); private: @@ -174,7 +189,7 @@ class ReflectionVTable { std::vector fvisit_attrs_; /*! \brief Structural equal function. */ std::vector fsequal_reduce_; - /*! \brief Structural hash function. */ + /*! \brief Structural hash function. */ std::vector fshash_reduce_; /*! \brief Creation function. */ std::vector fcreate_; @@ -186,7 +201,7 @@ class ReflectionVTable { class ReflectionVTable::Registry { public: Registry(ReflectionVTable* parent, uint32_t type_index) - : parent_(parent), type_index_(type_index) { } + : parent_(parent), type_index_(type_index) {} /*! * \brief Set fcreate function. * \param f The creator function. @@ -213,10 +228,8 @@ class ReflectionVTable::Registry { uint32_t type_index_; }; - -#define TVM_REFLECTION_REG_VAR_DEF \ - static TVM_ATTRIBUTE_UNUSED ::tvm::ReflectionVTable::Registry \ - __make_reflectiion +#define TVM_REFLECTION_REG_VAR_DEF \ + static TVM_ATTRIBUTE_UNUSED ::tvm::ReflectionVTable::Registry __make_reflectiion /*! * \brief Directly register reflection VTable. @@ -228,7 +241,11 @@ class ReflectionVTable::Registry { * // Example SEQualReduce traits for runtime StringObj. * * struct StringObjTrait { - * static constexpr const std::nullptr_t VisitAttrs = nullptr; + * static constexpr const std::nullptr_t VisitAttrs = nullptr; + * + * static void SHashReduce(const runtime::StringObj* key, SHashReducer hash_reduce) { + * hash_reduce->SHashReduceHashedValue(runtime::String::HashBytes(key->data, key->size)); + * } * * static bool SEqualReduce(const runtime::StringObj* lhs, * const runtime::StringObj* rhs, @@ -247,122 +264,108 @@ class ReflectionVTable::Registry { * \note This macro can be called in different place as TVM_REGISTER_OBJECT_TYPE. * And can be used to register the related reflection functions for runtime objects. */ -#define TVM_REGISTER_REFLECTION_VTABLE(TypeName, TraitName) \ - TVM_STR_CONCAT(TVM_REFLECTION_REG_VAR_DEF, __COUNTER__) = \ - ::tvm::ReflectionVTable::Global()->Register() \ +#define TVM_REGISTER_REFLECTION_VTABLE(TypeName, TraitName) \ + TVM_STR_CONCAT(TVM_REFLECTION_REG_VAR_DEF, __COUNTER__) = \ + ::tvm::ReflectionVTable::Global()->Register() /*! * \brief Register a node type to object registry and reflection registry. * \param TypeName The name of the type. * \note This macro will call TVM_REGISTER_OBJECT_TYPE for the type as well. */ -#define TVM_REGISTER_NODE_TYPE(TypeName) \ - TVM_REGISTER_OBJECT_TYPE(TypeName); \ +#define TVM_REGISTER_NODE_TYPE(TypeName) \ + TVM_REGISTER_OBJECT_TYPE(TypeName); \ TVM_REGISTER_REFLECTION_VTABLE(TypeName, ::tvm::detail::ReflectionTrait) \ - .set_creator([](const std::string&) -> ObjectPtr { \ - return ::tvm::runtime::make_object(); \ - }) - + .set_creator([](const std::string&) -> ObjectPtr { \ + return ::tvm::runtime::make_object(); \ + }) // Implementation details namespace detail { -template +template struct ImplVisitAttrs { static constexpr const std::nullptr_t VisitAttrs = nullptr; }; -template +template struct ImplVisitAttrs { - static void VisitAttrs(T* self, AttrVisitor* v) { - self->VisitAttrs(v); - } + static void VisitAttrs(T* self, AttrVisitor* v) { self->VisitAttrs(v); } }; -template +template struct ImplSEqualReduce { static constexpr const std::nullptr_t SEqualReduce = nullptr; }; -template +template struct ImplSEqualReduce { static bool SEqualReduce(const T* self, const T* other, SEqualReducer equal) { return self->SEqualReduce(other, equal); } }; -template +template struct ImplSHashReduce { static constexpr const std::nullptr_t SHashReduce = nullptr; }; -template +template struct ImplSHashReduce { static void SHashReduce(const T* self, SHashReducer hash_reduce) { self->SHashReduce(hash_reduce); } }; -template -struct ReflectionTrait : - public ImplVisitAttrs, - public ImplSEqualReduce, - public ImplSHashReduce { -}; +template +struct ReflectionTrait : public ImplVisitAttrs, + public ImplSEqualReduce, + public ImplSHashReduce {}; -template::value> +template ::value> struct SelectVisitAttrs { static constexpr const std::nullptr_t VisitAttrs = nullptr; }; -template +template struct SelectVisitAttrs { static void VisitAttrs(Object* self, AttrVisitor* v) { TraitName::VisitAttrs(static_cast(self), v); } }; -template::value> +template ::value> struct SelectSEqualReduce { static constexpr const std::nullptr_t SEqualReduce = nullptr; }; -template +template struct SelectSEqualReduce { - static bool SEqualReduce(const Object* self, - const Object* other, - SEqualReducer equal) { - return TraitName::SEqualReduce(static_cast(self), - static_cast(other), + static bool SEqualReduce(const Object* self, const Object* other, SEqualReducer equal) { + return TraitName::SEqualReduce(static_cast(self), static_cast(other), equal); } }; -template::value> +template ::value> struct SelectSHashReduce { static constexpr const std::nullptr_t SHashReduce = nullptr; }; -template +template struct SelectSHashReduce { - static void SHashReduce(const Object* self, - SHashReducer hash_reduce) { - return TraitName::SHashReduce(static_cast(self), - hash_reduce); + static void SHashReduce(const Object* self, SHashReducer hash_reduce) { + return TraitName::SHashReduce(static_cast(self), hash_reduce); } }; } // namespace detail -template -inline ReflectionVTable::Registry -ReflectionVTable::Register() { +template +inline ReflectionVTable::Registry ReflectionVTable::Register() { uint32_t tindex = T::RuntimeTypeIndex(); if (tindex >= fvisit_attrs_.size()) { fvisit_attrs_.resize(tindex + 1, nullptr); @@ -372,20 +375,16 @@ ReflectionVTable::Register() { fshash_reduce_.resize(tindex + 1, nullptr); } // functor that implemnts the redirection. - fvisit_attrs_[tindex] = - ::tvm::detail::SelectVisitAttrs::VisitAttrs; + fvisit_attrs_[tindex] = ::tvm::detail::SelectVisitAttrs::VisitAttrs; - fsequal_reduce_[tindex] = - ::tvm::detail::SelectSEqualReduce::SEqualReduce; + fsequal_reduce_[tindex] = ::tvm::detail::SelectSEqualReduce::SEqualReduce; - fshash_reduce_[tindex] = - ::tvm::detail::SelectSHashReduce::SHashReduce; + fshash_reduce_[tindex] = ::tvm::detail::SelectSHashReduce::SHashReduce; return Registry(this, tindex); } -inline void ReflectionVTable:: -VisitAttrs(Object* self, AttrVisitor* visitor) const { +inline void ReflectionVTable::VisitAttrs(Object* self, AttrVisitor* visitor) const { uint32_t tindex = self->type_index(); if (tindex >= fvisit_attrs_.size() || fvisit_attrs_[tindex] == nullptr) { LOG(FATAL) << "TypeError: " << self->GetTypeKey() @@ -394,8 +393,7 @@ VisitAttrs(Object* self, AttrVisitor* visitor) const { fvisit_attrs_[tindex](self, visitor); } -inline bool ReflectionVTable::GetReprBytes(const Object* self, - std::string* repr_bytes) const { +inline bool ReflectionVTable::GetReprBytes(const Object* self, std::string* repr_bytes) const { uint32_t tindex = self->type_index(); if (tindex < frepr_bytes_.size() && frepr_bytes_[tindex] != nullptr) { if (repr_bytes != nullptr) { diff --git a/include/tvm/node/repr_printer.h b/include/tvm/node/repr_printer.h index 57824306620c..532425a51b3e 100644 --- a/include/tvm/node/repr_printer.h +++ b/include/tvm/node/repr_printer.h @@ -24,6 +24,7 @@ #define TVM_NODE_REPR_PRINTER_H_ #include + #include namespace tvm { diff --git a/include/tvm/node/structural_equal.h b/include/tvm/node/structural_equal.h index f719e24f619c..9424f6dc30f2 100644 --- a/include/tvm/node/structural_equal.h +++ b/include/tvm/node/structural_equal.h @@ -23,9 +23,10 @@ #ifndef TVM_NODE_STRUCTURAL_EQUAL_H_ #define TVM_NODE_STRUCTURAL_EQUAL_H_ -#include -#include #include +#include +#include + #include namespace tvm { @@ -43,26 +44,13 @@ class BaseValueEqual { return diff > -atol && diff < atol; } - bool operator()(const int64_t& lhs, const int64_t& rhs) const { - return lhs == rhs; - } - bool operator()(const uint64_t& lhs, const uint64_t& rhs) const { - return lhs == rhs; - } - bool operator()(const int& lhs, const int& rhs) const { - return lhs == rhs; - } - bool operator()(const bool& lhs, const bool& rhs) const { - return lhs == rhs; - } - bool operator()(const std::string& lhs, const std::string& rhs) const { - return lhs == rhs; - } - bool operator()(const DataType& lhs, const DataType& rhs) const { - return lhs == rhs; - } - template::value>::type> + bool operator()(const int64_t& lhs, const int64_t& rhs) const { return lhs == rhs; } + bool operator()(const uint64_t& lhs, const uint64_t& rhs) const { return lhs == rhs; } + bool operator()(const int& lhs, const int& rhs) const { return lhs == rhs; } + bool operator()(const bool& lhs, const bool& rhs) const { return lhs == rhs; } + bool operator()(const std::string& lhs, const std::string& rhs) const { return lhs == rhs; } + bool operator()(const DataType& lhs, const DataType& rhs) const { return lhs == rhs; } + template ::value>::type> bool operator()(const ENum& lhs, const ENum& rhs) const { return lhs == rhs; } @@ -127,9 +115,7 @@ class SEqualReducer : public BaseValueEqual { * \note This function may save the equality condition of (lhs == rhs) in an internal * stack and try to resolve later. */ - virtual bool SEqualReduce(const ObjectRef& lhs, - const ObjectRef& rhs, - bool map_free_vars) = 0; + virtual bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars) = 0; /*! * \brief Lookup the graph node equal map for vars that are already mapped. * @@ -185,7 +171,7 @@ class SEqualReducer : public BaseValueEqual { * \param rhs The right operand. * \return the immediate check result. */ - template + template bool operator()(const Array& lhs, const Array& rhs) const { // quick specialization for Array to reduce amount of recursion // depth as array comparison is pretty common. @@ -210,9 +196,7 @@ class SEqualReducer : public BaseValueEqual { } /*! \return Get the internal handler. */ - Handler* operator->() const { - return handler_; - } + Handler* operator->() const { return handler_; } private: /*! \brief Internal class pointer. */ diff --git a/include/tvm/node/structural_hash.h b/include/tvm/node/structural_hash.h index affc5f4dc377..ed89d841cd65 100644 --- a/include/tvm/node/structural_hash.h +++ b/include/tvm/node/structural_hash.h @@ -23,11 +23,12 @@ #ifndef TVM_NODE_STRUCTURAL_HASH_H_ #define TVM_NODE_STRUCTURAL_HASH_H_ -#include -#include #include -#include +#include +#include + #include +#include namespace tvm { @@ -36,39 +37,25 @@ namespace tvm { */ class BaseValueHash { public: - size_t operator()(const double& key) const { - return std::hash()(key); - } + size_t operator()(const double& key) const { return std::hash()(key); } - size_t operator()(const int64_t& key) const { - return std::hash()(key); - } + size_t operator()(const int64_t& key) const { return std::hash()(key); } - size_t operator()(const uint64_t& key) const { - return std::hash()(key); - } + size_t operator()(const uint64_t& key) const { return std::hash()(key); } - size_t operator()(const int& key) const { - return std::hash()(key); - } + size_t operator()(const int& key) const { return std::hash()(key); } - size_t operator()(const bool& key) const { - return std::hash()(key); - } + size_t operator()(const bool& key) const { return std::hash()(key); } - size_t operator()(const std::string& key) const { - return std::hash()(key); - } + size_t operator()(const std::string& key) const { return std::hash()(key); } size_t operator()(const runtime::DataType& key) const { - return std::hash()( - static_cast(key.code()) | - (static_cast(key.bits()) << 8) | - (static_cast(key.lanes()) << 16)); + return std::hash()(static_cast(key.code()) | + (static_cast(key.bits()) << 8) | + (static_cast(key.lanes()) << 16)); } - template::value>::type> + template ::value>::type> bool operator()(const ENum& key) const { return std::hash()(static_cast(key)); } @@ -173,9 +160,8 @@ class SHashReducer { * \brief Push hash of key to the current sequence of hash values. * \param key The key to be hashed. */ - template::value>::type> + template ::value>::type> void operator()(const T& key) const { // handle normal values. handler_->SHashReduceHashedValue(BaseValueHash()(key)); @@ -184,17 +170,13 @@ class SHashReducer { * \brief Push hash of key to the current sequence of hash values. * \param key The key to be hashed. */ - void operator()(const ObjectRef& key) const { - return handler_->SHashReduce(key, map_free_vars_); - } + void operator()(const ObjectRef& key) const { return handler_->SHashReduce(key, map_free_vars_); } /*! * \brief Push hash of key to the current sequence of hash values. * \param key The key to be hashed. * \note This function indicate key could contain var defintions. */ - void DefHash(const ObjectRef& key) const { - return handler_->SHashReduce(key, true); - } + void DefHash(const ObjectRef& key) const { return handler_->SHashReduce(key, true); } /*! * \brief Implementation for hash for a free var. * \param var The variable. @@ -205,9 +187,7 @@ class SHashReducer { } /*! \return Get the internal handler. */ - Handler* operator->() const { - return handler_; - } + Handler* operator->() const { return handler_; } private: /*! \brief Internal class pointer. */ diff --git a/include/tvm/relay/adt.h b/include/tvm/relay/adt.h index 1ee7c9c09728..b2164ba8c1f7 100644 --- a/include/tvm/relay/adt.h +++ b/include/tvm/relay/adt.h @@ -24,13 +24,14 @@ #ifndef TVM_RELAY_ADT_H_ #define TVM_RELAY_ADT_H_ -#include #include +#include #include #include #include -#include + #include +#include #include namespace tvm { @@ -72,16 +73,11 @@ class PatternWildcard; /*! \brief PatternWildcard container node */ class PatternWildcardNode : public PatternNode { public: - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("span", &span); - } + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("span", &span); } - bool SEqualReduce(const PatternNode* other, SEqualReducer equal) const { - return true; - } + bool SEqualReduce(const PatternNode* other, SEqualReducer equal) const { return true; } - void SHashReduce(SHashReducer hash_reduce) const { - } + void SHashReduce(SHashReducer hash_reduce) const {} static constexpr const char* _type_key = "relay.PatternWildcard"; TVM_DECLARE_FINAL_OBJECT_INFO(PatternWildcardNode, PatternNode); @@ -131,9 +127,7 @@ class PatternVarNode : public PatternNode { return equal.DefEqual(var, other->var); } - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce.DefHash(var); - } + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce.DefHash(var); } static constexpr const char* _type_key = "relay.PatternVar"; TVM_DECLARE_FINAL_OBJECT_INFO(PatternVarNode, PatternNode); @@ -167,9 +161,7 @@ class PatternConstructorNode : public PatternNode { } bool SEqualReduce(const PatternConstructorNode* other, SEqualReducer equal) const { - return - equal(constructor, other->constructor) && - equal(patterns, other->patterns); + return equal(constructor, other->constructor) && equal(patterns, other->patterns); } void SHashReduce(SHashReducer hash_reduce) const { @@ -210,9 +202,7 @@ class PatternTupleNode : public PatternNode { return equal(patterns, other->patterns); } - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(patterns); - } + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(patterns); } static constexpr const char* _type_key = "relay.PatternTuple"; TVM_DECLARE_FINAL_OBJECT_INFO(PatternTupleNode, PatternNode); @@ -297,10 +287,8 @@ class MatchNode : public ExprNode { bool SEqualReduce(const MatchNode* other, SEqualReducer equal) const { equal->MarkGraphNode(); - return - equal(data, other->data) && - equal(clauses, other->clauses) && - equal(complete, other->complete); + return equal(data, other->data) && equal(clauses, other->clauses) && + equal(complete, other->complete); } void SHashReduce(SHashReducer hash_reduce) const { diff --git a/include/tvm/relay/analysis.h b/include/tvm/relay/analysis.h index a2c0c75b66ef..b4b1b9dcc4e8 100644 --- a/include/tvm/relay/analysis.h +++ b/include/tvm/relay/analysis.h @@ -24,11 +24,12 @@ #ifndef TVM_RELAY_ANALYSIS_H_ #define TVM_RELAY_ANALYSIS_H_ +#include #include #include #include -#include #include + #include #include @@ -73,9 +74,9 @@ TVM_DLL bool ConstantCheck(const Expr& e); * `let f = (\x -> x) in let g = (\x -> x + 1) in f(g(2))` also bound x twice, * although x is not shadowed. * - * \param expr the expression to check. + * \param expr the expression to check. * - * \return true iff all Var in expr is bound at most once. + * \return true iff all Var in expr is bound at most once. */ TVM_DLL bool WellFormed(const Expr& expr); @@ -233,8 +234,7 @@ TVM_DLL Array UnmatchedCases(const Match& match, const IRModule& mod); * * \return The reference count mapping. */ -TVM_DLL std::unordered_map -GetExprRefCount(const Expr& body); +TVM_DLL std::unordered_map GetExprRefCount(const Expr& body); } // namespace relay } // namespace tvm diff --git a/include/tvm/relay/attrs/algorithm.h b/include/tvm/relay/attrs/algorithm.h index 2d1b9028732d..83b4ddaead43 100644 --- a/include/tvm/relay/attrs/algorithm.h +++ b/include/tvm/relay/attrs/algorithm.h @@ -26,6 +26,8 @@ #include #include +#include + #include namespace tvm { @@ -38,39 +40,39 @@ struct ArgsortAttrs : public tvm::AttrsNode { DataType dtype; TVM_DECLARE_ATTRS(ArgsortAttrs, "relay.attrs.ArgsortAttrs") { - TVM_ATTR_FIELD(axis).set_default(-1) - .describe("Axis along which to sort the input tensor." - "If not given, the flattened array is used."); - TVM_ATTR_FIELD(is_ascend).set_default(true) - .describe("Whether to sort in ascending or descending order." - "By default, sort in ascending order"); - TVM_ATTR_FIELD(dtype).set_default(NullValue()) - .describe("DType of the output indices."); + TVM_ATTR_FIELD(axis).set_default(-1).describe( + "Axis along which to sort the input tensor." + "If not given, the flattened array is used."); + TVM_ATTR_FIELD(is_ascend).set_default(true).describe( + "Whether to sort in ascending or descending order." + "By default, sort in ascending order"); + TVM_ATTR_FIELD(dtype) + .set_default(NullValue()) + .describe("DType of the output indices."); } }; struct TopKAttrs : public tvm::AttrsNode { - int k; + Optional k; int axis; bool is_ascend; std::string ret_type; DataType dtype; TVM_DECLARE_ATTRS(TopKAttrs, "relay.attrs.TopkAttrs") { - TVM_ATTR_FIELD(k).set_default(1) - .describe("Number of top elements to select"); - TVM_ATTR_FIELD(axis).set_default(-1) - .describe("Axis along which to sort the input tensor."); - TVM_ATTR_FIELD(ret_type).set_default("both") - .describe("The return type [both, values, indices]." - "both - return both top k data and indices." - "values - return top k data only." - "indices - return top k indices only."); - TVM_ATTR_FIELD(is_ascend).set_default(false) - .describe("Whether to sort in ascending or descending order." - "By default, sort in descending order"); - TVM_ATTR_FIELD(dtype).set_default(NullValue()) - .describe("Data type of the output indices."); + TVM_ATTR_FIELD(k).describe("Number of top elements to select"); + TVM_ATTR_FIELD(axis).set_default(-1).describe("Axis along which to sort the input tensor."); + TVM_ATTR_FIELD(ret_type).set_default("both").describe( + "The return type [both, values, indices]." + "both - return both top k data and indices." + "values - return top k data only." + "indices - return top k indices only."); + TVM_ATTR_FIELD(is_ascend).set_default(false).describe( + "Whether to sort in ascending or descending order." + "By default, sort in descending order"); + TVM_ATTR_FIELD(dtype) + .set_default(NullValue()) + .describe("Data type of the output indices."); } }; diff --git a/include/tvm/relay/attrs/annotation.h b/include/tvm/relay/attrs/annotation.h index cc21e34b4125..4a2eb63c7e6a 100644 --- a/include/tvm/relay/attrs/annotation.h +++ b/include/tvm/relay/attrs/annotation.h @@ -25,6 +25,7 @@ #define TVM_RELAY_ATTRS_ANNOTATION_H_ #include + #include namespace tvm { @@ -38,9 +39,8 @@ struct OnDeviceAttrs : public tvm::AttrsNode { TVM_DECLARE_ATTRS(OnDeviceAttrs, "relay.attrs.OnDeviceAttrs") { TVM_ATTR_FIELD(device_type) - .describe( - "The virutal device/context type that an expression is annotated with.") - .set_default(0); + .describe("The virutal device/context type that an expression is annotated with.") + .set_default(0); } }; @@ -51,9 +51,7 @@ struct CastHintAttrs : public tvm::AttrsNode { DataType dtype; TVM_DECLARE_ATTRS(CastHintAttrs, "relay.attrs.CastHintAttrs") { - TVM_ATTR_FIELD(dtype) - .describe( - "The data type denoted to be cast."); + TVM_ATTR_FIELD(dtype).describe("The data type denoted to be cast."); } }; @@ -65,8 +63,7 @@ struct CompilerAttrs : public tvm::AttrsNode { std::string compiler; TVM_DECLARE_ATTRS(CompilerAttrs, "relay.attrs.CompilerAttrs") { - TVM_ATTR_FIELD(compiler) - .describe("A 3rd party compiler used for code generation."); + TVM_ATTR_FIELD(compiler).describe("A 3rd party compiler used for code generation."); } }; diff --git a/include/tvm/relay/attrs/bitserial.h b/include/tvm/relay/attrs/bitserial.h index 962afc29fdbc..ed04c59ec865 100644 --- a/include/tvm/relay/attrs/bitserial.h +++ b/include/tvm/relay/attrs/bitserial.h @@ -27,6 +27,7 @@ #include #include + #include namespace tvm { @@ -112,23 +113,18 @@ struct BinaryDenseAttrs : public tvm::AttrsNode { bool unipolar; TVM_DECLARE_ATTRS(BinaryDenseAttrs, "relay.attrs.BinaryDenseAttrs") { - TVM_ATTR_FIELD(units) - .describe("Number of hidden units of the dense transformation."); - TVM_ATTR_FIELD(data_bits) - .set_default(1) - .describe("Number of bits to pack for incoming tensor."); + TVM_ATTR_FIELD(units).describe("Number of hidden units of the dense transformation."); + TVM_ATTR_FIELD(data_bits).set_default(1).describe( + "Number of bits to pack for incoming tensor."); TVM_ATTR_FIELD(weight_bits) - .set_default(1) - .describe("Number of bits to pack for weight tensor."); + .set_default(1) + .describe("Number of bits to pack for weight tensor."); TVM_ATTR_FIELD(pack_dtype) - .set_default(NullValue()) - .describe("Datatype to pack bits into before computation."); - TVM_ATTR_FIELD(out_dtype) - .set_default(NullValue()) - .describe("Output data type."); - TVM_ATTR_FIELD(unipolar) - .set_default(true) - .describe("Whether to use unipolar or bipolar quantization for inputs."); + .set_default(NullValue()) + .describe("Datatype to pack bits into before computation."); + TVM_ATTR_FIELD(out_dtype).set_default(NullValue()).describe("Output data type."); + TVM_ATTR_FIELD(unipolar).set_default(true).describe( + "Whether to use unipolar or bipolar quantization for inputs."); } }; diff --git a/include/tvm/relay/attrs/debug.h b/include/tvm/relay/attrs/debug.h index ed9ed4ee0626..112228bb41ee 100644 --- a/include/tvm/relay/attrs/debug.h +++ b/include/tvm/relay/attrs/debug.h @@ -25,6 +25,8 @@ #define TVM_RELAY_ATTRS_DEBUG_H_ #include +#include + #include namespace tvm { @@ -37,8 +39,7 @@ struct DebugAttrs : public tvm::AttrsNode { EnvFunc debug_func; TVM_DECLARE_ATTRS(DebugAttrs, "relay.attrs.DebugAttrs") { - TVM_ATTR_FIELD(debug_func) - .describe("The function to use when debugging."); + TVM_ATTR_FIELD(debug_func).describe("The function to use when debugging."); } }; diff --git a/include/tvm/relay/attrs/device_copy.h b/include/tvm/relay/attrs/device_copy.h index 2486fcdf473d..7da92b3ff763 100644 --- a/include/tvm/relay/attrs/device_copy.h +++ b/include/tvm/relay/attrs/device_copy.h @@ -25,6 +25,7 @@ #define TVM_RELAY_ATTRS_DEVICE_COPY_H_ #include + #include namespace tvm { @@ -39,13 +40,11 @@ struct DeviceCopyAttrs : public tvm::AttrsNode { TVM_DECLARE_ATTRS(DeviceCopyAttrs, "relay.attrs.DeviceCopyAttrs") { TVM_ATTR_FIELD(src_dev_type) - .describe( - "The virtual device/context type where the op copies data from.") - .set_default(0); + .describe("The virtual device/context type where the op copies data from.") + .set_default(0); TVM_ATTR_FIELD(dst_dev_type) - .describe( - "The virtual device/context type where the op copies data to.") - .set_default(0); + .describe("The virtual device/context type where the op copies data to.") + .set_default(0); } }; diff --git a/include/tvm/relay/attrs/image.h b/include/tvm/relay/attrs/image.h index 52bb2efc63a8..cf5a6eff74bc 100644 --- a/include/tvm/relay/attrs/image.h +++ b/include/tvm/relay/attrs/image.h @@ -26,6 +26,7 @@ #include #include + #include namespace tvm { @@ -40,26 +41,58 @@ struct ResizeAttrs : public tvm::AttrsNode { DataType out_dtype; TVM_DECLARE_ATTRS(ResizeAttrs, "relay.attrs.ResizeAttrs") { - TVM_ATTR_FIELD(size).set_default(NullValue >()) - .describe("Output Size."); - TVM_ATTR_FIELD(layout).set_default("NCHW") - .describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." - "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" - "dimensions respectively. Resize is applied on the 'H' and" - "'W' dimensions."); - TVM_ATTR_FIELD(method).set_default("bilinear") - .describe("Specify the mode to use for scaling." - "nearest_neighbor - Nearest Neighbor" - "bilinear - Bilinear Interpolation" - "bicubic - Bicubic Interpolation"); - TVM_ATTR_FIELD(coordinate_transformation_mode).set_default("half_pixel") - .describe("Describes how to transform the coordinate in the resized tensor" - "to the coordinate in the original tensor." - "Refer to the ONNX Resize operator specification for details" - "Available options are half_pixel, align_corners and asymmetric"); - TVM_ATTR_FIELD(out_dtype) - .set_default(NullValue()) - .describe("Output data type."); + TVM_ATTR_FIELD(size).set_default(NullValue >()).describe("Output Size."); + TVM_ATTR_FIELD(layout).set_default("NCHW").describe( + "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Resize is applied on the 'H' and" + "'W' dimensions."); + TVM_ATTR_FIELD(method) + .set_default("bilinear") + .describe( + "Specify the mode to use for scaling." + "nearest_neighbor - Nearest Neighbor" + "bilinear - Bilinear Interpolation" + "bicubic - Bicubic Interpolation"); + TVM_ATTR_FIELD(coordinate_transformation_mode) + .set_default("half_pixel") + .describe( + "Describes how to transform the coordinate in the resized tensor" + "to the coordinate in the original tensor." + "Refer to the ONNX Resize operator specification for details" + "Available options are half_pixel, align_corners and asymmetric"); + TVM_ATTR_FIELD(out_dtype).set_default(NullValue()).describe("Output data type."); + } +}; + +/*! \brief Attributes used in image resize3d operator */ +struct Resize3dAttrs : public tvm::AttrsNode { + Array size; + String layout; + String method; + String coordinate_transformation_mode; + DataType out_dtype; + + TVM_DECLARE_ATTRS(Resize3dAttrs, "relay.attrs.Resize3dAttrs") { + TVM_ATTR_FIELD(size).set_default(NullValue >()).describe("Output Size."); + TVM_ATTR_FIELD(layout).set_default("NCDHW").describe( + "Dimension ordering of input data. Can be 'NCDHW', 'NDHWC', etc." + "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width" + "dimensions respectively. Resize3d is applied on the 'D', 'H' and" + "'W' dimensions."); + TVM_ATTR_FIELD(method) + .set_default("trilinear") + .describe( + "Specify the mode to use for scaling." + "nearest_neighbor - Nearest Neighbor" + "trilinear - Trilinear Interpolation"); + TVM_ATTR_FIELD(coordinate_transformation_mode) + .set_default("half_pixel") + .describe( + "Describes how to transform the coordinate in the resized tensor" + "to the coordinate in the original tensor." + "Available options are half_pixel, align_corners and asymmetric"); + TVM_ATTR_FIELD(out_dtype).set_default(NullValue()).describe("Output data type."); } }; @@ -72,22 +105,22 @@ struct CropAndResizeAttrs : public tvm::AttrsNode { DataType out_dtype; TVM_DECLARE_ATTRS(CropAndResizeAttrs, "relay.attrs.CropAndResizeAttrs") { - TVM_ATTR_FIELD(crop_size).set_default(NullValue >()) - .describe("Target Size."); - TVM_ATTR_FIELD(layout).set_default("NCHW") - .describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." - "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" - "dimensions respectively. Resize is applied on the 'H' and" - "'W' dimensions."); - TVM_ATTR_FIELD(method).set_default("bilinear") - .describe("Specify the mode to use for scaling." - "nearest_neighbor - Nearest Neighbor" - "bilinear - Bilinear Interpolation"); - TVM_ATTR_FIELD(extrapolation_value).set_default(0.0) + TVM_ATTR_FIELD(crop_size).set_default(NullValue >()).describe("Target Size."); + TVM_ATTR_FIELD(layout).set_default("NCHW").describe( + "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Resize is applied on the 'H' and" + "'W' dimensions."); + TVM_ATTR_FIELD(method) + .set_default("bilinear") + .describe( + "Specify the mode to use for scaling." + "nearest_neighbor - Nearest Neighbor" + "bilinear - Bilinear Interpolation"); + TVM_ATTR_FIELD(extrapolation_value) + .set_default(0.0) .describe("Specify value for extrapolation."); - TVM_ATTR_FIELD(out_dtype) - .set_default(NullValue()) - .describe("Output data type."); + TVM_ATTR_FIELD(out_dtype).set_default(NullValue()).describe("Output data type."); } }; @@ -101,31 +134,67 @@ struct Dilation2DAttrs : public tvm::AttrsNode { DataType out_dtype; TVM_DECLARE_ATTRS(Dilation2DAttrs, "relay.attrs.Dilation2DAttrs") { - TVM_ATTR_FIELD(strides).set_default(Array({1, 1})) + TVM_ATTR_FIELD(strides) + .set_default(Array({1, 1})) .describe("Specifies the strides of the sliding window. [stride_height, stride_width]."); - TVM_ATTR_FIELD(padding).set_default(Array({0, 0})) - .describe("If padding is non-zero, then the input is implicitly zero-padded" - "Padding support both symmetric and asymmetric as" - "one int : same padding used on all sides" - "two int : bottom, right will use same padding as top, left" - "four int : padding width in the order of (top, left, bottom, right)"); - TVM_ATTR_FIELD(dilations).set_default(Array({1, 1})) + TVM_ATTR_FIELD(padding) + .set_default(Array({0, 0})) + .describe( + "If padding is non-zero, then the input is implicitly zero-padded" + "Padding support both symmetric and asymmetric as" + "one int : same padding used on all sides" + "two int : bottom, right will use same padding as top, left" + "four int : padding width in the order of (top, left, bottom, right)"); + TVM_ATTR_FIELD(dilations) + .set_default(Array({1, 1})) .describe("Specifies the dilation rate to use. [dilation_height, dilation_width]"); - TVM_ATTR_FIELD(data_layout).set_default("NCHW") - .describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." - "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" - "dimensions respectively. Convolution is applied on the 'H' and" - "'W' dimensions."); - TVM_ATTR_FIELD(kernel_layout).set_default("IHW") - .describe("Dimension ordering of weight. Can be 'IHW', 'HWI', etc." - "'I', 'H', 'W' stands for input_channel, height, and width" - "dimensions respectively."); + TVM_ATTR_FIELD(data_layout) + .set_default("NCHW") + .describe( + "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Convolution is applied on the 'H' and" + "'W' dimensions."); + TVM_ATTR_FIELD(kernel_layout) + .set_default("IHW") + .describe( + "Dimension ordering of weight. Can be 'IHW', 'HWI', etc." + "'I', 'H', 'W' stands for input_channel, height, and width" + "dimensions respectively."); TVM_ATTR_FIELD(out_dtype) .set_default(NullValue()) .describe("Output data type, set to explicit type under mixed precision setting"); } }; +/*! \brief Attributes used in image affine_grid operator */ +struct AffineGridAttrs : public tvm::AttrsNode { + Array target_shape; + + TVM_DECLARE_ATTRS(AffineGridAttrs, "relay.attrs.AffineGridAttrs") { + TVM_ATTR_FIELD(target_shape).describe("Specifies the output shape (H, W)."); + } +}; + +/*! \brief Attributes used in image grid_sample operator */ +struct GridSampleAttrs : public tvm::AttrsNode { + String method; + String layout; + + TVM_DECLARE_ATTRS(GridSampleAttrs, "relay.attrs.GridSampleAttrs") { + TVM_ATTR_FIELD(method) + .set_default("bilinear") + .describe( + "Specify the mode to use for scaling." + "bilinear - Bilinear Interpolation"); + TVM_ATTR_FIELD(layout).set_default("NCHW").describe( + "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Resize is applied on the 'H' and" + "'W' dimensions."); + } +}; + } // namespace relay } // namespace tvm #endif // TVM_RELAY_ATTRS_IMAGE_H_ diff --git a/include/tvm/relay/attrs/memory.h b/include/tvm/relay/attrs/memory.h index d232f867a777..7429c396ea00 100644 --- a/include/tvm/relay/attrs/memory.h +++ b/include/tvm/relay/attrs/memory.h @@ -26,6 +26,7 @@ #include #include + #include #include @@ -46,15 +47,10 @@ struct AllocStorageAttrs : public tvm::AttrsNode { TVM_DECLARE_ATTRS(AllocStorageAttrs, "relay.attrs.AllocStorageAttrs") { TVM_ATTR_FIELD(dtype) - .describe( - "The dtype of the tensor to allocate.") - .set_default(DataType::Float(32, 1)); - TVM_ATTR_FIELD(device_id) - .describe( - "The device id on which to allocate memory."); - TVM_ATTR_FIELD(device_type) - .describe( - "The device type on which to allocate memory."); + .describe("The dtype of the tensor to allocate.") + .set_default(DataType::Float(32, 1)); + TVM_ATTR_FIELD(device_id).describe("The device id on which to allocate memory."); + TVM_ATTR_FIELD(device_type).describe("The device type on which to allocate memory."); } }; @@ -68,16 +64,13 @@ struct AllocTensorAttrs : public tvm::AttrsNode { TVM_DECLARE_ATTRS(AllocTensorAttrs, "relay.attrs.AllocTensorAttrs") { TVM_ATTR_FIELD(dtype) - .describe( - "The dtype of the tensor to allocate.") - .set_default(DataType::Float(32, 1)); - TVM_ATTR_FIELD(const_shape) - .describe( - "The shape of constant used to aid in type inference."); + .describe("The dtype of the tensor to allocate.") + .set_default(DataType::Float(32, 1)); + TVM_ATTR_FIELD(const_shape).describe("The shape of constant used to aid in type inference."); TVM_ATTR_FIELD(assert_shape) - .describe( - "The shape to cast the return type of the allocation to, "\ - "used to specify the shape obtained via further analysis."); + .describe( + "The shape to cast the return type of the allocation to, " + "used to specify the shape obtained via further analysis."); } }; @@ -88,10 +81,9 @@ struct ShapeFuncAttrs : public tvm::AttrsNode { Array is_input; TVM_DECLARE_ATTRS(ShapeFuncAttrs, "relay.attrs.ShapeFuncAttrs") { - TVM_ATTR_FIELD(is_input) - .describe( - "A bool indicating whether the shape function should"\ - "expect shape or input in each position."); + TVM_ATTR_FIELD(is_input).describe( + "A bool indicating whether the shape function should" + "expect shape or input in each position."); } }; diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index 536e4145db29..abe63e583ddc 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -26,6 +26,7 @@ #include #include + #include namespace tvm { @@ -42,13 +43,10 @@ struct BiasAddAttrs : public tvm::AttrsNode { int axis; TVM_DECLARE_ATTRS(BiasAddAttrs, "relay.attrs.BiasAddAttrs") { - TVM_ATTR_FIELD(axis) - .describe("The axis to add the bias") - .set_default(1); + TVM_ATTR_FIELD(axis).describe("The axis to add the bias").set_default(1); } }; - /*! \brief Attributes used in 1D convolution operators */ struct Conv1DAttrs : public tvm::AttrsNode { Array strides; @@ -63,31 +61,44 @@ struct Conv1DAttrs : public tvm::AttrsNode { DataType out_dtype; TVM_DECLARE_ATTRS(Conv1DAttrs, "relay.attrs.Conv1DAttrs") { - TVM_ATTR_FIELD(strides).set_default(Array({1, })) + TVM_ATTR_FIELD(strides) + .set_default(Array({ + 1, + })) .describe("Specifies the stride of the convolution."); - TVM_ATTR_FIELD(padding).set_default(Array({0, 0})) - .describe("If padding is non-zero, then the input is implicitly zero-padded" - "on both sides for padding number of points"); - TVM_ATTR_FIELD(dilation).set_default(Array({1, })) + TVM_ATTR_FIELD(padding) + .set_default(Array({0, 0})) + .describe( + "If padding is non-zero, then the input is implicitly zero-padded" + "on both sides for padding number of points"); + TVM_ATTR_FIELD(dilation) + .set_default(Array({ + 1, + })) .describe("Specifies the dilation rate to use for dilated convolution."); - TVM_ATTR_FIELD(groups).set_default(1) - .describe("Currently unused but may be added in the future."); + TVM_ATTR_FIELD(groups).set_default(1).describe( + "Currently unused but may be added in the future."); TVM_ATTR_FIELD(channels) - .describe("The number of output channels in the convolution." - " If it is not set, inferred by shape of the weight.") + .describe( + "The number of output channels in the convolution." + " If it is not set, inferred by shape of the weight.") .set_default(NullValue()); TVM_ATTR_FIELD(kernel_size) .describe("Specifies the dimensions of the convolution window.") .set_default(NullValue >()); - TVM_ATTR_FIELD(data_layout).set_default("NCW") - .describe("Dimension ordering of input data. Can be 'NCW', 'NWC', etc." - "'N', 'C', 'W' stands for batch, channel, and width" - "dimensions respectively. Convolution is applied on the 'W'" - "dimension."); - TVM_ATTR_FIELD(kernel_layout).set_default("OIW") - .describe("Dimension ordering of weight. Can be 'OIW', or 'WIO', etc." - "'O', 'I', 'W' stands for num_filter, input_channel, and width" - "dimensions respectively."); + TVM_ATTR_FIELD(data_layout) + .set_default("NCW") + .describe( + "Dimension ordering of input data. Can be 'NCW', 'NWC', etc." + "'N', 'C', 'W' stands for batch, channel, and width" + "dimensions respectively. Convolution is applied on the 'W'" + "dimension."); + TVM_ATTR_FIELD(kernel_layout) + .set_default("OIW") + .describe( + "Dimension ordering of weight. Can be 'OIW', or 'WIO', etc." + "'O', 'I', 'W' stands for num_filter, input_channel, and width" + "dimensions respectively."); // use 0 bits to indicate none. TVM_ATTR_FIELD(out_dtype) @@ -96,7 +107,6 @@ struct Conv1DAttrs : public tvm::AttrsNode { } }; - /*! \brief Attributes used in convolution operators */ struct Conv2DAttrs : public tvm::AttrsNode { Array strides; @@ -111,42 +121,53 @@ struct Conv2DAttrs : public tvm::AttrsNode { DataType out_dtype; TVM_DECLARE_ATTRS(Conv2DAttrs, "relay.attrs.Conv2DAttrs") { - TVM_ATTR_FIELD(strides).set_default(Array({1, 1})) + TVM_ATTR_FIELD(strides) + .set_default(Array({1, 1})) .describe("Specifies the strides of the convolution."); - TVM_ATTR_FIELD(padding).set_default(Array({0, 0})) - .describe("If padding is non-zero, then the input is implicitly zero-padded" - "Padding support both symmetric and asymmetric as" - "one int : same padding used on all sides" - "two int : bottom, right will use same padding as top, left" - "four int : padding width in the order of (top, left, bottom, right)"); - TVM_ATTR_FIELD(dilation).set_default(Array({1, 1})) + TVM_ATTR_FIELD(padding) + .set_default(Array({0, 0})) + .describe( + "If padding is non-zero, then the input is implicitly zero-padded" + "Padding support both symmetric and asymmetric as" + "one int : same padding used on all sides" + "two int : bottom, right will use same padding as top, left" + "four int : padding width in the order of (top, left, bottom, right)"); + TVM_ATTR_FIELD(dilation) + .set_default(Array({1, 1})) .describe("Specifies the dilation rate to use for dilated convolution."); - TVM_ATTR_FIELD(groups).set_default(1) - .describe("Controls the connections between inputs and outputs." - "At groups=1, all inputs are convolved to all outputs." - "At groups=2, the operation becomes equivalent to having two convolution" - "layers side by side, each seeing half the input channels, and producing" - "half the output channels, and both subsequently concatenated."); + TVM_ATTR_FIELD(groups).set_default(1).describe( + "Controls the connections between inputs and outputs." + "At groups=1, all inputs are convolved to all outputs." + "At groups=2, the operation becomes equivalent to having two convolution" + "layers side by side, each seeing half the input channels, and producing" + "half the output channels, and both subsequently concatenated."); TVM_ATTR_FIELD(channels) - .describe("The number of output channels in the convolution." - " If it is not set, inferred by shape of the weight.") + .describe( + "The number of output channels in the convolution." + " If it is not set, inferred by shape of the weight.") .set_default(NullValue()); TVM_ATTR_FIELD(kernel_size) .describe("Specifies the dimensions of the convolution window.") .set_default(NullValue >()); - TVM_ATTR_FIELD(data_layout).set_default("NCHW") - .describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." - "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" - "dimensions respectively. Convolution is applied on the 'H' and" - "'W' dimensions."); - TVM_ATTR_FIELD(kernel_layout).set_default("OIHW") - .describe("Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc." - "'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width" - "dimensions respectively."); - TVM_ATTR_FIELD(out_layout).set_default("") - .describe("Dimension ordering of output. Can be 'NCHW', 'NHWC', etc." - "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" - "dimensions respectively. Default to be same as input layout."); + TVM_ATTR_FIELD(data_layout) + .set_default("NCHW") + .describe( + "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Convolution is applied on the 'H' and" + "'W' dimensions."); + TVM_ATTR_FIELD(kernel_layout) + .set_default("OIHW") + .describe( + "Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc." + "'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width" + "dimensions respectively."); + TVM_ATTR_FIELD(out_layout) + .set_default("") + .describe( + "Dimension ordering of output. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Default to be same as input layout."); // use 0 bits to indicate none. TVM_ATTR_FIELD(out_dtype) @@ -156,14 +177,13 @@ struct Conv2DAttrs : public tvm::AttrsNode { }; /*! \brief Attributes used in winograd weight transformation operators */ -struct ConvWinogradWeightTransformAttrs : - public tvm::AttrsNode { +struct ConvWinogradWeightTransformAttrs : public tvm::AttrsNode { int tile_size; TVM_DECLARE_ATTRS(ConvWinogradWeightTransformAttrs, - "relay.attrs.ConvWinogradWeightTransformAttrs") { - TVM_ATTR_FIELD(tile_size) - .describe("Tile size of winograd. E.g. 2 for F(2x2, 3x3) and 4 for F(4x4, 3x3)"); + "relay.attrs.ConvWinogradWeightTransformAttrs") { + TVM_ATTR_FIELD(tile_size).describe( + "Tile size of winograd. E.g. 2 for F(2x2, 3x3) and 4 for F(4x4, 3x3)"); } }; @@ -182,44 +202,55 @@ struct Conv2DWinogradAttrs : public tvm::AttrsNode { DataType out_dtype; TVM_DECLARE_ATTRS(Conv2DWinogradAttrs, "relay.attrs.Conv2DWinogradAttrs") { - TVM_ATTR_FIELD(tile_size) - .describe("The tile size of winograd. E.g. 2 for F(2x2, 3x3) and 4 for F(4x4, 3x3)"); - TVM_ATTR_FIELD(strides).set_default(Array({1, 1})) + TVM_ATTR_FIELD(tile_size).describe( + "The tile size of winograd. E.g. 2 for F(2x2, 3x3) and 4 for F(4x4, 3x3)"); + TVM_ATTR_FIELD(strides) + .set_default(Array({1, 1})) .describe("Specifies the strides of the convolution."); - TVM_ATTR_FIELD(padding).set_default(Array({0, 0})) - .describe("If padding is non-zero, then the input is implicitly zero-padded" - "Padding support both symmetric and asymmetric as" - "one int : same padding used on all sides" - "two int : bottom, right will use same padding as top, left" - "four int : padding width in the order of (top, left, bottom, right)"); - TVM_ATTR_FIELD(dilation).set_default(Array({1, 1})) + TVM_ATTR_FIELD(padding) + .set_default(Array({0, 0})) + .describe( + "If padding is non-zero, then the input is implicitly zero-padded" + "Padding support both symmetric and asymmetric as" + "one int : same padding used on all sides" + "two int : bottom, right will use same padding as top, left" + "four int : padding width in the order of (top, left, bottom, right)"); + TVM_ATTR_FIELD(dilation) + .set_default(Array({1, 1})) .describe("Specifies the dilation rate to use for dilated convolution."); - TVM_ATTR_FIELD(groups).set_default(1) - .describe("Controls the connections between inputs and outputs." - "At groups=1, all inputs are convolved to all outputs." - "At groups=2, the operation becomes equivalent to having two convolution" - "layers side by side, each seeing half the input channels, and producing" - "half the output channels, and both subsequently concatenated."); + TVM_ATTR_FIELD(groups).set_default(1).describe( + "Controls the connections between inputs and outputs." + "At groups=1, all inputs are convolved to all outputs." + "At groups=2, the operation becomes equivalent to having two convolution" + "layers side by side, each seeing half the input channels, and producing" + "half the output channels, and both subsequently concatenated."); TVM_ATTR_FIELD(channels) - .describe("The number of output channels in the convolution." - " If it is not set, inferred by shape of the weight.") + .describe( + "The number of output channels in the convolution." + " If it is not set, inferred by shape of the weight.") .set_default(NullValue()); TVM_ATTR_FIELD(kernel_size) .describe("Specifies the dimensions of the convolution window.") .set_default(NullValue >()); - TVM_ATTR_FIELD(data_layout).set_default("NCHW") - .describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." - "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" - "dimensions respectively. Convolution is applied on the 'H' and" - "'W' dimensions."); - TVM_ATTR_FIELD(kernel_layout).set_default("OIHW") - .describe("Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc." - "'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width" - "dimensions respectively."); - TVM_ATTR_FIELD(out_layout).set_default("") - .describe("Dimension ordering of output. Can be 'NCHW', 'NHWC', etc." - "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" - "dimensions respectively. Default to be same as input layout."); + TVM_ATTR_FIELD(data_layout) + .set_default("NCHW") + .describe( + "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Convolution is applied on the 'H' and" + "'W' dimensions."); + TVM_ATTR_FIELD(kernel_layout) + .set_default("OIHW") + .describe( + "Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc." + "'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width" + "dimensions respectively."); + TVM_ATTR_FIELD(out_layout) + .set_default("") + .describe( + "Dimension ordering of output. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Default to be same as input layout."); // use 0 bits to indicate none. TVM_ATTR_FIELD(out_dtype) @@ -261,43 +292,54 @@ struct Conv3DAttrs : public tvm::AttrsNode { DataType out_dtype; TVM_DECLARE_ATTRS(Conv3DAttrs, "relay.attrs.Conv3DAttrs") { - TVM_ATTR_FIELD(strides).set_default(Array({1, 1, 1})) + TVM_ATTR_FIELD(strides) + .set_default(Array({1, 1, 1})) .describe("Specifies the strides of the convolution."); - TVM_ATTR_FIELD(padding).set_default(Array({0, 0, 0})) - .describe("If padding is non-zero, then the input is implicitly zero-padded" - "Padding support both symmetric and asymmetric as" - "one int : same padding used on all sides" - "three int : back, bottom, right will use same padding as front, top, left" - "six int : padding width in the order of (front, top, left, back, bottom," - "right)"); - TVM_ATTR_FIELD(dilation).set_default(Array({1, 1, 1})) + TVM_ATTR_FIELD(padding) + .set_default(Array({0, 0, 0})) + .describe( + "If padding is non-zero, then the input is implicitly zero-padded" + "Padding support both symmetric and asymmetric as" + "one int : same padding used on all sides" + "three int : back, bottom, right will use same padding as front, top, left" + "six int : padding width in the order of (front, top, left, back, bottom," + "right)"); + TVM_ATTR_FIELD(dilation) + .set_default(Array({1, 1, 1})) .describe("Specifies the dilation rate to use for dilated convolution."); - TVM_ATTR_FIELD(groups).set_default(1) - .describe("Controls the connections between inputs and outputs." - "At groups=1, all inputs are convolved to all outputs." - "At groups=2, the operation becomes equivalent to having two convolution" - "layers side by side, each seeing half the input channels, and producing" - "half the output channels, and both subsequently concatenated."); + TVM_ATTR_FIELD(groups).set_default(1).describe( + "Controls the connections between inputs and outputs." + "At groups=1, all inputs are convolved to all outputs." + "At groups=2, the operation becomes equivalent to having two convolution" + "layers side by side, each seeing half the input channels, and producing" + "half the output channels, and both subsequently concatenated."); TVM_ATTR_FIELD(channels) - .describe("The number of output channels in the convolution." - " If it is not set, inferred by shape of the weight.") + .describe( + "The number of output channels in the convolution." + " If it is not set, inferred by shape of the weight.") .set_default(NullValue()); TVM_ATTR_FIELD(kernel_size) .describe("Specifies the dimensions of the convolution window.") .set_default(NullValue >()); - TVM_ATTR_FIELD(data_layout).set_default("NCDHW") - .describe("Dimension ordering of input data. Can be 'NCDHW', 'NDHWC', etc." - "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width" - "dimensions respectively. Convolution is applied on the 'D', 'H' and" - "'W' dimensions."); - TVM_ATTR_FIELD(kernel_layout).set_default("OIDHW") - .describe("Dimension ordering of weight. Can be 'OIDHW', 'OIDHW16o16i', etc." - "'O', 'I', 'D', 'H', 'W' stands for num_filter, input_channel, depth, height," - "and width dimensions respectively."); - TVM_ATTR_FIELD(out_layout).set_default("") - .describe("Dimension ordering of output. Can be 'NCDHW', 'NDHWC', etc." - "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width" - "dimensions respectively. Default to be same as input layout."); + TVM_ATTR_FIELD(data_layout) + .set_default("NCDHW") + .describe( + "Dimension ordering of input data. Can be 'NCDHW', 'NDHWC', etc." + "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width" + "dimensions respectively. Convolution is applied on the 'D', 'H' and" + "'W' dimensions."); + TVM_ATTR_FIELD(kernel_layout) + .set_default("OIDHW") + .describe( + "Dimension ordering of weight. Can be 'OIDHW', 'OIDHW16o16i', etc." + "'O', 'I', 'D', 'H', 'W' stands for num_filter, input_channel, depth, height," + "and width dimensions respectively."); + TVM_ATTR_FIELD(out_layout) + .set_default("") + .describe( + "Dimension ordering of output. Can be 'NCDHW', 'NDHWC', etc." + "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width" + "dimensions respectively. Default to be same as input layout."); // use 0 bits to indicate none. TVM_ATTR_FIELD(out_dtype) @@ -306,6 +348,82 @@ struct Conv3DAttrs : public tvm::AttrsNode { } }; +/*! \brief Attributes used in transposed convolution operator */ +struct Conv3DTransposeAttrs : public tvm::AttrsNode { + IndexExpr channels; + Array kernel_size; + Array strides; + Array padding; + Array output_padding; + Array dilation; + int groups; + std::string data_layout; + std::string kernel_layout; + std::string out_layout; + DataType out_dtype; + + TVM_DECLARE_ATTRS(Conv3DTransposeAttrs, "relay.attrs.Conv3DTransposeAttrs") { + TVM_ATTR_FIELD(channels) + .set_default(NullValue()) + .describe( + "The dimensionality of the output space" + "i.e. the number of output channels in the convolution."); + TVM_ATTR_FIELD(kernel_size) + .describe("The dimensions of the convolution window.") + .set_default(NullValue >()); + TVM_ATTR_FIELD(strides) + .set_default(Array({1, 1, 1})) + .describe("The strides of the convolution."); + TVM_ATTR_FIELD(output_padding) + .set_default(Array({0, 0, 0})) + .describe( + "Zero-padding added to one side of the output." + "Padding support both symmetric and asymmetric as" + "one int : same padding used on all sides" + "three int : front, bottom, right will use same padding as back, top, left" + "six int : padding width in the order of (front, top, left, back, bottom, right)"); + TVM_ATTR_FIELD(padding) + .set_default(Array({0, 0, 0})) + .describe( + "If padding is non-zero, then the input is implicitly zero-padded" + "Padding support both symmetric and asymmetric as" + "one int : same padding used on all sides" + "three int : front, bottom, right will use same padding as back, top, left" + "six int : padding width in the order of (front, top, left, back, bottom, right)"); + TVM_ATTR_FIELD(dilation) + .set_default(Array({1, 1, 1})) + .describe("Specifies the dilation rate to use for dilated convolution."); + TVM_ATTR_FIELD(groups).set_default(1).describe( + "Controls the connections between inputs and outputs." + "At groups=1, all inputs are convolved to all outputs." + "At groups=2, the operation becomes equivalent to having two convolution" + "layers side by side, each seeing half the input channels, and producing" + "half the output channels, and both subsequently concatenated."); + TVM_ATTR_FIELD(data_layout) + .set_default("NCDHW") + .describe( + "Dimension ordering of data. Can be 'NCDHW', 'NDHWC', etc." + "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width" + "dimensions respectively. Convolution is applied on the 'D', 'H' and" + "'W' dimensions."); + TVM_ATTR_FIELD(kernel_layout) + .set_default("OIDHW") + .describe( + "Dimension ordering of data and weight. Can be 'OIDHW', 'OIDHW16o16i', etc." + "'O', 'I', 'D', 'H', 'W' stands for num_filter, input_channel, depth, height, and width" + "dimensions respectively."); + TVM_ATTR_FIELD(out_layout) + .set_default("") + .describe( + "Dimension ordering of output. Can be 'NCDHW', 'NDHWC', etc." + "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width" + "dimensions respectively. Default to be same as input layout."); + TVM_ATTR_FIELD(out_dtype) + .set_default(NullValue()) + .describe("Output data type, set to explicit type under mixed precision setting"); + } +}; + /*! \brief Attributes used in 3d winograd convolution operators */ struct Conv3DWinogradAttrs : public tvm::AttrsNode { int tile_size; @@ -321,45 +439,56 @@ struct Conv3DWinogradAttrs : public tvm::AttrsNode { DataType out_dtype; TVM_DECLARE_ATTRS(Conv3DWinogradAttrs, "relay.attrs.Conv3DWinogradAttrs") { - TVM_ATTR_FIELD(tile_size) - .describe("The tile size of winograd. E.g. 2 for F(2x2x2, 3x3x3) and 4 for F(4x4x4, 3x3x3)"); - TVM_ATTR_FIELD(strides).set_default(Array({1, 1, 1})) + TVM_ATTR_FIELD(tile_size).describe( + "The tile size of winograd. E.g. 2 for F(2x2x2, 3x3x3) and 4 for F(4x4x4, 3x3x3)"); + TVM_ATTR_FIELD(strides) + .set_default(Array({1, 1, 1})) .describe("Specifies the strides of the convolution."); - TVM_ATTR_FIELD(padding).set_default(Array({0, 0, 0})) - .describe("If padding is non-zero, then the input is implicitly zero-padded" - "Padding support both symmetric and asymmetric as" - "one int : same padding used on all sides" - "three int : back, bottom, right will use same padding as front, top, left" - "six int : padding width in the order of (front, top, left, back, bottom," - "right)"); - TVM_ATTR_FIELD(dilation).set_default(Array({1, 1, 1})) + TVM_ATTR_FIELD(padding) + .set_default(Array({0, 0, 0})) + .describe( + "If padding is non-zero, then the input is implicitly zero-padded" + "Padding support both symmetric and asymmetric as" + "one int : same padding used on all sides" + "three int : back, bottom, right will use same padding as front, top, left" + "six int : padding width in the order of (front, top, left, back, bottom," + "right)"); + TVM_ATTR_FIELD(dilation) + .set_default(Array({1, 1, 1})) .describe("Specifies the dilation rate to use for dilated convolution."); - TVM_ATTR_FIELD(groups).set_default(1) - .describe("Controls the connections between inputs and outputs." - "At groups=1, all inputs are convolved to all outputs." - "At groups=2, the operation becomes equivalent to having two convolution" - "layers side by side, each seeing half the input channels, and producing" - "half the output channels, and both subsequently concatenated."); + TVM_ATTR_FIELD(groups).set_default(1).describe( + "Controls the connections between inputs and outputs." + "At groups=1, all inputs are convolved to all outputs." + "At groups=2, the operation becomes equivalent to having two convolution" + "layers side by side, each seeing half the input channels, and producing" + "half the output channels, and both subsequently concatenated."); TVM_ATTR_FIELD(channels) - .describe("The number of output channels in the convolution." - " If it is not set, inferred by shape of the weight.") + .describe( + "The number of output channels in the convolution." + " If it is not set, inferred by shape of the weight.") .set_default(NullValue()); TVM_ATTR_FIELD(kernel_size) .describe("Specifies the dimensions of the convolution window.") .set_default(NullValue >()); - TVM_ATTR_FIELD(data_layout).set_default("NCDHW") - .describe("Dimension ordering of input data. Can be 'NCDHW', 'NDHWC', etc." - "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width" - "dimensions respectively. Convolution is applied on the 'D', 'H' and" - "'W' dimensions."); - TVM_ATTR_FIELD(kernel_layout).set_default("OIDHW") - .describe("Dimension ordering of weight. Can be 'OIDHW', 'OIDHW16o16i', etc." - "'O', 'I', 'D', 'H', 'W' stands for num_filter, input_channel, depth, height," - "and width dimensions respectively."); - TVM_ATTR_FIELD(out_layout).set_default("") - .describe("Dimension ordering of output. Can be 'NCDHW', 'NDHWC', etc." - "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width" - "dimensions respectively. Default to be same as input layout."); + TVM_ATTR_FIELD(data_layout) + .set_default("NCDHW") + .describe( + "Dimension ordering of input data. Can be 'NCDHW', 'NDHWC', etc." + "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width" + "dimensions respectively. Convolution is applied on the 'D', 'H' and" + "'W' dimensions."); + TVM_ATTR_FIELD(kernel_layout) + .set_default("OIDHW") + .describe( + "Dimension ordering of weight. Can be 'OIDHW', 'OIDHW16o16i', etc." + "'O', 'I', 'D', 'H', 'W' stands for num_filter, input_channel, depth, height," + "and width dimensions respectively."); + TVM_ATTR_FIELD(out_layout) + .set_default("") + .describe( + "Dimension ordering of output. Can be 'NCDHW', 'NDHWC', etc." + "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width" + "dimensions respectively. Default to be same as input layout."); // use 0 bits to indicate none. TVM_ATTR_FIELD(out_dtype) @@ -368,14 +497,12 @@ struct Conv3DWinogradAttrs : public tvm::AttrsNode { } }; - /*! \brief Attributes used in softmax operators */ struct SoftmaxAttrs : public tvm::AttrsNode { int axis; TVM_DECLARE_ATTRS(SoftmaxAttrs, "relay.attrs.SoftmaxAttrs") { - TVM_ATTR_FIELD(axis).set_default(-1) - .describe("The axis to sum over when computing softmax."); + TVM_ATTR_FIELD(axis).set_default(-1).describe("The axis to sum over when computing softmax."); } }; @@ -395,53 +522,77 @@ struct Conv2DTransposeAttrs : public tvm::AttrsNode { TVM_DECLARE_ATTRS(Conv2DTransposeAttrs, "relay.attrs.Conv2DTransposeAttrs") { TVM_ATTR_FIELD(channels) - .set_default(NullValue()) - .describe("The dimensionality of the output space" - "i.e. the number of output channels in the convolution."); + .set_default(NullValue()) + .describe( + "The dimensionality of the output space" + "i.e. the number of output channels in the convolution."); TVM_ATTR_FIELD(kernel_size) - .describe("The dimensions of the convolution window.") - .set_default(NullValue >()); - TVM_ATTR_FIELD(strides).set_default(Array({1, 1})) - .describe("The strides of the convolution."); - TVM_ATTR_FIELD(output_padding).set_default(Array({0, 0})) - .describe("Zero-padding added to one side of the output." - "Padding support both symmetric and asymmetric as" - "one int : same padding used on all sides" - "two int : bottom, right will use same padding as top, left" - "four int : padding width in the order of (top, left, bottom, right)"); - TVM_ATTR_FIELD(padding).set_default(Array({0, 0})) - .describe("If padding is non-zero, then the input is implicitly zero-padded" - "Padding support both symmetric and asymmetric as" - "one int : same padding used on all sides" - "two int : bottom, right will use same padding as top, left" - "four int : padding width in the order of (top, left, bottom, right)"); - TVM_ATTR_FIELD(dilation).set_default(Array({1, 1})) - .describe("Specifies the dilation rate to use for dilated convolution."); - TVM_ATTR_FIELD(groups).set_default(1) - .describe("Controls the connections between inputs and outputs." - "At groups=1, all inputs are convolved to all outputs." - "At groups=2, the operation becomes equivalent to having two convolution" - "layers side by side, each seeing half the input channels, and producing" - "half the output channels, and both subsequently concatenated."); - TVM_ATTR_FIELD(data_layout).set_default("NCHW") - .describe("Dimension ordering of data. Can be 'NCHW', 'NHWC', etc." - "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" - "dimensions respectively. Convolution is applied on the 'H' and" - "'W' dimensions."); - TVM_ATTR_FIELD(kernel_layout).set_default("OIHW") - .describe("Dimension ordering of data and weight. Can be 'OIHW', 'OIHW16o16i', etc." - "'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width" - "dimensions respectively."); - TVM_ATTR_FIELD(out_layout).set_default("") - .describe("Dimension ordering of output. Can be 'NCHW', 'NHWC', etc." - "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" - "dimensions respectively. Default to be same as input layout."); + .describe("The dimensions of the convolution window.") + .set_default(NullValue >()); + TVM_ATTR_FIELD(strides) + .set_default(Array({1, 1})) + .describe("The strides of the convolution."); + TVM_ATTR_FIELD(output_padding) + .set_default(Array({0, 0})) + .describe( + "Zero-padding added to one side of the output." + "Padding support both symmetric and asymmetric as" + "one int : same padding used on all sides" + "two int : bottom, right will use same padding as top, left" + "four int : padding width in the order of (top, left, bottom, right)"); + TVM_ATTR_FIELD(padding) + .set_default(Array({0, 0})) + .describe( + "If padding is non-zero, then the input is implicitly zero-padded" + "Padding support both symmetric and asymmetric as" + "one int : same padding used on all sides" + "two int : bottom, right will use same padding as top, left" + "four int : padding width in the order of (top, left, bottom, right)"); + TVM_ATTR_FIELD(dilation) + .set_default(Array({1, 1})) + .describe("Specifies the dilation rate to use for dilated convolution."); + TVM_ATTR_FIELD(groups).set_default(1).describe( + "Controls the connections between inputs and outputs." + "At groups=1, all inputs are convolved to all outputs." + "At groups=2, the operation becomes equivalent to having two convolution" + "layers side by side, each seeing half the input channels, and producing" + "half the output channels, and both subsequently concatenated."); + TVM_ATTR_FIELD(data_layout) + .set_default("NCHW") + .describe( + "Dimension ordering of data. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Convolution is applied on the 'H' and" + "'W' dimensions."); + TVM_ATTR_FIELD(kernel_layout) + .set_default("OIHW") + .describe( + "Dimension ordering of data and weight. Can be 'OIHW', 'OIHW16o16i', etc." + "'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width" + "dimensions respectively."); + TVM_ATTR_FIELD(out_layout) + .set_default("") + .describe( + "Dimension ordering of output. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Default to be same as input layout."); TVM_ATTR_FIELD(out_dtype) .set_default(NullValue()) .describe("Output data type, set to explicit type under mixed precision setting"); } }; +/*! \brief Attributes used in dilate operator */ +struct DilateAttrs : public tvm::AttrsNode { + Array strides; + + TVM_DECLARE_ATTRS(DilateAttrs, "relay.attrs.DilateAttrs") { + TVM_ATTR_FIELD(strides) + .set_default(Array({1, 1})) + .describe("Dilation stride on each dimension, 1 means no dilation."); + } +}; + /*! \brief Attributes used in 1D transposed convolution operator */ struct Conv1DTransposeAttrs : public tvm::AttrsNode { IndexExpr channels; @@ -458,42 +609,54 @@ struct Conv1DTransposeAttrs : public tvm::AttrsNode { TVM_DECLARE_ATTRS(Conv1DTransposeAttrs, "relay.attrs.Conv1DTransposeAttrs") { TVM_ATTR_FIELD(channels) - .set_default(NullValue()) - .describe("The dimensionality of the output space" - "i.e. the number of output channels in the convolution."); + .set_default(NullValue()) + .describe( + "The dimensionality of the output space" + "i.e. the number of output channels in the convolution."); TVM_ATTR_FIELD(kernel_size) - .describe("The dimensions of the convolution window.") - .set_default(NullValue >()); - TVM_ATTR_FIELD(strides).set_default(Array({1})) - .describe("The strides of the convolution."); - TVM_ATTR_FIELD(output_padding).set_default(Array({0})) - .describe("Zero-padding added to one side of the output."); - TVM_ATTR_FIELD(padding).set_default(Array({0})) - .describe("Symmetric or asymmetric padding." - "Single value: the input is implicitly zero-padded on both sides." - "Two values: padding[0] is used for left input padding, " - "padding[1] is used for right input padding,"); - TVM_ATTR_FIELD(dilation).set_default(Array({1})) - .describe("Specifies the dilation rate to use for dilated convolution."); - TVM_ATTR_FIELD(groups).set_default(1) - .describe("Controls the connections between inputs and outputs." - "At groups=1, all inputs are convolved to all outputs." - "At groups=2, the operation becomes equivalent to having two convolution" - "layers side by side, each seeing half the input channels, and producing" - "half the output channels, and both subsequently concatenated."); - TVM_ATTR_FIELD(data_layout).set_default("NCW") - .describe("Dimension ordering of data. Can be 'NCW', 'NWC', etc." - "'N', 'C', 'W' stands for batch, channel, and width" - "dimensions respectively. Convolution is applied on the" - "'W' dimension."); - TVM_ATTR_FIELD(kernel_layout).set_default("OIW") - .describe("Dimension ordering of data and weight. Can be 'OIW', 'OIW16o16i', etc." - "'O', 'I', 'W' stands for num_filter, input_channel, and width" - "dimensions respectively."); - TVM_ATTR_FIELD(out_layout).set_default("") - .describe("Dimension ordering of output. Can be 'NCW', 'NWC', etc." - "'N', 'C', 'W' stands for batch, channel, and width" - "dimensions respectively. Default to be same as input layout."); + .describe("The dimensions of the convolution window.") + .set_default(NullValue >()); + TVM_ATTR_FIELD(strides) + .set_default(Array({1})) + .describe("The strides of the convolution."); + TVM_ATTR_FIELD(output_padding) + .set_default(Array({0})) + .describe("Zero-padding added to one side of the output."); + TVM_ATTR_FIELD(padding) + .set_default(Array({0})) + .describe( + "Symmetric or asymmetric padding." + "Single value: the input is implicitly zero-padded on both sides." + "Two values: padding[0] is used for left input padding, " + "padding[1] is used for right input padding,"); + TVM_ATTR_FIELD(dilation) + .set_default(Array({1})) + .describe("Specifies the dilation rate to use for dilated convolution."); + TVM_ATTR_FIELD(groups).set_default(1).describe( + "Controls the connections between inputs and outputs." + "At groups=1, all inputs are convolved to all outputs." + "At groups=2, the operation becomes equivalent to having two convolution" + "layers side by side, each seeing half the input channels, and producing" + "half the output channels, and both subsequently concatenated."); + TVM_ATTR_FIELD(data_layout) + .set_default("NCW") + .describe( + "Dimension ordering of data. Can be 'NCW', 'NWC', etc." + "'N', 'C', 'W' stands for batch, channel, and width" + "dimensions respectively. Convolution is applied on the" + "'W' dimension."); + TVM_ATTR_FIELD(kernel_layout) + .set_default("OIW") + .describe( + "Dimension ordering of data and weight. Can be 'OIW', 'OIW16o16i', etc." + "'O', 'I', 'W' stands for num_filter, input_channel, and width" + "dimensions respectively."); + TVM_ATTR_FIELD(out_layout) + .set_default("") + .describe( + "Dimension ordering of output. Can be 'NCW', 'NWC', etc." + "'N', 'C', 'W' stands for batch, channel, and width" + "dimensions respectively. Default to be same as input layout."); TVM_ATTR_FIELD(out_dtype) .set_default(NullValue()) .describe("Output data type, set to explicit type under mixed precision setting"); @@ -509,23 +672,25 @@ struct MaxPool2DAttrs : public tvm::AttrsNode { bool ceil_mode; TVM_DECLARE_ATTRS(MaxPool2DAttrs, "relay.attrs.MaxPool2DAttrs") { - TVM_ATTR_FIELD(pool_size) - .describe("Size of the pooling windows."); - TVM_ATTR_FIELD(strides).set_default(Array({1, 1})) - .describe("Specifies the strides of the convolution."); - TVM_ATTR_FIELD(padding).set_default(Array({0, 0})) - .describe("If padding is non-zero, then the input is implicitly zero-padded" - "Padding support both symmetric and asymmetric as" - "one int : same padding used on all sides" - "two int : bottom, right will use same padding as top, left" - "four int : padding width in the order of (top, left, bottom, right)"); - TVM_ATTR_FIELD(layout).set_default("NCHW") - .describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." - "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" - "dimensions respectively. Pooling is applied on the 'H' and" - "'W' dimensions."); - TVM_ATTR_FIELD(ceil_mode).set_default(false) - .describe("When true, will use ceil instead of floor to compute the output shape."); + TVM_ATTR_FIELD(pool_size).describe("Size of the pooling windows."); + TVM_ATTR_FIELD(strides) + .set_default(Array({1, 1})) + .describe("Specifies the strides of the convolution."); + TVM_ATTR_FIELD(padding) + .set_default(Array({0, 0})) + .describe( + "If padding is non-zero, then the input is implicitly zero-padded" + "Padding support both symmetric and asymmetric as" + "one int : same padding used on all sides" + "two int : bottom, right will use same padding as top, left" + "four int : padding width in the order of (top, left, bottom, right)"); + TVM_ATTR_FIELD(layout).set_default("NCHW").describe( + "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Pooling is applied on the 'H' and" + "'W' dimensions."); + TVM_ATTR_FIELD(ceil_mode).set_default(false).describe( + "When true, will use ceil instead of floor to compute the output shape."); } }; @@ -539,25 +704,28 @@ struct AvgPool2DAttrs : public tvm::AttrsNode { bool count_include_pad; TVM_DECLARE_ATTRS(AvgPool2DAttrs, "relay.attrs.AvgPool2DAttrs") { - TVM_ATTR_FIELD(pool_size) - .describe("Size of the pooling windows."); - TVM_ATTR_FIELD(strides).set_default(Array({1, 1})) - .describe("Specifies the strides of the convolution."); - TVM_ATTR_FIELD(padding).set_default(Array({0, 0})) - .describe("If padding is non-zero, then the input is implicitly zero-padded" - "Padding support both symmetric and asymmetric as" - "one int : same padding used on all sides" - "two int : bottom, right will use same padding as top, left" - "four int : padding width in the order of (top, left, bottom, right)"); - TVM_ATTR_FIELD(layout).set_default("NCHW") - .describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." - "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" - "dimensions respectively. Pooling is applied on the 'H' and" - "'W' dimensions."); - TVM_ATTR_FIELD(ceil_mode).set_default(false) - .describe("When true, will use ceil instead of floor to compute the output shape."); - TVM_ATTR_FIELD(count_include_pad).set_default(false) - .describe("When true, will include padding to compute the average"); + TVM_ATTR_FIELD(pool_size).describe("Size of the pooling windows."); + TVM_ATTR_FIELD(strides) + .set_default(Array({1, 1})) + .describe("Specifies the strides of the convolution."); + TVM_ATTR_FIELD(padding) + .set_default(Array({0, 0})) + .describe( + "If padding is non-zero, then the input is implicitly zero-padded" + "Padding support both symmetric and asymmetric as" + "one int : same padding used on all sides" + "two int : bottom, right will use same padding as top, left" + "four int : padding width in the order of (top, left, bottom, right)"); + TVM_ATTR_FIELD(layout).set_default("NCHW").describe( + "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Pooling is applied on the 'H' and" + "'W' dimensions."); + TVM_ATTR_FIELD(ceil_mode).set_default(false).describe( + "When true, will use ceil instead of floor to compute the output shape."); + TVM_ATTR_FIELD(count_include_pad) + .set_default(false) + .describe("When true, will include padding to compute the average"); } }; @@ -566,11 +734,11 @@ struct GlobalPool2DAttrs : public tvm::AttrsNode { std::string layout; TVM_DECLARE_ATTRS(GlobalPool2DAttrs, "relay.attrs.GlobalPool2DAttrs") { - TVM_ATTR_FIELD(layout).set_default("NCHW") - .describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." - "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" - "dimensions respectively. Pooling is applied on the 'H' and" - "'W' dimensions."); + TVM_ATTR_FIELD(layout).set_default("NCHW").describe( + "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Pooling is applied on the 'H' and" + "'W' dimensions."); } }; @@ -580,13 +748,14 @@ struct AdaptivePool2DAttrs : public tvm::AttrsNode { std::string layout; TVM_DECLARE_ATTRS(AdaptivePool2DAttrs, "relay.attrs.AdaptivePool2DAttrs") { - TVM_ATTR_FIELD(output_size).set_default(Array({})) - .describe("Output height and width."); - TVM_ATTR_FIELD(layout).set_default("NCHW") - .describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." - "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" - "dimensions respectively. Pooling is applied on the 'H' and" - "'W' dimensions."); + TVM_ATTR_FIELD(output_size) + .set_default(Array({})) + .describe("Output height and width."); + TVM_ATTR_FIELD(layout).set_default("NCHW").describe( + "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Pooling is applied on the 'H' and" + "'W' dimensions."); } }; @@ -595,17 +764,17 @@ struct AdaptivePool3DAttrs : public tvm::AttrsNode { std::string layout; TVM_DECLARE_ATTRS(AdaptivePool3DAttrs, "relay.attrs.AdaptivePool3DAttrs") { - TVM_ATTR_FIELD(output_size).set_default(Array({})) - .describe("Output depth, height and width."); - TVM_ATTR_FIELD(layout).set_default("NCDHW") - .describe("Dimension ordering of input data. Can be 'NCDHW', 'NDHWC', etc." - "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width" - "dimensions respectively. Pooling is applied on 'D', 'H' and" - "'W' dimensions."); + TVM_ATTR_FIELD(output_size) + .set_default(Array({})) + .describe("Output depth, height and width."); + TVM_ATTR_FIELD(layout).set_default("NCDHW").describe( + "Dimension ordering of input data. Can be 'NCDHW', 'NDHWC', etc." + "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width" + "dimensions respectively. Pooling is applied on 'D', 'H' and" + "'W' dimensions."); } }; - /*! \brief Attributes for 1D max pool operator */ struct MaxPool1DAttrs : public tvm::AttrsNode { Array pool_size; @@ -615,22 +784,24 @@ struct MaxPool1DAttrs : public tvm::AttrsNode { bool ceil_mode; TVM_DECLARE_ATTRS(MaxPool1DAttrs, "relay.attrs.MaxPool1DAttrs") { - TVM_ATTR_FIELD(pool_size) - .describe("Size of the pooling windows."); - TVM_ATTR_FIELD(strides).set_default(Array({1})) - .describe("Specifies the strides of the convolution."); - TVM_ATTR_FIELD(padding).set_default(Array({0})) - .describe("If padding is non-zero, then the input is implicitly zero-padded" - "Padding support both symmetric and asymmetric as" - "one int : same padding used on all sides" - "three int : back, bottom, right will use same padding as front, top, left" - "six int : padding width in the order of (front, top, left, back, bottom, right)"); - TVM_ATTR_FIELD(layout).set_default("NCW") - .describe("Dimension ordering of input data. Can be 'NCW', 'NWC', etc." - "'N', 'C', 'W' stands for batch, channel, and width" - "dimensions respectively. Pooling is applied on the 'W' dimensions."); - TVM_ATTR_FIELD(ceil_mode).set_default(false) - .describe("When true, will use ceil instead of floor to compute the output shape."); + TVM_ATTR_FIELD(pool_size).describe("Size of the pooling windows."); + TVM_ATTR_FIELD(strides) + .set_default(Array({1})) + .describe("Specifies the strides of the convolution."); + TVM_ATTR_FIELD(padding) + .set_default(Array({0})) + .describe( + "If padding is non-zero, then the input is implicitly zero-padded" + "Padding support both symmetric and asymmetric as" + "one int : same padding used on all sides" + "three int : back, bottom, right will use same padding as front, top, left" + "six int : padding width in the order of (front, top, left, back, bottom, right)"); + TVM_ATTR_FIELD(layout).set_default("NCW").describe( + "Dimension ordering of input data. Can be 'NCW', 'NWC', etc." + "'N', 'C', 'W' stands for batch, channel, and width" + "dimensions respectively. Pooling is applied on the 'W' dimensions."); + TVM_ATTR_FIELD(ceil_mode).set_default(false).describe( + "When true, will use ceil instead of floor to compute the output shape."); } }; @@ -644,28 +815,30 @@ struct AvgPool1DAttrs : public tvm::AttrsNode { bool count_include_pad; TVM_DECLARE_ATTRS(AvgPool1DAttrs, "relay.attrs.AvgPool1DAttrs") { - TVM_ATTR_FIELD(pool_size) - .describe("Size of the pooling windows."); - TVM_ATTR_FIELD(strides).set_default(Array({1})) - .describe("Specifies the strides of the convolution."); - TVM_ATTR_FIELD(padding).set_default(Array({0})) - .describe("If padding is non-zero, then the input is implicitly zero-padded" - "Padding support both symmetric and asymmetric as" - "one int : same padding used on all sides" - "three int : back, bottom, right will use same padding as front, top, left" - "six int : padding width in the order of (front, top, left, back, bottom, right)"); - TVM_ATTR_FIELD(layout).set_default("NCW") - .describe("Dimension ordering of input data. Can be 'NCW', 'NHC', etc." - "'N', 'C', 'W' stands for batch, channel, and width" - "dimensions respectively. Pooling is applied on the 'W' dimension."); - TVM_ATTR_FIELD(ceil_mode).set_default(false) - .describe("When true, will use ceil instead of floor to compute the output shape."); - TVM_ATTR_FIELD(count_include_pad).set_default(false) - .describe("When true, will include padding to compute the average"); + TVM_ATTR_FIELD(pool_size).describe("Size of the pooling windows."); + TVM_ATTR_FIELD(strides) + .set_default(Array({1})) + .describe("Specifies the strides of the convolution."); + TVM_ATTR_FIELD(padding) + .set_default(Array({0})) + .describe( + "If padding is non-zero, then the input is implicitly zero-padded" + "Padding support both symmetric and asymmetric as" + "one int : same padding used on all sides" + "three int : back, bottom, right will use same padding as front, top, left" + "six int : padding width in the order of (front, top, left, back, bottom, right)"); + TVM_ATTR_FIELD(layout).set_default("NCW").describe( + "Dimension ordering of input data. Can be 'NCW', 'NHC', etc." + "'N', 'C', 'W' stands for batch, channel, and width" + "dimensions respectively. Pooling is applied on the 'W' dimension."); + TVM_ATTR_FIELD(ceil_mode).set_default(false).describe( + "When true, will use ceil instead of floor to compute the output shape."); + TVM_ATTR_FIELD(count_include_pad) + .set_default(false) + .describe("When true, will include padding to compute the average"); } }; - /*! \brief Attributes for 3D max pool operator */ struct MaxPool3DAttrs : public tvm::AttrsNode { Array pool_size; @@ -675,23 +848,25 @@ struct MaxPool3DAttrs : public tvm::AttrsNode { bool ceil_mode; TVM_DECLARE_ATTRS(MaxPool3DAttrs, "relay.attrs.MaxPool3DAttrs") { - TVM_ATTR_FIELD(pool_size) - .describe("Size of the pooling windows."); - TVM_ATTR_FIELD(strides).set_default(Array({1, 1, 1})) - .describe("Specifies the strides of the convolution."); - TVM_ATTR_FIELD(padding).set_default(Array({0, 0, 0})) - .describe("If padding is non-zero, then the input is implicitly zero-padded" - "Padding support both symmetric and asymmetric as" - "one int : same padding used on all sides" - "three int : back, bottom, right will use same padding as front, top, left" - "six int : padding width in the order of (front, top, left, back, bottom, right)"); - TVM_ATTR_FIELD(layout).set_default("NCDHW") - .describe("Dimension ordering of input data. Can be 'NCDHW', 'NDHWC', etc." - "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width" - "dimensions respectively. Pooling is applied on the 'D', 'H' and" - "'W' dimensions."); - TVM_ATTR_FIELD(ceil_mode).set_default(false) - .describe("When true, will use ceil instead of floor to compute the output shape."); + TVM_ATTR_FIELD(pool_size).describe("Size of the pooling windows."); + TVM_ATTR_FIELD(strides) + .set_default(Array({1, 1, 1})) + .describe("Specifies the strides of the convolution."); + TVM_ATTR_FIELD(padding) + .set_default(Array({0, 0, 0})) + .describe( + "If padding is non-zero, then the input is implicitly zero-padded" + "Padding support both symmetric and asymmetric as" + "one int : same padding used on all sides" + "three int : back, bottom, right will use same padding as front, top, left" + "six int : padding width in the order of (front, top, left, back, bottom, right)"); + TVM_ATTR_FIELD(layout).set_default("NCDHW").describe( + "Dimension ordering of input data. Can be 'NCDHW', 'NDHWC', etc." + "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width" + "dimensions respectively. Pooling is applied on the 'D', 'H' and" + "'W' dimensions."); + TVM_ATTR_FIELD(ceil_mode).set_default(false).describe( + "When true, will use ceil instead of floor to compute the output shape."); } }; @@ -705,37 +880,38 @@ struct AvgPool3DAttrs : public tvm::AttrsNode { bool count_include_pad; TVM_DECLARE_ATTRS(AvgPool3DAttrs, "relay.attrs.AvgPool3DAttrs") { - TVM_ATTR_FIELD(pool_size) - .describe("Size of the pooling windows."); - TVM_ATTR_FIELD(strides).set_default(Array({1, 1, 1})) - .describe("Specifies the strides of the convolution."); - TVM_ATTR_FIELD(padding).set_default(Array({0, 0, 0})) - .describe("If padding is non-zero, then the input is implicitly zero-padded" - "Padding support both symmetric and asymmetric as" - "one int : same padding used on all sides" - "three int : back, bottom, right will use same padding as front, top, left" - "six int : padding width in the order of (front, top, left, back, bottom, right)"); - TVM_ATTR_FIELD(layout).set_default("NCDHW") - .describe("Dimension ordering of input data. Can be 'NCDHW', 'NDHWC', etc." - "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width" - "dimensions respectively. Pooling is applied on the 'D', 'H' and" - "'W' dimensions."); - TVM_ATTR_FIELD(ceil_mode).set_default(false) - .describe("When true, will use ceil instead of floor to compute the output shape."); - TVM_ATTR_FIELD(count_include_pad).set_default(false) - .describe("When true, will include padding to compute the average"); + TVM_ATTR_FIELD(pool_size).describe("Size of the pooling windows."); + TVM_ATTR_FIELD(strides) + .set_default(Array({1, 1, 1})) + .describe("Specifies the strides of the convolution."); + TVM_ATTR_FIELD(padding) + .set_default(Array({0, 0, 0})) + .describe( + "If padding is non-zero, then the input is implicitly zero-padded" + "Padding support both symmetric and asymmetric as" + "one int : same padding used on all sides" + "three int : back, bottom, right will use same padding as front, top, left" + "six int : padding width in the order of (front, top, left, back, bottom, right)"); + TVM_ATTR_FIELD(layout).set_default("NCDHW").describe( + "Dimension ordering of input data. Can be 'NCDHW', 'NDHWC', etc." + "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width" + "dimensions respectively. Pooling is applied on the 'D', 'H' and" + "'W' dimensions."); + TVM_ATTR_FIELD(ceil_mode).set_default(false).describe( + "When true, will use ceil instead of floor to compute the output shape."); + TVM_ATTR_FIELD(count_include_pad) + .set_default(false) + .describe("When true, will include padding to compute the average"); } }; - /*! \brief Attributes for dense operator */ struct DenseAttrs : public tvm::AttrsNode { IndexExpr units; DataType out_dtype; TVM_DECLARE_ATTRS(DenseAttrs, "relay.attrs.DenseAttrs") { - TVM_ATTR_FIELD(units) - .describe("Number of hidden units of the dense transformation."); + TVM_ATTR_FIELD(units).describe("Number of hidden units of the dense transformation."); // use 0 bits to indicate none. TVM_ATTR_FIELD(out_dtype) @@ -772,21 +948,22 @@ struct UpSamplingAttrs : public tvm::AttrsNode { bool align_corners; TVM_DECLARE_ATTRS(UpSamplingAttrs, "relay.attrs.UpSamplingAttrs") { - TVM_ATTR_FIELD(scale_h) - .describe("The upsampling factor for height"); - TVM_ATTR_FIELD(scale_w) - .describe("The upsampling factor for width"); - TVM_ATTR_FIELD(layout).set_default("NCHW") - .describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." - "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" - "dimensions respectively. Upsampling is applied on the 'H' and" - "'W' dimensions."); - TVM_ATTR_FIELD(method).set_default("nearest_neighbor") - .describe("Specify the mode to use for scaling." - "nearest_neighbor - Nearest Neighbor" - "bilinear - Bilinear Interpolation" - "bicubic - Bicubic Interpolation"); - TVM_ATTR_FIELD(align_corners).set_default(false) + TVM_ATTR_FIELD(scale_h).describe("The upsampling factor for height"); + TVM_ATTR_FIELD(scale_w).describe("The upsampling factor for width"); + TVM_ATTR_FIELD(layout).set_default("NCHW").describe( + "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Upsampling is applied on the 'H' and" + "'W' dimensions."); + TVM_ATTR_FIELD(method) + .set_default("nearest_neighbor") + .describe( + "Specify the mode to use for scaling." + "nearest_neighbor - Nearest Neighbor" + "bilinear - Bilinear Interpolation" + "bicubic - Bicubic Interpolation"); + TVM_ATTR_FIELD(align_corners) + .set_default(false) .describe("Should be true to preserve the values at the corner pixels"); } }; @@ -801,26 +978,27 @@ struct UpSampling3DAttrs : public tvm::AttrsNode { std::string coordinate_transformation_mode; TVM_DECLARE_ATTRS(UpSampling3DAttrs, "relay.attrs.UpSampling3DAttrs") { - TVM_ATTR_FIELD(scale_d) - .describe("The upsampling factor for depth"); - TVM_ATTR_FIELD(scale_h) - .describe("The upsampling factor for height"); - TVM_ATTR_FIELD(scale_w) - .describe("The upsampling factor for width"); - TVM_ATTR_FIELD(layout).set_default("NCDHW") - .describe("Dimension ordering of input data. Can be 'NCDHW', 'NDHWC', etc." - "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width" - "dimensions respectively. Upsampling is applied on the 'D', 'H' and" - "'W' dimensions."); - TVM_ATTR_FIELD(method).set_default("nearest_neighbor") - .describe("Specify the mode to use for scaling." - "nearest_neighbor - Nearest Neighbor" - "trilinear - Trilinear Interpolation"); - TVM_ATTR_FIELD(coordinate_transformation_mode).set_default("half_pixel") - .describe("Describes how to transform the coordinate in the resized tensor" - "to the coordinate in the original tensor." - "Refer to the ONNX Resize operator specification for details" - "Available options are half_pixel, align_corners and asymmetric"); + TVM_ATTR_FIELD(scale_d).describe("The upsampling factor for depth"); + TVM_ATTR_FIELD(scale_h).describe("The upsampling factor for height"); + TVM_ATTR_FIELD(scale_w).describe("The upsampling factor for width"); + TVM_ATTR_FIELD(layout).set_default("NCDHW").describe( + "Dimension ordering of input data. Can be 'NCDHW', 'NDHWC', etc." + "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width" + "dimensions respectively. Upsampling is applied on the 'D', 'H' and" + "'W' dimensions."); + TVM_ATTR_FIELD(method) + .set_default("nearest_neighbor") + .describe( + "Specify the mode to use for scaling." + "nearest_neighbor - Nearest Neighbor" + "trilinear - Trilinear Interpolation"); + TVM_ATTR_FIELD(coordinate_transformation_mode) + .set_default("half_pixel") + .describe( + "Describes how to transform the coordinate in the resized tensor" + "to the coordinate in the original tensor." + "Refer to the ONNX Resize operator specification for details" + "Available options are half_pixel, align_corners and asymmetric"); } }; @@ -831,15 +1009,17 @@ struct PadAttrs : public tvm::AttrsNode { std::string pad_mode; TVM_DECLARE_ATTRS(PadAttrs, "relay.attrs.PadAttrs") { - TVM_ATTR_FIELD(pad_value).set_default(0.0) - .describe("The value used for padding when mode is 'constant'."); - TVM_ATTR_FIELD(pad_width) - .describe("Number of values padded to the edges of each axis, " - "in the format of ((before_1, after_1), ..., (before_N, after_N))"); - TVM_ATTR_FIELD(pad_mode).set_default("constant") - .describe("Padding type to use. \"constant\" pads with constant_value, " - "\"edge\" pads using the edge values of the input array, " - "\"reflect\" pads by reflecting values with respect to the edges."); + TVM_ATTR_FIELD(pad_value).set_default(0.0).describe( + "The value used for padding when mode is 'constant'."); + TVM_ATTR_FIELD(pad_width).describe( + "Number of values padded to the edges of each axis, " + "in the format of ((before_1, after_1), ..., (before_N, after_N))"); + TVM_ATTR_FIELD(pad_mode) + .set_default("constant") + .describe( + "Padding type to use. \"constant\" pads with constant_value, " + "\"edge\" pads using the edge values of the input array, " + "\"reflect\" pads by reflecting values with respect to the edges."); } }; @@ -849,11 +1029,12 @@ struct MirrorPadAttrs : public tvm::AttrsNode { Array > pad_width; TVM_DECLARE_ATTRS(MirrorPadAttrs, "relay.attrs.MirrorPadAttrs") { - TVM_ATTR_FIELD(mode).set_default("SYMMETRIC") - .describe("Specifies how mirroring should be performed."); - TVM_ATTR_FIELD(pad_width) - .describe("Number of values padded to the edges of each axis, " - "in the format of ((before_1, after_1), ..., (before_N, after_N))"); + TVM_ATTR_FIELD(mode) + .set_default("SYMMETRIC") + .describe("Specifies how mirroring should be performed."); + TVM_ATTR_FIELD(pad_width).describe( + "Number of values padded to the edges of each axis, " + "in the format of ((before_1, after_1), ..., (before_N, after_N))"); } }; @@ -862,30 +1043,28 @@ struct LeakyReluAttrs : public tvm::AttrsNode { double alpha; TVM_DECLARE_ATTRS(LeakyReluAttrs, "relay.attrs.LeakyReluAttrs") { - TVM_ATTR_FIELD(alpha).set_lower_bound(0.0).set_default(0.25) - .describe("Slope coefficient for the negative half axis."); + TVM_ATTR_FIELD(alpha).set_lower_bound(0.0).set_default(0.25).describe( + "Slope coefficient for the negative half axis."); } }; - /*! \brief Attributes for prelu operator */ struct PReluAttrs : public tvm::AttrsNode { int axis; TVM_DECLARE_ATTRS(PReluAttrs, "relay.attrs.PReluAttrs") { - TVM_ATTR_FIELD(axis).set_default(1) - .describe("Specify which shape axis the channel is specified."); + TVM_ATTR_FIELD(axis).set_default(1).describe( + "Specify which shape axis the channel is specified."); } }; - /*! \brief Attributes used in dropout operator */ struct DropoutAttrs : public tvm::AttrsNode { double rate; TVM_DECLARE_ATTRS(DropoutAttrs, "relay.attrs.DropoutAttrs") { TVM_ATTR_FIELD(rate) - .describe("Fraction of the input that gets dropped out during training time") - .set_default(0.5); + .describe("Fraction of the input that gets dropped out during training time") + .set_default(0.5); } }; // struct DropoutAttrs @@ -897,24 +1076,22 @@ struct BatchNormAttrs : public tvm::AttrsNode { bool scale; TVM_DECLARE_ATTRS(BatchNormAttrs, "relay.attrs.BatchNormAttrs") { - TVM_ATTR_FIELD(axis) - .describe("Specify which shape axis denotes the channel.") - .set_default(1); + TVM_ATTR_FIELD(axis).describe("Specify which shape axis denotes the channel.").set_default(1); TVM_ATTR_FIELD(epsilon) - .describe("Small float added to variance to avoid dividing by zero") - .set_default(1e-5); + .describe("Small float added to variance to avoid dividing by zero") + .set_default(1e-5); TVM_ATTR_FIELD(center) - .describe("If True, add offset of beta to normalized tensor. If False, beta is ignored") - .set_default(true); + .describe("If True, add offset of beta to normalized tensor. If False, beta is ignored") + .set_default(true); TVM_ATTR_FIELD(scale) - .describe("If True, multiply by gamma. If False, gamma is not used. " - "When the next layer is piecewise linear (also, e.g., nn.relu), " - "this can be disabled since the scaling will be done by the next layer.") - .set_default(true); + .describe( + "If True, multiply by gamma. If False, gamma is not used. " + "When the next layer is piecewise linear (also, e.g., nn.relu), " + "this can be disabled since the scaling will be done by the next layer.") + .set_default(true); } }; // struct BatchNormAttrs - /*! \brief Attributes used in instance_norm operator */ struct InstanceNormAttrs : public tvm::AttrsNode { int axis; @@ -923,21 +1100,18 @@ struct InstanceNormAttrs : public tvm::AttrsNode { bool scale; TVM_DECLARE_ATTRS(InstanceNormAttrs, "relay.attrs.InstanceNormAttrs") { - TVM_ATTR_FIELD(axis) - .describe("Specify which shape axis denotes the channel.") - .set_default(1); + TVM_ATTR_FIELD(axis).describe("Specify which shape axis denotes the channel.").set_default(1); TVM_ATTR_FIELD(epsilon) - .describe("Small float added to variance to avoid dividing by zero") - .set_default(1e-5); - TVM_ATTR_FIELD(center).set_default(true) - .describe("If true, add offset of beta to normalized tensor; " - "otherwise, beta is ignored."); - TVM_ATTR_FIELD(scale).set_default(true) - .describe("If true, multiply by gamma; otherwise, gamma is ignored."); + .describe("Small float added to variance to avoid dividing by zero") + .set_default(1e-5); + TVM_ATTR_FIELD(center).set_default(true).describe( + "If true, add offset of beta to normalized tensor; " + "otherwise, beta is ignored."); + TVM_ATTR_FIELD(scale).set_default(true).describe( + "If true, multiply by gamma; otherwise, gamma is ignored."); } }; // struct InstanceNormAttrs - /*! \brief Attributes used in layer_norm operator */ struct LayerNormAttrs : public tvm::AttrsNode { int axis; @@ -946,18 +1120,39 @@ struct LayerNormAttrs : public tvm::AttrsNode { bool scale; TVM_DECLARE_ATTRS(LayerNormAttrs, "relay.attrs.LayerNormAttrs") { - TVM_ATTR_FIELD(axis).set_default(-1) - .describe("Specify which shape axis denotes the channel."); - TVM_ATTR_FIELD(epsilon).set_default(1e-5) - .describe("Small float added to variance to avoid dividing by zero"); - TVM_ATTR_FIELD(center).set_default(true) - .describe("If true, add offset of beta to normalized tensor; " - "otherwise, beta is ignored."); - TVM_ATTR_FIELD(scale).set_default(true) - .describe("If true, multiply by gamma; otherwise, gamma is ignored."); + TVM_ATTR_FIELD(axis).set_default(-1).describe("Specify which shape axis denotes the channel."); + TVM_ATTR_FIELD(epsilon).set_default(1e-5).describe( + "Small float added to variance to avoid dividing by zero"); + TVM_ATTR_FIELD(center).set_default(true).describe( + "If true, add offset of beta to normalized tensor; " + "otherwise, beta is ignored."); + TVM_ATTR_FIELD(scale).set_default(true).describe( + "If true, multiply by gamma; otherwise, gamma is ignored."); } }; // struct LayerNormAttrs +/*! \brief Attributes used in group_norm operator */ +struct GroupNormAttrs : public tvm::AttrsNode { + int num_groups; + int axis; + double epsilon; + bool center; + bool scale; + + TVM_DECLARE_ATTRS(GroupNormAttrs, "relay.attrs.GroupNormAttrs") { + TVM_ATTR_FIELD(num_groups) + .set_default(0) + .describe("Specify number of groups to separate the channels into."); + TVM_ATTR_FIELD(axis).set_default(1).describe("Specify which shape axis denotes the channel."); + TVM_ATTR_FIELD(epsilon).set_default(1e-5).describe( + "Small float added to variance to avoid dividing by zero"); + TVM_ATTR_FIELD(center).set_default(true).describe( + "If true, add offset of beta to normalized tensor; " + "otherwise, beta is ignored."); + TVM_ATTR_FIELD(scale).set_default(true).describe( + "If true, multiply by gamma; otherwise, gamma is ignored."); + } +}; // struct GroupNormAttrs /*! \brief Attributes for LRN operator */ struct LRNAttrs : public tvm::AttrsNode { @@ -968,34 +1163,26 @@ struct LRNAttrs : public tvm::AttrsNode { double beta; TVM_DECLARE_ATTRS(LRNAttrs, "relay.attrs.LRNAttrs") { - TVM_ATTR_FIELD(size).set_default(5) - .describe("The size of the local region to be considered for normalization."); - TVM_ATTR_FIELD(axis).set_default(1) - .describe("Axis of input data layout channel."); - TVM_ATTR_FIELD(bias).set_default(2) - .describe("The offset parameter to avoid division by 0."); - TVM_ATTR_FIELD(alpha).set_default(0.0001) - .describe("The scaling parameter."); - TVM_ATTR_FIELD(beta).set_default(0.75) - .describe("The exponent parameter."); + TVM_ATTR_FIELD(size).set_default(5).describe( + "The size of the local region to be considered for normalization."); + TVM_ATTR_FIELD(axis).set_default(1).describe("Axis of input data layout channel."); + TVM_ATTR_FIELD(bias).set_default(2).describe("The offset parameter to avoid division by 0."); + TVM_ATTR_FIELD(alpha).set_default(0.0001).describe("The scaling parameter."); + TVM_ATTR_FIELD(beta).set_default(0.75).describe("The exponent parameter."); } }; - /*! \brief Attributes for L2Normalize operator */ struct L2NormalizeAttrs : public tvm::AttrsNode { double eps; Array axis; TVM_DECLARE_ATTRS(L2NormalizeAttrs, "relay.attrs.L2NormalizeAttrs") { - TVM_ATTR_FIELD(eps) - .describe("A lower bound value for the norm, to avoid division by 0."); - TVM_ATTR_FIELD(axis) - .describe("Axis over the normalization applied."); + TVM_ATTR_FIELD(eps).describe("A lower bound value for the norm, to avoid division by 0."); + TVM_ATTR_FIELD(axis).describe("Axis over the normalization applied."); } }; - /*! \brief Attributes for DeformableConv2D operator */ struct DeformableConv2DAttrs : public tvm::AttrsNode { Array strides; @@ -1011,46 +1198,59 @@ struct DeformableConv2DAttrs : public tvm::AttrsNode { DataType out_dtype; TVM_DECLARE_ATTRS(DeformableConv2DAttrs, "relay.attrs.DeformableConv2DAttrs") { - TVM_ATTR_FIELD(strides).set_default(Array({1, 1})) + TVM_ATTR_FIELD(strides) + .set_default(Array({1, 1})) .describe("Specifies the strides of the convolution."); - TVM_ATTR_FIELD(padding).set_default(Array({0, 0})) - .describe("If padding is non-zero, then the input is implicitly zero-padded" - "Padding support both symmetric and asymmetric as" - "one int : same padding used on all sides" - "two int : bottom, right will use same padding as top, left" - "four int : padding width in the order of (top, left, bottom, right)"); - TVM_ATTR_FIELD(dilation).set_default(Array({1, 1})) + TVM_ATTR_FIELD(padding) + .set_default(Array({0, 0})) + .describe( + "If padding is non-zero, then the input is implicitly zero-padded" + "Padding support both symmetric and asymmetric as" + "one int : same padding used on all sides" + "two int : bottom, right will use same padding as top, left" + "four int : padding width in the order of (top, left, bottom, right)"); + TVM_ATTR_FIELD(dilation) + .set_default(Array({1, 1})) .describe("Specifies the dilation rate to use for dilated convolution."); - TVM_ATTR_FIELD(deformable_groups).set_default(1) - .describe("Controls the connections between inputs and offsets." - "Input channels are partitioned into multiple deformable groups. Offsets" - "are shared across input channels in the same deformable group."); - TVM_ATTR_FIELD(groups).set_default(1) - .describe("Controls the connections between inputs and outputs." - "At groups=1, all inputs are convolved to all outputs." - "At groups=2, the operation becomes equivalent to having two convolution" - "layers side by side, each seeing half the input channels, and producing" - "half the output channels, and both subsequently concatenated."); + TVM_ATTR_FIELD(deformable_groups) + .set_default(1) + .describe( + "Controls the connections between inputs and offsets." + "Input channels are partitioned into multiple deformable groups. Offsets" + "are shared across input channels in the same deformable group."); + TVM_ATTR_FIELD(groups).set_default(1).describe( + "Controls the connections between inputs and outputs." + "At groups=1, all inputs are convolved to all outputs." + "At groups=2, the operation becomes equivalent to having two convolution" + "layers side by side, each seeing half the input channels, and producing" + "half the output channels, and both subsequently concatenated."); TVM_ATTR_FIELD(channels) - .describe("The number of output channels in the convolution." - " If it is not set, inferred by shape of the weight.") + .describe( + "The number of output channels in the convolution." + " If it is not set, inferred by shape of the weight.") .set_default(NullValue()); TVM_ATTR_FIELD(kernel_size) .describe("Specifies the dimensions of the convolution window.") .set_default(NullValue >()); - TVM_ATTR_FIELD(data_layout).set_default("NCHW") - .describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." - "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" - "dimensions respectively. Convolution is applied on the 'H' and" - "'W' dimensions."); - TVM_ATTR_FIELD(kernel_layout).set_default("OIHW") - .describe("Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc." - "'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width" - "dimensions respectively."); - TVM_ATTR_FIELD(out_layout).set_default("") - .describe("Dimension ordering of output. Can be 'NCHW', 'NHWC', etc." - "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" - "dimensions respectively. Default to be same as input layout."); + TVM_ATTR_FIELD(data_layout) + .set_default("NCHW") + .describe( + "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Convolution is applied on the 'H' and" + "'W' dimensions."); + TVM_ATTR_FIELD(kernel_layout) + .set_default("OIHW") + .describe( + "Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc." + "'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width" + "dimensions respectively."); + TVM_ATTR_FIELD(out_layout) + .set_default("") + .describe( + "Dimension ordering of output. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Default to be same as input layout."); // use 0 bits to indicate none. TVM_ATTR_FIELD(out_dtype) @@ -1079,6 +1279,36 @@ struct SubPixelAttrs : public tvm::AttrsNode { } }; // struct SubPixelAttrs +/*! \brief Attributes used in correlation operators */ +struct CorrelationAttrs : public tvm::AttrsNode { + int kernel_size; + int max_displacement; + int stride1; + int stride2; + Array padding; + bool is_multiply; + String layout; + + TVM_DECLARE_ATTRS(CorrelationAttrs, "relay.attrs.CorrelationAttrs") { + TVM_ATTR_FIELD(kernel_size) + .describe("Kernel size for correlation, must be an odd number.") + .set_default(1); + TVM_ATTR_FIELD(max_displacement).describe("Max displacement of Correlation.").set_default(1); + TVM_ATTR_FIELD(stride1).describe("Stride for data1.").set_default(1); + TVM_ATTR_FIELD(stride2).describe("Stride for data2.").set_default(1); + TVM_ATTR_FIELD(padding) + .describe("Padding for data1 and data2.") + .set_default(Array{0, 0}); + TVM_ATTR_FIELD(is_multiply) + .describe("Operation type is either multiplication or substraction.") + .set_default(true); + TVM_ATTR_FIELD(layout).set_default("NCHW").describe( + "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively."); + } +}; // struct CorrelationAttrs + } // namespace relay } // namespace tvm #endif // TVM_RELAY_ATTRS_NN_H_ diff --git a/include/tvm/relay/attrs/reduce.h b/include/tvm/relay/attrs/reduce.h index 443efb5b0c32..f57c1f4ddc58 100644 --- a/include/tvm/relay/attrs/reduce.h +++ b/include/tvm/relay/attrs/reduce.h @@ -25,6 +25,7 @@ #define TVM_RELAY_ATTRS_REDUCE_H_ #include + #include namespace tvm { @@ -37,7 +38,8 @@ struct ReduceAttrs : public tvm::AttrsNode { bool exclude; TVM_DECLARE_ATTRS(ReduceAttrs, "relay.attrs.ReduceAttrs") { - TVM_ATTR_FIELD(axis).set_default(NullValue>()) + TVM_ATTR_FIELD(axis) + .set_default(NullValue>()) .describe(R"code(The axis or axes along which to perform the reduction. The default, `axis=()`, will compute over all elements into a @@ -51,11 +53,11 @@ struct ReduceAttrs : public tvm::AttrsNode { If `exclude` is true, reduction will be performed on the axes that are NOT in axis instead.)code"); - TVM_ATTR_FIELD(keepdims).set_default(false) - .describe("If this is set to `True`, the reduced axes are left " - "in the result as dimension with size one."); - TVM_ATTR_FIELD(exclude).set_default(false) - .describe("Whether to perform reduction on axis that are NOT in axis instead."); + TVM_ATTR_FIELD(keepdims).set_default(false).describe( + "If this is set to `True`, the reduced axes are left " + "in the result as dimension with size one."); + TVM_ATTR_FIELD(exclude).set_default(false).describe( + "Whether to perform reduction on axis that are NOT in axis instead."); } }; } // namespace relay diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index ae2ac11b1e53..cbc60340d924 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -27,6 +27,7 @@ #include #include #include + #include namespace tvm { @@ -37,8 +38,7 @@ struct CastAttrs : public tvm::AttrsNode { DataType dtype; TVM_DECLARE_ATTRS(CastAttrs, "relay.attrs.CastAttrs") { - TVM_ATTR_FIELD(dtype) - .describe("Target data type"); + TVM_ATTR_FIELD(dtype).describe("Target data type"); } }; // struct CastAttrs. @@ -48,11 +48,11 @@ struct ExpandDimsAttrs : public tvm::AttrsNode { int num_newaxis; TVM_DECLARE_ATTRS(ExpandDimsAttrs, "relay.attrs.ExpandDimsAttrs") { - TVM_ATTR_FIELD(axis) - .describe("The axis at which the input array is expanded." - "Should lie in range `[-data.ndim - 1, data.ndim]`." - "If `axis < 0`, it is the first axis inserted;" - "If `axis >= 0`, it is the last axis inserted in Python's negative indexing."); + TVM_ATTR_FIELD(axis).describe( + "The axis at which the input array is expanded." + "Should lie in range `[-data.ndim - 1, data.ndim]`." + "If `axis < 0`, it is the first axis inserted;" + "If `axis >= 0`, it is the last axis inserted in Python's negative indexing."); TVM_ATTR_FIELD(num_newaxis) .describe("Number of axises to be inserted. Should be >= 0.") .set_lower_bound(0) @@ -65,8 +65,9 @@ struct ConcatenateAttrs : public tvm::AttrsNode { int axis; TVM_DECLARE_ATTRS(ConcatenateAttrs, "relay.attrs.ConcatenateAttrs") { TVM_ATTR_FIELD(axis) - .describe("The axis at which the input arrays are concatenated." - "Should lie in range `[-ndim, ndim)`.") + .describe( + "The axis at which the input arrays are concatenated." + "Should lie in range `[-ndim, ndim)`.") .set_default(0); } }; // struct ConcatenateAttrs @@ -75,50 +76,65 @@ struct ConcatenateAttrs : public tvm::AttrsNode { struct TransposeAttrs : public tvm::AttrsNode { Array axes; TVM_DECLARE_ATTRS(TransposeAttrs, "relay.attrs.TransposeAttrs") { - TVM_ATTR_FIELD(axes) - .describe("The target axes order, reverse order if not specified."); + TVM_ATTR_FIELD(axes).describe("The target axes order, reverse order if not specified."); } }; // struct TransposeAttrs /*! \brief Attributes used in reshape operators */ struct ReshapeAttrs : public tvm::AttrsNode { - Array newshape; + Optional> newshape; bool reverse; TVM_DECLARE_ATTRS(ReshapeAttrs, "relay.attrs.ReshapeAttrs") { - TVM_ATTR_FIELD(newshape) - .describe("The new shape. Should be compatible with the original shape."); + TVM_ATTR_FIELD(newshape).describe( + "The new shape. Should be compatible with the original shape."); TVM_ATTR_FIELD(reverse) .describe("Infer the special values from right to left if true") .set_default(false); } }; // struct ReshapeAttrs +struct ScatterAttrs : public tvm::AttrsNode { + Integer axis; + + TVM_DECLARE_ATTRS(ScatterAttrs, "relay.attrs.ScatterAttrs") { + TVM_ATTR_FIELD(axis).set_default(0).describe("The axis over which to select values."); + } +}; + +struct GatherAttrs : public tvm::AttrsNode { + Integer axis; + + TVM_DECLARE_ATTRS(GatherAttrs, "relay.attrs.GatherAttrs") { + TVM_ATTR_FIELD(axis) + .set_default(NullValue()) + .describe("The axis over which to select values."); + } +}; + struct TakeAttrs : public tvm::AttrsNode { Integer axis; std::string mode; TVM_DECLARE_ATTRS(TakeAttrs, "relay.attrs.TakeAttrs") { - TVM_ATTR_FIELD(axis).set_default(NullValue()) + TVM_ATTR_FIELD(axis) + .set_default(NullValue()) .describe("The axis over which to select values."); - TVM_ATTR_FIELD(mode).set_default("clip") - .describe("Specify how out-of-bound indices will behave." - "clip - clip to the range (default)" - "wrap - wrap around the indices" - "fast - no clip or wrap around (user must make sure indices are in-bound)"); + TVM_ATTR_FIELD(mode).set_default("clip").describe( + "Specify how out-of-bound indices will behave." + "clip - clip to the range (default)" + "wrap - wrap around the indices" + "fast - no clip or wrap around (user must make sure indices are in-bound)"); } }; /*! \brief Attributes that specify a tensor */ struct InitOpAttrs : public tvm::AttrsNode { - Array shape; + Optional> shape; DataType dtype; TVM_DECLARE_ATTRS(InitOpAttrs, "relay.attrs.InitOpAttrs") { - TVM_ATTR_FIELD(shape) - .describe("Target shape."); - TVM_ATTR_FIELD(dtype) - .describe("Target data type.") - .set_default(NullValue()); + TVM_ATTR_FIELD(shape).describe("Target shape."); + TVM_ATTR_FIELD(dtype).describe("Target data type.").set_default(NullValue()); } }; // struct InitOpAttrs @@ -130,14 +146,10 @@ struct ArangeAttrs : public tvm::AttrsNode { DataType dtype; TVM_DECLARE_ATTRS(ArangeAttrs, "relay.attrs.ArangeAttrs") { - TVM_ATTR_FIELD(start) - .describe("Start of interval. The interval includes this value."); - TVM_ATTR_FIELD(stop) - .describe("Stop of interval. The interval does not include this value."); - TVM_ATTR_FIELD(step) - .describe("Spacing between values."); - TVM_ATTR_FIELD(dtype) - .describe("Target data type."); + TVM_ATTR_FIELD(start).describe("Start of interval. The interval includes this value."); + TVM_ATTR_FIELD(stop).describe("Stop of interval. The interval does not include this value."); + TVM_ATTR_FIELD(step).describe("Spacing between values."); + TVM_ATTR_FIELD(dtype).describe("Target data type."); } }; // struct ArangeAttrs @@ -145,8 +157,8 @@ struct ArangeAttrs : public tvm::AttrsNode { struct StackAttrs : public tvm::AttrsNode { Integer axis; TVM_DECLARE_ATTRS(StackAttrs, "relay.attrs.StackAttrs") { - TVM_ATTR_FIELD(axis).set_default(0) - .describe("The axis in the result array along which the input arrays are stacked."); + TVM_ATTR_FIELD(axis).set_default(0).describe( + "The axis in the result array along which the input arrays are stacked."); } }; // struct StackAttrs @@ -155,9 +167,9 @@ struct RepeatAttrs : public tvm::AttrsNode { Integer repeats; Integer axis; TVM_DECLARE_ATTRS(RepeatAttrs, "relay.attrs.RepeatAttrs") { - TVM_ATTR_FIELD(repeats) - .describe("The number of repetitions for each element."); - TVM_ATTR_FIELD(axis).set_default(NullValue()) + TVM_ATTR_FIELD(repeats).describe("The number of repetitions for each element."); + TVM_ATTR_FIELD(axis) + .set_default(NullValue()) .describe(" The axis along which to repeat values."); } }; // struct RepeatAttrs @@ -166,9 +178,9 @@ struct RepeatAttrs : public tvm::AttrsNode { struct TileAttrs : public tvm::AttrsNode { Array reps; TVM_DECLARE_ATTRS(TileAttrs, "relay.attrs.TileAttrs") { - TVM_ATTR_FIELD(reps) - .describe("The number of times for repeating the tensor a." - "Each dim sizeof reps must be a positive integer."); + TVM_ATTR_FIELD(reps).describe( + "The number of times for repeating the tensor a." + "Each dim sizeof reps must be a positive integer."); } }; // struct TileAttrs @@ -176,7 +188,8 @@ struct TileAttrs : public tvm::AttrsNode { struct ReverseAttrs : public tvm::AttrsNode { Integer axis; TVM_DECLARE_ATTRS(ReverseAttrs, "relay.attrs.ReverseAttrs") { - TVM_ATTR_FIELD(axis).set_default(NullValue()) + TVM_ATTR_FIELD(axis) + .set_default(NullValue()) .describe("The axis along which to reverse elements."); } }; // struct ReverseAttrs @@ -188,11 +201,12 @@ struct SqueezeAttrs : public tvm::AttrsNode { TVM_DECLARE_ATTRS(SqueezeAttrs, "relay.attrs.SqueezeAttrs") { TVM_ATTR_FIELD(axis) - .describe("The axis to squeeze in the input tensor." - "If `axis = None`, all axis of dimension 1 get squeezed;" - "Else, the dimension in axes get squeezed." - "It is an error if an axis does not has dimension 1.") - .set_default(NullValue >()); + .describe( + "The axis to squeeze in the input tensor." + "If `axis = None`, all axis of dimension 1 get squeezed;" + "Else, the dimension in axes get squeezed." + "It is an error if an axis does not has dimension 1.") + .set_default(NullValue>()); } }; // struct SqueezeAttrs @@ -202,29 +216,36 @@ struct SplitAttrs : public tvm::AttrsNode { TVM_DECLARE_ATTRS(SplitAttrs, "relay.attrs.SplitAttrs") { TVM_ATTR_FIELD(indices_or_sections) - .describe("Indices or sections to split into. Accepts an int or a tuple" - "If indices_or_sections is an integer, the input will be divided equally" - "along given axis. If such a split is not possible, an error is raised." - "If indices_or_sections is a tuple of sorted integers," - "the entries indicate where along axis the array is split."); - TVM_ATTR_FIELD(axis).set_default(0) - .describe("the axis to be splitted."); + .describe( + "Indices or sections to split into. Accepts an int or a tuple" + "If indices_or_sections is an integer, the input will be divided equally" + "along given axis. If such a split is not possible, an error is raised." + "If indices_or_sections is a tuple of sorted integers," + "the entries indicate where along axis the array is split."); + TVM_ATTR_FIELD(axis).set_default(0).describe("the axis to be splitted."); } }; /*! \brief Attributes for StridedSlice operator */ struct StridedSliceAttrs : public tvm::AttrsNode { - Array begin; - Array end; - Array strides; + Optional> begin; + Optional> end; + Optional> strides; + std::string slice_mode; TVM_DECLARE_ATTRS(StridedSliceAttrs, "relay.attrs.StridedSliceAttrs") { - TVM_ATTR_FIELD(begin) - .describe("Indices for begin of slice, begin index is also inclusive"); - TVM_ATTR_FIELD(end) - .describe("Indices for end of slice, end index is exclusive"); - TVM_ATTR_FIELD(strides).set_default(Array({})) - .describe("Stride values of the slice"); + TVM_ATTR_FIELD(begin).describe("Indices for begin of slice, begin index is also inclusive"); + TVM_ATTR_FIELD(end).describe("Indices for end of slice, end index is exclusive"); + TVM_ATTR_FIELD(strides).describe( + "Stride values of the slice, a stride can be negative, which causes a reverse slice."); + TVM_ATTR_FIELD(slice_mode) + .set_default("end") + .describe( + "The slice mode [end, size]." + "end - The default slice mode, ending indices for the slice." + "size - The input strides will be ignored, input end in this mode indicates the size" + "of a slice starting at the location specified by begin. If end[i] is -1," + "all remaining elements in that dimension are included in the slice"); } }; @@ -232,10 +253,10 @@ struct SliceLikeAttrs : public tvm::AttrsNode { Array axes; TVM_DECLARE_ATTRS(SliceLikeAttrs, "relay.attrs.SliceLikeAttrs") { - TVM_ATTR_FIELD(axes) - .describe("List of axes on which input data will be sliced according to the " - "corresponding size of the second input. By default will slice " - "on all axes. Negative axes mean counting in reverse."); + TVM_ATTR_FIELD(axes).describe( + "List of axes on which input data will be sliced according to the " + "corresponding size of the second input. By default will slice " + "on all axes. Negative axes mean counting in reverse."); } }; @@ -245,10 +266,8 @@ struct ClipAttrs : public tvm::AttrsNode { double a_max; TVM_DECLARE_ATTRS(ClipAttrs, "relay.attrs.ClipAttrs") { - TVM_ATTR_FIELD(a_min) - .describe("The minimum clip value."); - TVM_ATTR_FIELD(a_max) - .describe("The maximum clip value."); + TVM_ATTR_FIELD(a_min).describe("The minimum clip value."); + TVM_ATTR_FIELD(a_max).describe("The maximum clip value."); } }; @@ -258,10 +277,8 @@ struct LayoutTransformAttrs : public tvm::AttrsNode { std::string dst_layout; TVM_DECLARE_ATTRS(LayoutTransformAttrs, "relay.attrs.LayoutTransformAttrs") { - TVM_ATTR_FIELD(src_layout) - .describe("The source layout of the tensor. (e.g. NCHW)"); - TVM_ATTR_FIELD(dst_layout) - .describe("The destination layout of the tensor. (e.g. NCHW16c)"); + TVM_ATTR_FIELD(src_layout).describe("The source layout of the tensor. (e.g. NCHW)"); + TVM_ATTR_FIELD(dst_layout).describe("The destination layout of the tensor. (e.g. NCHW16c)"); } }; @@ -270,9 +287,7 @@ struct ShapeOfAttrs : public tvm::AttrsNode { DataType dtype; TVM_DECLARE_ATTRS(ShapeOfAttrs, "relay.attrs.ShapeOfAttrs") { - TVM_ATTR_FIELD(dtype) - .describe("Target data type") - .set_default(NullValue()); + TVM_ATTR_FIELD(dtype).describe("Target data type").set_default(NullValue()); } }; @@ -281,21 +296,27 @@ struct SequenceMaskAttrs : public tvm::AttrsNode { int axis; TVM_DECLARE_ATTRS(SequenceMaskAttrs, "relay.attrs.SequenceMaskAttrs") { - TVM_ATTR_FIELD(mask_value).set_default(0) - .describe("The masking value."); - TVM_ATTR_FIELD(axis).set_default(0) - .describe("The axis of the length dimension. Can only be 0 or 1."); + TVM_ATTR_FIELD(mask_value).set_default(0).describe("The masking value."); + TVM_ATTR_FIELD(axis).set_default(0).describe( + "The axis of the length dimension. Can only be 0 or 1."); } }; // struct SequenceMaskAttrs. +/*! \brief Attributes used in sparse_to_dense operator */ +struct SparseToDenseAttrs : public tvm::AttrsNode { + Array output_shape; + + TVM_DECLARE_ATTRS(SparseToDenseAttrs, "relay.attrs.SparseToDenseAttrs") { + TVM_ATTR_FIELD(output_shape).describe("Shape of the dense output tensor"); + } +}; // struct SparseToDenseAttrs + /*! \brief Attributes for ndarray_size operator */ struct NdarraySizeAttrs : public tvm::AttrsNode { DataType dtype; TVM_DECLARE_ATTRS(NdarraySizeAttrs, "relay.attrs.NdarraySizeAttrs") { - TVM_ATTR_FIELD(dtype) - .describe("Target data type") - .set_default(NullValue()); + TVM_ATTR_FIELD(dtype).describe("Target data type").set_default(NullValue()); } }; @@ -306,12 +327,9 @@ struct OneHotAttrs : public tvm::AttrsNode { DataType dtype; TVM_DECLARE_ATTRS(OneHotAttrs, "relay.attrs.OneHotAttrs") { - TVM_ATTR_FIELD(depth).set_default(1) - .describe("Depth of the one hot dimension."); - TVM_ATTR_FIELD(axis).set_default(-1) - .describe("Axis to fill."); - TVM_ATTR_FIELD(dtype).set_default(NullValue()) - .describe("Output data type."); + TVM_ATTR_FIELD(depth).set_default(1).describe("Depth of the one hot dimension."); + TVM_ATTR_FIELD(axis).set_default(-1).describe("Axis to fill."); + TVM_ATTR_FIELD(dtype).set_default(NullValue()).describe("Output data type."); } }; // struct OneHotAttrs diff --git a/include/tvm/relay/attrs/vision.h b/include/tvm/relay/attrs/vision.h index c4a30ce8b159..550e24b8de26 100644 --- a/include/tvm/relay/attrs/vision.h +++ b/include/tvm/relay/attrs/vision.h @@ -26,6 +26,7 @@ #include #include + #include namespace tvm { @@ -41,39 +42,32 @@ struct MultiBoxPriorAttrs : public tvm::AttrsNode { TVM_DECLARE_ATTRS(MultiBoxPriorAttrs, "relay.attrs.MultiBoxPriorAttrs") { TVM_ATTR_FIELD(sizes) - .set_default(Array({static_cast(1.0)})) - .describe("List of sizes of generated MultiBoxPriores."); + .set_default(Array({static_cast(1.0)})) + .describe("List of sizes of generated MultiBoxPriores."); TVM_ATTR_FIELD(ratios) - .set_default(Array({static_cast(1.0)})) - .describe("List of aspect ratios of generated MultiBoxPriores."); + .set_default(Array({static_cast(1.0)})) + .describe("List of aspect ratios of generated MultiBoxPriores."); TVM_ATTR_FIELD(steps) - .set_default(Array({static_cast(-1.0), - static_cast(-1.0)})) - .describe("Priorbox step across y and x, -1 for auto calculation."); + .set_default(Array({static_cast(-1.0), static_cast(-1.0)})) + .describe("Priorbox step across y and x, -1 for auto calculation."); TVM_ATTR_FIELD(offsets) - .set_default(Array({static_cast(0.5), - static_cast(0.5)})) - .describe("Priorbox center offsets, y and x respectively."); - TVM_ATTR_FIELD(clip).set_default(false) - .describe("Whether to clip out-of-boundary boxes."); + .set_default(Array({static_cast(0.5), static_cast(0.5)})) + .describe("Priorbox center offsets, y and x respectively."); + TVM_ATTR_FIELD(clip).set_default(false).describe("Whether to clip out-of-boundary boxes."); } }; -struct MultiBoxTransformLocAttrs - : public tvm::AttrsNode { +struct MultiBoxTransformLocAttrs : public tvm::AttrsNode { bool clip; double threshold; Array variances; - TVM_DECLARE_ATTRS(MultiBoxTransformLocAttrs, - "relay.attrs.MultiBoxTransformLocAttrs") { - TVM_ATTR_FIELD(clip).set_default(true) - .describe("Clip out-of-boundary boxes."); - TVM_ATTR_FIELD(threshold).set_default(0.01) - .describe("Threshold to be a positive prediction."); + TVM_DECLARE_ATTRS(MultiBoxTransformLocAttrs, "relay.attrs.MultiBoxTransformLocAttrs") { + TVM_ATTR_FIELD(clip).set_default(true).describe("Clip out-of-boundary boxes."); + TVM_ATTR_FIELD(threshold).set_default(0.01).describe("Threshold to be a positive prediction."); TVM_ATTR_FIELD(variances) - .set_default(Array({0.1f, 0.1f , 0.2f, 0.2f})) - .describe("Variances to be decoded from box regression output."); + .set_default(Array({0.1f, 0.1f, 0.2f, 0.2f})) + .describe("Variances to be decoded from box regression output."); } }; @@ -84,12 +78,11 @@ struct GetValidCountsAttrs : public tvm::AttrsNode { int score_index; TVM_DECLARE_ATTRS(GetValidCountsAttrs, "relay.attrs.GetValidCountsAttrs") { - TVM_ATTR_FIELD(score_threshold).set_default(0.0) - .describe("Lower limit of score for valid bounding boxes."); - TVM_ATTR_FIELD(id_index).set_default(0) - .describe("Axis index of id."); - TVM_ATTR_FIELD(score_index).set_default(1) - .describe("Index of the scores/confidence of boxes."); + TVM_ATTR_FIELD(score_threshold) + .set_default(0.0) + .describe("Lower limit of score for valid bounding boxes."); + TVM_ATTR_FIELD(id_index).set_default(0).describe("Axis index of id."); + TVM_ATTR_FIELD(score_index).set_default(1).describe("Index of the scores/confidence of boxes."); } }; @@ -106,25 +99,30 @@ struct NonMaximumSuppressionAttrs : public tvm::AttrsNode { Integer stride; TVM_DECLARE_ATTRS(YoloReorgAttrs, "relay.attrs.YoloReorgAttrs") { - TVM_ATTR_FIELD(stride) - .set_default(1) - .describe("Stride value for yolo reorg"); + TVM_ATTR_FIELD(stride).set_default(1).describe("Stride value for yolo reorg"); } }; @@ -206,10 +202,8 @@ struct ProposalAttrs : public tvm::AttrsNode { .describe( "The size of the receptive field each unit in the convolution layer of the rpn," "for example the product of all stride's prior to this layer."); - TVM_ATTR_FIELD(threshold) - .set_default(0.7) - .describe( - "IoU threshold of non-maximum suppresion (suppress boxes with IoU >= this threshold)"); + TVM_ATTR_FIELD(threshold).set_default(0.7).describe( + "IoU threshold of non-maximum suppresion (suppress boxes with IoU >= this threshold)"); TVM_ATTR_FIELD(rpn_pre_nms_top_n) .set_default(6000) .describe("Number of top scoring boxes to apply NMS. -1 to use all boxes"); diff --git a/include/tvm/relay/base.h b/include/tvm/relay/base.h index 1d0120675e99..eeef7cd7bdb3 100644 --- a/include/tvm/relay/base.h +++ b/include/tvm/relay/base.h @@ -24,10 +24,10 @@ #ifndef TVM_RELAY_BASE_H_ #define TVM_RELAY_BASE_H_ - #include -#include #include +#include + #include #include @@ -42,17 +42,19 @@ namespace tvm { */ namespace relay { -#define RELAY_DEBUG(...) \ -{ auto fdebug = runtime::Registry::Get("relay.debug"); \ - CHECK(fdebug) << "Could not find Relay Python debugger function."; \ - (*fdebug)("RELAY_DEBUG", __FILE__, __LINE__, __VA_ARGS__); \ -} +#define RELAY_DEBUG(...) \ + { \ + auto fdebug = runtime::Registry::Get("relay.debug"); \ + CHECK(fdebug) << "Could not find Relay Python debugger function."; \ + (*fdebug)("RELAY_DEBUG", __FILE__, __LINE__, __VA_ARGS__); \ + } -#define RELAY_DEBUG_INTERP(...) \ -{ auto fdebug = runtime::Registry::Get("relay.debug_interp"); \ - CHECK(fdebug) << "Could not find Relay Python debugger function."; \ - (*fdebug)("RELAY_DEBUG", __FILE__, __LINE__, __VA_ARGS__); \ -} +#define RELAY_DEBUG_INTERP(...) \ + { \ + auto fdebug = runtime::Registry::Get("relay.debug_interp"); \ + CHECK(fdebug) << "Could not find Relay Python debugger function."; \ + (*fdebug)("RELAY_DEBUG", __FILE__, __LINE__, __VA_ARGS__); \ + } /*! * \brief Symbolic expression for tensor shape. @@ -91,11 +93,9 @@ class IdNode : public Object { * this only acts as a hint to the user, * and is not used for equality. */ - std::string name_hint; + String name_hint; - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("name_hint", &name_hint); - } + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("name_hint", &name_hint); } static constexpr const char* _type_key = "relay.Id"; TVM_DECLARE_FINAL_OBJECT_INFO(IdNode, Object); @@ -107,7 +107,7 @@ class Id : public ObjectRef { * \brief The constructor * \param name_hint The name of the variable. */ - TVM_DLL explicit Id(std::string name_hint); + TVM_DLL explicit Id(String name_hint); TVM_DEFINE_OBJECT_REF_METHODS(Id, ObjectRef, IdNode); }; diff --git a/include/tvm/relay/dataflow_matcher.h b/include/tvm/relay/dataflow_matcher.h new file mode 100644 index 000000000000..bb53ad32d9f4 --- /dev/null +++ b/include/tvm/relay/dataflow_matcher.h @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relay/dataflow_matcher.h + * \brief A pattern matcher for matching dataflow properties. + */ +#ifndef TVM_RELAY_DATAFLOW_MATCHER_H_ +#define TVM_RELAY_DATAFLOW_MATCHER_H_ + +#include +#include + +#include +#include +#include + +namespace tvm { +namespace relay { + +class DFPatternCallback; +/*! + * \brief Base type of all dataflow pattern callbacks. + * \sa DFPatternCallback + */ +class DFPatternCallbackNode : public Object { + public: + /*! \brief Pattern this callback matches */ + DFPattern pattern_; + /*! \brief Function to call when finding a matched expression */ + PackedFunc function_; + + void VisitAttrs(tvm::AttrVisitor* v) {} + + static constexpr const char* _type_key = "DFPatternCallbackNode"; + TVM_DECLARE_BASE_OBJECT_INFO(DFPatternCallbackNode, Object); +}; + +/*! + * \brief Managed reference to dataflow pattern callbacks. + * \sa DFPatternCallbackNode + */ +class DFPatternCallback : public ObjectRef { + public: + TVM_DLL DFPatternCallback(DFPattern pattern, PackedFunc callback); + TVM_DEFINE_OBJECT_REF_METHODS(DFPatternCallback, ObjectRef, DFPatternCallbackNode); +}; + +/*! + * \brief Determine if a pattern matches an expression + * + * \param pattern The pattern to match + * \param expr The expression to match + * + * \return Return true if the pattern and the expression match, return false otherwise. + */ +bool MatchPattern(DFPattern pattern, Expr expr); + +/*! + * \brief Rewrite an expression based on some number of DFPatternCallbacks + * + * \param callbacks An array of DFPatternCallback Nodes + * \param expr The expression to rewrite + * + * \return Return An Expr with every match of the pattern inside the callbacks rewritten by the + * functions inside the callbacks + */ +Expr RewritePatterns(Array callbacks, Expr expr); + +/*! + * \brief Partition all matches of a DFPattern inside an Expr into separate Function calls + * + * \param pattern The pattern to match + * \param expr The expression to patition + * \param attrs A set of parameter names and values to apply to the partitioned function + * \param check A callback function for checking more complicated properties of the matched + * expressions, returns true if the match is accepted and false otherwise + * + * \return Return the paritioned Expr. + */ +Expr PartitionPattern(DFPattern pattern, Expr expr, Map attrs, PackedFunc check); + +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_DATAFLOW_MATCHER_H_ diff --git a/include/tvm/relay/dataflow_pattern.h b/include/tvm/relay/dataflow_pattern.h new file mode 100644 index 000000000000..11ac7e39f4a3 --- /dev/null +++ b/include/tvm/relay/dataflow_pattern.h @@ -0,0 +1,435 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relay/dataflow_pattern.h + * \brief A pattern language for matching dataflow properties. + */ +#ifndef TVM_RELAY_DATAFLOW_PATTERN_H_ +#define TVM_RELAY_DATAFLOW_PATTERN_H_ + +#include +#include + +namespace tvm { +namespace relay { + +/*! + * \brief Base type of all dataflow patterns. + * \sa DFPattern + */ +class DFPatternNode : public Object { + public: + static constexpr const char* _type_key = "DFPatternNode"; + TVM_DECLARE_BASE_OBJECT_INFO(DFPatternNode, Object); +}; + +/*! + * \brief Managed reference to dataflow patterns. + * \sa DFPatternNode + */ +class DFPattern : public ObjectRef { + public: + TVM_DEFINE_OBJECT_REF_METHODS(DFPattern, ObjectRef, DFPatternNode); +}; + +/*! + * \brief Pattern for Relay Expression. + */ +class ExprPatternNode : public DFPatternNode { + public: + /*! \brief The expression to match. */ + Expr expr; + + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("expr", &expr); } + + static constexpr const char* _type_key = "relay.dataflow_pattern.ExprPattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(ExprPatternNode, DFPatternNode); +}; + +/*! + * \brief A pattern which matches a literal expression. + * + * \note Uses structural equality on expressions to check equality. + * + */ +class ExprPattern : public DFPattern { + public: + TVM_DLL explicit ExprPattern(Expr expr); + TVM_DEFINE_OBJECT_REF_METHODS(ExprPattern, DFPattern, ExprPatternNode); +}; + +/*! + * \brief A Pattern to Match a Relay Variable + */ +class VarPattern; +/*! \brief Container for Var */ +class VarPatternNode : public DFPatternNode { + public: + /*! + * \brief The name of the Var (optional). + */ + String name; + /*! + * \brief type annotation of the variable. + * This field records user provided type annotation of the Var. + * This field is optional and can be None. + */ + Type type_annotation; + + /*! \return The name hint of the variable */ + const String& name_hint() const { return name; } + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("name", &name); + v->Visit("type_annotation", &type_annotation); + } + + static constexpr const char* _type_key = "relay.dataflow_pattern.VarPattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(VarPatternNode, DFPatternNode); +}; + +class VarPattern : public DFPattern { + public: + TVM_DLL VarPattern(String name_hint, Type type_annotation); + TVM_DEFINE_OBJECT_REF_METHODS(VarPattern, DFPattern, VarPatternNode); +}; + +/*! + * \brief A Pattern to Match a Relay Constant + */ +class ConstantPattern; +/*! \brief Container for Constant */ +class ConstantPatternNode : public DFPatternNode { + public: + void VisitAttrs(tvm::AttrVisitor* v) {} + + static constexpr const char* _type_key = "relay.dataflow_pattern.ConstantPattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(ConstantPatternNode, DFPatternNode); +}; + +class ConstantPattern : public DFPattern { + public: + TVM_DEFINE_OBJECT_REF_METHODS(ConstantPattern, DFPattern, ConstantPatternNode); +}; + +/*! + * \brief Call corresponds to operator invocation. + * Corresponds to the operator in computational graph terminology. + */ +class CallPattern; +/*! \brief CallPattern container. */ +class CallPatternNode : public DFPatternNode { + public: + /*! + * \brief The operator(function) being invoked + * + * - It can be relay::Op which corresponds to the primitive operators. + * - It can also be user defined functions (Function, GlobalVar, Var). + */ + DFPattern op; + + /*! \brief The arguments(inputs) of the call */ + tvm::Array args; + + /*! \brief The additional attributes */ + Attrs attrs; + + /*! + * \brief The type arguments passed to polymorphic(template) function. + * + * This is the advance feature that is only used when the function is + * polymorphic. It is safe to be ignored in most cases. For example, in the + * following code, the type_args of addone call is [int]. + * + * \code + * + * template + * T addone(T a) { return a + 1; } + * + * void main() { + * int x = addone(10); + * } + * + * \endcode + */ + tvm::Array type_args; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("op", &op); + v->Visit("args", &args); + v->Visit("attrs", &attrs); + v->Visit("type_args", &type_args); + } + + static constexpr const char* _type_key = "relay.dataflow_pattern.CallPattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(CallPatternNode, DFPatternNode); +}; + +class CallPattern : public DFPattern { + public: + TVM_DLL CallPattern(DFPattern op, Array args, Attrs attrs, Array type_args); + TVM_DEFINE_OBJECT_REF_METHODS(CallPattern, DFPattern, CallPatternNode); +}; + +/*! \brief Tuple of multiple Exprs */ +class TuplePattern; +/*! \brief Tuple container */ +class TuplePatternNode : public DFPatternNode { + public: + /*! \brief the fields of the tuple */ + tvm::Array fields; + + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("fields", &fields); } + + static constexpr const char* _type_key = "relay.dataflow_pattern.TuplePattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(TuplePatternNode, DFPatternNode); +}; + +class TuplePattern : public DFPattern { + public: + TVM_DLL explicit TuplePattern(tvm::Array fields); + TVM_DEFINE_OBJECT_REF_METHODS(TuplePattern, DFPattern, TuplePatternNode); +}; + +/*! \brief Get index-th field out of a tuple. */ +class TupleGetItemPattern; +class TupleGetItemPatternNode : public DFPatternNode { + public: + /*! \brief The tuple Expression */ + DFPattern tuple; + /*! \brief which value to get */ + int index; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("tuple", &tuple); + v->Visit("index", &index); + } + + static constexpr const char* _type_key = "relay.dataflow_pattern.TupleGetItemPattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(TupleGetItemPatternNode, DFPatternNode); +}; + +class TupleGetItemPattern : public DFPattern { + public: + TVM_DLL TupleGetItemPattern(DFPattern tuple, int index); + TVM_DEFINE_OBJECT_REF_METHODS(TupleGetItemPattern, DFPattern, TupleGetItemPatternNode); +}; + +class AltPattern; +/*! + * \brief Pattern for Alternate Expressions. + */ +class AltPatternNode : public DFPatternNode { + public: + /*! \brief The left optional pattern. */ + DFPattern left; + /*! \brief The right optional pattern. */ + DFPattern right; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("left", &left); + v->Visit("right", &right); + } + + static constexpr const char* _type_key = "relay.dataflow_pattern.AltPattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(AltPatternNode, DFPatternNode); +}; + +/*! + * \brief A pattern which matches either of two patterns + */ +class AltPattern : public DFPattern { + public: + TVM_DLL AltPattern(DFPattern left, DFPattern right); + TVM_DEFINE_OBJECT_REF_METHODS(AltPattern, DFPattern, AltPatternNode); +}; + +/*! + * \brief Wildcard Pattern. + */ +class WildcardPatternNode : public DFPatternNode { + public: + void VisitAttrs(tvm::AttrVisitor* v) {} + + static constexpr const char* _type_key = "relay.dataflow_pattern.WildcardPattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(WildcardPatternNode, DFPatternNode); +}; + +/*! + * \brief A pattern which matches anything. + */ +class WildcardPattern : public DFPattern { + public: + TVM_DEFINE_OBJECT_REF_METHODS(WildcardPattern, DFPattern, WildcardPatternNode); +}; + +class TypePattern; +/*! + * \brief Pattern for Types. + */ +class TypePatternNode : public DFPatternNode { + public: + /*! \brief The pattern. */ + DFPattern pattern; + /*! \brief The type to match */ + Type type; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("pattern", &pattern); + v->Visit("type", &type); + } + + static constexpr const char* _type_key = "relay.dataflow_pattern.TypePattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(TypePatternNode, DFPatternNode); +}; + +/*! + * \brief A pattern which matches a type in another pattern + */ +class TypePattern : public DFPattern { + public: + TVM_DLL TypePattern(DFPattern pattern, Type type); + TVM_DEFINE_OBJECT_REF_METHODS(TypePattern, DFPattern, TypePatternNode); +}; + +class ShapePattern; +/*! + * \brief Pattern for Shapes. + */ +class ShapePatternNode : public DFPatternNode { + public: + /*! \brief The pattern. */ + DFPattern pattern; + /*! \brief The type to match */ + Array shape; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("pattern", &pattern); + v->Visit("shape", &shape); + } + + static constexpr const char* _type_key = "relay.dataflow_pattern.ShapePattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(ShapePatternNode, DFPatternNode); +}; + +/*! + * \brief A pattern which matches a type in another pattern + */ +class ShapePattern : public DFPattern { + public: + TVM_DLL ShapePattern(DFPattern pattern, Array type); + TVM_DEFINE_OBJECT_REF_METHODS(ShapePattern, DFPattern, ShapePatternNode); +}; + +class DataTypePattern; +/*! + * \brief Pattern for Types. + */ +class DataTypePatternNode : public DFPatternNode { + public: + /*! \brief The pattern. */ + DFPattern pattern; + /*! \brief The type to match */ + DataType dtype; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("pattern", &pattern); + v->Visit("dtype", &dtype); + } + + static constexpr const char* _type_key = "relay.dataflow_pattern.DataTypePattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(DataTypePatternNode, DFPatternNode); +}; + +/*! + * \brief A pattern which matches a type in another pattern + */ +class DataTypePattern : public DFPattern { + public: + TVM_DLL DataTypePattern(DFPattern pattern, DataType dtype); + TVM_DEFINE_OBJECT_REF_METHODS(DataTypePattern, DFPattern, DataTypePatternNode); +}; + +class AttrPattern; +/*! + * \brief Pattern for Attributes. + */ +class AttrPatternNode : public DFPatternNode { + public: + /*! \brief The pattern. */ + DFPattern pattern; + /*! \brief The attribute to match */ + Attrs attrs; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("pattern", &pattern); + v->Visit("attrs", &attrs); + } + + static constexpr const char* _type_key = "relay.dataflow_pattern.AttrPattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(AttrPatternNode, DFPatternNode); +}; + +/*! + * \brief A pattern which matches attributes in another pattern + */ +class AttrPattern : public DFPattern { + public: + TVM_DLL AttrPattern(DFPattern pattern, Attrs attrs); + TVM_DEFINE_OBJECT_REF_METHODS(AttrPattern, DFPattern, AttrPatternNode); +}; + +class DominatorPattern; +/*! + * \brief Dominated Graph Pattern + * Pattern for fuzzy subgraphs where all outputs of the parent are used finally by the child, and + * every operation between the parent and the child matches the path. + */ +class DominatorPatternNode : public DFPatternNode { + public: + /*! \brief The parent. */ + DFPattern parent; + /*! \brief The path. */ + DFPattern path; + /*! \brief The child. */ + DFPattern child; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("parent", &parent); + v->Visit("path", &path); + v->Visit("child", &child); + } + + static constexpr const char* _type_key = "relay.dataflow_pattern.DominatorPattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(DominatorPatternNode, DFPatternNode); +}; + +/*! + * \brief A pattern which matches a variable length dominator path + */ +class DominatorPattern : public DFPattern { + public: + TVM_DLL DominatorPattern(DFPattern parent, DFPattern path, DFPattern child); + TVM_DEFINE_OBJECT_REF_METHODS(DominatorPattern, DFPattern, DominatorPatternNode); +}; + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_DATAFLOW_PATTERN_H_ diff --git a/include/tvm/relay/dataflow_pattern_functor.h b/include/tvm/relay/dataflow_pattern_functor.h new file mode 100644 index 000000000000..98c81c929409 --- /dev/null +++ b/include/tvm/relay/dataflow_pattern_functor.h @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relay/dataflow_pattern_functor.h + * \brief A set of passes for operating on pattern graphs. + */ +#ifndef TVM_RELAY_DATAFLOW_PATTERN_FUNCTOR_H_ +#define TVM_RELAY_DATAFLOW_PATTERN_FUNCTOR_H_ + +#include + +#include +#include + +namespace tvm { +namespace relay { + +/*! + * \brief A dynamical functor that dispatches on in the first DFPattern argument. + * + * \tparam FType function signature + * This type is only defined for FType with function signature R(const DFPattern&, + * Args...) + */ +template +class DFPatternFunctor; + +// functions to be overriden. +#define DFPATTERN_FUNCTOR_DEFAULT \ + { return VisitDFPatternDefault_(op, std::forward(args)...); } + +#define RELAY_DFPATTERN_FUNCTOR_DISPATCH(OP) \ + vtable.template set_dispatch([](const ObjectRef& n, TSelf* self, Args... args) { \ + return self->VisitDFPattern_(static_cast(n.get()), std::forward(args)...); \ + }); + +template +class DFPatternFunctor { + private: + using TSelf = DFPatternFunctor; + using FType = tvm::NodeFunctor; + + public: + /*! \brief virtual destructor */ + virtual ~DFPatternFunctor() {} + /*! + * \brief Same as call. + * \param n The expression node. + * \param args Additional arguments. + * \return The result of the call + */ + R operator()(const DFPattern& n, Args... args) { + return VisitDFPattern(n, std::forward(args)...); + } + /*! + * \brief The functor call. + * \param n The expression node. + * \param args Additional arguments. + * \return The result of the call + */ + virtual R VisitDFPattern(const DFPattern& n, Args... args) { + CHECK(n.defined()); + static FType vtable = InitVTable(); + return vtable(n, this, std::forward(args)...); + } + // Functions that can be overriden by subclass + virtual R VisitDFPattern_(const AltPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const AttrPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const CallPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const DataTypePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const DominatorPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const ExprPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const ShapePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const TupleGetItemPatternNode* op, + Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const TuplePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const TypePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const VarPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const ConstantPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const WildcardPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPatternDefault_(const Object* op, Args...) { + LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); + throw; + } + + private: + // initialize the vtable. + static FType InitVTable() { + FType vtable; + // Set dispatch + RELAY_DFPATTERN_FUNCTOR_DISPATCH(AltPatternNode); + RELAY_DFPATTERN_FUNCTOR_DISPATCH(AttrPatternNode); + RELAY_DFPATTERN_FUNCTOR_DISPATCH(CallPatternNode); + RELAY_DFPATTERN_FUNCTOR_DISPATCH(ConstantPatternNode); + RELAY_DFPATTERN_FUNCTOR_DISPATCH(DataTypePatternNode); + RELAY_DFPATTERN_FUNCTOR_DISPATCH(DominatorPatternNode); + RELAY_DFPATTERN_FUNCTOR_DISPATCH(ExprPatternNode); + RELAY_DFPATTERN_FUNCTOR_DISPATCH(ShapePatternNode); + RELAY_DFPATTERN_FUNCTOR_DISPATCH(TupleGetItemPatternNode); + RELAY_DFPATTERN_FUNCTOR_DISPATCH(TuplePatternNode); + RELAY_DFPATTERN_FUNCTOR_DISPATCH(TypePatternNode); + RELAY_DFPATTERN_FUNCTOR_DISPATCH(VarPatternNode); + RELAY_DFPATTERN_FUNCTOR_DISPATCH(WildcardPatternNode); + return vtable; + } +}; + +/*! + * \brief A simple visitor wrapper around DFPatternFunctor. + * Recursively visit the content. + * + * DFPatternVisitor treats the Pattern as dataflow graph,and only visit each Expr node once. + */ +class DFPatternVisitor : public DFPatternFunctor { + public: + void VisitDFPattern(const DFPattern& pattern) override; + void VisitDFPattern_(const AltPatternNode* op) override; + void VisitDFPattern_(const AttrPatternNode* op) override; + void VisitDFPattern_(const CallPatternNode* op) override; + void VisitDFPattern_(const ConstantPatternNode* op) override; + void VisitDFPattern_(const DataTypePatternNode* op) override; + void VisitDFPattern_(const DominatorPatternNode* op) override; + void VisitDFPattern_(const ExprPatternNode* op) override; + void VisitDFPattern_(const ShapePatternNode* op) override; + void VisitDFPattern_(const TupleGetItemPatternNode* op) override; + void VisitDFPattern_(const TuplePatternNode* op) override; + void VisitDFPattern_(const TypePatternNode* op) override; + void VisitDFPattern_(const VarPatternNode* op) override; + void VisitDFPattern_(const WildcardPatternNode* op) override; + + protected: + // set of already-visited nodes + std::unordered_set visited_; +}; + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_DATAFLOW_PATTERN_FUNCTOR_H_ diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index fe240c30e471..779bcc34272f 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -26,10 +26,12 @@ #include #include -#include #include -#include +#include + #include +#include + #include "./base.h" #include "./type.h" @@ -63,9 +65,7 @@ class ConstantNode : public ExprNode { TensorType tensor_type() const; /*! \return Whether it is scalar(rank-0 tensor) */ - bool is_scalar() const { - return data->ndim == 0; - } + bool is_scalar() const { return data->ndim == 0; } void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("data", &data); @@ -77,9 +77,7 @@ class ConstantNode : public ExprNode { return equal(data, other->data); } - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(data); - } + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(data); } static constexpr const char* _type_key = "relay.Constant"; TVM_DECLARE_FINAL_OBJECT_INFO(ConstantNode, ExprNode); @@ -172,9 +170,7 @@ class VarNode : public ExprNode { Type type_annotation; /*! \return The name hint of the variable */ - const std::string& name_hint() const { - return vid->name_hint; - } + const String& name_hint() const { return vid->name_hint; } void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("vid", &vid); @@ -184,9 +180,7 @@ class VarNode : public ExprNode { } bool SEqualReduce(const VarNode* other, SEqualReducer equal) const { - return - equal(type_annotation, other->type_annotation) && - equal.FreeVarEqualImpl(this, other); + return equal(type_annotation, other->type_annotation) && equal.FreeVarEqualImpl(this, other); } void SHashReduce(SHashReducer hash_reduce) const { @@ -194,11 +188,9 @@ class VarNode : public ExprNode { hash_reduce.FreeVarHashImpl(this); } - TVM_DLL static Var make(std::string name_hint, - Type type_annotation); + TVM_DLL static Var make(String name_hint, Type type_annotation); - TVM_DLL static Var make(Id vid, - Type type_annotation); + TVM_DLL static Var make(Id vid, Type type_annotation); static constexpr const char* _type_key = "relay.Var"; TVM_DECLARE_FINAL_OBJECT_INFO(VarNode, ExprNode); @@ -211,8 +203,7 @@ class Var : public Expr { * \param name_hint The name hint of a variable. * \param type_annotation The type annotation of a variable. */ - TVM_DLL Var(std::string name_hint, Type type_annotation) : - Var(Id(name_hint), type_annotation) {} + TVM_DLL Var(String name_hint, Type type_annotation) : Var(Id(name_hint), type_annotation) {} /*! * \brief The constructor @@ -278,11 +269,8 @@ class CallNode : public ExprNode { bool SEqualReduce(const CallNode* other, SEqualReducer equal) const { // skip type_args check for primitive ops. equal->MarkGraphNode(); - return - equal(op, other->op) && - equal(args, other->args) && - equal(attrs, other->attrs) && - (IsPrimitiveOp(op) || equal(type_args, other->type_args)); + return equal(op, other->op) && equal(args, other->args) && equal(attrs, other->attrs) && + (IsPrimitiveOp(op) || equal(type_args, other->type_args)); } void SHashReduce(SHashReducer hash_reduce) const { @@ -308,9 +296,7 @@ class Call : public Expr { * \param attrs The attributes of the call node. * \param type_args The type arguments passed to a polymorphic function. */ - TVM_DLL Call(Expr op, - Array args, - Attrs attrs = Attrs(), + TVM_DLL Call(Expr op, Array args, Attrs attrs = Attrs(), Array type_args = Array()); TVM_DEFINE_OBJECT_REF_METHODS(Call, RelayExpr, CallNode); @@ -348,10 +334,8 @@ class LetNode : public ExprNode { bool SEqualReduce(const LetNode* other, SEqualReducer equal) const { equal->MarkGraphNode(); - return - equal.DefEqual(var, other->var) && - equal(value, other->value) && - equal(body, other->body); + return equal.DefEqual(var, other->var) && equal(value, other->value) && + equal(body, other->body); } void SHashReduce(SHashReducer hash_reduce) const { @@ -410,10 +394,8 @@ class IfNode : public ExprNode { bool SEqualReduce(const IfNode* other, SEqualReducer equal) const { equal->MarkGraphNode(); - return - equal(cond, other->cond) && - equal(true_branch, other->true_branch) && - equal(false_branch, other->false_branch); + return equal(cond, other->cond) && equal(true_branch, other->true_branch) && + equal(false_branch, other->false_branch); } void SHashReduce(SHashReducer hash_reduce) const { @@ -457,9 +439,7 @@ class TupleGetItemNode : public ExprNode { } bool SEqualReduce(const TupleGetItemNode* other, SEqualReducer equal) const { - return - equal(tuple, other->tuple) && - equal(index, other->index); + return equal(tuple, other->tuple) && equal(index, other->index); } void SHashReduce(SHashReducer hash_reduce) const { @@ -576,9 +556,7 @@ class RefWriteNode : public ExprNode { bool SEqualReduce(const RefWriteNode* other, SEqualReducer equal) const { equal->MarkGraphNode(); - return - equal(ref, other->ref) && - equal(value, other->value); + return equal(ref, other->ref) && equal(value, other->value); } void SHashReduce(SHashReducer hash_reduce) const { @@ -630,6 +608,7 @@ class TempExprNode : public ExprNode { static constexpr const char* _type_key = "relay.TempExpr"; static constexpr const bool _type_has_method_sequal_reduce = false; static constexpr const bool _type_has_method_shash_reduce = false; + static constexpr const uint32_t _type_child_slots = 0; TVM_DECLARE_BASE_OBJECT_INFO(TempExprNode, ExprNode); }; diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h index 04b275431f2b..1189643c8181 100644 --- a/include/tvm/relay/expr_functor.h +++ b/include/tvm/relay/expr_functor.h @@ -25,16 +25,16 @@ #ifndef TVM_RELAY_EXPR_FUNCTOR_H_ #define TVM_RELAY_EXPR_FUNCTOR_H_ -#include #include +#include +#include #include #include -#include #include #include -#include #include +#include namespace tvm { namespace relay { @@ -54,15 +54,13 @@ template class ExprFunctor; // functions to be overriden. -#define EXPR_FUNCTOR_DEFAULT \ +#define EXPR_FUNCTOR_DEFAULT \ { return VisitExprDefault_(op, std::forward(args)...); } -#define RELAY_EXPR_FUNCTOR_DISPATCH(OP) \ - vtable.template set_dispatch( \ - [](const ObjectRef& n, TSelf* self, Args... args) { \ - return self->VisitExpr_(static_cast(n.get()), \ - std::forward(args)...); \ - }); +#define RELAY_EXPR_FUNCTOR_DISPATCH(OP) \ + vtable.template set_dispatch([](const ObjectRef& n, TSelf* self, Args... args) { \ + return self->VisitExpr_(static_cast(n.get()), std::forward(args)...); \ + }); template class ExprFunctor { @@ -81,9 +79,7 @@ class ExprFunctor { * \param args Additional arguments. * \return The result of the call */ - R operator()(const Expr& n, Args... args) { - return VisitExpr(n, std::forward(args)...); - } + R operator()(const Expr& n, Args... args) { return VisitExpr(n, std::forward(args)...); } /*! * \brief The functor call. * \param n The expression node. @@ -96,22 +92,15 @@ class ExprFunctor { return vtable(n, this, std::forward(args)...); } // Functions that can be overriden by subclass - virtual R VisitExpr_(const ConstantNode* op, - Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const TupleNode* op, - Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const VarNode* op, - Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const GlobalVarNode* op, - Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const FunctionNode* op, - Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const ConstantNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const TupleNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const VarNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const GlobalVarNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const FunctionNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const CallNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const LetNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const IfNode* op, - Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const OpNode* op, - Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const IfNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const OpNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const TupleGetItemNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const RefCreateNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const RefReadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; @@ -154,8 +143,7 @@ class ExprFunctor { * ExprVisitor treats Expr as dataflow graph, * and only visit each Expr node once. */ -class ExprVisitor - : public ::tvm::relay::ExprFunctor { +class ExprVisitor : public ::tvm::relay::ExprFunctor { public: void VisitExpr(const Expr& expr) override; void VisitExpr_(const VarNode* op) override; @@ -189,16 +177,13 @@ class ExprVisitor * The mutated results are memoized in a map and reused so that * local transformation on the dataflow preserves the graph structure. */ -class ExprMutator - : public ::tvm::relay::ExprFunctor { +class ExprMutator : public ::tvm::relay::ExprFunctor { public: /*! * \brief Mutate is alias for VisitExpr * \return expr. */ - Expr Mutate(const Expr& expr) { - return this->VisitExpr(expr); - } + Expr Mutate(const Expr& expr) { return this->VisitExpr(expr); } Expr VisitExpr(const Expr& expr) override; Expr VisitExpr_(const VarNode* op) override; Expr VisitExpr_(const ConstantNode* op) override; @@ -229,7 +214,7 @@ class ExprMutator protected: /*! \brief Internal map used for memoization. */ - std::unordered_map memo_; + std::unordered_map memo_; }; /*! @@ -283,7 +268,8 @@ class MixedModeVisitor : public ::tvm::relay::ExprVisitor { * recursion to traverse most forms of the IR, but under the hood it expands nested dataflow regions * of the graph and processes them iteratatively to prevent stack overflows * - * Uses Rewrite_ API of ExprRewriter for a cleaner split between recrusive and non-recursive behavior. + * Uses Rewrite_ API of ExprRewriter for a cleaner split between recrusive and non-recursive + * behavior. */ class MixedModeMutator : public ::tvm::relay::ExprMutator { public: @@ -293,14 +279,14 @@ class MixedModeMutator : public ::tvm::relay::ExprMutator { Expr VisitExpr_(const CallNode* call_node) final { return Rewrite(call_node); }; Expr VisitExpr_(const TupleGetItemNode* op) final { return Rewrite(op); }; /*! - * \brief Users should override Rewrite_ methods to implement their pass. Rewrite_ functions will be - * able to rewrite the op only with data about the original node `pre` and the same node with + * \brief Users should override Rewrite_ methods to implement their pass. Rewrite_ functions will + * be able to rewrite the op only with data about the original node `pre` and the same node with * modified inputs `post` and should not recurse. * * \param pre The expression node before rewriting. * \param post The expression with rewritten inputs. */ - virtual Expr Rewrite_(const TupleNode* pre, const Expr& post) { return post;} + virtual Expr Rewrite_(const TupleNode* pre, const Expr& post) { return post; } virtual Expr Rewrite_(const CallNode* pre, const Expr& post) { return post; } virtual Expr Rewrite_(const TupleGetItemNode* pre, const Expr& post) { return post; } @@ -350,9 +336,7 @@ class ExprRewriter { * \param post The expression node with rewritten inputs. * \return The result of the call */ - Expr operator()(const Expr& pre, const Expr& post) { - return Rewrite(pre, post); - } + Expr operator()(const Expr& pre, const Expr& post) { return Rewrite(pre, post); } /*! * \brief The functor call. * \param pre The expression node before rewriting. diff --git a/include/tvm/relay/feature.h b/include/tvm/relay/feature.h index 744d7c4e111c..3783e320f57c 100644 --- a/include/tvm/relay/feature.h +++ b/include/tvm/relay/feature.h @@ -24,9 +24,9 @@ #ifndef TVM_RELAY_FEATURE_H_ #define TVM_RELAY_FEATURE_H_ +#include #include #include -#include #include @@ -65,9 +65,7 @@ class FeatureSet { public: FeatureSet(const FeatureSet&) = default; /*! \brief A singleton set containing a single Feature. */ - explicit FeatureSet(Feature ft) { - bs_.set(static_cast(ft)); - } + explicit FeatureSet(Feature ft) { bs_.set(static_cast(ft)); } explicit FeatureSet(const tvm::Array& ft) { for (Integer i : ft) { (*this) += Feature(static_cast(i)); @@ -93,25 +91,25 @@ class FeatureSet { FeatureSet fs; return fs; } - template + template FeatureSet& operator+=(const T& rhs) { bs_ |= FeatureSet(rhs).bs_; return *this; } /*! \brief Set union. */ - template + template FeatureSet operator+(const T& rhs) const { FeatureSet fs(*this); fs += rhs; return fs; } - template + template FeatureSet& operator-=(const T& rhs) { bs_ &= ~(FeatureSet(rhs)).bs_; return *this; } /*! \brief Set difference. */ - template + template FeatureSet operator-(const T& rhs) const { FeatureSet fs(*this); fs -= rhs; @@ -124,14 +122,12 @@ class FeatureSet { * * \return true only if this is a subset of rhs. */ - bool is_subset_of(const FeatureSet& rhs) const { - return ((*this) - rhs).bs_.none(); - } + bool is_subset_of(const FeatureSet& rhs) const { return ((*this) - rhs).bs_.none(); } private: std::bitset bs_; FeatureSet() = default; - explicit FeatureSet(const std::bitset& bs) : bs_(bs) { } + explicit FeatureSet(const std::bitset& bs) : bs_(bs) {} }; /*! diff --git a/include/tvm/relay/function.h b/include/tvm/relay/function.h index 33b813b76f18..d52a66cdadeb 100644 --- a/include/tvm/relay/function.h +++ b/include/tvm/relay/function.h @@ -26,8 +26,8 @@ #include #include -#include +#include namespace tvm { namespace relay { @@ -71,12 +71,9 @@ class FunctionNode : public BaseFuncNode { bool SEqualReduce(const FunctionNode* other, SEqualReducer equal) const { // Important to make def equal first. equal->MarkGraphNode(); - return - equal.DefEqual(params, other->params) && - equal.DefEqual(type_params, other->type_params) && - equal(ret_type, other->ret_type) && - equal(attrs, other->attrs) && - equal(body, other->body); + return equal.DefEqual(params, other->params) && + equal.DefEqual(type_params, other->type_params) && equal(ret_type, other->ret_type) && + equal(attrs, other->attrs) && equal(body, other->body); } void SHashReduce(SHashReducer hash_reduce) const { @@ -100,7 +97,6 @@ class FunctionNode : public BaseFuncNode { TVM_DECLARE_FINAL_OBJECT_INFO(FunctionNode, BaseFuncNode); }; - /*! * \brief Managed reference to FunctionNode. * \sa FunctionNode @@ -115,10 +111,7 @@ class Function : public BaseFunc { * \param ty_params The type parameters. * \param attrs Additional function attributes. */ - TVM_DLL Function(tvm::Array params, - Expr body, - Type ret_type, - tvm::Array ty_params, + TVM_DLL Function(tvm::Array params, Expr body, Type ret_type, tvm::Array ty_params, tvm::DictAttrs attrs = NullValue()); TVM_DEFINE_OBJECT_REF_METHODS(Function, BaseFunc, FunctionNode); @@ -148,6 +141,8 @@ constexpr const char* kSkipOptimization = "SkipOptimization"; constexpr const char* kComposite = "Composite"; /*! \brief Mark the function to be inlined. */ constexpr const char* kInline = "Inline"; +/*! \brief Indicate the function was created by the Pattern Partitioning Pass. */ +constexpr const char* kPartitionedFromPattern = "PartitionedFromPattern"; } // namespace attr } // namespace relay diff --git a/include/tvm/relay/interpreter.h b/include/tvm/relay/interpreter.h index ae1f84a616a4..bda73ed3a51b 100644 --- a/include/tvm/relay/interpreter.h +++ b/include/tvm/relay/interpreter.h @@ -36,12 +36,11 @@ #include #include -#include #include +#include #include #include - namespace tvm { namespace relay { @@ -64,8 +63,8 @@ namespace relay { * \param target Compiler target flag to compile the functions on the context. * \return A function that takes in an expression and returns a value. */ -runtime::TypedPackedFunc -CreateInterpreter(IRModule mod, DLContext context, Target target); +runtime::TypedPackedFunc CreateInterpreter(IRModule mod, DLContext context, + Target target); /*! \brief The container type of Closures used by the interpreter. */ class InterpreterClosureObj : public runtime::vm::ClosureObj { @@ -96,8 +95,7 @@ class InterpreterClosureObj : public runtime::vm::ClosureObj { class InterpreterClosure : public runtime::vm::Closure { public: TVM_DLL InterpreterClosure(tvm::Map env, Function func); - TVM_DEFINE_OBJECT_REF_METHODS(InterpreterClosure, runtime::vm::Closure, - InterpreterClosureObj); + TVM_DEFINE_OBJECT_REF_METHODS(InterpreterClosure, runtime::vm::Closure, InterpreterClosureObj); }; /*! \brief The container type of RecClosure. */ @@ -130,9 +128,7 @@ struct RefValueObj : Object { RefValueObj() {} - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("value", &value); - } + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("value", &value); } static constexpr const char* _type_key = "relay.RefValue"; TVM_DECLARE_FINAL_OBJECT_INFO(RefValueObj, Object); @@ -164,9 +160,7 @@ struct ConstructorValueObj : Object { class ConstructorValue : public ObjectRef { public: - TVM_DLL ConstructorValue(int32_t tag, - tvm::Array fields, - Constructor construtor = {}); + TVM_DLL ConstructorValue(int32_t tag, tvm::Array fields, Constructor construtor = {}); TVM_DEFINE_OBJECT_REF_METHODS(ConstructorValue, ObjectRef, ConstructorValueObj); }; diff --git a/include/tvm/relay/op.h b/include/tvm/relay/op.h index fa47da226dff..12845158a22f 100644 --- a/include/tvm/relay/op.h +++ b/include/tvm/relay/op.h @@ -25,8 +25,8 @@ #define TVM_RELAY_OP_H_ #include -#include #include +#include namespace tvm { namespace relay { @@ -34,8 +34,7 @@ namespace relay { using Op = tvm::Op; using OpNode = tvm::OpNode; -#define RELAY_REGISTER_OP(OpName) \ - TVM_REGISTER_OP(OpName) +#define RELAY_REGISTER_OP(OpName) TVM_REGISTER_OP(OpName) } // namespace relay } // namespace tvm diff --git a/include/tvm/relay/op_attr_types.h b/include/tvm/relay/op_attr_types.h index 5b2fdd3ab4e1..acd4a03aed03 100644 --- a/include/tvm/relay/op_attr_types.h +++ b/include/tvm/relay/op_attr_types.h @@ -24,21 +24,22 @@ #ifndef TVM_RELAY_OP_ATTR_TYPES_H_ #define TVM_RELAY_OP_ATTR_TYPES_H_ -#include -#include -#include #include -#include +#include #include +#include +#include +#include #include + #include namespace tvm { namespace relay { +using tir::BijectiveLayoutNode; using tir::Layout; using tir::LayoutAxis; -using tir::BijectiveLayoutNode; /*! \brief operator pattern used in graph fusion */ enum OpPatternKind { @@ -104,10 +105,8 @@ using TShapeDataDependant = bool; & these are always placeholders. * \return The output compute description of the operator. */ -using FTVMCompute = runtime::TypedPackedFunc< - Array(const Attrs& attrs, - const Array& inputs, - const Type& out_type)>; +using FTVMCompute = runtime::TypedPackedFunc( + const Attrs& attrs, const Array& inputs, const Type& out_type)>; /*! * \brief Build the computation schedule for @@ -118,10 +117,8 @@ using FTVMCompute = runtime::TypedPackedFunc< * \param target The build target. * \return schedule The computation schedule. */ -using FTVMSchedule = runtime::TypedPackedFunc< - te::Schedule(const Attrs& attrs, - const Array& outs, - const Target& target)>; +using FTVMSchedule = runtime::TypedPackedFunc& outs, const Target& target)>; /*! * \brief Generate the strategy of operators. This function is a generic @@ -143,11 +140,9 @@ using FTVMStrategy = GenericFunc; * and dtype of the inputs. * \return new_expr The modified expression. */ -using FTVMAlterOpLayout = runtime::TypedPackedFunc< - Expr(const Attrs& attrs, - const Array& args, - const Array& tinfos, - const Type& out_type)>; +using FTVMAlterOpLayout = + runtime::TypedPackedFunc& args, + const Array& tinfos, const Type& out_type)>; /*! * \brief Convert the layout of operators or replace the @@ -157,14 +152,14 @@ using FTVMAlterOpLayout = runtime::TypedPackedFunc< * \param inputs The input symbols of the original node. * \param tinfos An array of placeholders, use for getting the inferred shape * and dtype of the inputs. - * \param desired_layout The desired layout. + * \param desired_layouts Specify an array of desired layouts for each input. + * For example a conv2d op: Array("NHWC", "OHWI"), this + * specifies the desired layout for data then kernel. * \return new_expr The modified expression. */ -using FTVMConvertOpLayout = runtime::TypedPackedFunc< - Expr(const Attrs& attrs, - const Array& args, - const Array& tinfos, - const std::string& desired_layout)>; +using FTVMConvertOpLayout = runtime::TypedPackedFunc& args, const Array& tinfos, + const Array& desired_layouts)>; /*! * \brief Legalizes an expression with another expression. This function will be * invoked in Legalize pass. It is a target-dependent pass. @@ -174,10 +169,8 @@ using FTVMConvertOpLayout = runtime::TypedPackedFunc< * and dtype of the inputs. * \return new_expr The modified expression. */ -using FTVMLegalize = runtime::TypedPackedFunc< - Expr(const Attrs& attrs, - const Array& args, - const Array& arg_types)>; +using FTVMLegalize = runtime::TypedPackedFunc& args, + const Array& arg_types)>; /*! * \brief Annotates an expression to indicate if an op should be compiled using @@ -189,9 +182,8 @@ using FTVMLegalize = runtime::TypedPackedFunc< * \return true if this op should be registered to invoke a specific compiler * for codegen, otherwise, false. */ -using FTVMAnnotateTarget = runtime::TypedPackedFunc< - bool(const Attrs& attrs, // NOLINT(*) - const Array& args)>; +using FTVMAnnotateTarget = runtime::TypedPackedFunc& args)>; /*! * \brief Forward rewriting rule for a specific op. @@ -207,10 +199,8 @@ using FTVMAnnotateTarget = runtime::TypedPackedFunc< * \note When we register the function, we can register * a different signature with ctx to be a specific node type. */ -using FForwardRewrite = runtime::TypedPackedFunc< - Expr(const Call& ref_call, - const Array& new_args, - const ObjectRef& ctx)>; +using FForwardRewrite = runtime::TypedPackedFunc& new_args, const ObjectRef& ctx)>; /*! * \brief Gradient for a specific op. @@ -219,8 +209,8 @@ using FForwardRewrite = runtime::TypedPackedFunc< * \param output_grad the gradient of the Expr. * \return the gradient for each parameters. */ -using FPrimalGradient = runtime::TypedPackedFunc(const Expr& orig_call, - const Expr& output_grad)>; +using FPrimalGradient = + runtime::TypedPackedFunc(const Expr& orig_call, const Expr& output_grad)>; /*! * \brief The codegeneration strategy for dynamic dimensions. @@ -233,10 +223,8 @@ enum AnyCodegenStrategy { /*! \brief A runtime representation of shape. */ using Shape = Array; -using FShapeFunc = runtime::TypedPackedFunc< - Array(const Attrs& attrs, - const Array& inputs, - const Array& out_ndims)>; +using FShapeFunc = runtime::TypedPackedFunc( + const Attrs& attrs, const Array& inputs, const Array& out_ndims)>; } // namespace relay } // namespace tvm diff --git a/include/tvm/relay/op_strategy.h b/include/tvm/relay/op_strategy.h index a4da95a36b07..c5785369f8d5 100644 --- a/include/tvm/relay/op_strategy.h +++ b/include/tvm/relay/op_strategy.h @@ -25,11 +25,12 @@ #ifndef TVM_RELAY_OP_STRATEGY_H_ #define TVM_RELAY_OP_STRATEGY_H_ -#include -#include #include #include #include +#include +#include + #include namespace tvm { @@ -45,7 +46,7 @@ class OpImplementationNode : public Object { /*! \brief Schedule function */ FTVMSchedule fschedule; /*! \brief Name of the implementation */ - std::string name; + String name; /*! \brief Priority level */ int plevel; @@ -70,8 +71,7 @@ class OpImplementation : public ObjectRef { * \param out_type The output type information. * \return The output compute description of the operator. */ - TVM_DLL Array Compute(const Attrs& attrs, - const Array& inputs, + TVM_DLL Array Compute(const Attrs& attrs, const Array& inputs, const Type& out_type); /*! * \brief Build the computation schedule. @@ -80,8 +80,7 @@ class OpImplementation : public ObjectRef { * \param target The build target. * \return The computation schedule. */ - TVM_DLL te::Schedule Schedule(const Attrs& attrs, - const Array& outs, + TVM_DLL te::Schedule Schedule(const Attrs& attrs, const Array& outs, const Target& target); TVM_DEFINE_OBJECT_REF_METHODS(OpImplementation, ObjectRef, OpImplementationNode); @@ -119,8 +118,8 @@ class OpSpecialization : public ObjectRef { * \param name Name of the implementation * \param plevel Priority level of the implementation */ - TVM_DLL void AddImplementation(FTVMCompute fcompute, FTVMSchedule fschedule, - std::string name, int plevel); + TVM_DLL void AddImplementation(FTVMCompute fcompute, FTVMSchedule fschedule, String name, + int plevel); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(OpSpecialization, ObjectRef, OpSpecializationNode); }; @@ -133,9 +132,7 @@ class OpStrategyNode : public Object { /*! \brief List of operator specializations. */ Array specializations; - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("specializations", &specializations); - } + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("specializations", &specializations); } static constexpr const char* _type_key = "relay.OpStrategy"; TVM_DECLARE_FINAL_OBJECT_INFO(OpStrategyNode, ExprNode); @@ -153,8 +150,8 @@ class OpStrategy : public ObjectRef { * \param name Name of the implementation * \param plevel Priority level of the implementation */ - TVM_DLL void AddImplementation(FTVMCompute fcompute, FTVMSchedule fschedule, - std::string name, int plevel); + TVM_DLL void AddImplementation(FTVMCompute fcompute, FTVMSchedule fschedule, String name, + int plevel); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(OpStrategy, ObjectRef, OpStrategyNode); }; diff --git a/include/tvm/relay/pattern_functor.h b/include/tvm/relay/pattern_functor.h index 6e0fb17ed233..de3bafa49074 100644 --- a/include/tvm/relay/pattern_functor.h +++ b/include/tvm/relay/pattern_functor.h @@ -25,16 +25,16 @@ #ifndef TVM_RELAY_PATTERN_FUNCTOR_H_ #define TVM_RELAY_PATTERN_FUNCTOR_H_ -#include #include +#include #include -#include #include +#include +#include "./adt.h" #include "./expr.h" #include "./op.h" -#include "./adt.h" namespace tvm { namespace relay { @@ -54,15 +54,13 @@ template class PatternFunctor; // functions to be overriden. -#define PATTERN_FUNCTOR_DEFAULT \ +#define PATTERN_FUNCTOR_DEFAULT \ { return VisitPatternDefault_(op, std::forward(args)...); } -#define RELAY_PATTERN_FUNCTOR_DISPATCH(OP) \ - vtable.template set_dispatch( \ - [](const ObjectRef& n, TSelf* self, Args... args) { \ - return self->VisitPattern_(static_cast(n.get()), \ - std::forward(args)...); \ - }); +#define RELAY_PATTERN_FUNCTOR_DISPATCH(OP) \ + vtable.template set_dispatch([](const ObjectRef& n, TSelf* self, Args... args) { \ + return self->VisitPattern_(static_cast(n.get()), std::forward(args)...); \ + }); template class PatternFunctor { @@ -96,14 +94,10 @@ class PatternFunctor { return vtable(n, this, std::forward(args)...); } // Functions that can be overriden by subclass - virtual R VisitPattern_(const PatternWildcardNode* op, - Args... args) PATTERN_FUNCTOR_DEFAULT; - virtual R VisitPattern_(const PatternVarNode* op, - Args... args) PATTERN_FUNCTOR_DEFAULT; - virtual R VisitPattern_(const PatternConstructorNode* op, - Args... args) PATTERN_FUNCTOR_DEFAULT; - virtual R VisitPattern_(const PatternTupleNode* op, - Args... args) PATTERN_FUNCTOR_DEFAULT; + virtual R VisitPattern_(const PatternWildcardNode* op, Args... args) PATTERN_FUNCTOR_DEFAULT; + virtual R VisitPattern_(const PatternVarNode* op, Args... args) PATTERN_FUNCTOR_DEFAULT; + virtual R VisitPattern_(const PatternConstructorNode* op, Args... args) PATTERN_FUNCTOR_DEFAULT; + virtual R VisitPattern_(const PatternTupleNode* op, Args... args) PATTERN_FUNCTOR_DEFAULT; virtual R VisitPatternDefault_(const Object* op, Args...) { LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); throw; @@ -144,8 +138,7 @@ class PatternVisitor : public ::tvm::relay::PatternFunctor { +class PatternMutator : public ::tvm::relay::PatternFunctor { public: Pattern Mutate(const Pattern& pat); Pattern VisitPattern_(const PatternWildcardNode* op) override; @@ -163,8 +156,9 @@ class PatternMutator virtual Var VisitVar(const Var& v); /*! \brief Used to visit the vars inside of patterns. */ virtual Constructor VisitConstructor(const Constructor& c); + private: - std::unordered_map var_map_; + std::unordered_map var_map_; }; } // namespace relay diff --git a/include/tvm/relay/qnn/attrs.h b/include/tvm/relay/qnn/attrs.h index 3c1c4a33c3d1..4b5cd89f0b0c 100644 --- a/include/tvm/relay/qnn/attrs.h +++ b/include/tvm/relay/qnn/attrs.h @@ -25,6 +25,7 @@ #define TVM_RELAY_QNN_ATTRS_H_ #include + #include namespace tvm { @@ -39,19 +40,20 @@ struct RequantizeAttrs : public tvm::AttrsNode { TVM_DECLARE_ATTRS(RequantizeAttrs, "relay.attrs.RequantizeAttrs") { TVM_ATTR_FIELD(axis) - .describe("The output channel axis for channel wise quantization. Default value is -1," - "which corresponds to the last axis.") - .set_default(-1); - TVM_ATTR_FIELD(rounding).set_default("UPWARD") - .describe("Defines the rounding direction when the value is midway between" - "two representable values. There are two supported modes - UPWARD" - "or TONEAREST. Both modes behave exactly same except at the" - "midpoints between the two representable values. At the midpoint," - "UPWARD rounds towards positive infinity (for example -1.5 will be" - "rounded to -1). TONEAREST is the standard rounding where the" - "value is rounded away from zero at midpoints (for example, -1.5" - "rounds to -2). More context can be found at following gblic manual" - "https://www.gnu.org/software/libc/manual/html_node/Rounding.html."); + .describe( + "The output channel axis for channel wise quantization. Default value is -1," + "which corresponds to the last axis.") + .set_default(-1); + TVM_ATTR_FIELD(rounding).set_default("UPWARD").describe( + "Defines the rounding direction when the value is midway between" + "two representable values. There are two supported modes - UPWARD" + "or TONEAREST. Both modes behave exactly same except at the" + "midpoints between the two representable values. At the midpoint," + "UPWARD rounds towards positive infinity (for example -1.5 will be" + "rounded to -1). TONEAREST is the standard rounding where the" + "value is rounded away from zero at midpoints (for example, -1.5" + "rounds to -2). More context can be found at following gblic manual" + "https://www.gnu.org/software/libc/manual/html_node/Rounding.html."); TVM_ATTR_FIELD(out_dtype) .set_default(NullValue()) .describe("Output data type, set to explicit type under mixed precision setting"); @@ -64,12 +66,12 @@ struct QuantizeAttrs : public tvm::AttrsNode { int axis; TVM_DECLARE_ATTRS(QuantizeAttrs, "relay.attrs.QuantizeAttrs") { - TVM_ATTR_FIELD(out_dtype) - .describe("Output data type, can be one of [int8 or uint8]."); + TVM_ATTR_FIELD(out_dtype).describe("Output data type, can be one of [int8 or uint8]."); TVM_ATTR_FIELD(axis) - .describe("The output channel axis for channel wise quantization. Default value is -1," - "which corresponds to the last axis.") - .set_default(-1); + .describe( + "The output channel axis for channel wise quantization. Default value is -1," + "which corresponds to the last axis.") + .set_default(-1); } }; diff --git a/include/tvm/relay/qnn/transform.h b/include/tvm/relay/qnn/transform.h index 10cd19afe6f3..d1f07c924d6b 100644 --- a/include/tvm/relay/qnn/transform.h +++ b/include/tvm/relay/qnn/transform.h @@ -25,8 +25,8 @@ #ifndef TVM_RELAY_QNN_TRANSFORM_H_ #define TVM_RELAY_QNN_TRANSFORM_H_ -#include #include +#include namespace tvm { namespace relay { diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 2dcf7f31e2d0..b287c053e8a9 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -24,13 +24,13 @@ #ifndef TVM_RELAY_TRANSFORM_H_ #define TVM_RELAY_TRANSFORM_H_ -#include -#include #include +#include #include #include -#include #include +#include +#include #include @@ -56,11 +56,9 @@ using Sequential = tvm::transform::Sequential; * * \return The created function pass. */ -TVM_DLL Pass CreateFunctionPass(const runtime::TypedPackedFunc< - Function(Function, IRModule, PassContext)>& pass_func, - int opt_level, - const std::string& name, - const tvm::Array& required); +TVM_DLL Pass CreateFunctionPass( + const runtime::TypedPackedFunc& pass_func, + int opt_level, String name, tvm::Array required); /*! \brief Remove expressions which does not effect the program result. * @@ -79,17 +77,17 @@ TVM_DLL Pass CreateFunctionPass(const runtime::TypedPackedFunc< TVM_DLL Pass DeadCodeElimination(bool inline_once = false); /*! -* \brief Convert all expressions of TensorType into GradCell, -* an algebraic data type defined in gradient.rly. -* -* This will delay or decrease memory usage. All calls to -* ones, ones_like, zeros, zeros_like will not immediately instantiate a tensor in memory, -* rather only instantiate if needed. It also defines + and * operation -* between GradCell types which can increase performance when using -* zero-filled or one-filled tensors, which is the case in reverse mode ad. -* -* \return the pass -*/ + * \brief Convert all expressions of TensorType into GradCell, + * an algebraic data type defined in gradient.rly. + * + * This will delay or decrease memory usage. All calls to + * ones, ones_like, zeros, zeros_like will not immediately instantiate a tensor in memory, + * rather only instantiate if needed. It also defines + and * operation + * between GradCell types which can increase performance when using + * zero-filled or one-filled tensors, which is the case in reverse mode ad. + * + * \return the pass + */ TVM_DLL Pass LazyGradientInit(); /*! @@ -283,10 +281,12 @@ TVM_DLL Pass AlterOpLayout(); * layouts for conv2d ops for now. Most of the other operators try to adapt to their input layout * using the InferCorrectLayout infrastructure. * - * \param desired_layout The desired layout. + * \param desired_layouts Specify mapping of op_name to array of desired layouts for each input. + * For example: Map("nn.conv2d", Array("NHWC", "OHWI")), + * this specifies the desired layout for data then kernel for nn.conv2d. * \return The pass. */ -TVM_DLL Pass ConvertLayout(const std::string& desired_layout); +TVM_DLL Pass ConvertLayout(const Map>& desired_layouts); /*! * \brief Legalizes an expr with another expression. @@ -298,7 +298,7 @@ TVM_DLL Pass ConvertLayout(const std::string& desired_layout); * * \return The pass. */ -TVM_DLL Pass Legalize(const std::string& legalize_map_attr_name = "FTVMLegalize"); +TVM_DLL Pass Legalize(const String& legalize_map_attr_name = "FTVMLegalize"); /*! * \brief Canonicalize cast expressions to make operator fusion more efficient. @@ -323,15 +323,6 @@ TVM_DLL Pass CanonicalizeCast(); */ TVM_DLL Pass EtaExpand(bool expand_constructor, bool expand_global_var); -/*! - * \brief Print the IR for a module to help debugging. - * - * \param show_meta_data The flag to control if meta data needs to be printed. - * - * \return the pass. - */ -TVM_DLL Pass PrintIR(bool show_meta_data = true); - /*! * \brief Partition a Relay program into regions that can be executed on * different backends. @@ -382,9 +373,7 @@ TVM_DLL Expr Bind(const Expr& expr, const tvm::Map& binds); * \return A type checked Function with its checked_type field populated. * \note this function mutates mod and is not thread-safe. */ -TVM_DLL Function InferType(const Function& f, - const IRModule& mod, - const GlobalVar& var); +TVM_DLL Function InferType(const Function& f, const IRModule& mod, const GlobalVar& var); /*! * \brief Apply rewrite rules to rewrite the expr in post DFS order. This @@ -398,8 +387,7 @@ TVM_DLL Function InferType(const Function& f, * an Expr consumed by multiple callers. * \return The rewritten expression. */ -TVM_DLL Expr ForwardRewrite(const Expr& expr, - const std::string& rewrite_map_attr_name, +TVM_DLL Expr ForwardRewrite(const Expr& expr, const String& rewrite_map_attr_name, std::function fcontext = nullptr, std::function fmulti_ref_trigger = nullptr); @@ -415,8 +403,7 @@ TVM_DLL Expr ForwardRewrite(const Expr& expr, * * \return The rewritten expression. */ -TVM_DLL Expr ForwardRewrite(const Expr& expr, - const FForwardRewrite& rewrite_func, +TVM_DLL Expr ForwardRewrite(const Expr& expr, const FForwardRewrite& rewrite_func, std::function fcontext = nullptr, std::function fmulti_ref_trigger = nullptr); diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h index e8f402ac961d..a388c82a8d90 100644 --- a/include/tvm/relay/type.h +++ b/include/tvm/relay/type.h @@ -24,24 +24,25 @@ #ifndef TVM_RELAY_TYPE_H_ #define TVM_RELAY_TYPE_H_ -#include +#include +#include #include +#include #include -#include #include -#include #include + #include #include "base.h" - namespace tvm { namespace relay { // namespace update for backward compact // will be removed later. -using Any = tvm::tir::AnyNode; +using AnyNode = tvm::tir::AnyNode; +using Any = tvm::tir::Any; using Kind = TypeKind; using Type = tvm::Type; using TypeNode = tvm::TypeNode; diff --git a/include/tvm/runtime/c_backend_api.h b/include/tvm/runtime/c_backend_api.h index abfc792d574f..40cef83ee05b 100644 --- a/include/tvm/runtime/c_backend_api.h +++ b/include/tvm/runtime/c_backend_api.h @@ -45,11 +45,8 @@ extern "C" { * * \return 0 if success, -1 if failure happens, set error via TVMAPISetLastError. */ -typedef int (*TVMBackendPackedCFunc)(TVMValue* args, - int* type_codes, - int num_args, - TVMValue* out_ret_value, - int* out_ret_tcode); +typedef int (*TVMBackendPackedCFunc)(TVMValue* args, int* type_codes, int num_args, + TVMValue* out_ret_value, int* out_ret_tcode); /*! * \brief Backend function for modules to get function @@ -61,9 +58,7 @@ typedef int (*TVMBackendPackedCFunc)(TVMValue* args, * \param out The result function. * \return 0 when no error is thrown, -1 when failure happens */ -TVM_DLL int TVMBackendGetFuncFromEnv(void* mod_node, - const char* func_name, - TVMFunctionHandle *out); +TVM_DLL int TVMBackendGetFuncFromEnv(void* mod_node, const char* func_name, TVMFunctionHandle* out); /*! * \brief Backend function to register system-wide library symbol. * @@ -76,7 +71,7 @@ TVM_DLL int TVMBackendRegisterSystemLibSymbol(const char* name, void* ptr); /*! * \brief Backend function to allocate temporal workspace. * - * \note The result allocate spaced is ensured to be aligned to kTempAllocaAlignment. + * \note The result allocated space is ensured to be aligned to kTempAllocaAlignment. * * \param nbytes The size of the space requested. * \param device_type The device type which the space will be allocated. @@ -87,11 +82,8 @@ TVM_DLL int TVMBackendRegisterSystemLibSymbol(const char* name, void* ptr); * certain backends such as OpenGL. * \return nullptr when error is thrown, a valid ptr if success */ -TVM_DLL void* TVMBackendAllocWorkspace(int device_type, - int device_id, - uint64_t nbytes, - int dtype_code_hint, - int dtype_bits_hint); +TVM_DLL void* TVMBackendAllocWorkspace(int device_type, int device_id, uint64_t nbytes, + int dtype_code_hint, int dtype_bits_hint); /*! * \brief Backend function to free temporal workspace. @@ -103,9 +95,7 @@ TVM_DLL void* TVMBackendAllocWorkspace(int device_type, * * \sa TVMBackendAllocWorkspace */ -TVM_DLL int TVMBackendFreeWorkspace(int device_type, - int device_id, - void* ptr); +TVM_DLL int TVMBackendFreeWorkspace(int device_type, int device_id, void* ptr); /*! * \brief Environment for TVM parallel task. @@ -125,8 +115,7 @@ typedef struct { * \param penv The parallel environment backs the execution. * \param cdata The supporting closure data. */ -typedef int (*FTVMParallelLambda)( - int task_id, TVMParallelGroupEnv* penv, void* cdata); +typedef int (*FTVMParallelLambda)(int task_id, TVMParallelGroupEnv* penv, void* cdata); /*! * \brief Backend function for running parallel jobs. @@ -138,9 +127,7 @@ typedef int (*FTVMParallelLambda)( * * \return 0 when no error is thrown, -1 when failure happens */ -TVM_DLL int TVMBackendParallelLaunch(FTVMParallelLambda flambda, - void* cdata, - int num_task); +TVM_DLL int TVMBackendParallelLaunch(FTVMParallelLambda flambda, void* cdata, int num_task); /*! * \brief BSP barrrier between parallel threads @@ -150,22 +137,18 @@ TVM_DLL int TVMBackendParallelLaunch(FTVMParallelLambda flambda, */ TVM_DLL int TVMBackendParallelBarrier(int task_id, TVMParallelGroupEnv* penv); - /*! * \brief Simple static initialization function. * Run f once and set handle to be not null. * This function is mainly used for test purpose. * - * \param handle An global address to indicate f - * \param f The function to be ran + * \param handle A global address to indicate f + * \param f The function to be run * \param cdata The closure data to pass to the function. * \param nbytes Number of bytes in the closure data. * \return 0 when no error is thrown, -1 when failure happens */ -TVM_DLL int TVMBackendRunOnce(void** handle, - int (*f)(void*), - void *cdata, - int nbytes); +TVM_DLL int TVMBackendRunOnce(void** handle, int (*f)(void*), void* cdata, int nbytes); #ifdef __cplusplus } // TVM_EXTERN_C diff --git a/include/tvm/runtime/c_runtime_api.h b/include/tvm/runtime/c_runtime_api.h index 920ecfbf9b13..213c7059a5f9 100644 --- a/include/tvm/runtime/c_runtime_api.h +++ b/include/tvm/runtime/c_runtime_api.h @@ -63,15 +63,14 @@ // TVM version #define TVM_VERSION "0.7.dev1" - // TVM Runtime is DLPack compatible. #include #ifdef __cplusplus extern "C" { #endif -#include #include +#include /*! \brief type of array index. */ typedef int64_t tvm_index_t; @@ -83,16 +82,29 @@ typedef enum { kOpenGL = 11, kDLMicroDev = 13, kDLHexagon = 14, + kDLWebGPU = 15 // AddExtraTVMType which is not in DLPack here } TVMDeviceExtType; /*! - * \brief The type code in used in the TVM FFI. + * \brief The type code in used and only used in TVM FFI for argument passing. + * + * DLPack consistency: + * 1) kTVMArgInt is compatible with kDLInt + * 2) kTVMArgFloat is compatible with kDLFloat + * 3) kDLUInt is not in ArgTypeCode, but has a spared slot + * + * Downstream consistency: + * The kDLInt, kDLUInt, kDLFloat are kept consistent with the original ArgType code + * + * It is only used in argument passing, and should not be confused with + * DataType::TypeCode, which is DLPack-compatible. + * + * \sa tvm::runtime::DataType::TypeCode */ typedef enum { - // The type code of other types are compatible with DLPack. - // The next few fields are extension types - // that is used by TVM API calls. + kTVMArgInt = kDLInt, + kTVMArgFloat = kDLFloat, kTVMOpaqueHandle = 3U, kTVMNullptr = 4U, kTVMDataType = 5U, @@ -115,9 +127,7 @@ typedef enum { // The following section of code is used for non-reserved types. kTVMExtReserveEnd = 64U, kTVMExtEnd = 128U, - // The rest of the space is used for custom, user-supplied datatypes - kTVMCustomBegin = 129U, -} TVMTypeCode; +} TVMArgTypeCode; /*! * \brief The Device information, abstract away common device types. @@ -179,7 +189,7 @@ TVM_DLL void TVMAPISetLastError(const char* msg); * this function is threadsafe and can be called by different thread * \return error info */ -TVM_DLL const char *TVMGetLastError(void); +TVM_DLL const char* TVMGetLastError(void); /*! * \brief Load module from file. * \param file_name The file name to load the module from. @@ -190,9 +200,7 @@ TVM_DLL const char *TVMGetLastError(void); * \note The resulting module do not contain import relation. * It can be reconstructed by TVMModImport. */ -TVM_DLL int TVMModLoadFromFile(const char* file_name, - const char* format, - TVMModuleHandle* out); +TVM_DLL int TVMModLoadFromFile(const char* file_name, const char* format, TVMModuleHandle* out); /*! * \brief Add dep to mod's dependency. @@ -202,8 +210,7 @@ TVM_DLL int TVMModLoadFromFile(const char* file_name, * \param dep The dependent module to be imported. * \return 0 when success, -1 when failure happens */ -TVM_DLL int TVMModImport(TVMModuleHandle mod, - TVMModuleHandle dep); +TVM_DLL int TVMModImport(TVMModuleHandle mod, TVMModuleHandle dep); /*! * \brief Get function from the module. @@ -213,10 +220,8 @@ TVM_DLL int TVMModImport(TVMModuleHandle mod, * \param out The result function, can be NULL if it is not available. * \return 0 when no error is thrown, -1 when failure happens */ -TVM_DLL int TVMModGetFunction(TVMModuleHandle mod, - const char* func_name, - int query_imports, - TVMFunctionHandle *out); +TVM_DLL int TVMModGetFunction(TVMModuleHandle mod, const char* func_name, int query_imports, + TVMFunctionHandle* out); /*! * \brief Free the Module @@ -258,12 +263,8 @@ TVM_DLL int TVMFuncFree(TVMFunctionHandle func); * The front-end need to call free function (e.g. TVMFuncFree) * to free these handles. */ -TVM_DLL int TVMFuncCall(TVMFunctionHandle func, - TVMValue* arg_values, - int* type_codes, - int num_args, - TVMValue* ret_val, - int* ret_type_code); +TVM_DLL int TVMFuncCall(TVMFunctionHandle func, TVMValue* arg_values, int* type_codes, int num_args, + TVMValue* ret_val, int* ret_type_code); /*! * \brief Set the return value of TVMPackedCFunc. @@ -276,10 +277,7 @@ TVM_DLL int TVMFuncCall(TVMFunctionHandle func, * \param type_code The type of the value to be returned. * \param num_ret Number of return values, for now only 1 is supported. */ -TVM_DLL int TVMCFuncSetReturn(TVMRetValueHandle ret, - TVMValue* value, - int* type_code, - int num_ret); +TVM_DLL int TVMCFuncSetReturn(TVMRetValueHandle ret, TVMValue* value, int* type_code, int num_ret); /*! * \brief Inplace translate callback argument value to return value. @@ -304,12 +302,8 @@ TVM_DLL int TVMCbArgToReturn(TVMValue* value, int* code); * \return 0 if success, -1 if failure happens, set error via TVMAPISetLastError. * \sa TVMCFuncSetReturn */ -typedef int (*TVMPackedCFunc)( - TVMValue* args, - int* type_codes, - int num_args, - TVMRetValueHandle ret, - void* resource_handle); +typedef int (*TVMPackedCFunc)(TVMValue* args, int* type_codes, int num_args, TVMRetValueHandle ret, + void* resource_handle); /*! * \brief C callback to free the resource handle in C packed function. @@ -339,10 +333,8 @@ typedef int (*TVMExtensionFuncDeclarer)(TVMFunctionHandle register_func_handle); * \param out the result function handle. * \return 0 when success, -1 when failure happens */ -TVM_DLL int TVMFuncCreateFromCFunc(TVMPackedCFunc func, - void* resource_handle, - TVMPackedCFuncFinalizer fin, - TVMFunctionHandle *out); +TVM_DLL int TVMFuncCreateFromCFunc(TVMPackedCFunc func, void* resource_handle, + TVMPackedCFuncFinalizer fin, TVMFunctionHandle* out); /*! * \brief Register the function to runtime's global table. @@ -353,8 +345,7 @@ TVM_DLL int TVMFuncCreateFromCFunc(TVMPackedCFunc func, * \param f The function to be registered. * \param override Whether allow override already registered function. */ -TVM_DLL int TVMFuncRegisterGlobal( - const char* name, TVMFunctionHandle f, int override); +TVM_DLL int TVMFuncRegisterGlobal(const char* name, TVMFunctionHandle f, int override); /*! * \brief Get a global function. @@ -373,8 +364,7 @@ TVM_DLL int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out); * \param out_array The array of function names. * \return 0 when success, -1 when failure happens */ -TVM_DLL int TVMFuncListGlobalNames(int* out_size, - const char*** out_array); +TVM_DLL int TVMFuncListGlobalNames(int* out_size, const char*** out_array); // Array related apis for quick proptyping /*! @@ -391,14 +381,8 @@ TVM_DLL int TVMFuncListGlobalNames(int* out_size, * \param out The output handle. * \return 0 when success, -1 when failure happens */ -TVM_DLL int TVMArrayAlloc(const tvm_index_t* shape, - int ndim, - int dtype_code, - int dtype_bits, - int dtype_lanes, - int device_type, - int device_id, - TVMArrayHandle* out); +TVM_DLL int TVMArrayAlloc(const tvm_index_t* shape, int ndim, int dtype_code, int dtype_bits, + int dtype_lanes, int device_type, int device_id, TVMArrayHandle* out); /*! * \brief Free the TVM Array. @@ -414,9 +398,7 @@ TVM_DLL int TVMArrayFree(TVMArrayHandle handle); * \param nbytes The number of bytes to copy. * \return 0 when success, -1 when failure happens */ -TVM_DLL int TVMArrayCopyFromBytes(TVMArrayHandle handle, - void* data, - size_t nbytes); +TVM_DLL int TVMArrayCopyFromBytes(TVMArrayHandle handle, void* data, size_t nbytes); /*! * \brief Copy array data to CPU byte array. @@ -425,9 +407,7 @@ TVM_DLL int TVMArrayCopyFromBytes(TVMArrayHandle handle, * \param nbytes The number of bytes to copy. * \return 0 when success, -1 when failure happens */ -TVM_DLL int TVMArrayCopyToBytes(TVMArrayHandle handle, - void* data, - size_t nbytes); +TVM_DLL int TVMArrayCopyToBytes(TVMArrayHandle handle, void* data, size_t nbytes); /*! * \brief Copy the array, both from and to must be valid during the copy. @@ -436,9 +416,7 @@ TVM_DLL int TVMArrayCopyToBytes(TVMArrayHandle handle, * \param stream The stream where the copy happens, can be NULL. * \return 0 when success, -1 when failure happens */ -TVM_DLL int TVMArrayCopyFromTo(TVMArrayHandle from, - TVMArrayHandle to, - TVMStreamHandle stream); +TVM_DLL int TVMArrayCopyFromTo(TVMArrayHandle from, TVMArrayHandle to, TVMStreamHandle stream); /*! * \brief Produce an array from the DLManagedTensor that shares data memory @@ -447,8 +425,7 @@ TVM_DLL int TVMArrayCopyFromTo(TVMArrayHandle from, * \param out The output array handle. * \return 0 when success, -1 when failure happens */ -TVM_DLL int TVMArrayFromDLPack(DLManagedTensor* from, - TVMArrayHandle* out); +TVM_DLL int TVMArrayFromDLPack(DLManagedTensor* from, TVMArrayHandle* out); /*! * \brief Produce a DLMangedTensor from the array that shares data memory with @@ -457,8 +434,7 @@ TVM_DLL int TVMArrayFromDLPack(DLManagedTensor* from, * \param out The DLManagedTensor handle. * \return 0 when success, -1 when failure happens */ -TVM_DLL int TVMArrayToDLPack(TVMArrayHandle from, - DLManagedTensor** out); +TVM_DLL int TVMArrayToDLPack(TVMArrayHandle from, DLManagedTensor** out); /*! * \brief Delete (free) a DLManagedTensor's data. @@ -518,9 +494,7 @@ TVM_DLL int TVMSynchronize(int device_type, int device_id, TVMStreamHandle strea * \param dst The destination stream to synchronize. * \return 0 when success, -1 when failure happens */ -TVM_DLL int TVMStreamStreamSynchronize(int device_type, - int device_id, - TVMStreamHandle src, +TVM_DLL int TVMStreamStreamSynchronize(int device_type, int device_id, TVMStreamHandle src, TVMStreamHandle dst); /*! @@ -540,6 +514,15 @@ TVM_DLL int TVMObjectGetTypeIndex(TVMObjectHandle obj, unsigned* out_tindex); */ TVM_DLL int TVMObjectTypeKey2Index(const char* type_key, unsigned* out_tindex); +/*! + * \brief Increase the reference count of an object. + * + * \param obj The object handle. + * \note Internally we increase the reference counter of the object. + * \return 0 when success, -1 when failure happens + */ +TVM_DLL int TVMObjectRetain(TVMObjectHandle obj); + /*! * \brief Free the object. * @@ -550,6 +533,56 @@ TVM_DLL int TVMObjectTypeKey2Index(const char* type_key, unsigned* out_tindex); */ TVM_DLL int TVMObjectFree(TVMObjectHandle obj); +/*! + * \brief Allocate a data space on device. + * \param ctx The device context to perform operation. + * \param nbytes The number of bytes in memory. + * \param alignment The alignment of the memory. + * \param type_hint The type of elements. Only needed by certain backends such + * as nbytes & alignment are sufficient for most backends. + * \param out_data The allocated device pointer. + * \return 0 when success, -1 when failure happens + */ +TVM_DLL int TVMDeviceAllocDataSpace(DLContext ctx, size_t nbytes, size_t alignment, + DLDataType type_hint, void** out_data); + +/*! + * \brief Free a data space on device. + * \param ctx The device context to perform operation. + * \param ptr The data space. + * \return 0 when success, -1 when failure happens + */ +TVM_DLL int TVMDeviceFreeDataSpace(TVMContext ctx, void* ptr); + +/*! + * \brief Copy data from one place to another. + * \param from The source array. + * \param from_offset The byte offeset in the from. + * \param to The target array. + * \param to_offset The byte offset in the to. + * \param num_bytes The size of the memory in bytes + * \param ctx_from The source context + * \param ctx_to The target context + * \param type_hint The type of elements, only neded by certain backends. + * can be useful for cross device endian converison. + * \param stream Optional stream object. + * \return 0 when success, -1 when failure happens. + */ +TVM_DLL int TVMDeviceCopyDataFromTo(const void* from, size_t from_offset, void* to, + size_t to_offset, size_t num_bytes, TVMContext ctx_from, + TVMContext ctx_to, DLDataType type_hint, + TVMStreamHandle stream); + +/*! + * \brief Check that an object is derived from another. + * \param child_type_index The type index of the derived type. + * \param parent_type_index The type index of the parent type. + * \param is_derived A boolean representing whether this predicate holds. + * \return 0 when success, -1 when failure happens. + */ +TVM_DLL int TVMObjectDerivedFrom(uint32_t child_type_index, uint32_t parent_type_index, + int* is_derived); + #ifdef __cplusplus } // TVM_EXTERN_C #endif diff --git a/include/tvm/runtime/container.h b/include/tvm/runtime/container.h index 8f426415ffee..36e2e8f5f276 100644 --- a/include/tvm/runtime/container.h +++ b/include/tvm/runtime/container.h @@ -29,8 +29,10 @@ #include #include +#include #include #include +#include #include // We use c++14 std::experimental::string_view for optimizing hash computation // only right now, its usage is limited in this file. Any broader usage of @@ -39,8 +41,7 @@ // string_view: // https://isocpp.org/std/standing-documents/sd-6-sg10-feature-test-recommendations // https://en.cppreference.com/w/User:D41D8CD98F/feature_testing_macros -#if defined(__cpp_lib_experimental_string_view) && \ - __cpp_lib_experimental_string_view >= 201411 +#if defined(__cpp_lib_experimental_string_view) && __cpp_lib_experimental_string_view >= 201411 #define TVM_USE_CXX14_STRING_VIEW_HASH 1 #else #define TVM_USE_CXX14_STRING_VIEW_HASH 0 @@ -64,7 +65,15 @@ #include #include +namespace llvm { +// String to llvm object compatibility. +class StringRef; +} // namespace llvm + namespace tvm { + +struct ObjectEqual; + namespace runtime { /*! @@ -135,8 +144,7 @@ class InplaceArrayBase { * \brief Destroy the Inplace Array Base object */ ~InplaceArrayBase() { - if (!(std::is_standard_layout::value && - std::is_trivial::value)) { + if (!(std::is_standard_layout::value && std::is_trivial::value)) { size_t size = Self()->GetSize(); for (size_t i = 0; i < size; ++i) { ElemType* fp = reinterpret_cast(AddressOf(i)); @@ -162,7 +170,6 @@ class InplaceArrayBase { new (field_ptr) ElemType(std::forward(args)...); } - private: /*! * \brief Return the self object for the array. * @@ -179,10 +186,10 @@ class InplaceArrayBase { * \return Raw pointer to the element. */ void* AddressOf(size_t idx) const { - static_assert(alignof(ArrayType) % alignof(ElemType) == 0 && - sizeof(ArrayType) % alignof(ElemType) == 0, - "The size and alignment of ArrayType should respect " - "ElemType's alignment."); + static_assert( + alignof(ArrayType) % alignof(ElemType) == 0 && sizeof(ArrayType) % alignof(ElemType) == 0, + "The size and alignment of ArrayType should respect " + "ElemType's alignment."); size_t kDataStart = sizeof(ArrayType); ArrayType* self = Self(); @@ -191,6 +198,788 @@ class InplaceArrayBase { } }; +/*! + * \brief iterator adapter that adapts TIter to return another type. + * \tparam Converter a struct that contains converting function + * \tparam TIter the content iterator type. + */ +template +class IterAdapter { + public: + using difference_type = typename std::iterator_traits::difference_type; + using value_type = typename Converter::ResultType; + using pointer = typename Converter::ResultType*; + using reference = typename Converter::ResultType&; // NOLINT(*) + using iterator_category = typename std::iterator_traits::iterator_category; + + explicit IterAdapter(TIter iter) : iter_(iter) {} + IterAdapter& operator++() { + ++iter_; + return *this; + } + IterAdapter& operator--() { + --iter_; + return *this; + } + IterAdapter& operator++(int) { + IterAdapter copy = *this; + ++iter_; + return copy; + } + IterAdapter& operator--(int) { + IterAdapter copy = *this; + --iter_; + return copy; + } + + IterAdapter operator+(difference_type offset) const { return IterAdapter(iter_ + offset); } + + template + typename std::enable_if::value, + typename T::difference_type>::type inline + operator-(const IterAdapter& rhs) const { + return iter_ - rhs.iter_; + } + + bool operator==(IterAdapter other) const { return iter_ == other.iter_; } + bool operator!=(IterAdapter other) const { return !(*this == other); } + const value_type operator*() const { return Converter::convert(*iter_); } + + private: + TIter iter_; +}; + +/*! + * \brief iterator adapter that adapts TIter to return another type. + * \tparam Converter a struct that contains converting function + * \tparam TIter the content iterator type. + */ +template +class ReverseIterAdapter { + public: + using difference_type = typename std::iterator_traits::difference_type; + using value_type = typename Converter::ResultType; + using pointer = typename Converter::ResultType*; + using reference = typename Converter::ResultType&; // NOLINT(*) + using iterator_category = typename std::iterator_traits::iterator_category; + + explicit ReverseIterAdapter(TIter iter) : iter_(iter) {} + ReverseIterAdapter& operator++() { + --iter_; + return *this; + } + ReverseIterAdapter& operator--() { + ++iter_; + return *this; + } + ReverseIterAdapter& operator++(int) { + ReverseIterAdapter copy = *this; + --iter_; + return copy; + } + ReverseIterAdapter& operator--(int) { + ReverseIterAdapter copy = *this; + ++iter_; + return copy; + } + ReverseIterAdapter operator+(difference_type offset) const { + return ReverseIterAdapter(iter_ - offset); + } + + template + typename std::enable_if::value, + typename T::difference_type>::type inline + operator-(const ReverseIterAdapter& rhs) const { + return rhs.iter_ - iter_; + } + + bool operator==(ReverseIterAdapter other) const { return iter_ == other.iter_; } + bool operator!=(ReverseIterAdapter other) const { return !(*this == other); } + const value_type operator*() const { return Converter::convert(*iter_); } + + private: + TIter iter_; +}; + +/*! \brief array node content in array */ +class ArrayNode : public Object, public InplaceArrayBase { + public: + /*! \return The size of the array */ + size_t size() const { return this->size_; } + + /*! + * \brief Read i-th element from array. + * \param i The index + * \return the i-th element. + */ + const ObjectRef at(int64_t i) const { return this->operator[](i); } + + /*! \return begin constant iterator */ + const ObjectRef* begin() const { return static_cast(InplaceArrayBase::AddressOf(0)); } + + /*! \return end constant iterator */ + const ObjectRef* end() const { return begin() + size_; } + + /*! \brief Release reference to all the elements */ + void clear() { ShrinkBy(size_); } + + /*! + * \brief Set i-th element of the array in-place + * \param i The index + * \param item The value to be set + */ + void SetItem(int64_t i, ObjectRef item) { this->operator[](i) = std::move(item); } + + /*! + * \brief Constructs a container and copy from another + * \param cap The capacity of the container + * \param from Source of the copy + * \return Ref-counted ArrayNode requested + */ + static ObjectPtr CopyFrom(int64_t cap, ArrayNode* from) { + int64_t size = from->size_; + CHECK_GE(cap, size) << "ValueError: not enough capacity"; + ObjectPtr p = ArrayNode::Empty(cap); + ObjectRef* write = p->MutableBegin(); + ObjectRef* read = from->MutableBegin(); + // To ensure exception safety, size is only incremented after the initialization succeeds + for (int64_t& i = p->size_ = 0; i < size; ++i) { + new (write++) ObjectRef(*read++); + } + return p; + } + + /*! + * \brief Constructs a container and move from another + * \param cap The capacity of the container + * \param from Source of the move + * \return Ref-counted ArrayNode requested + */ + static ObjectPtr MoveFrom(int64_t cap, ArrayNode* from) { + int64_t size = from->size_; + CHECK_GE(cap, size) << "ValueError: not enough capacity"; + ObjectPtr p = ArrayNode::Empty(cap); + ObjectRef* write = p->MutableBegin(); + ObjectRef* read = from->MutableBegin(); + // To ensure exception safety, size is only incremented after the initialization succeeds + for (int64_t& i = p->size_ = 0; i < size; ++i) { + new (write++) ObjectRef(std::move(*read++)); + } + from->size_ = 0; + return p; + } + + /*! + * \brief Constructs a container with n elements. Each element is a copy of val + * \param n The size of the container + * \param val The init value + * \return Ref-counted ArrayNode requested + */ + static ObjectPtr CreateRepeated(int64_t n, const ObjectRef& val) { + ObjectPtr p = ArrayNode::Empty(n); + ObjectRef* itr = p->MutableBegin(); + for (int64_t& i = p->size_ = 0; i < n; ++i) { + new (itr++) ObjectRef(val); + } + return p; + } + + static constexpr const uint32_t _type_index = TypeIndex::kRuntimeArray; + static constexpr const char* _type_key = "Array"; + TVM_DECLARE_FINAL_OBJECT_INFO(ArrayNode, Object); + + private: + /*! \return Size of initialized memory, used by InplaceArrayBase. */ + size_t GetSize() const { return this->size_; } + + /*! \return begin mutable iterator */ + ObjectRef* MutableBegin() const { + return static_cast(InplaceArrayBase::AddressOf(0)); + } + + /*! \return end mutable iterator */ + ObjectRef* MutableEnd() const { return MutableBegin() + size_; } + + /*! + * \brief Create an ArrayNode with the given capacity. + * \param n Required capacity + * \return Ref-counted ArrayNode requested + */ + static ObjectPtr Empty(int64_t n = kInitSize) { + CHECK_GE(n, 0); + ObjectPtr p = make_inplace_array_object(n); + p->capacity_ = n; + p->size_ = 0; + return p; + } + + /*! + * \brief Inplace-initialize the elements starting idx from [first, last) + * \param idx The starting point + * \param first Begin of iterator + * \param last End of iterator + * \tparam IterType The type of iterator + * \return Self + */ + template + ArrayNode* InitRange(int64_t idx, IterType first, IterType last) { + ObjectRef* itr = MutableBegin() + idx; + for (; first != last; ++first) { + ObjectRef ref = *first; + new (itr++) ObjectRef(std::move(ref)); + } + return this; + } + + /*! + * \brief Move elements from right to left, requires src_begin > dst + * \param dst Destination + * \param src_begin The start point of copy (inclusive) + * \param src_end The end point of copy (exclusive) + * \return Self + */ + ArrayNode* MoveElementsLeft(int64_t dst, int64_t src_begin, int64_t src_end) { + ObjectRef* from = MutableBegin() + src_begin; + ObjectRef* to = MutableBegin() + dst; + while (src_begin++ != src_end) { + *to++ = std::move(*from++); + } + return this; + } + + /*! + * \brief Move elements from left to right, requires src_begin < dst + * \param dst Destination + * \param src_begin The start point of move (inclusive) + * \param src_end The end point of move (exclusive) + * \return Self + */ + ArrayNode* MoveElementsRight(int64_t dst, int64_t src_begin, int64_t src_end) { + ObjectRef* from = MutableBegin() + src_end; + ObjectRef* to = MutableBegin() + (src_end - src_begin + dst); + while (src_begin++ != src_end) { + *--to = std::move(*--from); + } + return this; + } + + /*! + * \brief Enlarges the size of the array + * \param delta Size enlarged, should be positive + * \param val Default value + * \return Self + */ + ArrayNode* EnlargeBy(int64_t delta, const ObjectRef& val = ObjectRef(nullptr)) { + ObjectRef* itr = MutableEnd(); + while (delta-- > 0) { + new (itr++) ObjectRef(val); + ++size_; + } + return this; + } + + /*! + * \brief Shrinks the size of the array + * \param delta Size shrinked, should be positive + * \return Self + */ + ArrayNode* ShrinkBy(int64_t delta) { + ObjectRef* itr = MutableEnd(); + while (delta-- > 0) { + (--itr)->ObjectRef::~ObjectRef(); + --size_; + } + return this; + } + + /*! \brief Number of elements used */ + int64_t size_; + + /*! \brief Number of elements allocated */ + int64_t capacity_; + + /*! \brief Initial size of ArrayNode */ + static constexpr int64_t kInitSize = 4; + + /*! \brief Expansion factor of the Array */ + static constexpr int64_t kIncFactor = 2; + + // CRTP parent class + friend InplaceArrayBase; + + // Reference class + template + friend class Array; + + // To specialize make_object + friend ObjectPtr make_object<>(); +}; + +/*! + * \brief Array, container representing a contigious sequence of ObjectRefs. + * + * Array implements in-place copy-on-write semantics. + * + * As in typical copy-on-write, a method which would typically mutate the array + * instead opaquely copies the underlying container, and then acts on its copy. + * + * If the array has reference count equal to one, we directly update the + * container in place without copying. This is optimization is sound because + * when the reference count is equal to one this reference is guranteed to be + * the sole pointer to the container. + * + * + * operator[] only provides const access, use Set to mutate the content. + * \tparam T The content ObjectRef type. + */ +template ::value>::type> +class Array : public ObjectRef { + public: + // constructors + /*! + * \brief default constructor + */ + Array() { data_ = ArrayNode::Empty(); } + + /*! + * \brief move constructor + * \param other source + */ + Array(Array&& other) : ObjectRef() { // NOLINT(*) + data_ = std::move(other.data_); + } + + /*! + * \brief copy constructor + * \param other source + */ + Array(const Array& other) : ObjectRef() { // NOLINT(*) + data_ = other.data_; + } + + /*! + * \brief constructor from pointer + * \param n the container pointer + */ + explicit Array(ObjectPtr n) : ObjectRef(n) {} + + /*! + * \brief Constructor from iterator + * \param first begin of iterator + * \param last end of iterator + * \tparam IterType The type of iterator + */ + template + Array(IterType first, IterType last) { + Assign(first, last); + } + + /*! + * \brief constructor from initializer list + * \param init The initializer list + */ + Array(std::initializer_list init) { // NOLINT(*) + Assign(init.begin(), init.end()); + } + + /*! + * \brief constructor from vector + * \param init The vector + */ + Array(const std::vector& init) { // NOLINT(*) + Assign(init.begin(), init.end()); + } + + /*! + * \brief Constructs a container with n elements. Each element is a copy of val + * \param n The size of the container + * \param val The init value + */ + explicit Array(const size_t n, const T& val) { data_ = ArrayNode::CreateRepeated(n, val); } + + /*! + * \brief move assign operator + * \param other The source of assignment + * \return reference to self. + */ + Array& operator=(Array&& other) { + data_ = std::move(other.data_); + return *this; + } + + /*! + * \brief copy assign operator + * \param other The source of assignment + * \return reference to self. + */ + Array& operator=(const Array& other) { + data_ = other.data_; + return *this; + } + + public: + // iterators + struct ValueConverter { + using ResultType = T; + static T convert(const ObjectRef& n) { return DowncastNoCheck(n); } + }; + + using iterator = IterAdapter; + using reverse_iterator = ReverseIterAdapter; + + /*! \return begin iterator */ + iterator begin() const { return iterator(GetArrayNode()->begin()); } + + /*! \return end iterator */ + iterator end() const { return iterator(GetArrayNode()->end()); } + + /*! \return rbegin iterator */ + reverse_iterator rbegin() const { + // ArrayNode::end() is never nullptr + return reverse_iterator(GetArrayNode()->end() - 1); + } + + /*! \return rend iterator */ + reverse_iterator rend() const { + // ArrayNode::begin() is never nullptr + return reverse_iterator(GetArrayNode()->begin() - 1); + } + + public: + // const methods in std::vector + /*! + * \brief Immutably read i-th element from array. + * \param i The index + * \return the i-th element. + */ + const T operator[](int64_t i) const { + ArrayNode* p = GetArrayNode(); + CHECK(p != nullptr) << "ValueError: cannot index a null array"; + CHECK(0 <= i && i < p->size_) << "IndexError: indexing " << i << " on an array of size " + << p->size_; + return DowncastNoCheck(*(p->begin() + i)); + } + + /*! \return The size of the array */ + size_t size() const { + ArrayNode* p = GetArrayNode(); + return p == nullptr ? 0 : GetArrayNode()->size_; + } + + /*! \return The capacity of the array */ + size_t capacity() const { + ArrayNode* p = GetArrayNode(); + return p == nullptr ? 0 : GetArrayNode()->capacity_; + } + + /*! \return Whether array is empty */ + bool empty() const { return size() == 0; } + + /*! \return The first element of the array */ + const T front() const { + ArrayNode* p = GetArrayNode(); + CHECK(p != nullptr) << "ValueError: cannot index a null array"; + CHECK_GT(p->size_, 0) << "IndexError: cannot index an empty array"; + return DowncastNoCheck(*(p->begin())); + } + + /*! \return The last element of the array */ + const T back() const { + ArrayNode* p = GetArrayNode(); + CHECK(p != nullptr) << "ValueError: cannot index a null array"; + CHECK_GT(p->size_, 0) << "IndexError: cannot index an empty array"; + return DowncastNoCheck(*(p->end() - 1)); + } + + public: + // mutation in std::vector, implements copy-on-write + + /*! + * \brief push a new item to the back of the list + * \param item The item to be pushed. + */ + void push_back(const T& item) { + ArrayNode* p = CopyOnWrite(1); + p->EmplaceInit(p->size_++, item); + } + + /*! + * \brief Insert an element into the given position + * \param position An iterator pointing to the insertion point + * \param val The element to insert + */ + void insert(iterator position, const T& val) { + CHECK(data_ != nullptr) << "ValueError: cannot insert a null array"; + int64_t idx = std::distance(begin(), position); + int64_t size = GetArrayNode()->size_; + auto addr = CopyOnWrite(1) // + ->EnlargeBy(1) // + ->MoveElementsRight(idx + 1, idx, size) // + ->MutableBegin(); + new (addr + idx) ObjectRef(val); + } + + /*! + * \brief Insert a range of elements into the given position + * \param position An iterator pointing to the insertion point + * \param first The begin iterator of the range + * \param last The end iterator of the range + */ + template + void insert(iterator position, IterType first, IterType last) { + if (first == last) { + return; + } + CHECK(data_ != nullptr) << "ValueError: cannot insert a null array"; + int64_t idx = std::distance(begin(), position); + int64_t size = GetArrayNode()->size_; + int64_t numel = std::distance(first, last); + CopyOnWrite(numel) + ->EnlargeBy(numel) + ->MoveElementsRight(idx + numel, idx, size) + ->InitRange(idx, first, last); + } + + /*! \brief Remove the last item of the list */ + void pop_back() { + CHECK(data_ != nullptr) << "ValueError: cannot pop_back because array is null"; + int64_t size = GetArrayNode()->size_; + CHECK_GT(size, 0) << "ValueError: cannot pop_back because array is empty"; + CopyOnWrite()->ShrinkBy(1); + } + + /*! + * \brief Erase an element on the given position + * \param position An iterator pointing to the element to be erased + */ + void erase(iterator position) { + CHECK(data_ != nullptr) << "ValueError: cannot erase a null array"; + int64_t st = std::distance(begin(), position); + int64_t size = GetArrayNode()->size_; + CHECK(0 <= st && st < size) << "ValueError: cannot erase at index " << st + << ", because Array size is " << size; + CopyOnWrite() // + ->MoveElementsLeft(st, st + 1, size) // + ->ShrinkBy(1); + } + + /*! + * \brief Erase a given range of elements + * \param first The begin iterator of the range + * \param last The end iterator of the range + */ + void erase(iterator first, iterator last) { + if (first == last) { + return; + } + CHECK(data_ != nullptr) << "ValueError: cannot erase a null array"; + int64_t size = GetArrayNode()->size_; + int64_t st = std::distance(begin(), first); + int64_t ed = std::distance(begin(), last); + CHECK_LT(st, ed) << "ValueError: cannot erase array in range [" << st << ", " << ed << ")"; + CHECK(0 <= st && st <= size && 0 <= ed && ed <= size) + << "ValueError: cannot erase array in range [" << st << ", " << ed << ")" + << ", because array size is " << size; + CopyOnWrite() // + ->MoveElementsLeft(st, ed, size) // + ->ShrinkBy(ed - st); + } + + /*! + * \brief Resize the array. + * \param n The new size. + */ + void resize(int64_t n) { + CHECK_GE(n, 0) << "ValueError: cannot resize an Array to negative size"; + if (data_ == nullptr) { + SwitchContainer(n); + return; + } + int64_t size = GetArrayNode()->size_; + if (size < n) { + CopyOnWrite(n - size)->EnlargeBy(n - size); + } else if (size > n) { + CopyOnWrite()->ShrinkBy(size - n); + } + } + + /*! + * \brief Make sure the list has the capacity of at least n + * \param n lower bound of the capacity + */ + void reserve(int64_t n) { + if (data_ == nullptr || n > GetArrayNode()->capacity_) { + SwitchContainer(n); + } + } + + /*! \brief Release reference to all the elements */ + void clear() { + if (data_ != nullptr) { + ArrayNode* p = CopyOnWrite(); + p->clear(); + } + } + + public: + // Array's own methods + + /*! + * \brief set i-th element of the array. + * \param i The index + * \param value The value to be setted. + */ + void Set(int64_t i, T value) { + ArrayNode* p = this->CopyOnWrite(); + CHECK(0 <= i && i < p->size_) << "IndexError: indexing " << i << " on an array of size " + << p->size_; + *(p->MutableBegin() + i) = std::move(value); + } + + /*! \return The underlying ArrayNode */ + ArrayNode* GetArrayNode() const { return static_cast(data_.get()); } + + /*! + * \brief Helper function to apply fmutate to mutate an array. + * \param fmutate The transformation function T -> T. + * \tparam F the type of the mutation function. + * \note This function performs copy on write optimization. + */ + template + void MutateByApply(F fmutate) { + if (data_ == nullptr) { + return; + } + struct StackFrame { + ArrayNode* p; + ObjectRef* itr; + int64_t i; + int64_t size; + }; + std::unique_ptr s = std::make_unique(); + s->p = GetArrayNode(); + s->itr = s->p->MutableBegin(); + s->i = 0; + s->size = s->p->size_; + if (!data_.unique()) { + // Loop invariant: keeps iterating when + // 1) data is not unique + // 2) no elements are actually mutated yet + for (; s->i < s->size; ++s->i, ++s->itr) { + T new_elem = fmutate(DowncastNoCheck(*s->itr)); + // do nothing when there is no mutation + if (new_elem.same_as(*s->itr)) { + continue; + } + // loop invariant breaks when the first real mutation happens + // we copy the elements into a new unique array + ObjectPtr copy = ArrayNode::CopyFrom(s->p->capacity_, s->p); + s->itr = copy->MutableBegin() + (s->i++); + *s->itr++ = std::move(new_elem); + data_ = std::move(copy); + // make sure `data_` is unique and break + break; + } + } + // when execution comes to this line, it is guaranteed that either + // 1) i == size + // or 2) data_.unique() is true + for (; s->i < s->size; ++s->i, ++s->itr) { + *s->itr = std::move(fmutate(std::move(DowncastNoCheck(std::move(*s->itr))))); + } + } + + /*! + * \brief reset the array to content from iterator. + * \param first begin of iterator + * \param last end of iterator + * \tparam IterType The type of iterator + */ + template + void Assign(IterType first, IterType last) { + int64_t cap = std::distance(first, last); + CHECK_GE(cap, 0) << "ValueError: cannot construct an Array of negative size"; + ArrayNode* p = GetArrayNode(); + if (p != nullptr && data_.unique() && p->capacity_ >= cap) { + // do not have to make new space + p->clear(); + } else { + // create new space + data_ = ArrayNode::Empty(cap); + p = GetArrayNode(); + } + // To ensure exception safety, size is only incremented after the initialization succeeds + ObjectRef* itr = p->MutableBegin(); + for (int64_t& i = p->size_ = 0; i < cap; ++i, ++first, ++itr) { + new (itr) ObjectRef(*first); + } + } + + /*! + * \brief Copy on write semantics + * Do nothing if current handle is the unique copy of the array. + * Otherwise make a new copy of the array to ensure the current handle + * hold a unique copy. + * + * \return Handle to the internal node container(which ganrantees to be unique) + */ + ArrayNode* CopyOnWrite() { + if (data_ == nullptr) { + return SwitchContainer(ArrayNode::kInitSize); + } + if (!data_.unique()) { + return SwitchContainer(capacity()); + } + return static_cast(data_.get()); + } + + /*! \brief specify container node */ + using ContainerType = ArrayNode; + + private: + /*! + * \brief Implement copy-on-write semantics, and ensures capacity is enough for extra elements. + * \param reserve_extra Number of extra slots needed + * \return ArrayNode pointer to the unique copy + */ + ArrayNode* CopyOnWrite(int64_t reserve_extra) { + ArrayNode* p = GetArrayNode(); + if (p == nullptr) { + // necessary to get around the constexpr address issue before c++17 + const int64_t kInitSize = ArrayNode::kInitSize; + return SwitchContainer(std::max(kInitSize, reserve_extra)); + } + if (p->capacity_ >= p->size_ + reserve_extra) { + return CopyOnWrite(); + } + int64_t cap = p->capacity_ * ArrayNode::kIncFactor; + cap = std::max(cap, p->size_ + reserve_extra); + return SwitchContainer(cap); + } + + /*! + * \brief Move or copy the ArrayNode to new address with the given capacity + * \param capacity The capacity requirement of the new address + */ + ArrayNode* SwitchContainer(int64_t capacity) { + if (data_ == nullptr) { + data_ = ArrayNode::Empty(capacity); + } else if (data_.unique()) { + data_ = ArrayNode::MoveFrom(capacity, GetArrayNode()); + } else { + data_ = ArrayNode::CopyFrom(capacity, GetArrayNode()); + } + return static_cast(data_.get()); + } +}; + +// Specialize make_object to make sure it is correct. +template <> +inline ObjectPtr make_object() { + return ArrayNode::Empty(); +} + /*! \brief An object representing a structure or enumeration. */ class ADTObj : public Object, public InplaceArrayBase { public: @@ -200,8 +989,8 @@ class ADTObj : public Object, public InplaceArrayBase { uint32_t size; // The fields of the structure follows directly in memory. - static constexpr const uint32_t _type_index = TypeIndex::kVMADT; - static constexpr const char* _type_key = "vm.ADT"; + static constexpr const uint32_t _type_index = TypeIndex::kRuntimeADT; + static constexpr const char* _type_key = "runtime.ADT"; TVM_DECLARE_FINAL_OBJECT_INFO(ADTObj, Object); private: @@ -242,8 +1031,7 @@ class ADT : public ObjectRef { * \param fields The fields of the ADT object. * \return The constructed ADT object reference. */ - ADT(int32_t tag, std::vector fields) - : ADT(tag, fields.begin(), fields.end()){}; + ADT(int32_t tag, std::vector fields) : ADT(tag, fields.begin(), fields.end()){}; /*! * \brief construct an ADT object reference. @@ -267,8 +1055,7 @@ class ADT : public ObjectRef { * \param init The initializer list of fields. * \return The constructed ADT object reference. */ - ADT(int32_t tag, std::initializer_list init) - : ADT(tag, init.begin(), init.end()){}; + ADT(int32_t tag, std::initializer_list init) : ADT(tag, init.begin(), init.end()){}; /*! * \brief Access element at index. @@ -276,9 +1063,7 @@ class ADT : public ObjectRef { * \param idx The array index * \return const ObjectRef */ - const ObjectRef& operator[](size_t idx) const { - return operator->()->operator[](idx); - } + const ObjectRef& operator[](size_t idx) const { return operator->()->operator[](idx); } /*! * \brief Return the ADT tag. @@ -314,7 +1099,7 @@ class StringObj : public Object { /*! \brief The length of the string object. */ uint64_t size; - static constexpr const uint32_t _type_index = TypeIndex::kDynamic; + static constexpr const uint32_t _type_index = TypeIndex::kRuntimeString; static constexpr const char* _type_key = "runtime.String"; TVM_DECLARE_FINAL_OBJECT_INFO(StringObj, Object); @@ -381,45 +1166,80 @@ class String : public ObjectRef { * \param other The value for the new String * */ - inline String operator=(std::string other); + inline String& operator=(std::string other); /*! - * \brief Compare is equal to other std::string + * \brief Change the value the reference object points to. + * + * \param other The value for the new String + */ + inline String& operator=(const char* other); + + /*! + * \brief Compare is less than other std::string * * \param other The other string * * \return the comparison result */ - bool operator==(const std::string& other) const { - return this->compare(other) == 0; - } + bool operator<(const std::string& other) const { return this->compare(other) < 0; } + bool operator<(const String& other) const { return this->compare(other) < 0; } + bool operator<(const char* other) const { return this->compare(other) < 0; } /*! - * \brief Compare is not equal to other std::string + * \brief Compare is greater than other std::string + * + * \param other The other string + * + * \return the comparison result + */ + bool operator>(const std::string& other) const { return this->compare(other) > 0; } + bool operator>(const String& other) const { return this->compare(other) > 0; } + bool operator>(const char* other) const { return this->compare(other) > 0; } + + /*! + * \brief Compare is less than or equal to other std::string + * + * \param other The other string + * + * \return the comparison result + */ + bool operator<=(const std::string& other) const { return this->compare(other) <= 0; } + bool operator<=(const String& other) const { return this->compare(other) <= 0; } + bool operator<=(const char* other) const { return this->compare(other) <= 0; } + + /*! + * \brief Compare is greater than or equal to other std::string * * \param other The other string * * \return the comparison result */ - bool operator!=(const std::string& other) const { return !operator==(other); } + bool operator>=(const std::string& other) const { return this->compare(other) >= 0; } + bool operator>=(const String& other) const { return this->compare(other) >= 0; } + bool operator>=(const char* other) const { return this->compare(other) >= 0; } /*! - * \brief Compare is equal to other char string + * \brief Compare is equal to other std::string * - * \param other The other char string + * \param other The other string * * \return the comparison result */ + bool operator==(const std::string& other) const { return this->compare(other) == 0; } + bool operator==(const String& other) const { return this->compare(other) == 0; } bool operator==(const char* other) const { return compare(other) == 0; } /*! - * \brief Compare is not equal to other char string + * \brief Compare is not equal to other std::string * - * \param other The other char string + * \param other The other string * * \return the comparison result */ - bool operator!=(const char* other) const { return !operator==(other); } + bool operator!=(const std::string& other) const { return this->compare(other) != 0; } + bool operator!=(const String& other) const { return this->compare(other) != 0; } + bool operator!=(const char* other) const { return this->compare(other) != 0; } /*! * \brief Compares this String object to other @@ -496,12 +1316,29 @@ class String : public ObjectRef { const char* data() const { return get()->data; } /*! - * \brief Convert String to an std::sting object + * \brief Convert String to an std::string object * * \return std::string */ operator std::string() const { return std::string{get()->data, size()}; } + // LLVM compatibility function, implemented in src/target/llvm/llvm_common.h + /*! + * \brief Convert String to an llvm::StringRef object + * + * \return llvm::StringRef + */ + inline operator llvm::StringRef() const; + + /*! + * \brief Check if a TVMArgValue can be converted to String, i.e. it can be std::string or String + * \param val The value to be checked + * \return A boolean indicating if val can be converted to String + */ + static bool CanConvertFrom(const TVMArgValue& val) { + return val.type_code() == kTVMStr || val.IsObjectRef(); + } + /*! * \brief Hash the binary bytes * \param data The data pointer @@ -512,19 +1349,14 @@ class String : public ObjectRef { // This function falls back to string copy with c++11 compiler and is // recommended to be compiled with c++14 #if TVM_USE_CXX17_STRING_VIEW_HASH - return std::hash()( - std::string_view(data, size)); + return std::hash()(std::string_view(data, size)); #elif TVM_USE_CXX14_STRING_VIEW_HASH - return std::hash()( - std::experimental::string_view(data, size)); + return std::hash()(std::experimental::string_view(data, size)); #else return std::hash()(std::string(data, size)); #endif } - /*! \return the internal StringObj pointer */ - const StringObj* get() const { return operator->(); } - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(String, ObjectRef, StringObj); private: @@ -538,8 +1370,9 @@ class String : public ObjectRef { * \return int zero if both char sequences compare equal. negative if this * appear before other, positive otherwise. */ - static int memncmp(const char* lhs, const char* rhs, size_t lhs_count, - size_t rhs_count); + static int memncmp(const char* lhs, const char* rhs, size_t lhs_count, size_t rhs_count); + + friend struct tvm::ObjectEqual; }; /*! \brief An object representing string moved from std::string. */ @@ -569,14 +1402,24 @@ inline String::String(std::string other) { data_ = std::move(ptr); } -inline String String::operator=(std::string other) { +inline String& String::operator=(std::string other) { String replace{std::move(other)}; data_.swap(replace.data_); - return Downcast(*this); + return *this; +} + +inline String& String::operator=(const char* other) { return operator=(std::string(other)); } + +inline String operator+(const std::string lhs, const String& rhs) { + return lhs + rhs.operator std::string(); } -inline int String::memncmp(const char* lhs, const char* rhs, size_t lhs_count, - size_t rhs_count) { +inline std::ostream& operator<<(std::ostream& out, const String& input) { + out.write(input.data(), input.size()); + return out; +} + +inline int String::memncmp(const char* lhs, const char* rhs, size_t lhs_count, size_t rhs_count) { if (lhs == rhs && lhs_count == rhs_count) return 0; for (size_t i = 0; i < lhs_count && i < rhs_count; ++i) { @@ -592,7 +1435,7 @@ inline int String::memncmp(const char* lhs, const char* rhs, size_t lhs_count, } } -template<> +template <> struct PackedFuncValueConverter<::tvm::runtime::String> { static String From(const TVMArgValue& val) { if (val.IsObjectRef()) { @@ -611,6 +1454,9 @@ struct PackedFuncValueConverter<::tvm::runtime::String> { } }; +/*! \brief Helper to represent nullptr for optional. */ +struct NullOptType {}; + /*! * \brief Optional container that to represent to a Nullable variant of T. * \tparam T The original ObjectRef. @@ -624,12 +1470,11 @@ struct PackedFuncValueConverter<::tvm::runtime::String> { * * \endcode */ -template +template class Optional : public ObjectRef { public: using ContainerType = typename T::ContainerType; - static_assert(std::is_base_of::value, - "Optional is only defined for ObjectRef."); + static_assert(std::is_base_of::value, "Optional is only defined for ObjectRef."); // default constructors. Optional() = default; Optional(const Optional&) = default; @@ -642,6 +1487,8 @@ class Optional : public ObjectRef { * \param ptr */ explicit Optional(ObjectPtr ptr) : ObjectRef(ptr) {} + /*! \brief Nullopt handling */ + Optional(NullOptType) {} // NOLINT(*) // nullptr handling. // disallow implicit conversion as 0 can be implicitly converted to nullptr_t explicit Optional(std::nullptr_t) {} @@ -650,9 +1497,8 @@ class Optional : public ObjectRef { return *this; } // normal value handling. - Optional(T other) // NOLINT(*) - : ObjectRef(std::move(other)) { - } + Optional(T other) // NOLINT(*) + : ObjectRef(std::move(other)) {} Optional& operator=(T other) { ObjectRef::operator=(std::move(other)); return *this; @@ -674,20 +1520,13 @@ class Optional : public ObjectRef { * \return The contained value if the Optional is not null * otherwise return the default_value. */ - T value_or(T default_value) const { - return data_ != nullptr ? T(data_) : default_value; - } + T value_or(T default_value) const { return data_ != nullptr ? T(data_) : default_value; } + /*! \return Whether the container is not nullptr.*/ - explicit operator bool() const { - return *this != nullptr; - } + explicit operator bool() const { return *this != nullptr; } // operator overloadings - bool operator==(std::nullptr_t) const { - return data_ == nullptr; - } - bool operator!=(std::nullptr_t) const { - return data_ != nullptr; - } + bool operator==(std::nullptr_t) const { return data_ == nullptr; } + bool operator!=(std::nullptr_t) const { return data_ != nullptr; } auto operator==(const Optional& other) const { // support case where sub-class returns a symbolic ref type. using RetType = decltype(value() == other.value()); @@ -716,16 +1555,14 @@ class Optional : public ObjectRef { if (*this != nullptr) return value() == other; return RetType(false); } - auto operator!=(const T& other) const { - return !(*this == other); - } - template + auto operator!=(const T& other) const { return !(*this == other); } + template auto operator==(const U& other) const { using RetType = decltype(value() == other); if (*this == nullptr) return RetType(false); return value() == other; } - template + template auto operator!=(const U& other) const { using RetType = decltype(value() != other); if (*this == nullptr) return RetType(true); @@ -734,7 +1571,7 @@ class Optional : public ObjectRef { static constexpr bool _type_is_nullable = true; }; -template +template struct PackedFuncValueConverter> { static Optional From(const TVMArgValue& val) { if (val.type_code() == kTVMNullptr) return Optional(nullptr); @@ -749,8 +1586,9 @@ struct PackedFuncValueConverter> { } // namespace runtime // expose the functions to the root namespace. -using runtime::String; using runtime::Optional; +using runtime::String; +constexpr runtime::NullOptType NullOpt{}; } // namespace tvm namespace std { diff --git a/include/tvm/runtime/crt/memory.h b/include/tvm/runtime/crt/memory.h index 3e47060a86c4..7b88b3123644 100644 --- a/include/tvm/runtime/crt/memory.h +++ b/include/tvm/runtime/crt/memory.h @@ -32,7 +32,7 @@ static int vleak_size = 0; * \param size The size of memory * \return The virtual address */ -void * vmalloc(size_t size); +void* vmalloc(size_t size); /*! * \brief Reallocate memory from manager @@ -40,13 +40,13 @@ void * vmalloc(size_t size); * \param size The size of memory * \return The virtual address */ -void * vrealloc(void * ptr, size_t size); +void* vrealloc(void* ptr, size_t size); /*! * \brief Free the memory. * \param ptr The pointer to the memory to deallocate * \return The virtual address */ -void vfree(void * ptr); +void vfree(void* ptr); #endif // TVM_RUNTIME_CRT_MEMORY_H_ diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h index 44385d63263b..b12938bd751a 100644 --- a/include/tvm/runtime/data_type.h +++ b/include/tvm/runtime/data_type.h @@ -24,10 +24,11 @@ #ifndef TVM_RUNTIME_DATA_TYPE_H_ #define TVM_RUNTIME_DATA_TYPE_H_ -#include #include -#include +#include + #include +#include namespace tvm { namespace runtime { @@ -39,12 +40,20 @@ namespace runtime { */ class DataType { public: - /*! \brief Type code for the DataType. */ + /*! + * \brief Type code for the DataType. + * + * DLPack consistency: + * 1) kInt is consistent with kDLInt + * 2) kUInt is consistent with kDLUInt + * 3) kFloat is consistent with kDLFloat + */ enum TypeCode { kInt = kDLInt, kUInt = kDLUInt, kFloat = kDLFloat, - kHandle = TVMTypeCode::kTVMOpaqueHandle, + kHandle = TVMArgTypeCode::kTVMOpaqueHandle, + kCustomBegin = 129 }; /*! \brief default constructor */ DataType() {} @@ -52,8 +61,7 @@ class DataType { * \brief Constructor * \param dtype The DLDataType */ - explicit DataType(DLDataType dtype) - : data_(dtype) {} + explicit DataType(DLDataType dtype) : data_(dtype) {} /*! * \brief Constructor * \param code The type code. @@ -66,106 +74,70 @@ class DataType { data_.lanes = static_cast(lanes); } /*! \return The type code. */ - int code() const { - return static_cast(data_.code); - } + int code() const { return static_cast(data_.code); } /*! \return number of bits in the data. */ - int bits() const { - return static_cast(data_.bits); - } + int bits() const { return static_cast(data_.bits); } /*! \return number of bytes to store each scalar. */ - int bytes() const { - return (bits() + 7) / 8; - } + int bytes() const { return (bits() + 7) / 8; } /*! \return number of lanes in the data. */ - int lanes() const { - return static_cast(data_.lanes); - } + int lanes() const { return static_cast(data_.lanes); } /*! \return whether type is a scalar type. */ - bool is_scalar() const { - return lanes() == 1; - } + bool is_scalar() const { return lanes() == 1; } /*! \return whether type is a scalar type. */ - bool is_bool() const { - return code() == DataType::kUInt && bits() == 1; - } + bool is_bool() const { return code() == DataType::kUInt && bits() == 1; } /*! \return whether type is a float type. */ - bool is_float() const { - return code() == DataType::kFloat; - } + bool is_float() const { return code() == DataType::kFloat; } /*! \return whether type is a float16 type. */ - bool is_float16() const { - return is_float() && bits() == 16; - } + bool is_float16() const { return is_float() && bits() == 16; } /*! \return whether type is an int type. */ - bool is_int() const { - return code() == DataType::kInt; - } + bool is_int() const { return code() == DataType::kInt; } /*! \return whether type is an uint type. */ - bool is_uint() const { - return code() == DataType::kUInt; - } + bool is_uint() const { return code() == DataType::kUInt; } /*! \return whether type is a handle type. */ - bool is_handle() const { - return code() == DataType::kHandle; - } + bool is_handle() const { return code() == DataType::kHandle && !is_void(); } /*! \return whether type is a vector type. */ - bool is_vector() const { - return lanes() > 1; - } + bool is_vector() const { return lanes() > 1; } /*! \return whether type is a bool vector type. */ - bool is_vector_bool() const { - return is_vector() && bits() == 1; - } + bool is_vector_bool() const { return is_vector() && bits() == 1; } + /*! \return whether type is a Void type. */ + bool is_void() const { return code() == DataType::kHandle && bits() == 0 && lanes() == 0; } /*! * \brief Create a new data type by change lanes to a specified value. * \param lanes The target number of lanes. * \return the result type. */ - DataType with_lanes(int lanes) const { - return DataType(data_.code, data_.bits, lanes); - } + DataType with_lanes(int lanes) const { return DataType(data_.code, data_.bits, lanes); } /*! * \brief Create a new data type by change bits to a specified value. * \param bits The target number of bits. * \return the result type. */ - DataType with_bits(int bits) const { - return DataType(data_.code, bits, data_.lanes); - } + DataType with_bits(int bits) const { return DataType(data_.code, bits, data_.lanes); } /*! * \brief Get the scalar version of the type. * \return the result type. */ - DataType element_of() const { - return with_lanes(1); - } + DataType element_of() const { return with_lanes(1); } /*! * \brief Equal comparator. * \param other The data type to compre against. * \return The comparison resilt. */ bool operator==(const DataType& other) const { - return - data_.code == other.data_.code && - data_.bits == other.data_.bits && - data_.lanes == other.data_.lanes; + return data_.code == other.data_.code && data_.bits == other.data_.bits && + data_.lanes == other.data_.lanes; } /*! * \brief NotEqual comparator. * \param other The data type to compre against. * \return The comparison resilt. */ - bool operator!=(const DataType& other) const { - return !operator==(other); - } + bool operator!=(const DataType& other) const { return !operator==(other); } /*! * \brief Converter to DLDataType * \return the result. */ - operator DLDataType () const { - return data_; - } + operator DLDataType() const { return data_; } /*! * \brief Construct an int type. @@ -173,44 +145,39 @@ class DataType { * \param lanes The number of lanes. * \return The constructed data type. */ - static DataType Int(int bits, int lanes = 1) { - return DataType(kDLInt, bits, lanes); - } + static DataType Int(int bits, int lanes = 1) { return DataType(kDLInt, bits, lanes); } /*! * \brief Construct an uint type. * \param bits The number of bits in the type. * \param lanes The number of lanes * \return The constructed data type. */ - static DataType UInt(int bits, int lanes = 1) { - return DataType(kDLUInt, bits, lanes); - } + static DataType UInt(int bits, int lanes = 1) { return DataType(kDLUInt, bits, lanes); } /*! * \brief Construct an uint type. * \param bits The number of bits in the type. * \param lanes The number of lanes * \return The constructed data type. */ - static DataType Float(int bits, int lanes = 1) { - return DataType(kDLFloat, bits, lanes); - } + static DataType Float(int bits, int lanes = 1) { return DataType(kDLFloat, bits, lanes); } /*! * \brief Construct a bool type. * \param lanes The number of lanes * \return The constructed data type. */ - static DataType Bool(int lanes = 1) { - return DataType::UInt(1, lanes); - } + static DataType Bool(int lanes = 1) { return DataType::UInt(1, lanes); } /*! * \brief Construct a handle type. * \param bits The number of bits in the type. * \param lanes The number of lanes * \return The constructed data type. */ - static DataType Handle(int bits = 64, int lanes = 1) { - return DataType(kHandle, bits, lanes); - } + static DataType Handle(int bits = 64, int lanes = 1) { return DataType(kHandle, bits, lanes); } + /*! + * \brief Construct a Void type. + * \return The constructed data type. + */ + static DataType Void() { return DataType(kHandle, 0, 0); } /*! * \brief Get the corresponding type of TVMShapeIndex. * \return The type of TVM shape index. @@ -235,14 +202,11 @@ class DataType { inline int GetVectorBytes(DataType dtype) { int data_bits = dtype.bits() * dtype.lanes(); // allow bool to exist - if (dtype == DataType::Bool() || - dtype == DataType::Int(4) || - dtype == DataType::UInt(4) || + if (dtype == DataType::Bool() || dtype == DataType::Int(4) || dtype == DataType::UInt(4) || dtype == DataType::Int(1)) { return 1; } - CHECK_EQ(data_bits % 8, 0U) - << "Need to load/store by multiple of bytes"; + CHECK_EQ(data_bits % 8, 0U) << "Need to load/store by multiple of bytes"; return data_bits / 8; } @@ -292,7 +256,7 @@ TVM_DLL uint8_t ParseCustomDatatype(const std::string& s, const char** scan); * \param type_code The type code . * \return The name of type code. */ -inline const char* TypeCode2Str(int type_code); +inline const char* DLDataTypeCode2Str(DLDataTypeCode type_code); /*! * \brief convert a string to TVM type. @@ -309,34 +273,32 @@ inline DLDataType String2DLDataType(std::string s); inline std::string DLDataType2String(DLDataType t); // implementation details -inline const char* TypeCode2Str(int type_code) { - switch (type_code) { - case kDLInt: return "int"; - case kDLUInt: return "uint"; - case kDLFloat: return "float"; - case kTVMStr: return "str"; - case kTVMBytes: return "bytes"; - case kTVMOpaqueHandle: return "handle"; - case kTVMNullptr: return "NULL"; - case kTVMDLTensorHandle: return "ArrayHandle"; - case kTVMDataType: return "DLDataType"; - case kTVMContext: return "TVMContext"; - case kTVMPackedFuncHandle: return "FunctionHandle"; - case kTVMModuleHandle: return "ModuleHandle"; - case kTVMNDArrayHandle: return "NDArrayContainer"; - case kTVMObjectHandle: return "Object"; - case kTVMObjectRValueRefArg: return "ObjectRValueRefArg"; - default: LOG(FATAL) << "unknown type_code=" - << static_cast(type_code); return ""; +inline const char* DLDataTypeCode2Str(DLDataTypeCode type_code) { + switch (static_cast(type_code)) { + case kDLInt: + return "int"; + case kDLUInt: + return "uint"; + case kDLFloat: + return "float"; + case DataType::kHandle: + return "handle"; + default: + LOG(FATAL) << "unknown type_code=" << static_cast(type_code); + return ""; } } inline std::ostream& operator<<(std::ostream& os, DLDataType t) { // NOLINT(*) if (t.bits == 1 && t.lanes == 1 && t.code == kDLUInt) { - os << "bool"; return os; + os << "bool"; + return os; } - if (t.code < kTVMCustomBegin) { - os << TypeCode2Str(t.code); + if (DataType(t).is_void()) { + return os << "void"; + } + if (t.code < DataType::kCustomBegin) { + os << DLDataTypeCode2Str(static_cast(t.code)); } else { os << "custom[" << GetCustomTypeName(t.code) << "]"; } @@ -348,7 +310,7 @@ inline std::ostream& operator<<(std::ostream& os, DLDataType t) { // NOLINT(*) return os; } -inline std::ostream& operator<<(std::ostream& os, const DataType& dtype) { // NOLINT(*) +inline std::ostream& operator<<(std::ostream& os, const DataType& dtype) { // NOLINT(*) return os << dtype.operator DLDataType(); } @@ -361,19 +323,23 @@ inline std::string DLDataType2String(DLDataType t) { inline DLDataType String2DLDataType(std::string s) { DLDataType t; - // handle None type + // handle void type if (s.length() == 0) { - t.bits = 0; t.lanes = 0; t.code = kTVMOpaqueHandle; + t = DataType::Void(); return t; } - t.bits = 32; t.lanes = 1; + t.bits = 32; + t.lanes = 1; const char* scan; if (s.substr(0, 3) == "int") { - t.code = kDLInt; scan = s.c_str() + 3; + t.code = kDLInt; + scan = s.c_str() + 3; } else if (s.substr(0, 4) == "uint") { - t.code = kDLUInt; scan = s.c_str() + 4; + t.code = kDLUInt; + scan = s.c_str() + 4; } else if (s.substr(0, 5) == "float") { - t.code = kDLFloat; scan = s.c_str() + 5; + t.code = kDLFloat; + scan = s.c_str() + 5; } else if (s.substr(0, 6) == "handle") { t.code = kTVMOpaqueHandle; t.bits = 64; // handle uses 64 bit by default. diff --git a/include/tvm/runtime/device_api.h b/include/tvm/runtime/device_api.h index f2ddc84e9f98..421811a52c3b 100644 --- a/include/tvm/runtime/device_api.h +++ b/include/tvm/runtime/device_api.h @@ -26,6 +26,7 @@ #include #include + #include namespace tvm { @@ -85,9 +86,7 @@ class TVM_DLL DeviceAPI { * as OpenGL, as nbytes & alignment are sufficient for most backends. * \return The allocated device pointer. */ - virtual void* AllocDataSpace(TVMContext ctx, - size_t nbytes, - size_t alignment, + virtual void* AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment, DLDataType type_hint) = 0; /*! * \brief Free a data space on device. @@ -108,16 +107,10 @@ class TVM_DLL DeviceAPI { * can be useful for cross device endian converison. * \param stream Optional stream object. */ - virtual void CopyDataFromTo(const void* from, - size_t from_offset, - void* to, - size_t to_offset, - size_t num_bytes, - TVMContext ctx_from, - TVMContext ctx_to, - DLDataType type_hint, - TVMStreamHandle stream) = 0; - /*! + virtual void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, + size_t num_bytes, TVMContext ctx_from, TVMContext ctx_to, + DLDataType type_hint, TVMStreamHandle stream) = 0; + /*! * \brief Create a new stream of execution. * * \param ctx The context of allocation. @@ -156,9 +149,8 @@ class TVM_DLL DeviceAPI { * \param event_src The source stream to synchronize. * \param event_dst The destination stream to synchronize. */ - virtual void SyncStreamFromTo(TVMContext ctx, - TVMStreamHandle event_src, - TVMStreamHandle event_dst); + virtual void SyncStreamFromTo(TVMContext ctx, TVMStreamHandle event_src, + TVMStreamHandle event_dst); /*! * \brief Allocate temporal workspace for backend execution. * @@ -175,9 +167,7 @@ class TVM_DLL DeviceAPI { * \param type_hint The type of elements. Only needed by certain backends such * as OpenGL, as nbytes is sufficient for most backends. */ - virtual void* AllocWorkspace(TVMContext ctx, - size_t nbytes, - DLDataType type_hint = {}); + virtual void* AllocWorkspace(TVMContext ctx, size_t nbytes, DLDataType type_hint = {}); /*! * \brief Free temporal workspace in backend execution. * @@ -187,7 +177,7 @@ class TVM_DLL DeviceAPI { virtual void FreeWorkspace(TVMContext ctx, void* ptr); /*! - * \brief Get device API base don context. + * \brief Get device API based on context. * \param ctx The context * \param allow_missing Whether allow missing * \return The corresponding device API. @@ -214,21 +204,37 @@ constexpr int kRPCSessMask = 128; */ inline const char* DeviceName(int type) { switch (type) { - case kDLCPU: return "cpu"; - case kDLGPU: return "gpu"; - case kDLCPUPinned: return "cpu_pinned"; - case kDLOpenCL: return "opencl"; - case kDLSDAccel: return "sdaccel"; - case kDLAOCL: return "aocl"; - case kDLVulkan: return "vulkan"; - case kDLMetal: return "metal"; - case kDLVPI: return "vpi"; - case kDLROCM: return "rocm"; - case kOpenGL: return "opengl"; - case kDLExtDev: return "ext_dev"; - case kDLMicroDev: return "micro_dev"; - case kDLHexagon: return "hexagon"; - default: LOG(FATAL) << "unknown type =" << type; return "Unknown"; + case kDLCPU: + return "cpu"; + case kDLGPU: + return "gpu"; + case kDLCPUPinned: + return "cpu_pinned"; + case kDLOpenCL: + return "opencl"; + case kDLSDAccel: + return "sdaccel"; + case kDLAOCL: + return "aocl"; + case kDLVulkan: + return "vulkan"; + case kDLMetal: + return "metal"; + case kDLVPI: + return "vpi"; + case kDLROCM: + return "rocm"; + case kDLExtDev: + return "ext_dev"; + case kDLWebGPU: + return "webgpu"; + case kDLMicroDev: + return "micro_dev"; + case kDLHexagon: + return "hexagon"; + default: + LOG(FATAL) << "unknown type =" << type; + return "Unknown"; } } diff --git a/include/tvm/runtime/memory.h b/include/tvm/runtime/memory.h index 121dbdde37a6..1199c420f212 100644 --- a/include/tvm/runtime/memory.h +++ b/include/tvm/runtime/memory.h @@ -24,9 +24,10 @@ #define TVM_RUNTIME_MEMORY_H_ #include + #include -#include #include +#include namespace tvm { namespace runtime { @@ -36,7 +37,7 @@ namespace runtime { * \tparam T the node type. * \return The ObjectPtr to the allocated object. */ -template +template inline ObjectPtr make_object(Args&&... args); // Detail implementations after this @@ -55,7 +56,7 @@ inline ObjectPtr make_object(Args&&... args); * * \tparam Derived The derived class. */ -template +template class ObjAllocatorBase { public: /*! @@ -64,13 +65,11 @@ class ObjAllocatorBase { * \tparam Args The constructor signature. * \param args The arguments. */ - template + template inline ObjectPtr make_object(Args&&... args) { using Handler = typename Derived::template Handler; - static_assert(std::is_base_of::value, - "make can only be used to create Object"); - T* ptr = Handler::New(static_cast(this), - std::forward(args)...); + static_assert(std::is_base_of::value, "make can only be used to create Object"); + T* ptr = Handler::New(static_cast(this), std::forward(args)...); ptr->type_index_ = T::RuntimeTypeIndex(); ptr->deleter_ = Handler::Deleter(); return ObjectPtr(ptr); @@ -83,14 +82,13 @@ class ObjAllocatorBase { * \param num_elems The number of array elements. * \param args The arguments. */ - template + template inline ObjectPtr make_inplace_array(size_t num_elems, Args&&... args) { using Handler = typename Derived::template ArrayHandler; static_assert(std::is_base_of::value, "make_inplace_array can only be used to create Object"); - ArrayType* ptr = Handler::New(static_cast(this), - num_elems, - std::forward(args)...); + ArrayType* ptr = + Handler::New(static_cast(this), num_elems, std::forward(args)...); ptr->type_index_ = ArrayType::RuntimeTypeIndex(); ptr->deleter_ = Handler::Deleter(); return ObjectPtr(ptr); @@ -98,15 +96,14 @@ class ObjAllocatorBase { }; // Simple allocator that uses new/delete. -class SimpleObjAllocator : - public ObjAllocatorBase { +class SimpleObjAllocator : public ObjAllocatorBase { public: - template + template class Handler { public: using StorageType = typename std::aligned_storage::type; - template + template static T* New(SimpleObjAllocator*, Args&&... args) { // NOTE: the first argument is not needed for SimpleObjAllocator // It is reserved for special allocators that needs to recycle @@ -126,9 +123,7 @@ class SimpleObjAllocator : return reinterpret_cast(data); } - static Object::FDeleter Deleter() { - return Deleter_; - } + static Object::FDeleter Deleter() { return Deleter_; } private: static void Deleter_(Object* objptr) { @@ -146,16 +141,16 @@ class SimpleObjAllocator : }; // Array handler that uses new/delete. - template + template class ArrayHandler { public: using StorageType = typename std::aligned_storage::type; // for now only support elements that aligns with array header. static_assert(alignof(ArrayType) % alignof(ElemType) == 0 && - sizeof(ArrayType) % alignof(ElemType) == 0, + sizeof(ArrayType) % alignof(ElemType) == 0, "element alignment constraint"); - template + template static ArrayType* New(SimpleObjAllocator*, size_t num_elems, Args&&... args) { // NOTE: the first argument is not needed for ArrayObjAllocator // It is reserved for special allocators that needs to recycle @@ -177,9 +172,7 @@ class SimpleObjAllocator : return reinterpret_cast(data); } - static Object::FDeleter Deleter() { - return Deleter_; - } + static Object::FDeleter Deleter() { return Deleter_; } private: static void Deleter_(Object* objptr) { @@ -193,20 +186,20 @@ class SimpleObjAllocator : // call a virtual destructor(which may not be available and is not required). tptr->ArrayType::~ArrayType(); StorageType* p = reinterpret_cast(tptr); - delete []p; + delete[] p; } }; }; -template +template inline ObjectPtr make_object(Args&&... args) { return SimpleObjAllocator().make_object(std::forward(args)...); } -template +template inline ObjectPtr make_inplace_array_object(size_t num_elems, Args&&... args) { - return SimpleObjAllocator().make_inplace_array( - num_elems, std::forward(args)...); + return SimpleObjAllocator().make_inplace_array(num_elems, + std::forward(args)...); } } // namespace runtime diff --git a/include/tvm/runtime/module.h b/include/tvm/runtime/module.h index ee50f71f451f..0e7cd2b08784 100644 --- a/include/tvm/runtime/module.h +++ b/include/tvm/runtime/module.h @@ -27,15 +27,14 @@ #define TVM_RUNTIME_MODULE_H_ #include - #include -#include #include +#include #include -#include #include #include +#include namespace tvm { namespace runtime { @@ -50,8 +49,7 @@ class Module : public ObjectRef { public: Module() {} // constructor from container. - explicit Module(ObjectPtr n) - : ObjectRef(n) {} + explicit Module(ObjectPtr n) : ObjectRef(n) {} /*! * \brief Get packed function from current module by name. * @@ -82,17 +80,7 @@ class Module : public ObjectRef { * \note This function won't load the import relationship. * Re-create import relationship by calling Import. */ - TVM_DLL static Module LoadFromFile(const std::string& file_name, - const std::string& format = ""); - /*! - * \brief Return whether the Module::node_ is a nullptr. - * This is necessary to check after compiling a model using TVM with - * an external accelerator, e.g. TensorRT. When all the operators are - * supported in TensorRT, there is no code generation for any operators - * in the network and thus, the Module is empty. - */ - TVM_DLL bool IsEmpty() const; - + TVM_DLL static Module LoadFromFile(const std::string& file_name, const std::string& format = ""); // refer to the corresponding container. using ContainerType = ModuleNode; friend class ModuleNode; @@ -146,16 +134,14 @@ class TVM_DLL ModuleNode : public Object { * If the function need resource from the module(e.g. late linking), * it should capture sptr_to_self. */ - virtual PackedFunc GetFunction( - const std::string& name, - const ObjectPtr& sptr_to_self) = 0; + virtual PackedFunc GetFunction(const std::string& name, + const ObjectPtr& sptr_to_self) = 0; /*! * \brief Save the module to file. * \param file_name The file to be saved to. * \param format The format of the file. */ - virtual void SaveToFile(const std::string& file_name, - const std::string& format); + virtual void SaveToFile(const std::string& file_name, const std::string& format); /*! * \brief Save the module to binary stream. * \param stream The binary stream to save to. @@ -197,9 +183,7 @@ class TVM_DLL ModuleNode : public Object { */ const PackedFunc* GetFuncFromEnv(const std::string& name); /*! \return The module it imports from */ - const std::vector& imports() const { - return imports_; - } + const std::vector& imports() const { return imports_; } // integration with the existing components. static constexpr const uint32_t _type_index = TypeIndex::kRuntimeModule; @@ -216,8 +200,7 @@ class TVM_DLL ModuleNode : public Object { private: /*! \brief Cache used by GetImport */ - std::unordered_map > import_cache_; + std::unordered_map > import_cache_; }; /*! @@ -247,13 +230,9 @@ constexpr const char* tvm_module_main = "__tvm_main__"; // implementations of inline functions. -inline void Module::Import(Module other) { - return (*this)->Import(other); -} +inline void Module::Import(Module other) { return (*this)->Import(other); } -inline ModuleNode* Module::operator->() { - return static_cast(get_mutable()); -} +inline ModuleNode* Module::operator->() { return static_cast(get_mutable()); } inline const ModuleNode* Module::operator->() const { return static_cast(get()); @@ -263,4 +242,4 @@ inline const ModuleNode* Module::operator->() const { } // namespace tvm #include // NOLINT(*) -#endif // TVM_RUNTIME_MODULE_H_ +#endif // TVM_RUNTIME_MODULE_H_ diff --git a/include/tvm/runtime/ndarray.h b/include/tvm/runtime/ndarray.h index 17f81a2a8b68..e69d802652fd 100644 --- a/include/tvm/runtime/ndarray.h +++ b/include/tvm/runtime/ndarray.h @@ -25,12 +25,13 @@ #define TVM_RUNTIME_NDARRAY_H_ #include +#include #include #include #include -#include #include +#include namespace tvm { namespace runtime { @@ -53,8 +54,7 @@ class NDArray : public ObjectRef { * \brief constructor. * \param data ObjectPtr to the data container. */ - explicit NDArray(ObjectPtr data) - : ObjectRef(data) {} + explicit NDArray(ObjectPtr data) : ObjectRef(data) {} /*! \brief reset the content of NDArray to be nullptr */ inline void reset(); @@ -76,13 +76,13 @@ class NDArray : public ObjectRef { inline void CopyFrom(const DLTensor* other); inline void CopyFrom(const NDArray& other); /*! - * \brief Copy data content from a byte buffer. - * \param data The source bytes to be copied from. - * \param nbytes The size of the buffer in bytes - * Must be equal to the size of the NDArray. - * \note The copy may happen asynchronously if it involves a GPU context. - * TVMSynchronize is necessary. - */ + * \brief Copy data content from a byte buffer. + * \param data The source bytes to be copied from. + * \param nbytes The size of the buffer in bytes + * Must be equal to the size of the NDArray. + * \note The copy may happen asynchronously if it involves a GPU context. + * TVMSynchronize is necessary. + */ TVM_DLL void CopyFromBytes(const void* data, size_t nbytes); /*! * \brief Copy data content into another array. @@ -124,8 +124,7 @@ class NDArray : public ObjectRef { * \param dtype The data type of the new array. * \note The memory size of new array must be smaller than the current one. */ - TVM_DLL NDArray CreateView( - std::vector shape, DLDataType dtype); + TVM_DLL NDArray CreateView(std::vector shape, DLDataType dtype); /*! * \brief Create a reference view of NDArray that * represents as DLManagedTensor. @@ -139,9 +138,7 @@ class NDArray : public ObjectRef { * \param ctx The context of the Array. * \return The created Array */ - TVM_DLL static NDArray Empty(std::vector shape, - DLDataType dtype, - DLContext ctx); + TVM_DLL static NDArray Empty(std::vector shape, DLDataType dtype, DLContext ctx); /*! * \brief Create a NDArray backed by a dlpack tensor. * @@ -160,10 +157,11 @@ class NDArray : public ObjectRef { * \param to The target array. * \param stream The stream used in copy. */ - TVM_DLL static void CopyFromTo( - const DLTensor* from, DLTensor* to, TVMStreamHandle stream = nullptr); + TVM_DLL static void CopyFromTo(const DLTensor* from, DLTensor* to, + TVMStreamHandle stream = nullptr); TVM_DLL std::vector Shape() const; + TVM_DLL runtime::DataType DataType() const; // internal namespace struct Internal; @@ -244,9 +242,7 @@ class NDArray::ContainerBase { * \brief Object container class that backs NDArray. * \note do not use this function directly, use NDArray. */ -class NDArray::Container : - public Object, - public NDArray::ContainerBase { +class NDArray::Container : public Object, public NDArray::ContainerBase { public: /*! \brief default constructor */ Container() { @@ -259,10 +255,7 @@ class NDArray::Container : dl_tensor.byte_offset = 0; } - Container(void* data, - std::vector shape, - DLDataType dtype, - DLContext ctx) { + Container(void* data, std::vector shape, DLDataType dtype, DLContext ctx) { // Initialize the type index. type_index_ = Container::RuntimeTypeIndex(); dl_tensor.data = data; @@ -278,9 +271,7 @@ class NDArray::Container : * \brief Set the deleter field. * \param deleter The deleter. */ - void SetDeleter(FDeleter deleter) { - deleter_ = deleter; - } + void SetDeleter(FDeleter deleter) { deleter_ = deleter; } // Expose DecRef and IncRef as public function // NOTE: they are only for developer purposes only. @@ -288,10 +279,10 @@ class NDArray::Container : using Object::IncRef; // Information for object protocol. - static constexpr const uint32_t _type_index = TypeIndex::kDynamic; + static constexpr const uint32_t _type_index = TypeIndex::kRuntimeNDArray; static constexpr const uint32_t _type_child_slots = 0; static constexpr const uint32_t _type_child_slots_can_overflow = true; - static constexpr const char* _type_key = "NDArray"; + static constexpr const char* _type_key = "runtime.NDArray"; TVM_DECLARE_BASE_OBJECT_INFO(NDArray::Container, Object); protected: @@ -360,53 +351,44 @@ inline void NDArray::CopyTo(const NDArray& other) const { inline NDArray NDArray::CopyTo(const DLContext& ctx) const { CHECK(data_ != nullptr); const DLTensor* dptr = operator->(); - NDArray ret = Empty(std::vector(dptr->shape, dptr->shape + dptr->ndim), - dptr->dtype, ctx); + NDArray ret = + Empty(std::vector(dptr->shape, dptr->shape + dptr->ndim), dptr->dtype, ctx); this->CopyTo(ret); return ret; } -inline int NDArray::use_count() const { - return data_.use_count(); -} +inline int NDArray::use_count() const { return data_.use_count(); } -inline const DLTensor* NDArray::operator->() const { - return &(get_mutable()->dl_tensor); -} +inline const DLTensor* NDArray::operator->() const { return &(get_mutable()->dl_tensor); } inline NDArray::Container* NDArray::get_mutable() const { return static_cast(data_.get()); } inline ObjectPtr NDArray::FFIDataFromHandle(TVMArrayHandle handle) { - return GetObjectPtr(static_cast( - reinterpret_cast(handle))); + return GetObjectPtr( + static_cast(reinterpret_cast(handle))); } inline TVMArrayHandle NDArray::FFIGetHandle(const ObjectRef& nd) { // NOTE: it is necessary to cast to container then to base // so that the FFI handle uses the ContainerBase address. - return reinterpret_cast( - static_cast( - static_cast( - const_cast(nd.get())))); + return reinterpret_cast(static_cast( + static_cast(const_cast(nd.get())))); } inline void NDArray::FFIDecRef(TVMArrayHandle handle) { - static_cast( - reinterpret_cast(handle))->DecRef(); + static_cast(reinterpret_cast(handle))->DecRef(); } inline Object* TVMArrayHandleToObjectHandle(TVMArrayHandle handle) { - return static_cast( - reinterpret_cast(handle)); + return static_cast(reinterpret_cast(handle)); } /*! \brief Magic number for NDArray file */ constexpr uint64_t kTVMNDArrayMagic = 0xDD5E40F096B4A13F; -inline bool SaveDLTensor(dmlc::Stream* strm, - const DLTensor* tensor) { +inline bool SaveDLTensor(dmlc::Stream* strm, const DLTensor* tensor) { uint64_t header = kTVMNDArrayMagic, reserved = 0; strm->Write(header); strm->Write(reserved); @@ -435,16 +417,15 @@ inline bool SaveDLTensor(dmlc::Stream* strm, int64_t data_byte_size = type_bytes * num_elems; strm->Write(data_byte_size); - if (DMLC_IO_NO_ENDIAN_SWAP && - tensor->ctx.device_type == kDLCPU && - tensor->strides == nullptr && + if (DMLC_IO_NO_ENDIAN_SWAP && tensor->ctx.device_type == kDLCPU && tensor->strides == nullptr && tensor->byte_offset == 0) { // quick path strm->Write(tensor->data, data_byte_size); } else { std::vector bytes(data_byte_size); - CHECK_EQ(TVMArrayCopyToBytes( - const_cast(tensor), dmlc::BeginPtr(bytes), data_byte_size), 0) + CHECK_EQ( + TVMArrayCopyToBytes(const_cast(tensor), dmlc::BeginPtr(bytes), data_byte_size), + 0) << TVMGetLastError(); if (!DMLC_IO_NO_ENDIAN_SWAP) { dmlc::ByteSwap(dmlc::BeginPtr(bytes), type_bytes, num_elems); @@ -454,33 +435,23 @@ inline bool SaveDLTensor(dmlc::Stream* strm, return true; } -inline void NDArray::Save(dmlc::Stream* strm) const { - SaveDLTensor(strm, operator->()); -} +inline void NDArray::Save(dmlc::Stream* strm) const { SaveDLTensor(strm, operator->()); } inline bool NDArray::Load(dmlc::Stream* strm) { uint64_t header, reserved; - CHECK(strm->Read(&header)) - << "Invalid DLTensor file format"; - CHECK(strm->Read(&reserved)) - << "Invalid DLTensor file format"; - CHECK(header == kTVMNDArrayMagic) - << "Invalid DLTensor file format"; + CHECK(strm->Read(&header)) << "Invalid DLTensor file format"; + CHECK(strm->Read(&reserved)) << "Invalid DLTensor file format"; + CHECK(header == kTVMNDArrayMagic) << "Invalid DLTensor file format"; DLContext ctx; int ndim; DLDataType dtype; - CHECK(strm->Read(&ctx)) - << "Invalid DLTensor file format"; - CHECK(strm->Read(&ndim)) - << "Invalid DLTensor file format"; - CHECK(strm->Read(&dtype)) - << "Invalid DLTensor file format"; - CHECK_EQ(ctx.device_type, kDLCPU) - << "Invalid DLTensor context: can only save as CPU tensor"; + CHECK(strm->Read(&ctx)) << "Invalid DLTensor file format"; + CHECK(strm->Read(&ndim)) << "Invalid DLTensor file format"; + CHECK(strm->Read(&dtype)) << "Invalid DLTensor file format"; + CHECK_EQ(ctx.device_type, kDLCPU) << "Invalid DLTensor context: can only save as CPU tensor"; std::vector shape(ndim); if (ndim != 0) { - CHECK(strm->ReadArray(&shape[0], ndim)) - << "Invalid DLTensor file format"; + CHECK(strm->ReadArray(&shape[0], ndim)) << "Invalid DLTensor file format"; } NDArray ret = NDArray::Empty(shape, dtype, ctx); int64_t num_elems = 1; @@ -489,12 +460,13 @@ inline bool NDArray::Load(dmlc::Stream* strm) { num_elems *= ret->shape[i]; } int64_t data_byte_size; - CHECK(strm->Read(&data_byte_size)) - << "Invalid DLTensor file format"; - CHECK(data_byte_size == num_elems * elem_bytes) - << "Invalid DLTensor file format"; - CHECK(strm->Read(ret->data, data_byte_size)) - << "Invalid DLTensor file format"; + CHECK(strm->Read(&data_byte_size)) << "Invalid DLTensor file format"; + CHECK(data_byte_size == num_elems * elem_bytes) << "Invalid DLTensor file format"; + auto read_ret = strm->Read(ret->data, data_byte_size); + // Only check non-empty data + if (ndim > 0 && shape[0] != 0) { + CHECK(read_ret) << "Invalid DLTensor file format"; + } if (!DMLC_IO_NO_ENDIAN_SWAP) { dmlc::ByteSwap(ret->data, elem_bytes, num_elems); } diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h index edca925baeb0..483ad6b63794 100644 --- a/include/tvm/runtime/object.h +++ b/include/tvm/runtime/object.h @@ -25,8 +25,9 @@ #include #include -#include + #include +#include #include /*! @@ -46,17 +47,33 @@ namespace tvm { namespace runtime { -/*! \brief list of the type index. */ -enum TypeIndex { - /*! \brief Root object type. */ - kRoot = 0, - kClosure = 1, - kVMADT = 2, - kRuntimeModule = 3, - kStaticIndexEnd, - /*! \brief Type index is allocated during runtime. */ - kDynamic = kStaticIndexEnd -}; +/*! + * \brief Namespace for the list of type index. + * \note Use struct so that we have to use TypeIndex::ENumName to refer to + * the constant, but still able to use enum. + */ +struct TypeIndex { + enum { + /*! \brief Root object type. */ + kRoot = 0, + // Standard static index assignments, + // Frontends can take benefit of these constants. + /*! \brief runtime::Module. */ + kRuntimeModule = 1, + /*! \brief runtime::NDArray. */ + kRuntimeNDArray = 2, + /*! \brief runtime::String. */ + kRuntimeString = 3, + /*! \brief runtime::Array. */ + kRuntimeArray = 4, + // static assignments that may subject to change. + kRuntimeClosure, + kRuntimeADT, + kStaticIndexEnd, + /*! \brief Type index is allocated during runtime. */ + kDynamic = kStaticIndexEnd + }; +}; // namespace TypeIndex /*! * \brief base class of all object containers. @@ -71,7 +88,7 @@ enum TypeIndex { * The unique string identifier of tyep type. * - _type_final: * Whether the type is terminal type(there is no subclass of the type in the object system). - * This field is automatically set by marco TVM_DECLARE_FINAL_OBJECT_INFO + * This field is automatically set by macro TVM_DECLARE_FINAL_OBJECT_INFO * It is still OK to sub-class a terminal object type T and construct it using make_object. * But IsInstance check will only show that the object type is T(instead of the sub-class). * @@ -86,8 +103,8 @@ enum TypeIndex { * Recommendation: set to estimate number of children needed. * - _type_child_slots_can_overflow: * Whether we can add additional child classes even if the number of child classes - * exceeds the _type_child_slots. A fallback mechanism to check global type table will be used. - * Recommendation: set to false for optimal runtime speed if we know exact number of children. + * exceeds the _type_child_slots. A fallback mechanism to check global type table will be + * used. Recommendation: set to false for optimal runtime speed if we know exact number of children. * * Two macros are used to declare helper functions in the object: * - Use TVM_DECLARE_BASE_OBJECT_INFO for object classes that can be sub-classed. @@ -149,28 +166,22 @@ class Object { */ typedef void (*FDeleter)(Object* self); /*! \return The internal runtime type index of the object. */ - uint32_t type_index() const { - return type_index_; - } + uint32_t type_index() const { return type_index_; } /*! * \return the type key of the object. * \note this operation is expensive, can be used for error reporting. */ - std::string GetTypeKey() const { - return TypeIndex2Key(type_index_); - } + std::string GetTypeKey() const { return TypeIndex2Key(type_index_); } /*! * \return A hash value of the return of GetTypeKey. */ - size_t GetTypeKeyHash() const { - return TypeIndex2KeyHash(type_index_); - } + size_t GetTypeKeyHash() const { return TypeIndex2KeyHash(type_index_); } /*! * Check if the object is an instance of TargetType. * \tparam TargetType The target type to be checked. * \return Whether the target type is true. */ - template + template inline bool IsInstance() const; /*! @@ -198,14 +209,10 @@ class Object { using RefCounterType = int32_t; #endif - static constexpr const char* _type_key = "Object"; + static constexpr const char* _type_key = "runtime.Object"; - static uint32_t _GetOrAllocRuntimeTypeIndex() { - return TypeIndex::kRoot; - } - static uint32_t RuntimeTypeIndex() { - return TypeIndex::kRoot; - } + static uint32_t _GetOrAllocRuntimeTypeIndex() { return TypeIndex::kRoot; } + static uint32_t RuntimeTypeIndex() { return TypeIndex::kRoot; } // Default object type properties for sub-classes static constexpr bool _type_final = false; @@ -220,7 +227,6 @@ class Object { // The type index of Object is TypeIndex::kRoot static constexpr uint32_t _type_index = TypeIndex::kDynamic; - // Default constructor and copy constructor Object() {} // Override the copy and assign constructors to do nothing. @@ -232,10 +238,10 @@ class Object { } Object(Object&& other) { // NOLINT(*) } - Object& operator=(const Object& other) { //NOLINT(*) + Object& operator=(const Object& other) { // NOLINT(*) return *this; } - Object& operator=(Object&& other) { //NOLINT(*) + Object& operator=(Object&& other) { // NOLINT(*) return *this; } @@ -253,7 +259,7 @@ class Object { FDeleter deleter_ = nullptr; // Invariant checks. static_assert(sizeof(int32_t) == sizeof(RefCounterType) && - alignof(int32_t) == sizeof(RefCounterType), + alignof(int32_t) == sizeof(RefCounterType), "RefCounter ABI check."); /*! @@ -273,12 +279,10 @@ class Object { * \param type_child_slots_can_overflow Whether to allow child to overflow the slots. * \return The allocated type index. */ - TVM_DLL static uint32_t GetOrAllocRuntimeTypeIndex( - const std::string& key, - uint32_t static_tindex, - uint32_t parent_tindex, - uint32_t type_child_slots, - bool type_child_slots_can_overflow); + TVM_DLL static uint32_t GetOrAllocRuntimeTypeIndex(const std::string& key, uint32_t static_tindex, + uint32_t parent_tindex, + uint32_t type_child_slots, + bool type_child_slots_can_overflow); // reference counter related operations /*! \brief developer function, increases reference counter. */ @@ -302,9 +306,9 @@ class Object { */ TVM_DLL bool DerivedFrom(uint32_t parent_tindex) const; // friend classes - template + template friend class ObjAllocatorBase; - template + template friend class ObjectPtr; friend class TVMRetValue; friend class ObjectInternal; @@ -384,9 +388,7 @@ class ObjectPtr { other.data_ = nullptr; } /*! \brief destructor */ - ~ObjectPtr() { - this->reset(); - } + ~ObjectPtr() { this->reset(); } /*! * \brief Swap this array with another Object * \param other The other Object @@ -397,15 +399,11 @@ class ObjectPtr { /*! * \return Get the content of the pointer */ - T* get() const { - return static_cast(data_); - } + T* get() const { return static_cast(data_); } /*! * \return The pointer */ - T* operator->() const { - return get(); - } + T* operator->() const { return get(); } /*! * \return The reference */ @@ -441,29 +439,17 @@ class ObjectPtr { } } /*! \return The use count of the ptr, for debug purposes */ - int use_count() const { - return data_ != nullptr ? data_->use_count() : 0; - } + int use_count() const { return data_ != nullptr ? data_->use_count() : 0; } /*! \return whether the reference is unique */ - bool unique() const { - return data_ != nullptr && data_->use_count() == 1; - } + bool unique() const { return data_ != nullptr && data_->use_count() == 1; } /*! \return Whether two ObjectPtr do not equal each other */ - bool operator==(const ObjectPtr& other) const { - return data_ == other.data_; - } + bool operator==(const ObjectPtr& other) const { return data_ == other.data_; } /*! \return Whether two ObjectPtr equals each other */ - bool operator!=(const ObjectPtr& other) const { - return data_ != other.data_; - } + bool operator!=(const ObjectPtr& other) const { return data_ != other.data_; } /*! \return Whether the pointer is nullptr */ - bool operator==(std::nullptr_t null) const { - return data_ == nullptr; - } + bool operator==(std::nullptr_t null) const { return data_ == nullptr; } /*! \return Whether the pointer is not nullptr */ - bool operator!=(std::nullptr_t null) const { - return data_ != nullptr; - } + bool operator!=(std::nullptr_t null) const { return data_ != nullptr; } private: /*! \brief internal pointer field */ @@ -491,10 +477,10 @@ class ObjectPtr { // friend classes friend class Object; friend class ObjectRef; - friend struct ObjectHash; - template + friend struct ObjectPtrHash; + template friend class ObjectPtr; - template + template friend class ObjAllocatorBase; friend class TVMPODValue_; friend class TVMArgsSetter; @@ -519,55 +505,37 @@ class ObjectRef { * \param other Another object ref. * \return the compare result. */ - bool same_as(const ObjectRef& other) const { - return data_ == other.data_; - } + bool same_as(const ObjectRef& other) const { return data_ == other.data_; } /*! * \brief Comparator * \param other Another object ref. * \return the compare result. */ - bool operator==(const ObjectRef& other) const { - return data_ == other.data_; - } + bool operator==(const ObjectRef& other) const { return data_ == other.data_; } /*! * \brief Comparator * \param other Another object ref. * \return the compare result. */ - bool operator!=(const ObjectRef& other) const { - return data_ != other.data_; - } + bool operator!=(const ObjectRef& other) const { return data_ != other.data_; } /*! * \brief Comparator * \param other Another object ref by address. * \return the compare result. */ - bool operator<(const ObjectRef& other) const { - return data_.get() < other.data_.get(); - } + bool operator<(const ObjectRef& other) const { return data_.get() < other.data_.get(); } /*! * \return whether the object is defined(not null). */ - bool defined() const { - return data_ != nullptr; - } + bool defined() const { return data_ != nullptr; } /*! \return the internal object pointer */ - const Object* get() const { - return data_.get(); - } + const Object* get() const { return data_.get(); } /*! \return the internal object pointer */ - const Object* operator->() const { - return get(); - } + const Object* operator->() const { return get(); } /*! \return whether the reference is unique */ - bool unique() const { - return data_.unique(); - } + bool unique() const { return data_.unique(); } /*! \return The use count of the ptr, for debug purposes */ - int use_count() const { - return data_.use_count(); - } + int use_count() const { return data_.use_count(); } /*! * \brief Try to downcast the internal Object to a * raw pointer of a corresponding type. @@ -591,16 +559,14 @@ class ObjectRef { /*! \brief Internal pointer that backs the reference. */ ObjectPtr data_; /*! \return return a mutable internal ptr, can be used by sub-classes. */ - Object* get_mutable() const { - return data_.get(); - } + Object* get_mutable() const { return data_.get(); } /*! * \brief Internal helper function downcast a ref without check. * \note Only used for internal dev purposes. * \tparam T The target reference type. * \return The casted result. */ - template + template static T DowncastNoCheck(ObjectRef ref) { return T(std::move(ref.data_)); } @@ -609,21 +575,19 @@ class ObjectRef { * after we successfully moved the field. * \param ref The reference data. */ - static void FFIClearAfterMove(ObjectRef* ref) { - ref->data_.data_ = nullptr; - } + static void FFIClearAfterMove(ObjectRef* ref) { ref->data_.data_ = nullptr; } /*! * \brief Internal helper function get data_ as ObjectPtr of ObjectType. * \note only used for internal dev purpose. * \tparam ObjectType The corresponding object type. * \return the corresponding type. */ - template + template static ObjectPtr GetDataPtr(const ObjectRef& ref) { return ObjectPtr(ref.data_.data_); } // friend classes. - friend struct ObjectHash; + friend struct ObjectPtrHash; friend class TVMRetValue; friend class TVMArgsSetter; template @@ -642,64 +606,57 @@ template inline ObjectPtr GetObjectPtr(ObjectType* ptr); /*! \brief ObjectRef hash functor */ -struct ObjectHash { - size_t operator()(const ObjectRef& a) const { - return operator()(a.data_); - } +struct ObjectPtrHash { + size_t operator()(const ObjectRef& a) const { return operator()(a.data_); } - template + template size_t operator()(const ObjectPtr& a) const { return std::hash()(a.get()); } }; - /*! \brief ObjectRef equal functor */ -struct ObjectEqual { - bool operator()(const ObjectRef& a, const ObjectRef& b) const { - return a.same_as(b); - } +struct ObjectPtrEqual { + bool operator()(const ObjectRef& a, const ObjectRef& b) const { return a.same_as(b); } - template + template size_t operator()(const ObjectPtr& a, const ObjectPtr& b) const { return a == b; } }; - /*! * \brief helper macro to declare a base object type that can be inheritated. * \param TypeName The name of the current type. * \param ParentType The name of the ParentType */ -#define TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) \ - static_assert(!ParentType::_type_final, "ParentObj maked as final"); \ - static uint32_t RuntimeTypeIndex() { \ - if (TypeName::_type_index != ::tvm::runtime::TypeIndex::kDynamic) { \ - return TypeName::_type_index; \ - } \ - return _GetOrAllocRuntimeTypeIndex(); \ - } \ - static uint32_t _GetOrAllocRuntimeTypeIndex() { \ - static uint32_t tidx = Object::GetOrAllocRuntimeTypeIndex( \ - TypeName::_type_key, \ - TypeName::_type_index, \ - ParentType::_GetOrAllocRuntimeTypeIndex(), \ - TypeName::_type_child_slots, \ - TypeName::_type_child_slots_can_overflow); \ - return tidx; \ - } \ +#define TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) \ + static_assert(!ParentType::_type_final, "ParentObj maked as final"); \ + static uint32_t RuntimeTypeIndex() { \ + static_assert(TypeName::_type_child_slots == 0 || ParentType::_type_child_slots == 0 || \ + TypeName::_type_child_slots < ParentType::_type_child_slots, \ + "Need to set _type_child_slots when parent specifies it."); \ + if (TypeName::_type_index != ::tvm::runtime::TypeIndex::kDynamic) { \ + return TypeName::_type_index; \ + } \ + return _GetOrAllocRuntimeTypeIndex(); \ + } \ + static uint32_t _GetOrAllocRuntimeTypeIndex() { \ + static uint32_t tidx = Object::GetOrAllocRuntimeTypeIndex( \ + TypeName::_type_key, TypeName::_type_index, ParentType::_GetOrAllocRuntimeTypeIndex(), \ + TypeName::_type_child_slots, TypeName::_type_child_slots_can_overflow); \ + return tidx; \ + } /*! * \brief helper macro to declare type information in a final class. - * \param TypeName The name of the current type. - * \param ParentType The name of the ParentType - */ -#define TVM_DECLARE_FINAL_OBJECT_INFO(TypeName, ParentType) \ - static const constexpr bool _type_final = true; \ - static const constexpr int _type_child_slots = 0; \ - TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) \ - + * \param TypeName The name of the current type. + * \param ParentType The name of the ParentType + */ +#define TVM_DECLARE_FINAL_OBJECT_INFO(TypeName, ParentType) \ + static const constexpr bool _type_final = true; \ + static const constexpr int _type_child_slots = 0; \ + TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) /*! \brief helper macro to supress unused warning */ #if defined(__GNUC__) @@ -711,8 +668,7 @@ struct ObjectEqual { #define TVM_STR_CONCAT_(__x, __y) __x##__y #define TVM_STR_CONCAT(__x, __y) TVM_STR_CONCAT_(__x, __y) -#define TVM_OBJECT_REG_VAR_DEF \ - static TVM_ATTRIBUTE_UNUSED uint32_t __make_Object_tid +#define TVM_OBJECT_REG_VAR_DEF static TVM_ATTRIBUTE_UNUSED uint32_t __make_Object_tid /*! * \brief Helper macro to register the object type to runtime. @@ -720,20 +676,18 @@ struct ObjectEqual { * * Use this macro in the cc file for each terminal class. */ -#define TVM_REGISTER_OBJECT_TYPE(TypeName) \ - TVM_STR_CONCAT(TVM_OBJECT_REG_VAR_DEF, __COUNTER__) = \ - TypeName::_GetOrAllocRuntimeTypeIndex() - +#define TVM_REGISTER_OBJECT_TYPE(TypeName) \ + TVM_STR_CONCAT(TVM_OBJECT_REG_VAR_DEF, __COUNTER__) = TypeName::_GetOrAllocRuntimeTypeIndex() /* * \brief Define the default copy/move constructor and assign opeator * \param TypeName The class typename. */ -#define TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName) \ - TypeName(const TypeName& other) = default; \ - TypeName(TypeName&& other) = default; \ - TypeName& operator=(const TypeName& other) = default; \ - TypeName& operator=(TypeName&& other) = default; \ +#define TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName) \ + TypeName(const TypeName& other) = default; \ + TypeName(TypeName&& other) = default; \ + TypeName& operator=(const TypeName& other) = default; \ + TypeName& operator=(TypeName&& other) = default; /* * \brief Define object reference methods. @@ -741,15 +695,12 @@ struct ObjectEqual { * \param ParentType The parent type of the objectref * \param ObjectName The type name of the object. */ -#define TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ - TypeName() = default; \ - explicit TypeName( \ - ::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) \ - : ParentType(n) {} \ - TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \ - const ObjectName* operator->() const { \ - return static_cast(data_.get()); \ - } \ +#define TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ + TypeName() = default; \ + explicit TypeName(::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) : ParentType(n) {} \ + TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \ + const ObjectName* operator->() const { return static_cast(data_.get()); } \ + const ObjectName* get() const { return operator->(); } \ using ContainerType = ObjectName; /* @@ -759,15 +710,12 @@ struct ObjectEqual { * \param ParentType The parent type of the objectref * \param ObjectName The type name of the object. */ -#define TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ - explicit TypeName( \ - ::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) \ - : ParentType(n) {} \ - TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \ - const ObjectName* operator->() const { \ - return static_cast(data_.get()); \ - } \ - static constexpr bool _type_is_nullable = false; \ +#define TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ + explicit TypeName(::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) : ParentType(n) {} \ + TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \ + const ObjectName* operator->() const { return static_cast(data_.get()); } \ + const ObjectName* get() const { return operator->(); } \ + static constexpr bool _type_is_nullable = false; \ using ContainerType = ObjectName; /* @@ -778,15 +726,11 @@ struct ObjectEqual { * \note We recommend making objects immutable when possible. * This macro is only reserved for objects that stores runtime states. */ -#define TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ - TypeName() = default; \ - TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \ - explicit TypeName( \ - ::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) \ - : ParentType(n) {} \ - ObjectName* operator->() const { \ - return static_cast(data_.get()); \ - } \ +#define TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ + TypeName() = default; \ + TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \ + explicit TypeName(::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) : ParentType(n) {} \ + ObjectName* operator->() const { return static_cast(data_.get()); } \ using ContainerType = ObjectName; /*! @@ -808,23 +752,21 @@ struct ObjectEqual { * * \endcode */ -#define TVM_DEFINE_OBJECT_REF_COW_METHOD(ObjectName) \ - ObjectName* CopyOnWrite() { \ - CHECK(data_ != nullptr); \ - if (!data_.unique()) { \ - auto n = make_object(*(operator->())); \ - ObjectPtr(std::move(n)).swap(data_); \ - } \ - return static_cast(data_.get()); \ - } +#define TVM_DEFINE_OBJECT_REF_COW_METHOD(ObjectName) \ + ObjectName* CopyOnWrite() { \ + CHECK(data_ != nullptr); \ + if (!data_.unique()) { \ + auto n = make_object(*(operator->())); \ + ObjectPtr(std::move(n)).swap(data_); \ + } \ + return static_cast(data_.get()); \ + } // Implementations details below // Object reference counting. #if TVM_OBJECT_ATOMIC_REF_COUNTER -inline void Object::IncRef() { - ref_counter_.fetch_add(1, std::memory_order_relaxed); -} +inline void Object::IncRef() { ref_counter_.fetch_add(1, std::memory_order_relaxed); } inline void Object::DecRef() { if (ref_counter_.fetch_sub(1, std::memory_order_release) == 1) { @@ -835,15 +777,11 @@ inline void Object::DecRef() { } } -inline int Object::use_count() const { - return ref_counter_.load(std::memory_order_relaxed); -} +inline int Object::use_count() const { return ref_counter_.load(std::memory_order_relaxed); } #else -inline void Object::IncRef() { - ++ref_counter_; -} +inline void Object::IncRef() { ++ref_counter_; } inline void Object::DecRef() { if (--ref_counter_ == 0) { @@ -853,13 +791,11 @@ inline void Object::DecRef() { } } -inline int Object::use_count() const { - return ref_counter_; -} +inline int Object::use_count() const { return ref_counter_; } #endif // TVM_OBJECT_ATOMIC_REF_COUNTER -template +template inline bool Object::IsInstance() const { const Object* self = this; // NOTE: the following code can be optimized by @@ -893,11 +829,9 @@ inline bool Object::IsInstance() const { } } - template inline const ObjectType* ObjectRef::as() const { - if (data_ != nullptr && - data_->IsInstance()) { + if (data_ != nullptr && data_->IsInstance()) { return static_cast(data_.get()); } else { return nullptr; @@ -925,12 +859,11 @@ template inline SubRef Downcast(BaseRef ref) { if (ref.defined()) { CHECK(ref->template IsInstance()) - << "Downcast from " << ref->GetTypeKey() << " to " - << SubRef::ContainerType::_type_key << " failed."; + << "Downcast from " << ref->GetTypeKey() << " to " << SubRef::ContainerType::_type_key + << " failed."; } else { - CHECK(SubRef::_type_is_nullable) - << "Downcast from nullptr to not nullable reference of " - << SubRef::ContainerType::_type_key; + CHECK(SubRef::_type_is_nullable) << "Downcast from nullptr to not nullable reference of " + << SubRef::ContainerType::_type_key; } return SubRef(std::move(ref.data_)); } diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 3d5a7e865303..e82b97a5a2d4 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -26,25 +26,33 @@ #include #include +#include #include #include -#include #include + #include -#include -#include -#include #include #include -#include +#include +#include #include - +#include +#include // Whether use TVM runtime in header only mode. #ifndef TVM_RUNTIME_HEADER_ONLY #define TVM_RUNTIME_HEADER_ONLY 0 #endif +// Always inline macro only use in template +// expansion cases where we know inline is important. +#ifdef _MSC_VER +#define TVM_ALWAYS_INLINE __forceinline inline +#else +#define TVM_ALWAYS_INLINE inline __attribute__((always_inline)) +#endif + namespace tvm { namespace runtime { @@ -83,7 +91,7 @@ class PackedFunc { * } * \endcode */ - using FType = std::function; + using FType = std::function; /*! \brief default constructor */ PackedFunc() {} /*! \brief constructor from null */ @@ -107,8 +115,8 @@ class PackedFunc { * } * \endcode */ - template - inline TVMRetValue operator()(Args&& ...args) const; + template + inline TVMRetValue operator()(Args&&... args) const; /*! * \brief Call the function in packed format. * \param args The arguments @@ -118,13 +126,9 @@ class PackedFunc { /*! \return the internal body function */ inline FType body() const; /*! \return Whether the packed function is nullptr */ - bool operator==(std::nullptr_t null) const { - return body_ == nullptr; - } + bool operator==(std::nullptr_t null) const { return body_ == nullptr; } /*! \return Whether the packed function is not nullptr */ - bool operator!=(std::nullptr_t null) const { - return body_ != nullptr; - } + bool operator!=(std::nullptr_t null) const { return body_ != nullptr; } private: /*! \brief internal container of packed function */ @@ -134,7 +138,7 @@ class PackedFunc { /*! * \brief Please refer to \ref TypedPackedFuncAnchor "TypedPackedFunc" */ -template +template class TypedPackedFunc; /*! @@ -169,7 +173,7 @@ class TypedPackedFunc; * \tparam R The return value of the function. * \tparam Args The argument signature of the function. */ -template +template class TypedPackedFunc { public: /*! \brief short hand for this function type */ @@ -226,11 +230,9 @@ class TypedPackedFunc { * \param typed_lambda typed lambda function. * \tparam FLambda the type of the lambda function. */ - template - >::value>::type> + template >::value>::type> TypedPackedFunc(const FLambda& typed_lambda) { // NOLINT(*) this->AssignTypedLambda(typed_lambda); } @@ -250,11 +252,9 @@ class TypedPackedFunc { * \tparam FLambda the type of the lambda function. * \returns reference to self. */ - template - >::value>::type> + template >::value>::type> TSelf& operator=(FLambda typed_lambda) { // NOLINT(*) this->AssignTypedLambda(typed_lambda); return *this; @@ -273,28 +273,20 @@ class TypedPackedFunc { * \param args The arguments * \returns The return value. */ - inline R operator()(Args ...args) const; + TVM_ALWAYS_INLINE R operator()(Args... args) const; /*! * \brief convert to PackedFunc * \return the internal PackedFunc */ - operator PackedFunc() const { - return packed(); - } + operator PackedFunc() const { return packed(); } /*! * \return reference the internal PackedFunc */ - const PackedFunc& packed() const { - return packed_; - } + const PackedFunc& packed() const { return packed_; } /*! \return Whether the packed function is nullptr */ - bool operator==(std::nullptr_t null) const { - return packed_ == nullptr; - } + bool operator==(std::nullptr_t null) const { return packed_ == nullptr; } /*! \return Whether the packed function is not nullptr */ - bool operator!=(std::nullptr_t null) const { - return packed_ != nullptr; - } + bool operator!=(std::nullptr_t null) const { return packed_ != nullptr; } private: friend class TVMRetValue; @@ -307,7 +299,7 @@ class TypedPackedFunc { * \tparam FLambda The lambda function type. * \note We capture the lambda when possible for maximum efficiency. */ - template + template inline void AssignTypedLambda(FLambda flambda); }; @@ -323,12 +315,8 @@ class TVMArgs { * \param type_codes The argument type codes * \param num_args number of arguments. */ - TVMArgs(const TVMValue* values, - const int* type_codes, - int num_args) - : values(values), - type_codes(type_codes), - num_args(num_args) { } + TVMArgs(const TVMValue* values, const int* type_codes, int num_args) + : values(values), type_codes(type_codes), num_args(num_args) {} /*! \return size of the arguments */ inline int size() const; /*! @@ -339,16 +327,22 @@ class TVMArgs { inline TVMArgValue operator[](int i) const; }; +/*! + * \brief Convert argument type code to string. + * \param type_code The input type code. + * \return The corresponding string repr. + */ +inline const char* ArgTypeCode2Str(int type_code); + // macro to check type code. -#define TVM_CHECK_TYPE_CODE(CODE, T) \ - CHECK_EQ(CODE, T) << " expected " \ - << TypeCode2Str(T) << " but get " << TypeCode2Str(CODE) \ +#define TVM_CHECK_TYPE_CODE(CODE, T) \ + CHECK_EQ(CODE, T) << " expected " << ArgTypeCode2Str(T) << " but get " << ArgTypeCode2Str(CODE) /*! * \brief Type traits for runtime type check during FFI conversion. * \tparam T the type to be checked. */ -template +template struct ObjectTypeChecker { static bool Check(const Object* ptr) { using ContainerType = typename T::ContainerType; @@ -402,61 +396,53 @@ class TVMPODValue_ { return value_.v_handle; } operator DLTensor*() const { - if (type_code_ == kTVMDLTensorHandle || - type_code_ == kTVMNDArrayHandle) { + if (type_code_ == kTVMDLTensorHandle || type_code_ == kTVMNDArrayHandle) { return static_cast(value_.v_handle); } else { if (type_code_ == kTVMNullptr) return nullptr; LOG(FATAL) << "Expect " - << "DLTensor* or NDArray but get " - << TypeCode2Str(type_code_); + << "DLTensor* or NDArray but get " << ArgTypeCode2Str(type_code_); return nullptr; } } operator NDArray() const { if (type_code_ == kTVMNullptr) return NDArray(ObjectPtr(nullptr)); TVM_CHECK_TYPE_CODE(type_code_, kTVMNDArrayHandle); - return NDArray(NDArray::FFIDataFromHandle( - static_cast(value_.v_handle))); + return NDArray(NDArray::FFIDataFromHandle(static_cast(value_.v_handle))); } operator Module() const { if (type_code_ == kTVMNullptr) { return Module(ObjectPtr(nullptr)); } TVM_CHECK_TYPE_CODE(type_code_, kTVMModuleHandle); - return Module( - ObjectPtr(static_cast(value_.v_handle))); + return Module(ObjectPtr(static_cast(value_.v_handle))); } operator TVMContext() const { TVM_CHECK_TYPE_CODE(type_code_, kTVMContext); return value_.v_ctx; } - int type_code() const { - return type_code_; - } + int type_code() const { return type_code_; } /*! * \brief return handle as specific pointer type. * \tparam T the data type. * \return The pointer type. */ - template + template T* ptr() const { return static_cast(value_.v_handle); } // ObjectRef handling - template::value>::type> + template ::value>::type> inline bool IsObjectRef() const; - template + template inline TObjectRef AsObjectRef() const; protected: friend class TVMArgsSetter; friend class TVMRetValue; TVMPODValue_() : type_code_(kTVMNullptr) {} - TVMPODValue_(TVMValue value, int type_code) - : value_(value), type_code_(type_code) {} + TVMPODValue_(TVMValue value, int type_code) : value_(value), type_code_(type_code) {} /*! \brief The value */ TVMValue value_; @@ -479,9 +465,7 @@ class TVMArgValue : public TVMPODValue_ { * \param value of the function * \param type_code The type code. */ - TVMArgValue(TVMValue value, int type_code) - : TVMPODValue_(value, type_code) { - } + TVMArgValue(TVMValue value, int type_code) : TVMPODValue_(value, type_code) {} // reuse converter from parent using TVMPODValue_::operator double; using TVMPODValue_::operator int64_t; @@ -493,8 +477,8 @@ class TVMArgValue : public TVMPODValue_ { using TVMPODValue_::operator NDArray; using TVMPODValue_::operator TVMContext; using TVMPODValue_::operator Module; - using TVMPODValue_::IsObjectRef; using TVMPODValue_::AsObjectRef; + using TVMPODValue_::IsObjectRef; // conversion operator. operator std::string() const { @@ -515,31 +499,27 @@ class TVMArgValue : public TVMPODValue_ { // None type if (type_code_ == kTVMNullptr) { DLDataType t; - t.code = kTVMOpaqueHandle; t.bits = 0; t.lanes = 0; + t.code = kTVMOpaqueHandle; + t.bits = 0; + t.lanes = 0; return t; } TVM_CHECK_TYPE_CODE(type_code_, kTVMDataType); return value_.v_type; } - operator DataType() const { - return DataType(operator DLDataType()); - } + operator DataType() const { return DataType(operator DLDataType()); } operator PackedFunc() const { if (type_code_ == kTVMNullptr) return PackedFunc(); TVM_CHECK_TYPE_CODE(type_code_, kTVMPackedFuncHandle); return *ptr(); } - template + template operator TypedPackedFunc() const { return TypedPackedFunc(operator PackedFunc()); } - const TVMValue& value() const { - return value_; - } + const TVMValue& value() const { return value_; } - template::value>::type> + template ::value>::type> inline operator T() const; }; @@ -555,9 +535,7 @@ class TVMArgValue : public TVMPODValue_ { */ class TVMMovableArgValue_ : public TVMArgValue { public: - TVMMovableArgValue_(TVMValue value, int type_code) - : TVMArgValue(value, type_code) { - } + TVMMovableArgValue_(TVMValue value, int type_code) : TVMArgValue(value, type_code) {} // reuse converter from parent using TVMArgValue::operator double; using TVMArgValue::operator int64_t; @@ -576,9 +554,8 @@ class TVMMovableArgValue_ : public TVMArgValue { * Try to move out an argument if possible, * fall back to normal argument conversion rule otherwise. */ - template::value>::type> + template ::value>::type> inline operator T() const; }; @@ -598,15 +575,12 @@ class TVMRetValue : public TVMPODValue_ { * \brief move constructor from anoter return value. * \param other The other return value. */ - TVMRetValue(TVMRetValue&& other) - : TVMPODValue_(other.value_, other.type_code_) { + TVMRetValue(TVMRetValue&& other) : TVMPODValue_(other.value_, other.type_code_) { other.value_.v_handle = nullptr; other.type_code_ = kTVMNullptr; } /*! \brief destructor */ - ~TVMRetValue() { - this->Clear(); - } + ~TVMRetValue() { this->Clear(); } // reuse converter from parent using TVMPODValue_::operator double; using TVMPODValue_::operator int64_t; @@ -618,12 +592,10 @@ class TVMRetValue : public TVMPODValue_ { using TVMPODValue_::operator TVMContext; using TVMPODValue_::operator NDArray; using TVMPODValue_::operator Module; - using TVMPODValue_::IsObjectRef; using TVMPODValue_::AsObjectRef; + using TVMPODValue_::IsObjectRef; - TVMRetValue(const TVMRetValue& other) : TVMPODValue_() { - this->Assign(other); - } + TVMRetValue(const TVMRetValue& other) : TVMPODValue_() { this->Assign(other); } // conversion operators operator std::string() const { if (type_code_ == kTVMDataType) { @@ -641,15 +613,13 @@ class TVMRetValue : public TVMPODValue_ { TVM_CHECK_TYPE_CODE(type_code_, kTVMDataType); return value_.v_type; } - operator DataType() const { - return DataType(operator DLDataType()); - } + operator DataType() const { return DataType(operator DLDataType()); } operator PackedFunc() const { if (type_code_ == kTVMNullptr) return PackedFunc(); TVM_CHECK_TYPE_CODE(type_code_, kTVMPackedFuncHandle); return *ptr(); } - template + template operator TypedPackedFunc() const { return TypedPackedFunc(operator PackedFunc()); } @@ -696,9 +666,7 @@ class TVMRetValue : public TVMPODValue_ { value_.v_type = t; return *this; } - TVMRetValue& operator=(const DataType& other) { - return operator=(other.operator DLDataType()); - } + TVMRetValue& operator=(const DataType& other) { return operator=(other.operator DLDataType()); } TVMRetValue& operator=(bool value) { this->SwitchToPOD(kDLInt); value_.v_int64 = value; @@ -728,10 +696,14 @@ class TVMRetValue : public TVMPODValue_ { return *this; } TVMRetValue& operator=(PackedFunc f) { - this->SwitchToClass(kTVMPackedFuncHandle, f); + if (f == nullptr) { + this->SwitchToPOD(kTVMNullptr); + } else { + this->SwitchToClass(kTVMPackedFuncHandle, f); + } return *this; } - template + template TVMRetValue& operator=(const TypedPackedFunc& f) { return operator=(f.packed()); } @@ -756,8 +728,7 @@ class TVMRetValue : public TVMPODValue_ { * \param ret_value The return value. * \param ret_type_code The return type code. */ - void MoveToCHost(TVMValue* ret_value, - int* ret_type_code) { + void MoveToCHost(TVMValue* ret_value, int* ret_type_code) { // cannot move str; need specially handle. CHECK(type_code_ != kTVMStr && type_code_ != kTVMBytes); *ret_value = value_; @@ -771,11 +742,9 @@ class TVMRetValue : public TVMPODValue_ { * \param type_code The type code. * \return The created TVMRetValue. */ - static TVMRetValue MoveFromCHost(TVMValue value, - int type_code) { + static TVMRetValue MoveFromCHost(TVMValue value, int type_code) { // Can move POD and everything under the object system. - CHECK(type_code <= kTVMPackedFuncHandle || - type_code == kTVMNDArrayHandle); + CHECK(type_code <= kTVMPackedFuncHandle || type_code == kTVMNDArrayHandle); TVMRetValue ret; ret.value_ = value; ret.type_code_ = type_code; @@ -783,24 +752,20 @@ class TVMRetValue : public TVMPODValue_ { } /*! \return The value field, if the data is POD */ const TVMValue& value() const { - CHECK(type_code_ != kTVMObjectHandle && - type_code_ != kTVMPackedFuncHandle && - type_code_ != kTVMModuleHandle && - type_code_ != kTVMStr) << "TVMRetValue.value can only be used for POD data"; + CHECK(type_code_ != kTVMObjectHandle && type_code_ != kTVMPackedFuncHandle && + type_code_ != kTVMModuleHandle && type_code_ != kTVMStr) + << "TVMRetValue.value can only be used for POD data"; return value_; } // ObjectRef handling - template::value>::type> + template ::value>::type> inline TVMRetValue& operator=(TObjectRef other); - template::value>::type> + template ::value>::type> inline operator T() const; private: - template + template void Assign(const T& other) { switch (other.type_code()) { case kTVMStr: { @@ -825,9 +790,8 @@ class TVMRetValue : public TVMPODValue_ { } case kTVMObjectHandle: { // Avoid operator ObjectRef as we already know it is not NDArray/Module - SwitchToObject( - kTVMObjectHandle, GetObjectPtr( - static_cast(other.value_.v_handle))); + SwitchToObject(kTVMObjectHandle, + GetObjectPtr(static_cast(other.value_.v_handle))); break; } case kTVMObjectRValueRefArg: { @@ -848,7 +812,7 @@ class TVMRetValue : public TVMPODValue_ { type_code_ = type_code; } } - template + template void SwitchToClass(int type_code, T v) { if (type_code_ != type_code) { this->Clear(); @@ -872,8 +836,13 @@ class TVMRetValue : public TVMPODValue_ { void Clear() { if (type_code_ == kTVMNullptr) return; switch (type_code_) { - case kTVMStr: case kTVMBytes: delete ptr(); break; - case kTVMPackedFuncHandle: delete ptr(); break; + case kTVMStr: + case kTVMBytes: + delete ptr(); + break; + case kTVMPackedFuncHandle: + delete ptr(); + break; case kTVMNDArrayHandle: { NDArray::FFIDecRef(static_cast(value_.v_handle)); break; @@ -900,24 +869,20 @@ class TVMRetValue : public TVMPODValue_ { * * \tparam TObjectRef the specific ObjectRefType. */ -template +template struct PackedFuncValueConverter { /*! * \brief Convert a TObjectRef from an argument value. * \param val The argument value. * \return the converted result. */ - static TObjectRef From(const TVMArgValue& val) { - return val.AsObjectRef(); - } + static TObjectRef From(const TVMArgValue& val) { return val.AsObjectRef(); } /*! * \brief Convert a TObjectRef from a return value. * \param val The argument value. * \return the converted result. */ - static TObjectRef From(const TVMRetValue& val) { - return val.AsObjectRef(); - } + static TObjectRef From(const TVMRetValue& val) { return val.AsObjectRef(); } }; /*! @@ -939,29 +904,22 @@ struct PackedFuncValueConverter { * * \endcode */ -#define TVM_DLL_EXPORT_PACKED_FUNC(ExportName, Function) \ - extern "C" { \ - TVM_DLL int ExportName(TVMValue* args, \ - int* type_code, \ - int num_args, \ - TVMValue* out_value, \ - int* out_type_code); \ - int ExportName(TVMValue* args, \ - int* type_code, \ - int num_args, \ - TVMValue* out_value, \ - int* out_type_code) { \ - try { \ - ::tvm::runtime::TVMRetValue rv; \ - Function(::tvm::runtime::TVMArgs( \ - args, type_code, num_args), &rv); \ - rv.MoveToCHost(out_value, out_type_code); \ - return 0; \ - } catch (const ::std::runtime_error& _except_) { \ - TVMAPISetLastError(_except_.what()); \ - return -1; \ - } \ - } \ +#define TVM_DLL_EXPORT_PACKED_FUNC(ExportName, Function) \ + extern "C" { \ + TVM_DLL int ExportName(TVMValue* args, int* type_code, int num_args, TVMValue* out_value, \ + int* out_type_code); \ + int ExportName(TVMValue* args, int* type_code, int num_args, TVMValue* out_value, \ + int* out_type_code) { \ + try { \ + ::tvm::runtime::TVMRetValue rv; \ + Function(::tvm::runtime::TVMArgs(args, type_code, num_args), &rv); \ + rv.MoveToCHost(out_value, out_type_code); \ + return 0; \ + } catch (const ::std::runtime_error& _except_) { \ + TVMAPISetLastError(_except_.what()); \ + return -1; \ + } \ + } \ } /*! @@ -999,181 +957,208 @@ struct PackedFuncValueConverter { * * \endcode */ -#define TVM_DLL_EXPORT_TYPED_FUNC(ExportName, Function) \ - extern "C" { \ - TVM_DLL int ExportName(TVMValue* args, \ - int* type_code, \ - int num_args, \ - TVMValue* out_value, \ - int* out_type_code) { \ - try { \ - auto f = Function; \ - using FType = ::tvm::runtime::detail:: \ - function_signature::FType; \ - ::tvm::runtime::TVMRetValue rv; \ - ::tvm::runtime::detail::unpack_call_by_signature::run( \ - f, \ - ::tvm::runtime::TVMArgs(args, type_code, num_args), &rv); \ - rv.MoveToCHost(out_value, out_type_code); \ - return 0; \ - } catch (const ::std::runtime_error& _except_) { \ - TVMAPISetLastError(_except_.what()); \ - return -1; \ - } \ - } \ +#define TVM_DLL_EXPORT_TYPED_FUNC(ExportName, Function) \ + extern "C" { \ + TVM_DLL int ExportName(TVMValue* args, int* type_code, int num_args, TVMValue* out_value, \ + int* out_type_code) { \ + try { \ + auto f = Function; \ + using FType = ::tvm::runtime::detail::function_signature::FType; \ + ::tvm::runtime::TVMRetValue rv; \ + ::tvm::runtime::detail::unpack_call_by_signature::run( \ + f, ::tvm::runtime::TVMArgs(args, type_code, num_args), &rv); \ + rv.MoveToCHost(out_value, out_type_code); \ + return 0; \ + } catch (const ::std::runtime_error& _except_) { \ + TVMAPISetLastError(_except_.what()); \ + return -1; \ + } \ + } \ } - inline TVMArgValue TVMArgs::operator[](int i) const { - CHECK_LT(i, num_args) - << "not enough argument passed, " - << num_args << " passed" - << " but request arg[" << i << "]."; + CHECK_LT(i, num_args) << "not enough argument passed, " << num_args << " passed" + << " but request arg[" << i << "]."; return TVMArgValue(values[i], type_codes[i]); } -inline int TVMArgs::size() const { - return num_args; -} +inline int TVMArgs::size() const { return num_args; } -inline void PackedFunc::CallPacked(TVMArgs args, TVMRetValue* rv) const { - body_(args, rv); -} +inline void PackedFunc::CallPacked(TVMArgs args, TVMRetValue* rv) const { body_(args, rv); } -inline PackedFunc::FType PackedFunc::body() const { - return body_; -} +inline PackedFunc::FType PackedFunc::body() const { return body_; } // internal namespace +inline const char* ArgTypeCode2Str(int type_code) { + switch (type_code) { + case kDLInt: + return "int"; + case kDLUInt: + return "uint"; + case kDLFloat: + return "float"; + case kTVMStr: + return "str"; + case kTVMBytes: + return "bytes"; + case kTVMOpaqueHandle: + return "handle"; + case kTVMNullptr: + return "NULL"; + case kTVMDLTensorHandle: + return "ArrayHandle"; + case kTVMDataType: + return "DLDataType"; + case kTVMContext: + return "TVMContext"; + case kTVMPackedFuncHandle: + return "FunctionHandle"; + case kTVMModuleHandle: + return "ModuleHandle"; + case kTVMNDArrayHandle: + return "NDArrayContainer"; + case kTVMObjectHandle: + return "Object"; + case kTVMObjectRValueRefArg: + return "ObjectRValueRefArg"; + default: + LOG(FATAL) << "unknown type_code=" << static_cast(type_code); + return ""; + } +} + namespace detail { -template +template struct for_each_dispatcher { - template + template static void run(const F& f, T&& value, Args&&... args) { // NOLINT(*) f(I, std::forward(value)); - for_each_dispatcher - ::run(f, std::forward(args)...); + for_each_dispatcher::run(f, std::forward(args)...); } }; -template -struct for_each_dispatcher { +template +struct for_each_dispatcher { static void run(const F& f) {} // NOLINT(*) }; -template +template inline void for_each(const F& f, Args&&... args) { // NOLINT(*) - for_each_dispatcher - ::run(f, std::forward(args)...); + for_each_dispatcher::run(f, std::forward(args)...); } -template +template struct func_signature_helper { using FType = void; }; -template +template struct func_signature_helper { using FType = R(Args...); + static_assert(!std::is_reference::value, "TypedPackedFunc return reference"); }; -template +template struct func_signature_helper { using FType = R(Args...); + static_assert(!std::is_reference::value, "TypedPackedFunc return reference"); }; /*! * \brief template class to get function signature of a function or functor. * \tparam T The funtion/functor type. */ -template +template struct function_signature { using FType = typename func_signature_helper::FType; }; // handle case of function. -template +template struct function_signature { using FType = R(Args...); + static_assert(!std::is_reference::value, "TypedPackedFunc return reference"); }; // handle case of function ptr. -template +template struct function_signature { using FType = R(Args...); + static_assert(!std::is_reference::value, "TypedPackedFunc return reference"); }; } // namespace detail /* \brief argument settter to PackedFunc */ class TVMArgsSetter { public: - TVMArgsSetter(TVMValue* values, int* type_codes) - : values_(values), type_codes_(type_codes) {} + TVMArgsSetter(TVMValue* values, int* type_codes) : values_(values), type_codes_(type_codes) {} // setters for POD types - template::value>::type> - void operator()(size_t i, T value) const { + template ::value>::type> + TVM_ALWAYS_INLINE void operator()(size_t i, T value) const { values_[i].v_int64 = static_cast(value); type_codes_[i] = kDLInt; } - void operator()(size_t i, uint64_t value) const { + TVM_ALWAYS_INLINE void operator()(size_t i, uint64_t value) const { values_[i].v_int64 = static_cast(value); - CHECK_LE(value, - static_cast(std::numeric_limits::max())); + CHECK_LE(value, static_cast(std::numeric_limits::max())); type_codes_[i] = kDLInt; } - void operator()(size_t i, double value) const { + TVM_ALWAYS_INLINE void operator()(size_t i, double value) const { values_[i].v_float64 = value; type_codes_[i] = kDLFloat; } - void operator()(size_t i, std::nullptr_t value) const { + TVM_ALWAYS_INLINE void operator()(size_t i, std::nullptr_t value) const { values_[i].v_handle = value; type_codes_[i] = kTVMNullptr; } - void operator()(size_t i, const TVMArgValue& value) const { + TVM_ALWAYS_INLINE void operator()(size_t i, const TVMArgValue& value) const { values_[i] = value.value_; type_codes_[i] = value.type_code_; } - void operator()(size_t i, void* value) const { + TVM_ALWAYS_INLINE void operator()(size_t i, void* value) const { values_[i].v_handle = value; type_codes_[i] = kTVMOpaqueHandle; } - void operator()(size_t i, DLTensor* value) const { + TVM_ALWAYS_INLINE void operator()(size_t i, DLTensor* value) const { values_[i].v_handle = value; type_codes_[i] = kTVMDLTensorHandle; } - void operator()(size_t i, TVMContext value) const { + TVM_ALWAYS_INLINE void operator()(size_t i, TVMContext value) const { values_[i].v_ctx = value; type_codes_[i] = kTVMContext; } - void operator()(size_t i, DLDataType value) const { + TVM_ALWAYS_INLINE void operator()(size_t i, DLDataType value) const { values_[i].v_type = value; type_codes_[i] = kTVMDataType; } - void operator()(size_t i, DataType dtype) const { + TVM_ALWAYS_INLINE void operator()(size_t i, DataType dtype) const { operator()(i, dtype.operator DLDataType()); } - void operator()(size_t i, const char* value) const { + TVM_ALWAYS_INLINE void operator()(size_t i, const char* value) const { values_[i].v_str = value; type_codes_[i] = kTVMStr; } // setters for container types - void operator()(size_t i, const std::string& value) const { + TVM_ALWAYS_INLINE void operator()(size_t i, const std::string& value) const { values_[i].v_str = value.c_str(); type_codes_[i] = kTVMStr; } - void operator()(size_t i, const TVMByteArray& value) const { + TVM_ALWAYS_INLINE void operator()(size_t i, const TVMByteArray& value) const { values_[i].v_handle = const_cast(&value); type_codes_[i] = kTVMBytes; } - void operator()(size_t i, const PackedFunc& value) const { - values_[i].v_handle = const_cast(&value); - type_codes_[i] = kTVMPackedFuncHandle; + TVM_ALWAYS_INLINE void operator()(size_t i, const PackedFunc& value) const { + if (value != nullptr) { + values_[i].v_handle = const_cast(&value); + type_codes_[i] = kTVMPackedFuncHandle; + } else { + values_[i].v_handle = nullptr; + type_codes_[i] = kTVMNullptr; + } } - template - void operator()(size_t i, const TypedPackedFunc& value) const { + template + TVM_ALWAYS_INLINE void operator()(size_t i, const TypedPackedFunc& value) const { operator()(i, value.packed()); } void operator()(size_t i, const TVMRetValue& value) const { @@ -1187,25 +1172,21 @@ class TVMArgsSetter { } } // ObjectRef handling - template::value> - ::type> - void operator()(size_t i, const TObjectRef& value) const { + template ::value>::type> + TVM_ALWAYS_INLINE void operator()(size_t i, const TObjectRef& value) const { this->SetObject(i, value); } - template::type>::value> - ::type> - void operator()(size_t i, TObjectRef&& value) const { + template ::type>::value>::type> + TVM_ALWAYS_INLINE void operator()(size_t i, TObjectRef&& value) const { this->SetObject(i, std::forward(value)); } private: - template + template inline void SetObject(size_t i, TObjectRef&& value) const; /*! \brief The values fields */ TVMValue* values_; @@ -1213,128 +1194,120 @@ class TVMArgsSetter { int* type_codes_; }; -template -inline TVMRetValue PackedFunc::operator()(Args&& ...args) const { +template +inline TVMRetValue PackedFunc::operator()(Args&&... args) const { const int kNumArgs = sizeof...(Args); const int kArraySize = kNumArgs > 0 ? kNumArgs : 1; TVMValue values[kArraySize]; int type_codes[kArraySize]; - detail::for_each(TVMArgsSetter(values, type_codes), - std::forward(args)...); + detail::for_each(TVMArgsSetter(values, type_codes), std::forward(args)...); TVMRetValue rv; body_(TVMArgs(values, type_codes, kNumArgs), &rv); return rv; } namespace detail { -template +template struct unpack_call_dispatcher { - template - static void run(const F& f, - const TVMArgs& args_pack, - TVMRetValue* rv, - Args&&... unpacked_args) { + template + TVM_ALWAYS_INLINE static void run(const F& f, const TVMArgs& args_pack, TVMRetValue* rv, + Args&&... unpacked_args) { // construct a movable argument value // which allows potential move of argument to the input of F. - unpack_call_dispatcher - ::run(f, args_pack, rv, - std::forward(unpacked_args)..., - TVMMovableArgValue_(args_pack.values[index], - args_pack.type_codes[index])); + unpack_call_dispatcher::run( + f, args_pack, rv, std::forward(unpacked_args)..., + TVMMovableArgValue_(args_pack.values[index], args_pack.type_codes[index])); } }; -template +template struct unpack_call_dispatcher { - template - static void run(const F& f, - const TVMArgs& args_pack, - TVMRetValue* rv, - Args&&... unpacked_args) { - *rv = R(f(std::forward(unpacked_args)...)); + template + TVM_ALWAYS_INLINE static void run(const F& f, const TVMArgs& args_pack, TVMRetValue* rv, + Args&&... unpacked_args) { + using RetType = decltype(f(std::forward(unpacked_args)...)); + if (std::is_same::value) { + *rv = f(std::forward(unpacked_args)...); + } else { + *rv = R(f(std::forward(unpacked_args)...)); + } } }; -template +template struct unpack_call_dispatcher { - template - static void run(const F& f, - const TVMArgs& args_pack, - TVMRetValue* rv, - Args&&... unpacked_args) { + template + TVM_ALWAYS_INLINE static void run(const F& f, const TVMArgs& args_pack, TVMRetValue* rv, + Args&&... unpacked_args) { f(std::forward(unpacked_args)...); } }; -template -inline void unpack_call(const F& f, const TVMArgs& args, TVMRetValue* rv) { +template +TVM_ALWAYS_INLINE void unpack_call(const F& f, const TVMArgs& args, TVMRetValue* rv) { + CHECK_EQ(nargs, args.size()) << "Expect " << nargs << " arguments but get " << args.size(); unpack_call_dispatcher::run(f, args, rv); } -template -struct unpack_call_by_signature { -}; +template +struct unpack_call_by_signature {}; -template +template struct unpack_call_by_signature { - template - static void run(const F& f, - const TVMArgs& args, - TVMRetValue* rv) { + template + TVM_ALWAYS_INLINE static void run(const F& f, const TVMArgs& args, TVMRetValue* rv) { unpack_call(f, args, rv); } }; -template -inline R call_packed(const PackedFunc& pf, Args&& ...args) { +template +TVM_ALWAYS_INLINE R call_packed(const PackedFunc& pf, Args&&... args) { return R(pf(std::forward(args)...)); } -template +template struct typed_packed_call_dispatcher { - template - static inline R run(const PackedFunc& pf, Args&& ...args) { + template + TVM_ALWAYS_INLINE static R run(const PackedFunc& pf, Args&&... args) { return pf(std::forward(args)...); } }; -template<> +template <> struct typed_packed_call_dispatcher { - template - static inline void run(const PackedFunc& pf, Args&& ...args) { + template + TVM_ALWAYS_INLINE static void run(const PackedFunc& pf, Args&&... args) { pf(std::forward(args)...); } }; } // namespace detail -template -TypedPackedFunc::TypedPackedFunc(PackedFunc packed) - : packed_(packed) {} +template +TypedPackedFunc::TypedPackedFunc(PackedFunc packed) : packed_(packed) {} -template +template TypedPackedFunc::TypedPackedFunc(const TVMRetValue& value) : packed_(value.operator PackedFunc()) {} -template +template TypedPackedFunc::TypedPackedFunc(const TVMArgValue& value) : packed_(value.operator PackedFunc()) {} -template +template TypedPackedFunc::TypedPackedFunc(TVMMovableArgValue_&& value) : packed_(value.operator PackedFunc()) {} -template -template +template +template inline void TypedPackedFunc::AssignTypedLambda(FType flambda) { packed_ = PackedFunc([flambda](const TVMArgs& args, TVMRetValue* rv) { - detail::unpack_call(flambda, args, rv); - }); + detail::unpack_call(flambda, args, rv); + }); } -template -inline R TypedPackedFunc::operator()(Args... args) const { - return detail::typed_packed_call_dispatcher - ::run(packed_, std::forward(args)...); +template +TVM_ALWAYS_INLINE R TypedPackedFunc::operator()(Args... args) const { + return detail::typed_packed_call_dispatcher::run(packed_, std::forward(args)...); } // ObjectRef related conversion handling @@ -1342,18 +1315,18 @@ inline R TypedPackedFunc::operator()(Args... args) const { // kTVMNDArrayHandle, kTVMModuleHandle, kTVMObjectHandle // // We use type traits to eliminate un-necessary checks. -template +template inline void TVMArgsSetter::SetObject(size_t i, T&& value) const { - using TObjectRef = typename std::remove_reference::type; + using ContainerType = typename std::remove_reference::type::ContainerType; if (value.defined()) { Object* ptr = value.data_.data_; - if (std::is_base_of::value || - (std::is_base_of::value && + if (std::is_base_of::value || + (std::is_base_of::value && ptr->IsInstance())) { values_[i].v_handle = NDArray::FFIGetHandle(value); type_codes_[i] = kTVMNDArrayHandle; - } else if (std::is_base_of::value || - (std::is_base_of::value && + } else if (std::is_base_of::value || + (std::is_base_of::value && ptr->IsInstance())) { values_[i].v_handle = ptr; type_codes_[i] = kTVMModuleHandle; @@ -1369,53 +1342,53 @@ inline void TVMArgsSetter::SetObject(size_t i, T&& value) const { } } -template +template inline bool TVMPODValue_::IsObjectRef() const { using ContainerType = typename TObjectRef::ContainerType; // NOTE: the following code can be optimized by constant folding. - if (std::is_base_of::value) { + if (std::is_base_of::value) { return type_code_ == kTVMNDArrayHandle && - TVMArrayHandleToObjectHandle( - static_cast(value_.v_handle))->IsInstance(); + TVMArrayHandleToObjectHandle(static_cast(value_.v_handle)) + ->IsInstance(); } - if (std::is_base_of::value) { + if (std::is_base_of::value) { return type_code_ == kTVMModuleHandle && - static_cast(value_.v_handle)->IsInstance(); + static_cast(value_.v_handle)->IsInstance(); } // NOTE: we don't pass NDArray and runtime::Module as RValue ref. if (type_code_ == kTVMObjectRValueRefArg) { - return ObjectTypeChecker::Check( - *static_cast(value_.v_handle)); - } - return - (std::is_base_of::value && type_code_ == kTVMNDArrayHandle) || - (std::is_base_of::value && type_code_ == kTVMModuleHandle) || - (type_code_ == kTVMObjectHandle && - ObjectTypeChecker::Check(static_cast(value_.v_handle))); + return ObjectTypeChecker::Check(*static_cast(value_.v_handle)); + } + return (std::is_base_of::value && + type_code_ == kTVMNDArrayHandle) || + (std::is_base_of::value && + type_code_ == kTVMModuleHandle) || + (type_code_ == kTVMObjectHandle && + ObjectTypeChecker::Check(static_cast(value_.v_handle))); } -template +template inline TObjectRef TVMPODValue_::AsObjectRef() const { - static_assert( - std::is_base_of::value, - "Conversion only works for ObjectRef"); + static_assert(std::is_base_of::value, + "Conversion only works for ObjectRef"); using ContainerType = typename TObjectRef::ContainerType; + if (type_code_ == kTVMNullptr) { CHECK(TObjectRef::_type_is_nullable) << "Expect a not null value of " << ContainerType::_type_key; return TObjectRef(ObjectPtr(nullptr)); } // NOTE: the following code can be optimized by constant folding. - if (std::is_base_of::value) { + if (std::is_base_of::value) { // Casting to a sub-class of NDArray TVM_CHECK_TYPE_CODE(type_code_, kTVMNDArrayHandle); - ObjectPtr data = NDArray::FFIDataFromHandle( - static_cast(value_.v_handle)); + ObjectPtr data = + NDArray::FFIDataFromHandle(static_cast(value_.v_handle)); CHECK(data->IsInstance()) << "Expect " << ContainerType::_type_key << " but get " << data->GetTypeKey(); return TObjectRef(data); } - if (std::is_base_of::value) { + if (std::is_base_of::value) { // Casting to a sub-class of Module TVM_CHECK_TYPE_CODE(type_code_, kTVMModuleHandle); ObjectPtr data = GetObjectPtr(static_cast(value_.v_handle)); @@ -1427,22 +1400,22 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const { // normal object type check. Object* ptr = static_cast(value_.v_handle); CHECK(ObjectTypeChecker::Check(ptr)) - << "Expect " << ObjectTypeChecker::TypeName() - << " but get " << ptr->GetTypeKey(); + << "Expect " << ObjectTypeChecker::TypeName() << " but get " + << ptr->GetTypeKey(); return TObjectRef(GetObjectPtr(ptr)); } else if (type_code_ == kTVMObjectRValueRefArg) { Object* ptr = *static_cast(value_.v_handle); CHECK(ObjectTypeChecker::Check(ptr)) - << "Expect " << ObjectTypeChecker::TypeName() - << " but get " << ptr->GetTypeKey(); + << "Expect " << ObjectTypeChecker::TypeName() << " but get " + << ptr->GetTypeKey(); return TObjectRef(GetObjectPtr(ptr)); - } else if (std::is_base_of::value && + } else if (std::is_base_of::value && type_code_ == kTVMNDArrayHandle) { // Casting to a base class that NDArray can sub-class - ObjectPtr data = NDArray::FFIDataFromHandle( - static_cast(value_.v_handle)); + ObjectPtr data = + NDArray::FFIDataFromHandle(static_cast(value_.v_handle)); return TObjectRef(data); - } else if (std::is_base_of::value && + } else if (std::is_base_of::value && type_code_ == kTVMModuleHandle) { // Casting to a base class that Module can sub-class return TObjectRef(GetObjectPtr(static_cast(value_.v_handle))); @@ -1452,17 +1425,18 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const { } } -template +template inline TVMRetValue& TVMRetValue::operator=(TObjectRef other) { + using ContainerType = typename TObjectRef::ContainerType; const Object* ptr = other.get(); if (ptr != nullptr) { - if (std::is_base_of::value || - (std::is_base_of::value && + if (std::is_base_of::value || + (std::is_base_of::value && ptr->IsInstance())) { return operator=(NDArray(std::move(other.data_))); } - if (std::is_base_of::value || - (std::is_base_of::value && + if (std::is_base_of::value || + (std::is_base_of::value && ptr->IsInstance())) { return operator=(Module(std::move(other.data_))); } @@ -1473,13 +1447,12 @@ inline TVMRetValue& TVMRetValue::operator=(TObjectRef other) { return *this; } - -template +template inline TVMArgValue::operator T() const { return PackedFuncValueConverter::From(*this); } -template +template inline TVMMovableArgValue_::operator T() const { if (type_code_ == kTVMObjectRValueRefArg) { auto** ref = static_cast(value_.v_handle); @@ -1491,7 +1464,7 @@ inline TVMMovableArgValue_::operator T() const { return PackedFuncValueConverter::From(*this); } -template +template inline TVMRetValue::operator T() const { return PackedFuncValueConverter::From(*this); } diff --git a/include/tvm/runtime/registry.h b/include/tvm/runtime/registry.h index 6faa7b7c84d7..4a5a21088222 100644 --- a/include/tvm/runtime/registry.h +++ b/include/tvm/runtime/registry.h @@ -44,9 +44,10 @@ #define TVM_RUNTIME_REGISTRY_H_ #include + #include -#include #include +#include namespace tvm { namespace runtime { @@ -68,7 +69,8 @@ class Registry { } /*! * \brief set the body of the function to the given function. - * Note that this will ignore default arg values and always require all arguments to be provided. + * Note that this will ignore default arg values and always require all arguments to be + * provided. * * \code * @@ -88,14 +90,15 @@ class Registry { * \param f The function to forward to. * \tparam FLambda The signature of the function. */ - template + template Registry& set_body_typed(FLambda f) { using FType = typename detail::function_signature::FType; return set_body(TypedPackedFunc(std::move(f)).packed()); } /*! * \brief set the body of the function to be the passed method pointer. - * Note that this will ignore default arg values and always require all arguments to be provided. + * Note that this will ignore default arg values and always require all arguments to be + * provided. * * \code * @@ -113,9 +116,9 @@ class Registry { * \tparam R the return type of the function (inferred). * \tparam Args the argument types of the function (inferred). */ - template + template Registry& set_body_method(R (T::*f)(Args...)) { - auto fwrap =[f](T target, Args... params) -> R { + auto fwrap = [f](T target, Args... params) -> R { // call method pointer return (target.*f)(params...); }; @@ -124,7 +127,8 @@ class Registry { /*! * \brief set the body of the function to be the passed method pointer. - * Note that this will ignore default arg values and always require all arguments to be provided. + * Note that this will ignore default arg values and always require all arguments to be + * provided. * * \code * @@ -142,7 +146,7 @@ class Registry { * \tparam R the return type of the function (inferred). * \tparam Args the argument types of the function (inferred). */ - template + template Registry& set_body_method(R (T::*f)(Args...) const) { auto fwrap = [f](const T target, Args... params) -> R { // call method pointer @@ -154,7 +158,8 @@ class Registry { /*! * \brief set the body of the function to be the passed method pointer. * Used when calling a method on a Node subclass through a ObjectRef subclass. - * Note that this will ignore default arg values and always require all arguments to be provided. + * Note that this will ignore default arg values and always require all arguments to be + * provided. * * \code * @@ -181,8 +186,8 @@ class Registry { * \tparam R the return type of the function (inferred). * \tparam Args the argument types of the function (inferred). */ - template::value>::type> + template ::value>::type> Registry& set_body_method(R (TNode::*f)(Args...)) { auto fwrap = [f](TObjectRef ref, Args... params) { TNode* target = ref.operator->(); @@ -195,7 +200,8 @@ class Registry { /*! * \brief set the body of the function to be the passed method pointer. * Used when calling a method on a Node subclass through a ObjectRef subclass. - * Note that this will ignore default arg values and always require all arguments to be provided. + * Note that this will ignore default arg values and always require all arguments to be + * provided. * * \code * @@ -222,8 +228,8 @@ class Registry { * \tparam R the return type of the function (inferred). * \tparam Args the argument types of the function (inferred). */ - template::value>::type> + template ::value>::type> Registry& set_body_method(R (TNode::*f)(Args...) const) { auto fwrap = [f](TObjectRef ref, Args... params) { const TNode* target = ref.operator->(); @@ -270,8 +276,7 @@ class Registry { friend struct Manager; }; -#define TVM_FUNC_REG_VAR_DEF \ - static TVM_ATTRIBUTE_UNUSED ::tvm::runtime::Registry& __mk_ ## TVM +#define TVM_FUNC_REG_VAR_DEF static TVM_ATTRIBUTE_UNUSED ::tvm::runtime::Registry& __mk_##TVM /*! * \brief Register a function globally. @@ -281,9 +286,8 @@ class Registry { * }); * \endcode */ -#define TVM_REGISTER_GLOBAL(OpName) \ - TVM_STR_CONCAT(TVM_FUNC_REG_VAR_DEF, __COUNTER__) = \ - ::tvm::runtime::Registry::Register(OpName) +#define TVM_REGISTER_GLOBAL(OpName) \ + TVM_STR_CONCAT(TVM_FUNC_REG_VAR_DEF, __COUNTER__) = ::tvm::runtime::Registry::Register(OpName) } // namespace runtime } // namespace tvm diff --git a/include/tvm/runtime/serializer.h b/include/tvm/runtime/serializer.h index 37bb95f54655..f40c87ee07ec 100644 --- a/include/tvm/runtime/serializer.h +++ b/include/tvm/runtime/serializer.h @@ -33,14 +33,14 @@ namespace dmlc { namespace serializer { -template<> +template <> struct Handler { - inline static void Write(Stream *strm, const DLDataType& dtype) { + inline static void Write(Stream* strm, const DLDataType& dtype) { Handler::Write(strm, dtype.code); Handler::Write(strm, dtype.bits); Handler::Write(strm, dtype.lanes); } - inline static bool Read(Stream *strm, DLDataType* dtype) { + inline static bool Read(Stream* strm, DLDataType* dtype) { if (!Handler::Read(strm, &(dtype->code))) return false; if (!Handler::Read(strm, &(dtype->bits))) return false; if (!Handler::Read(strm, &(dtype->lanes))) return false; @@ -48,14 +48,14 @@ struct Handler { } }; -template<> +template <> struct Handler { - inline static void Write(Stream *strm, const DLContext& ctx) { + inline static void Write(Stream* strm, const DLContext& ctx) { int32_t device_type = static_cast(ctx.device_type); Handler::Write(strm, device_type); Handler::Write(strm, ctx.device_id); } - inline static bool Read(Stream *strm, DLContext* ctx) { + inline static bool Read(Stream* strm, DLContext* ctx) { int32_t device_type = 0; if (!Handler::Read(strm, &(device_type))) return false; ctx->device_type = static_cast(device_type); diff --git a/include/tvm/runtime/threading_backend.h b/include/tvm/runtime/threading_backend.h index f1984013e6a9..95a64049fd45 100644 --- a/include/tvm/runtime/threading_backend.h +++ b/include/tvm/runtime/threading_backend.h @@ -40,26 +40,25 @@ class ThreadGroup { public: class Impl; - /*! - * \brief Creates a collection of threads which run a provided function. - * - * \param num_workers The total number of worker threads in this group. - Includes main thread if `exclude_worker0 = true` - * \param worker_callback A callback which is run in its own thread. - Receives the worker_id as an argument. - * \param exclude_worker0 Whether to use the main thread as a worker. - * If `true`, worker0 will not be launched in a new thread and - * `worker_callback` will only be called for values >= 1. This - * allows use of the main thread as a worker. - */ - ThreadGroup(int num_workers, - std::function worker_callback, + /*! + * \brief Creates a collection of threads which run a provided function. + * + * \param num_workers The total number of worker threads in this group. + Includes main thread if `exclude_worker0 = true` + * \param worker_callback A callback which is run in its own thread. + Receives the worker_id as an argument. + * \param exclude_worker0 Whether to use the main thread as a worker. + * If `true`, worker0 will not be launched in a new thread and + * `worker_callback` will only be called for values >= 1. This + * allows use of the main thread as a worker. + */ + ThreadGroup(int num_workers, std::function worker_callback, bool exclude_worker0 = false); ~ThreadGroup(); - /*! - * \brief Blocks until all non-main threads in the pool finish. - */ + /*! + * \brief Blocks until all non-main threads in the pool finish. + */ void Join(); enum AffinityMode : int { @@ -95,7 +94,6 @@ void Yield(); */ int MaxConcurrency(); - } // namespace threading } // namespace runtime } // namespace tvm diff --git a/include/tvm/runtime/vm.h b/include/tvm/runtime/vm.h index 43c222d0994a..552edc5f19db 100644 --- a/include/tvm/runtime/vm.h +++ b/include/tvm/runtime/vm.h @@ -24,10 +24,11 @@ #ifndef TVM_RUNTIME_VM_H_ #define TVM_RUNTIME_VM_H_ -#include #include +#include #include #include + #include #include #include @@ -44,8 +45,8 @@ namespace vm { */ class ClosureObj : public Object { public: - static constexpr const uint32_t _type_index = TypeIndex::kClosure; - static constexpr const char* _type_key = "Closure"; + static constexpr const uint32_t _type_index = TypeIndex::kRuntimeClosure; + static constexpr const char* _type_key = "runtime.Closure"; TVM_DECLARE_BASE_OBJECT_INFO(ClosureObj, Object); }; @@ -135,6 +136,8 @@ struct Instruction { struct /* AllocTensor Operands */ { /*! \brief The storage to allocate from. */ RegName storage; + /*! \brief The offset into the storage to allocate from. */ + Index offset; /*! \brief The number of dimensions. */ uint32_t ndim; /*! \brief The shape of tensor. */ @@ -145,6 +148,8 @@ struct Instruction { struct /* AllocTensorReg Operands */ { /*! \brief The storage to allocate from. */ RegName storage; + /*! \brief The offset into the storage to allocate from. */ + Index offset; /*! \brief The register to read the shape out of. */ RegName shape_register; /*! \brief The datatype of tensor to be allocated. */ @@ -266,23 +271,25 @@ struct Instruction { /*! * \brief Construct an allocate tensor instruction with constant shape. * \param storage The storage to allocate out of. + * \param offset The offset to allocate at. * \param shape The shape of the tensor. * \param dtype The dtype of the tensor. * \param dst The destination register. * \return The allocate tensor instruction. */ - static Instruction AllocTensor(RegName storage, - const std::vector& shape, DLDataType dtype, RegName dst); + static Instruction AllocTensor(RegName storage, Index offset, const std::vector& shape, + DLDataType dtype, RegName dst); /*! * \brief Construct an allocate tensor instruction with register. * \param storage The storage to allocate out of. + * \param offset The offset into the storage to allocate from. * \param shape_register The register containing the shape. * \param dtype The dtype of the tensor. * \param dst The destination register. * \return The allocate tensor instruction. */ - static Instruction AllocTensorReg(RegName storage, - RegName shape_register, DLDataType dtype, RegName dst); + static Instruction AllocTensorReg(RegName storage, Index offset, RegName shape_register, + DLDataType dtype, RegName dst); /*! * \brief Construct an allocate datatype instruction. * \param tag The datatype tag. @@ -379,8 +386,8 @@ struct Instruction { * \param dst The destination to place the storage. * \return The alloc storage instruction. */ - static Instruction AllocStorage(RegName size, RegName alignment, - DLDataType dtype_hint, RegName dst); + static Instruction AllocStorage(RegName size, RegName alignment, DLDataType dtype_hint, + RegName dst); Instruction(); Instruction(const Instruction& instr); @@ -407,8 +414,7 @@ struct VMFunction { Index register_file_size; VMFunction(const std::string& name, std::vector params, - const std::vector& instructions, - Index register_file_size) + const std::vector& instructions, Index register_file_size) : name(name), params(params), instructions(instructions), @@ -473,8 +479,7 @@ class Executable : public ModuleNode { * * \return PackedFunc or nullptr when it is not available. */ - PackedFunc GetFunction(const std::string& name, - const ObjectPtr& sptr_to_self) final; + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final; /*! * \brief Serialize the executable into global section, constant section, and @@ -559,9 +564,7 @@ class Executable : public ModuleNode { virtual ~Executable() {} - const char* type_key() const final { - return "VMExecutable"; - } + const char* type_key() const final { return "VMExecutable"; } /*! \brief The runtime module/library that contains both the host and also the device * code when executing on non-CPU devices. */ @@ -668,14 +671,11 @@ class VirtualMachine : public runtime::ModuleNode { * If the function needs resource from the module(e.g. late linking), * it should capture sptr_to_self. */ - virtual PackedFunc GetFunction(const std::string& name, - const ObjectPtr& sptr_to_self); + virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self); virtual ~VirtualMachine() {} - const char* type_key() const final { - return "VirtualMachine"; - } + const char* type_key() const final { return "VirtualMachine"; } VirtualMachine() : frames_(), func_index_(0), code_(nullptr), pc_(0), exec_(nullptr) {} @@ -763,11 +763,8 @@ class VirtualMachine : public runtime::ModuleNode { * * \note The return value will be stored in the last output_size slots of args. */ - virtual void InvokePacked(Index packed_index, - const PackedFunc& func, - Index arg_count, - Index output_size, - const std::vector& args); + virtual void InvokePacked(Index packed_index, const PackedFunc& func, Index arg_count, + Index output_size, const std::vector& args); /*! * \brief Initialize the virtual machine for a set of contexts. diff --git a/include/tvm/support/logging.h b/include/tvm/support/logging.h index 44b990e0d7db..c318b89e5c51 100644 --- a/include/tvm/support/logging.h +++ b/include/tvm/support/logging.h @@ -59,8 +59,8 @@ * a = ... * b = ... * // if quit_on_assertion is true, if a==b, continue, otherwise quit. - * // if quit_on_assertion is false, if a==b, continue, otherwise 'return false' (default behaviour) - * COND_CHECK_EQ(quit_on_assertion, a, b) << "some error message when quiting" + * // if quit_on_assertion is false, if a==b, continue, otherwise 'return false' (default + * behaviour) COND_CHECK_EQ(quit_on_assertion, a, b) << "some error message when quiting" * ... * for (int i = 0; i < N; i++) { * a = ... @@ -84,29 +84,24 @@ // Not supposed to be used by users directly. #define COND_CHECK_OP(quit_on_assert, x, y, what, op) \ - if (!quit_on_assert) { \ - if (!((x) op (y))) \ - what; \ - } \ - else /* NOLINT(*) */ \ + if (!quit_on_assert) { \ + if (!((x)op(y))) what; \ + } else /* NOLINT(*) */ \ CHECK_##op(x, y) #define COND_CHECK_EQ_4(quit_on_assert, x, y, what) COND_CHECK_OP(quit_on_assert, x, y, what, ==) #define COND_CHECK_GE_4(quit_on_assert, x, y, what) COND_CHECK_OP(quit_on_assert, x, y, what, >=) #define COND_CHECK_3(quit_on_assert, x, what) \ - if (!quit_on_assert) { \ - if (!(x)) \ - what; \ - } \ - else /* NOLINT(*) */ \ + if (!quit_on_assert) { \ + if (!(x)) what; \ + } else /* NOLINT(*) */ \ CHECK(x) #define COND_LOG_3(quit_on_assert, x, what) \ - if (!quit_on_assert) { \ - what; \ - } \ - else /* NOLINT(*) */ \ + if (!quit_on_assert) { \ + what; \ + } else /* NOLINT(*) */ \ LOG(x) #define COND_CHECK_EQ_3(quit_on_assert, x, y) COND_CHECK_EQ_4(quit_on_assert, x, y, return false) @@ -114,4 +109,4 @@ #define COND_CHECK_2(quit_on_assert, x) COND_CHECK_3(quit_on_assert, x, return false) #define COND_LOG_2(quit_on_assert, x) COND_LOG_3(quit_on_assert, x, return false) -#endif // TVM_SUPPORT_LOGGING_H_ +#endif // TVM_SUPPORT_LOGGING_H_ diff --git a/include/tvm/support/with.h b/include/tvm/support/with.h index 46b091a68f34..90c82c4f3a06 100644 --- a/include/tvm/support/with.h +++ b/include/tvm/support/with.h @@ -26,6 +26,7 @@ #define TVM_SUPPORT_WITH_H_ #include + #include namespace tvm { @@ -52,22 +53,19 @@ namespace tvm { * * \tparam ContextType Type of the context object. */ -template +template class With { public: /*! * \brief constructor. * Enter the scope of the context. */ - template - explicit With(Args&& ...args) - : ctx_(std::forward(args)...) { + template + explicit With(Args&&... args) : ctx_(std::forward(args)...) { ctx_.EnterWithScope(); } /*! \brief destructor, leaves the scope of the context. */ - ~With() DMLC_THROW_EXCEPTION { - ctx_.ExitWithScope(); - } + ~With() DMLC_THROW_EXCEPTION { ctx_.ExitWithScope(); } private: /*! \brief internal context type. */ diff --git a/include/tvm/target/codegen.h b/include/tvm/target/codegen.h index 4b7ea56e705d..e89d44dd4eb1 100644 --- a/include/tvm/target/codegen.h +++ b/include/tvm/target/codegen.h @@ -24,14 +24,13 @@ #ifndef TVM_TARGET_CODEGEN_H_ #define TVM_TARGET_CODEGEN_H_ -#include #include -#include +#include #include +#include #include - namespace tvm { /*! \brief namespace for target translation and codegen. */ namespace codegen { @@ -71,8 +70,7 @@ std::string PackImportsToC(const runtime::Module& m, bool system_lib); * \param target_triple LLVM target triple * \return runtime::Module The generated LLVM module. */ -runtime::Module PackImportsToLLVM(const runtime::Module& m, - bool system_lib, +runtime::Module PackImportsToLLVM(const runtime::Module& m, bool system_lib, const std::string& target_triple); } // namespace codegen } // namespace tvm diff --git a/include/tvm/target/generic_func.h b/include/tvm/target/generic_func.h index f2a361b3afaf..a310173fa6ea 100644 --- a/include/tvm/target/generic_func.h +++ b/include/tvm/target/generic_func.h @@ -24,14 +24,14 @@ #ifndef TVM_TARGET_GENERIC_FUNC_H_ #define TVM_TARGET_GENERIC_FUNC_H_ -#include #include +#include #include -#include #include -#include #include +#include +#include namespace tvm { @@ -52,8 +52,7 @@ class GenericFunc : public ObjectRef { * false, an error will be logged if the call would override a previously registered function. * \return reference to self. */ - TVM_DLL GenericFunc& set_default(const runtime::PackedFunc value, - bool allow_override = false); + TVM_DLL GenericFunc& set_default(const runtime::PackedFunc value, bool allow_override = false); /*! * \brief Register a specialized function * \param tags The tags for this specialization @@ -63,8 +62,7 @@ class GenericFunc : public ObjectRef { * \return reference to self. */ TVM_DLL GenericFunc& register_func(const std::vector& tags, - const runtime::PackedFunc value, - bool allow_override = false); + const runtime::PackedFunc value, bool allow_override = false); /*! * \brief Call generic function by directly passing in unpacked format. * \param args Arguments to be passed. @@ -79,16 +77,15 @@ class GenericFunc : public ObjectRef { * } * \endcode */ - template - inline runtime::TVMRetValue operator()(Args&& ...args) const; + template + inline runtime::TVMRetValue operator()(Args&&... args) const; /*! * \brief Invoke the relevant function for the current target context, set by set_target_context. * Arguments are passed in packed format. * \param args The arguments to pass to the function. * \param ret The return value */ - TVM_DLL void CallPacked(runtime::TVMArgs args, - runtime::TVMRetValue* ret) const; + TVM_DLL void CallPacked(runtime::TVMArgs args, runtime::TVMRetValue* ret) const; /*! * \brief Find or register the GenericFunc instance corresponding to the give name @@ -120,14 +117,14 @@ class GenericFunc : public ObjectRef { friend struct Manager; }; -template -inline runtime::TVMRetValue GenericFunc::operator()(Args&& ...args) const { +template +inline runtime::TVMRetValue GenericFunc::operator()(Args&&... args) const { const int kNumArgs = sizeof...(Args); const int kArraySize = kNumArgs > 0 ? kNumArgs : 1; TVMValue values[kArraySize]; int type_codes[kArraySize]; runtime::detail::for_each(runtime::TVMArgsSetter(values, type_codes), - std::forward(args)...); + std::forward(args)...); runtime::TVMRetValue rv; CallPacked(runtime::TVMArgs(values, type_codes, kNumArgs), &rv); return rv; @@ -155,8 +152,7 @@ inline GenericFuncNode* GenericFunc::operator->() { return static_cast(get_mutable()); } -#define TVM_GENERIC_FUNC_REG_VAR_DEF \ - static TVM_ATTRIBUTE_UNUSED ::tvm::GenericFunc& __mk_ ## TVM +#define TVM_GENERIC_FUNC_REG_VAR_DEF static TVM_ATTRIBUTE_UNUSED ::tvm::GenericFunc& __mk_##TVM /*! * \def TVM_REGISTER_GENERIC_FUNC @@ -165,9 +161,8 @@ inline GenericFuncNode* GenericFunc::operator->() { * * \param name The name of the function */ -#define TVM_REGISTER_GENERIC_FUNC(name) \ - TVM_STR_CONCAT(TVM_GENERIC_FUNC_REG_VAR_DEF, __COUNTER__) = \ - ::tvm::GenericFunc::Get(#name) +#define TVM_REGISTER_GENERIC_FUNC(name) \ + TVM_STR_CONCAT(TVM_GENERIC_FUNC_REG_VAR_DEF, __COUNTER__) = ::tvm::GenericFunc::Get(#name) } // namespace tvm #endif // TVM_TARGET_GENERIC_FUNC_H_ diff --git a/include/tvm/target/target.h b/include/tvm/target/target.h index 59aa955b0d9e..c85349d0da60 100644 --- a/include/tvm/target/target.h +++ b/include/tvm/target/target.h @@ -24,14 +24,15 @@ #ifndef TVM_TARGET_TARGET_H_ #define TVM_TARGET_TARGET_H_ -#include -#include #include +#include +#include +#include #include -#include #include #include +#include namespace tvm { /*! @@ -98,9 +99,9 @@ class Target : public ObjectRef { Target() {} explicit Target(ObjectPtr n) : ObjectRef(n) {} /*! - * \brief Create a Target given a string - * \param target_str the string to parse - */ + * \brief Create a Target given a string + * \param target_str the string to parse + */ TVM_DLL static Target Create(const std::string& target_str); /*! * \brief Get the current target context from thread local storage. @@ -112,12 +113,11 @@ class Target : public ObjectRef { */ TVM_DLL static tvm::Target Current(bool allow_not_defined = true); - const TargetNode* operator->() const { - return static_cast(get()); - } + const TargetNode* operator->() const { return static_cast(get()); } using ContainerType = TargetNode; class Internal; + private: // enable with syntax. friend class Internal; @@ -139,174 +139,38 @@ class Target : public ObjectRef { namespace target { /*! \return A target for LLVM */ -TVM_DLL Target llvm(const std::vector& options = - std::vector()); +TVM_DLL Target llvm(const std::vector& options = std::vector()); /*! \return A target for CUDA */ -TVM_DLL Target cuda(const std::vector& options = - std::vector()); +TVM_DLL Target cuda(const std::vector& options = std::vector()); /*! \return A target for ROCm */ -TVM_DLL Target rocm(const std::vector& options = - std::vector()); +TVM_DLL Target rocm(const std::vector& options = std::vector()); /*! \return A target for OpenCL */ -TVM_DLL Target opencl(const std::vector& options = - std::vector()); +TVM_DLL Target opencl(const std::vector& options = std::vector()); /*! \return A target for Metal */ -TVM_DLL Target metal(const std::vector& options = - std::vector()); +TVM_DLL Target metal(const std::vector& options = std::vector()); /*! \return A target for rasp */ -TVM_DLL Target rasp(const std::vector& options = - std::vector()); +TVM_DLL Target rasp(const std::vector& options = std::vector()); /*! \return A target for Mali */ -TVM_DLL Target mali(const std::vector& options = - std::vector()); +TVM_DLL Target mali(const std::vector& options = std::vector()); /*! \return A target for Intel Graphics */ -TVM_DLL Target intel_graphics(const std::vector& options = - std::vector()); +TVM_DLL Target intel_graphics(const std::vector& options = std::vector()); /*! \return A target for stackvm */ -TVM_DLL Target stackvm(const std::vector& options = - std::vector()); +TVM_DLL Target stackvm(const std::vector& options = std::vector()); /*! \return A target for external device */ -TVM_DLL Target ext_dev(const std::vector& options = - std::vector()); +TVM_DLL Target ext_dev(const std::vector& options = std::vector()); /*! \return A target for hexagon */ -TVM_DLL Target hexagon(const std::vector& options = - std::vector()); +TVM_DLL Target hexagon(const std::vector& options = std::vector()); } // namespace target -/*! - * \brief Container for build configuration options - */ -class BuildConfigNode : public Object { - public: - /*! - * \brief The data alignment to use when constructing buffers. If this is set to - * -1, then TVM's internal default will be used - */ - int data_alignment = -1; - /*! - * \brief The offset factor to use when constructing buffers. If this is set to - * 0, then the offset field is not used. - */ - int offset_factor = 0; - - /*! - * \brief Splitting factor for loop splitting. If this is set to zero, no splitting will be - * done. Otherwise, a split will be done with this factor and the inner loop will be unrolled. - */ - int double_buffer_split_loop = 1; - /*! \brief Threshold of number of steps in the loop to be automatically unrolled */ - int auto_unroll_max_step = 0; - /*! \brief The maximum nested level of loops that can be automatically unrolled */ - int auto_unroll_max_depth = 8; - /*! \brief The maximum extent of loop that will be unrolled */ - int auto_unroll_max_extent = 0; - /*! - * \brief Whether to explicitly unroll the loop. If set to false, the unroll hint will - * be passed to the CodeGen phase. Set to true if CodeGen supports unroll pragma. - */ - bool unroll_explicit = true; - - /*! \brief Set to true if buffer arguments do not overlap. This enables more optimization. */ - bool restricted_func = true; - - /*! \brief Whether to detect global barrier */ - bool detect_global_barrier = false; - - /*! \brief Whether to partition const loop */ - bool partition_const_loop = false; - - /*! \brief Whether to dump the IR of each pass (only when building from python) */ - std::vector< std::pair > add_lower_pass; - - /*! \brief Whether to dump the IR of each pass (only when building from python) */ - bool dump_pass_ir = false; - - /*! \brief Whether to instrument loads and stores with check for out of the bounds. */ - bool instrument_bound_checkers = false; - - /*! \brief Whether to disable select rewriting. */ - bool disable_select_rewriting = false; - - /*! \brief Whether to disable loop vectorization. */ - bool disable_vectorize = false; - - /*! \brief Whether to disable assert stmt generation. */ - bool disable_assert = false; - - void VisitAttrs(AttrVisitor* v) { - v->Visit("data_alignment", &data_alignment); - v->Visit("offset_factor", &offset_factor); - v->Visit("double_buffer_split_loop", &double_buffer_split_loop); - v->Visit("auto_unroll_max_step", &auto_unroll_max_step); - v->Visit("auto_unroll_max_depth", &auto_unroll_max_depth); - v->Visit("auto_unroll_max_extent", &auto_unroll_max_extent); - v->Visit("unroll_explicit", &unroll_explicit); - v->Visit("restricted_func", &restricted_func); - v->Visit("detect_global_barrier", &detect_global_barrier); - v->Visit("partition_const_loop", &partition_const_loop); - v->Visit("dump_pass_ir", &dump_pass_ir); - v->Visit("instrument_bound_checkers", &instrument_bound_checkers); - v->Visit("disable_select_rewriting", &disable_select_rewriting); - v->Visit("disable_vectorize", &disable_vectorize); - v->Visit("disable_assert", &disable_assert); - } - - static constexpr const char* _type_key = "BuildConfig"; - TVM_DECLARE_FINAL_OBJECT_INFO(BuildConfigNode, Object); -}; - -/*! - * \brief Build configuration for compilations. - */ -class BuildConfig : public ::tvm::ObjectRef { - public: - BuildConfig() {} - explicit BuildConfig(ObjectPtr n) : ObjectRef(n) {} - const BuildConfigNode* operator->() const { - return static_cast(get()); - } - BuildConfigNode* operator->() { - return static_cast(get_mutable()); - } - /*! - * \brief Construct a BuildConfig containing a empty build config node. - * \return The new BuildConfig - */ - TVM_DLL static BuildConfig Create(); - /*! - * \brief Get the current BuildConfig context from thread local storage, or a default - * configuration if a BuildConfig scope has not been entered. - * \return The configuration that is the current context. - */ - TVM_DLL static BuildConfig Current(); - - using ContainerType = BuildConfigNode; - class Internal; - - private: - // Enable with syntax. - friend class With; - /*! - * \brief Push a new BuildConfig context onto the thread local stack. - */ - TVM_DLL void EnterWithScope(); - - /*! - * \brief Pop a build config off the thread local context stack, - * restoring the previous configuration as the current context. - */ - TVM_DLL void ExitWithScope(); -}; - } // namespace tvm #endif // TVM_TARGET_TARGET_H_ diff --git a/include/tvm/target/target_info.h b/include/tvm/target/target_info.h index 4466476a18de..1de15a5bd526 100644 --- a/include/tvm/target/target_info.h +++ b/include/tvm/target/target_info.h @@ -25,6 +25,7 @@ #define TVM_TARGET_TARGET_INFO_H_ #include + #include namespace tvm { diff --git a/include/tvm/te/autodiff.h b/include/tvm/te/autodiff.h index 180ec0bf676c..e2d379969c65 100644 --- a/include/tvm/te/autodiff.h +++ b/include/tvm/te/autodiff.h @@ -27,6 +27,7 @@ #include #include + #include "tensor.h" namespace tvm { @@ -59,8 +60,8 @@ Tensor Jacobian(const Tensor& output, const Tensor& input); * * Differentiate \p output wrt \p input and multiply the result by \p head on the left using tensor * dot product. \p input must be an immediate dependency of \p output (must be called from within - * the body of \p output). That is, the function will compute one summand of the adjoint for \p input - * given the adjoint for \p output (which is called \p head here). + * the body of \p output). That is, the function will compute one summand of the adjoint for \p + * input given the adjoint for \p output (which is called \p head here). * * \param output The tensor to differentiate. * \param input The input tensor, which \p output should directly use. @@ -68,7 +69,7 @@ Tensor Jacobian(const Tensor& output, const Tensor& input); * \return The tensor of shape `prefix + input.shape` * representing the partial adjoint of \p input wrt one of its consumers (output) */ -Tensor VectorJacobianProduct(const Tensor &output, const Tensor &input, const Tensor &head); +Tensor VectorJacobianProduct(const Tensor& output, const Tensor& input, const Tensor& head); /*! * \brief Perform reverse mode automatic differentiation. @@ -82,14 +83,12 @@ Tensor VectorJacobianProduct(const Tensor &output, const Tensor &input, const Te * wrt all tensors the output depends on. * \param head The adjoint of the output, in other words, some tensor, by which the Jacobians * will be multiplied (using tensordot axes=`output.shape`). - * Its shape must be of the form `prefix + output.shape`. If the null pointer is provided, - * the identity tensor of shape `output.shape + output.shape` will be used. - * \return An array of adjoints corresponding to \p inputs. + * Its shape must be of the form `prefix + output.shape`. If the null pointer is + * provided, the identity tensor of shape `output.shape + output.shape` will be used. \return An + * array of adjoints corresponding to \p inputs. */ -TVM_DLL Array Gradient( - const Tensor& output, - const Array& inputs, - const Tensor& head = Tensor()); +TVM_DLL Array Gradient(const Tensor& output, const Array& inputs, + const Tensor& head = Tensor()); } // namespace te } // namespace tvm diff --git a/include/tvm/te/operation.h b/include/tvm/te/operation.h index 205589928f01..dbd07fa4cf69 100644 --- a/include/tvm/te/operation.h +++ b/include/tvm/te/operation.h @@ -25,16 +25,15 @@ #define TVM_TE_OPERATION_H_ #include -#include #include - +#include +#include #include #include -#include #include -#include #include +#include namespace tvm { /*! \brief Tensor expression language DSL. */ @@ -46,8 +45,7 @@ namespace te { */ struct TensorDom { // constructor - explicit TensorDom(int ndim) - : data(ndim) {} + explicit TensorDom(int ndim) : data(ndim) {} /*! \brief The domain data */ std::vector > data; }; @@ -55,18 +53,18 @@ struct TensorDom { /*! * \brief Base class of all operation nodes */ -class OperationNode : public tir::FunctionBaseNode { +class OperationNode : public Object { public: /*! \brief optional name of the operation */ std::string name; /*! \brief optional tag of the operation */ std::string tag; /*! \brief additional attributes of the operation*/ - Map attrs; - /*! \return name of the operation */ - const std::string& func_name() const final { - return name; - } + Map attrs; + // virtual destructor. + virtual ~OperationNode() {} + /*! \return number of outputs */ + virtual int num_outputs() const = 0; /*! * \return The list of iteration variable at root * \note root_iter_vars decides the shape of the outputs. @@ -96,9 +94,8 @@ class OperationNode : public tir::FunctionBaseNode { * \param rmap The replacement map. * \return self if nothing is replaced, otherwise return replaced op. */ - virtual Operation ReplaceInputs( - const Operation& self, - const std::unordered_map& rmap) const = 0; + virtual Operation ReplaceInputs(const Operation& self, + const std::unordered_map& rmap) const = 0; /*! * \brief Propagate the bounds to inputs * \param self The reference to self. @@ -108,11 +105,9 @@ class OperationNode : public tir::FunctionBaseNode { * The function is only asked to fill the bounds for Tensors that * is already in the out_dom_map */ - virtual void PropBoundToInputs( - const Operation& self, - arith::Analyzer* analyzer, - const std::unordered_map& dom_map, - std::unordered_map* out_dom_map) const = 0; + virtual void PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer, + const std::unordered_map& dom_map, + std::unordered_map* out_dom_map) const = 0; /*! * \brief Gather the bound from output tensor. * Set the range of each root_iter_vars in the op to out_dom_map @@ -121,10 +116,9 @@ class OperationNode : public tir::FunctionBaseNode { * \param tensor_dom Domain map of Tensor->access set of each dimension. * \param out_dom_map The output domain map of each IterVar to be setted. */ - virtual void GatherBound( - const Operation& self, - const std::unordered_map& tensor_dom, - std::unordered_map* out_dom_map) const = 0; + virtual void GatherBound(const Operation& self, + const std::unordered_map& tensor_dom, + std::unordered_map* out_dom_map) const = 0; /*! * \brief Build the Realize statement that realizes * the op's output tensors. @@ -133,10 +127,9 @@ class OperationNode : public tir::FunctionBaseNode { * \param body The body that is going to get * \return A realization statement that wraps body. */ - virtual Stmt BuildRealize( - const Stage& stage, - const std::unordered_map& realize_map, - const Stmt& body) const = 0; + virtual Stmt BuildRealize(const Stage& stage, + const std::unordered_map& realize_map, + const Stmt& body) const = 0; /*! * \brief Build the statement that provide the output tensors. * \param stage The schedule stage of the op. @@ -144,10 +137,8 @@ class OperationNode : public tir::FunctionBaseNode { * \param debug_keep_trivial_loop Whether keep trivial loops with extent of 1 * \return A statement that add production and wraps consumer. */ - virtual Stmt BuildProvide( - const Stage& stage, - const std::unordered_map& dom_map, - bool debug_keep_trivial_loop) const = 0; + virtual Stmt BuildProvide(const Stage& stage, const std::unordered_map& dom_map, + bool debug_keep_trivial_loop) const = 0; static constexpr const char* _type_key = "Operation"; @@ -169,26 +160,17 @@ class PlaceholderOpNode : public OperationNode { DataType output_dtype(size_t i) const final; Array output_shape(size_t i) const final; Array InputTensors() const final; - Operation ReplaceInputs( - const Operation& self, - const std::unordered_map& rmap) const final; - void PropBoundToInputs( - const Operation& self, - arith::Analyzer* analyzer, - const std::unordered_map& dom_map, - std::unordered_map* out_dom_map) const final; - void GatherBound( - const Operation& self, - const std::unordered_map& tensor_dom, - std::unordered_map* out_dom_map) const final; - Stmt BuildRealize( - const Stage& stage, - const std::unordered_map& realize_map, - const Stmt& body) const final; - Stmt BuildProvide( - const Stage& stage, - const std::unordered_map& dom_map, - bool debug_keep_trivial_loop) const final; + Operation ReplaceInputs(const Operation& self, + const std::unordered_map& rmap) const final; + void PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer, + const std::unordered_map& dom_map, + std::unordered_map* out_dom_map) const final; + void GatherBound(const Operation& self, const std::unordered_map& tensor_dom, + std::unordered_map* out_dom_map) const final; + Stmt BuildRealize(const Stage& stage, const std::unordered_map& realize_map, + const Stmt& body) const final; + Stmt BuildProvide(const Stage& stage, const std::unordered_map& dom_map, + bool debug_keep_trivial_loop) const final; void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); @@ -197,14 +179,22 @@ class PlaceholderOpNode : public OperationNode { v->Visit("shape", &shape); v->Visit("dtype", &dtype); } - static Operation make(std::string name, - Array shape, - DataType dtype); static constexpr const char* _type_key = "PlaceholderOp"; TVM_DECLARE_FINAL_OBJECT_INFO(PlaceholderOpNode, OperationNode); }; +/*! + * \brief Managed reference to PlaceholderOpNode + * \sa PlaceholderOpNode + */ +class PlaceholderOp : public Operation { + public: + TVM_DLL PlaceholderOp(std::string name, Array shape, DataType dtype); + + TVM_DEFINE_OBJECT_REF_METHODS(PlaceholderOp, Operation, PlaceholderOpNode); +}; + /*! * \brief A Compute op that compute a tensor on certain domain. * This is the base class for ComputeOp (operating on a scalar at a time) and @@ -219,21 +209,16 @@ class TVM_DLL BaseComputeOpNode : public OperationNode { // override functions Array root_iter_vars() const final; Array output_shape(size_t idx) const final; - void GatherBound( - const Operation& self, - const std::unordered_map& tensor_dom, - std::unordered_map* out_dom_map) const final; - Stmt BuildRealize( - const Stage& stage, - const std::unordered_map& realize_map, - const Stmt& body) const final; + void GatherBound(const Operation& self, const std::unordered_map& tensor_dom, + std::unordered_map* out_dom_map) const final; + Stmt BuildRealize(const Stage& stage, const std::unordered_map& realize_map, + const Stmt& body) const final; virtual size_t num_schedulable_dims() const = 0; static constexpr const char* _type_key = "BaseComputeOp"; TVM_DECLARE_BASE_OBJECT_INFO(BaseComputeOpNode, OperationNode); }; - /*! * \brief A Compute op that compute a tensor on certain domain. */ @@ -247,18 +232,13 @@ class TVM_DLL ComputeOpNode : public BaseComputeOpNode { int num_outputs() const final; DataType output_dtype(size_t i) const final; Array InputTensors() const final; - Operation ReplaceInputs( - const Operation& self, - const std::unordered_map& rmap) const final; - void PropBoundToInputs( - const Operation& self, - arith::Analyzer* analyzer, - const std::unordered_map& dom_map, - std::unordered_map* out_dom_map) const final; - Stmt BuildProvide( - const Stage& stage, - const std::unordered_map& dom_map, - bool debug_keep_trivial_loop) const final; + Operation ReplaceInputs(const Operation& self, + const std::unordered_map& rmap) const final; + void PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer, + const std::unordered_map& dom_map, + std::unordered_map* out_dom_map) const final; + Stmt BuildProvide(const Stage& stage, const std::unordered_map& dom_map, + bool debug_keep_trivial_loop) const final; size_t num_schedulable_dims() const final; void VisitAttrs(AttrVisitor* v) { @@ -269,16 +249,23 @@ class TVM_DLL ComputeOpNode : public BaseComputeOpNode { v->Visit("reduce_axis", &reduce_axis); v->Visit("body", &body); } - static Operation make(std::string name, - std::string tag, - Map attrs, - Array axis, - Array body); static constexpr const char* _type_key = "ComputeOp"; TVM_DECLARE_FINAL_OBJECT_INFO(ComputeOpNode, BaseComputeOpNode); }; +/*! + * \brief Managed reference to ComputeOpNode + * \sa ComputeOpNode + */ +class ComputeOp : public Operation { + public: + TVM_DLL ComputeOp(std::string name, std::string tag, Map attrs, + Array axis, Array body); + + TVM_DEFINE_OBJECT_REF_METHODS(ComputeOp, Operation, ComputeOpNode); +}; + /*! * \brief A TenorCompute op that compute a tensor with an tensor intrinsic. */ @@ -300,18 +287,13 @@ class TensorComputeOpNode : public BaseComputeOpNode { int num_outputs() const final; DataType output_dtype(size_t i) const final; Array InputTensors() const final; - Operation ReplaceInputs( - const Operation& self, - const std::unordered_map& rmap) const final; - void PropBoundToInputs( - const Operation& self, - arith::Analyzer* analyzer, - const std::unordered_map& dom_map, - std::unordered_map* out_dom_map) const final; - Stmt BuildProvide( - const Stage& stage, - const std::unordered_map& dom_map, - bool debug_keep_trivial_loop) const final; + Operation ReplaceInputs(const Operation& self, + const std::unordered_map& rmap) const final; + void PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer, + const std::unordered_map& dom_map, + std::unordered_map* out_dom_map) const final; + Stmt BuildProvide(const Stage& stage, const std::unordered_map& dom_map, + bool debug_keep_trivial_loop) const final; size_t num_schedulable_dims() const final; void VisitAttrs(AttrVisitor* v) { @@ -325,20 +307,25 @@ class TensorComputeOpNode : public BaseComputeOpNode { v->Visit("input_regions", &input_regions); v->Visit("scalar_inputs", &scalar_inputs); } - static Operation make(std::string name, - std::string tag, - Array axis, - Array reduce_axis, - int schedulable_ndim, - TensorIntrin intrin, - Array tensors, - Array regions, - Array scalar_inputs); static constexpr const char* _type_key = "TensorComputeOp"; TVM_DECLARE_FINAL_OBJECT_INFO(TensorComputeOpNode, BaseComputeOpNode); }; +/*! + * \brief Managed reference to TensorComputeOpNode + * \sa TensorComputeOpNode + */ +class TensorComputeOp : public Operation { + public: + TVM_DLL TensorComputeOp(std::string name, std::string tag, Array axis, + Array reduce_axis, int schedulable_ndim, TensorIntrin intrin, + Array tensors, Array regions, + Array scalar_inputs); + + TVM_DEFINE_OBJECT_REF_METHODS(TensorComputeOp, Operation, TensorComputeOpNode); +}; + /*! * \brief Symbolic scan. */ @@ -375,26 +362,17 @@ class ScanOpNode : public OperationNode { DataType output_dtype(size_t i) const final; Array output_shape(size_t i) const final; Array InputTensors() const final; - Operation ReplaceInputs( - const Operation& self, - const std::unordered_map& rmap) const final; - void PropBoundToInputs( - const Operation& self, - arith::Analyzer* analyzer, - const std::unordered_map& dom_map, - std::unordered_map* out_dom_map) const final; - void GatherBound( - const Operation& self, - const std::unordered_map& tensor_dom, - std::unordered_map* out_dom_map) const final; - Stmt BuildRealize( - const Stage& stage, - const std::unordered_map& realize_map, - const Stmt& body) const final; - Stmt BuildProvide( - const Stage& stage, - const std::unordered_map& dom_map, - bool debug_keep_trivial_loop) const final; + Operation ReplaceInputs(const Operation& self, + const std::unordered_map& rmap) const final; + void PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer, + const std::unordered_map& dom_map, + std::unordered_map* out_dom_map) const final; + void GatherBound(const Operation& self, const std::unordered_map& tensor_dom, + std::unordered_map* out_dom_map) const final; + Stmt BuildRealize(const Stage& stage, const std::unordered_map& realize_map, + const Stmt& body) const final; + Stmt BuildProvide(const Stage& stage, const std::unordered_map& dom_map, + bool debug_keep_trivial_loop) const final; void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); @@ -407,19 +385,24 @@ class ScanOpNode : public OperationNode { v->Visit("inputs", &inputs); v->Visit("spatial_axis_", &spatial_axis_); } - static Operation make(std::string name, - std::string tag, - Map attrs, - IterVar axis, - Array init, - Array update, - Array state_placeholder, - Array input); static constexpr const char* _type_key = "ScanOp"; TVM_DECLARE_FINAL_OBJECT_INFO(ScanOpNode, OperationNode); }; +/*! + * \brief Managed reference to ScanOpNode + * \sa ScanOpNode + */ +class ScanOp : public Operation { + public: + TVM_DLL ScanOp(std::string name, std::string tag, Map attrs, IterVar axis, + Array init, Array update, Array state_placeholder, + Array input); + + TVM_DEFINE_OBJECT_REF_METHODS(ScanOp, Operation, ScanOpNode); +}; + /*! * \brief External computation that cannot be splitted. */ @@ -442,26 +425,17 @@ class ExternOpNode : public OperationNode { DataType output_dtype(size_t i) const final; Array output_shape(size_t i) const final; Array InputTensors() const final; - Operation ReplaceInputs( - const Operation& self, - const std::unordered_map& rmap) const final; - void PropBoundToInputs( - const Operation& self, - arith::Analyzer* analyzer, - const std::unordered_map& dom_map, - std::unordered_map* out_dom_map) const final; - void GatherBound( - const Operation& self, - const std::unordered_map& tensor_dom, - std::unordered_map* out_dom_map) const final; - Stmt BuildRealize( - const Stage& stage, - const std::unordered_map& realize_map, - const Stmt& body) const final; - Stmt BuildProvide( - const Stage& stage, - const std::unordered_map& dom_map, - bool debug_keep_trivial_loop) const final; + Operation ReplaceInputs(const Operation& self, + const std::unordered_map& rmap) const final; + void PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer, + const std::unordered_map& dom_map, + std::unordered_map* out_dom_map) const final; + void GatherBound(const Operation& self, const std::unordered_map& tensor_dom, + std::unordered_map* out_dom_map) const final; + Stmt BuildRealize(const Stage& stage, const std::unordered_map& realize_map, + const Stmt& body) const final; + Stmt BuildProvide(const Stage& stage, const std::unordered_map& dom_map, + bool debug_keep_trivial_loop) const final; void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); @@ -472,18 +446,24 @@ class ExternOpNode : public OperationNode { v->Visit("output_placeholders", &output_placeholders); v->Visit("body", &body); } - TVM_DLL static Operation make(std::string name, - std::string tag, - Map attrs, - Array inputs, - Array input_placeholders, - Array output_placeholders, - Stmt body); static constexpr const char* _type_key = "ExternOp"; TVM_DECLARE_FINAL_OBJECT_INFO(ExternOpNode, OperationNode); }; +/*! + * \brief Managed reference to ExternOpNode + * \sa ExternOpNode + */ +class ExternOp : public Operation { + public: + TVM_DLL ExternOp(std::string name, std::string tag, Map attrs, + Array inputs, Array input_placeholders, + Array output_placeholders, Stmt body); + + TVM_DEFINE_OBJECT_REF_METHODS(ExternOp, Operation, ExternOpNode); +}; + /*! * \brief A computation operator that generated by hybrid script. */ @@ -510,26 +490,17 @@ class HybridOpNode : public OperationNode { DataType output_dtype(size_t i) const final; Array output_shape(size_t i) const final; Array InputTensors() const final; - Operation ReplaceInputs( - const Operation& self, - const std::unordered_map& rmap) const final; - void PropBoundToInputs( - const Operation& self, - arith::Analyzer* analyzer, - const std::unordered_map& dom_map, - std::unordered_map* out_dom_map) const final; - void GatherBound( - const Operation& self, - const std::unordered_map& tensor_dom, - std::unordered_map* out_dom_map) const final; - Stmt BuildRealize( - const Stage& stage, - const std::unordered_map& realize_map, - const Stmt& body) const final; - Stmt BuildProvide( - const Stage& stage, - const std::unordered_map& dom_map, - bool debug_keep_trivial_loop) const final; + Operation ReplaceInputs(const Operation& self, + const std::unordered_map& rmap) const final; + void PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer, + const std::unordered_map& dom_map, + std::unordered_map* out_dom_map) const final; + void GatherBound(const Operation& self, const std::unordered_map& tensor_dom, + std::unordered_map* out_dom_map) const final; + Stmt BuildRealize(const Stage& stage, const std::unordered_map& realize_map, + const Stmt& body) const final; + Stmt BuildProvide(const Stage& stage, const std::unordered_map& dom_map, + bool debug_keep_trivial_loop) const final; void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); @@ -540,17 +511,23 @@ class HybridOpNode : public OperationNode { v->Visit("axis", &axis); v->Visit("body", &body); } - TVM_DLL static Operation make(std::string name, - std::string tag, - Map attrs, - Array inputs, - Array outputs, - Stmt body); static constexpr const char* _type_key = "HybridOp"; TVM_DECLARE_FINAL_OBJECT_INFO(HybridOpNode, OperationNode); }; +/*! + * \brief Managed reference to HybridOpNode + * \sa HybridOpNode + */ +class HybridOp : public Operation { + public: + TVM_DLL HybridOp(std::string name, std::string tag, Map attrs, + Array inputs, Array outputs, Stmt body); + + TVM_DEFINE_OBJECT_REF_METHODS(HybridOp, Operation, HybridOpNode); +}; + /*! * \brief Construct a new Var expression * \param name_hint The name hint for the expression @@ -575,10 +552,10 @@ TVM_DLL IterVar thread_axis(Range dom, std::string tag); TVM_DLL IterVar reduce_axis(Range dom, std::string name = "rv"); /*! \brief The compute function to specify the input source of a Tensor */ -using FCompute = std::function& i)>; +using FCompute = std::function& i)>; /*! \brief The compute function to specify the inputs source of Tensors */ -using FBatchCompute = std::function (const Array& i)>; +using FBatchCompute = std::function(const Array& i)>; /*! * \brief create a place holder tensor. @@ -586,8 +563,7 @@ using FBatchCompute = std::function (const Array& i)>; * \param dtype the data type of the tensor. * \param name The name of the Tensor. */ -TVM_DLL Tensor placeholder(Array shape, - DataType dtype = DataType::Float(32), +TVM_DLL Tensor placeholder(Array shape, DataType dtype = DataType::Float(32), std::string name = "placeholder"); /*! @@ -599,11 +575,8 @@ TVM_DLL Tensor placeholder(Array shape, * \param tag The optional tag of the tensor. * \param attrs Optional additional attributes of the compute. */ -TVM_DLL Tensor compute(Array shape, - FCompute fcompute, - std::string name = "tensor", - std::string tag = "", - Map attrs = {}); +TVM_DLL Tensor compute(Array shape, FCompute fcompute, std::string name = "tensor", + std::string tag = "", Map attrs = {}); /*! * \brief Construct a new tensor by computing over shape, @@ -614,11 +587,9 @@ TVM_DLL Tensor compute(Array shape, * \param tag The optional tag of the tensor. * \param attrs Optional additional attributes of the compute. */ -TVM_DLL Array compute(Array shape, - FBatchCompute fcompute, - std::string name = "tensor", - std::string tag = "", - Map attrs = {}); +TVM_DLL Array compute(Array shape, FBatchCompute fcompute, + std::string name = "tensor", std::string tag = "", + Map attrs = {}); /*! * \brief Construct new tensors by scan. @@ -632,45 +603,34 @@ TVM_DLL Array compute(Array shape, * \param tag The optional tag of the tensor. * \param attrs Optional additional attributes of the compute. */ -TVM_DLL Array scan(Array init, - Array update, - Array state_placeholder, - Array inputs = Array(), - std::string name = "scan", - std::string tag = "", - Map attrs = {}); +TVM_DLL Array scan(Array init, Array update, + Array state_placeholder, Array inputs = Array(), + std::string name = "scan", std::string tag = "", + Map attrs = {}); // same as compute, specialized for different fcompute function -inline Tensor compute(Array shape, - std::function f, - std::string name = "tensor", - std::string tag = "", - Map attrs = {}) { - FCompute fc = [f] (const Array& i) { return f(i[0]); }; +inline Tensor compute(Array shape, std::function f, + std::string name = "tensor", std::string tag = "", + Map attrs = {}) { + FCompute fc = [f](const Array& i) { return f(i[0]); }; return compute(shape, fc, name, tag, attrs); } -inline Tensor compute(Array shape, - std::function f, - std::string name = "tensor", - std::string tag = "", - Map attrs = {}) { - FCompute fc = [f] (const Array& i) { return f(i[0], i[1]); }; +inline Tensor compute(Array shape, std::function f, + std::string name = "tensor", std::string tag = "", + Map attrs = {}) { + FCompute fc = [f](const Array& i) { return f(i[0], i[1]); }; return compute(shape, fc, name, tag, attrs); } -inline Tensor compute(Array shape, - std::function f, - std::string name = "tensor", - std::string tag = "", - Map attrs = {}) { - FCompute fc = [f] (const Array& i) { return f(i[0], i[1], i[2]); }; - return compute(shape, fc, name, tag, attrs); +inline Tensor compute(Array shape, std::function f, + std::string name = "tensor", std::string tag = "", + Map attrs = {}) { + FCompute fc = [f](const Array& i) { return f(i[0], i[1], i[2]); }; + return compute(shape, fc, name, tag, attrs); } -inline Tensor compute(Array shape, - std::function f, - std::string name = "tensor", - std::string tag = "", - Map attrs = {}) { - FCompute fc = [f] (const Array& i) { return f(i[0], i[1], i[2], i[3]); }; +inline Tensor compute(Array shape, std::function f, + std::string name = "tensor", std::string tag = "", + Map attrs = {}) { + FCompute fc = [f](const Array& i) { return f(i[0], i[1], i[2], i[3]); }; return compute(shape, fc, name, tag, attrs); } diff --git a/include/tvm/te/schedule.h b/include/tvm/te/schedule.h index a8a02365fbda..ee4fb33349f7 100644 --- a/include/tvm/te/schedule.h +++ b/include/tvm/te/schedule.h @@ -25,10 +25,10 @@ #ifndef TVM_TE_SCHEDULE_H_ #define TVM_TE_SCHEDULE_H_ -#include +#include #include #include -#include +#include #include #include @@ -84,12 +84,12 @@ class Stage : public ObjectRef { * \param scope The iteration point to carry the schedule. * \return reference to self. */ - TVM_DLL Stage& compute_at(Stage parent, IterVar scope); // NOLINT(*) + TVM_DLL Stage& compute_at(Stage parent, IterVar scope); // NOLINT(*) /*! * \brief Compute the function inline. * \return reference to self. */ - TVM_DLL Stage& compute_inline(); // NOLINT(*) + TVM_DLL Stage& compute_inline(); // NOLINT(*) /*! * \brief Compute the function at group root. * \return reference to self. @@ -131,7 +131,8 @@ class Stage : public ObjectRef { * \param p_inner The result inner domain. * \return reference to self. */ - TVM_DLL Stage& split(IterVar parent, PrimExpr factor, IterVar* p_outer, IterVar* p_inner); // NOLINT(*) + TVM_DLL Stage& split(IterVar parent, PrimExpr factor, IterVar* p_outer, + IterVar* p_inner); // NOLINT(*) /*! * \brief Split the iteration with given number of parts. * @@ -141,7 +142,8 @@ class Stage : public ObjectRef { * \param p_inner The result inner domain. * \return reference to self. */ - TVM_DLL Stage& split_by_nparts(IterVar parent, PrimExpr nparts, IterVar* p_outer, IterVar* p_inner); // NOLINT(*) + TVM_DLL Stage& split_by_nparts(IterVar parent, PrimExpr nparts, IterVar* p_outer, + IterVar* p_inner); // NOLINT(*) /*! * \brief Fuse the inner outer domain to the target * \param outer The outer domain to be fused. @@ -169,7 +171,7 @@ class Stage : public ObjectRef { * \param order The order of iteration variable. * \return reference to self. */ - TVM_DLL Stage& reorder(const Array& order); // NOLINT(*) + TVM_DLL Stage& reorder(const Array& order); // NOLINT(*) /*! * \brief Perform tiling on two dimensions * The final loop order from outmost to inner most are @@ -185,16 +187,15 @@ class Stage : public ObjectRef { * \param p_y_inner Inner axis of y dimension * \return reference to self. */ - TVM_DLL Stage& tile(IterVar x_parent, IterVar y_parent, // NOLINT(*) - PrimExpr x_factor, PrimExpr y_factor, - IterVar* p_x_outer, IterVar* p_y_outer, - IterVar* p_x_inner, IterVar* p_y_inner); + TVM_DLL Stage& tile(IterVar x_parent, IterVar y_parent, // NOLINT(*) + PrimExpr x_factor, PrimExpr y_factor, IterVar* p_x_outer, IterVar* p_y_outer, + IterVar* p_x_inner, IterVar* p_y_inner); /*! * \brief Vectorize iteration. * \param var The axis to be vectorized. * \return reference to self. */ - TVM_DLL Stage& vectorize(IterVar var); // NOLINT(*) + TVM_DLL Stage& vectorize(IterVar var); // NOLINT(*) /*! * \brief Replace computation of the current stage by tensor intrinsic f. * \param var The axis marks beginning of tensorization. @@ -202,19 +203,19 @@ class Stage : public ObjectRef { * \param f The Tensor compute intrinsics. * \return reference to self. */ - TVM_DLL Stage& tensorize(IterVar var, TensorIntrin f); // NOLINT(*) + TVM_DLL Stage& tensorize(IterVar var, TensorIntrin f); // NOLINT(*) /*! * \brief Unroll iteration. * \param var The axis to be unrolled. * \return reference to self. */ - TVM_DLL Stage& unroll(IterVar var); // NOLINT(*) + TVM_DLL Stage& unroll(IterVar var); // NOLINT(*) /*! * \brief Parallelize iteration. * \param var The axis to be parallelized. * \return reference to self. */ - TVM_DLL Stage& parallel(IterVar var); // NOLINT(*) + TVM_DLL Stage& parallel(IterVar var); // NOLINT(*) /*! * \brief Annotate the iteration with pragma * @@ -224,9 +225,8 @@ class Stage : public ObjectRef { * * \return reference to self. */ - TVM_DLL Stage& pragma(IterVar var, - const std::string& pragma_type, - const PrimExpr& pragma_value = PrimExpr()); // NOLINT(*) + TVM_DLL Stage& pragma(IterVar var, const std::string& pragma_type, + const PrimExpr& pragma_value = PrimExpr()); // NOLINT(*) /*! * \brief Fetch data in advance. * \param domain the tensor to be prefetched @@ -234,7 +234,7 @@ class Stage : public ObjectRef { * \param offset the number of iterations be to fetched in advance * \return reference to self */ - TVM_DLL Stage& prefetch(const Tensor &domain, IterVar var, PrimExpr offset); //NOLINT(*) + TVM_DLL Stage& prefetch(const Tensor& domain, IterVar var, PrimExpr offset); // NOLINT(*) /*! * \brief Set alignment requirement for specific dimension. * @@ -245,17 +245,12 @@ class Stage : public ObjectRef { * \param offset The required offset factor. * \return reference to self */ - TVM_DLL Stage& storage_align(IterVar axis, int factor, int offset); //NOLINT(*) + TVM_DLL Stage& storage_align(IterVar axis, int factor, int offset); // NOLINT(*) /*! * \brief Compute current stage with double buffering. * \return reference to self. */ - TVM_DLL Stage& double_buffer(); // NOLINT(*) - /*! - * \brief Schedule for OpenGL fragment shader. - * \return reference to self. - */ - Stage& opengl(); // NOLINT(*) + TVM_DLL Stage& double_buffer(); // NOLINT(*) /*! * \brief whether the stage has been scheduled. * \return whether the stage has been scheduled. @@ -282,6 +277,12 @@ class Schedule : public ObjectRef { public: Schedule() {} explicit Schedule(ObjectPtr n) : ObjectRef(n) {} + /*! + * \brief Create a schedule for array of ops(and their dependencies). + * \param ops The ops to be scheduled. + * \return sch The created Schedule. + */ + TVM_DLL explicit Schedule(Array ops); /*! * \brief Get a copy of current schedule. * \return The copied schedule. @@ -297,9 +298,7 @@ class Schedule : public ObjectRef { * \param tensor The tensor * \return The stage corresponding to the tensor's op */ - TVM_DLL Stage operator[](const Tensor& tensor) { - return this->operator[](tensor->op); - } + TVM_DLL Stage operator[](const Tensor& tensor) { return this->operator[](tensor->op); } /*! * \brief Create a new stage group for all intermediate * operations between inputs and outputs. @@ -309,9 +308,8 @@ class Schedule : public ObjectRef { * \param include_inputs Whether include inputs if they are reachable from outputs. * \return The new grouped stage. */ - TVM_DLL Stage create_group(const Array& outputs, - const Array& inputs, - bool include_inputs = false); + TVM_DLL Stage create_group(const Array& outputs, const Array& inputs, + bool include_inputs = false); /*! * \brief create a cache read of original tensor for readers. * This will mutate the body of the readers. @@ -321,9 +319,8 @@ class Schedule : public ObjectRef { * \param readers The readers to redirect to the tensor. * \return The created tensor. */ - TVM_DLL Tensor cache_read(const Tensor& tensor, - const std::string& scope, - const Array& readers); + TVM_DLL Tensor cache_read(const Tensor& tensor, const std::string& scope, + const Array& readers); /*! * \brief Create a cache write tensor for producing tensor. * The the tensor will take over body of original tensor op. @@ -371,9 +368,7 @@ class Schedule : public ObjectRef { * \param factor_axis The position where the new axis is placed. * \return The created factored tensors. */ - TVM_DLL Array rfactor(const Tensor& tensor, - const IterVar& axis, - int factor_axis = 0); + TVM_DLL Array rfactor(const Tensor& tensor, const IterVar& axis, int factor_axis = 0); /*! * \brief Normalize the schedule. * This is needed before bound inference. @@ -484,8 +479,6 @@ class StageNode : public Object { std::string scope; /*! \brief Whether this is an output stage */ bool is_output{false}; - /*! \brief Whether this is an OpenGL stage */ - bool is_opengl{false}; /*! \brief Whether apply double buffer optimization to this stage */ bool double_buffer{false}; /*! @@ -509,7 +502,6 @@ class StageNode : public Object { v->Visit("attach_stage", &attach_stage); v->Visit("scope", &scope); v->Visit("is_output", &is_output); - v->Visit("is_opengl", &is_opengl); v->Visit("double_buffer", &double_buffer); v->Visit("group", &group); v->Visit("num_child_stages", &num_child_stages); @@ -565,16 +557,7 @@ class ScheduleNode : public Object { * \param tensor The candidate tensor. * \return true if the schedule has the tensor. Otherwise, false. */ - TVM_DLL bool Contain(const Tensor& tensor) const { - return Contain(tensor->op); - } - - /*! - * \brief Create a schedule for array of ops(and their dependencies). - * \param ops The ops to be scheduled. - * \return sch The created Schedule. - */ - TVM_DLL static Schedule make(Array ops); + TVM_DLL bool Contain(const Tensor& tensor) const { return Contain(tensor->op); } static constexpr const char* _type_key = "Schedule"; TVM_DECLARE_FINAL_OBJECT_INFO(ScheduleNode, Object); @@ -585,9 +568,7 @@ class ScheduleNode : public Object { * \param ops The ops to be scheduled. * \return sch The created Schedule. */ -inline Schedule create_schedule(Array ops) { - return ScheduleNode::make(ops); -} +inline Schedule create_schedule(Array ops) { return Schedule(ops); } /*! \brief node container for IterVar attr */ class IterVarAttrNode : public Object { @@ -666,16 +647,21 @@ class SplitNode : public IterVarRelationNode { v->Visit("nparts", &nparts); } - static IterVarRelation make(IterVar parent, - IterVar outer, - IterVar inner, - PrimExpr factor, - PrimExpr nparts); - static constexpr const char* _type_key = "Split"; TVM_DECLARE_FINAL_OBJECT_INFO(SplitNode, IterVarRelationNode); }; +/*! + * \brief Managed reference to SplitNode + * \sa SplitNode + */ +class Split : public IterVarRelation { + public: + TVM_DLL Split(IterVar parent, IterVar outer, IterVar inner, PrimExpr factor, PrimExpr nparts); + + TVM_DEFINE_OBJECT_REF_METHODS(Split, IterVarRelation, SplitNode); +}; + /*! * \brief Fuse two domains into one domain. */ @@ -694,13 +680,21 @@ class FuseNode : public IterVarRelationNode { v->Visit("fused", &fused); } - static IterVarRelation make( - IterVar outer, IterVar inner, IterVar fused); - static constexpr const char* _type_key = "Fuse"; TVM_DECLARE_FINAL_OBJECT_INFO(FuseNode, IterVarRelationNode); }; +/*! + * \brief Managed reference to FuseNode + * \sa FuseNode + */ +class Fuse : public IterVarRelation { + public: + TVM_DLL Fuse(IterVar outer, IterVar inner, IterVar fused); + + TVM_DEFINE_OBJECT_REF_METHODS(Fuse, IterVarRelation, FuseNode); +}; + /*! * \brief Rebase the iteration to make min to be 0. * This is useful to normalize the Schedule @@ -718,12 +712,20 @@ class RebaseNode : public IterVarRelationNode { v->Visit("rebased", &rebased); } - static IterVarRelation make(IterVar parent, IterVar rebased); - static constexpr const char* _type_key = "Rebase"; TVM_DECLARE_FINAL_OBJECT_INFO(RebaseNode, IterVarRelationNode); }; +/*! + * \brief Managed reference to RebaseNode + * \sa RebaseNode + */ +class Rebase : public IterVarRelation { + public: + TVM_DLL Rebase(IterVar parent, IterVar rebased); + + TVM_DEFINE_OBJECT_REF_METHODS(Rebase, IterVarRelation, RebaseNode); +}; /*! * \brief Singleton iterator [0, 1) @@ -733,16 +735,23 @@ class SingletonNode : public IterVarRelationNode { /*! \brief The singleton iterator */ IterVar iter; - void VisitAttrs(AttrVisitor* v) { - v->Visit("iter", &iter); - } - - static IterVarRelation make(IterVar iter); + void VisitAttrs(AttrVisitor* v) { v->Visit("iter", &iter); } static constexpr const char* _type_key = "Singleton"; TVM_DECLARE_FINAL_OBJECT_INFO(SingletonNode, IterVarRelationNode); }; +/*! + * \brief Managed reference to SingletonNode + * \sa SingletonNode + */ +class Singleton : public IterVarRelation { + public: + TVM_DLL explicit Singleton(IterVar iter); + + TVM_DEFINE_OBJECT_REF_METHODS(Singleton, IterVarRelation, SingletonNode); +}; + /*! \brief Container for specialization conditions. */ class SpecializedConditionNode : public Object { public: @@ -753,9 +762,7 @@ class SpecializedConditionNode : public Object { */ Array clauses; - void VisitAttrs(AttrVisitor* v) { - v->Visit("clauses", &clauses); - } + void VisitAttrs(AttrVisitor* v) { v->Visit("clauses", &clauses); } static constexpr const char* _type_key = "SpecializedCondition"; TVM_DECLARE_FINAL_OBJECT_INFO(SpecializedConditionNode, Object); @@ -792,19 +799,13 @@ class SpecializedCondition : public ObjectRef { }; // implementations -inline const StageNode* Stage::operator->() const { - return static_cast(get()); -} -inline StageNode* Stage::operator->() { - return static_cast(get_mutable()); -} +inline const StageNode* Stage::operator->() const { return static_cast(get()); } +inline StageNode* Stage::operator->() { return static_cast(get_mutable()); } inline const ScheduleNode* Schedule::operator->() const { return static_cast(get()); } -inline ScheduleNode* Schedule::operator->() { - return static_cast(get_mutable()); -} +inline ScheduleNode* Schedule::operator->() { return static_cast(get_mutable()); } inline const IterVarRelationNode* IterVarRelation::operator->() const { return static_cast(get()); diff --git a/include/tvm/te/schedule_pass.h b/include/tvm/te/schedule_pass.h index b3ecbf8c08e1..a4efa7a94990 100644 --- a/include/tvm/te/schedule_pass.h +++ b/include/tvm/te/schedule_pass.h @@ -29,10 +29,28 @@ #define TVM_TE_SCHEDULE_PASS_H_ #include +#include namespace tvm { namespace te { +/*! + * \brief To automatically inline the element-wise operations. + * + * \param sch The schedule to be inlined. + */ +void AutoInlineElemWise(Schedule sch); + +/*! + * \brief To automatically inline operations with injective writes + * (i.e. writes without reduction or sequential loops). Note + * that in this case, guarantees about contiguity, transpose, stride, + * alignemnt and memory footprint in general do not hold. + * + * \param sch The schedule to be inlined. + */ +TVM_DLL void AutoInlineInjective(Schedule sch); + /*! * \brief Infer the bound of all iteration variables relates to the schedule. * @@ -41,6 +59,15 @@ namespace te { */ Map InferBound(const Schedule& sch); +/*! + * \brief Verify if there is any argument bound to compact buffer. + * + * \param stmt The stmt to be verified. + * \return true if there is any buffer_bind_scope attribute found, + * otherwise, false. + */ +bool VerifyCompactBuffer(const Stmt& stmt); + /*! * \brief Schedule s' dependent operations. * @@ -55,21 +82,35 @@ Map InferBound(const Schedule& sch); Stmt ScheduleOps(Schedule s, Map dom_map, bool debug_keep_trivial_loop); /*! - * \brief To automatically inline the element-wise operations. + * \brief Try to modify the AST generated by ScheduleOps to support TensorCore. * - * \param sch The schedule to be inlined. + * \param stmt The stmt to be trasnformed. + * \param schedule The original schedule. + * \param extern_buffer Map specifies external + * buffer assignment of input and outputs. + * \return Transformed stmt. */ -void AutoInlineElemWise(Schedule sch); +Stmt SchedulePostProcRewriteForTensorCore(Stmt stmt, Schedule schedule, + Map extern_buffer); /*! - * \brief To automatically inline operations with injective writes - * (i.e. writes without reduction or sequential loops). Note - * that in this case, guarantees about contiguity, transpose, stride, - * alignemnt and memory footprint in general do not hold. + * \brief Postprocessing the Stmt generated by ScheduleOps to create + * a PrimFunc that can then be used for further TIR optimizations. * - * \param sch The schedule to be inlined. + * Perform this translation before running any TIR optimizations. + * + * List of actions taken by the function: + * - Remove occurences of te::Tensor, te::Operation in the IR + * and replace them by corresponding IR nodes via tir::Buffer. + * - Add annotation of extern buffers using the buffer_map field + * in the PrimFunc type. + * + * \param arg_list Array of Tensor/Var/Buffer arguments to the function. + * \param body The body of the function. + * \param bindings potential Tensor to Buffer bindings for the Tensors in the body. */ -TVM_DLL void AutoInlineInjective(Schedule sch); +PrimFunc SchedulePostProcToPrimFunc(Array arg_list, Stmt body, + Optional> bindings); } // namespace te } // namespace tvm diff --git a/include/tvm/te/tensor.h b/include/tvm/te/tensor.h index c247dca3ff45..2f9fa2f534c5 100644 --- a/include/tvm/te/tensor.h +++ b/include/tvm/te/tensor.h @@ -24,15 +24,15 @@ #ifndef TVM_TE_TENSOR_H_ #define TVM_TE_TENSOR_H_ -#include #include +#include #include #include #include -#include -#include #include +#include +#include namespace tvm { namespace te { @@ -40,25 +40,69 @@ namespace te { using arith::IntSet; using namespace tvm::tir; -// Internal node container of Tensor -class TensorNode; // internal node container for Operation class OperationNode; +class Tensor; -/*! - * \brief Tensor structure representing a possible input, - * or intermediate computation result. - */ -class Tensor : public ObjectRef { +/*! \brief Operation that produces tensors */ +class Operation : public ObjectRef { public: - /*! \brief default constructor, used internally */ - Tensor() {} - explicit Tensor(ObjectPtr n) : ObjectRef(n) {} + /*! \brief default constructor */ + Operation() {} + explicit Operation(ObjectPtr n) : ObjectRef(n) {} /*! * \brief access the internal node container * \return the pointer to the internal node container */ - inline const TensorNode* operator->() const; + inline const OperationNode* operator->() const; + /*! + * \brief get the i-th output of the operation. + * \param i the output index. + * \return The i-th output. + */ + TVM_DLL Tensor output(size_t i) const; + /*! \brief specify container node */ + using ContainerType = OperationNode; +}; + +/*! \brief Node to represent a tensor */ +class TensorNode : public DataProducerNode { + public: + /*! \brief The shape of the tensor */ + Array shape; + /*! \brief data type in the content of the tensor */ + DataType dtype; + /*! \brief the source operation, can be None */ + Operation op; + /*! \brief the output index from source operation */ + int value_index{0}; + /*! \brief constructor */ + TensorNode() {} + + void VisitAttrs(AttrVisitor* v) { + v->Visit("shape", &shape); + v->Visit("dtype", &dtype); + v->Visit("op", &op); + v->Visit("value_index", &value_index); + } + + Array GetShape() const final { return shape; } + + DataType GetDataType() const final { return dtype; } + + TVM_DLL String GetNameHint() const final; + + static constexpr const char* _type_key = "Tensor"; + TVM_DECLARE_FINAL_OBJECT_INFO(TensorNode, DataProducerNode); +}; + +/*! + * \brief Tensor structure representing a possible input, + * or intermediate computation result. + */ +class Tensor : public DataProducer { + public: + TVM_DLL Tensor(Array shape, DataType dtype, Operation op, int value_index); /*! * \brief check if two tensors equals each other. * \param other tensor to be checked. @@ -78,8 +122,8 @@ class Tensor : public ObjectRef { * \param args The indices * \return the result expression representing tensor read. */ - template - inline PrimExpr operator()(Args&& ...args) const { + template + inline PrimExpr operator()(Args&&... args) const { Array indices{std::forward(args)...}; return operator()(indices); } @@ -119,9 +163,7 @@ class Tensor : public ObjectRef { * This is only valid when all the coordinates are fully specified. * \return the corresponding expression of this slice. */ - inline operator PrimExpr() const { - return tensor_(indices_); - } + inline operator PrimExpr() const { return tensor_(indices_); } private: const Tensor& tensor_; @@ -132,105 +174,41 @@ class Tensor : public ObjectRef { * \param i the index of the coordinate * \return the subsequent slice. */ - inline Slice operator[](PrimExpr i) const { - return Slice(*this, {i}); - } - /*! \brief specify container node */ - using ContainerType = TensorNode; -}; + inline Slice operator[](PrimExpr i) const { return Slice(*this, {i}); } -/*! \brief Operation that produces tensors */ -class Operation : public tir::FunctionRef { - public: - /*! \brief default constructor */ - Operation() {} - explicit Operation(ObjectPtr n) : FunctionRef(n) {} - /*! - * \brief access the internal node container - * \return the pointer to the internal node container - */ - inline const OperationNode* operator->() const; - /*! - * \brief get the i-th output of the operation. - * \param i the output index. - * \return The i-th output. - */ - TVM_DLL Tensor output(size_t i) const; - /*! \brief specify container node */ - using ContainerType = OperationNode; + TVM_DEFINE_OBJECT_REF_METHODS(Tensor, DataProducer, TensorNode); }; -/*! \brief Node to represent a tensor */ -class TensorNode : public Object { - public: - /*! \brief The shape of the tensor */ - Array shape; - /*! \brief data type in the content of the tensor */ - DataType dtype; - /*! \brief the source operation, can be None */ - Operation op; - /*! \brief the output index from source operation */ - int value_index{0}; - /*! \brief constructor */ - TensorNode() {} - - void VisitAttrs(AttrVisitor* v) { - v->Visit("shape", &shape); - v->Visit("dtype", &dtype); - v->Visit("op", &op); - v->Visit("value_index", &value_index); - } - TVM_DLL static Tensor make(Array shape, - DataType dtype, - Operation op, - int value_index); - - static constexpr const char* _type_key = "Tensor"; - TVM_DECLARE_FINAL_OBJECT_INFO(TensorNode, Object); -}; - - // Implementations of inline functions -inline const TensorNode* Tensor::operator->() const { - return static_cast(get()); -} - -inline size_t Tensor::ndim() const { - return (*this)->shape.size(); -} +inline size_t Tensor::ndim() const { return (*this)->shape.size(); } inline bool Tensor::operator==(const Tensor& other) const { if (get() == other.get()) return true; if (get() == nullptr || other.get() == nullptr) return false; if ((*this)->op.defined() || other->op.defined()) { - return (*this)->op == other->op && - (*this)->value_index == other->value_index; + return (*this)->op == other->op && (*this)->value_index == other->value_index; } else { return false; } } -inline bool Tensor::operator!=(const Tensor& other) const { - return !(*this == other); -} +inline bool Tensor::operator!=(const Tensor& other) const { return !(*this == other); } // macro to turn every operation of slice to expression -#define DEFINE_OVERLOAD_SLICE_UNARY_OP(Op) \ - inline PrimExpr operator Op (const Tensor::Slice& a) { \ - return Op a.operator PrimExpr() ; \ - } \ - -#define DEFINE_OVERLOAD_SLICE_BINARY_OP(Op) \ - template \ - inline PrimExpr operator Op (const Tensor::Slice& a, const T& b) { \ - return a.operator PrimExpr() Op b; \ - } \ - template \ - inline PrimExpr operator Op (const T& a, const Tensor::Slice& b) { \ - return a Op b.operator PrimExpr(); \ - } \ - inline PrimExpr operator Op (const Tensor::Slice& a, const Tensor::Slice& b) { \ - return a.operator PrimExpr() Op b.operator PrimExpr(); \ +#define DEFINE_OVERLOAD_SLICE_UNARY_OP(Op) \ + inline PrimExpr operator Op(const Tensor::Slice& a) { return Op a.operator PrimExpr(); } + +#define DEFINE_OVERLOAD_SLICE_BINARY_OP(Op) \ + template \ + inline PrimExpr operator Op(const Tensor::Slice& a, const T& b) { \ + return a.operator PrimExpr() Op b; \ + } \ + template \ + inline PrimExpr operator Op(const T& a, const Tensor::Slice& b) { \ + return a Op b.operator PrimExpr(); \ + } \ + inline PrimExpr operator Op(const Tensor::Slice& a, const Tensor::Slice& b) { \ + return a.operator PrimExpr() Op b.operator PrimExpr(); \ } DEFINE_OVERLOAD_SLICE_UNARY_OP(!); @@ -254,16 +232,15 @@ DEFINE_OVERLOAD_SLICE_BINARY_OP(<); // NOLINT(*) namespace std { template <> -struct hash<::tvm::te::Operation> : public ::tvm::ObjectHash { -}; +struct hash<::tvm::te::Operation> : public ::tvm::ObjectPtrHash {}; template <> struct hash<::tvm::te::Tensor> { std::size_t operator()(const ::tvm::te::Tensor& k) const { - ::tvm::ObjectHash hasher; + ::tvm::ObjectPtrHash hasher; if (k.defined() && k->op.defined()) { return hasher(k->op); - } else{ + } else { return hasher(k); } } diff --git a/include/tvm/te/tensor_intrin.h b/include/tvm/te/tensor_intrin.h index c964d3e5491b..22f29defbb64 100644 --- a/include/tvm/te/tensor_intrin.h +++ b/include/tvm/te/tensor_intrin.h @@ -32,24 +32,6 @@ namespace tvm { namespace te { -// Internal node container of tensor intrinsics. -class TensorIntrinNode; - -/*! \brief Tensor intrinsic node. */ -class TensorIntrin : public ObjectRef { - public: - TensorIntrin() {} - explicit TensorIntrin(ObjectPtr n) : ObjectRef(n) {} - /*! - * \brief access the internal node container - * \return the pointer to the internal node container - */ - inline const TensorIntrinNode* operator->() const; - - /*! \brief specify container node */ - using ContainerType = TensorIntrinNode; -}; - /*! \brief Node to represent a Tensor intrinsic operator */ class TensorIntrinNode : public Object { public: @@ -100,39 +82,20 @@ class TensorIntrinNode : public Object { v->Visit("reduce_update", &reduce_update); } - TVM_DLL static TensorIntrin make(std::string name, - Operation op, - Array inputs, - Array buffers, - Array scalar_params, - Stmt body, - Stmt reduce_init, - Stmt reduce_update); - static constexpr const char* _type_key = "TensorIntrin"; TVM_DECLARE_FINAL_OBJECT_INFO(TensorIntrinNode, Object); }; -inline const TensorIntrinNode* TensorIntrin::operator->() const { - return static_cast(get()); -} - -// Internal node container of tensor intrinsic calling. -class TensorIntrinCallNode; - -/*! \brief Tensor intrinsic calling node. */ -class TensorIntrinCall : public ObjectRef { +/*! + * \brief Managed reference to TensorIntrinNode + * \sa TensorIntrinNode + */ +class TensorIntrin : public ObjectRef { public: - TensorIntrinCall() {} - explicit TensorIntrinCall(ObjectPtr n) : ObjectRef(n) {} - /*! - * \brief access the internal node container - * \return the pointer to the internal node container - */ - inline const TensorIntrinCallNode* operator->() const; + TVM_DLL TensorIntrin(std::string name, Operation op, Array inputs, Array buffers, + Array scalar_params, Stmt body, Stmt reduce_init, Stmt reduce_update); - /*! \brief specify container node */ - using ContainerType = TensorIntrinCallNode; + TVM_DEFINE_OBJECT_REF_METHODS(TensorIntrin, ObjectRef, TensorIntrinNode); }; class TensorIntrinCallNode : public Object { @@ -144,7 +107,6 @@ class TensorIntrinCallNode : public Object { /*! \brief regions of input tensors */ Array regions; - /*! * \brief IterVar on each reduction axis, if the * intrin will use the reduce axis @@ -161,19 +123,22 @@ class TensorIntrinCallNode : public Object { v->Visit("reduce_axis", &reduce_axis); v->Visit("scalar_inputs", &scalar_inputs); } - static TensorIntrinCall make(TensorIntrin intrin, - Array tensors, - Array regions, - Array reduce_axis, - Array scalar_inputs); static constexpr const char* _type_key = "TensorIntrinCall"; TVM_DECLARE_FINAL_OBJECT_INFO(TensorIntrinCallNode, Object); }; -inline const TensorIntrinCallNode* TensorIntrinCall::operator->() const { - return static_cast(get()); -} +/*! + * \brief Managed reference to TensorIntrinCallNode + * \sa TensorIntrinCallNode + */ +class TensorIntrinCall : public ObjectRef { + public: + TVM_DLL TensorIntrinCall(TensorIntrin intrin, Array tensors, Array regions, + Array reduce_axis, Array scalar_inputs); + + TVM_DEFINE_OBJECT_REF_METHODS(TensorIntrinCall, ObjectRef, TensorIntrinCallNode); +}; } // namespace te } // namespace tvm diff --git a/include/tvm/tir/analysis.h b/include/tvm/tir/analysis.h index 6af99586d2f9..6e7ed418b17a 100644 --- a/include/tvm/tir/analysis.h +++ b/include/tvm/tir/analysis.h @@ -25,10 +25,12 @@ #define TVM_TIR_ANALYSIS_H_ #include +#include #include #include #include +#include namespace tvm { namespace tir { @@ -53,14 +55,49 @@ struct ExprDeepEqual { TVM_DLL bool operator()(const PrimExpr& lhs, const PrimExpr& rhs) const; }; - /*! * \brief Find undefined vars in the statment. * \param stmt The function to be checked. * \param defs The vars that is defined. * \return Array of undefined vars. */ -Array UndefinedVars(const Stmt& stmt, const Array& defs); +TVM_DLL Array UndefinedVars(const Stmt& stmt, const Array& defs); + +/*! + * \brief Whether the expression have side effect. + * \param expr The expression to be checked. + * \return whether expression have side effect + */ +TVM_DLL bool HasSideEffect(const PrimExpr& expr); + +/*! + * \brief Whether e expression used any var in variable set.. + * \param expr The expression to be checked. + * \param vset_contains The check function to see if var is in the vset. + * \return Whether e uses vset. + */ +TVM_DLL bool ExprUseVar(const PrimExpr& expr, std::function vset_contains); + +/*! + * \brief Whether e expression used var. + * \param expr The expression to be checked. + * \param var The variable. + * \return Whether e uses v. + */ +inline bool ExprUseVar(const PrimExpr& expr, const Var& var) { + return ExprUseVar(expr, [&](const VarNode* node) { return var.get() == node; }); +} + +/*! + * \brief Verifies whether the IR stmt or Expr is in SSA form. + * That is: each Var is defined and assigned once(in Let/For) + * + * \param func The function to be verified. + * \return Whether IR is in SSA form. + * + * \note All passes in TIR consume and produce SSA form. + */ +TVM_DLL bool VerifySSA(const PrimFunc& func); /*! * \brief Verify if memory accesses are legal for a specific target device type. @@ -69,11 +106,66 @@ Array UndefinedVars(const Stmt& stmt, const Array& defs); * threads, CPU code is generated that tries to access GPU memory, * which is illegal. This pass performs verification for this case. * - * \param mod The module to be verified. + * \param func The function to be verified. * \return Success of memory verification. */ -void VerifyMemory(const IRModule& mod); +TVM_DLL bool VerifyMemory(const PrimFunc& func); + +/*! + * \brief Verify the correctness of a GPU code + * It will check the whether the amount of memory usage or the number of threads + * in a block exceeds the limit + * \param func The function to be checked + * \param constraints The dict to specify constraints to check. + * Possible keys are + * + * "max_local_memory_per_block": Total amount of local memory per block (in bytes). + * "max_shared_memory_per_block": Total amount of shared memory per block (in bytes). + * "max_threads_per_block": Maximum number of threads per block. + * "max_thread_x": Maximum length of threadIdx.x. + * "max_thread_y": Maximum length of threadIdx.y. + * "max_thread_z": Maximum length of threadIdx.z. + * + * If one key is missing in this argument, the pass won't check for that item. + * \return valid Whether it is a valid GPU code + * + */ +TVM_DLL bool VerifyGPUCode(const PrimFunc& func, Map constraints); + +// Pass variants of verification analysis +// directly throws RuntimeError when verification fails. +namespace transform { + +using tvm::transform::Pass; +using tvm::transform::PassContext; + +/*! + * \brief Pass variant of VerifySSA. + * + * \returns The pass. + * \sa tvm::tir::VerifySSA + */ +TVM_DLL Pass VerifySSA(); + +/*! + * \brief Pass variant of VerifyMemory. + * + * \returns The pass. + * \sa tvm::tir::VerifyMemory + */ +TVM_DLL Pass VerifyMemory(); + +/*! + * \brief Pass variant of VerifyGPUCode. + * + * \param constraints The dict to specify constraints to check. + * + * \returns The pass. + * \sa tvm::tir::VerifyGPUCode + */ +TVM_DLL Pass VerifyGPUCode(Map constraints); +} // namespace transform } // namespace tir } // namespace tvm #endif // TVM_TIR_ANALYSIS_H_ diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h index 08a8e69a4532..e150ff38041b 100644 --- a/include/tvm/tir/buffer.h +++ b/include/tvm/tir/buffer.h @@ -24,16 +24,14 @@ #ifndef TVM_TIR_BUFFER_H_ #define TVM_TIR_BUFFER_H_ -#include #include +#include #include -#include +#include namespace tvm { namespace tir { -// Internal node container Buffer -class BufferNode; // forward declare Stmt class Stmt; @@ -45,63 +43,6 @@ enum BufferType : int { kAutoBroadcast = 2, }; -/*! - * \brief Buffer is a symbolic n-darray structure. - * It is a composition of primitive symbolic types, - * used to specify the memory layout of the Tensor used in program input. - */ -class Buffer : public ObjectRef { - public: - Buffer() {} - explicit Buffer(ObjectPtr n) : ObjectRef(n) {} - /*! - * \brief Return a new buffer that is equivalent with current one - * but always add stride field. - * \return The strided version of the buffer. - */ - TVM_DLL Buffer MakeStrideView() const; - /*! - * \brief Make a new symbolic buffer representing a slice of the buffer. - * \param begins The beginning position of each dimension. - * \param extents The extent of each dimension. - * \note This function will make target buffer as compact as possible. - * If stride is not needed in the slice, it won't be presented - * \return the result buffer. - */ - TVM_DLL Buffer MakeSlice(Array begins, Array extents) const; - /*! - * \brief Get access ptr to the entire buffer. - * \param access_mask The access mask - * \param ptr_type The type of the pointer. - * \param content_lanes The number of lanes for the (data) type. - * \param offset The offset of ptr. - */ - TVM_DLL PrimExpr access_ptr(int access_mask, - DataType ptr_type = DataType::Handle(), - int content_lanes = 1, - PrimExpr offset = IntImm(DataType::Int(32), 0)) const; - /*! - * \brief Create an Expr that does a vector load at begin index. - * \param begin The beginning index - * \param dtype The data type to be loaded. - */ - TVM_DLL PrimExpr vload(Array begin, DataType dtype) const; - /*! - * \brief Create a Stmt that does a vector store at begin index. - * \param begin The beginning index - * \param value The value to be stored. - */ - TVM_DLL Stmt vstore(Array begin, PrimExpr value) const; - /*! - * \brief access the internal node container - * \return the pointer to the internal node container - */ - inline const BufferNode* operator->() const; - - /*! \brief specify container node */ - using ContainerType = BufferNode; -}; - /*! \brief Node to represent a buffer */ class BufferNode : public Object { public: @@ -124,9 +65,9 @@ class BufferNode : public Object { PrimExpr elem_offset; // Meta data /*! \brief optional name of the buffer */ - std::string name; + String name; /*! \brief storage scope of the buffer, if other than global */ - std::string scope; + String scope; /*! \brief Alignment requirement of data pointer in bytes. */ int data_alignment; /*! @@ -155,15 +96,10 @@ class BufferNode : public Object { bool SEqualReduce(const BufferNode* other, SEqualReducer equal) const { // Use DefEqual as buffer can define variables // in its semantics, skip name as name is not important. - return - equal.DefEqual(data, other->data) && - equal(dtype, other->dtype) && - equal.DefEqual(shape, other->shape) && - equal.DefEqual(strides, other->strides) && - equal.DefEqual(elem_offset, other->elem_offset) && - equal(scope, other->scope) && - equal(data_alignment, other->data_alignment) && - equal(buffer_type, other->buffer_type); + return equal.DefEqual(data, other->data) && equal(dtype, other->dtype) && + equal.DefEqual(shape, other->shape) && equal.DefEqual(strides, other->strides) && + equal.DefEqual(elem_offset, other->elem_offset) && equal(scope, other->scope) && + equal(data_alignment, other->data_alignment) && equal(buffer_type, other->buffer_type); } void SHashReduce(SHashReducer hash_reduce) const { @@ -182,28 +118,65 @@ class BufferNode : public Object { return shape.size() != 0 ? shape[0].dtype() : DataType::Int(32); } - // User can specify data_alignment and offset_factor to be 0 - // A default value will be picked. - TVM_DLL static Buffer make(Var ptr, - DataType dtype, - Array shape, - Array strides, - PrimExpr elem_offset, - std::string name, - std::string scope, - int data_alignment, - int offset_factor, - BufferType buffer_type); - - static constexpr const char* _type_key = "Buffer"; + static constexpr const char* _type_key = "tir.Buffer"; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true; TVM_DECLARE_FINAL_OBJECT_INFO(BufferNode, Object); }; -inline const BufferNode* Buffer::operator->() const { - return static_cast(get()); -} +/*! + * \brief Buffer is a symbolic n-darray structure. + * It is a composition of primitive symbolic types, + * used to specify the memory layout of the Tensor used in program input. + */ +class Buffer : public ObjectRef { + public: + // User can specify data_alignment and offset_factor to be 0 + // A default value will be picked. + TVM_DLL Buffer(Var ptr, DataType dtype, Array shape, Array strides, + PrimExpr elem_offset, String name, String scope, int data_alignment, + int offset_factor, BufferType buffer_type); + + /*! + * \brief Return a new buffer that is equivalent with current one + * but always add stride field. + * \return The strided version of the buffer. + */ + TVM_DLL Buffer MakeStrideView() const; + /*! + * \brief Make a new symbolic buffer representing a slice of the buffer. + * \param begins The beginning position of each dimension. + * \param extents The extent of each dimension. + * \note This function will make target buffer as compact as possible. + * If stride is not needed in the slice, it won't be presented + * \return the result buffer. + */ + TVM_DLL Buffer MakeSlice(Array begins, Array extents) const; + /*! + * \brief Get access ptr to the entire buffer. + * \param access_mask The access mask + * \param ptr_type The type of the pointer. + * \param content_lanes The number of lanes for the (data) type. + * \param offset The offset of ptr. + */ + TVM_DLL PrimExpr access_ptr(int access_mask, DataType ptr_type = DataType::Handle(), + int content_lanes = 1, + PrimExpr offset = IntImm(DataType::Int(32), 0)) const; + /*! + * \brief Create an Expr that does a vector load at begin index. + * \param begin The beginning index + * \param dtype The data type to be loaded. + */ + TVM_DLL PrimExpr vload(Array begin, DataType dtype) const; + /*! + * \brief Create a Stmt that does a vector store at begin index. + * \param begin The beginning index + * \param value The value to be stored. + */ + TVM_DLL Stmt vstore(Array begin, PrimExpr value) const; + + TVM_DEFINE_OBJECT_REF_METHODS(Buffer, ObjectRef, BufferNode); +}; /*! * \brief Construct a new buffer given shape, and dtype. @@ -211,11 +184,65 @@ inline const BufferNode* Buffer::operator->() const { * \param dtype The content data type. * \param name The name of the buffer * \return The created buffer. - * \sa BufferNode::make for complete constructor. + * \sa Buffer for complete constructor. + */ +TVM_DLL Buffer decl_buffer(Array shape, DataType dtype = DataType::Float(32), + String name = "buffer"); + +/*! + * \brief Base node for data producers. + * + * A DataProducer stores necessary information(e.g. a tensor expression) to produce + * a multi-dimensional array. The stored information is opaque to the TIR. + * DataProducer can appear in high-level DSLs that are built on top of the TIR. + * + * A valid TIR PrimFunc should not contain any DataProducer, high level DSLs should lower + * all DataProducers to Buffers before TIR transformations. + * + * \sa tvm::te::Tensor + */ +class DataProducerNode : public Object { + public: + /*! \brief destructor. */ + virtual ~DataProducerNode() {} + /*! + * \brief Get the shape of the result. + * \return The shape. + */ + virtual Array GetShape() const = 0; + /*! + * \brief Get the data type of the result. + * \return The data type. + */ + virtual DataType GetDataType() const = 0; + /*! + * \brief Get the name hint of the data producer. + * \return The data type. + */ + virtual String GetNameHint() const = 0; + + bool SEqualReduce(const DataProducerNode* other, SEqualReducer equal) const { + // because buffer producer is opaque, we just do pointer equality. + return this == other; + } + + void SHashReduce(SHashReducer hash_reduce) const {} + + static constexpr const char* _type_key = "tir.DataProducer"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_BASE_OBJECT_INFO(DataProducerNode, Object); +}; + +/*! + * \brief Managed reference to DataProducerNode. + * \sa DataProducerNode */ -TVM_DLL Buffer decl_buffer(Array shape, - DataType dtype = DataType::Float(32), - std::string name = "buffer"); +class DataProducer : public ObjectRef { + public: + TVM_DEFINE_OBJECT_REF_METHODS(DataProducer, ObjectRef, DataProducerNode); +}; + } // namespace tir } // namespace tvm #endif // TVM_TIR_BUFFER_H_ diff --git a/include/tvm/tir/data_layout.h b/include/tvm/tir/data_layout.h index 434337057167..d3a77cc81063 100644 --- a/include/tvm/tir/data_layout.h +++ b/include/tvm/tir/data_layout.h @@ -25,20 +25,20 @@ #ifndef TVM_TIR_DATA_LAYOUT_H_ #define TVM_TIR_DATA_LAYOUT_H_ - #include #include -#include +#include #include -#include +#include #include -#include - +#include namespace tvm { namespace tir { +class Layout; + class LayoutAxis { public: static const LayoutAxis& Get(const char name); @@ -47,7 +47,7 @@ class LayoutAxis { static const LayoutAxis& Get(const tir::IterVar& itvar); // Get the singleton LayoutAxis using name[0] (size of name must be 1). - static const LayoutAxis& make(const std::string& name); + static const LayoutAxis& Get(const std::string& name); inline bool IsPrimal() const { return name_ >= 'A' && name_ <= 'Z'; } inline std::string name() const { return std::string(1, name_); } @@ -63,18 +63,12 @@ class LayoutAxis { } // return the primal axis. If it is already primal, return itself. - const LayoutAxis& ToPrimal() const { - return IsPrimal() ? *this : ToDual(); - } + const LayoutAxis& ToPrimal() const { return IsPrimal() ? *this : ToDual(); } // return the subordinate axis. If it is already subordinate, return itself. - const LayoutAxis& ToSubordinate() const { - return IsPrimal() ? ToDual() : *this; - } + const LayoutAxis& ToSubordinate() const { return IsPrimal() ? ToDual() : *this; } - inline bool operator==(const LayoutAxis& rhs) const { - return name_ == rhs.name_; - } + inline bool operator==(const LayoutAxis& rhs) const { return name_ == rhs.name_; } friend std::ostream& operator<<(std::ostream& os, const LayoutAxis& l) { os << l.name(); @@ -91,12 +85,20 @@ class LayoutAxis { const char name_; }; -class Layout; -// Internal node container Buffer +/*! + * \brief Layout is to describe how data is organized within an N-dimention tensor. + * It is composed of upper cases, lower cases and numbers, + * where upper case indicates a primal axis and + * the corresponding lower case with factor size indicates the subordinate axis. + * For example, NCHW16c can describe a 5-D tensor of + * [batch_size, channel, height, width, channel_block]. + * Here subordinate axis channel_block=16 is the factor size of the primal axis C (channel). + * Layout for scalar is defined, while both its name and axes have size 0. + */ class LayoutNode : public Object { public: /*! \brief string representation of layout, "" for scalar. */ - std::string name; + String name; /*! \brief specify each axis of the layout, * in which the variable name is the name of the axis. * The IterVar's extent indicates the size of the axis, @@ -110,33 +112,20 @@ class LayoutNode : public Object { v->Visit("axes", &axes); } - TVM_DLL static Layout make(const std::string& layout); - - static constexpr const char* _type_key = "Layout"; + static constexpr const char* _type_key = "tir.Layout"; TVM_DECLARE_FINAL_OBJECT_INFO(LayoutNode, Object); }; /*! - * \brief Layout is to describe how data is organized within an N-dimention tensor. - * It is composed of upper cases, lower cases and numbers, - * where upper case indicates a primal axis and - * the corresponding lower case with factor size indicates the subordinate axis. - * For example, NCHW16c can describe a 5-D tensor of - * [batch_size, channel, height, width, channel_block]. - * Here subordinate axis channel_block=16 is the factor size of the primal axis C (channel). - * Layout for scalar is defined, while both its name and axes have size 0. + * \brief Managed reference to LayoutNode + * \sa LayoutNode */ class Layout : public ObjectRef { public: - explicit Layout(ObjectPtr n) : ObjectRef(n) {} - - /*! \brief default constructor */ - Layout() = default; - explicit Layout(const Array& axes); /*! \brief construct from a string */ - Layout(const char* name) : Layout(std::string(name)) {} // NOLINT(*) + Layout(const char* name) : Layout(std::string(name)) {} // NOLINT(*) /*! * \brief construct from a string. @@ -146,23 +135,13 @@ class Layout : public ObjectRef { * indicates the split dimension. * return undefined layout if "__undef__" is passed. */ - Layout(const std::string& name); // NOLINT(*) - - /*! - * \brief access the internal node container - * \return the pointer to the internal node container - */ - const LayoutNode* operator->() const { - return static_cast(get()); - } + TVM_DLL Layout(const std::string& name); // NOLINT(*) /*! * \brief access the internal node container * \return the pointer to the internal node container */ - LayoutNode* operator->() { - return static_cast(get_mutable()); - } + LayoutNode* operator->() { return static_cast(get_mutable()); } /*! * \brief Return an undefined layout. @@ -190,8 +169,7 @@ class Layout : public ObjectRef { * \param factor size of the sub-dimension. * \return A newly constructed Layout object. */ - Layout Split(const LayoutAxis &axis, size_t target_pos, int32_t factor) const; - + Layout Split(const LayoutAxis& axis, size_t target_pos, int32_t factor) const; /*! \return number of dimensions */ inline size_t ndim() const { @@ -292,9 +270,7 @@ class Layout : public ObjectRef { * \param rhs Another layout. * \return whether the two layouts are equal. */ - inline bool Equals(const Layout &rhs) const { - return name() == rhs.name(); - } + inline bool Equals(const Layout& rhs) const { return name() == rhs.name(); } /*! * \brief allow output string of layout to ostream @@ -307,10 +283,9 @@ class Layout : public ObjectRef { return os; } - using ContainerType = LayoutNode; + TVM_DEFINE_OBJECT_REF_METHODS(Layout, ObjectRef, LayoutNode); }; -class BijectiveLayout; // Internal node container BijectiveLayout class BijectiveLayoutNode : public Object { public: @@ -333,7 +308,7 @@ class BijectiveLayoutNode : public Object { v->Visit("backward_rule", &backward_rule); } - static constexpr const char* _type_key = "BijectiveLayout"; + static constexpr const char* _type_key = "tir.BijectiveLayout"; TVM_DECLARE_FINAL_OBJECT_INFO(BijectiveLayoutNode, Object); }; @@ -344,8 +319,6 @@ class BijectiveLayoutNode : public Object { */ class BijectiveLayout : public ObjectRef { public: - BijectiveLayout() = default; - explicit BijectiveLayout(ObjectPtr n) : ObjectRef(n) {} /*! * \brief The constructor * \param src_layout The source layout @@ -362,19 +335,9 @@ class BijectiveLayout : public ObjectRef { // Given the destination indices, recover the source indices. TVM_DLL Array BackwardIndex(const Array& dst_index) const; - /*! - * \brief access the internal node container - * \return the pointer to the internal node container - */ - inline const BijectiveLayoutNode* operator->() const; - - /*! \brief specify container node */ - using ContainerType = BijectiveLayoutNode; + TVM_DEFINE_OBJECT_REF_METHODS(BijectiveLayout, ObjectRef, BijectiveLayoutNode); }; -inline const BijectiveLayoutNode* BijectiveLayout::operator->() const { - return static_cast(get()); -} } // namespace tir } // namespace tvm diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h index a1603d5e7bda..1518d1ff548e 100644 --- a/include/tvm/tir/expr.h +++ b/include/tvm/tir/expr.h @@ -25,20 +25,20 @@ #ifndef TVM_TIR_EXPR_H_ #define TVM_TIR_EXPR_H_ -#include +#include #include #include +#include #include #include -#include -#include #include +#include -#include #include -#include #include #include +#include +#include #include namespace tvm { @@ -51,7 +51,7 @@ using FloatImmNode = tvm::FloatImmNode; class StringImmNode : public PrimExprNode { public: /*! \brief The constant value content. */ - std::string value; + String value; void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &dtype); @@ -62,18 +62,19 @@ class StringImmNode : public PrimExprNode { return equal(value, other->value); } - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(value); - } - - TVM_DLL PrimExpr static make(std::string value); + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(value); } - static constexpr const char* _type_key = "StringImm"; + static constexpr const char* _type_key = "tir.StringImm"; TVM_DECLARE_FINAL_OBJECT_INFO(StringImmNode, PrimExprNode); }; +/*! + * \brief Managed reference to StringImmNode. + * \sa StringImmNode + */ class StringImm : public PrimExpr { public: + TVM_DLL StringImm(String value); TVM_DEFINE_OBJECT_REF_METHODS(StringImm, PrimExpr, StringImmNode); }; @@ -100,17 +101,25 @@ class CastNode : public PrimExprNode { hash_reduce(value); } - TVM_DLL static PrimExpr make(DataType t, PrimExpr v); - - static constexpr const char* _type_key = "Cast"; + static constexpr const char* _type_key = "tir.Cast"; TVM_DECLARE_FINAL_OBJECT_INFO(CastNode, PrimExprNode); }; +/*! + * \brief Managed reference to CastNode + * \sa CastNode + */ +class Cast : public PrimExpr { + public: + TVM_DLL Cast(DataType dtype, PrimExpr value); + TVM_DEFINE_OBJECT_REF_METHODS(Cast, PrimExpr, CastNode); +}; + /*! * \brief Base template to implement binary ops. * \tparam T The type of the child class. */ -template +template class BinaryOpNode : public PrimExprNode { public: /*! \brief The left operand. */ @@ -125,10 +134,7 @@ class BinaryOpNode : public PrimExprNode { } bool SEqualReduce(const T* other, SEqualReducer equal) const { - return - equal(dtype, other->dtype) && - equal(a, other->a) && - equal(b, other->b); + return equal(dtype, other->dtype) && equal(a, other->a) && equal(b, other->b); } void SHashReduce(SHashReducer hash_reduce) const { @@ -137,36 +143,55 @@ class BinaryOpNode : public PrimExprNode { hash_reduce(b); } - static PrimExpr make(PrimExpr a, PrimExpr b) { - CHECK(a.defined()) << "ValueError: a is undefined\n"; - CHECK(b.defined()) << "ValueError: b is undefined\n"; - CHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types\n"; - ObjectPtr node = make_object(); - node->dtype = a.dtype(); - node->a = std::move(a); - node->b = std::move(b); - return PrimExpr(node); - } - TVM_DECLARE_FINAL_OBJECT_INFO(T, PrimExprNode); }; /*! \brief a + b */ class AddNode : public BinaryOpNode { public: - static constexpr const char* _type_key = "Add"; + static constexpr const char* _type_key = "tir.Add"; +}; + +/*! + * \brief Managed reference to AddNode + * \sa AddNode + */ +class Add : public PrimExpr { + public: + TVM_DLL Add(PrimExpr a, PrimExpr b); + TVM_DEFINE_OBJECT_REF_METHODS(Add, PrimExpr, AddNode); }; /*! \brief a - b */ class SubNode : public BinaryOpNode { public: - static constexpr const char* _type_key = "Sub"; + static constexpr const char* _type_key = "tir.Sub"; +}; + +/*! + * \brief Managed reference to SubNode + * \sa SubNode + */ +class Sub : public PrimExpr { + public: + TVM_DLL Sub(PrimExpr a, PrimExpr b); + TVM_DEFINE_OBJECT_REF_METHODS(Sub, PrimExpr, SubNode); }; /*! \brief a * b */ class MulNode : public BinaryOpNode { public: - static constexpr const char* _type_key = "Mul"; + static constexpr const char* _type_key = "tir.Mul"; +}; + +/*! + * \brief Managed reference to MulNode + * \sa MulNode + */ +class Mul : public PrimExpr { + public: + TVM_DLL Mul(PrimExpr a, PrimExpr b); + TVM_DEFINE_OBJECT_REF_METHODS(Mul, PrimExpr, MulNode); }; /*! @@ -175,7 +200,17 @@ class MulNode : public BinaryOpNode { */ class DivNode : public BinaryOpNode { public: - static constexpr const char* _type_key = "Div"; + static constexpr const char* _type_key = "tir.Div"; +}; + +/*! + * \brief Managed reference to DivNode + * \sa DivNode + */ +class Div : public PrimExpr { + public: + TVM_DLL Div(PrimExpr a, PrimExpr b); + TVM_DEFINE_OBJECT_REF_METHODS(Div, PrimExpr, DivNode); }; /*! @@ -184,38 +219,88 @@ class DivNode : public BinaryOpNode { */ class ModNode : public BinaryOpNode { public: - static constexpr const char* _type_key = "Mod"; + static constexpr const char* _type_key = "tir.Mod"; +}; + +/*! + * \brief Managed reference to ModNode + * \sa ModNode + */ +class Mod : public PrimExpr { + public: + TVM_DLL Mod(PrimExpr a, PrimExpr b); + TVM_DEFINE_OBJECT_REF_METHODS(Mod, PrimExpr, ModNode); }; /*! \brief Floor division, floor(a/b) */ class FloorDivNode : public BinaryOpNode { public: - static constexpr const char* _type_key = "FloorDiv"; + static constexpr const char* _type_key = "tir.FloorDiv"; +}; + +/*! + * \brief Managed reference to FloorDivNode + * \sa FloorDivNode + */ +class FloorDiv : public PrimExpr { + public: + TVM_DLL FloorDiv(PrimExpr a, PrimExpr b); + TVM_DEFINE_OBJECT_REF_METHODS(FloorDiv, PrimExpr, FloorDivNode); }; /*! \brief The remainder of the floordiv */ class FloorModNode : public BinaryOpNode { public: - static constexpr const char* _type_key = "FloorMod"; + static constexpr const char* _type_key = "tir.FloorMod"; +}; + +/*! + * \brief Managed reference to FloorModNode + * \sa FloorModNode + */ +class FloorMod : public PrimExpr { + public: + TVM_DLL FloorMod(PrimExpr a, PrimExpr b); + TVM_DEFINE_OBJECT_REF_METHODS(FloorMod, PrimExpr, FloorModNode); }; /*! \brief min(a, b) */ class MinNode : public BinaryOpNode { public: - static constexpr const char* _type_key = "Min"; + static constexpr const char* _type_key = "tir.Min"; +}; + +/*! + * \brief Managed reference to MinNode + * \sa MinNode + */ +class Min : public PrimExpr { + public: + TVM_DLL Min(PrimExpr a, PrimExpr b); + TVM_DEFINE_OBJECT_REF_METHODS(Min, PrimExpr, MinNode); }; /*! \brief max(a, b) */ class MaxNode : public BinaryOpNode { public: - static constexpr const char* _type_key = "Max"; + static constexpr const char* _type_key = "tir.Max"; +}; + +/*! + * \brief Managed reference to MaxNode + * \sa MaxNode + */ +class Max : public PrimExpr { + public: + TVM_DLL Max(PrimExpr a, PrimExpr b); + TVM_DEFINE_OBJECT_REF_METHODS(Max, PrimExpr, MaxNode); }; /*! * \brief Base template to implement comparison ops. * \tparam T The type of the child class. */ -template +template class CmpOpNode : public PrimExprNode { public: /*! \brief The left operand. */ @@ -230,10 +315,7 @@ class CmpOpNode : public PrimExprNode { } bool SEqualReduce(const T* other, SEqualReducer equal) const { - return - equal(dtype, other->dtype) && - equal(a, other->a) && - equal(b, other->b); + return equal(dtype, other->dtype) && equal(a, other->a) && equal(b, other->b); } void SHashReduce(SHashReducer hash_reduce) const { @@ -242,54 +324,103 @@ class CmpOpNode : public PrimExprNode { hash_reduce(b); } - static PrimExpr make(PrimExpr a, PrimExpr b) { - CHECK(a.defined()) << "ValueError: a is undefined\n"; - CHECK(b.defined()) << "ValueError: b is undefined\n"; - CHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types\n"; - ObjectPtr node = make_object(); - node->dtype = DataType::Bool(a.dtype().lanes()); - node->a = std::move(a); - node->b = std::move(b); - return PrimExpr(node); - } - TVM_DECLARE_FINAL_OBJECT_INFO(T, PrimExprNode); }; /*! \brief a == b */ class EQNode : public CmpOpNode { public: - static constexpr const char* _type_key = "EQ"; + static constexpr const char* _type_key = "tir.EQ"; +}; + +/*! + * \brief Managed reference to EQNode + * \sa EQNode + */ +class EQ : public PrimExpr { + public: + TVM_DLL EQ(PrimExpr a, PrimExpr b); + TVM_DEFINE_OBJECT_REF_METHODS(EQ, PrimExpr, EQNode); }; /*! \brief a != b */ class NENode : public CmpOpNode { public: - static constexpr const char* _type_key = "NE"; + static constexpr const char* _type_key = "tir.NE"; +}; + +/*! + * \brief Managed reference to NENode + * \sa NENode + */ +class NE : public PrimExpr { + public: + TVM_DLL NE(PrimExpr a, PrimExpr b); + TVM_DEFINE_OBJECT_REF_METHODS(NE, PrimExpr, NENode); }; /*! \brief a < b */ class LTNode : public CmpOpNode { public: - static constexpr const char* _type_key = "LT"; + static constexpr const char* _type_key = "tir.LT"; +}; + +/*! + * \brief Managed reference to LTNode + * \sa LTNode + */ +class LT : public PrimExpr { + public: + TVM_DLL LT(PrimExpr a, PrimExpr b); + TVM_DEFINE_OBJECT_REF_METHODS(LT, PrimExpr, LTNode); }; /*! \brief a <= b */ struct LENode : public CmpOpNode { public: - static constexpr const char* _type_key = "LE"; + static constexpr const char* _type_key = "tir.LE"; +}; + +/*! + * \brief Managed reference to LENode + * \sa LENode + */ +class LE : public PrimExpr { + public: + TVM_DLL LE(PrimExpr a, PrimExpr b); + TVM_DEFINE_OBJECT_REF_METHODS(LE, PrimExpr, LENode); }; /*! \brief a > b */ class GTNode : public CmpOpNode { public: - static constexpr const char* _type_key = "GT"; + static constexpr const char* _type_key = "tir.GT"; +}; + +/*! + * \brief Managed reference to GTNode + * \sa GTNode + */ +class GT : public PrimExpr { + public: + TVM_DLL GT(PrimExpr a, PrimExpr b); + TVM_DEFINE_OBJECT_REF_METHODS(GT, PrimExpr, GTNode); }; /*! \brief a >= b */ class GENode : public CmpOpNode { public: - static constexpr const char* _type_key = "GE"; + static constexpr const char* _type_key = "tir.GE"; +}; + +/*! + * \brief Managed reference to GENode + * \sa GENode + */ +class GE : public PrimExpr { + public: + TVM_DLL GE(PrimExpr a, PrimExpr b); + TVM_DEFINE_OBJECT_REF_METHODS(GE, PrimExpr, GENode); }; /*! \brief a && b */ @@ -307,10 +438,7 @@ class AndNode : public PrimExprNode { } bool SEqualReduce(const AndNode* other, SEqualReducer equal) const { - return - equal(dtype, other->dtype) && - equal(a, other->a) && - equal(b, other->b); + return equal(dtype, other->dtype) && equal(a, other->a) && equal(b, other->b); } void SHashReduce(SHashReducer hash_reduce) const { @@ -319,12 +447,20 @@ class AndNode : public PrimExprNode { hash_reduce(b); } - TVM_DLL static PrimExpr make(PrimExpr a, PrimExpr b); - - static constexpr const char* _type_key = "And"; + static constexpr const char* _type_key = "tir.And"; TVM_DECLARE_FINAL_OBJECT_INFO(AndNode, PrimExprNode); }; +/*! + * \brief Managed reference to AndNode + * \sa AndNode + */ +class And : public PrimExpr { + public: + TVM_DLL And(PrimExpr a, PrimExpr b); + TVM_DEFINE_OBJECT_REF_METHODS(And, PrimExpr, AndNode); +}; + /*! \brief a || b */ class OrNode : public PrimExprNode { public: @@ -340,10 +476,7 @@ class OrNode : public PrimExprNode { } bool SEqualReduce(const OrNode* other, SEqualReducer equal) const { - return - equal(dtype, other->dtype) && - equal(a, other->a) && - equal(b, other->b); + return equal(dtype, other->dtype) && equal(a, other->a) && equal(b, other->b); } void SHashReduce(SHashReducer hash_reduce) const { @@ -352,12 +485,20 @@ class OrNode : public PrimExprNode { hash_reduce(b); } - TVM_DLL static PrimExpr make(PrimExpr a, PrimExpr b); - - static constexpr const char* _type_key = "Or"; + static constexpr const char* _type_key = "tir.Or"; TVM_DECLARE_FINAL_OBJECT_INFO(OrNode, PrimExprNode); }; +/*! + * \brief Managed reference to OrNode + * \sa OrNode + */ +class Or : public PrimExpr { + public: + TVM_DLL Or(PrimExpr a, PrimExpr b); + TVM_DEFINE_OBJECT_REF_METHODS(Or, PrimExpr, OrNode); +}; + /*! \brief !a */ class NotNode : public PrimExprNode { public: @@ -378,12 +519,20 @@ class NotNode : public PrimExprNode { hash_reduce(a); } - TVM_DLL static PrimExpr make(PrimExpr a); - - static constexpr const char* _type_key = "Not"; + static constexpr const char* _type_key = "tir.Not"; TVM_DECLARE_FINAL_OBJECT_INFO(NotNode, PrimExprNode); }; +/*! + * \brief Managed reference to NotNode + * \sa NotNode + */ +class Not : public PrimExpr { + public: + TVM_DLL Not(PrimExpr a); + TVM_DEFINE_OBJECT_REF_METHODS(Not, PrimExpr, NotNode); +}; + /*! * \brief return true_value if condition is true, otherwise return false_value. * \note Both true_value and false_value could be evaluated @@ -408,11 +557,8 @@ class SelectNode : public PrimExprNode { } bool SEqualReduce(const SelectNode* other, SEqualReducer equal) const { - return - equal(dtype, other->dtype) && - equal(condition, other->condition) && - equal(true_value, other->true_value) && - equal(false_value, other->false_value); + return equal(dtype, other->dtype) && equal(condition, other->condition) && + equal(true_value, other->true_value) && equal(false_value, other->false_value); } void SHashReduce(SHashReducer hash_reduce) const { @@ -422,12 +568,21 @@ class SelectNode : public PrimExprNode { hash_reduce(false_value); } - TVM_DLL static PrimExpr make(PrimExpr condition, PrimExpr true_value, PrimExpr false_value); - - static constexpr const char* _type_key = "Select"; + static constexpr const char* _type_key = "tir.Select"; TVM_DECLARE_FINAL_OBJECT_INFO(SelectNode, PrimExprNode); }; +/*! + * \brief Managed reference to SelectNode + * \sa SelectNode + */ +class Select : public PrimExpr { + public: + TVM_DLL Select(PrimExpr condition, PrimExpr true_value, PrimExpr false_value); + + TVM_DEFINE_OBJECT_REF_METHODS(Select, PrimExpr, SelectNode); +}; + /*! * \brief Load value from the high dimension buffer. * @@ -452,10 +607,8 @@ class BufferLoadNode : public PrimExprNode { } bool SEqualReduce(const BufferLoadNode* other, SEqualReducer equal) const { - return - equal(dtype, other->dtype) && - equal(buffer, other->buffer) && - equal(indices, other->indices); + return equal(dtype, other->dtype) && equal(buffer, other->buffer) && + equal(indices, other->indices); } void SHashReduce(SHashReducer hash_reduce) const { @@ -464,17 +617,68 @@ class BufferLoadNode : public PrimExprNode { hash_reduce(indices); } - static constexpr const char* _type_key = "BufferLoad"; + static constexpr const char* _type_key = "tir.BufferLoad"; TVM_DECLARE_FINAL_OBJECT_INFO(BufferLoadNode, PrimExprNode); }; +/*! + * \brief Managed reference to BufferLoadNode. + * \sa BufferLoadNode + */ class BufferLoad : public PrimExpr { public: - TVM_DLL explicit BufferLoad(Buffer buffer, - Array indices); + TVM_DLL explicit BufferLoad(Buffer buffer, Array indices); TVM_DEFINE_OBJECT_REF_METHODS(BufferLoad, PrimExpr, BufferLoadNode); }; +/*! + * \brief Load value from the result produced by the producer. + * + * \note This node only appears in high-level DSLs that are built on top of the TIR. + * It should not appear in a valid TIR PrimFunc. A high-level DSL needs to lower + * this node before TIR transformations. + * + * \sa ProducerLoad, DataProducerNode + */ +class ProducerLoadNode : public PrimExprNode { + public: + /*! \brief The buffer producer. */ + DataProducer producer; + /*! \brief The location arguments. */ + Array indices; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("dtype", &(this->dtype)); + v->Visit("producer", &producer); + v->Visit("indices", &indices); + } + + bool SEqualReduce(const ProducerLoadNode* other, SEqualReducer equal) const { + return equal(dtype, other->dtype) && equal(producer, other->producer) && + equal(indices, other->indices); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(dtype); + hash_reduce(producer); + hash_reduce(indices); + } + + static constexpr const char* _type_key = "tir.ProducerLoad"; + TVM_DECLARE_FINAL_OBJECT_INFO(ProducerLoadNode, PrimExprNode); +}; + +/*! + * \brief Managed reference to ProducerLoadNode. + * \sa ProducerLoadNode + */ +class ProducerLoad : public PrimExpr { + public: + TVM_DLL explicit ProducerLoad(DataProducer producer, Array indices); + + TVM_DEFINE_OBJECT_REF_METHODS(ProducerLoad, PrimExpr, ProducerLoadNode); +}; + /*! * \brief Load the value from buffer_var. * @@ -507,11 +711,8 @@ class LoadNode : public PrimExprNode { } bool SEqualReduce(const LoadNode* other, SEqualReducer equal) const { - return - equal(dtype, other->dtype) && - equal(buffer_var, other->buffer_var) && - equal(index, other->index) && - equal(predicate, other->predicate); + return equal(dtype, other->dtype) && equal(buffer_var, other->buffer_var) && + equal(index, other->index) && equal(predicate, other->predicate); } void SHashReduce(SHashReducer hash_reduce) const { @@ -521,12 +722,20 @@ class LoadNode : public PrimExprNode { hash_reduce(predicate); } - TVM_DLL static PrimExpr make(DataType dtype, Var buffer_var, PrimExpr index, PrimExpr predicate); - - static constexpr const char* _type_key = "Load"; + static constexpr const char* _type_key = "tir.Load"; TVM_DECLARE_FINAL_OBJECT_INFO(LoadNode, PrimExprNode); }; +/*! + * \brief Managed reference to LoadNode + * \sa LoadNode + */ +class Load : public PrimExpr { + public: + TVM_DLL Load(DataType dtype, Var buffer_var, PrimExpr index, PrimExpr predicate); + TVM_DEFINE_OBJECT_REF_METHODS(Load, PrimExpr, LoadNode); +}; + /*! * \brief Construct a vector with lanes elements * where its i-th element equals base + i * stride. @@ -553,11 +762,8 @@ class RampNode : public PrimExprNode { } bool SEqualReduce(const RampNode* other, SEqualReducer equal) const { - return - equal(dtype, other->dtype) && - equal(base, other->base) && - equal(stride, other->stride) && - equal(lanes, other->lanes); + return equal(dtype, other->dtype) && equal(base, other->base) && equal(stride, other->stride) && + equal(lanes, other->lanes); } void SHashReduce(SHashReducer hash_reduce) const { @@ -567,12 +773,20 @@ class RampNode : public PrimExprNode { hash_reduce(lanes); } - TVM_DLL static PrimExpr make(PrimExpr base, PrimExpr stride, int lanes); - - static constexpr const char* _type_key = "Ramp"; + static constexpr const char* _type_key = "tir.Ramp"; TVM_DECLARE_FINAL_OBJECT_INFO(RampNode, PrimExprNode); }; +/*! + * \brief Managed reference to RampNode + * \sa RampNode + */ +class Ramp : public PrimExpr { + public: + TVM_DLL Ramp(PrimExpr base, PrimExpr stride, int lanes); + TVM_DEFINE_OBJECT_REF_METHODS(Ramp, PrimExpr, RampNode); +}; + /*! \brief Create a vector where all the elements are value. */ class BroadcastNode : public PrimExprNode { public: @@ -588,10 +802,7 @@ class BroadcastNode : public PrimExprNode { } bool SEqualReduce(const BroadcastNode* other, SEqualReducer equal) const { - return - equal(dtype, other->dtype) && - equal(value, other->value) && - equal(lanes, other->lanes); + return equal(dtype, other->dtype) && equal(value, other->value) && equal(lanes, other->lanes); } void SHashReduce(SHashReducer hash_reduce) const { @@ -600,12 +811,20 @@ class BroadcastNode : public PrimExprNode { hash_reduce(lanes); } - TVM_DLL static PrimExpr make(PrimExpr value, int lanes); - - static constexpr const char* _type_key = "Broadcast"; + static constexpr const char* _type_key = "tir.Broadcast"; TVM_DECLARE_FINAL_OBJECT_INFO(BroadcastNode, PrimExprNode); }; +/*! + * \brief Managed reference to BroadcastNode + * \sa BroadcastNode + */ +class Broadcast : public PrimExpr { + public: + TVM_DLL Broadcast(PrimExpr value, int lanes); + TVM_DEFINE_OBJECT_REF_METHODS(Broadcast, PrimExpr, BroadcastNode); +}; + /*! * \brief Let binding. Bind var to value then evaluate body. */ @@ -626,11 +845,8 @@ class LetNode : public PrimExprNode { } bool SEqualReduce(const LetNode* other, SEqualReducer equal) const { - return - equal(dtype, other->dtype) && - equal.DefEqual(var, other->var) && - equal(value, other->value) && - equal(body, other->body); + return equal(dtype, other->dtype) && equal.DefEqual(var, other->var) && + equal(value, other->value) && equal(body, other->body); } void SHashReduce(SHashReducer hash_reduce) const { @@ -640,45 +856,18 @@ class LetNode : public PrimExprNode { hash_reduce(body); } - TVM_DLL static PrimExpr make(Var var, PrimExpr value, PrimExpr body); - - static constexpr const char* _type_key = "Let"; + static constexpr const char* _type_key = "tir.Let"; TVM_DECLARE_FINAL_OBJECT_INFO(LetNode, PrimExprNode); }; -// Call node, represent a function call or a multi-dimensional array load. -// -// TODO(tvm-team): -// Refactor call with more explicit property registrations. -// rather than calling a string symbol. -// We should move most information into function itself and remove name. - -/*! \brief Base node of internal functions. */ -class FunctionBaseNode : public Object { - public: - /*! \brief virtual destructor */ - virtual ~FunctionBaseNode() {} - /*! \return the name of the function */ - virtual const std::string& func_name() const = 0; - /*! \return the number of outputs of this function */ - virtual int num_outputs() const = 0; - - // fall back to pointer equality now before refactor. - bool SEqualReduce(const FunctionBaseNode* other, SEqualReducer equal) const { - return this == other; - } - - void SHashReduce(SHashReducer hash_reduce) const { - } - - static constexpr const bool _type_has_method_sequal_reduce = true; - static constexpr const bool _type_has_method_shash_reduce = true; -}; - -/*! \brief reference to a function */ -class FunctionRef : public ObjectRef { +/*! + * \brief Managed reference to LetNode + * \sa LetNode + */ +class Let : public PrimExpr { public: - TVM_DEFINE_OBJECT_REF_METHODS(FunctionRef, ObjectRef, FunctionBaseNode); + TVM_DLL Let(Var var, PrimExpr value, PrimExpr body); + TVM_DEFINE_OBJECT_REF_METHODS(Let, PrimExpr, LetNode); }; /*! @@ -694,41 +883,28 @@ class CallNode : public PrimExprNode { ExternCPlusPlus = 1, /*! \brief Extern "C" without side-effect. */ PureExtern = 2, - /*! \brief Halide-style call, evaluates func(args). */ - Halide = 3, /*! \brief Intrinsic functions. */ Intrinsic = 4, /*! \brief Intrinsic functions that are pure. */ PureIntrinsic = 5 }; /*! \brief The name of the function/intrinsic. */ - std::string name; + String name; /*! \brief The arguments. */ Array args; /*! \brief Type of calls. */ CallType call_type; - /*! \brief The function to be called. */ - FunctionRef func; - /*! \brief The output value index if func's value is a tuple. */ - int value_index{0}; void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &dtype); v->Visit("name", &name); v->Visit("args", &args); v->Visit("call_type", &call_type); - v->Visit("func", &func); - v->Visit("value_index", &value_index); } bool SEqualReduce(const CallNode* other, SEqualReducer equal) const { - return - equal(dtype, other->dtype) && - equal(name, other->name) && - equal(args, other->args) && - equal(call_type, other->call_type) && - equal(func, other->func) && - equal(value_index, other->value_index); + return equal(dtype, other->dtype) && equal(name, other->name) && equal(args, other->args) && + equal(call_type, other->call_type); } void SHashReduce(SHashReducer hash_reduce) const { @@ -736,39 +912,23 @@ class CallNode : public PrimExprNode { hash_reduce(name); hash_reduce(args); hash_reduce(call_type); - hash_reduce(func); - hash_reduce(value_index); } - TVM_DLL static PrimExpr make(DataType dtype, - std::string name, - Array args, - CallType call_type, - FunctionRef func = FunctionRef(), - int value_index = 0); - /*! \return Whether call node is pure. */ - bool is_pure() const { - return (call_type == PureExtern || - call_type == PureIntrinsic || - call_type == Halide); - } + bool is_pure() const { return (call_type == PureExtern || call_type == PureIntrinsic); } /*! * \return Whether call node corresponds to a defined intrinsic. * \param intrin_name The name of the intrinsic. */ bool is_intrinsic(const char* intrin_name) const { - return - ((call_type == Intrinsic || - call_type == PureIntrinsic) && - name == intrin_name); + return ((call_type == Intrinsic || call_type == PureIntrinsic) && name == intrin_name); } /*! \return Whether call node can be vectorized. */ bool is_vectorizable() const; - static constexpr const char* _type_key = "Call"; + static constexpr const char* _type_key = "tir.Call"; TVM_DECLARE_FINAL_OBJECT_INFO(CallNode, PrimExprNode); // Build-in intrinsics @@ -781,7 +941,6 @@ class CallNode : public PrimExprNode { static constexpr const char* shift_right = "shift_right"; static constexpr const char* popcount = "popcount"; static constexpr const char* likely = "likely"; - static constexpr const char* glsl_texture_store = "glsl_texture_store"; static constexpr const char* prefetch = "prefetch"; static constexpr const char* isnan = "isnan"; static constexpr const char* isfinite = "isfinite"; @@ -791,6 +950,18 @@ class CallNode : public PrimExprNode { static const char* vectorizable_intrinsics[]; }; +/*! + * \brief Managed reference to CallNode + * \sa CallNode + */ +class Call : public PrimExpr { + public: + using CallType = CallNode::CallType; + + TVM_DLL Call(DataType dtype, String name, Array args, CallType call_type); + TVM_DEFINE_OBJECT_REF_METHODS(Call, PrimExpr, CallNode); +}; + /*! * \brief Shuffle instruction. * vec = concat(vectors) @@ -809,10 +980,8 @@ class ShuffleNode : public PrimExprNode { } bool SEqualReduce(const ShuffleNode* other, SEqualReducer equal) const { - return - equal(dtype, other->dtype) && - equal(vectors, other->vectors) && - equal(indices, other->indices); + return equal(dtype, other->dtype) && equal(vectors, other->vectors) && + equal(indices, other->indices); } void SHashReduce(SHashReducer hash_reduce) const { @@ -821,35 +990,24 @@ class ShuffleNode : public PrimExprNode { hash_reduce(indices); } - TVM_DLL static PrimExpr make(Array vectors, Array indices); - TVM_DLL static PrimExpr make_concat(Array vectors); - TVM_DLL static PrimExpr make_extract_element(PrimExpr vector, int index); - - static constexpr const char* _type_key = "Shuffle"; + static constexpr const char* _type_key = "tir.Shuffle"; TVM_DECLARE_FINAL_OBJECT_INFO(ShuffleNode, PrimExprNode); }; -// Reduce operator -class CommReducerNode; - -class CommReducer : public ObjectRef { +/*! + * \brief Managed reference to ShuffleNode + * \sa ShuffleNode + */ +class Shuffle : public PrimExpr { public: - CommReducer() {} - explicit CommReducer(ObjectPtr n) : ObjectRef(n) {} - /*! - * \brief access the internal node container - * \return the pointer to the internal node container - */ - inline const CommReducerNode* get() const; - /*! - * \brief access the internal node container - * \return the pointer to the internal node container - */ - inline const CommReducerNode* operator->() const; - /*! \brief type indicate the container type */ - using ContainerType = CommReducerNode; + TVM_DLL Shuffle(Array vectors, Array indices); + TVM_DLL static PrimExpr Concat(Array vectors); + TVM_DLL static PrimExpr ExtractElement(PrimExpr vector, int index); + + TVM_DEFINE_OBJECT_REF_METHODS(Shuffle, PrimExpr, ShuffleNode); }; +// Reduce operator /*! * \brief A commutative reducer node to represent a commutative * binary operator with identity element @@ -870,11 +1028,6 @@ class CommReducerNode : public Object { Array identity_element; /*! \brief Function call operator to combine a and b */ Array operator()(Array a, Array b) const; - /*! \brief construct CommReducer from args, result and identity_element */ - TVM_DLL static CommReducer make(Array lhs, - Array rhs, - Array result, - Array identity_element); void VisitAttrs(AttrVisitor* v) { v->Visit("lhs", &lhs); @@ -884,11 +1037,8 @@ class CommReducerNode : public Object { } bool SEqualReduce(const CommReducerNode* other, SEqualReducer equal) const { - return - equal.DefEqual(lhs, other->lhs) && - equal.DefEqual(rhs, other->rhs) && - equal(result, other->result) && - equal(identity_element, other->identity_element); + return equal.DefEqual(lhs, other->lhs) && equal.DefEqual(rhs, other->rhs) && + equal(result, other->result) && equal(identity_element, other->identity_element); } void SHashReduce(SHashReducer hash_reduce) const { @@ -898,18 +1048,23 @@ class CommReducerNode : public Object { hash_reduce(identity_element); } - static constexpr const char* _type_key = "CommReducer"; + static constexpr const char* _type_key = "tir.CommReducer"; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true; TVM_DECLARE_FINAL_OBJECT_INFO(CommReducerNode, Object); }; -inline const CommReducerNode* CommReducer::get() const { - return static_cast(data_.get()); -} -inline const CommReducerNode* CommReducer::operator->() const { - return get(); -} +/*! + * \brief Managed reference to CommReducerNode + * \sa CommReducerNode + */ +class CommReducer : public ObjectRef { + public: + TVM_DLL CommReducer(Array lhs, Array rhs, Array result, + Array identity_element); + + TVM_DEFINE_OBJECT_REF_METHODS(CommReducer, ObjectRef, CommReducerNode); +}; /*! \brief Reduction operator operator */ class ReduceNode : public PrimExprNode { @@ -928,13 +1083,6 @@ class ReduceNode : public PrimExprNode { /*! \brief the index of this reduce node */ int value_index; - /*! \brief construct expr from op and rdom */ - TVM_DLL static PrimExpr make(CommReducer combiner, - Array src, - Array rdom, - PrimExpr condition, - int value_index); - void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &dtype); v->Visit("combiner", &combiner); @@ -946,13 +1094,9 @@ class ReduceNode : public PrimExprNode { bool SEqualReduce(const ReduceNode* other, SEqualReducer equal) const { // check axis first so IterVars can define the necessary variables. - return - equal(dtype, other->dtype) && - equal(axis, other->axis) && - equal(combiner, other->combiner) && - equal(source, other->source) && - equal(condition, other->condition) && - equal(value_index, other->value_index); + return equal(dtype, other->dtype) && equal(axis, other->axis) && + equal(combiner, other->combiner) && equal(source, other->source) && + equal(condition, other->condition) && equal(value_index, other->value_index); } void SHashReduce(SHashReducer hash_reduce) const { @@ -964,33 +1108,48 @@ class ReduceNode : public PrimExprNode { hash_reduce(value_index); } - static constexpr const char* _type_key = "Reduce"; + static constexpr const char* _type_key = "tir.Reduce"; TVM_DECLARE_FINAL_OBJECT_INFO(ReduceNode, PrimExprNode); }; +/*! + * \brief Managed reference to ReduceNode + * \sa ReduceNode + */ +class Reduce : public PrimExpr { + public: + TVM_DLL Reduce(CommReducer combiner, Array src, Array rdom, PrimExpr condition, + int value_index); + + TVM_DEFINE_OBJECT_REF_METHODS(Reduce, PrimExpr, ReduceNode); +}; + /*! \brief Any shape. */ class AnyNode : public PrimExprNode { public: void VisitAttrs(AttrVisitor* v) {} - bool SEqualReduce(const AnyNode* other, SEqualReducer equal) const { - return true; - } + bool SEqualReduce(const AnyNode* other, SEqualReducer equal) const { return true; } - void SHashReduce(SHashReducer hash_reduce) const { - } + void SHashReduce(SHashReducer hash_reduce) const {} /*! \brief Convert to var. */ - Var ToVar() const { - return Var("any_dim", DataType::Int(32)); - } - - TVM_DLL static PrimExpr make(); + Var ToVar() const { return Var("any_dim", DataType::Int(32)); } - static constexpr const char* _type_key = "Any"; + static constexpr const char* _type_key = "tir.Any"; TVM_DECLARE_FINAL_OBJECT_INFO(AnyNode, PrimExprNode); }; +/*! + * \brief Managed reference to AnyNode + * \sa AnyNode + */ +class Any : public PrimExpr { + public: + TVM_DLL Any(); + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Any, PrimExpr, AnyNode); +}; /* * \brief Template function to convert Map to unordered_map @@ -1000,7 +1159,7 @@ class AnyNode : public PrimExprNode { * \tparam K the key of the Map. * \tparam V the value of the Map. */ -template +template inline std::unordered_map as_unordered_map(const Map& dmap) { std::unordered_map ret; for (auto kv : dmap) { @@ -1167,7 +1326,7 @@ constexpr const char* tvm_call_packed = "tvm_call_packed"; * return 0; * } */ -constexpr const char *tvm_call_trace_packed = "tvm_call_trace_packed"; +constexpr const char* tvm_call_trace_packed = "tvm_call_trace_packed"; /*! * \brief See pesudo code * Mark the content as thread local context, can get optimized @@ -1214,8 +1373,7 @@ constexpr const char* tvm_call_packed_lowered = "tvm_call_packed_lowered"; * TVMRetValue(value_stack + end, tcode_stack + end)); * } */ -constexpr const char *tvm_call_trace_packed_lowered = - "tvm_call_trace_packed_lowered"; +constexpr const char* tvm_call_trace_packed_lowered = "tvm_call_trace_packed_lowered"; /*! * \brief See pseudo code * @@ -1225,14 +1383,43 @@ constexpr const char *tvm_call_trace_packed_lowered = * } */ constexpr const char* tvm_storage_sync = "tvm_storage_sync"; + /*! * \brief See pseudo code * - * Type tvm_warp_shuffle(Type value, warp_id) { - * return (value passed in by warp indicated by warp_id); + * Type tvm_warp_shuffle(mask, Type value, warp_id, width, warp_size) { + * return (value passed in by warp indicated by this_warp_id); * } + * + * Type tvm_warp_shuffle_up(mask, Type value, offset, width, warp_size) { + * return (value passed in by warp indicated by this_warp_id - offset); + * } + * + * Type tvm_warp_shuffle_down(mask, Type value, offset, width, warp_size) { + * return (value passed in by warp indicated by this_warp_id + offset); + * } + * + * unsigned tvm_warp_activemask() { + * return (32-bit mask of currently active threads in the calling warp); + * } + * + * Parameter warp_id indicates the source thread ID in a warp. + * + * Parameter offset indicates the relative distance to this_warp_id. + * + * Parameter width indicates the number of threads involved in one + * shuffle. See CUDA document for __shfl_sync, __shfl_up_sync, + * __shfl_down_sync and __activemask. + * + * Parameter warp_size is the size of a warp, which helps a backend + * to determine wheter the width paramter is legal. + * */ constexpr const char* tvm_warp_shuffle = "tvm_warp_shuffle"; +constexpr const char* tvm_warp_shuffle_up = "tvm_warp_shuffle_up"; +constexpr const char* tvm_warp_shuffle_down = "tvm_warp_shuffle_down"; +constexpr const char* tvm_warp_activemask = "tvm_warp_activemask"; + /*! * \brief Initialize the global barrier. * Call this at beginning of kernel that need global barrier. @@ -1330,7 +1517,7 @@ enum TVMStructFieldKind : int { kTVMValueContent, kTVMValueKindBound_ }; -} // namespace intrinsic +} // namespace intrinsic } // namespace tir } // namespace tvm @@ -1339,7 +1526,7 @@ namespace tvm { namespace runtime { // Additional implementattion overloads for PackedFunc. -template<> +template <> struct PackedFuncValueConverter { // common rule for RetValue and ArgValue static tvm::Integer From(const TVMPODValue_& val) { @@ -1357,7 +1544,6 @@ struct PackedFuncValueConverter { namespace std { template <> -struct hash<::tvm::tir::IterVar> : public ::tvm::ObjectHash { -}; -} +struct hash<::tvm::tir::IterVar> : public ::tvm::ObjectPtrHash {}; +} // namespace std #endif // TVM_TIR_EXPR_H_ diff --git a/include/tvm/tir/expr_functor.h b/include/tvm/tir/expr_functor.h index dcf04c346454..a6c90b36a49c 100644 --- a/include/tvm/tir/expr_functor.h +++ b/include/tvm/tir/expr_functor.h @@ -71,22 +71,19 @@ namespace tir { * \tparam FType function signiture * This type if only defined for FType with function signiture R(const Expr&, Args...) */ -template +template class ExprFunctor; // functions to be overriden. -#define EXPR_FUNCTOR_DEFAULT { \ - return VisitExprDefault_(op, std::forward(args)...); \ - } +#define EXPR_FUNCTOR_DEFAULT \ + { return VisitExprDefault_(op, std::forward(args)...); } -#define IR_EXPR_FUNCTOR_DISPATCH(OP) \ - vtable.template set_dispatch( \ - [](const ObjectRef& n, TSelf* self, Args... args) { \ - return self->VisitExpr_(static_cast(n.get()), \ - std::forward(args)...); \ - }); \ +#define IR_EXPR_FUNCTOR_DISPATCH(OP) \ + vtable.template set_dispatch([](const ObjectRef& n, TSelf* self, Args... args) { \ + return self->VisitExpr_(static_cast(n.get()), std::forward(args)...); \ + }); -template +template class ExprFunctor { private: using TSelf = ExprFunctor; @@ -122,6 +119,7 @@ class ExprFunctor { return VisitExpr_(static_cast(op), std::forward(args)...); } virtual R VisitExpr_(const BufferLoadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const ProducerLoadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const LoadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const LetNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const CallNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; @@ -152,7 +150,7 @@ class ExprFunctor { virtual R VisitExpr_(const IntImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const FloatImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const StringImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExprDefault_(const Object* op, Args ...) { + virtual R VisitExprDefault_(const Object* op, Args...) { LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); return R(); } @@ -166,6 +164,7 @@ class ExprFunctor { IR_EXPR_FUNCTOR_DISPATCH(SizeVarNode); IR_EXPR_FUNCTOR_DISPATCH(LoadNode); IR_EXPR_FUNCTOR_DISPATCH(BufferLoadNode); + IR_EXPR_FUNCTOR_DISPATCH(ProducerLoadNode); IR_EXPR_FUNCTOR_DISPATCH(LetNode); IR_EXPR_FUNCTOR_DISPATCH(CallNode); IR_EXPR_FUNCTOR_DISPATCH(AddNode); @@ -205,8 +204,7 @@ class ExprFunctor { /*! * \brief ExprVisitor */ -class TVM_DLL ExprVisitor : - public ExprFunctor { +class TVM_DLL ExprVisitor : public ExprFunctor { public: using ExprFunctor::operator(); @@ -217,6 +215,7 @@ class TVM_DLL ExprVisitor : void VisitExpr_(const SizeVarNode* op) override; void VisitExpr_(const LoadNode* op) override; void VisitExpr_(const BufferLoadNode* op) override; + void VisitExpr_(const ProducerLoadNode* op) override; void VisitExpr_(const LetNode* op) override; void VisitExpr_(const CallNode* op) override; void VisitExpr_(const AddNode* op) override; @@ -251,8 +250,7 @@ class TVM_DLL ExprVisitor : /*! * \brief ExprMutator that mutates expressions. */ -class TVM_DLL ExprMutator : - protected ExprFunctor { +class TVM_DLL ExprMutator : protected ExprFunctor { public: using ExprFunctor::operator(); @@ -263,6 +261,7 @@ class TVM_DLL ExprMutator : PrimExpr VisitExpr_(const SizeVarNode* op) override; PrimExpr VisitExpr_(const LoadNode* op) override; PrimExpr VisitExpr_(const BufferLoadNode* op) override; + PrimExpr VisitExpr_(const ProducerLoadNode* op) override; PrimExpr VisitExpr_(const LetNode* op) override; PrimExpr VisitExpr_(const CallNode* op) override; PrimExpr VisitExpr_(const AddNode* op) override; diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h index 1866f2f1f891..919391e36b96 100644 --- a/include/tvm/tir/function.h +++ b/include/tvm/tir/function.h @@ -25,11 +25,11 @@ #define TVM_TIR_FUNCTION_H_ #include -#include #include +#include #include -#include +#include namespace tvm { namespace tir { @@ -104,12 +104,9 @@ class PrimFuncNode : public BaseFuncNode { bool SEqualReduce(const PrimFuncNode* other, SEqualReducer equal) const { // visit params and buffer_map first as they contains defs. - return - equal.DefEqual(params, other->params) && - equal(buffer_map, other->buffer_map) && - equal(ret_type, other->ret_type) && - equal(body, other->body) && - equal(attrs, other->attrs); + return equal.DefEqual(params, other->params) && equal(buffer_map, other->buffer_map) && + equal(ret_type, other->ret_type) && equal(body, other->body) && + equal(attrs, other->attrs); } void SHashReduce(SHashReducer hash_reduce) const { @@ -146,9 +143,7 @@ class PrimFunc : public BaseFunc { * \param buffer_map The buffer map for parameter buffer unpacking. * \param attrs Additional function attributes. */ - TVM_DLL PrimFunc(Array params, - Stmt body, - Type ret_type = VoidType(), + TVM_DLL PrimFunc(Array params, Stmt body, Type ret_type = VoidType(), Map buffer_map = NullValue>(), DictAttrs attrs = NullValue()); diff --git a/include/tvm/tir/ir_pass.h b/include/tvm/tir/ir_pass.h deleted file mode 100644 index e228ce32adab..000000000000 --- a/include/tvm/tir/ir_pass.h +++ /dev/null @@ -1,404 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/tir/ir_pass.h - * \brief Collection of IR pass functions - * - * When the pass functions in this file are for Stmt, - * we can use PassFunction(Evaluate(expr)) to apply it to Expr - */ -#ifndef TVM_TIR_IR_PASS_H_ -#define TVM_TIR_IR_PASS_H_ - -#include -#include -#include -#include - -#include -#include -#include -#include - - -namespace tvm { -namespace tir { - -/*! - * \brief Simplify the expression. - * \param expr The expression to be simplifed. - * \param vrange The range information about the variable. - * \return Canonicalized statement. - */ -TVM_DLL PrimExpr Simplify(PrimExpr expr, Map vrange = Map()); - -/*! - * \brief Simplify the statement. - * \param stmt The statement to be simplifed. - * \param vrange The range information about the variable. - * \return Canonicalized statement. - */ -Stmt Simplify(Stmt stmt, Map vrange = Map()); - -/*! - * \brief Simplify by applying canonical form. - * \param stmt The statement to be canonically simplifed. - * \param vrange The range information about the variable. - * \return Canonicalized statement. - */ -Stmt CanonicalSimplify(Stmt stmt, - Map vrange = Map()); - -/*! - * \brief Simplify by applying canonical form. - * \param expr The statement to be canonically simplifed. - * \param vrange The range information about the variable. - * \return Canonicalized expression. - */ -TVM_DLL PrimExpr CanonicalSimplify(PrimExpr expr, - Map vrange = Map()); - -/*! - * \brief verifies whether the IR stmt or Expr is in SSA form. - * That is: each VarExpr is defined and assigned once(in Let/For) - * - * \param ir The root of the IR DAG. - * \return Whether IR is in SSA form. - * \note All the passes in this file uses SSA form and outputs SSA form. - */ -TVM_DLL bool VerifySSA(const Stmt& ir); - -/*! - * \brief Whether the expression have side effect. - * \return whether expression have side effect - */ -TVM_DLL bool HasSideEffect(const PrimExpr& e); - -/*! - * \brief Whether e expression used var. - * \param e The expression to be checked. - * \param v The variable. - * \return Whether e uses v. - */ -bool ExprUseVar(const PrimExpr& e, const Var& v); - -/*! - * \brief Whether e expression used any var in variable set.. - * \param e The expression to be checked. - * \param vset The variable set. - * \return Whether e uses vset. - */ -bool ExprUseVar(const PrimExpr& e, const std::unordered_set& vset); - -/*! - * \brief Convert a IR node to be SSA form. - * \param stmt The source statement to be converted. - * \return The converted form. - */ -TVM_DLL Stmt ConvertSSA(Stmt stmt); - -/*! - * \brief Substitute the var specified in key->var to be value. - * \param stmt The source statement to be substituted - * \param value_map The map of new values. - * \return The converted form. - */ -Stmt Substitute(Stmt stmt, - const std::unordered_map& value_map); - -/*! - * \brief Substitute the var specified in key->var to be value. - * \param expr The source expression to be substituted - * \param value_map The map of new values. - * \return The converted expression. - */ -PrimExpr Substitute(PrimExpr expr, - const std::unordered_map& value_map); - -/*! - * \brief Substitute the var specified in key->var to be value. - * \param stmt The source statement to be substituted - * \param value_map The map of new values. - * \return The converted form. - */ -Stmt Substitute(Stmt stmt, const Map& value_map); - -/*! - * \brief Substitute the var specified in key->var to be value. - * \param expr The source expression to be substituted - * \param value_map The map of new values. - * \return The converted expression. - */ -PrimExpr Substitute(PrimExpr expr, const Map& value_map); - -/*! - * \brief inline all calls of f in stmt. - * - * \param stmt The statement to apply inline optimization. - * \param f The function reference to be inlined - * \param args The arguments variable of the function. - * \param body The definition body of the function. - * \return The result stmt - * - * \note All the passes in this file uses SSA form and outputs SSA form. - */ -Stmt Inline(Stmt stmt, - FunctionRef f, - Array args, - PrimExpr body); - -/*! - * \brief Flatten the multi-dimensional read/write - * to single dimensional Load/Store - * - * \param stmt The stmt to be trasnformed. - * \param extern_buffer Map specifies external - * buffer assignment of input and outputs. - * \param cache_line_size The size of CPU cache line. - * \param create_bound_attribute Whether to create bound attributes. - * \return Transformed stmt. - */ -Stmt StorageFlatten(Stmt stmt, - Map extern_buffer, - int cache_line_size, - bool create_bound_attribute = false); - -/*! - * \brief Try to modify the AST to support TensorCore - * - * \param stmt The stmt to be trasnformed. - * \param schedule The original schedule. - * \param extern_buffer Map specifies external - * buffer assignment of input and outputs. - * \return Transformed stmt. - */ -Stmt RewriteForTensorCore(Stmt stmt, - te::Schedule schedule, - Map extern_buffer); - -/*! - * \brief Verify if there is any argument bound to compact buffer. - * - * \param stmt The stmt to be verified. - * \return true if there is any buffer_bind_scope attribute found, - * otherwise, false. - */ -bool VerifyCompactBuffer(Stmt stmt); - -/*! - * \brief Remove No Op from the Stmt. - * \param stmt The stmt to be trasnformed - * \return Transformed stmt. - */ -Stmt RemoveNoOp(Stmt stmt); - -/*! - * \brief unroll the constant loop marked by unroll. - * This pass also automatically attach pragma unroll tag to loops which meets the standard. - * - * \param stmt The statment to be unrolled. - * \param auto_max_step The maximum step before stop attach automatic unroll - * \param auto_max_depth The maximum depth before stop attach automatic unroll - * \param auto_max_extent The maximum extent of the loop we can unroll, - * this is an legacy option that do not take the loop total steps into account. - * \param explicit_unroll Whether explicitly unroll the loop, or leave unroll annotation to codegen. - * \return Transformed stmt. - */ -Stmt UnrollLoop(Stmt stmt, - int auto_max_step, - int auto_max_depth, - int auto_max_extent, - bool explicit_unroll); - -/*! - * \brief vectorize the constant loops - * \param stmt The statement to be vectorized. - * \return Transformed stmt. - */ -Stmt VectorizeLoop(Stmt stmt); - -/*! - * \brief convert vectorized loops into serialized loops - * \param stmt The statement to skip vectorization on. - * \return Transformed stmt. - */ -Stmt SkipVectorize(Stmt stmt); - -/*! -* \brief instruments bound checkers. -* \param stmt The statement to be instrumented. -* \return Instrumented stmt. -*/ -Stmt InstrumentBoundCheckers(Stmt stmt); - -/*! - * \brief Inject virtual thread loops into stmt. - * \param stmt The statement to be transformed. - * \return Transformed stmt. - */ -Stmt InjectVirtualThread(Stmt stmt); - -/*! - * \brief Inject prefetch instructions into stmt. - * \param stmt The statement to be transformed. - * \return Transformed stmt. - */ -Stmt InjectPrefetch(Stmt stmt); - -/*! - * \brief Inject double buffer into stmt. - * \param stmt The statement to be transformed. - * \param split_loop Loop splitting factor. - * \return Transformed stmt. - */ -Stmt InjectDoubleBuffer(Stmt stmt, int split_loop); - -/*! - * \brief Inject copy intrinsics with optional pad. - * - * \param stmt The statement to be transformed. - * \param pragma_key The pragma key for hint of copy. - * \param fintrin The function with signature - * - * Stmt fintrin(Buffer src, - * Buffer dst, - * Array pad_before, - * Array pad_after, - * Expr pad_value) - * \return Transformed stmt. - */ -Stmt InjectCopyIntrin(Stmt stmt, - const std::string& pragma_key, - const runtime::PackedFunc& fintrin); - -/*! - * \brief Rewrite storage allocation pattern. - * Moves the allocation to outer most possible scope. - * Trying to share space between allocations to make - * a static allocation plan when possible. - * - * \param stmt The stmt to be transformed - * \return Transformed stmt. - */ -Stmt StorageRewrite(Stmt stmt); - -/*! - * \brief partition loops in the stmt - * \param stmt The stmt to do loop partition - * \param split_const_loop flag to enable partition for const loop - * \return Transformed stmt. - */ -Stmt LoopPartition(Stmt stmt, bool split_const_loop); - -/*! - * \brief Detect and insert sync points to co-processor. - * - * \param stmt The stmt to be transformed - * \return Transformed stmt. - */ -Stmt CoProcSync(Stmt stmt); - -/*! - * \brief Lift common attrs with attr_key to outer scope. - * - * \param stmt The stmt to be transformed - * \param attr_key The attribute key to be checked. - * \return Transformed stmt. - */ -Stmt LiftAttrScope(Stmt stmt, std::string attr_key); - -/*! - * \brief Detect and rewrite unsafe select that contains memory access. - * \param stmt The statement to be rewritten. - * \return Transformed stmt. - */ -Stmt RewriteUnsafeSelect(Stmt stmt); - -/*! - * \brief Lower attached storage access information. - * Do this pass after all storage access analysis finish. - * - * \param stmt The stmt to be transformed - * \return Transformed stmt. - */ -Stmt LowerStorageAccessInfo(Stmt stmt); - -/*! - * \brief Decorate the stmt with a device scope, this is helpful for - * hardware accelerator without thread blocks. - * - * \param stmt The stmt to be transformed - * \return Transformed stmt. - */ -Stmt DecorateDeviceScope(Stmt stmt); - -/*! - * \brief Loop invariant code motion which locates and hoists if statements. - * \param stmt The stmt to do if statement hoisting. - * \return Transformed stmt. - */ -Stmt HoistIfThenElse(Stmt stmt); - -/*! - * \brief Narrow down PrimExpr datatype in stmt to target_bits. - * \note Run this pass after StorageFlatten. - * \param stmt The stmt to do datatype rewrite - * \param target_bits the bit of target datatype - * \return Transformed stmt. - */ -Stmt NarrowDataType(Stmt stmt, int target_bits); - -/*! - * \brief Rewrite the pointer content type of arguments, - * as well as Alloc internal to the function to use - * the most frequently accessed type for load/store - * to avoid pointer casting in backend when possible. - * - * \note implemeneted in storage_rewrite.cc - * \param f The function to be trasnformed - * \return Transformed function. - */ -PrimFunc PointerValueTypeRewrite(PrimFunc f); - -/*! - * \brief Verify the correctness of a GPU code - * It will check the whether the amount of memory usage or the number of threads - * in a block exceeds the limit - * \param stmt The statement to be checked - * \param constraints The dict to specify constraints to check. - * Possible keys are - * - * "max_local_memory_per_block": Total amount of local memory per block (in bytes). - * "max_shared_memory_per_block": Total amount of shared memory per block (in bytes). - * "max_threads_per_block": Maximum number of threads per block. - * "max_thread_x": Maximum length of threadIdx.x. - * "max_thread_y": Maximum length of threadIdx.y. - * "max_thread_z": Maximum length of threadIdx.z. - * - * If one key is missing in this argument, the pass won't check for that item. - * \return valid Whether it is a valid GPU code - * - */ -bool VerifyGPUCode(Stmt stmt, - Map constraints); - -} // namespace tir -} // namespace tvm -#endif // TVM_TIR_IR_PASS_H_ diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h index b54aa9aaf7cc..71e9ac4c3e22 100644 --- a/include/tvm/tir/op.h +++ b/include/tvm/tir/op.h @@ -33,9 +33,8 @@ #include #include -#include #include - +#include namespace tvm { @@ -464,6 +463,7 @@ TVM_DLL PrimExpr isinf(PrimExpr x); * \brief sum of of source expression over axis * \param source The source expression. * \param axis List of iteration variables that will be used for reduction. + * \return The result. */ TVM_DLL PrimExpr sum(PrimExpr source, Array axis); @@ -478,6 +478,7 @@ TVM_DLL PrimExpr all(PrimExpr source, Array axis); * \brief logical Or of of source expression over axis * \param source The source expression. * \param axis List of iteration variables that will be used for reduction. + * \return The result. */ TVM_DLL PrimExpr any(PrimExpr source, Array axis); @@ -485,6 +486,7 @@ TVM_DLL PrimExpr any(PrimExpr source, Array axis); * \brief max of of source expression over axis * \param source The source expression. * \param axis List of iteration variables that will be used for reduction. + * \return The result. */ TVM_DLL PrimExpr max(PrimExpr source, Array axis); @@ -492,6 +494,7 @@ TVM_DLL PrimExpr max(PrimExpr source, Array axis); * \brief max of of source expression over axis * \param source The source expression. * \param axis List of iteration variables that will be used for reduction. + * \return The result. */ TVM_DLL PrimExpr min(PrimExpr source, Array axis); @@ -499,6 +502,7 @@ TVM_DLL PrimExpr min(PrimExpr source, Array axis); * \brief product of of source expression over axis * \param source The source expression. * \param axis List of iteration variables that will be used for reduction. + * \return The result. */ TVM_DLL PrimExpr prod(PrimExpr source, Array axis); @@ -548,10 +552,10 @@ TVM_DLL PrimExpr trunc(PrimExpr x); TVM_DLL PrimExpr LargeUIntImm(DataType dtype, int64_t low, int64_t high); // Intrinsic operators -#define TVM_DECLARE_INTRIN_UNARY(OpName) \ - inline PrimExpr OpName(PrimExpr x) { \ - return tir::CallNode::make(x.dtype(), #OpName, {x}, tir::CallNode::PureIntrinsic); \ - } \ +#define TVM_DECLARE_INTRIN_UNARY(OpName) \ + inline PrimExpr OpName(PrimExpr x) { \ + return tir::Call(x.dtype(), #OpName, {x}, tir::CallNode::PureIntrinsic); \ + } TVM_DECLARE_INTRIN_UNARY(exp); TVM_DECLARE_INTRIN_UNARY(exp2); @@ -570,7 +574,12 @@ TVM_DECLARE_INTRIN_UNARY(cos); TVM_DECLARE_INTRIN_UNARY(cosh); TVM_DECLARE_INTRIN_UNARY(sin); TVM_DECLARE_INTRIN_UNARY(sinh); +TVM_DECLARE_INTRIN_UNARY(asin); +TVM_DECLARE_INTRIN_UNARY(acos); TVM_DECLARE_INTRIN_UNARY(atan); +TVM_DECLARE_INTRIN_UNARY(acosh); +TVM_DECLARE_INTRIN_UNARY(asinh); +TVM_DECLARE_INTRIN_UNARY(atanh); namespace tir { /*! @@ -580,8 +589,8 @@ namespace tir { * \return the result expression. * \tparam ValueType The constant value type */ -template::value>::type> +template ::value>::type> inline PrimExpr make_const(DataType t, ValueType value); /*! * \brief Make a const zero expr. @@ -594,17 +603,13 @@ inline PrimExpr make_zero(DataType t); * \param lanes The number of lanes in the bool * \return The result expression. */ -inline PrimExpr const_true(int lanes = 1) { - return make_const(DataType::UInt(1, lanes), 1); -} +inline PrimExpr const_true(int lanes = 1) { return make_const(DataType::UInt(1, lanes), 1); } /*! * \brief Make a constant false expression. * \param lanes The number of lanes in the bool * \return The result expression. */ -inline PrimExpr const_false(int lanes = 1) { - return make_const(DataType::UInt(1, lanes), 0); -} +inline PrimExpr const_false(int lanes = 1) { return make_const(DataType::UInt(1, lanes), 0); } /*! * \brief Get x as constant int expression. * \param x The expression @@ -641,9 +646,7 @@ inline bool is_no_op(const tir::Stmt& stmt); * \note This only return true for integer types. * \return whether x is constant 1 */ -inline bool is_one(const PrimExpr& x) { - return is_const_int(x, 1); -} +inline bool is_one(const PrimExpr& x) { return is_const_int(x, 1); } /*! * \brief Check whether x is a constant integer 0 @@ -651,9 +654,7 @@ inline bool is_one(const PrimExpr& x) { * \return whether x is constant 0 * \note This only return true for integer types. */ -inline bool is_zero(const PrimExpr& x) { - return is_const_int(x, 0); -} +inline bool is_zero(const PrimExpr& x) { return is_const_int(x, 0); } /*! * \brief Check whether x is a constant. @@ -662,6 +663,17 @@ inline bool is_zero(const PrimExpr& x) { */ inline bool is_const(const PrimExpr& x); +/*! + * \brief Left fold. + * \param freduce The reduction function. + * \param init_value The initial value. + * \param values The values to be folded. + * \return The result. + * \tparam FReduce The type of the reduction. + */ +template +inline PrimExpr foldl(FReduce freduce, PrimExpr init_value, const Array& values); + /*! * \brief Check whether x is a constant power of two * If x is power of two, write the power to the shift. @@ -724,7 +736,7 @@ inline bool is_no_op(const tir::Stmt& stmt) { return false; } -template +template inline PrimExpr MakeConstScalar(DataType t, ValueType value) { if (t.is_int()) return IntImm(t, static_cast(value)); if (t.is_uint()) { @@ -744,20 +756,19 @@ inline PrimExpr MakeConstScalar(DataType t, ValueType value) { // datatypes lowering pass, we will lower the value to its true representation in the format // specified by the datatype. // TODO(gus) when do we need to start worrying about doubles not being precise enough? - if (static_cast(t.code()) >= static_cast(kTVMCustomBegin)) { + if (static_cast(t.code()) >= static_cast(DataType::kCustomBegin)) { return FloatImm(t, static_cast(value)); } LOG(FATAL) << "cannot make const for type " << t; return PrimExpr(); } -template +template inline PrimExpr make_const(DataType t, ValueType value) { if (t.lanes() == 1) { return MakeConstScalar(t, value); } else { - return tir::BroadcastNode::make( - MakeConstScalar(t.element_of(), value), t.lanes()); + return tir::Broadcast(MakeConstScalar(t.element_of(), value), t.lanes()); } } @@ -767,47 +778,46 @@ inline PrimExpr make_zero(DataType t) { } return make_const(t, 0); } + +template +inline PrimExpr foldl(FReduce freduce, PrimExpr init_value, const Array& values) { + for (PrimExpr val : values) { + init_value = freduce(init_value, val); + } + return init_value; +} + } // namespace tir // additional const expression overloading -#define TVM_DEFINE_ASSIGN_OP_OVERLOAD(Name, OpFunc) \ - inline PrimExpr Name(PrimExpr& a, PrimExpr b) {\ - a = OpFunc(a, b); \ - return a; \ +#define TVM_DEFINE_ASSIGN_OP_OVERLOAD(Name, OpFunc) \ + inline PrimExpr Name(PrimExpr& a, PrimExpr b) { \ + a = OpFunc(a, b); \ + return a; \ } -#define TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(Name) \ - inline PrimExpr Name(const PrimExpr& a, float b) { \ - return Name(a, PrimExpr(b)); \ - } \ - inline PrimExpr Name(float a, const PrimExpr& b) { \ - return Name(PrimExpr(a), b); \ - } \ - inline PrimExpr Name(int a, const PrimExpr& b) { \ - return Name(tir::make_const(b.dtype(), a), b); \ - } \ - inline PrimExpr Name(const PrimExpr& a, int b) { \ - return Name(a, tir::make_const(a.dtype(), b)); \ - } \ - inline PrimExpr Name(const PrimExpr& a, double b) { \ - return Name(a, tir::make_const(DataType::Float(64), b)); \ +#define TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(Name) \ + inline PrimExpr Name(const PrimExpr& a, float b) { return Name(a, PrimExpr(b)); } \ + inline PrimExpr Name(float a, const PrimExpr& b) { return Name(PrimExpr(a), b); } \ + inline PrimExpr Name(int a, const PrimExpr& b) { \ + return Name(tir::make_const(b.dtype(), a), b); \ + } \ + inline PrimExpr Name(const PrimExpr& a, int b) { \ + return Name(a, tir::make_const(a.dtype(), b)); \ + } \ + inline PrimExpr Name(const PrimExpr& a, double b) { \ + return Name(a, tir::make_const(DataType::Float(64), b)); \ } -#define TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(Name) \ - inline PrimExpr Name(const PrimExpr& a, bool b) { \ - return Name(a, PrimExpr(b)); \ - } \ - inline PrimExpr Name(bool a, const PrimExpr& b) { \ - return Name(PrimExpr(a), b); \ - } +#define TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(Name) \ + inline PrimExpr Name(const PrimExpr& a, bool b) { return Name(a, PrimExpr(b)); } \ + inline PrimExpr Name(bool a, const PrimExpr& b) { return Name(PrimExpr(a), b); } -#define TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(Name) \ - inline PrimExpr Name(const PrimExpr& a, int b) { \ - return Name(a, tir::make_const(a.dtype(), b)); \ - } \ - inline PrimExpr Name(int a, const PrimExpr& b) { \ - return Name(tir::make_const(b.dtype(), a), b); \ - } +#define TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(Name) \ + inline PrimExpr Name(const PrimExpr& a, int b) { \ + return Name(a, tir::make_const(a.dtype(), b)); \ + } \ + inline PrimExpr Name(int a, const PrimExpr& b) { return Name(tir::make_const(b.dtype(), a), b); } TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator+=, operator+); TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator-=, operator-); @@ -829,8 +839,8 @@ TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(truncdiv); TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(truncmod); TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(floordiv); TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(floormod); -TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator>>); // NOLINT(*) -TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator<<); // NOLINT(*) +TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator>>); // NOLINT(*) +TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator<<); // NOLINT(*) TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator&); TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator|); TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator^); @@ -843,7 +853,7 @@ TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(operator||); * \note The call to this function will always results in a compiler error. * \tparam TA Any class type. */ -template +template inline void DivAmbiguityError(const TA& a) { constexpr bool div_ambiguity = !std::is_class::value; static_assert(div_ambiguity, @@ -859,19 +869,19 @@ inline void DivAmbiguityError(const TA& a) { // to use the specific division function. // The second template argument is necessary to make sure the // code compiles lazily by the compiler during invocation. -template +template inline PrimExpr operator/(const PrimExpr& a, const TB& b) { DivAmbiguityError(a); return a; } -template +template inline PrimExpr operator/=(const PrimExpr& a, const TB& b) { DivAmbiguityError(a); return a; } -template +template inline PrimExpr operator%(const PrimExpr& a, const TB& b) { DivAmbiguityError(a); return a; diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 5bc492fcefb8..be1c567198d9 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -20,16 +20,16 @@ * \file tvm/tir/stmt.h * \brief TIR statements. */ -// Acknowledgement: Mnay low-level stmts originate from Halide. +// Acknowledgement: Many low-level stmts originate from Halide. #ifndef TVM_TIR_STMT_H_ #define TVM_TIR_STMT_H_ #include -#include #include -#include +#include #include +#include namespace tvm { namespace tir { @@ -37,9 +37,10 @@ namespace tir { /*! \brief Base node of all statements. */ class StmtNode : public Object { public: - static constexpr const char* _type_key = "Stmt"; + static constexpr const char* _type_key = "tir.Stmt"; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true; + static constexpr const uint32_t _type_child_slots = 15; TVM_DECLARE_BASE_OBJECT_INFO(StmtNode, Object); }; @@ -68,10 +69,8 @@ class LetStmtNode : public StmtNode { } bool SEqualReduce(const LetStmtNode* other, SEqualReducer equal) const { - return - equal.DefEqual(var, other->var) && - equal(value, other->value) && - equal(body, other->body); + return equal.DefEqual(var, other->var) && equal(value, other->value) && + equal(body, other->body); } void SHashReduce(SHashReducer hash_reduce) const { @@ -80,12 +79,21 @@ class LetStmtNode : public StmtNode { hash_reduce(body); } - TVM_DLL static Stmt make(Var var, PrimExpr value, Stmt body); - - static constexpr const char* _type_key = "LetStmt"; + static constexpr const char* _type_key = "tir.LetStmt"; TVM_DECLARE_FINAL_OBJECT_INFO(LetStmtNode, StmtNode); }; +/*! + * \brief Managed reference to LetStmtNode. + * \sa LetStmtNode + */ +class LetStmt : public Stmt { + public: + TVM_DLL LetStmt(Var var, PrimExpr value, Stmt body); + + TVM_DEFINE_OBJECT_REF_METHODS(LetStmt, Stmt, LetStmtNode); +}; + /*! * \brief Define certain auxiliary attribute for the body to be a symbolic value. * This provide auxiliary information for IR passes that transforms body. @@ -101,7 +109,7 @@ class AttrStmtNode : public StmtNode { /*! \brief this is attribute about certain node */ ObjectRef node; /*! \brief the type key of the attribute */ - std::string attr_key; + String attr_key; /*! \brief The attribute value, value is well defined at current scope. */ PrimExpr value; /*! \brief The body statement to be executed */ @@ -115,11 +123,8 @@ class AttrStmtNode : public StmtNode { } bool SEqualReduce(const AttrStmtNode* other, SEqualReducer equal) const { - return - equal(node, other->node) && - equal(attr_key, other->attr_key) && - equal(value, other->value) && - equal(body, other->body); + return equal(node, other->node) && equal(attr_key, other->attr_key) && + equal(value, other->value) && equal(body, other->body); } void SHashReduce(SHashReducer hash_reduce) const { @@ -129,15 +134,21 @@ class AttrStmtNode : public StmtNode { hash_reduce(body); } - TVM_DLL static Stmt make(ObjectRef node, - std::string type_key, - PrimExpr value, - Stmt body); - - static constexpr const char* _type_key = "AttrStmt"; + static constexpr const char* _type_key = "tir.AttrStmt"; TVM_DECLARE_FINAL_OBJECT_INFO(AttrStmtNode, StmtNode); }; +/*! + * \brief Managed reference to AttrStmtNode. + * \sa AttrStmtNode + */ +class AttrStmt : public Stmt { + public: + TVM_DLL AttrStmt(ObjectRef node, String attr_key, PrimExpr value, Stmt body); + + TVM_DEFINE_OBJECT_REF_METHODS(AttrStmt, Stmt, AttrStmtNode); +}; + /*! * \brief Assert condition, if an error occurs, return the error message. */ @@ -160,10 +171,8 @@ class AssertStmtNode : public StmtNode { } bool SEqualReduce(const AssertStmtNode* other, SEqualReducer equal) const { - return - equal(condition, other->condition) && - equal(message, other->message) && - equal(body, other->body); + return equal(condition, other->condition) && equal(message, other->message) && + equal(body, other->body); } void SHashReduce(SHashReducer hash_reduce) const { @@ -172,12 +181,21 @@ class AssertStmtNode : public StmtNode { hash_reduce(body); } - TVM_DLL static Stmt make(PrimExpr condition, PrimExpr message, Stmt body); - - static constexpr const char* _type_key = "AssertStmt"; + static constexpr const char* _type_key = "tir.AssertStmt"; TVM_DECLARE_FINAL_OBJECT_INFO(AssertStmtNode, StmtNode); }; +/*! + * \brief Managed reference to AssertStmtNode. + * \sa AssertStmtNode + */ +class AssertStmt : public Stmt { + public: + TVM_DLL AssertStmt(PrimExpr condition, PrimExpr message, Stmt body); + + TVM_DEFINE_OBJECT_REF_METHODS(AssertStmt, Stmt, AssertStmtNode); +}; + /*! * \brief Store value to the buffer. * @@ -215,11 +233,8 @@ class StoreNode : public StmtNode { } bool SEqualReduce(const StoreNode* other, SEqualReducer equal) const { - return - equal(buffer_var, other->buffer_var) && - equal(value, other->value) && - equal(index, other->index) && - equal(predicate, other->predicate); + return equal(buffer_var, other->buffer_var) && equal(value, other->value) && + equal(index, other->index) && equal(predicate, other->predicate); } void SHashReduce(SHashReducer hash_reduce) const { @@ -229,15 +244,21 @@ class StoreNode : public StmtNode { hash_reduce(predicate); } - TVM_DLL static Stmt make(Var buffer_var, - PrimExpr value, - PrimExpr index, - PrimExpr predicate); - - static constexpr const char* _type_key = "Store"; + static constexpr const char* _type_key = "tir.Store"; TVM_DECLARE_FINAL_OBJECT_INFO(StoreNode, StmtNode); }; +/*! + * \brief Managed reference to StoreNode. + * \sa StoreNode + */ +class Store : public Stmt { + public: + TVM_DLL Store(Var buffer_var, PrimExpr value, PrimExpr index, PrimExpr predicate); + + TVM_DEFINE_OBJECT_REF_METHODS(Store, Stmt, StoreNode); +}; + /*! * \brief Store value to the high dimension buffer. * @@ -248,7 +269,6 @@ class StoreNode : public StmtNode { * \endcode * \sa BufferLoad */ -class BufferStore; class BufferStoreNode : public StmtNode { public: /*! \brief The buffer variable. */ @@ -265,10 +285,8 @@ class BufferStoreNode : public StmtNode { } bool SEqualReduce(const BufferStoreNode* other, SEqualReducer equal) const { - return - equal(buffer, other->buffer) && - equal(value, other->value) && - equal(indices, other->indices); + return equal(buffer, other->buffer) && equal(value, other->value) && + equal(indices, other->indices); } void SHashReduce(SHashReducer hash_reduce) const { @@ -277,61 +295,186 @@ class BufferStoreNode : public StmtNode { hash_reduce(indices); } - static constexpr const char* _type_key = "BufferStore"; + static constexpr const char* _type_key = "tir.BufferStore"; TVM_DECLARE_FINAL_OBJECT_INFO(BufferStoreNode, StmtNode); }; +/*! + * \brief Managed reference to BufferStoreNode. + * \sa BufferStoreNode + */ class BufferStore : public Stmt { public: - TVM_DLL explicit BufferStore(Buffer buffer, - PrimExpr value, - Array indices); + TVM_DLL explicit BufferStore(Buffer buffer, PrimExpr value, Array indices); + TVM_DEFINE_OBJECT_REF_METHODS(BufferStore, Stmt, BufferStoreNode); }; /*! - * \brief Store value into mult-dimensional array defined by func. + * \brief Annotate the region where the buffer need to + * be read and write in the body. + * We only need to allocate the space for the corresponding region. + * + * \note There should be at most one BufferRealize for each buffer. + * BufferRealize is not necessary for external buffers, + * since they are assumed to be fully allocated. + * + * \sa BufferLoad, BufferStore + */ +class BufferRealizeNode : public StmtNode { + public: + /*! \brief The buffer variable. */ + Buffer buffer; + /*! \brief Bounds to be realized */ + Array bounds; + /*! \brief Only realize if condition holds. */ + PrimExpr condition; + /*! \brief The body of realization. */ + Stmt body; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("buffer", &buffer); + v->Visit("bounds", &bounds); + v->Visit("condition", &condition); + v->Visit("body", &body); + } + + bool SEqualReduce(const BufferRealizeNode* other, SEqualReducer equal) const { + return equal(buffer, other->buffer) && equal(bounds, other->bounds) && + equal(condition, other->condition) && equal(body, other->body); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(buffer); + hash_reduce(bounds); + hash_reduce(condition); + hash_reduce(body); + } + + BufferRealizeNode() = default; + BufferRealizeNode(Buffer buffer, Array bounds, PrimExpr condition, Stmt body) + : buffer(buffer), bounds(bounds), condition(condition), body(body) {} + + static constexpr const char* _type_key = "tir.BufferRealize"; + TVM_DECLARE_FINAL_OBJECT_INFO(BufferRealizeNode, StmtNode); +}; + +/*! + * \brief Managed reference to BufferRealizeNode. + * \sa BufferRealizeNode */ -class ProvideNode : public StmtNode { +class BufferRealize : public Stmt { public: - /*! \brief The function to be updated. */ - FunctionRef func; - /*! \brief The output value index if func's value is a tuple. */ - int value_index{0}; + TVM_DLL explicit BufferRealize(Buffer buffer, Array bounds, PrimExpr condition, Stmt body); + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(BufferRealize, Stmt, BufferRealizeNode); +}; + +/*! + * \brief Store value into mult-dimensional array that will be read by the consumer + * of the producer. + * + * \note This node only appears in high-level DSLs that are built on top of the TIR. + * It should not appear in a valid TIR PrimFunc. A high-level DSL needs to lower + * this node before TIR transformations. + * + * \sa DataProducer + */ +class ProducerStoreNode : public StmtNode { + public: + /*! \brief The producer to store the results into. */ + DataProducer producer; /*! \brief The value to be stored. */ PrimExpr value; /*! \brief The index arguments of the function. */ - Array args; + Array indices; void VisitAttrs(AttrVisitor* v) { - v->Visit("func", &func); - v->Visit("value_index", &value_index); + v->Visit("producer", &producer); v->Visit("value", &value); - v->Visit("args", &args); + v->Visit("indices", &indices); } - bool SEqualReduce(const ProvideNode* other, SEqualReducer equal) const { - return - equal(func, other->func) && - equal(value_index, other->value_index) && - equal(value, other->value) && - equal(args, other->args); + bool SEqualReduce(const ProducerStoreNode* other, SEqualReducer equal) const { + return equal(producer, other->producer) && equal(value, other->value) && + equal(indices, other->indices); } void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(func); - hash_reduce(value_index); + hash_reduce(producer); hash_reduce(value); - hash_reduce(args); + hash_reduce(indices); } - TVM_DLL static Stmt make(FunctionRef func, - int value_index, - PrimExpr value, - Array args); + static constexpr const char* _type_key = "tir.ProducerStore"; + TVM_DECLARE_FINAL_OBJECT_INFO(ProducerStoreNode, StmtNode); +}; - static constexpr const char* _type_key = "Provide"; - TVM_DECLARE_FINAL_OBJECT_INFO(ProvideNode, StmtNode); +/*! + * \brief Managed reference to ProducerStoreNode. + * \sa ProducerStoreNode + */ +class ProducerStore : public Stmt { + public: + TVM_DLL ProducerStore(DataProducer producer, PrimExpr value, Array indices); + + TVM_DEFINE_OBJECT_REF_METHODS(ProducerStore, Stmt, ProducerStoreNode); +}; + +/*! + * \brief Annotate the bounds where the data produced by the producer + * need to be written and read in body. + * We will need to allocate space for the corresponding regions. + * + * \note This node only appears in high-level DSLs that are built on top of the TIR. + * It should not appear in a valid TIR PrimFunc. A high-level DSL needs to lower + * this node before TIR transformations. + * + * \sa DataProducer + */ +class ProducerRealizeNode : public StmtNode { + public: + /*! \brief The producer that produces the data. */ + DataProducer producer; + /*! \brief Bounds to be realized. */ + Region bounds; + /*! \brief Only realize if condition holds. */ + PrimExpr condition; + /*! \brief The body of realization. */ + Stmt body; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("producer", &producer); + v->Visit("bounds", &bounds); + v->Visit("condition", &condition); + v->Visit("body", &body); + } + + bool SEqualReduce(const ProducerRealizeNode* other, SEqualReducer equal) const { + return equal(producer, other->producer) && equal(bounds, other->bounds) && + equal(condition, other->condition) && equal(body, other->body); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(producer); + hash_reduce(bounds); + hash_reduce(condition); + hash_reduce(body); + } + + static constexpr const char* _type_key = "tir.ProducerRealize"; + TVM_DECLARE_FINAL_OBJECT_INFO(ProducerRealizeNode, StmtNode); +}; + +/*! + * \brief Managed reference to ProducerRealizeNode. + * \sa ProducerRealizeNode + */ +class ProducerRealize : public Stmt { + public: + TVM_DLL ProducerRealize(DataProducer producer, Region bounds, PrimExpr condition, Stmt body); + + TVM_DEFINE_OBJECT_REF_METHODS(ProducerRealize, Stmt, ProducerRealizeNode); }; /*! @@ -359,12 +502,9 @@ class AllocateNode : public StmtNode { } bool SEqualReduce(const AllocateNode* other, SEqualReducer equal) const { - return - equal.DefEqual(buffer_var, other->buffer_var) && - equal(dtype, other->dtype) && - equal(extents, other->extents) && - equal(condition, other->condition) && - equal(body, other->body); + return equal.DefEqual(buffer_var, other->buffer_var) && equal(dtype, other->dtype) && + equal(extents, other->extents) && equal(condition, other->condition) && + equal(body, other->body); } void SHashReduce(SHashReducer hash_reduce) const { @@ -375,114 +515,63 @@ class AllocateNode : public StmtNode { hash_reduce(body); } - TVM_DLL static Stmt make(Var buffer_var, - DataType dtype, - Array extents, - PrimExpr condition, - Stmt body); - /*! * \brief If the buffer size is constant, return the size. * Otherwise return 0. * \return The result. */ - int32_t constant_allocation_size() const { - return constant_allocation_size(extents); - } + int32_t constant_allocation_size() const { return constant_allocation_size(extents); } /*! * \brief If the buffer size is constant, return the size. * Otherwise return 0. * \param extents The extents of the buffer. * \return The result. */ - TVM_DLL static int32_t constant_allocation_size( - const Array& extents); + TVM_DLL static int32_t constant_allocation_size(const Array& extents); - static constexpr const char* _type_key = "Allocate"; + static constexpr const char* _type_key = "tir.Allocate"; TVM_DECLARE_FINAL_OBJECT_INFO(AllocateNode, StmtNode); }; +/*! + * \brief Managed reference to AllocateNode. + * \sa AllocateNode + */ +class Allocate : public Stmt { + public: + TVM_DLL Allocate(Var buffer_var, DataType dtype, Array extents, PrimExpr condition, + Stmt body); + + TVM_DEFINE_OBJECT_REF_METHODS(Allocate, Stmt, AllocateNode); +}; + /*! \brief Free the resources in the buffer before the scope ends. */ class FreeNode : public StmtNode { public: /*! \brief The buffer variable. */ Var buffer_var; - void VisitAttrs(AttrVisitor* v) { - v->Visit("buffer_var", &buffer_var); - } + void VisitAttrs(AttrVisitor* v) { v->Visit("buffer_var", &buffer_var); } bool SEqualReduce(const FreeNode* other, SEqualReducer equal) const { - return - equal(buffer_var, other->buffer_var); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(buffer_var); + return equal(buffer_var, other->buffer_var); } - TVM_DLL static Stmt make(Var buffer_var); + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(buffer_var); } - static constexpr const char* _type_key = "Free"; + static constexpr const char* _type_key = "tir.Free"; TVM_DECLARE_FINAL_OBJECT_INFO(FreeNode, StmtNode); }; /*! - * \brief Annotate the bounds where func need to be written and read in body. - * We will need to allocate space for the corresponding regions. + * \brief Managed reference to FreeNode. + * \sa FreeNode */ -class RealizeNode : public StmtNode { +class Free : public Stmt { public: - /*! \brief The function to be realized. */ - FunctionRef func; - /*! \brief The output value index if func's value is a tuple. */ - int value_index; - /*! \brief The data type of the array. */ - DataType dtype; - /*! \brief Bounds to be realized. */ - Region bounds; - /*! \brief Only realize if condition holds. */ - PrimExpr condition; - /*! \brief The body of realization. */ - Stmt body; - - void VisitAttrs(AttrVisitor* v) { - v->Visit("func", &func); - v->Visit("value_index", &value_index); - v->Visit("dtype", &dtype); - v->Visit("bounds", &bounds); - v->Visit("condition", &condition); - v->Visit("body", &body); - } - - TVM_DLL static Stmt make(FunctionRef func, - int value_index, - DataType dtype, - Region bounds, - PrimExpr condition, - Stmt body); - - bool SEqualReduce(const RealizeNode* other, SEqualReducer equal) const { - return - equal(func, other->func) && - equal(value_index, other->value_index) && - equal(dtype, other->dtype) && - equal(bounds, other->bounds) && - equal(condition, other->condition) && - equal(body, other->body); - } + TVM_DLL Free(Var buffer_var); - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(func); - hash_reduce(value_index); - hash_reduce(dtype); - hash_reduce(bounds); - hash_reduce(condition); - hash_reduce(body); - } - - static constexpr const char* _type_key = "Realize"; - TVM_DECLARE_FINAL_OBJECT_INFO(RealizeNode, StmtNode); + TVM_DEFINE_OBJECT_REF_METHODS(Free, Stmt, FreeNode); }; /*! @@ -495,29 +584,21 @@ class SeqStmtNode : public StmtNode { Array seq; /*! \return get the size of the sequence */ - size_t size() const { - return seq.size(); - } + size_t size() const { return seq.size(); } /*! * \brief Get the index-th element in the sequence. */ - Stmt operator[](size_t index) const { - return seq[index]; - } + Stmt operator[](size_t index) const { return seq[index]; } - void VisitAttrs(AttrVisitor* v) { - v->Visit("seq", &seq); - } + void VisitAttrs(AttrVisitor* v) { v->Visit("seq", &seq); } bool SEqualReduce(const SeqStmtNode* other, SEqualReducer equal) const { return equal(seq, other->seq); } - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(seq); - } + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(seq); } - static constexpr const char* _type_key = "SeqStmt"; + static constexpr const char* _type_key = "tir.SeqStmt"; TVM_DECLARE_FINAL_OBJECT_INFO(SeqStmtNode, StmtNode); }; @@ -531,15 +612,11 @@ class SeqStmt : public Stmt { TVM_DLL explicit SeqStmt(Array seq); /*! \return get the size of the sequence */ - size_t size() const { - return operator->()->size(); - } + size_t size() const { return operator->()->size(); } /*! * \brief Get the index-th element in the sequence. */ - Stmt operator[](size_t index) const { - return (*(operator->()))[index]; - } + Stmt operator[](size_t index) const { return (*(operator->()))[index]; } /*! * \brief Construct a sequence statement by flattening * all the arrays and sequences in the arguments @@ -556,19 +633,17 @@ class SeqStmt : public Stmt { * \tparam Args arguments * \return The constructed statement */ - template + template static Stmt Flatten(Args&&... seq_args) { Array seq; - runtime::detail::for_each( - Flattener(&seq), std::forward(seq_args)...); + runtime::detail::for_each(Flattener(&seq), std::forward(seq_args)...); if (seq.size() == 1) return seq[0]; return SeqStmt(seq); } /*! \brief Helper class to flatten sequence of arguments into Array. */ class Flattener { public: - explicit Flattener(Array* seq) - : seq_(seq) {} + explicit Flattener(Array* seq) : seq_(seq) {} void operator()(size_t i, const Stmt& stmt) const { if (!stmt.defined()) return; @@ -579,7 +654,7 @@ class SeqStmt : public Stmt { } } - template + template void operator()(size_t i, const T& seq) const { for (auto v : seq) { this->operator()(0, v); @@ -612,10 +687,8 @@ class IfThenElseNode : public StmtNode { } bool SEqualReduce(const IfThenElseNode* other, SEqualReducer equal) const { - return - equal(condition, other->condition) && - equal(then_case, other->then_case) && - equal(else_case, other->else_case); + return equal(condition, other->condition) && equal(then_case, other->then_case) && + equal(else_case, other->else_case); } void SHashReduce(SHashReducer hash_reduce) const { @@ -624,12 +697,21 @@ class IfThenElseNode : public StmtNode { hash_reduce(else_case); } - TVM_DLL static Stmt make(PrimExpr condition, Stmt then_case, Stmt else_case = Stmt()); - - static constexpr const char* _type_key = "IfThenElse"; + static constexpr const char* _type_key = "tir.IfThenElse"; TVM_DECLARE_FINAL_OBJECT_INFO(IfThenElseNode, StmtNode); }; +/*! + * \brief Managed reference to IfThenElseNode. + * \sa IfThenElseNode + */ +class IfThenElse : public Stmt { + public: + TVM_DLL IfThenElse(PrimExpr condition, Stmt then_case, Stmt else_case = Stmt()); + + TVM_DEFINE_OBJECT_REF_METHODS(IfThenElse, Stmt, IfThenElseNode); +}; + /*! * \brief Evaluates an expression. * This is mostly used for putting a Call node into Stmt. @@ -641,24 +723,31 @@ class EvaluateNode : public StmtNode { /*! \brief The expression to be evaluated. */ PrimExpr value; - void VisitAttrs(AttrVisitor* v) { - v->Visit("value", &value); - } + void VisitAttrs(AttrVisitor* v) { v->Visit("value", &value); } bool SEqualReduce(const EvaluateNode* other, SEqualReducer equal) const { return equal(value, other->value); } - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(value); - } - - TVM_DLL static Stmt make(PrimExpr v); + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(value); } - static constexpr const char* _type_key = "Evaluate"; + static constexpr const char* _type_key = "tir.Evaluate"; TVM_DECLARE_FINAL_OBJECT_INFO(EvaluateNode, StmtNode); }; +/*! + * \brief Managed reference to EvaluateNode. + * \sa EvaluateNode + */ +class Evaluate : public Stmt { + public: + TVM_DLL explicit Evaluate(PrimExpr value); + + explicit Evaluate(int value) : Evaluate(PrimExpr(value)) {} + + TVM_DEFINE_OBJECT_REF_METHODS(Evaluate, Stmt, EvaluateNode); +}; + /*! \brief Additional annotation of for loop. */ enum class ForType : int { /*! \brief serial execution. */ @@ -674,9 +763,7 @@ enum class ForType : int { // Kevice api of for loop // kept for backward compatibility // consider refactor and remove later. -enum class DeviceAPI: int { - None = 0 -}; +enum class DeviceAPI : int { None = 0 }; /*! * \brief A for loop, with poissible type annotations. @@ -706,13 +793,6 @@ class ForNode : public StmtNode { /*! \brief The body of the for loop. */ Stmt body; - TVM_DLL static Stmt make(Var loop_var, - PrimExpr min, - PrimExpr extent, - ForType for_type, - DeviceAPI device_api, - Stmt body); - void VisitAttrs(AttrVisitor* v) { v->Visit("loop_var", &loop_var); v->Visit("min", &min); @@ -723,13 +803,9 @@ class ForNode : public StmtNode { } bool SEqualReduce(const ForNode* other, SEqualReducer equal) const { - return - equal.DefEqual(loop_var, other->loop_var) && - equal(min, other->min) && - equal(extent, other->extent) && - equal(for_type, other->for_type) && - equal(device_api, other->device_api) && - equal(body, other->body); + return equal.DefEqual(loop_var, other->loop_var) && equal(min, other->min) && + equal(extent, other->extent) && equal(for_type, other->for_type) && + equal(device_api, other->device_api) && equal(body, other->body); } void SHashReduce(SHashReducer hash_reduce) const { @@ -741,72 +817,62 @@ class ForNode : public StmtNode { hash_reduce(body); } - - static constexpr const char* _type_key = "For"; + static constexpr const char* _type_key = "tir.For"; TVM_DECLARE_FINAL_OBJECT_INFO(ForNode, StmtNode); }; /*! - * \brief A prefetch hint of func. + * \brief Managed reference to ForNode. + * \sa ForNode + */ +class For : public Stmt { + public: + TVM_DLL For(Var loop_var, PrimExpr min, PrimExpr extent, ForType for_type, DeviceAPI device_api, + Stmt body); + + TVM_DEFINE_OBJECT_REF_METHODS(For, Stmt, ForNode); +}; + +/*! + * \brief A prefetch hint for abuffer */ class PrefetchNode : public StmtNode { public: /*! \brief The function to be prefetched. */ - FunctionRef func; - /*! \brief The output value index if func's value is a tuple. */ - int value_index; - /*! \brief The data type of the array. */ - DataType dtype; + Buffer buffer; /*! \brief Bounds to be prefetched. */ - Region bounds; + Array bounds; void VisitAttrs(AttrVisitor* v) { - v->Visit("func", &func); - v->Visit("value_index", &value_index); - v->Visit("dtype", &dtype); + v->Visit("buffer", &buffer); v->Visit("bounds", &bounds); } bool SEqualReduce(const PrefetchNode* other, SEqualReducer equal) const { - return - equal(func, other->func) && - equal(value_index, other->value_index) && - equal(dtype, other->dtype) && - equal(bounds, other->bounds); + return equal(buffer, other->buffer) && equal(bounds, other->bounds); } void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(func); - hash_reduce(value_index); - hash_reduce(dtype); + hash_reduce(buffer); hash_reduce(bounds); } - TVM_DLL static Stmt make(FunctionRef func, - int value_index, - DataType dtype, - Region bounds); + PrefetchNode() = default; + PrefetchNode(Buffer buffer, Array bounds) : buffer(buffer), bounds(bounds) {} - static constexpr const char* _type_key = "Prefetch"; + static constexpr const char* _type_key = "tir.Prefetch"; TVM_DECLARE_FINAL_OBJECT_INFO(PrefetchNode, StmtNode); }; /*! - * \brief Auxiliary data structure used in IR Pass to indicate a tensor. + * \brief Managed reference to PrefetchNode. + * \sa PrefetchNode */ -struct TensorKey { - FunctionRef f; - int value_index; +class Prefetch : public Stmt { + public: + TVM_DLL explicit Prefetch(Buffer buffer, Array bounds); - inline bool operator==(const TensorKey& other) const { - return f == other.f && value_index == other.value_index; - } - inline std::string GetName() const { - if (f->num_outputs() == 1) return f->func_name(); - std::ostringstream os; - os << f->func_name() << ".v" << value_index; - return os.str(); - } + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Prefetch, Stmt, PrefetchNode); }; /*! \brief namespace of possible attribute sin AttrStmt.attr_key */ @@ -852,6 +918,8 @@ constexpr const char* loop_scope = "loop_scope"; constexpr const char* reduce_scope = "reduce_scope"; /*! \brief Mark region is guarded by the pragma extension */ constexpr const char* pragma_scope_prefix = "pragma_"; +/*! \brief Import C source or file into the final code gen module */ +constexpr const char* pragma_import_c = "pragma_import_c"; /*! \brief Import llvm source or file into the final code gen module */ constexpr const char* pragma_import_llvm = "pragma_import_llvm"; /*! \brief Try to modify the AST to support Tensor Core */ @@ -905,13 +973,6 @@ constexpr const char* channel_write_advance = "channel_write_advance"; constexpr const char* pipeline_stage_scope = "pipeline_stage_scope"; /*! \brief pipeline execution scope, implies the scope can be pipelined. */ constexpr const char* pipeline_exec_scope = "pipeline_exec_scope"; -/*! - * \brief Mark that this stage is an OpenGL shader. Since OpenGL shader only - * allows writing out to one element of the output texture, the Provide node - * gets translated to a special Call::glsl_texture_store statement instead of a - * Store statement. - */ -constexpr const char* opengl_stage_scope = "opengl_stage_scope"; /*! * \brief Mark that it is in the device scope. @@ -944,9 +1005,7 @@ inline bool IsPragmaKey(const std::string& attr_key) { * \return Expr a expression with dtype. */ inline PrimExpr TypeAnnotation(DataType dtype) { - return tir::CallNode::make(dtype, - "type_annotation", {}, - tir::CallNode::PureIntrinsic); + return tir::Call(dtype, "type_annotation", {}, tir::CallNode::PureIntrinsic); } // overload printing of for type. @@ -954,17 +1013,4 @@ TVM_DLL std::ostream& operator<<(std::ostream& os, ForType for_type); } // namespace tir } // namespace tvm - -namespace std { -template <> -struct hash<::tvm::tir::TensorKey> { - std::size_t operator()(const ::tvm::tir::TensorKey& k) const { - size_t lhs = ::tvm::ObjectHash()(k.f); - size_t rhs = static_cast(k.value_index); - lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2); - return lhs; - } -}; -} // namespace std - #endif // TVM_TIR_STMT_H_ diff --git a/include/tvm/tir/stmt_functor.h b/include/tvm/tir/stmt_functor.h index f93e9080a377..f037de7d2ba8 100644 --- a/include/tvm/tir/stmt_functor.h +++ b/include/tvm/tir/stmt_functor.h @@ -20,16 +20,19 @@ /*! * \file tvm/tir/stmt_functor.h * - * \brief Functors for tir stmts. + * \brief Functors for tir stmts + * utility functions to call common functors. */ #ifndef TVM_TIR_STMT_FUNCTOR_H_ #define TVM_TIR_STMT_FUNCTOR_H_ +#include #include #include -#include #include +#include +#include #include namespace tvm { @@ -39,22 +42,18 @@ namespace tir { * \tparam FType The function signature. * \sa ExprFunctor */ -template +template class StmtFunctor; -#define STMT_FUNCTOR_DEFAULT { \ - return VisitStmtDefault_(op, std::forward(args)...); \ - } - -#define IR_STMT_FUNCTOR_DISPATCH(OP) \ - vtable.template set_dispatch( \ - [](const ObjectRef& n, TSelf* self, Args... args) { \ - return self->VisitStmt_(static_cast(n.get()), \ - std::forward(args)...); \ - }); \ +#define STMT_FUNCTOR_DEFAULT \ + { return VisitStmtDefault_(op, std::forward(args)...); } +#define IR_STMT_FUNCTOR_DISPATCH(OP) \ + vtable.template set_dispatch([](const ObjectRef& n, TSelf* self, Args... args) { \ + return self->VisitStmt_(static_cast(n.get()), std::forward(args)...); \ + }); -template +template class StmtFunctor { private: using TSelf = StmtFunctor; @@ -71,9 +70,7 @@ class StmtFunctor { * \param args Additional arguments. * \return The result of the call */ - R operator()(const Stmt& n, Args... args) { - return VisitStmt(n, std::forward(args)...); - } + R operator()(const Stmt& n, Args... args) { return VisitStmt(n, std::forward(args)...); } /*! * \brief The functor call. * \param n The stmt node. @@ -92,14 +89,15 @@ class StmtFunctor { virtual R VisitStmt_(const AllocateNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const StoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const BufferStoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT; + virtual R VisitStmt_(const BufferRealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const FreeNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const AssertStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT; - virtual R VisitStmt_(const ProvideNode* op, Args... args) STMT_FUNCTOR_DEFAULT; - virtual R VisitStmt_(const RealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT; + virtual R VisitStmt_(const ProducerStoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT; + virtual R VisitStmt_(const ProducerRealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const PrefetchNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const SeqStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const EvaluateNode* op, Args... args) STMT_FUNCTOR_DEFAULT; - virtual R VisitStmtDefault_(const Object* op, Args ...) { + virtual R VisitStmtDefault_(const Object* op, Args...) { LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); return R(); } @@ -116,11 +114,13 @@ class StmtFunctor { IR_STMT_FUNCTOR_DISPATCH(StoreNode); IR_STMT_FUNCTOR_DISPATCH(FreeNode); IR_STMT_FUNCTOR_DISPATCH(AssertStmtNode); - IR_STMT_FUNCTOR_DISPATCH(ProvideNode); - IR_STMT_FUNCTOR_DISPATCH(RealizeNode); + IR_STMT_FUNCTOR_DISPATCH(ProducerStoreNode); + IR_STMT_FUNCTOR_DISPATCH(ProducerRealizeNode); IR_STMT_FUNCTOR_DISPATCH(PrefetchNode); IR_STMT_FUNCTOR_DISPATCH(SeqStmtNode); IR_STMT_FUNCTOR_DISPATCH(EvaluateNode); + IR_STMT_FUNCTOR_DISPATCH(BufferStoreNode); + IR_STMT_FUNCTOR_DISPATCH(BufferRealizeNode); return vtable; } }; @@ -131,8 +131,7 @@ class StmtFunctor { /*! * \brief StmtVisitor. */ -class TVM_DLL StmtVisitor : - protected StmtFunctor { +class TVM_DLL StmtVisitor : protected StmtFunctor { public: using StmtFunctor::operator(); @@ -154,10 +153,11 @@ class TVM_DLL StmtVisitor : void VisitStmt_(const AllocateNode* op) override; void VisitStmt_(const StoreNode* op) override; void VisitStmt_(const BufferStoreNode* op) override; + void VisitStmt_(const BufferRealizeNode* op) override; void VisitStmt_(const FreeNode* op) override; void VisitStmt_(const AssertStmtNode* op) override; - void VisitStmt_(const ProvideNode* op) override; - void VisitStmt_(const RealizeNode* op) override; + void VisitStmt_(const ProducerStoreNode* op) override; + void VisitStmt_(const ProducerRealizeNode* op) override; void VisitStmt_(const PrefetchNode* op) override; void VisitStmt_(const SeqStmtNode* op) override; void VisitStmt_(const EvaluateNode* op) override; @@ -166,8 +166,7 @@ class TVM_DLL StmtVisitor : /*! * \brief StmtMutator that mutates the statements. */ -class TVM_DLL StmtMutator : - protected StmtFunctor { +class TVM_DLL StmtMutator : protected StmtFunctor { public: /*! * \brief Mutate stmt. @@ -203,7 +202,7 @@ class TVM_DLL StmtMutator : * * \return The result object pointer. */ - template + template ObjectPtr CopyOnWrite(const TNode* node) { if (allow_copy_on_write_) { // return the old node. @@ -237,9 +236,7 @@ class TVM_DLL StmtMutator : * or have a class sub-class both StmtMutator and ExprMutator * and redirect Mutate to ExprMutator::Mutate(Expr) */ - virtual PrimExpr VisitExpr(const PrimExpr& e) { - return e; - } + virtual PrimExpr VisitExpr(const PrimExpr& e) { return e; } // statement visitor Stmt VisitStmt_(const AttrStmtNode* op) override; Stmt VisitStmt_(const IfThenElseNode* op) override; @@ -248,10 +245,11 @@ class TVM_DLL StmtMutator : Stmt VisitStmt_(const AllocateNode* op) override; Stmt VisitStmt_(const StoreNode* op) override; Stmt VisitStmt_(const BufferStoreNode* op) override; + Stmt VisitStmt_(const BufferRealizeNode* op) override; Stmt VisitStmt_(const FreeNode* op) override; Stmt VisitStmt_(const AssertStmtNode* op) override; - Stmt VisitStmt_(const ProvideNode* op) override; - Stmt VisitStmt_(const RealizeNode* op) override; + Stmt VisitStmt_(const ProducerStoreNode* op) override; + Stmt VisitStmt_(const ProducerRealizeNode* op) override; Stmt VisitStmt_(const PrefetchNode* op) override; Stmt VisitStmt_(const SeqStmtNode* op) override; Stmt VisitStmt_(const EvaluateNode* op) override; @@ -267,8 +265,7 @@ class TVM_DLL StmtMutator : * \param fmutate The mutate function, can be nullptr, which defaults to Visit. * \return The mutated result. */ - Stmt VisitSeqStmt_(const SeqStmtNode* op, - bool flatten_before_visit, + Stmt VisitSeqStmt_(const SeqStmtNode* op, bool flatten_before_visit, std::function fmutate = nullptr); // internal helper. class Internal; @@ -277,45 +274,37 @@ class TVM_DLL StmtMutator : /*! * \brief Visitor that recursively visit stmts and exprs on them. */ -class StmtExprVisitor : - public StmtVisitor, - public ExprVisitor { +class StmtExprVisitor : public StmtVisitor, public ExprVisitor { public: using StmtVisitor::operator(); using ExprVisitor::operator(); protected: - using StmtVisitor::VisitStmt; using ExprVisitor::VisitExpr; + using StmtVisitor::VisitStmt; - void VisitExpr(const PrimExpr& e) override { - return ExprVisitor::VisitExpr(e); - } + void VisitExpr(const PrimExpr& e) override { return ExprVisitor::VisitExpr(e); } }; /*! * \brief Mutator that recursively mutates stmts and exprs on them. */ -class StmtExprMutator : - public StmtMutator, - public ExprMutator { +class StmtExprMutator : public StmtMutator, public ExprMutator { public: using StmtMutator::operator(); using ExprMutator::operator(); protected: - using StmtMutator::VisitExpr; using ExprMutator::VisitExpr; + using StmtMutator::VisitExpr; - PrimExpr VisitExpr(const PrimExpr& e) override { - return ExprMutator::VisitExpr(e); - } + PrimExpr VisitExpr(const PrimExpr& e) override { return ExprMutator::VisitExpr(e); } }; /*! - * \brief recursively visit the ir in post DFS order node, and transform it + * \brief recursively visit the ir nodes in post DFS order, and transform it * - * \param node The ir to be transformed. + * \param stmt The ir to be transformed. * \param preorder The function called in before recursive mutation * If preorder returns None, then the transform will proceed to recursive call. * If preorder returns a not None Stmt/Expr, the transformer will simply return it and @@ -323,23 +312,72 @@ class StmtExprMutator : * \param postorder The function called after recursive mutation. * The recursive mutation result is passed to postorder for further mutation. * \param only_enable List of runtime::String. - * If it is empty, all IRNode will call preorder/postorder - * If it is not empty, preorder/postorder will only be called + * If it is null, all IRNode will call preorder/postorder + * If it is not null, preorder/postorder will only be called * when the IRNode's type key is in the list. */ -TVM_DLL Stmt IRTransform(Stmt node, - const runtime::PackedFunc& preorder, +TVM_DLL Stmt IRTransform(Stmt stmt, const runtime::PackedFunc& preorder, const runtime::PackedFunc& postorder, - const Array& only_enable = {}); + Optional> only_enable = NullOpt); /*! - * \brief recursively visit the ir in post DFS order node, apply fvisit + * \brief Recursively visit the ir in post DFS order node, apply fvisit * Each node is guaranteed to be visited only once. * \param node The ir to be visited. * \param fvisit The visitor function to be applied. */ TVM_DLL void PostOrderVisit(const ObjectRef& node, std::function fvisit); +/*! + * \brief Substitute the var specified by vmap. + * \param stmt The source statement to be substituted + * \param vmap returns a new value if re-mapping is needed, otherwise returns nullptr. + * \return The converted form. + */ +TVM_DLL Stmt Substitute(Stmt stmt, std::function(const Var& var)> vmap); + +/*! + * \brief Substitute the var specified by vmap. + * \param expr The source statement to be substituted + * \param vmap returns a new value if re-mapping is needed, otherwise returns nullptr. + * \return The result. + */ +TVM_DLL PrimExpr Substitute(PrimExpr expr, std::function(const Var& var)> vmap); + +/*! + * \brief Sugar for substitute via a given map. + * \param input The input to be updated. + * \param value_map The map of new values. + * \return The result. + * \tparam T the input type, can be PrimExpr or Stmt. + */ +template +inline auto Substitute(T input, const Map& value_map) { + auto vmap = [&](const Var& var) -> Optional { + auto it = value_map.find(var); + if (it != value_map.end()) return (*it).second; + return Optional(nullptr); + }; + return Substitute(std::move(input), vmap); +} + +/*! + * \brief Sugar for substitute via a given map. + * \param input The input to be updated. + * \param value_map The map of new values. + * \return The result. + * \tparam T the input type, can be PrimExpr or Stmt. + */ +template +inline T Substitute(T input, const std::unordered_map& value_map) { + auto vmap = [&](const Var& var) -> Optional { + auto it = value_map.find(var.get()); + if (it != value_map.end()) return (*it).second; + return Optional(nullptr); + }; + return Substitute(std::move(input), vmap); +} + } // namespace tir } // namespace tvm diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index 23c195563ac2..a794c12b55ee 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -35,11 +35,11 @@ namespace tir { namespace transform { using tvm::transform::Pass; -using tvm::transform::PassNode; -using tvm::transform::PassInfo; -using tvm::transform::PassInfoNode; using tvm::transform::PassContext; using tvm::transform::PassContextNode; +using tvm::transform::PassInfo; +using tvm::transform::PassInfoNode; +using tvm::transform::PassNode; using tvm::transform::Sequential; /* @@ -52,11 +52,134 @@ using tvm::transform::Sequential; * * \return The created function pass. */ -TVM_DLL Pass CreatePrimFuncPass(const runtime::TypedPackedFunc< - PrimFunc(PrimFunc, IRModule, PassContext)>& pass_func, - int opt_level, - const std::string& name, - const tvm::Array& required); +TVM_DLL Pass CreatePrimFuncPass( + const runtime::TypedPackedFunc& pass_func, + int opt_level, String name, tvm::Array required); + +/*! + * \brief Inject prefetch instructions into stmt. + * + * \return The pass. + */ +TVM_DLL Pass InjectPrefetch(); + +// TODO(tvm-team): consolidate configs to the PassContext +/*! + * \brief Flatten the multi-dimensional read/write + * to single dimensional Load/Store + * + * \param cache_line_size The size of CPU cache line. + * \param create_bound_attribute Whether to create bound attributes. + * + * \return The Pass + */ +TVM_DLL Pass StorageFlatten(int cache_line_size, bool create_bound_attribute = false); + +/*! + * \brief Inject copy intrinsics with optional pad. + * + * \param pragma_key The pragma key for hint of copy. + * \param fintrin The function with signature + * + * Stmt fintrin(Buffer src, + * Buffer dst, + * Array pad_before, + * Array pad_after, + * Expr pad_value) + * \return The pass. + */ +TVM_DLL Pass InjectCopyIntrin(String pragma_key, runtime::PackedFunc fintrin); + +/*! + * \brief Detect and insert sync points to co-processor. + * + * \return The pass. + */ +TVM_DLL Pass CoProcSync(); + +/*! + * \brief Lift common attrs with attr_key to outer scope. + * + * \param attr_key The attribute key to be checked. + * \return The pass. + */ +TVM_DLL Pass LiftAttrScope(String attr_key); + +/*! + * \brief partition loops in the stmt. + * + * \return The pass. + */ +TVM_DLL Pass LoopPartition(); + +/*! + * \brief Lower vectorization loops. + * + * \param enable_vectorize Whether vectorization is enabled. + * + * \return The pass. + */ +TVM_DLL Pass VectorizeLoop(bool enable_vectorize = true); + +/*! + * \brief Inject virtual thread loops. + * + * \return The pass. + */ +TVM_DLL Pass InjectVirtualThread(); + +/*! + * \brief Inject double buffer statements. + * + * \return The pass. + */ +TVM_DLL Pass InjectDoubleBuffer(); + +/*! + * \brief Rewrite storage allocation pattern. + * Moves the allocation to outer most possible scope. + * Trying to share space between allocations to make + * a static allocation plan when possible. + * + * \return The pass. + */ +TVM_DLL Pass StorageRewrite(); + +/*! + * \brief unroll the constant loop marked by unroll. + * This pass also automatically attach pragma unroll tag to loops which meets the standard. + * + * \return The pass. + */ +TVM_DLL Pass UnrollLoop(); + +/*! + * \brief Remove No Op from the Stmt. + * + * \return The pass. + */ +TVM_DLL Pass RemoveNoOp(); + +/*! + * \brief Detect and rewrite unsafe select that contains memory access. + * + * \return The pass. + */ +TVM_DLL Pass RewriteUnsafeSelect(); + +/*! + * \brief Run arithmetic simplifications on the statements and expressions. + * + * \return The pass. + */ +TVM_DLL Pass Simplify(); + +/*! + * \brief Instruments bound checkers. + * + * \return The pass. + */ +TVM_DLL Pass InstrumentBoundCheckers(); /*! * \brief Transform the high-level PrimFunc to a low-level version @@ -89,7 +212,6 @@ TVM_DLL Pass CreatePrimFuncPass(const runtime::TypedPackedFunc< */ TVM_DLL Pass MakePackedAPI(int num_unpacked_args); - /*! * \brief Remap the thread axis * @@ -100,8 +222,7 @@ TVM_DLL Pass MakePackedAPI(int num_unpacked_args); * * \return The pass. */ -TVM_DLL Pass RemapThreadAxis(Map axis_map); - +TVM_DLL Pass RemapThreadAxis(Map axis_map); /*! * \brief Lower custom datatypes. @@ -112,6 +233,13 @@ TVM_DLL Pass RemapThreadAxis(Map axis_map); */ TVM_DLL Pass LowerCustomDatatypes(); +/*! + * \brief Decorate all the function's body as device function. + * + * \return The pass. + */ +TVM_DLL Pass DecorateDeviceScope(); + /*! * \brief Split the function into a host function and device functions. * @@ -132,8 +260,7 @@ TVM_DLL Pass SkipAssert(); * \param storage_scope The storage scope considered. * \return The pass. */ -TVM_DLL Pass ThreadSync(std::string storage_scope); - +TVM_DLL Pass ThreadSync(String storage_scope); /*! * \brief Lower cross thread alleduce. @@ -184,7 +311,6 @@ TVM_DLL Pass LowerDeviceStorageAccessInfo(); */ TVM_DLL Pass CombineContextCall(); - /*! * \brief Narrow down PrimExpr datatype in stmt to target_bits. * @@ -195,6 +321,16 @@ TVM_DLL Pass CombineContextCall(); */ TVM_DLL Pass NarrowDataType(int target_bits); +/*! + * \brief Rewrite the pointer content type of arguments, + * as well as Alloc internal to the function to use + * the most frequently accessed type for load/store + * to avoid pointer casting in backend when possible. + * + * \return The pass. + */ +TVM_DLL Pass PointerValueTypeRewrite(); + } // namespace transform } // namespace tir } // namespace tvm diff --git a/include/tvm/tir/var.h b/include/tvm/tir/var.h index 19c904a1230f..f1651c118010 100644 --- a/include/tvm/tir/var.h +++ b/include/tvm/tir/var.h @@ -24,9 +24,10 @@ #ifndef TVM_TIR_VAR_H_ #define TVM_TIR_VAR_H_ +#include #include #include -#include + #include namespace tvm { @@ -49,7 +50,7 @@ class VarNode : public PrimExprNode { * \brief The hint to the variable name. * \note Each variable is uniquely identified by its address. */ - std::string name_hint; + String name_hint; /*! * \brief type annotaion of the variable. * @@ -78,10 +79,11 @@ class VarNode : public PrimExprNode { } static constexpr const char* _type_key = "tir.Var"; + static constexpr const uint32_t _type_child_slots = 1; TVM_DECLARE_BASE_OBJECT_INFO(VarNode, PrimExprNode); }; -/*! \brief a named variable in TVM */ +/*! \brief a named variable in TIR */ class Var : public PrimExpr { public: explicit Var(ObjectPtr n) : PrimExpr(n) {} @@ -90,34 +92,30 @@ class Var : public PrimExpr { * \param name_hint variable name * \param dtype data type */ - TVM_DLL explicit Var(std::string name_hint = "v", - DataType dtype = DataType::Int(32)); + TVM_DLL explicit Var(String name_hint = "v", DataType dtype = DataType::Int(32)); /*! * \brief Constructor which provides a more detailed type annotation. * \param name_hint variable name. * \param type_annotation The type annotation. */ - TVM_DLL explicit Var(std::string name_hint, Type type_annotation); + TVM_DLL explicit Var(String name_hint, Type type_annotation); /*! * \brief Make a new copy of var with same type, append suffix * \param suffix The suffix to be appended. * \return the new Var copy */ - TVM_DLL Var copy_with_suffix(const std::string& suffix) const; + TVM_DLL Var copy_with_suffix(const String& suffix) const; + /*! * \brief Get pointer to the internal value. * \return the corresponding Variable. */ - const VarNode* operator->() const { - return get(); - } + const VarNode* operator->() const { return get(); } /*! * \brief Get pointer to the internal value. * \return the corresponding Variable. */ - const VarNode* get() const { - return static_cast(data_.get()); - } + const VarNode* get() const { return static_cast(data_.get()); } /*! \brief type indicate the container type */ using ContainerType = VarNode; }; @@ -141,30 +139,21 @@ class SizeVar : public Var { * \param name_hint variable name * \param t data type */ - TVM_DLL explicit SizeVar(std::string name_hint = "s", - DataType t = DataType::Int(32)); + TVM_DLL explicit SizeVar(String name_hint = "s", DataType t = DataType::Int(32)); /*! * \brief Get pointer to the internal value. * \return the corresponding Variable. */ - const SizeVarNode* operator->() const { - return get(); - } + const SizeVarNode* operator->() const { return get(); } /*! * \brief Get pointer to the internal value. * \return the corresponding Variable. */ - const SizeVarNode* get() const { - return static_cast(data_.get()); - } + const SizeVarNode* get() const { return static_cast(data_.get()); } /*! \brief type indicate the container type */ using ContainerType = SizeVarNode; }; - -/*! \brief container class of iteration variable. */ -class IterVarNode; - using Region = Array; /*! @@ -187,7 +176,7 @@ enum IterVarType : int { /*! * \brief The IterVar itself is a thread-index * of a fixed thread launching group. - * Note that this is already assumed to be paralellized. + * Note that this is already assumed to be parallelized. * * Disallow: split/fuse/vectorize/parallel */ @@ -237,31 +226,6 @@ enum IterVarType : int { kTensorized = 8 }; -/*! - * \brief Iteration Variable, - * represents an iteration over an integer interval. - */ -class IterVar : public ObjectRef { - public: - // construct a new iter var without a domain - IterVar() {} - // construct from shared ptr. - explicit IterVar(ObjectPtr n) : ObjectRef(n) {} - /*! - * \brief access the internal node container - * \return the pointer to the internal node container - */ - inline const IterVarNode* operator->() const; - /*! - * \return the corresponding var in the IterVar. - */ - inline operator PrimExpr() const; - /*! \brief specify container node */ - using ContainerType = IterVarNode; -}; - -using Domain = Array; - /*! * \brief An iteration variable representing an iteration * over a one dimensional interval. @@ -281,7 +245,7 @@ class IterVarNode : public Object { * \brief additional tag on the iteration variable, * set this if this is binded already to a known thread tag. */ - std::string thread_tag; + String thread_tag; void VisitAttrs(AttrVisitor* v) { v->Visit("dom", &dom); @@ -291,11 +255,8 @@ class IterVarNode : public Object { } bool SEqualReduce(const IterVarNode* other, SEqualReducer equal) const { - return - equal(dom, other->dom) && - equal.DefEqual(var, other->var) && - equal(iter_type, other->iter_type) && - equal(thread_tag, other->thread_tag); + return equal(dom, other->dom) && equal.DefEqual(var, other->var) && + equal(iter_type, other->iter_type) && equal(thread_tag, other->thread_tag); } void SHashReduce(SHashReducer hash_reduce) const { @@ -305,36 +266,50 @@ class IterVarNode : public Object { hash_reduce(thread_tag); } - TVM_DLL static IterVar make(Range dom, Var var, - IterVarType iter_type, - std::string thread_tag = ""); - - static constexpr const char* _type_key = "IterVar"; + static constexpr const char* _type_key = "tir.IterVar"; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true; TVM_DECLARE_FINAL_OBJECT_INFO(IterVarNode, Object); }; -// inline implementations -inline const IterVarNode* IterVar::operator->() const { - return static_cast(data_.get()); -} +/*! + * \brief Iteration Variable, + * represents an iteration over an integer interval. + */ +class IterVar : public ObjectRef { + public: + TVM_DLL IterVar(Range dom, Var var, IterVarType iter_type, String thread_tag = ""); + /*! + * \return the corresponding var in the IterVar. + */ + inline operator PrimExpr() const; -inline IterVar::operator PrimExpr() const { - return (*this)->var; -} + TVM_DEFINE_OBJECT_REF_METHODS(IterVar, ObjectRef, IterVarNode); +}; + +// inline implementations +inline IterVar::operator PrimExpr() const { return (*this)->var; } inline const char* IterVarType2String(IterVarType t) { switch (t) { - case kDataPar: return "DataPar"; - case kThreadIndex: return "ThreadIndex"; - case kCommReduce: return "CommReduce"; - case kOrdered: return "Ordered"; - case kOpaque: return "Opaque"; - case kUnrolled: return "Unrolled"; - case kVectorized: return "Vectorized"; - case kParallelized: return "Parallelized"; - case kTensorized: return "Tensorized"; + case kDataPar: + return "DataPar"; + case kThreadIndex: + return "ThreadIndex"; + case kCommReduce: + return "CommReduce"; + case kOrdered: + return "Ordered"; + case kOpaque: + return "Opaque"; + case kUnrolled: + return "Unrolled"; + case kVectorized: + return "Vectorized"; + case kParallelized: + return "Parallelized"; + case kTensorized: + return "Tensorized"; } return "Unknown"; } diff --git a/jvm/core/src/main/java/org/apache/tvm/TypeCode.java b/jvm/core/src/main/java/org/apache/tvm/ArgTypeCode.java similarity index 95% rename from jvm/core/src/main/java/org/apache/tvm/TypeCode.java rename to jvm/core/src/main/java/org/apache/tvm/ArgTypeCode.java index 2d21e4afa6b4..b3b3da56e72f 100644 --- a/jvm/core/src/main/java/org/apache/tvm/TypeCode.java +++ b/jvm/core/src/main/java/org/apache/tvm/ArgTypeCode.java @@ -18,14 +18,14 @@ package org.apache.tvm; // Type code used in API calls -public enum TypeCode { +public enum ArgTypeCode { INT(0), UINT(1), FLOAT(2), HANDLE(3), NULL(4), TVM_TYPE(5), TVM_CONTEXT(6), ARRAY_HANDLE(7), NODE_HANDLE(8), MODULE_HANDLE(9), FUNC_HANDLE(10), STR(11), BYTES(12), NDARRAY_CONTAINER(13); public final int id; - private TypeCode(int id) { + private ArgTypeCode(int id) { this.id = id; } diff --git a/jvm/core/src/main/java/org/apache/tvm/Function.java b/jvm/core/src/main/java/org/apache/tvm/Function.java index a9ac70722410..df535a87aa85 100644 --- a/jvm/core/src/main/java/org/apache/tvm/Function.java +++ b/jvm/core/src/main/java/org/apache/tvm/Function.java @@ -80,7 +80,7 @@ private static Function getGlobalFunc(String name, boolean isResident, boolean a * @param isResident Whether this is a resident function in jvm */ Function(long handle, boolean isResident) { - super(TypeCode.FUNC_HANDLE); + super(ArgTypeCode.FUNC_HANDLE); this.handle = handle; this.isResident = isResident; } @@ -187,7 +187,7 @@ public Function pushArg(String arg) { * @return this */ public Function pushArg(NDArrayBase arg) { - int id = arg.isView ? TypeCode.ARRAY_HANDLE.id : TypeCode.NDARRAY_CONTAINER.id; + int id = arg.isView ? ArgTypeCode.ARRAY_HANDLE.id : ArgTypeCode.NDARRAY_CONTAINER.id; Base._LIB.tvmFuncPushArgHandle(arg.handle, id); return this; } @@ -198,7 +198,7 @@ public Function pushArg(NDArrayBase arg) { * @return this */ public Function pushArg(Module arg) { - Base._LIB.tvmFuncPushArgHandle(arg.handle, TypeCode.MODULE_HANDLE.id); + Base._LIB.tvmFuncPushArgHandle(arg.handle, ArgTypeCode.MODULE_HANDLE.id); return this; } @@ -208,7 +208,7 @@ public Function pushArg(Module arg) { * @return this */ public Function pushArg(Function arg) { - Base._LIB.tvmFuncPushArgHandle(arg.handle, TypeCode.FUNC_HANDLE.id); + Base._LIB.tvmFuncPushArgHandle(arg.handle, ArgTypeCode.FUNC_HANDLE.id); return this; } @@ -249,12 +249,12 @@ private static void pushArgToStack(Object arg) { Base._LIB.tvmFuncPushArgBytes((byte[]) arg); } else if (arg instanceof NDArrayBase) { NDArrayBase nd = (NDArrayBase) arg; - int id = nd.isView ? TypeCode.ARRAY_HANDLE.id : TypeCode.NDARRAY_CONTAINER.id; + int id = nd.isView ? ArgTypeCode.ARRAY_HANDLE.id : ArgTypeCode.NDARRAY_CONTAINER.id; Base._LIB.tvmFuncPushArgHandle(nd.handle, id); } else if (arg instanceof Module) { - Base._LIB.tvmFuncPushArgHandle(((Module) arg).handle, TypeCode.MODULE_HANDLE.id); + Base._LIB.tvmFuncPushArgHandle(((Module) arg).handle, ArgTypeCode.MODULE_HANDLE.id); } else if (arg instanceof Function) { - Base._LIB.tvmFuncPushArgHandle(((Function) arg).handle, TypeCode.FUNC_HANDLE.id); + Base._LIB.tvmFuncPushArgHandle(((Function) arg).handle, ArgTypeCode.FUNC_HANDLE.id); } else if (arg instanceof TVMValue) { TVMValue tvmArg = (TVMValue) arg; switch (tvmArg.typeCode) { diff --git a/jvm/core/src/main/java/org/apache/tvm/Module.java b/jvm/core/src/main/java/org/apache/tvm/Module.java index 1656f8dee6fa..874daa4029dc 100644 --- a/jvm/core/src/main/java/org/apache/tvm/Module.java +++ b/jvm/core/src/main/java/org/apache/tvm/Module.java @@ -45,7 +45,7 @@ private static Function getApi(String name) { } Module(long handle) { - super(TypeCode.MODULE_HANDLE); + super(ArgTypeCode.MODULE_HANDLE); this.handle = handle; } @@ -138,7 +138,7 @@ public String typeKey() { */ public static Module load(String path, String fmt) { TVMValue ret = getApi("ModuleLoadFromFile").pushArg(path).pushArg(fmt).invoke(); - assert ret.typeCode == TypeCode.MODULE_HANDLE; + assert ret.typeCode == ArgTypeCode.MODULE_HANDLE; return ret.asModule(); } diff --git a/jvm/core/src/main/java/org/apache/tvm/NDArrayBase.java b/jvm/core/src/main/java/org/apache/tvm/NDArrayBase.java index 5ac630d3a668..26bb735e1a5b 100644 --- a/jvm/core/src/main/java/org/apache/tvm/NDArrayBase.java +++ b/jvm/core/src/main/java/org/apache/tvm/NDArrayBase.java @@ -27,7 +27,7 @@ public class NDArrayBase extends TVMValue { private boolean isReleased = false; NDArrayBase(long handle, boolean isView) { - super(TypeCode.ARRAY_HANDLE); + super(ArgTypeCode.ARRAY_HANDLE); this.handle = handle; this.isView = isView; } diff --git a/jvm/core/src/main/java/org/apache/tvm/TVMValue.java b/jvm/core/src/main/java/org/apache/tvm/TVMValue.java index 92c7623b2dc1..d30cfcc4f30a 100644 --- a/jvm/core/src/main/java/org/apache/tvm/TVMValue.java +++ b/jvm/core/src/main/java/org/apache/tvm/TVMValue.java @@ -18,9 +18,9 @@ package org.apache.tvm; public class TVMValue { - public final TypeCode typeCode; + public final ArgTypeCode typeCode; - public TVMValue(TypeCode tc) { + public TVMValue(ArgTypeCode tc) { typeCode = tc; } diff --git a/jvm/core/src/main/java/org/apache/tvm/TVMValueBytes.java b/jvm/core/src/main/java/org/apache/tvm/TVMValueBytes.java index 6c7c1c892747..132d88f7622b 100644 --- a/jvm/core/src/main/java/org/apache/tvm/TVMValueBytes.java +++ b/jvm/core/src/main/java/org/apache/tvm/TVMValueBytes.java @@ -21,7 +21,7 @@ public class TVMValueBytes extends TVMValue { public final byte[] value; public TVMValueBytes(byte[] value) { - super(TypeCode.BYTES); + super(ArgTypeCode.BYTES); this.value = value; } diff --git a/jvm/core/src/main/java/org/apache/tvm/TVMValueDouble.java b/jvm/core/src/main/java/org/apache/tvm/TVMValueDouble.java index d94b011d7e10..9db4c3bb0e8c 100644 --- a/jvm/core/src/main/java/org/apache/tvm/TVMValueDouble.java +++ b/jvm/core/src/main/java/org/apache/tvm/TVMValueDouble.java @@ -21,7 +21,7 @@ public class TVMValueDouble extends TVMValue { public final double value; public TVMValueDouble(double value) { - super(TypeCode.FLOAT); + super(ArgTypeCode.FLOAT); this.value = value; } diff --git a/jvm/core/src/main/java/org/apache/tvm/TVMValueHandle.java b/jvm/core/src/main/java/org/apache/tvm/TVMValueHandle.java index 8ab7572d1cfd..b91f55e2f59b 100644 --- a/jvm/core/src/main/java/org/apache/tvm/TVMValueHandle.java +++ b/jvm/core/src/main/java/org/apache/tvm/TVMValueHandle.java @@ -18,13 +18,13 @@ package org.apache.tvm; /** - * Java class related to TVM handles (TypeCode.HANDLE) + * Java class related to TVM handles (ArgTypeCode.HANDLE) */ public class TVMValueHandle extends TVMValue { public final long value; public TVMValueHandle(long value) { - super(TypeCode.HANDLE); + super(ArgTypeCode.HANDLE); this.value = value; } diff --git a/jvm/core/src/main/java/org/apache/tvm/TVMValueLong.java b/jvm/core/src/main/java/org/apache/tvm/TVMValueLong.java index 5dba2fd459f6..8a9b157d3961 100644 --- a/jvm/core/src/main/java/org/apache/tvm/TVMValueLong.java +++ b/jvm/core/src/main/java/org/apache/tvm/TVMValueLong.java @@ -21,7 +21,7 @@ public class TVMValueLong extends TVMValue { public final long value; public TVMValueLong(long value) { - super(TypeCode.INT); + super(ArgTypeCode.INT); this.value = value; } diff --git a/jvm/core/src/main/java/org/apache/tvm/TVMValueNull.java b/jvm/core/src/main/java/org/apache/tvm/TVMValueNull.java index 03c0ea0dbcd4..8c49ee5b3df5 100644 --- a/jvm/core/src/main/java/org/apache/tvm/TVMValueNull.java +++ b/jvm/core/src/main/java/org/apache/tvm/TVMValueNull.java @@ -19,6 +19,6 @@ public class TVMValueNull extends TVMValue { public TVMValueNull() { - super(TypeCode.NULL); + super(ArgTypeCode.NULL); } } diff --git a/jvm/core/src/main/java/org/apache/tvm/TVMValueString.java b/jvm/core/src/main/java/org/apache/tvm/TVMValueString.java index 260803e8e897..46926e7d3fc6 100644 --- a/jvm/core/src/main/java/org/apache/tvm/TVMValueString.java +++ b/jvm/core/src/main/java/org/apache/tvm/TVMValueString.java @@ -21,7 +21,7 @@ public class TVMValueString extends TVMValue { public final String value; public TVMValueString(String value) { - super(TypeCode.STR); + super(ArgTypeCode.STR); this.value = value; } diff --git a/jvm/core/src/main/java/org/apache/tvm/contrib/GraphRuntime.java b/jvm/core/src/main/java/org/apache/tvm/contrib/GraphRuntime.java index c31c67f283af..61ff966eaf38 100644 --- a/jvm/core/src/main/java/org/apache/tvm/contrib/GraphRuntime.java +++ b/jvm/core/src/main/java/org/apache/tvm/contrib/GraphRuntime.java @@ -38,53 +38,14 @@ public class GraphRuntime { * @return Runtime graph module that can be used to execute the graph. */ public static GraphModule create(String graphJson, Module libmod, TVMContext ctx) { - Module graphModule = null; - if (ctx.deviceType >= RPC.RPC_SESS_MASK) { - if (!(ctx instanceof TVMRemoteContext)) { - throw new IllegalArgumentException( - "Looks like you are using remote context with no RPCSession bind." - + "Use session.context instead."); - } - RPCSession rpcSession = ((TVMRemoteContext) ctx).rpcSession; - // check arguments - if (!"rpc".equals(libmod.typeKey())) { - throw new IllegalArgumentException("libmod.typeKey != rpc"); - } - final int sessIndex = (int) ((Function) reflectionStaticCall( - RPC.class, "getApi", "_SessTableIndex")) - .pushArg(libmod).invoke().asLong(); - if (sessIndex != (Integer) reflectionGetField(rpcSession, "tblIndex")) { - throw new IllegalArgumentException(String.format( - "libmod SessTableIndex=%d mismatch rpcSession.tblIndex=%d", - sessIndex, reflectionGetField(rpcSession, "tblIndex"))); - } - - Function rpcModuleHandle = (Function) reflectionStaticCall( - RPC.class, "getApi","_ModuleHandle"); - if (rpcModuleHandle == null) { - throw new RuntimeException("Cannot find global function tvm.rpc._ModuleHandle." - + "Did you compile tvm_runtime with the correct version?"); - } - - Function fcreate = Function.getFunction("tvm.graph_runtime.remote_create"); - if (fcreate == null) { - throw new RuntimeException("Cannot find global function tvm.graph_runtime.remote_create." - + "Did you compile tvm_runtime with correct version?"); - } - - TVMValue hmod = rpcModuleHandle.pushArg(libmod).invoke(); - graphModule = fcreate.call(graphJson, hmod, - ctx.deviceType % RPC.RPC_SESS_MASK, ctx.deviceId).asModule(); - } else { - Function fcreate = Function.getFunction("tvm.graph_runtime.create"); - if (fcreate == null) { - throw new RuntimeException("Cannot find global function tvm.graph_runtime.create." - + "Did you compile tvm_runtime with correct version?"); - } - graphModule = fcreate.pushArg(graphJson) - .pushArg(libmod).pushArg(ctx.deviceType).pushArg(ctx.deviceId) - .invoke().asModule(); + Function fcreate = Function.getFunction("tvm.graph_runtime.create"); + if (fcreate == null) { + throw new RuntimeException("Cannot find global function tvm.graph_runtime.create." + + "Did you compile tvm_runtime with correct version?"); } + Module graphModule = fcreate.pushArg(graphJson) + .pushArg(libmod).pushArg(ctx.deviceType).pushArg(ctx.deviceId) + .invoke().asModule(); return new GraphModule(graphModule, ctx); } diff --git a/jvm/core/src/main/java/org/apache/tvm/rpc/Client.java b/jvm/core/src/main/java/org/apache/tvm/rpc/Client.java index 5178ac900a36..69321c3b51c8 100644 --- a/jvm/core/src/main/java/org/apache/tvm/rpc/Client.java +++ b/jvm/core/src/main/java/org/apache/tvm/rpc/Client.java @@ -29,7 +29,7 @@ public class Client { * @return The connected session. */ public static RPCSession connect(String url, int port, String key) { - Function doConnect = RPC.getApi("_Connect"); + Function doConnect = RPC.getApi("Connect"); if (doConnect == null) { throw new RuntimeException("Please compile with USE_RPC=1"); } diff --git a/jvm/core/src/main/java/org/apache/tvm/rpc/NativeServerLoop.java b/jvm/core/src/main/java/org/apache/tvm/rpc/NativeServerLoop.java index 29a457f39a40..1f3191fb2e8c 100644 --- a/jvm/core/src/main/java/org/apache/tvm/rpc/NativeServerLoop.java +++ b/jvm/core/src/main/java/org/apache/tvm/rpc/NativeServerLoop.java @@ -46,7 +46,7 @@ public NativeServerLoop(final Function fsend, final Function frecv) { try { tempDir = serverEnv(); System.err.println("starting server loop..."); - RPC.getApi("_ServerLoop").pushArg(fsend).pushArg(frecv).invoke(); + RPC.getApi("ServerLoop").pushArg(fsend).pushArg(frecv).invoke(); System.err.println("done server loop..."); } catch (IOException e) { e.printStackTrace(); diff --git a/jvm/core/src/main/java/org/apache/tvm/rpc/RPCSession.java b/jvm/core/src/main/java/org/apache/tvm/rpc/RPCSession.java index 92b328488b40..b9f621473cf4 100644 --- a/jvm/core/src/main/java/org/apache/tvm/rpc/RPCSession.java +++ b/jvm/core/src/main/java/org/apache/tvm/rpc/RPCSession.java @@ -39,7 +39,7 @@ public class RPCSession { RPCSession(Module sess) { session = sess; - tblIndex = (int) RPC.getApi("_SessTableIndex").pushArg(session).invoke().asLong(); + tblIndex = (int) RPC.getApi("SessTableIndex").pushArg(session).invoke().asLong(); } /** @@ -237,7 +237,7 @@ public byte[] download(String path) { * @return The remote module containing remote function. */ public Module loadModule(String path) { - return RPC.getApi("_LoadRemoteModule").pushArg(session).pushArg(path).invoke().asModule(); + return RPC.getApi("LoadRemoteModule").pushArg(session).pushArg(path).invoke().asModule(); } diff --git a/jvm/native/src/main/native/jni_helper_func.h b/jvm/native/src/main/native/jni_helper_func.h index ce1979c6618b..0f202004f99d 100644 --- a/jvm/native/src/main/native/jni_helper_func.h +++ b/jvm/native/src/main/native/jni_helper_func.h @@ -26,7 +26,7 @@ #define TVM4J_JNI_MAIN_NATIVE_JNI_HELPER_FUNC_H_ // Helper functions for RefXXX getter & setter -jlong getLongField(JNIEnv *env, jobject obj) { +jlong getLongField(JNIEnv* env, jobject obj) { jclass refClass = env->FindClass("org/apache/tvm/Base$RefLong"); jfieldID refFid = env->GetFieldID(refClass, "value", "J"); jlong ret = env->GetLongField(obj, refFid); @@ -34,7 +34,7 @@ jlong getLongField(JNIEnv *env, jobject obj) { return ret; } -jint getIntField(JNIEnv *env, jobject obj) { +jint getIntField(JNIEnv* env, jobject obj) { jclass refClass = env->FindClass("org/apache/tvm/Base$RefInt"); jfieldID refFid = env->GetFieldID(refClass, "value", "I"); jint ret = env->GetIntField(obj, refFid); @@ -42,21 +42,21 @@ jint getIntField(JNIEnv *env, jobject obj) { return ret; } -void setIntField(JNIEnv *env, jobject obj, jint value) { +void setIntField(JNIEnv* env, jobject obj, jint value) { jclass refClass = env->FindClass("org/apache/tvm/Base$RefInt"); jfieldID refFid = env->GetFieldID(refClass, "value", "I"); env->SetIntField(obj, refFid, value); env->DeleteLocalRef(refClass); } -void setLongField(JNIEnv *env, jobject obj, jlong value) { +void setLongField(JNIEnv* env, jobject obj, jlong value) { jclass refClass = env->FindClass("org/apache/tvm/Base$RefLong"); jfieldID refFid = env->GetFieldID(refClass, "value", "J"); env->SetLongField(obj, refFid, value); env->DeleteLocalRef(refClass); } -void setStringField(JNIEnv *env, jobject obj, const char *value) { +void setStringField(JNIEnv* env, jobject obj, const char* value) { jclass refClass = env->FindClass("org/apache/tvm/Base$RefString"); jfieldID refFid = env->GetFieldID(refClass, "value", "Ljava/lang/String;"); env->SetObjectField(obj, refFid, env->NewStringUTF(value)); @@ -64,8 +64,8 @@ void setStringField(JNIEnv *env, jobject obj, const char *value) { } // Helper functions for TVMValue -jlong getTVMValueLongField(JNIEnv *env, jobject obj, - const char *clsname = "org/apache/tvm/TVMValueLong") { +jlong getTVMValueLongField(JNIEnv* env, jobject obj, + const char* clsname = "org/apache/tvm/TVMValueLong") { jclass cls = env->FindClass(clsname); jfieldID fid = env->GetFieldID(cls, "value", "J"); jlong ret = env->GetLongField(obj, fid); @@ -73,7 +73,7 @@ jlong getTVMValueLongField(JNIEnv *env, jobject obj, return ret; } -jdouble getTVMValueDoubleField(JNIEnv *env, jobject obj) { +jdouble getTVMValueDoubleField(JNIEnv* env, jobject obj) { jclass cls = env->FindClass("org/apache/tvm/TVMValueDouble"); jfieldID fid = env->GetFieldID(cls, "value", "D"); jdouble ret = env->GetDoubleField(obj, fid); @@ -81,7 +81,7 @@ jdouble getTVMValueDoubleField(JNIEnv *env, jobject obj) { return ret; } -jstring getTVMValueStringField(JNIEnv *env, jobject obj) { +jstring getTVMValueStringField(JNIEnv* env, jobject obj) { jclass cls = env->FindClass("org/apache/tvm/TVMValueString"); jfieldID fid = env->GetFieldID(cls, "value", "Ljava/lang/String;"); jstring ret = static_cast(env->GetObjectField(obj, fid)); @@ -89,7 +89,7 @@ jstring getTVMValueStringField(JNIEnv *env, jobject obj) { return ret; } -jobject newTVMValueHandle(JNIEnv *env, jlong value) { +jobject newTVMValueHandle(JNIEnv* env, jlong value) { jclass cls = env->FindClass("org/apache/tvm/TVMValueHandle"); jmethodID constructor = env->GetMethodID(cls, "", "(J)V"); jobject object = env->NewObject(cls, constructor, value); @@ -97,7 +97,7 @@ jobject newTVMValueHandle(JNIEnv *env, jlong value) { return object; } -jobject newTVMValueLong(JNIEnv *env, jlong value) { +jobject newTVMValueLong(JNIEnv* env, jlong value) { jclass cls = env->FindClass("org/apache/tvm/TVMValueLong"); jmethodID constructor = env->GetMethodID(cls, "", "(J)V"); jobject object = env->NewObject(cls, constructor, value); @@ -105,7 +105,7 @@ jobject newTVMValueLong(JNIEnv *env, jlong value) { return object; } -jobject newTVMValueDouble(JNIEnv *env, jdouble value) { +jobject newTVMValueDouble(JNIEnv* env, jdouble value) { jclass cls = env->FindClass("org/apache/tvm/TVMValueDouble"); jmethodID constructor = env->GetMethodID(cls, "", "(D)V"); jobject object = env->NewObject(cls, constructor, value); @@ -113,7 +113,7 @@ jobject newTVMValueDouble(JNIEnv *env, jdouble value) { return object; } -jobject newTVMValueString(JNIEnv *env, const char *value) { +jobject newTVMValueString(JNIEnv* env, const char* value) { jstring jvalue = env->NewStringUTF(value); jclass cls = env->FindClass("org/apache/tvm/TVMValueString"); jmethodID constructor = env->GetMethodID(cls, "", "(Ljava/lang/String;)V"); @@ -123,10 +123,10 @@ jobject newTVMValueString(JNIEnv *env, const char *value) { return object; } -jobject newTVMValueBytes(JNIEnv *env, const TVMByteArray *arr) { +jobject newTVMValueBytes(JNIEnv* env, const TVMByteArray* arr) { jbyteArray jarr = env->NewByteArray(arr->size); env->SetByteArrayRegion(jarr, 0, arr->size, - reinterpret_cast(const_cast(arr->data))); + reinterpret_cast(const_cast(arr->data))); jclass cls = env->FindClass("org/apache/tvm/TVMValueBytes"); jmethodID constructor = env->GetMethodID(cls, "", "([B)V"); jobject object = env->NewObject(cls, constructor, jarr); @@ -135,7 +135,7 @@ jobject newTVMValueBytes(JNIEnv *env, const TVMByteArray *arr) { return object; } -jobject newModule(JNIEnv *env, jlong value) { +jobject newModule(JNIEnv* env, jlong value) { jclass cls = env->FindClass("org/apache/tvm/Module"); jmethodID constructor = env->GetMethodID(cls, "", "(J)V"); jobject object = env->NewObject(cls, constructor, value); @@ -143,7 +143,7 @@ jobject newModule(JNIEnv *env, jlong value) { return object; } -jobject newFunction(JNIEnv *env, jlong value) { +jobject newFunction(JNIEnv* env, jlong value) { jclass cls = env->FindClass("org/apache/tvm/Function"); jmethodID constructor = env->GetMethodID(cls, "", "(J)V"); jobject object = env->NewObject(cls, constructor, value); @@ -151,7 +151,7 @@ jobject newFunction(JNIEnv *env, jlong value) { return object; } -jobject newNDArray(JNIEnv *env, jlong handle, jboolean isview) { +jobject newNDArray(JNIEnv* env, jlong handle, jboolean isview) { jclass cls = env->FindClass("org/apache/tvm/NDArrayBase"); jmethodID constructor = env->GetMethodID(cls, "", "(JZ)V"); jobject object = env->NewObject(cls, constructor, handle, isview); @@ -159,7 +159,7 @@ jobject newNDArray(JNIEnv *env, jlong handle, jboolean isview) { return object; } -jobject newObject(JNIEnv *env, const char *clsname) { +jobject newObject(JNIEnv* env, const char* clsname) { jclass cls = env->FindClass(clsname); jmethodID constructor = env->GetMethodID(cls, "", "()V"); jobject object = env->NewObject(cls, constructor); @@ -167,7 +167,7 @@ jobject newObject(JNIEnv *env, const char *clsname) { return object; } -void fromJavaDType(JNIEnv *env, jobject jdtype, DLDataType *dtype) { +void fromJavaDType(JNIEnv* env, jobject jdtype, DLDataType* dtype) { jclass tvmTypeClass = env->FindClass("org/apache/tvm/DLDataType"); dtype->code = (uint8_t)(env->GetIntField(jdtype, env->GetFieldID(tvmTypeClass, "typeCode", "I"))); dtype->bits = (uint8_t)(env->GetIntField(jdtype, env->GetFieldID(tvmTypeClass, "bits", "I"))); @@ -175,16 +175,16 @@ void fromJavaDType(JNIEnv *env, jobject jdtype, DLDataType *dtype) { env->DeleteLocalRef(tvmTypeClass); } -void fromJavaContext(JNIEnv *env, jobject jctx, TVMContext *ctx) { +void fromJavaContext(JNIEnv* env, jobject jctx, TVMContext* ctx) { jclass tvmContextClass = env->FindClass("org/apache/tvm/TVMContext"); - ctx->device_type = static_cast(env->GetIntField(jctx, - env->GetFieldID(tvmContextClass, "deviceType", "I"))); - ctx->device_id = static_cast(env->GetIntField(jctx, - env->GetFieldID(tvmContextClass, "deviceId", "I"))); + ctx->device_type = static_cast( + env->GetIntField(jctx, env->GetFieldID(tvmContextClass, "deviceType", "I"))); + ctx->device_id = + static_cast(env->GetIntField(jctx, env->GetFieldID(tvmContextClass, "deviceId", "I"))); env->DeleteLocalRef(tvmContextClass); } -jobject tvmRetValueToJava(JNIEnv *env, TVMValue value, int tcode) { +jobject tvmRetValueToJava(JNIEnv* env, TVMValue value, int tcode) { switch (tcode) { case kDLUInt: case kDLInt: @@ -204,7 +204,7 @@ jobject tvmRetValueToJava(JNIEnv *env, TVMValue value, int tcode) { case kTVMStr: return newTVMValueString(env, value.v_str); case kTVMBytes: - return newTVMValueBytes(env, reinterpret_cast(value.v_handle)); + return newTVMValueBytes(env, reinterpret_cast(value.v_handle)); case kTVMNullptr: return newObject(env, "org/apache/tvm/TVMValueNull"); default: diff --git a/jvm/native/src/main/native/org_apache_tvm_native_c_api.cc b/jvm/native/src/main/native/org_apache_tvm_native_c_api.cc index b59956824d26..6fc316ca8739 100644 --- a/jvm/native/src/main/native/org_apache_tvm_native_c_api.cc +++ b/jvm/native/src/main/native/org_apache_tvm_native_c_api.cc @@ -29,28 +29,28 @@ #include #include #endif -#include #include -#include +#include #include +#include #include "jni_helper_func.h" -JavaVM *_jvm; -void *_tvmHandle = nullptr; +JavaVM* _jvm; +void* _tvmHandle = nullptr; struct TVMFuncArgsThreadLocalEntry { std::vector tvmFuncArgValues; std::vector tvmFuncArgTypes; // for later release - std::vector > tvmFuncArgPushedStrs; - std::vector > tvmFuncArgPushedBytes; + std::vector > tvmFuncArgPushedStrs; + std::vector > tvmFuncArgPushedBytes; }; typedef dmlc::ThreadLocalStore TVMFuncArgsThreadLocalStore; -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_nativeLibInit - (JNIEnv *env, jobject obj, jstring jtvmLibFile) { +JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_nativeLibInit(JNIEnv* env, jobject obj, + jstring jtvmLibFile) { if (_tvmHandle == NULL && !env->IsSameObject(jtvmLibFile, NULL)) { - const char *tvmLibFile = env->GetStringUTFChars(jtvmLibFile, 0); + const char* tvmLibFile = env->GetStringUTFChars(jtvmLibFile, 0); _tvmHandle = dlopen(tvmLibFile, RTLD_LAZY | RTLD_GLOBAL); env->ReleaseStringUTFChars(jtvmLibFile, tvmLibFile); if (!_tvmHandle) { @@ -61,70 +61,70 @@ JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_nativeLibInit return env->GetJavaVM(&_jvm); } -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_shutdown(JNIEnv *env, jobject obj) { +JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_shutdown(JNIEnv* env, jobject obj) { if (_tvmHandle) { dlclose(_tvmHandle); } return 0; } -JNIEXPORT jstring JNICALL Java_org_apache_tvm_LibInfo_tvmGetLastError(JNIEnv * env, jobject obj) { +JNIEXPORT jstring JNICALL Java_org_apache_tvm_LibInfo_tvmGetLastError(JNIEnv* env, jobject obj) { return env->NewStringUTF(TVMGetLastError()); } // Function -JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFuncPushArgLong( - JNIEnv *env, jobject obj, jlong arg) { +JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFuncPushArgLong(JNIEnv* env, jobject obj, + jlong arg) { TVMValue value; value.v_int64 = static_cast(arg); - TVMFuncArgsThreadLocalEntry *e = TVMFuncArgsThreadLocalStore::Get(); + TVMFuncArgsThreadLocalEntry* e = TVMFuncArgsThreadLocalStore::Get(); e->tvmFuncArgValues.push_back(value); e->tvmFuncArgTypes.push_back(kDLInt); } -JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFuncPushArgDouble( - JNIEnv *env, jobject obj, jdouble arg) { +JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFuncPushArgDouble(JNIEnv* env, jobject obj, + jdouble arg) { TVMValue value; value.v_float64 = static_cast(arg); - TVMFuncArgsThreadLocalEntry *e = TVMFuncArgsThreadLocalStore::Get(); + TVMFuncArgsThreadLocalEntry* e = TVMFuncArgsThreadLocalStore::Get(); e->tvmFuncArgValues.push_back(value); e->tvmFuncArgTypes.push_back(kDLFloat); } -JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFuncPushArgString( - JNIEnv *env, jobject obj, jstring arg) { +JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFuncPushArgString(JNIEnv* env, jobject obj, + jstring arg) { TVMValue value; jstring garg = reinterpret_cast(env->NewGlobalRef(arg)); value.v_str = env->GetStringUTFChars(garg, 0); - TVMFuncArgsThreadLocalEntry *e = TVMFuncArgsThreadLocalStore::Get(); + TVMFuncArgsThreadLocalEntry* e = TVMFuncArgsThreadLocalStore::Get(); e->tvmFuncArgValues.push_back(value); e->tvmFuncArgTypes.push_back(kTVMStr); // release string args later e->tvmFuncArgPushedStrs.push_back(std::make_pair(garg, value.v_str)); } -JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFuncPushArgHandle( - JNIEnv *env, jobject obj, jlong arg, jint argType) { +JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFuncPushArgHandle(JNIEnv* env, jobject obj, + jlong arg, jint argType) { TVMValue value; - value.v_handle = reinterpret_cast(arg); - TVMFuncArgsThreadLocalEntry *e = TVMFuncArgsThreadLocalStore::Get(); + value.v_handle = reinterpret_cast(arg); + TVMFuncArgsThreadLocalEntry* e = TVMFuncArgsThreadLocalStore::Get(); e->tvmFuncArgValues.push_back(value); e->tvmFuncArgTypes.push_back(static_cast(argType)); } -JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFuncPushArgBytes( - JNIEnv *env, jobject obj, jbyteArray arg) { +JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFuncPushArgBytes(JNIEnv* env, jobject obj, + jbyteArray arg) { jbyteArray garg = reinterpret_cast(env->NewGlobalRef(arg)); - jbyte *data = env->GetByteArrayElements(garg, 0); + jbyte* data = env->GetByteArrayElements(garg, 0); - TVMByteArray *byteArray = new TVMByteArray(); + TVMByteArray* byteArray = new TVMByteArray(); byteArray->size = static_cast(env->GetArrayLength(garg)); - byteArray->data = reinterpret_cast(data); + byteArray->data = reinterpret_cast(data); TVMValue value; - value.v_handle = reinterpret_cast(byteArray); + value.v_handle = reinterpret_cast(byteArray); - TVMFuncArgsThreadLocalEntry *e = TVMFuncArgsThreadLocalStore::Get(); + TVMFuncArgsThreadLocalEntry* e = TVMFuncArgsThreadLocalStore::Get(); e->tvmFuncArgValues.push_back(value); e->tvmFuncArgTypes.push_back(kTVMBytes); @@ -132,10 +132,10 @@ JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFuncPushArgBytes( // release (garg, data), byteArray later } -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFuncListGlobalNames( - JNIEnv *env, jobject obj, jobject jfuncNames) { +JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFuncListGlobalNames(JNIEnv* env, jobject obj, + jobject jfuncNames) { int outSize; - const char **outArray; + const char** outArray; int ret = TVMFuncListGlobalNames(&outSize, &outArray); if (ret) { @@ -157,24 +157,25 @@ JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFuncListGlobalNames( return ret; } -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFuncFree( - JNIEnv *env, jobject obj, jlong jhandle) { +JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFuncFree(JNIEnv* env, jobject obj, + jlong jhandle) { return TVMFuncFree(reinterpret_cast(jhandle)); } -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFuncGetGlobal( - JNIEnv *env, jobject obj, jstring jname, jobject jhandle) { +JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFuncGetGlobal(JNIEnv* env, jobject obj, + jstring jname, + jobject jhandle) { TVMFunctionHandle handle; - const char *name = env->GetStringUTFChars(jname, 0); + const char* name = env->GetStringUTFChars(jname, 0); int ret = TVMFuncGetGlobal(name, &handle); env->ReleaseStringUTFChars(jname, name); setLongField(env, jhandle, reinterpret_cast(handle)); return ret; } -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFuncCall( - JNIEnv *env, jobject obj, jlong jhandle, jobject jretVal) { - TVMFuncArgsThreadLocalEntry *e = TVMFuncArgsThreadLocalStore::Get(); +JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFuncCall(JNIEnv* env, jobject obj, + jlong jhandle, jobject jretVal) { + TVMFuncArgsThreadLocalEntry* e = TVMFuncArgsThreadLocalStore::Get(); int numArgs = e->tvmFuncArgValues.size(); TVMValue retVal; @@ -192,8 +193,8 @@ JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFuncCall( e->tvmFuncArgTypes.clear(); e->tvmFuncArgValues.clear(); - int ret = TVMFuncCall(reinterpret_cast(jhandle), - &argValues[0], &argTypes[0], numArgs, &retVal, &retTypeCode); + int ret = TVMFuncCall(reinterpret_cast(jhandle), &argValues[0], &argTypes[0], + numArgs, &retVal, &retTypeCode); if (ret != 0) { return ret; @@ -204,16 +205,15 @@ JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFuncCall( env->DeleteGlobalRef(iter->first); } for (auto iter = pushedBytes.cbegin(); iter != pushedBytes.cend(); iter++) { - env->ReleaseByteArrayElements(iter->first, - reinterpret_cast(const_cast(iter->second->data)), 0); + env->ReleaseByteArrayElements( + iter->first, reinterpret_cast(const_cast(iter->second->data)), 0); env->DeleteGlobalRef(iter->first); delete iter->second; } // return TVMValue object to Java jclass refTVMValueCls = env->FindClass("org/apache/tvm/Base$RefTVMValue"); - jfieldID refTVMValueFid - = env->GetFieldID(refTVMValueCls, "value", "Lorg/apache/tvm/TVMValue;"); + jfieldID refTVMValueFid = env->GetFieldID(refTVMValueCls, "value", "Lorg/apache/tvm/TVMValue;"); env->SetObjectField(jretVal, refTVMValueFid, tvmRetValueToJava(env, retVal, retTypeCode)); @@ -223,16 +223,16 @@ JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFuncCall( } // Callback function -extern "C" int funcInvokeCallback(TVMValue *args, - int *typeCodes, int numArgs, TVMRetValueHandle ret, void *resourceHandle) { - JNIEnv *env; - int jniStatus = _jvm->GetEnv(reinterpret_cast(&env), JNI_VERSION_1_6); +extern "C" int funcInvokeCallback(TVMValue* args, int* typeCodes, int numArgs, + TVMRetValueHandle ret, void* resourceHandle) { + JNIEnv* env; + int jniStatus = _jvm->GetEnv(reinterpret_cast(&env), JNI_VERSION_1_6); if (jniStatus == JNI_EDETACHED) { - #ifdef TVM4J_ANDROID +#ifdef TVM4J_ANDROID _jvm->AttachCurrentThread(&env, nullptr); - #else - _jvm->AttachCurrentThread(reinterpret_cast(&env), nullptr); - #endif +#else + _jvm->AttachCurrentThread(reinterpret_cast(&env), nullptr); +#endif } else { CHECK(jniStatus == JNI_OK); } @@ -242,10 +242,8 @@ extern "C" int funcInvokeCallback(TVMValue *args, for (int i = 0; i < numArgs; ++i) { TVMValue arg = args[i]; int tcode = typeCodes[i]; - if (tcode == kTVMObjectHandle || - tcode == kTVMPackedFuncHandle || - tcode == kTVMObjectRValueRefArg || - tcode == kTVMModuleHandle) { + if (tcode == kTVMObjectHandle || tcode == kTVMPackedFuncHandle || + tcode == kTVMObjectRValueRefArg || tcode == kTVMModuleHandle) { TVMCbArgToReturn(&arg, &tcode); } jobject jarg = tvmRetValueToJava(env, arg, tcode); @@ -253,15 +251,16 @@ extern "C" int funcInvokeCallback(TVMValue *args, } jclass clsFunc = env->FindClass("org/apache/tvm/Function"); - jmethodID invokeRegisteredCbFunc = env->GetStaticMethodID(clsFunc, "invokeRegisteredCbFunc", + jmethodID invokeRegisteredCbFunc = env->GetStaticMethodID( + clsFunc, "invokeRegisteredCbFunc", "(Lorg/apache/tvm/Function$Callback;[Lorg/apache/tvm/TVMValue;)Ljava/lang/Object;"); - jmethodID pushArgToStack = env->GetStaticMethodID(clsFunc, "pushArgToStack", - "(Ljava/lang/Object;)V"); + jmethodID pushArgToStack = + env->GetStaticMethodID(clsFunc, "pushArgToStack", "(Ljava/lang/Object;)V"); jobject jretValue = env->CallStaticObjectMethod(clsFunc, invokeRegisteredCbFunc, - reinterpret_cast(resourceHandle), jargs); + reinterpret_cast(resourceHandle), jargs); - TVMFuncArgsThreadLocalEntry *e = TVMFuncArgsThreadLocalStore::Get(); + TVMFuncArgsThreadLocalEntry* e = TVMFuncArgsThreadLocalStore::Get(); const size_t prevNumStrArg = e->tvmFuncArgPushedStrs.size(); const size_t prevNumBytesArg = e->tvmFuncArgPushedBytes.size(); @@ -279,16 +278,16 @@ extern "C" int funcInvokeCallback(TVMValue *args, // release allocated strings. if (e->tvmFuncArgPushedStrs.size() > prevNumStrArg) { - const auto &pairArg = e->tvmFuncArgPushedStrs.back(); + const auto& pairArg = e->tvmFuncArgPushedStrs.back(); env->ReleaseStringUTFChars(pairArg.first, pairArg.second); env->DeleteGlobalRef(pairArg.first); e->tvmFuncArgPushedStrs.pop_back(); } // release allocated bytes. if (e->tvmFuncArgPushedBytes.size() > prevNumBytesArg) { - const auto &pairArg = e->tvmFuncArgPushedBytes.back(); - env->ReleaseByteArrayElements(pairArg.first, - reinterpret_cast(const_cast(pairArg.second->data)), 0); + const auto& pairArg = e->tvmFuncArgPushedBytes.back(); + env->ReleaseByteArrayElements( + pairArg.first, reinterpret_cast(const_cast(pairArg.second->data)), 0); env->DeleteGlobalRef(pairArg.first); delete pairArg.second; e->tvmFuncArgPushedBytes.pop_back(); @@ -301,62 +300,64 @@ extern "C" int funcInvokeCallback(TVMValue *args, } // Free callback function -extern "C" void funcFreeCallback(void *resourceHandle) { - JNIEnv *env; - int jniStatus = _jvm->GetEnv(reinterpret_cast(&env), JNI_VERSION_1_6); +extern "C" void funcFreeCallback(void* resourceHandle) { + JNIEnv* env; + int jniStatus = _jvm->GetEnv(reinterpret_cast(&env), JNI_VERSION_1_6); if (jniStatus == JNI_EDETACHED) { - #ifdef TVM4J_ANDROID +#ifdef TVM4J_ANDROID _jvm->AttachCurrentThread(&env, nullptr); - #else - _jvm->AttachCurrentThread(reinterpret_cast(&env), nullptr); - #endif +#else + _jvm->AttachCurrentThread(reinterpret_cast(&env), nullptr); +#endif } else { CHECK(jniStatus == JNI_OK); } env->DeleteGlobalRef(reinterpret_cast(resourceHandle)); } -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFuncCreateFromCFunc( - JNIEnv *env, jobject obj, jobject jfunction, jobject jretHandle) { +JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFuncCreateFromCFunc(JNIEnv* env, jobject obj, + jobject jfunction, + jobject jretHandle) { TVMFunctionHandle out; - int ret = TVMFuncCreateFromCFunc(reinterpret_cast(&funcInvokeCallback), - reinterpret_cast(env->NewGlobalRef(jfunction)), - reinterpret_cast(&funcFreeCallback), - &out); + int ret = + TVMFuncCreateFromCFunc(reinterpret_cast(&funcInvokeCallback), + reinterpret_cast(env->NewGlobalRef(jfunction)), + reinterpret_cast(&funcFreeCallback), &out); setLongField(env, jretHandle, reinterpret_cast(out)); return ret; } -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFuncRegisterGlobal( - JNIEnv *env, jobject obj, jstring jname, jlong jhandle, jint joverride) { - const char *name = env->GetStringUTFChars(jname, 0); - int ret = TVMFuncRegisterGlobal( - name, reinterpret_cast(jhandle), reinterpret_cast(joverride)); +JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFuncRegisterGlobal(JNIEnv* env, jobject obj, + jstring jname, + jlong jhandle, + jint joverride) { + const char* name = env->GetStringUTFChars(jname, 0); + int ret = TVMFuncRegisterGlobal(name, reinterpret_cast(jhandle), + reinterpret_cast(joverride)); env->ReleaseStringUTFChars(jname, name); return ret; } // Module -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmModFree( - JNIEnv *env, jobject obj, jlong jhandle) { +JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmModFree(JNIEnv* env, jobject obj, + jlong jhandle) { return TVMModFree(reinterpret_cast(jhandle)); } -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmModImport( - JNIEnv *env, jobject obj, jlong jmod, jlong jdep) { +JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmModImport(JNIEnv* env, jobject obj, + jlong jmod, jlong jdep) { return TVMModImport(reinterpret_cast(jmod), reinterpret_cast(jdep)); } -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmModGetFunction( - JNIEnv *env, jobject obj, jlong jhandle, jstring jname, jint jimport, jobject jret) { +JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmModGetFunction(JNIEnv* env, jobject obj, + jlong jhandle, jstring jname, + jint jimport, jobject jret) { TVMFunctionHandle retFunc; - const char *name = env->GetStringUTFChars(jname, 0); - int ret = TVMModGetFunction(reinterpret_cast(jhandle), - name, - reinterpret_cast(jimport), - &retFunc); + const char* name = env->GetStringUTFChars(jname, 0); + int ret = TVMModGetFunction(reinterpret_cast(jhandle), name, + reinterpret_cast(jimport), &retFunc); env->ReleaseStringUTFChars(jname, name); setLongField(env, jret, reinterpret_cast(retFunc)); @@ -365,28 +366,25 @@ JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmModGetFunction( } // NDArray -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayFree( - JNIEnv *env, jobject obj, jlong jhandle) { +JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayFree(JNIEnv* env, jobject obj, + jlong jhandle) { return TVMArrayFree(reinterpret_cast(jhandle)); } -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayAlloc( - JNIEnv *env, jobject obj, jlongArray jshape, jint jdtypeCode, - jint jdtypeBits, jint jdtypeLanes, jint jdeviceType, jint jdeviceId, jobject jret) { +JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayAlloc(JNIEnv* env, jobject obj, + jlongArray jshape, jint jdtypeCode, + jint jdtypeBits, jint jdtypeLanes, + jint jdeviceType, jint jdeviceId, + jobject jret) { int ndim = static_cast(env->GetArrayLength(jshape)); TVMArrayHandle out; - jlong *shapeArray = env->GetLongArrayElements(jshape, NULL); - int ret = TVMArrayAlloc( - reinterpret_cast(shapeArray), - ndim, - static_cast(jdtypeCode), - static_cast(jdtypeBits), - static_cast(jdtypeLanes), - static_cast(jdeviceType), - static_cast(jdeviceId), - &out); + jlong* shapeArray = env->GetLongArrayElements(jshape, NULL); + int ret = TVMArrayAlloc(reinterpret_cast(shapeArray), ndim, + static_cast(jdtypeCode), static_cast(jdtypeBits), + static_cast(jdtypeLanes), static_cast(jdeviceType), + static_cast(jdeviceId), &out); env->ReleaseLongArrayElements(jshape, shapeArray, 0); setLongField(env, jret, reinterpret_cast(out)); @@ -394,10 +392,10 @@ JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayAlloc( return ret; } -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayGetShape( - JNIEnv *env, jobject obj, jlong jhandle, jobject jshape) { - DLTensor *array = reinterpret_cast(jhandle); - int64_t *shape = array->shape; +JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayGetShape(JNIEnv* env, jobject obj, + jlong jhandle, jobject jshape) { + DLTensor* array = reinterpret_cast(jhandle); + int64_t* shape = array->shape; int ndim = array->ndim; // fill shape buffer @@ -417,18 +415,19 @@ JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayGetShape( return 0; } -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayCopyFromTo( - JNIEnv *env, jobject obj, jlong jfrom, jlong jto) { +JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayCopyFromTo(JNIEnv* env, jobject obj, + jlong jfrom, jlong jto) { return TVMArrayCopyFromTo(reinterpret_cast(jfrom), reinterpret_cast(jto), NULL); } -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayCopyFromJArray( - JNIEnv *env, jobject obj, jbyteArray jarr, jlong jfrom, jlong jto) { - jbyte *data = env->GetByteArrayElements(jarr, NULL); +JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayCopyFromJArray(JNIEnv* env, jobject obj, + jbyteArray jarr, + jlong jfrom, jlong jto) { + jbyte* data = env->GetByteArrayElements(jarr, NULL); - DLTensor *from = reinterpret_cast(jfrom); - from->data = static_cast(data); + DLTensor* from = reinterpret_cast(jfrom); + from->data = static_cast(data); int ret = TVMArrayCopyFromTo(static_cast(from), reinterpret_cast(jto), NULL); @@ -439,13 +438,14 @@ JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayCopyFromJArray( return ret; } -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayCopyToJArray( - JNIEnv *env, jobject obj, jlong jfrom, jbyteArray jarr) { - DLTensor *from = reinterpret_cast(jfrom); +JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayCopyToJArray(JNIEnv* env, jobject obj, + jlong jfrom, + jbyteArray jarr) { + DLTensor* from = reinterpret_cast(jfrom); int size = static_cast(env->GetArrayLength(jarr)); - jbyte *pdata = env->GetByteArrayElements(jarr, NULL); + jbyte* pdata = env->GetByteArrayElements(jarr, NULL); int ret = 0; - if (memcpy(static_cast(pdata), from->data, size) == NULL) { + if (memcpy(static_cast(pdata), from->data, size) == NULL) { ret = 1; } env->ReleaseByteArrayElements(jarr, pdata, 0); // copy back to java array automatically @@ -453,7 +453,7 @@ JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayCopyToJArray( } // Context -JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmSynchronize( - JNIEnv *env, jint deviceType, jint deviceId) { +JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmSynchronize(JNIEnv* env, jint deviceType, + jint deviceId) { return TVMSynchronize(static_cast(deviceType), static_cast(deviceId), NULL); } diff --git a/nnvm/include/nnvm/base.h b/nnvm/include/nnvm/base.h index 678ed4d4a942..b8c5c6c5ed41 100644 --- a/nnvm/include/nnvm/base.h +++ b/nnvm/include/nnvm/base.h @@ -24,13 +24,13 @@ #ifndef NNVM_BASE_H_ #define NNVM_BASE_H_ +#include +#include #include #include -#include -#include #include +#include #include -#include namespace nnvm { @@ -52,7 +52,7 @@ enum TypeFlag { kFloat16 = 2, kUint8 = 3, kInt32 = 4, - kInt8 = 5, + kInt8 = 5, kInt64 = 6, // kBool = 7, // 7 is reserved for kBool, in order to keep consistency with MXNet TypeFlag defined in diff --git a/nnvm/include/nnvm/c_api.h b/nnvm/include/nnvm/c_api.h index b35e4da343f7..e6efb79e8626 100644 --- a/nnvm/include/nnvm/c_api.h +++ b/nnvm/include/nnvm/c_api.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -41,11 +41,11 @@ typedef unsigned int nn_uint; /*! \brief handle to a function that takes param and creates symbol */ -typedef void *OpHandle; +typedef void* OpHandle; /*! \brief handle to a symbol that can be bind as operator */ -typedef void *SymbolHandle; +typedef void* SymbolHandle; /*! \brief handle to Graph */ -typedef void *GraphHandle; +typedef void* GraphHandle; #ifdef __cplusplus extern "C" { @@ -65,7 +65,7 @@ NNVM_DLL void NNAPISetLastError(const char* msg); * this function is threadsafe and can be called by different thread * \return error info */ -NNVM_DLL const char *NNGetLastError(void); +NNVM_DLL const char* NNGetLastError(void); /*! * \brief list all the available operator names, include entries. @@ -73,16 +73,14 @@ NNVM_DLL const char *NNGetLastError(void); * \param out_array the output operator name array. * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNListAllOpNames(nn_uint *out_size, - const char*** out_array); +NNVM_DLL int NNListAllOpNames(nn_uint* out_size, const char*** out_array); /*! * \brief Get operator handle given name. * \param op_name The name of the operator. * \param op_out The returnning op handle. */ -NNVM_DLL int NNGetOpHandle(const char* op_name, - OpHandle* op_out); +NNVM_DLL int NNGetOpHandle(const char* op_name, OpHandle* op_out); /*! * \brief list all the available operators. @@ -93,8 +91,7 @@ NNVM_DLL int NNGetOpHandle(const char* op_name, * \param out_array the output AtomicSymbolCreator array * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNListUniqueOps(nn_uint *out_size, - OpHandle **out_array); +NNVM_DLL int NNListUniqueOps(nn_uint* out_size, OpHandle** out_array); /*! * \brief Get the detailed information about atomic symbol. @@ -109,14 +106,10 @@ NNVM_DLL int NNListUniqueOps(nn_uint *out_size, * \param return_type Return type of the function, if any. * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNGetOpInfo(OpHandle op, - const char **real_name, - const char **description, - nn_uint *num_doc_args, - const char ***arg_names, - const char ***arg_type_infos, - const char ***arg_descriptions, - const char **return_type); +NNVM_DLL int NNGetOpInfo(OpHandle op, const char** real_name, const char** description, + nn_uint* num_doc_args, const char*** arg_names, + const char*** arg_type_infos, const char*** arg_descriptions, + const char** return_type); /*! * \brief Create an AtomicSymbol functor. * \param op The operator handle @@ -126,18 +119,15 @@ NNVM_DLL int NNGetOpInfo(OpHandle op, * \param out pointer to the created symbol handle * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNSymbolCreateAtomicSymbol(OpHandle op, - nn_uint num_param, - const char **keys, - const char **vals, - SymbolHandle *out); +NNVM_DLL int NNSymbolCreateAtomicSymbol(OpHandle op, nn_uint num_param, const char** keys, + const char** vals, SymbolHandle* out); /*! * \brief Create a Variable Symbol. * \param name name of the variable * \param out pointer to the created symbol handle * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNSymbolCreateVariable(const char *name, SymbolHandle *out); +NNVM_DLL int NNSymbolCreateVariable(const char* name, SymbolHandle* out); /*! * \brief Create a Symbol by grouping list of symbols together * \param num_symbols number of symbols to be grouped @@ -145,16 +135,13 @@ NNVM_DLL int NNSymbolCreateVariable(const char *name, SymbolHandle *out); * \param out pointer to the created symbol handle * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNSymbolCreateGroup(nn_uint num_symbols, - SymbolHandle *symbols, - SymbolHandle *out); +NNVM_DLL int NNSymbolCreateGroup(nn_uint num_symbols, SymbolHandle* symbols, SymbolHandle* out); /*! * \brief Add src_dep to the handle as control dep. * \param handle The symbol to add dependency edges on. * \param src_dep the source handles. */ -NNVM_DLL int NNAddControlDeps(SymbolHandle handle, - SymbolHandle src_dep); +NNVM_DLL int NNAddControlDeps(SymbolHandle handle, SymbolHandle src_dep); /*! * \brief Free the symbol handle. * \param symbol the symbol @@ -167,14 +154,14 @@ NNVM_DLL int NNSymbolFree(SymbolHandle symbol); * \param out used to hold the result of copy * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNSymbolCopy(SymbolHandle symbol, SymbolHandle *out); +NNVM_DLL int NNSymbolCopy(SymbolHandle symbol, SymbolHandle* out); /*! * \brief Print the content of symbol, used for debug. * \param symbol the symbol * \param out_str pointer to hold the output string of the printing. * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNSymbolPrint(SymbolHandle symbol, const char **out_str); +NNVM_DLL int NNSymbolPrint(SymbolHandle symbol, const char** out_str); /*! * \brief Get string attribute from symbol * \param symbol the source symbol @@ -183,13 +170,11 @@ NNVM_DLL int NNSymbolPrint(SymbolHandle symbol, const char **out_str); * \param success Whether the result is contained in out. * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNSymbolGetAttr(SymbolHandle symbol, - const char* key, - const char** out, - int *success); +NNVM_DLL int NNSymbolGetAttr(SymbolHandle symbol, const char* key, const char** out, int* success); /*! * \brief Set string attribute from symbol. - * NOTE: Setting attribute to a symbol can affect the semantics(mutable/immutable) of symbolic graph. + * NOTE: Setting attribute to a symbol can affect the semantics(mutable/immutable) of symbolic + * graph. * * Safe recommendaton: use immutable graph * - Only allow set attributes during creation of new symbol as optional parameter @@ -204,9 +189,7 @@ NNVM_DLL int NNSymbolGetAttr(SymbolHandle symbol, * \param values The value to be set * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNSymbolSetAttrs(SymbolHandle symbol, - nn_uint num_param, - const char** keys, +NNVM_DLL int NNSymbolSetAttrs(SymbolHandle symbol, nn_uint num_param, const char** keys, const char** values); /*! * \brief Get all attributes from symbol, including all descendents. @@ -216,9 +199,7 @@ NNVM_DLL int NNSymbolSetAttrs(SymbolHandle symbol, * \param out 2*out_size strings representing key value pairs. * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNSymbolListAttrs(SymbolHandle symbol, - int recursive_option, - nn_uint *out_size, +NNVM_DLL int NNSymbolListAttrs(SymbolHandle symbol, int recursive_option, nn_uint* out_size, const char*** out); /*! @@ -232,9 +213,7 @@ NNVM_DLL int NNSymbolListAttrs(SymbolHandle symbol, * \param out_sym_array the output array. * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNSymbolListInputVariables(SymbolHandle symbol, - int option, - nn_uint *out_size, +NNVM_DLL int NNSymbolListInputVariables(SymbolHandle symbol, int option, nn_uint* out_size, SymbolHandle** out_sym_array); /*! @@ -248,10 +227,8 @@ NNVM_DLL int NNSymbolListInputVariables(SymbolHandle symbol, * \param out_str_array pointer to hold the output string array * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNSymbolListInputNames(SymbolHandle symbol, - int option, - nn_uint *out_size, - const char ***out_str_array); +NNVM_DLL int NNSymbolListInputNames(SymbolHandle symbol, int option, nn_uint* out_size, + const char*** out_str_array); /*! * \brief List returns names in the symbol. * \param symbol the symbol @@ -259,10 +236,8 @@ NNVM_DLL int NNSymbolListInputNames(SymbolHandle symbol, * \param out_str_array pointer to hold the output string array * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNSymbolListOutputNames(SymbolHandle symbol, - nn_uint *out_size, - const char ***out_str_array); - +NNVM_DLL int NNSymbolListOutputNames(SymbolHandle symbol, nn_uint* out_size, + const char*** out_str_array); /*! * \brief Supply number of outputs of the symbol. @@ -270,8 +245,7 @@ NNVM_DLL int NNSymbolListOutputNames(SymbolHandle symbol, * \param output_count number of outputs * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNSymbolGetNumOutputs(SymbolHandle symbol, - nn_uint *output_count); +NNVM_DLL int NNSymbolGetNumOutputs(SymbolHandle symbol, nn_uint* output_count); /*! * \brief Get a symbol that contains all the internals. @@ -279,16 +253,14 @@ NNVM_DLL int NNSymbolGetNumOutputs(SymbolHandle symbol, * \param out The output symbol whose outputs are all the internals. * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNSymbolGetInternals(SymbolHandle symbol, - SymbolHandle *out); +NNVM_DLL int NNSymbolGetInternals(SymbolHandle symbol, SymbolHandle* out); /*! * \brief Get a symbol that contains only direct children. * \param symbol The symbol * \param out The output symbol whose outputs are the direct children. * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNSymbolGetChildren(SymbolHandle symbol, - SymbolHandle *out); +NNVM_DLL int NNSymbolGetChildren(SymbolHandle symbol, SymbolHandle* out); /*! * \brief Get index-th outputs of the symbol. * \param symbol The symbol @@ -296,9 +268,7 @@ NNVM_DLL int NNSymbolGetChildren(SymbolHandle symbol, * \param out The output symbol whose outputs are the index-th symbol. * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNSymbolGetOutput(SymbolHandle symbol, - nn_uint index, - SymbolHandle *out); +NNVM_DLL int NNSymbolGetOutput(SymbolHandle symbol, nn_uint index, SymbolHandle* out); /*! * \brief Compose the symbol on other symbols. @@ -314,11 +284,8 @@ NNVM_DLL int NNSymbolGetOutput(SymbolHandle symbol, * \param args arguments to sym * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNSymbolCompose(SymbolHandle sym, - const char* name, - nn_uint num_args, - const char** keys, - SymbolHandle* args); +NNVM_DLL int NNSymbolCompose(SymbolHandle sym, const char* name, nn_uint num_args, + const char** keys, SymbolHandle* args); // Graph IR API /*! @@ -327,7 +294,7 @@ NNVM_DLL int NNSymbolCompose(SymbolHandle sym, * \param graph The graph handle created. * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNGraphCreate(SymbolHandle symbol, GraphHandle *graph); +NNVM_DLL int NNGraphCreate(SymbolHandle symbol, GraphHandle* graph); /*! * \brief free the graph handle * \param handle The handle to be freed. @@ -339,7 +306,7 @@ NNVM_DLL int NNGraphFree(GraphHandle handle); * \param symbol The corresponding symbol * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNGraphGetSymbol(GraphHandle graph, SymbolHandle *symbol); +NNVM_DLL int NNGraphGetSymbol(GraphHandle graph, SymbolHandle* symbol); /*! * \brief Get Set a attribute in json format. @@ -351,9 +318,7 @@ NNVM_DLL int NNGraphGetSymbol(GraphHandle graph, SymbolHandle *symbol); * Where type_name is a registered type string in C++ side via DMLC_JSON_ENABLE_ANY. * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNGraphSetJSONAttr(GraphHandle handle, - const char* key, - const char* json_value); +NNVM_DLL int NNGraphSetJSONAttr(GraphHandle handle, const char* key, const char* json_value); /*! * \brief Get a serialized attrirbute from graph. @@ -367,10 +332,8 @@ NNVM_DLL int NNGraphSetJSONAttr(GraphHandle handle, * \param success Whether the result is contained in out. * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNGraphGetJSONAttr(GraphHandle handle, - const char* key, - const char** json_out, - int *success); +NNVM_DLL int NNGraphGetJSONAttr(GraphHandle handle, const char* key, const char** json_out, + int* success); /*! * \brief Set a attribute whose type is std::vector in c++ @@ -383,9 +346,7 @@ NNVM_DLL int NNGraphGetJSONAttr(GraphHandle handle, * \param list The symbol whose outputs represents the list of NodeEntry to be passed. * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNGraphSetNodeEntryListAttr_(GraphHandle handle, - const char* key, - SymbolHandle list); +NNVM_DLL int NNGraphSetNodeEntryListAttr_(GraphHandle handle, const char* key, SymbolHandle list); /*! * \brief Apply passes on the src graph. * \param src The source graph handle. @@ -394,10 +355,8 @@ NNVM_DLL int NNGraphSetNodeEntryListAttr_(GraphHandle handle, * \param dst The result graph. * \return 0 when success, -1 when failure happens */ -NNVM_DLL int NNGraphApplyPasses(GraphHandle src, - nn_uint num_pass, - const char** pass_names, - GraphHandle *dst); +NNVM_DLL int NNGraphApplyPasses(GraphHandle src, nn_uint num_pass, const char** pass_names, + GraphHandle* dst); #ifdef __cplusplus } /* end extern "C" */ diff --git a/nnvm/include/nnvm/graph.h b/nnvm/include/nnvm/graph.h index 1911a0337ac2..475494e62c4d 100644 --- a/nnvm/include/nnvm/graph.h +++ b/nnvm/include/nnvm/graph.h @@ -24,13 +24,14 @@ #ifndef NNVM_GRAPH_H_ #define NNVM_GRAPH_H_ -#include -#include -#include #include #include +#include #include #include +#include +#include + #include "base.h" #include "node.h" #include "symbolic.h" @@ -64,7 +65,7 @@ class Graph { * \return the reference to corresponding attribute * \tparam T the type of the attribute. */ - template + template inline const T& GetAttr(const std::string& attr_name) const; /*! * \brief Check whether has a specific attribute. @@ -81,7 +82,7 @@ class Graph { * \return a new copy of the corresponding attribute. * \tparam T the type of the attribute. */ - template + template inline T MoveCopyAttr(const std::string& attr_name); /*! * \brief get a indexed graph of current graph, if not exist, create it on demand @@ -127,13 +128,9 @@ class IndexedGraph { std::weak_ptr weak_ref; }; /*! \return number of nodes in the graph */ - inline size_t num_nodes() const { - return nodes_.size(); - } + inline size_t num_nodes() const { return nodes_.size(); } /*! \return total number of NodeEntry in the graph */ - inline size_t num_node_entries() const { - return entry_rptr_.back(); - } + inline size_t num_node_entries() const { return entry_rptr_.back(); } /*! * \brief Get a unique entry id between 0 to num_node_entries() * for a given IndexedGraph::NodeEntry @@ -150,9 +147,7 @@ class IndexedGraph { * \param e The entry to query for index. * \return the unique index. */ - inline uint32_t entry_id(const NodeEntry& e) const { - return entry_rptr_[e.node_id] + e.index; - } + inline uint32_t entry_id(const NodeEntry& e) const { return entry_rptr_[e.node_id] + e.index; } /*! * \brief Get a unique entry id between 0 to num_node_entries() * for a given NodeEntry. @@ -167,42 +162,30 @@ class IndexedGraph { * \param node The Node to query for index. * \return the node index. */ - inline uint32_t node_id(const nnvm::Node* node) const { - return node2index_.at(node); - } + inline uint32_t node_id(const nnvm::Node* node) const { return node2index_.at(node); } /*! * \brief Get the corresponding Node structure for a given node_id. * \param node_id The node id * \return const reference to the corresponding IndexedGraph::Node */ - inline const Node& operator[](uint32_t node_id) const { - return nodes_[node_id]; - } + inline const Node& operator[](uint32_t node_id) const { return nodes_[node_id]; } /*! * \brief Get the corresponding Node structure * \param node The pointer to the Node structure * \return const reference to the corresponding IndexedGraph::Node */ - inline const Node& operator[](const nnvm::Node* node) const { - return nodes_[node_id(node)]; - } + inline const Node& operator[](const nnvm::Node* node) const { return nodes_[node_id(node)]; } /*! \return list of argument nodes */ - inline const std::vector& input_nodes() const { - return input_nodes_; - } + inline const std::vector& input_nodes() const { return input_nodes_; } /*! \return list of mutable nodes */ inline const std::unordered_set& mutable_input_nodes() const { return mutable_input_nodes_; } /*! \return list of output entries */ - inline const std::vector& outputs() const { - return outputs_; - } + inline const std::vector& outputs() const { return outputs_; } /*! \return whether a node is existed in the indexed graph */ - inline bool exist(const nnvm::Node* node) const { - return node2index_.count(node); - } + inline bool exist(const nnvm::Node* node) const { return node2index_.count(node); } // disalllow copy assign IndexedGraph(const IndexedGraph&) = delete; @@ -239,15 +222,14 @@ class IndexedGraph { * \param fvisit a function of type std::function&)> * \tparam FVisit The function type to perform the visit. */ -template +template inline void DFSVisit(const std::vector& heads, FVisit fvisit); // inline function implementations -template +template inline const T& Graph::GetAttr(const std::string& attr_name) const { auto it = attrs.find(attr_name); - CHECK(it != attrs.end()) - << "Cannot find attribute " << attr_name << " in the graph"; + CHECK(it != attrs.end()) << "Cannot find attribute " << attr_name << " in the graph"; return nnvm::unsafe_get(*it->second); } @@ -256,11 +238,10 @@ inline bool Graph::HasAttr(const std::string& attr_name) const { return it != attrs.end(); } -template +template inline T Graph::MoveCopyAttr(const std::string& attr_name) { auto it = attrs.find(attr_name); - CHECK(it != attrs.end()) - << "Cannot find attribute " << attr_name << " in the graph"; + CHECK(it != attrs.end()) << "Cannot find attribute " << attr_name << " in the graph"; std::shared_ptr sptr = it->second; attrs.erase(it); if (sptr.unique()) { @@ -270,14 +251,10 @@ inline T Graph::MoveCopyAttr(const std::string& attr_name) { } } -template -void PostOrderDFSVisit(const std::vector& heads, - FVisit fvisit, - HashFunc hash, - InDegree indegree, - GetInput getinput) { +template +void PostOrderDFSVisit(const std::vector& heads, FVisit fvisit, HashFunc hash, + InDegree indegree, GetInput getinput) { std::vector > stack; std::unordered_set visited; for (auto& head : heads) { @@ -303,28 +280,20 @@ void PostOrderDFSVisit(const std::vector& heads, } } -template -inline void DFSVisit(const std::vector& heads, - FVisit fvisit) { +template +inline void DFSVisit(const std::vector& heads, FVisit fvisit) { typedef const ObjectPtr* GNode; std::vector head_nodes(heads.size()); std::transform(heads.begin(), heads.end(), head_nodes.begin(), - [](const NodeEntry& e)->GNode { - return &e.node; - }); + [](const NodeEntry& e) -> GNode { return &e.node; }); PostOrderDFSVisit( - head_nodes, - [fvisit](GNode n) { - fvisit(*n); - }, // FVisit - [](GNode n)->Node* { - return n->get(); - }, // HashFunc - [](GNode n)->uint32_t { // InDegree + head_nodes, [fvisit](GNode n) { fvisit(*n); }, // FVisit + [](GNode n) -> Node* { return n->get(); }, // HashFunc + [](GNode n) -> uint32_t { // InDegree if (!(*n)) return 0; return (*n)->inputs.size() + (*n)->control_deps.size(); - }, - [](GNode n, uint32_t index)->GNode { // GetInput + }, + [](GNode n, uint32_t index) -> GNode { // GetInput if (index < (*n)->inputs.size()) { return &(*n)->inputs.at(index).node; } else { diff --git a/nnvm/include/nnvm/graph_attr_types.h b/nnvm/include/nnvm/graph_attr_types.h index acc52a2ae1db..9e0185526eef 100644 --- a/nnvm/include/nnvm/graph_attr_types.h +++ b/nnvm/include/nnvm/graph_attr_types.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -24,11 +24,12 @@ #ifndef NNVM_GRAPH_ATTR_TYPES_H_ #define NNVM_GRAPH_ATTR_TYPES_H_ -#include #include #include -#include "tuple.h" +#include + #include "layout.h" +#include "tuple.h" namespace nnvm { diff --git a/nnvm/include/nnvm/layout.h b/nnvm/include/nnvm/layout.h index 3a81b84b2487..e2e99784c99e 100644 --- a/nnvm/include/nnvm/layout.h +++ b/nnvm/include/nnvm/layout.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -31,11 +31,12 @@ #define NNVM_LAYOUT_H_ #include -#include + +#include #include -#include +#include #include -#include +#include namespace nnvm { @@ -44,7 +45,7 @@ class Layout { using LayoutDim = char; /*! \brief default constructor */ - Layout() : name_("__undef__") {} // NOLINT(*) + Layout() : name_("__undef__") {} // NOLINT(*) /*! * \brief construct from a string. @@ -54,21 +55,21 @@ class Layout { * indicates the split dimension. * return undefined layout if "__undef__" is passed. */ - inline Layout(const std::string& layout) { // NOLINT(*) + inline Layout(const std::string& layout) { // NOLINT(*) parse(layout); } /*! * \brief copy constructor from another layout * \param s the source layout */ - inline Layout(const Layout& s) { // NOLINT(*) + inline Layout(const Layout& s) { // NOLINT(*) this->parse(s.name_); } /*! * \brief move constructor from Layout * \param src the source layout */ - inline Layout(Layout&& src) { // NOLINT(*) + inline Layout(Layout&& src) { // NOLINT(*) this->swap(src); } /*! @@ -86,7 +87,7 @@ class Layout { * \return reference of self */ inline Layout& operator=(Layout&& src) { - Layout(std::move(src)).swap(*this); // NOLINT(*) + Layout(std::move(src)).swap(*this); // NOLINT(*) return *this; } /*! @@ -102,16 +103,12 @@ class Layout { * \return whether two layout equals * \param s the layout to compare against */ - inline bool operator==(const Layout& s) const { - return name_ == s.name_; - } + inline bool operator==(const Layout& s) const { return name_ == s.name_; } /*! * \return whether two layout not equal * \param s the layout to compare against */ - inline bool operator!=(const Layout& s) const { - return !(*this == s); - } + inline bool operator!=(const Layout& s) const { return !(*this == s); } /*! * \brief Append the current layout by another. @@ -134,18 +131,14 @@ class Layout { * \param dim input dimension * \return Whether a given dimension is a super-dimension. */ - static inline bool is_superdim(LayoutDim dim) { - return dim >= 'A' && dim <= 'Z'; - } + static inline bool is_superdim(LayoutDim dim) { return dim >= 'A' && dim <= 'Z'; } /*! * \brief Check whether a given dimension is a sub-dimension. * \param dim input dimension * \return Whether a given dimension is a sub-dimension. */ - static inline bool is_subdim(LayoutDim dim) { - return dim >= 'a' && dim <= 'z'; - } + static inline bool is_subdim(LayoutDim dim) { return dim >= 'a' && dim <= 'z'; } /*! * \brief Convert a given dimension to super-dimension. @@ -200,7 +193,7 @@ class Layout { * \param dst the target layout * \return Whether can be converted to dst layout. */ - inline bool convertible(const Layout &dst) const { + inline bool convertible(const Layout& dst) const { if (!this->defined() || !dst.defined()) return false; for (size_t i = 0; i < kUniqueDim; ++i) { if ((superdim_pos_[i] >= 0 && dst.superdim_pos_[i] < 0) || @@ -258,13 +251,12 @@ class Layout { * \return A newly constructed Layout object. */ inline Layout split(LayoutDim dim, size_t target_pos, uint32_t size) const { - CHECK(target_pos <= this->ndim()) << "Invalid split position " - << target_pos << " for layout " << name_; + CHECK(target_pos <= this->ndim()) + << "Invalid split position " << target_pos << " for layout " << name_; CHECK(is_superdim(dim)) << "Cannot split a sub-dimension " << dim; CHECK(this->contains(dim)) << "Axis " << dim << " does not exist in " << name_; - CHECK(!this->contains(to_subdim(dim))) << "Dimension " << dim - << " has already been split in " - << name_; + CHECK(!this->contains(to_subdim(dim))) + << "Dimension " << dim << " has already been split in " << name_; CHECK(size > 0) << "Invalid split size " << size; std::ostringstream new_layout; for (size_t i = 0; i <= this->ndim(); ++i) { @@ -282,26 +274,16 @@ class Layout { using reverse_iterator = std::vector::const_reverse_iterator; /*! \return begin iterator */ - inline iterator begin() const { - return layout_simplified_.begin(); - } + inline iterator begin() const { return layout_simplified_.begin(); } /*! \return end iterator */ - inline iterator end() const { - return layout_simplified_.end(); - } + inline iterator end() const { return layout_simplified_.end(); } /*! \return rbegin iterator */ - inline reverse_iterator rbegin() const { - return layout_simplified_.rbegin(); - } + inline reverse_iterator rbegin() const { return layout_simplified_.rbegin(); } /*! \return rend iterator */ - inline reverse_iterator rend() const { - return layout_simplified_.rend(); - } + inline reverse_iterator rend() const { return layout_simplified_.rend(); } /*! \return number of dimensions */ - inline size_t ndim() const { - return layout_simplified_.size(); - } + inline size_t ndim() const { return layout_simplified_.size(); } /*! * \brief The description of the \p i-th dimension. @@ -311,8 +293,7 @@ class Layout { * \return the description of the dimension. */ inline std::string at(size_t i) const { - CHECK_LT(i, this->ndim()) << "position " << i - << " exceeds ndim=" << this->ndim(); + CHECK_LT(i, this->ndim()) << "position " << i << " exceeds ndim=" << this->ndim(); std::ostringstream repr; if (is_subdim(layout_simplified_[i])) { auto factor = subsizeof(layout_simplified_[i]); @@ -331,9 +312,12 @@ class Layout { * \return the index or -1 if not found. */ inline int32_t indexof(LayoutDim dim) const { - if (!this->defined()) return -1; - else if (is_superdim(dim)) return superdim_pos_[dim - 'A']; - else if (is_subdim(dim)) return subdim_pos_[dim - 'a']; + if (!this->defined()) + return -1; + else if (is_superdim(dim)) + return superdim_pos_[dim - 'A']; + else if (is_subdim(dim)) + return subdim_pos_[dim - 'a']; return -1; } @@ -359,34 +343,26 @@ class Layout { */ inline bool contains(LayoutDim dim) const { if (is_superdim(dim)) { - return superdim_pos_[dim-'A'] >= 0; + return superdim_pos_[dim - 'A'] >= 0; } else if (is_subdim(dim)) { - return subdim_pos_[dim-'a'] >= 0; + return subdim_pos_[dim - 'a'] >= 0; } return false; } - inline LayoutDim operator[](size_t i) const { - return layout_simplified_[i]; - } + inline LayoutDim operator[](size_t i) const { return layout_simplified_[i]; } /*! \return whether the layout is defined */ - inline bool defined() const { - return name_ != "__undef__"; - } + inline bool defined() const { return name_ != "__undef__"; } /*! \return the string description of the layout */ - inline const std::string& name() const { - return name_; - } + inline const std::string& name() const { return name_; } /*! * \brief Write layout in JSON format. * \param writer JSONWriter */ - inline void Save(dmlc::JSONWriter* writer) const { - writer->Write(name_); - } + inline void Save(dmlc::JSONWriter* writer) const { writer->Write(name_); } /*! * \brief Load layout from JSON. @@ -433,21 +409,20 @@ class Layout { const LayoutDim c = layout.at(i); if (is_superdim(c)) { int pos = c - 'A'; - CHECK_EQ(factor, 0) << "Invalid layout " << layout - << ": invalid factor size " << factor + CHECK_EQ(factor, 0) << "Invalid layout " << layout << ": invalid factor size " << factor << " before dimension " << c; - CHECK_EQ(superdim_pos_[pos], -1) << "Invalid layout " << layout - << ": duplicate dimension " << c; + CHECK_EQ(superdim_pos_[pos], -1) + << "Invalid layout " << layout << ": duplicate dimension " << c; superdim_pos_[pos] = curr++; layout_simplified_.push_back(c); } else if (is_subdim(c)) { int pos = c - 'a'; - CHECK_GT(factor, 0) << "Invalid layout " << layout << ": invalid factor size " - << factor << " for dimension " << c; - CHECK_EQ(subdim_pos_[pos], -1) << "Invalid layout " << layout - << ": duplicate dimension " << c; - CHECK_EQ(subdim_size_[pos], -1) << "Invalid layout " << layout - << ": duplicate dimension " << c; + CHECK_GT(factor, 0) << "Invalid layout " << layout << ": invalid factor size " << factor + << " for dimension " << c; + CHECK_EQ(subdim_pos_[pos], -1) + << "Invalid layout " << layout << ": duplicate dimension " << c; + CHECK_EQ(subdim_size_[pos], -1) + << "Invalid layout " << layout << ": duplicate dimension " << c; subdim_pos_[pos] = curr++; subdim_size_[pos] = factor; layout_simplified_.push_back(c); @@ -461,9 +436,8 @@ class Layout { } CHECK(!layout_simplified_.empty()) << "Invalid layout " << layout; for (LayoutDim dim : layout_simplified_) { - CHECK(is_superdim(dim) || superdim_pos_[dim-'a'] >= 0) - << "Invalid layout " << layout << ": missing axis " - << static_cast(dim - 'a' + 'A'); + CHECK(is_superdim(dim) || superdim_pos_[dim - 'a'] >= 0) + << "Invalid layout " << layout << ": missing axis " << static_cast(dim - 'a' + 'A'); } } }; diff --git a/nnvm/include/nnvm/node.h b/nnvm/include/nnvm/node.h index 95a7ce23e4da..1b2dda2bbb69 100644 --- a/nnvm/include/nnvm/node.h +++ b/nnvm/include/nnvm/node.h @@ -26,12 +26,13 @@ #include #include -#include -#include #include +#include +#include + #include "base.h" -#include "op.h" #include "c_api.h" +#include "op.h" namespace nnvm { @@ -49,27 +50,16 @@ using ObjectPtr = std::shared_ptr; /*! \brief an entry that represents output data from a node */ struct NodeEntry { - NodeEntry(ObjectPtr node, uint32_t index, uint32_t version): - node(std::move(node)), - index(index), - version(version) - {} - - explicit NodeEntry(ObjectPtr node): - node(std::move(node)), - index(), - version() - {} + NodeEntry(ObjectPtr node, uint32_t index, uint32_t version) + : node(std::move(node)), index(index), version(version) {} + + explicit NodeEntry(ObjectPtr node) : node(std::move(node)), index(), version() {} /** * MXNet assumes that a node with a null ptr doesn't have a gradient attached. Don't change this * constructor. */ - NodeEntry(): - node(nullptr), - index(), - version() - {} + NodeEntry() : node(nullptr), index(), version() {} /*! \brief the source node of this data */ ObjectPtr node; @@ -79,7 +69,8 @@ struct NodeEntry { * \brief version of input Variable. * This field can only be nonzero when this->node is a Variable node. * version is increased by one each time a Variable get composed to a mutation Op. - * This information can be helpful to decide order of operations when sequence of mutation happens. + * This information can be helpful to decide order of operations when sequence of mutation + * happens. */ uint32_t version; }; @@ -90,9 +81,8 @@ struct NodeEntry { */ struct NodeEntryHash { size_t operator()(const NodeEntry& e) const { - return std::hash()(e.node.get()) ^ - (std::hash()(e.index) << 1 >> 1) ^ - (std::hash()(e.version) << 1); + return std::hash()(e.node.get()) ^ (std::hash()(e.index) << 1 >> 1) ^ + (std::hash()(e.version) << 1); } }; @@ -102,14 +92,12 @@ struct NodeEntryHash { */ struct NodeEntryEqual { size_t operator()(const NodeEntry& a, const NodeEntry& b) const { - return (a.node.get() == b.node.get()) && - (a.index == b.index) && - (a.version == b.version); + return (a.node.get() == b.node.get()) && (a.index == b.index) && (a.version == b.version); } }; /*! use NodeEntry as key in unordered_map */ -template +template using NodeEntryMap = std::unordered_map; /*! @@ -121,7 +109,7 @@ struct NodeAttrs { * \brief The operator this node uses. * For place holder variable, op == nullptr. */ - const Op *op{nullptr}; + const Op* op{nullptr}; /*! \brief name of the node */ std::string name; /*! \brief The dictionary representation of attributes */ @@ -190,7 +178,7 @@ class NNVM_DLL Node { * \brief create a new empty shared_ptr of Node. * \return a created empty node. */ - template + template static ObjectPtr Create(Args&&... args) { return std::make_shared(std::forward(args)...); } @@ -204,12 +192,9 @@ class NNVM_DLL Node { * \param attrs The attributes * \return The created node entry. */ -inline NodeEntry MakeNode( - const char* op_name, - std::string node_name, - std::vector inputs, - std::unordered_map attrs = - std::unordered_map()) { +inline NodeEntry MakeNode(const char* op_name, std::string node_name, std::vector inputs, + std::unordered_map attrs = + std::unordered_map()) { ObjectPtr p = Node::Create(); p->attrs.op = nnvm::Op::Get(op_name); p->attrs.name = std::move(node_name); @@ -222,13 +207,9 @@ inline NodeEntry MakeNode( } // implementation of functions. -inline const Op* Node::op() const { - return this->attrs.op; -} +inline const Op* Node::op() const { return this->attrs.op; } -inline bool Node::is_variable() const { - return this->op() == nullptr; -} +inline bool Node::is_variable() const { return this->op() == nullptr; } inline uint32_t Node::num_outputs() const { if (is_variable()) return 1; diff --git a/nnvm/include/nnvm/op.h b/nnvm/include/nnvm/op.h index 84abdc7a9363..d5794d88f705 100644 --- a/nnvm/include/nnvm/op.h +++ b/nnvm/include/nnvm/op.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -25,12 +25,14 @@ #define NNVM_OP_H_ #include + +#include +#include #include -#include -#include #include -#include -#include +#include +#include + #include "base.h" #include "c_api.h" @@ -39,7 +41,7 @@ namespace nnvm { // forward declarations class Node; struct NodeAttrs; -template +template class OpMap; class OpGroup; class OpRegistryEntry; @@ -208,15 +210,14 @@ class NNVM_DLL Op { * \param description Description of the argument. * \return reference to self. */ - inline Op& add_argument(const std::string &name, - const std::string &type, - const std::string &description); + inline Op& add_argument(const std::string& name, const std::string& type, + const std::string& description); /*! * \brief Append list if arguments to the end. * \param args Additional list of arguments. * \return reference to self. */ - inline Op& add_arguments(const std::vector &args); + inline Op& add_arguments(const std::vector& args); /*! * \brief Set the num_inputs * \param n The number of inputs to be set. @@ -234,7 +235,7 @@ class NNVM_DLL Op { * \param fn The function to be set. * \return reference to self. */ - inline Op& set_num_inputs(std::function fn); // NOLINT(*) + inline Op& set_num_inputs(std::function fn); // NOLINT(*) /*! * \brief Set the num_outputs * \param n The number of outputs to be set. @@ -246,26 +247,13 @@ class NNVM_DLL Op { * \param fn The function to be set. * \return reference to self. */ - inline Op& set_num_outputs(std::function fn); // NOLINT(*) - /*! - * \brief Set the fallback field. - * \param fallback The bool value to be set. - * \return reference to self. - */ - Op& set_fallback_device(bool fallback); // NOLINT(*) - /*! - * \brief Set the set_fallback_device function. - * \param fn The function to be set. - * \return reference to self. - */ - Op& set_fallback_device( - std::function fn); // NOLINT(*) + inline Op& set_num_outputs(std::function fn); // NOLINT(*) /*! * \brief Set the attr_parser function. * \param fn The number of outputs to be set. * \return reference to self. */ - inline Op& set_attr_parser(std::function fn); // NOLINT(*) + inline Op& set_attr_parser(std::function fn); // NOLINT(*) /*! * \brief Register additional attributes to operator. * \param attr_name The name of the attribute. @@ -279,10 +267,9 @@ class NNVM_DLL Op { * * \tparam ValueType The type of the value to be set. */ - template + template inline Op& set_attr(const std::string& attr_name, // NOLINT(*) - const ValueType& value, - int plevel = 10); + const ValueType& value, int plevel = 10); /*! * \brief Add another alias to this operator. * The same Op can be queried with Op::Get(alias) @@ -312,11 +299,11 @@ class NNVM_DLL Op { * \return An OpMap of specified attr_name. * \tparam ValueType The type of the attribute. */ - template + template static const OpMap& GetAttr(const std::string& attr_name); private: - template + template friend class OpMap; friend class OpGroup; friend class dmlc::Registry; @@ -328,15 +315,13 @@ class NNVM_DLL Op { // get const reference to certain attribute static const any* GetAttrMap(const std::string& key); // update the attribute OpMap - static void UpdateAttrMap(const std::string& key, - std::function updater); + static void UpdateAttrMap(const std::string& key, std::function updater); // add a trigger based on tag matching on certain tag attribute // This will apply trigger on all the op such that // include the corresponding group. // The trigger will also be applied to all future registrations // that calls include - static void AddGroupTrigger(const std::string& group_name, - std::function trigger); + static void AddGroupTrigger(const std::string& group_name, std::function trigger); }; /*! @@ -344,7 +329,7 @@ class NNVM_DLL Op { * and returns ValueType * \tparam ValueType The type of the value stored in map. */ -template +template class OpMap { public: /*! @@ -379,7 +364,7 @@ class OpMap { // internal attribute name std::string attr_name_; // internal data - std::vector > data_; + std::vector> data_; OpMap() = default; }; @@ -404,18 +389,17 @@ class OpGroup { * * \tparam ValueType The type of the value to be set. */ - template + template inline OpGroup& set_attr(const std::string& attr_name, // NOLINT(*) - const ValueType& value, - int plevel = 1); + const ValueType& value, int plevel = 1); }; // internal macros to make -#define NNVM_REGISTER_VAR_DEF(OpName) \ - static DMLC_ATTRIBUTE_UNUSED ::nnvm::Op & __make_ ## NnvmOp ## _ ## OpName +#define NNVM_REGISTER_VAR_DEF(OpName) \ + static DMLC_ATTRIBUTE_UNUSED ::nnvm::Op& __make_##NnvmOp##_##OpName -#define NNVM_REGISTER_GVAR_DEF(TagName) \ - static DMLC_ATTRIBUTE_UNUSED ::nnvm::OpGroup __make_ ## NnvmOpGroup ## _ ## TagName +#define NNVM_REGISTER_GVAR_DEF(TagName) \ + static DMLC_ATTRIBUTE_UNUSED ::nnvm::OpGroup __make_##NnvmOpGroup##_##TagName /*! * \def NNVM_REGISTER_OP @@ -432,8 +416,8 @@ class OpGroup { * * \endcode */ -#define NNVM_REGISTER_OP(OpName) \ - DMLC_STR_CONCAT(NNVM_REGISTER_VAR_DEF(OpName), __COUNTER__) = \ +#define NNVM_REGISTER_OP(OpName) \ + DMLC_STR_CONCAT(NNVM_REGISTER_VAR_DEF(OpName), __COUNTER__) = \ ::dmlc::Registry<::nnvm::Op>::Get()->__REGISTER_OR_GET__(#OpName) /*! @@ -457,85 +441,72 @@ class OpGroup { * * \endcode */ -#define NNVM_REGISTER_OP_GROUP(GroupName) \ - DMLC_STR_CONCAT(NNVM_REGISTER_GVAR_DEF(GroupName), __COUNTER__) = \ - ::nnvm::OpGroup {#GroupName} +#define NNVM_REGISTER_OP_GROUP(GroupName) \ + DMLC_STR_CONCAT(NNVM_REGISTER_GVAR_DEF(GroupName), __COUNTER__) = ::nnvm::OpGroup { #GroupName } // implementations of template functions after this. // member function of Op -template +template inline const OpMap& Op::GetAttr(const std::string& key) { const any* ref = GetAttrMap(key); if (ref == nullptr) { // update the attribute map of the key by creating new empty OpMap UpdateAttrMap(key, [key](any* pmap) { - // use callback so it is in lockscope - if (pmap->empty()) { - OpMap pm; - pm.attr_name_ = key; - *pmap = std::move(pm); - } - }); + // use callback so it is in lockscope + if (pmap->empty()) { + OpMap pm; + pm.attr_name_ = key; + *pmap = std::move(pm); + } + }); ref = GetAttrMap(key); } - return nnvm::get >(*ref); + return nnvm::get>(*ref); } -template +template inline Op& Op::set_attr( // NOLINT(*) - const std::string& attr_name, - const ValueType& value, - int plevel) { - CHECK_GT(plevel, 0) - << "plevel in set_attr must be greater than 0"; + const std::string& attr_name, const ValueType& value, int plevel) { + CHECK_GT(plevel, 0) << "plevel in set_attr must be greater than 0"; // update the attribute map of the key by creating new empty if needed. - UpdateAttrMap(attr_name, - [this, attr_name, value, plevel](any* pmap) { - // the callback is in lockscope so is threadsafe. - if (pmap->empty()) { - OpMap pm; - pm.attr_name_ = attr_name; - *pmap = std::move(pm); - } - CHECK(pmap->type() == typeid(OpMap)) - << "Attribute " << attr_name - << " of operator " << this->name - << " is registered as inconsistent types" - << " previously " << pmap->type().name() - << " current " << typeid(OpMap).name(); - std::vector >& vec = - nnvm::get >(*pmap).data_; - // resize the value type. - if (vec.size() <= index_) { - vec.resize(index_ + 1, - std::make_pair(ValueType(), 0)); - } - std::pair& p = vec[index_]; - CHECK(p.second != plevel) - << "Attribute " << attr_name - << " of operator " << this->name - << " is already registered with same plevel=" << plevel; - if (p.second < plevel) { - vec[index_] = std::make_pair(value, plevel); - } - }); + UpdateAttrMap(attr_name, [this, attr_name, value, plevel](any* pmap) { + // the callback is in lockscope so is threadsafe. + if (pmap->empty()) { + OpMap pm; + pm.attr_name_ = attr_name; + *pmap = std::move(pm); + } + CHECK(pmap->type() == typeid(OpMap)) + << "Attribute " << attr_name << " of operator " << this->name + << " is registered as inconsistent types" + << " previously " << pmap->type().name() << " current " << typeid(OpMap).name(); + std::vector>& vec = nnvm::get>(*pmap).data_; + // resize the value type. + if (vec.size() <= index_) { + vec.resize(index_ + 1, std::make_pair(ValueType(), 0)); + } + std::pair& p = vec[index_]; + CHECK(p.second != plevel) << "Attribute " << attr_name << " of operator " << this->name + << " is already registered with same plevel=" << plevel; + if (p.second < plevel) { + vec[index_] = std::make_pair(value, plevel); + } + }); return *this; } - inline Op& Op::describe(const std::string& descr) { // NOLINT(*) this->description = descr; return *this; } -inline Op& Op::add_argument(const std::string &name, - const std::string &type, - const std::string &description) { +inline Op& Op::add_argument(const std::string& name, const std::string& type, + const std::string& description) { arguments.push_back({name, type, type, description}); return *this; } -inline Op& Op::add_arguments(const std::vector &args) { +inline Op& Op::add_arguments(const std::vector& args) { this->arguments.insert(arguments.end(), args.begin(), args.end()); return *this; } @@ -550,7 +521,7 @@ inline Op& Op::set_support_level(uint32_t n) { // NOLINT(*) return *this; } -inline Op& Op::set_num_inputs(std::function fn) { // NOLINT(*) +inline Op& Op::set_num_inputs(std::function fn) { // NOLINT(*) this->get_num_inputs = fn; return *this; } @@ -560,29 +531,18 @@ inline Op& Op::set_num_outputs(uint32_t n) { // NOLINT(*) return *this; } -inline Op& Op::set_num_outputs(std::function fn) { // NOLINT(*) +inline Op& Op::set_num_outputs(std::function fn) { // NOLINT(*) this->get_num_outputs = fn; return *this; } -inline Op& Op::set_fallback_device(bool fallback) { // NOLINT(*) - this->fallback = fallback; - return *this; -} - -inline Op& Op::set_fallback_device( - std::function fn) { // NOLINT(*) - this->get_fallback_device = fn; - return *this; -} - -inline Op& Op::set_attr_parser(std::function fn) { // NOLINT(*) +inline Op& Op::set_attr_parser(std::function fn) { // NOLINT(*) this->attr_parser = fn; return *this; } // member functions of OpMap -template +template inline int OpMap::count(const Op* op) const { if (contains(op)) { return 1; @@ -591,7 +551,7 @@ inline int OpMap::count(const Op* op) const { } } -template +template inline bool OpMap::contains(const Op* op) const { if (op == nullptr) { return false; @@ -600,17 +560,16 @@ inline bool OpMap::contains(const Op* op) const { return idx < data_.size() ? (data_[idx].second != 0) : false; } -template +template inline const ValueType& OpMap::operator[](const Op* op) const { CHECK(op != nullptr); const uint32_t idx = op->index_; CHECK(idx < data_.size() && data_[idx].second) - << "Attribute " << attr_name_ - << " has not been registered for Operator " << op->name; + << "Attribute " << attr_name_ << " has not been registered for Operator " << op->name; return data_[idx].first; } -template +template inline const ValueType& OpMap::get(const Op* op, const ValueType& def_value) const { if (op == nullptr) return def_value; const uint32_t idx = op->index_; @@ -621,9 +580,8 @@ inline const ValueType& OpMap::get(const Op* op, const ValueType& def } } -template -inline OpGroup& OpGroup::set_attr(const std::string& attr_name, - const ValueType& value, +template +inline OpGroup& OpGroup::set_attr(const std::string& attr_name, const ValueType& value, int plevel) { auto trigger = [attr_name, value, plevel](Op* op) { op->set_attr(attr_name, value, plevel); diff --git a/nnvm/include/nnvm/op_attr_types.h b/nnvm/include/nnvm/op_attr_types.h index bf001e0f1be7..b6db3400418a 100644 --- a/nnvm/include/nnvm/op_attr_types.h +++ b/nnvm/include/nnvm/op_attr_types.h @@ -24,15 +24,16 @@ #ifndef NNVM_OP_ATTR_TYPES_H_ #define NNVM_OP_ATTR_TYPES_H_ -#include -#include -#include #include +#include #include +#include +#include + #include "base.h" +#include "layout.h" #include "node.h" #include "tuple.h" -#include "layout.h" namespace nnvm { @@ -48,7 +49,7 @@ namespace nnvm { * * FListInputNames enables automatic variable creation for missing arguments. */ -using FListInputNames = std::function (const NodeAttrs& attrs)>; +using FListInputNames = std::function(const NodeAttrs& attrs)>; /*! * \brief Return number of visible outputs by the user. @@ -60,7 +61,7 @@ using FListInputNames = std::function (const NodeAttrs& * but the additional outputs can be used to pass information from * forward to gradient pass. */ -using FNumVisibleOutputs = std::function; +using FNumVisibleOutputs = std::function; /*! * \brief Return list of output arguments names of each operator. @@ -71,7 +72,7 @@ using FNumVisibleOutputs = std::function; * * FListOutputNames customized naming for operator outputs. */ -using FListOutputNames = std::function (const NodeAttrs& attrs)>; +using FListOutputNames = std::function(const NodeAttrs& attrs)>; /*! * \brief Check whether operator will mutate k-th input. @@ -81,17 +82,16 @@ using FListOutputNames = std::function (const NodeAttrs * \note Register under "FMutateInputs", default return false * FMutateInputs enables mutation order handling correctly. */ -using FMutateInputs = std::function (const NodeAttrs& attrs)>; +using FMutateInputs = std::function(const NodeAttrs& attrs)>; /*! * \brief Inference function of certain type. * \tparam AttrType The type of the attribute to be infered. * \return whether all attributes are inferred. */ -template -using FInferNodeEntryAttr = std::function *in_attrs, - std::vector *out_attrs)>; +template +using FInferNodeEntryAttr = std::function* in_attrs, std::vector* out_attrs)>; /*! * \brief Get attribute dictionary from node. @@ -100,9 +100,8 @@ using FInferNodeEntryAttr = std::function - (const NodeAttrs& attrs)>; +using FGetAttrDict = + std::function(const NodeAttrs& attrs)>; /*! * \brief Shape inference function. @@ -155,8 +154,7 @@ using TIsGhost = bool; * * \note Register under "FInplaceOption", by default no inplace can happen. */ -using FInplaceOption = std::function< - std::vector > (const NodeAttrs& attrs)>; +using FInplaceOption = std::function >(const NodeAttrs& attrs)>; /*! * \brief Get if the inplace option is an identity @@ -168,7 +166,7 @@ using FInplaceOption = std::function< * * \note Register under "FInplaceIdentity", by default no identities. */ -using FInplaceIdentity = std::function (const NodeAttrs& attrs)>; +using FInplaceIdentity = std::function(const NodeAttrs& attrs)>; /*! * \brief Get list of inputs in the op whose content are actually not used by the operator @@ -179,8 +177,7 @@ using FInplaceIdentity = std::function (const NodeAttrs& attrs * * \note Register under "FIgnoreInputs". */ -using FIgnoreInputs = std::function< - std::vector (const NodeAttrs& attrs)>; +using FIgnoreInputs = std::function(const NodeAttrs& attrs)>; /*! * \brief Get the gradient node of the op node @@ -191,9 +188,8 @@ using FIgnoreInputs = std::function< * * \note Register under "FGradient" */ -using FGradient = std::function( - const ObjectPtr& nodeptr, - const std::vector& out_grads)>; +using FGradient = std::function(const ObjectPtr& nodeptr, + const std::vector& out_grads)>; /*! * \brief Set the attributes of input variable. @@ -202,10 +198,8 @@ using FGradient = std::function( * \param var the input variable * \param index index of var in all inputs */ -using FSetInputVarAttrOnCompose = std::function; +using FSetInputVarAttrOnCompose = + std::function; /*! * \brief Infer & correct function of node layout. See \p Layout for layout convention @@ -226,12 +220,9 @@ using FSetInputVarAttrOnCompose = std::function *ilayouts, - const std::vector *last_ilayouts, - std::vector *olayouts)>; - +using FCorrectLayout = + std::function* ilayouts, + const std::vector* last_ilayouts, std::vector* olayouts)>; /*! * \brief Infer & correct function of node layout. See \p Layout for layout convention @@ -254,12 +245,8 @@ using FCorrectLayout = std::function* ishapes, - std::vector* ilayouts, - const std::vector* last_ilayouts, - std::vector* olayouts)>; - + const NodeAttrs& attrs, std::vector* ishapes, std::vector* ilayouts, + const std::vector* last_ilayouts, std::vector* olayouts)>; /*! * \brief Get a list of inputs that represent graphs instead of data. diff --git a/nnvm/include/nnvm/pass.h b/nnvm/include/nnvm/pass.h index a6158df5ffdf..0bccdccd0791 100644 --- a/nnvm/include/nnvm/pass.h +++ b/nnvm/include/nnvm/pass.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -24,8 +24,9 @@ #ifndef NNVM_PASS_H_ #define NNVM_PASS_H_ -#include #include +#include + #include "base.h" #include "graph.h" @@ -42,7 +43,7 @@ namespace nnvm { * \param src The graph to be transformed. * \return The generated graph. */ -typedef std::function PassFunction; +typedef std::function PassFunction; /*! * \brief Apply a series of pass transformations on the input graph. @@ -50,8 +51,7 @@ typedef std::function PassFunction; * \param passes A list of pass names to be applied. * \return The transformed graph */ -Graph ApplyPasses(Graph src, - const std::vector& passes); +Graph ApplyPasses(Graph src, const std::vector& passes); /*! * \brief Apply one pass to the graph. @@ -59,17 +59,12 @@ Graph ApplyPasses(Graph src, * \param pass The name of pass to be applied. * \return The transformed graph. */ -inline Graph ApplyPass(Graph src, const std::string& pass) { - return ApplyPasses(src, {pass}); -} - +inline Graph ApplyPass(Graph src, const std::string& pass) { return ApplyPasses(src, {pass}); } /*! * \brief Registry entry for pass functions. */ -struct PassFunctionReg - : public dmlc::FunctionRegEntryBase { +struct PassFunctionReg : public dmlc::FunctionRegEntryBase { /*! * \brief Whether the pass will change graph structure * If this is false, the pass will only change attributes. @@ -138,7 +133,7 @@ struct PassFunctionReg * }); * \endcode */ -#define NNVM_REGISTER_PASS(name) \ +#define NNVM_REGISTER_PASS(name) \ DMLC_REGISTRY_REGISTER(::nnvm::PassFunctionReg, PassFunctionReg, name) } // namespace nnvm diff --git a/nnvm/include/nnvm/pass_functions.h b/nnvm/include/nnvm/pass_functions.h index a7893c6fec56..3097e20223d5 100644 --- a/nnvm/include/nnvm/pass_functions.h +++ b/nnvm/include/nnvm/pass_functions.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -28,13 +28,14 @@ #ifndef NNVM_PASS_FUNCTIONS_H_ #define NNVM_PASS_FUNCTIONS_H_ -#include #include -#include +#include #include +#include + #include "base.h" -#include "pass.h" #include "graph_attr_types.h" +#include "pass.h" namespace nnvm { namespace pass { @@ -60,7 +61,6 @@ inline std::string SaveJSON(Graph graph) { return ret.GetAttr("json"); } - /*! * \brief Print graph ir * \param graph The graph to be printed @@ -81,9 +81,7 @@ inline std::string PrintGraphIR(Graph graph) { * \param src The input graph. * \return A graph with proper control flow dependencies added. */ -inline Graph OrderMutation(Graph src) { - return ApplyPass(std::move(src), "OrderMutation"); -} +inline Graph OrderMutation(Graph src) { return ApplyPass(std::move(src), "OrderMutation"); } /*! * \brief Infer shapes in the graph given the information. @@ -94,9 +92,7 @@ inline Graph OrderMutation(Graph src) { * \return A graph with new attribute "shape" containing inferred shape of each NodeEntry. * The index of ShapeVector is given by graph.indexed_graph().entry_id. */ -inline Graph InferShape(Graph graph, - ShapeVector shape_inputs, - std::string shape_attr_key = "") { +inline Graph InferShape(Graph graph, ShapeVector shape_inputs, std::string shape_attr_key = "") { if (shape_inputs.size() != 0) { graph.attrs["shape_inputs"] = std::make_shared(std::move(shape_inputs)); } @@ -115,9 +111,7 @@ inline Graph InferShape(Graph graph, * \return A graph with new attribute "dtype" containing inferred type of each NodeEntry. * The index of ShapeVector is given by graph.indexed_graph().entry_id. */ -inline Graph InferType(Graph graph, - DTypeVector dtype_inputs, - std::string dtype_attr_key = "") { +inline Graph InferType(Graph graph, DTypeVector dtype_inputs, std::string dtype_attr_key = "") { if (dtype_inputs.size() != 0) { graph.attrs["dtype_inputs"] = std::make_shared(std::move(dtype_inputs)); } @@ -141,10 +135,8 @@ inline Graph InferType(Graph graph, * \param device_copy_op The name of copy op to be inserted when cross device copy happened. * \return A graph with new attribute "device", cotaining device information of each node. */ -inline Graph PlaceDevice(Graph graph, - std::string device_group_attr_key, - DeviceAssignMap device_assign_map, - std::string device_copy_op) { +inline Graph PlaceDevice(Graph graph, std::string device_group_attr_key, + DeviceAssignMap device_assign_map, std::string device_copy_op) { graph.attrs["device_group_attr_key"] = std::make_shared(std::move(device_group_attr_key)); graph.attrs["device_assign_map"] = std::make_shared(std::move(device_assign_map)); graph.attrs["device_copy_op"] = std::make_shared(std::move(device_copy_op)); @@ -159,22 +151,18 @@ inline Graph PlaceDevice(Graph graph, * \param ys_out_grad The symbol for additional gradient to be propagate back to y. * \param aggregate_fun Aggregation function applied to aggregate the inputs. * \param mirror_fun Optional mirror function to do mirror optimization and save memory. - * \param attr_hint_fun Optional, hint function to output a node that like src, but its attr is same as like. - * \param zero_ops Optional, list of operators that outputs a single zero array. The first one - * must be zeros_like. - * \param copy_op_str Optional, name of the copy operation required to handle duplicates - * on the edge of the graph - * \return A new graph, whose outputs correspond to inputs of xs. + * \param attr_hint_fun Optional, hint function to output a node that like src, but its attr is same + * as like. \param zero_ops Optional, list of operators that outputs a single zero array. The first + * one must be zeros_like. \param copy_op_str Optional, name of the copy operation required to + * handle duplicates on the edge of the graph \return A new graph, whose outputs correspond to + * inputs of xs. */ inline Graph Gradient( - Graph graph, - std::vector ys, - std::vector xs, + Graph graph, std::vector ys, std::vector xs, std::vector ys_out_grad, std::function&& inputs)> aggregate_fun = nullptr, std::function mirror_fun = nullptr, - std::function - attr_hint_fun = nullptr, + std::function attr_hint_fun = nullptr, std::vector zero_ops = std::vector(), std::string copy_op_str = std::string()) { graph.attrs["grad_ys"] = std::make_shared(std::move(ys)); @@ -198,7 +186,7 @@ inline Graph Gradient( } if (copy_op_str != std::string()) { - graph.attrs["copy_op"] = std::make_shared(std::move(copy_op_str)); + graph.attrs["copy_op"] = std::make_shared(std::move(copy_op_str)); } return ApplyPass(std::move(graph), "Gradient"); diff --git a/nnvm/include/nnvm/symbolic.h b/nnvm/include/nnvm/symbolic.h index d3555ec726b2..77d385505845 100644 --- a/nnvm/include/nnvm/symbolic.h +++ b/nnvm/include/nnvm/symbolic.h @@ -29,10 +29,10 @@ #define NNVM_SYMBOLIC_H_ #include -#include #include -#include #include +#include +#include #include "base.h" #include "node.h" @@ -81,13 +81,13 @@ class NNVM_DLL Symbol { * \brief Print the symbol info to output stream. * \param os The output stream to print to. */ - void Print(std::ostream &os) const; // NOLINT(*) + void Print(std::ostream& os) const; // NOLINT(*) /*! * \brief Get the index-th element from the returned tuple. * \param index Index of multi output. * \return The symbol corresponds to the indexed element. */ - Symbol operator[] (size_t index) const; + Symbol operator[](size_t index) const; /*! * \brief List the input variable nodes. * @@ -139,9 +139,9 @@ class NNVM_DLL Symbol { * \param name Name of returned symbol. * \return A new Symbol which is the composition of current symbol with its arguments. */ - Symbol operator () (const array_view& args, - const std::unordered_map& kwargs, - const std::string& name) const; + Symbol operator()(const array_view& args, + const std::unordered_map& kwargs, + const std::string& name) const; /*! * \brief Add control flow dependencies to the operators in symbols. * @@ -201,16 +201,14 @@ class NNVM_DLL Symbol { * * \return The created attribute in format . */ - std::vector > - ListAttrsRecursive() const; + std::vector > ListAttrsRecursive() const; /*! * \brief Create symbolic functor(AtomicSymbol) by given operator and attributes. * \param op The operator. * \param attrs The additional attributes. * \return Symbol that can be used to call compose further. */ - static Symbol CreateFunctor(const Op* op, - std::unordered_map attrs); + static Symbol CreateFunctor(const Op* op, std::unordered_map attrs); /*! * \brief Create symbolic functor(AtomicSymbol) by given node attributes. * \param attrs pre-initialized Node attributes. diff --git a/nnvm/include/nnvm/tuple.h b/nnvm/include/nnvm/tuple.h index a7f2d2603093..c6d6125aa194 100644 --- a/nnvm/include/nnvm/tuple.h +++ b/nnvm/include/nnvm/tuple.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -24,12 +24,13 @@ #ifndef NNVM_TUPLE_H_ #define NNVM_TUPLE_H_ -#include -#include #include -#include #include #include +#include +#include +#include + #include "base.h" namespace nnvm { @@ -47,29 +48,23 @@ typedef int64_t dim_t; * \tparam ValueType The type of data stored inside tuple. * \sa TShape */ -template +template class Tuple { public: /*! \brief default constructor */ Tuple() = default; /*! \brief destructor */ - inline ~Tuple() { - delete [] data_heap_; - } + inline ~Tuple() { delete[] data_heap_; } /*! * \brief copy constructor from another tuple * \param s the source tuple */ - inline Tuple(const Tuple& s) { - this->assign(s.begin(), s.end()); - } + inline Tuple(const Tuple& s) { this->assign(s.begin(), s.end()); } /*! * \brief constructor from initializer list * \param init the initializer_list */ - inline Tuple(std::initializer_list init) { - this->assign(init.begin(), init.end()); - } + inline Tuple(std::initializer_list init) { this->assign(init.begin(), init.end()); } /*! * \brief constructor from vector * \param init the vector @@ -82,7 +77,7 @@ class Tuple { * \param src the source shape */ - inline Tuple(Tuple&& src) { // NOLINT(runtime/explicit) + inline Tuple(Tuple&& src) { // NOLINT(runtime/explicit) this->swap(src); } /*! @@ -91,9 +86,8 @@ class Tuple { * \param end end the end of the iterator * \tparam RandomAccessIterator iterator type */ - template - inline Tuple(RandomAccessIterator begin, - RandomAccessIterator end) { + template + inline Tuple(RandomAccessIterator begin, RandomAccessIterator end) { this->assign(begin, end); } /*! @@ -102,9 +96,8 @@ class Tuple { * \param end end the end of the iterator * \tparam RandomAccessIterator iterator type */ - template - inline void assign(RandomAccessIterator begin, - RandomAccessIterator end) { + template + inline void assign(RandomAccessIterator begin, RandomAccessIterator end) { this->SetDim(end - begin); std::copy(begin, end, this->begin()); } @@ -141,7 +134,7 @@ class Tuple { * \param init the source initializer list * \return reference of self */ - inline Tuple &operator=(std::initializer_list init) { + inline Tuple& operator=(std::initializer_list init) { this->assign(init.begin(), init.end()); return *this; } @@ -149,7 +142,7 @@ class Tuple { * \return whether two tuple equals * \param s the tuple to compare against */ - inline bool operator==(const Tuple &s) const { + inline bool operator==(const Tuple& s) const { if (ndim_ != s.ndim_) return false; return std::equal(begin(), end(), s.begin()); } @@ -157,45 +150,33 @@ class Tuple { * \return whether two tuple not equal * \param s the tuple to compare against */ - inline bool operator!=(const Tuple &s) const { - return !(*this == s); - } + inline bool operator!=(const Tuple& s) const { return !(*this == s); } /*! \return the begin data pointer to content of the tuple */ - inline const ValueType *begin() const { - return ndim_ <= kStackCache ? data_stack_ : data_heap_; - } + inline const ValueType* begin() const { return ndim_ <= kStackCache ? data_stack_ : data_heap_; } /*! \return the begin data pointer to content of the tuple */ - inline ValueType *begin() { - return ndim_ <= kStackCache ? data_stack_ : data_heap_; - } + inline ValueType* begin() { return ndim_ <= kStackCache ? data_stack_ : data_heap_; } /*! \return the data pointer to end of the tuple */ inline const ValueType* end() const { - return ndim_ <= kStackCache ? (data_stack_ + ndim_): (data_heap_ + ndim_); + return ndim_ <= kStackCache ? (data_stack_ + ndim_) : (data_heap_ + ndim_); } /*! \return the data pointer to end the tuple */ inline ValueType* end() { - return ndim_ <= kStackCache ? (data_stack_ + ndim_): (data_heap_ + ndim_); + return ndim_ <= kStackCache ? (data_stack_ + ndim_) : (data_heap_ + ndim_); } /*! \return number of dimension of the tuple */ - inline uint32_t ndim() const { - return ndim_; - } + inline uint32_t ndim() const { return ndim_; } /*! * \brief get corresponding index * \param i dimension index * \return the corresponding dimension size */ - inline ValueType& operator[](size_t i) { - return begin()[i]; - } + inline ValueType& operator[](size_t i) { return begin()[i]; } /*! * \brief get corresponding index * \param i dimension index * \return the corresponding dimension size */ - inline const ValueType& operator[](size_t i) const { - return begin()[i]; - } + inline const ValueType& operator[](size_t i) const { return begin()[i]; } /*! * \brief Save Tuple to JSON. * \param writer JSONWriter @@ -219,7 +200,7 @@ class Tuple { * \param t the tuple * \return the ostream */ - friend std::ostream &operator<<(std::ostream &os, const Tuple &t) { + friend std::ostream& operator<<(std::ostream& os, const Tuple& t) { os << '['; const ValueType* begin = t.begin(); const ValueType* end = t.end(); @@ -236,7 +217,7 @@ class Tuple { * \param t The tuple * \return the istream */ - friend std::istream &operator>>(std::istream &is, Tuple &t) { + friend std::istream& operator>>(std::istream& is, Tuple& t) { // get ( while (true) { char ch = is.peek(); @@ -252,7 +233,7 @@ class Tuple { if (!isspace(ch)) { is.setstate(std::ios::failbit); return is; - } + } } // Handle empty tuple while (isspace(is.peek())) { @@ -278,10 +259,12 @@ class Tuple { while (true) { ch = is.peek(); if (isspace(ch)) { - is.get(); continue; + is.get(); + continue; } if (ch == ')' || ch == ']') { - is.get(); break; + is.get(); + break; } break; } @@ -302,8 +285,8 @@ class Tuple { * \tparam DType data type that save to * \tparam TStream any stream type that have write */ - template - inline void Save(TStream *strm) const; + template + inline void Save(TStream* strm) const; /*! * \brief load the content from binary stream * \param strm the output stream @@ -311,8 +294,8 @@ class Tuple { * \tparam TStream any stream type that have write * \return whether the load is successful */ - template - inline bool Load(TStream *strm); + template + inline bool Load(TStream* strm); protected: // stack cache size @@ -327,9 +310,8 @@ class Tuple { ValueType* data_heap_{nullptr}; // internal function to change the dimension inline void SetDim(uint32_t ndim) { - if (ndim > kStackCache && - ndim > num_heap_allocated_) { - delete [] data_heap_; + if (ndim > kStackCache && ndim > num_heap_allocated_) { + delete[] data_heap_; data_heap_ = new ValueType[ndim]; num_heap_allocated_ = ndim; } @@ -356,16 +338,14 @@ class TShape : public Tuple { * \brief copy constructor of TShape * \param s source shape. */ - inline TShape(const Tuple& s) { // NOLINT(*) + inline TShape(const Tuple& s) { // NOLINT(*) this->assign(s.begin(), s.end()); } /*! * \brief constructor from initializer list * \param init the initializer_list */ - inline TShape(std::initializer_list init) { - this->assign(init.begin(), init.end()); - } + inline TShape(std::initializer_list init) { this->assign(init.begin(), init.end()); } /*! * \brief move constructor. * \param s source shape. @@ -379,9 +359,8 @@ class TShape : public Tuple { * \param end end the end of the iterator * \tparam RandomAccessIterator iterator type */ - template - inline TShape(RandomAccessIterator begin, - RandomAccessIterator end) { + template + inline TShape(RandomAccessIterator begin, RandomAccessIterator end) { this->assign(begin, end); } /*! @@ -399,13 +378,13 @@ class TShape : public Tuple { * \return self. */ inline TShape& operator=(Tuple&& src) { // NOLINT(*) - TShape(std::move(src)).swap(*this); // NOLINT(*) + TShape(std::move(src)).swap(*this); // NOLINT(*) return *this; } /*! \return total number of elements in the shape */ inline size_t Size() const { dim_t size = 1; - const dim_t* start = begin(), *fin = end(); + const dim_t *start = begin(), *fin = end(); for (const dim_t* it = start; it != fin; ++it) { size *= *it; } @@ -418,28 +397,24 @@ class TShape : public Tuple { */ inline size_t ProdShape(int dimstart, int dimend) const { dim_t num = 1; - const dim_t *d = this->data(); + const dim_t* d = this->data(); for (int i = dimstart; i < dimend; ++i) { num *= d[i]; } return num; } /*! \return the begin data pointer to content of the tuple */ - inline const dim_t *data() const { - return begin(); - } + inline const dim_t* data() const { return begin(); } /*! \return the begin data pointer to content of the tuple */ - inline dim_t *data() { - return begin(); - } + inline dim_t* data() { return begin(); } #ifdef MSHADOW_XINLINE - template - inline TShape(const mshadow::Shape &s) {// NOLINT(*) + template + inline TShape(const mshadow::Shape& s) { // NOLINT(*) this->assign(s.shape_, s.shape_ + dim); } - template - inline TShape(mshadow::Shape &&s) {// NOLINT(*) + template + inline TShape(mshadow::Shape&& s) { // NOLINT(*) this->assign(s.shape_, s.shape_ + dim); } /*! @@ -448,8 +423,8 @@ class TShape : public Tuple { * \tparam dim shape dimension * \return reference of self */ - template - inline TShape &operator=(const mshadow::Shape &shape) { + template + inline TShape& operator=(const mshadow::Shape& shape) { this->assign(shape.shape_, shape.shape_ + dim); return *this; } @@ -458,11 +433,11 @@ class TShape : public Tuple { * \return the shape requested * \tparam dim dimension of the tensor */ - template + template inline mshadow::Shape get() const { CHECK_EQ(dim, static_cast(ndim())) << "dimension do not match target dimension " << dim << " vs " << ndim(); - const dim_t *d = this->data(); + const dim_t* d = this->data(); mshadow::Shape s; for (int i = 0; i < dim; ++i) { s[i] = d[i]; @@ -476,7 +451,7 @@ class TShape : public Tuple { inline mshadow::Shape<2> FlatTo2D(void) const { mshadow::Shape<2> s; if (ndim() == 0) return mshadow::Shape2(0, 0); - const dim_t *d = this->data(); + const dim_t* d = this->data(); s.shape_[1] = d[ndim() - 1]; dim_t ymax = 1; for (size_t i = 1; i < ndim(); ++i) { @@ -495,7 +470,7 @@ class TShape : public Tuple { CHECK(axis_end >= axis_begin); mshadow::Shape<3> s; if (ndim() == 0) return mshadow::Shape3(0, 0, 0); - const dim_t *d = this->data(); + const dim_t* d = this->data(); s.shape_[0] = 1; s.shape_[1] = 1; s.shape_[2] = 1; @@ -516,25 +491,21 @@ class TShape : public Tuple { * \param axis The axis specified. * \return the flat 3d shape */ - inline mshadow::Shape<3> FlatTo3D(size_t axis) const { - return FlatTo3D(axis, axis); - } - inline bool operator==(const TShape &s) const { + inline mshadow::Shape<3> FlatTo3D(size_t axis) const { return FlatTo3D(axis, axis); } + inline bool operator==(const TShape& s) const { if (ndim() != s.ndim()) return false; return std::equal(begin(), end(), s.begin()); } - inline bool operator!=(const TShape &s) const { - return !(*this == s); - } + inline bool operator!=(const TShape& s) const { return !(*this == s); } /*! * \return whether two shape equals * \param s the shape to compare against * \tparam dim dimension of the shape */ - template - inline bool operator==(const mshadow::Shape &s) const { + template + inline bool operator==(const mshadow::Shape& s) const { if (ndim_ != dim) return false; - const dim_t *d = dim <= kStackCache ? data_stack_ : data_heap_; + const dim_t* d = dim <= kStackCache ? data_stack_ : data_heap_; for (size_t i = 0; i < dim; ++i) { if (d[i] != s.shape_[i]) return false; } @@ -545,18 +516,16 @@ class TShape : public Tuple { * \param s the shape to compare against * \tparam dim dimension of the shape */ - template - inline bool operator!=(const mshadow::Shape &s) const { + template + inline bool operator!=(const mshadow::Shape& s) const { return !(*this == s); } #endif }; /*! \brief helper function to cast type of container elements */ -template -inline DstIter ShapeTypeCast(const SrcIter begin, - const SrcIter end, - DstIter dst_begin) { +template +inline DstIter ShapeTypeCast(const SrcIter begin, const SrcIter end, DstIter dst_begin) { typedef typename std::iterator_traits::value_type SrcDType; typedef typename std::iterator_traits::value_type DstDType; auto cast = [](const SrcDType& dim) { return static_cast(dim); }; @@ -564,7 +533,7 @@ inline DstIter ShapeTypeCast(const SrcIter begin, } /*! \brief helper function to transform a container to TShape with type cast */ -template +template inline TShape ShapeTypeCast(const SrcIter begin, const SrcIter end) { size_t ndim = std::distance(begin, end); TShape res(ndim); @@ -573,9 +542,9 @@ inline TShape ShapeTypeCast(const SrcIter begin, const SrcIter end) { } /*! \tparam ValueType The type of data stored inside tuple. */ -template -template -inline void Tuple::Save(TStream *strm) const { +template +template +inline void Tuple::Save(TStream* strm) const { strm->Write(&ndim_, sizeof(ndim_)); if (typeid(DType) == typeid(ValueType)) { strm->Write(begin(), sizeof(ValueType) * ndim_); @@ -587,9 +556,9 @@ inline void Tuple::Save(TStream *strm) const { } /*! \tparam ValueType The type of data stored inside tuple. */ -template -template -inline bool Tuple::Load(TStream *strm) { +template +template +inline bool Tuple::Load(TStream* strm) { if (strm->Read(&ndim_, sizeof(ndim_)) != sizeof(ndim_)) return false; this->SetDim(ndim_); size_t nread = sizeof(DType) * ndim_; @@ -607,7 +576,7 @@ inline bool Tuple::Load(TStream *strm) { namespace std { /*! \brief hash function for Tuple. */ -template +template struct hash > { /*! \brief hash a Tuple into unsigned int */ size_t operator()(const nnvm::Tuple& val) const { @@ -621,7 +590,7 @@ struct hash > { }; /*! \brief hash function for TShape. */ -template<> +template <> struct hash { /*! \brief hash a TShape into unsigned int */ size_t operator()(const nnvm::TShape& val) const { @@ -640,11 +609,9 @@ namespace dmlc { DMLC_DECLARE_TYPE_NAME(optional, "Shape or None"); // avoid low version of MSVC #if !defined(_MSC_VER) -template +template struct type_name_helper > { - static inline std::string value() { - return "tuple of <" + type_name() + ">"; - } + static inline std::string value() { return "tuple of <" + type_name() + ">"; } }; #endif } // namespace dmlc diff --git a/nnvm/src/c_api/c_api_common.h b/nnvm/src/c_api/c_api_common.h index b3ff36ae606f..129194715649 100644 --- a/nnvm/src/c_api/c_api_common.h +++ b/nnvm/src/c_api/c_api_common.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -29,23 +29,34 @@ #include #include #include -#include + #include -#include #include +#include +#include /*! \brief macro to guard beginning and end section of all functions */ #define API_BEGIN() try { /*! \brief every function starts with API_BEGIN(); and finishes with API_END() or API_END_HANDLE_ERROR */ -#define API_END() } catch(dmlc::Error &_except_) { return NNAPIHandleException(_except_); } return 0; // NOLINT(*) +#define API_END() \ + } \ + catch (dmlc::Error & _except_) { \ + return NNAPIHandleException(_except_); \ + } \ + return 0; // NOLINT(*) /*! * \brief every function starts with API_BEGIN(); * and finishes with API_END() or API_END_HANDLE_ERROR * The finally clause contains procedure to cleanup states when an error happens. */ -#define API_END_HANDLE_ERROR(Finalize) } catch(dmlc::Error &_except_) { Finalize; return NNAPIHandleException(_except_); } return 0; // NOLINT(*) - +#define API_END_HANDLE_ERROR(Finalize) \ + } \ + catch (dmlc::Error & _except_) { \ + Finalize; \ + return NNAPIHandleException(_except_); \ + } \ + return 0; // NOLINT(*) /*! \brief entry to to easily hold returning information */ struct NNAPIThreadLocalEntry { @@ -54,9 +65,9 @@ struct NNAPIThreadLocalEntry { /*! \brief result holder for returning strings */ std::vector ret_vec_str; /*! \brief result holder for returning string pointers */ - std::vector ret_vec_charp; + std::vector ret_vec_charp; /*! \brief result holder for returning handles */ - std::vector ret_handles; + std::vector ret_handles; /*! \brief argument holder to hold symbol */ std::unordered_map kwarg_symbol; }; @@ -69,7 +80,7 @@ typedef dmlc::ThreadLocalStore NNAPIThreadLocalStore; * \param e the exception * \return the return value of API after exception is handled */ -inline int NNAPIHandleException(const dmlc::Error &e) { +inline int NNAPIHandleException(const dmlc::Error& e) { NNAPISetLastError(e.what()); return -1; } diff --git a/nnvm/src/c_api/c_api_error.cc b/nnvm/src/c_api/c_api_error.cc index ba6e1cd37c8a..c2f90b162e1f 100644 --- a/nnvm/src/c_api/c_api_error.cc +++ b/nnvm/src/c_api/c_api_error.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -22,6 +22,7 @@ * \brief C error handling */ #include + #include "c_api_common.h" struct ErrorEntry { @@ -30,10 +31,6 @@ struct ErrorEntry { typedef dmlc::ThreadLocalStore NNAPIErrorStore; -const char *NNGetLastError() { - return NNAPIErrorStore::Get()->last_error.c_str(); -} +const char* NNGetLastError() { return NNAPIErrorStore::Get()->last_error.c_str(); } -void NNAPISetLastError(const char* msg) { - NNAPIErrorStore::Get()->last_error = msg; -} +void NNAPISetLastError(const char* msg) { NNAPIErrorStore::Get()->last_error = msg; } diff --git a/nnvm/src/c_api/c_api_graph.cc b/nnvm/src/c_api/c_api_graph.cc index cc5449b0fbbe..a547476e4c7e 100644 --- a/nnvm/src/c_api/c_api_graph.cc +++ b/nnvm/src/c_api/c_api_graph.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -21,17 +21,18 @@ * \file c_api_graph.cc * \brief C API related to Graph IR. */ +#include #include -#include -#include #include +#include #include -#include +#include + #include "c_api_common.h" using namespace nnvm; -int NNGraphCreate(SymbolHandle symbol, GraphHandle *graph) { +int NNGraphCreate(SymbolHandle symbol, GraphHandle* graph) { Graph* g = new Graph(); API_BEGIN(); g->outputs = static_cast(symbol)->outputs; @@ -45,7 +46,7 @@ int NNGraphFree(GraphHandle handle) { API_END(); } -int NNGraphGetSymbol(GraphHandle graph, SymbolHandle *symbol) { +int NNGraphGetSymbol(GraphHandle graph, SymbolHandle* symbol) { Symbol* s = new Symbol(); API_BEGIN(); s->outputs = static_cast(graph)->outputs; @@ -53,20 +54,15 @@ int NNGraphGetSymbol(GraphHandle graph, SymbolHandle *symbol) { API_END_HANDLE_ERROR(delete s); } -int NNGraphSetNodeEntryListAttr_(GraphHandle handle, - const char* key, - SymbolHandle list) { +int NNGraphSetNodeEntryListAttr_(GraphHandle handle, const char* key, SymbolHandle list) { API_BEGIN(); Symbol* s = static_cast(list); Graph* g = static_cast(handle); - g->attrs[std::string(key)] - = std::make_shared(s->outputs); + g->attrs[std::string(key)] = std::make_shared(s->outputs); API_END(); } -int NNGraphSetJSONAttr(GraphHandle handle, - const char* key, - const char* json_value) { +int NNGraphSetJSONAttr(GraphHandle handle, const char* key, const char* json_value) { API_BEGIN(); Graph* g = static_cast(handle); std::string temp(json_value); @@ -78,11 +74,8 @@ int NNGraphSetJSONAttr(GraphHandle handle, API_END(); } -int NNGraphGetJSONAttr(GraphHandle handle, - const char* key, - const char** json_out, - int *success) { - NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get(); +int NNGraphGetJSONAttr(GraphHandle handle, const char* key, const char** json_out, int* success) { + NNAPIThreadLocalEntry* ret = NNAPIThreadLocalStore::Get(); API_BEGIN(); Graph* g = static_cast(handle); std::string skey(key); @@ -100,10 +93,8 @@ int NNGraphGetJSONAttr(GraphHandle handle, API_END(); } -int NNGraphApplyPasses(GraphHandle src, - nn_uint num_pass, - const char** pass_names, - GraphHandle *dst) { +int NNGraphApplyPasses(GraphHandle src, nn_uint num_pass, const char** pass_names, + GraphHandle* dst) { Graph* g = new Graph(); API_BEGIN(); std::vector vpass; diff --git a/nnvm/src/c_api/c_api_symbolic.cc b/nnvm/src/c_api/c_api_symbolic.cc index 7ca56035acae..2127997da05a 100644 --- a/nnvm/src/c_api/c_api_symbolic.cc +++ b/nnvm/src/c_api/c_api_symbolic.cc @@ -24,14 +24,14 @@ #include #include #include + #include "c_api_common.h" using namespace nnvm; -int NNListAllOpNames(nn_uint *out_size, - const char*** out_array) { +int NNListAllOpNames(nn_uint* out_size, const char*** out_array) { API_BEGIN(); - NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get(); + NNAPIThreadLocalEntry* ret = NNAPIThreadLocalStore::Get(); ret->ret_vec_str = dmlc::Registry::ListAllNames(); ret->ret_vec_charp.resize(0); ret->ret_vec_charp.reserve(ret->ret_vec_str.size()); @@ -43,40 +43,31 @@ int NNListAllOpNames(nn_uint *out_size, API_END(); } -int NNGetOpHandle(const char* op_name, - OpHandle* op_out) { +int NNGetOpHandle(const char* op_name, OpHandle* op_out) { API_BEGIN(); *op_out = (OpHandle)Op::Get(op_name); // NOLINT(*) API_END(); } -int NNListUniqueOps(nn_uint *out_size, - OpHandle **out_array) { +int NNListUniqueOps(nn_uint* out_size, OpHandle** out_array) { API_BEGIN(); - auto &vec = dmlc::Registry::List(); + auto& vec = dmlc::Registry::List(); *out_size = static_cast(vec.size()); *out_array = (OpHandle*)(dmlc::BeginPtr(vec)); // NOLINT(*) API_END(); } -int NNAddControlDeps(SymbolHandle handle, - SymbolHandle src_dep) { +int NNAddControlDeps(SymbolHandle handle, SymbolHandle src_dep) { API_BEGIN(); - static_cast(handle)->AddControlDeps( - *static_cast(src_dep)); + static_cast(handle)->AddControlDeps(*static_cast(src_dep)); API_END(); } -int NNGetOpInfo(OpHandle handle, - const char **name, - const char **description, - nn_uint *num_doc_args, - const char ***arg_names, - const char ***arg_type_infos, - const char ***arg_descriptions, - const char **return_type) { - const Op *op = static_cast(handle); - NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get(); +int NNGetOpInfo(OpHandle handle, const char** name, const char** description, nn_uint* num_doc_args, + const char*** arg_names, const char*** arg_type_infos, + const char*** arg_descriptions, const char** return_type) { + const Op* op = static_cast(handle); + NNAPIThreadLocalEntry* ret = NNAPIThreadLocalStore::Get(); API_BEGIN(); *name = op->name.c_str(); @@ -100,12 +91,9 @@ int NNGetOpInfo(OpHandle handle, API_END(); } -int NNSymbolCreateAtomicSymbol(OpHandle creator, - nn_uint num_param, - const char **keys, - const char **vals, - SymbolHandle *out) { - Symbol *s = new Symbol(); +int NNSymbolCreateAtomicSymbol(OpHandle creator, nn_uint num_param, const char** keys, + const char** vals, SymbolHandle* out) { + Symbol* s = new Symbol(); API_BEGIN(); const Op* op = static_cast(creator); std::unordered_map kwargs; @@ -117,19 +105,17 @@ int NNSymbolCreateAtomicSymbol(OpHandle creator, API_END_HANDLE_ERROR(delete s;); } -int NNSymbolCreateVariable(const char *name, SymbolHandle *out) { - Symbol *s = new Symbol(); +int NNSymbolCreateVariable(const char* name, SymbolHandle* out) { + Symbol* s = new Symbol(); API_BEGIN(); *s = Symbol::CreateVariable(name); *out = s; API_END_HANDLE_ERROR(delete s); } -int NNSymbolCreateGroup(nn_uint num_symbols, - SymbolHandle *symbols, - SymbolHandle *out) { - Symbol *s = new Symbol(); - Symbol **sym_arr = (Symbol**)symbols; // NOLINT(*) +int NNSymbolCreateGroup(nn_uint num_symbols, SymbolHandle* symbols, SymbolHandle* out) { + Symbol* s = new Symbol(); + Symbol** sym_arr = (Symbol**)symbols; // NOLINT(*) API_BEGIN(); std::vector syms; for (nn_uint i = 0; i < num_symbols; ++i) { @@ -140,28 +126,24 @@ int NNSymbolCreateGroup(nn_uint num_symbols, API_END_HANDLE_ERROR(delete s); } -int NNSymbolGetOutput(SymbolHandle symbol, - nn_uint index, - SymbolHandle *out) { - Symbol *s = new Symbol(); +int NNSymbolGetOutput(SymbolHandle symbol, nn_uint index, SymbolHandle* out) { + Symbol* s = new Symbol(); API_BEGIN(); *s = (*static_cast(symbol))[index]; *out = s; API_END_HANDLE_ERROR(delete s); } -int NNSymbolGetInternals(SymbolHandle symbol, - SymbolHandle *out) { - Symbol *s = new Symbol(); +int NNSymbolGetInternals(SymbolHandle symbol, SymbolHandle* out) { + Symbol* s = new Symbol(); API_BEGIN(); *s = static_cast(symbol)->GetInternals(); *out = s; API_END_HANDLE_ERROR(delete s); } -int NNSymbolGetChildren(SymbolHandle symbol, - SymbolHandle *out) { - Symbol *s = new Symbol(); +int NNSymbolGetChildren(SymbolHandle symbol, SymbolHandle* out) { + Symbol* s = new Symbol(); API_BEGIN(); *s = static_cast(symbol)->GetChildren(); *out = s; @@ -174,17 +156,17 @@ int NNSymbolFree(SymbolHandle symbol) { API_END(); } -int NNSymbolCopy(SymbolHandle symbol, SymbolHandle *out) { - Symbol *s = new Symbol(); +int NNSymbolCopy(SymbolHandle symbol, SymbolHandle* out) { + Symbol* s = new Symbol(); API_BEGIN(); *s = static_cast(symbol)->Copy(); *out = s; API_END_HANDLE_ERROR(delete s); } -int NNSymbolPrint(SymbolHandle symbol, const char **out_str) { - Symbol *s = static_cast(symbol); - NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get(); +int NNSymbolPrint(SymbolHandle symbol, const char** out_str) { + Symbol* s = static_cast(symbol); + NNAPIThreadLocalEntry* ret = NNAPIThreadLocalStore::Get(); API_BEGIN(); std::ostringstream os; s->Print(os); @@ -193,12 +175,9 @@ int NNSymbolPrint(SymbolHandle symbol, const char **out_str) { API_END(); } -int NNSymbolGetAttr(SymbolHandle symbol, - const char* key, - const char** out, - int* success) { - Symbol *s = static_cast(symbol); - NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get(); +int NNSymbolGetAttr(SymbolHandle symbol, const char* key, const char** out, int* success) { + Symbol* s = static_cast(symbol); + NNAPIThreadLocalEntry* ret = NNAPIThreadLocalStore::Get(); API_BEGIN(); if (s->GetAttr(key, &(ret->ret_str))) { *out = (ret->ret_str).c_str(); @@ -210,27 +189,20 @@ int NNSymbolGetAttr(SymbolHandle symbol, API_END(); } -int NNSymbolSetAttrs(SymbolHandle symbol, - nn_uint num_param, - const char** keys, - const char** vals) { - Symbol *s = static_cast(symbol); +int NNSymbolSetAttrs(SymbolHandle symbol, nn_uint num_param, const char** keys, const char** vals) { + Symbol* s = static_cast(symbol); API_BEGIN(); std::vector > kwargs; for (nn_uint i = 0; i < num_param; ++i) { - kwargs.emplace_back( - std::make_pair(std::string(keys[i]), std::string(vals[i]))); + kwargs.emplace_back(std::make_pair(std::string(keys[i]), std::string(vals[i]))); } s->SetAttrs(kwargs); API_END(); } -int NNSymbolListAttrs(SymbolHandle symbol, - int option, - nn_uint *out_size, - const char*** out) { - Symbol *s = static_cast(symbol); - NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get(); +int NNSymbolListAttrs(SymbolHandle symbol, int option, nn_uint* out_size, const char*** out) { + Symbol* s = static_cast(symbol); + NNAPIThreadLocalEntry* ret = NNAPIThreadLocalStore::Get(); API_BEGIN(); std::unordered_map attr = s->ListAttrs(static_cast(option)); // NOLINT(*) @@ -252,12 +224,10 @@ int NNSymbolListAttrs(SymbolHandle symbol, API_END(); } -int NNSymbolListInputVariables(SymbolHandle symbol, - int option, - nn_uint *out_size, +int NNSymbolListInputVariables(SymbolHandle symbol, int option, nn_uint* out_size, SymbolHandle** out_sym_array) { - Symbol *s = static_cast(symbol); - NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get(); + Symbol* s = static_cast(symbol); + NNAPIThreadLocalEntry* ret = NNAPIThreadLocalStore::Get(); API_BEGIN(); std::vector vs = s->ListInputs(Symbol::ListInputOption(option)); ret->ret_handles.resize(0); @@ -272,15 +242,12 @@ int NNSymbolListInputVariables(SymbolHandle symbol, API_END(); } -int NNSymbolListInputNames(SymbolHandle symbol, - int option, - nn_uint *out_size, - const char ***out_str_array) { - Symbol *s = static_cast(symbol); - NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get(); +int NNSymbolListInputNames(SymbolHandle symbol, int option, nn_uint* out_size, + const char*** out_str_array) { + Symbol* s = static_cast(symbol); + NNAPIThreadLocalEntry* ret = NNAPIThreadLocalStore::Get(); API_BEGIN(); - ret->ret_vec_str = - s->ListInputNames(Symbol::ListInputOption(option)); + ret->ret_vec_str = s->ListInputNames(Symbol::ListInputOption(option)); ret->ret_vec_charp.resize(0); ret->ret_vec_charp.reserve(ret->ret_vec_str.size()); for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) { @@ -291,11 +258,9 @@ int NNSymbolListInputNames(SymbolHandle symbol, API_END(); } -int NNSymbolListOutputNames(SymbolHandle symbol, - nn_uint *out_size, - const char ***out_str_array) { - Symbol *s = static_cast(symbol); - NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get(); +int NNSymbolListOutputNames(SymbolHandle symbol, nn_uint* out_size, const char*** out_str_array) { + Symbol* s = static_cast(symbol); + NNAPIThreadLocalEntry* ret = NNAPIThreadLocalStore::Get(); API_BEGIN(); ret->ret_vec_str = s->ListOutputNames(); ret->ret_vec_charp.resize(0); @@ -308,24 +273,19 @@ int NNSymbolListOutputNames(SymbolHandle symbol, API_END(); } -int NNSymbolGetNumOutputs(SymbolHandle symbol, - nn_uint *output_count) { - Symbol *s = static_cast(symbol); +int NNSymbolGetNumOutputs(SymbolHandle symbol, nn_uint* output_count) { + Symbol* s = static_cast(symbol); API_BEGIN(); *output_count = static_cast(s->outputs.size()); API_END(); } -int NNSymbolCompose(SymbolHandle sym, - const char *name, - nn_uint num_args, - const char** keys, +int NNSymbolCompose(SymbolHandle sym, const char* name, nn_uint num_args, const char** keys, SymbolHandle* args) { API_BEGIN(); - NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get(); + NNAPIThreadLocalEntry* ret = NNAPIThreadLocalStore::Get(); std::string& s_name = ret->ret_str; - std::unordered_map& kwargs - = ret->kwarg_symbol; + std::unordered_map& kwargs = ret->kwarg_symbol; kwargs.clear(); if (name != nullptr) { s_name = name; @@ -335,8 +295,7 @@ int NNSymbolCompose(SymbolHandle sym, Symbol* s = static_cast(sym); if (keys == nullptr && num_args != 0) { kwargs.clear(); - array_view parg( - (Symbol**)args, (Symbol**)args + num_args); // NOLINT(*) + array_view parg((Symbol**)args, (Symbol**)args + num_args); // NOLINT(*) s->Compose(parg, kwargs, s_name); } else { for (nn_uint i = 0; i < num_args; ++i) { diff --git a/nnvm/src/core/graph.cc b/nnvm/src/core/graph.cc index c3ae60e99937..fd5b64f4777d 100644 --- a/nnvm/src/core/graph.cc +++ b/nnvm/src/core/graph.cc @@ -23,6 +23,7 @@ */ #include #include + #include namespace nnvm { @@ -39,23 +40,22 @@ const IndexedGraph& Graph::indexed_graph() const { // e.g. the main graph is level 0 // subgraphs of the main graph is level 1 // subgraphs of the subgraphs of the main graph is level 2 -static void SubgraphSanityCheck(const std::vector> &subgraphs) { +static void SubgraphSanityCheck(const std::vector>& subgraphs) { std::vector*> curr_level; std::vector*> next_level; std::unordered_map node2level; - for (auto &subgraph : subgraphs) - next_level.push_back(&subgraph->outputs); + for (auto& subgraph : subgraphs) next_level.push_back(&subgraph->outputs); for (uint32_t level = 0; !next_level.empty(); ++level) { curr_level.swap(next_level); next_level.clear(); - for (const std::vector *graph_ptr : curr_level) { - const std::vector &graph = *graph_ptr; + for (const std::vector* graph_ptr : curr_level) { + const std::vector& graph = *graph_ptr; DFSVisit(graph, [&next_level, &node2level, level](const ObjectPtr& n) { - nnvm::Node *node = n.get(); + nnvm::Node* node = n.get(); // if the node is visited, but on a different level, then check failed // if check failed here or before, we stop doing anything, but raise an error CHECK(!node2level.count(node) || node2level[node] == level) - << "A subgraph should not depend on the outputs of nodes on higher levels"; + << "A subgraph should not depend on the outputs of nodes on higher levels"; // otherwise, this node belongs to the current level node2level[node] = level; // subgraphs of current node belongs to next level @@ -68,55 +68,51 @@ static void SubgraphSanityCheck(const std::vector> &subg } // implement constructor from graph -IndexedGraph::IndexedGraph(const Graph &g) { +IndexedGraph::IndexedGraph(const Graph& g) { entry_rptr_.push_back(0); std::vector inputs_rptr{0}, control_rptr{0}; std::vector> subgraphs; - DFSVisit(g.outputs, [this, &inputs_rptr, &control_rptr, &subgraphs] - (const ObjectPtr& n) { - const auto& is_ghost = Op::GetAttr("TIsGhost"); - if (!n->is_variable() && is_ghost.get(n->op(), false)) return; - CHECK_LT(nodes_.size(), std::numeric_limits::max()); - uint32_t nid = static_cast(nodes_.size()); - CHECK(n); - for (const auto &subgraph : n->attrs.subgraphs) - subgraphs.push_back(subgraph); - // nodes_ - IndexedGraph::Node new_node; - new_node.source = n.get(); - new_node.weak_ref = n; - nodes_.emplace_back(std::move(new_node)); - // arg_nodes_ - if (n->is_variable()) { - input_nodes_.push_back(nid); - } - // node2index_ - node2index_[n.get()] = nid; - // entry rptr - entry_rptr_.push_back(entry_rptr_.back() + n->num_outputs()); - // input entries - for (const auto& e : n->inputs) { - auto it = node2index_.find(e.node.get()); - if (it == node2index_.end() || it->first != e.node.get()) continue; - input_entries_.emplace_back(NodeEntry{it->second, e.index, e.version}); - } - inputs_rptr.push_back(input_entries_.size()); - // control deps - for (const auto& nptr : n->control_deps) { - if (!nptr->is_variable() && is_ghost.get(nptr->op(), false)) continue; - auto it = node2index_.find(nptr.get()); - CHECK(it != node2index_.end()) << "control dep not found in graph"; - control_deps_.push_back(it->second); - } - control_rptr.push_back(control_deps_.size()); + DFSVisit(g.outputs, [this, &inputs_rptr, &control_rptr, &subgraphs](const ObjectPtr& n) { + const auto& is_ghost = Op::GetAttr("TIsGhost"); + if (!n->is_variable() && is_ghost.get(n->op(), false)) return; + CHECK_LT(nodes_.size(), std::numeric_limits::max()); + uint32_t nid = static_cast(nodes_.size()); + CHECK(n); + for (const auto& subgraph : n->attrs.subgraphs) subgraphs.push_back(subgraph); + // nodes_ + IndexedGraph::Node new_node; + new_node.source = n.get(); + new_node.weak_ref = n; + nodes_.emplace_back(std::move(new_node)); + // arg_nodes_ + if (n->is_variable()) { + input_nodes_.push_back(nid); + } + // node2index_ + node2index_[n.get()] = nid; + // entry rptr + entry_rptr_.push_back(entry_rptr_.back() + n->num_outputs()); + // input entries + for (const auto& e : n->inputs) { + auto it = node2index_.find(e.node.get()); + if (it == node2index_.end() || it->first != e.node.get()) continue; + input_entries_.emplace_back(NodeEntry{it->second, e.index, e.version}); + } + inputs_rptr.push_back(input_entries_.size()); + // control deps + for (const auto& nptr : n->control_deps) { + if (!nptr->is_variable() && is_ghost.get(nptr->op(), false)) continue; + auto it = node2index_.find(nptr.get()); + CHECK(it != node2index_.end()) << "control dep not found in graph"; + control_deps_.push_back(it->second); + } + control_rptr.push_back(control_deps_.size()); }); - if (!subgraphs.empty()) - SubgraphSanityCheck(subgraphs); + if (!subgraphs.empty()) SubgraphSanityCheck(subgraphs); for (const auto& e : g.outputs) { - outputs_.emplace_back(NodeEntry{ - node2index_.at(e.node.get()), e.index, e.version}); + outputs_.emplace_back(NodeEntry{node2index_.at(e.node.get()), e.index, e.version}); } static auto& fmutate_inputs = Op::GetAttr("FMutateInputs"); @@ -124,10 +120,9 @@ IndexedGraph::IndexedGraph(const Graph &g) { // input_entries_ and control_rptr must not change after this step. const NodeEntry* iptr = dmlc::BeginPtr(input_entries_); for (size_t nid = 0; nid < nodes_.size(); ++nid) { - nodes_[nid].inputs = array_view( - iptr + inputs_rptr[nid], iptr + inputs_rptr[nid + 1]); - if (nodes_[nid].source->op() != nullptr && - fmutate_inputs.count(nodes_[nid].source->op())) { + nodes_[nid].inputs = + array_view(iptr + inputs_rptr[nid], iptr + inputs_rptr[nid + 1]); + if (nodes_[nid].source->op() != nullptr && fmutate_inputs.count(nodes_[nid].source->op())) { for (uint32_t i : fmutate_inputs[nodes_[nid].source->op()](nodes_[nid].source->attrs)) { mutable_input_nodes_.insert(nodes_[nid].inputs[i].node_id); } @@ -135,8 +130,8 @@ IndexedGraph::IndexedGraph(const Graph &g) { } const uint32_t* cptr = dmlc::BeginPtr(control_deps_); for (size_t nid = 0; nid < nodes_.size(); ++nid) { - nodes_[nid].control_deps = array_view( - cptr + control_rptr[nid], cptr + control_rptr[nid + 1]); + nodes_[nid].control_deps = + array_view(cptr + control_rptr[nid], cptr + control_rptr[nid + 1]); } } diff --git a/nnvm/src/core/op.cc b/nnvm/src/core/op.cc index eb51d4b3cd74..08a11dff9a02 100644 --- a/nnvm/src/core/op.cc +++ b/nnvm/src/core/op.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -24,8 +24,8 @@ #include #include -#include #include +#include #include #include @@ -46,7 +46,7 @@ struct OpManager { // storage of additional attribute table. std::unordered_map > attr; // storage of existing triggers - std::unordered_map > > tmap; + std::unordered_map > > tmap; // group of each operator. std::vector > op_group; // get singleton of the @@ -70,14 +70,13 @@ Op& Op::add_alias(const std::string& alias) { // NOLINT(*) // find operator by name const Op* Op::Get(const std::string& name) { const Op* op = dmlc::Registry::Find(name); - CHECK(op != nullptr) - << "Operator " << name << " is not registered"; + CHECK(op != nullptr) << "Operator " << name << " is not registered"; return op; } // Get attribute map by key const any* Op::GetAttrMap(const std::string& key) { - auto& dict = OpManager::Global()->attr; + auto& dict = OpManager::Global()->attr; auto it = dict.find(key); if (it != dict.end()) { return it->second.get(); @@ -87,8 +86,7 @@ const any* Op::GetAttrMap(const std::string& key) { } // update attribute map -void Op::UpdateAttrMap(const std::string& key, - std::function updater) { +void Op::UpdateAttrMap(const std::string& key, std::function updater) { OpManager* mgr = OpManager::Global(); std::lock_guard(mgr->mutex); std::unique_ptr& value = mgr->attr[key]; @@ -96,16 +94,14 @@ void Op::UpdateAttrMap(const std::string& key, if (updater != nullptr) updater(value.get()); } -void Op::AddGroupTrigger(const std::string& group_name, - std::function trigger) { +void Op::AddGroupTrigger(const std::string& group_name, std::function trigger) { OpManager* mgr = OpManager::Global(); std::lock_guard(mgr->mutex); auto& tvec = mgr->tmap[group_name]; tvec.push_back(trigger); auto& op_group = mgr->op_group; for (const Op* op : dmlc::Registry::List()) { - if (op->index_ < op_group.size() && - op_group[op->index_].count(group_name) != 0) { + if (op->index_ < op_group.size() && op_group[op->index_].count(group_name) != 0) { trigger((Op*)op); // NOLINT(*) } } diff --git a/nnvm/src/core/pass.cc b/nnvm/src/core/pass.cc index b43d470f3eb3..974cd2b35918 100644 --- a/nnvm/src/core/pass.cc +++ b/nnvm/src/core/pass.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -22,6 +22,7 @@ * \brief Support for pass registry. */ #include + #include namespace dmlc { @@ -31,7 +32,7 @@ DMLC_REGISTRY_ENABLE(nnvm::PassFunctionReg); namespace nnvm { -const PassFunctionReg* FindPassDep(const std::string&attr_name) { +const PassFunctionReg* FindPassDep(const std::string& attr_name) { for (auto* r : dmlc::Registry::List()) { for (auto& s : r->graph_attr_targets) { if (s == attr_name) return r; @@ -40,13 +41,11 @@ const PassFunctionReg* FindPassDep(const std::string&attr_name) { return nullptr; } -Graph ApplyPasses(Graph g, - const std::vector& pass) { +Graph ApplyPasses(Graph g, const std::vector& pass) { std::vector fpass; for (auto& name : pass) { auto* reg = dmlc::Registry::Find(name); - CHECK(reg != nullptr) - << "Cannot find pass " << name << " in the registry"; + CHECK(reg != nullptr) << "Cannot find pass " << name << " in the registry"; fpass.push_back(reg); } @@ -58,10 +57,8 @@ Graph ApplyPasses(Graph g, if (pass_dep != nullptr) { msg = " The attribute is provided by pass " + pass_dep->name; } - LOG(FATAL) << "Graph attr dependency " << dep - << " is required by pass " << r->name - << " but is not available " - << msg; + LOG(FATAL) << "Graph attr dependency " << dep << " is required by pass " << r->name + << " but is not available " << msg; } } g = r->body(std::move(g)); diff --git a/nnvm/src/core/symbolic.cc b/nnvm/src/core/symbolic.cc index 86dc7e63c403..12b8675d0bd7 100644 --- a/nnvm/src/core/symbolic.cc +++ b/nnvm/src/core/symbolic.cc @@ -22,13 +22,13 @@ * \brief Symbolic graph composition API. */ #include -#include #include +#include namespace nnvm { namespace symbol_constants { -const char *kNamespaceSeparator = "$"; +const char* kNamespaceSeparator = "$"; } // namespace symbol_constants // auxililary version attribute in variable. @@ -48,7 +48,7 @@ ObjectPtr CreateVariableNode(const std::string& name) { // If the node's op mutates a certain input variable, // The version of that varaible will increase // version is used to implicitly order the mutation sequences -inline void UpdateNodeVersion(Node *n) { +inline void UpdateNodeVersion(Node* n) { static auto& fmutate_inputs = Op::GetAttr("FMutateInputs"); for (NodeEntry& e : n->inputs) { if (e.node->is_variable()) { @@ -58,16 +58,14 @@ inline void UpdateNodeVersion(Node *n) { if (fmutate_inputs.count(n->op()) != 0) { for (uint32_t i : fmutate_inputs[n->op()](n->attrs)) { NodeEntry& e = n->inputs[i]; - CHECK(e.node->is_variable()) - << "Mutation target can only be Variable"; + CHECK(e.node->is_variable()) << "Mutation target can only be Variable"; // increase the version of the variable. e.version = ++nnvm::get(e.node->attrs.parsed).version; } } } -inline std::string DefaultVarName(const std::string &op_name, - const std::string &arg_name) { +inline std::string DefaultVarName(const std::string& op_name, const std::string& arg_name) { if (op_name.length() == 0) { return arg_name; } else { @@ -75,8 +73,7 @@ inline std::string DefaultVarName(const std::string &op_name, } } -inline void KeywordArgumentMismatch(const char *source, - const std::vector& user_args, +inline void KeywordArgumentMismatch(const char* source, const std::vector& user_args, const array_view& args) { std::unordered_set keys(args.begin(), args.end()); std::ostringstream head, msg; @@ -87,16 +84,13 @@ inline void KeywordArgumentMismatch(const char *source, for (const auto& key : user_args) { if (keys.count(key) == 0) { - LOG(FATAL) << source - << "Keyword argument name " << key << " not found." - << msg.str(); + LOG(FATAL) << source << "Keyword argument name " << key << " not found." << msg.str(); } } } -template -inline std::vector GetKeys( - const std::unordered_map& kwargs) { +template +inline std::vector GetKeys(const std::unordered_map& kwargs) { std::vector keys(kwargs.size()); std::transform(kwargs.begin(), kwargs.end(), keys.begin(), [](decltype(*kwargs.begin())& kv) { return kv.first; }); @@ -117,14 +111,14 @@ Symbol Symbol::Copy() const { std::unordered_map old_new; // use DFSVisit to copy all the nodes DFSVisit(this->outputs, [&old_new](const ObjectPtr& node) { - ObjectPtr np = Node::Create(); - np->attrs = node->attrs; - old_new[node.get()] = std::move(np); - }); + ObjectPtr np = Node::Create(); + np->attrs = node->attrs; + old_new[node.get()] = std::move(np); + }); // connect nodes of new graph - for (const auto &kv : old_new) { + for (const auto& kv : old_new) { for (const NodeEntry& e : kv.first->inputs) { - Node *ptr = e.node.get(); + Node* ptr = e.node.get(); kv.second->inputs.emplace_back(NodeEntry{old_new[ptr], e.index, e.version}); } for (const ObjectPtr& p : kv.first->control_deps) { @@ -133,66 +127,64 @@ Symbol Symbol::Copy() const { } // set the head Symbol ret; - for (const NodeEntry &e : outputs) { + for (const NodeEntry& e : outputs) { ret.outputs.emplace_back(NodeEntry{old_new[e.node.get()], e.index, e.version}); } return ret; } -void Symbol::Print(std::ostream &os) const { - if (outputs.size() == 1 && - outputs[0].node->inputs.size() == 0 && +void Symbol::Print(std::ostream& os) const { + if (outputs.size() == 1 && outputs[0].node->inputs.size() == 0 && outputs[0].node->control_deps.size() == 0) { if (outputs[0].node->is_variable()) { os << "Variable:" << outputs[0].node->attrs.name << '\n'; } else { - os << "AtomicFunctor "<< " Op:" << outputs[0].node->op()->name << '\n'; + os << "AtomicFunctor " + << " Op:" << outputs[0].node->op()->name << '\n'; } } else { // use DFSVisit to copy all the nodes os << "Symbol Outputs:\n"; for (size_t i = 0; i < outputs.size(); ++i) { - os << "\toutput[" << i << "]=" << outputs[i].node->attrs.name - << '(' << outputs[i].index << ")\n"; + os << "\toutput[" << i << "]=" << outputs[i].node->attrs.name << '(' << outputs[i].index + << ")\n"; } DFSVisit(this->outputs, [&os](const ObjectPtr& node) { - if (node->is_variable()) { - os << "Variable:" << node->attrs.name << '\n'; - } else { - os << "--------------------\n"; - os << "Op:" << node->op()->name << ", Name=" << node->attrs.name << '\n' - << "Inputs:\n"; - for (size_t i = 0; i < node->inputs.size(); ++i) { - const NodeEntry& e = node->inputs[i]; - os << "\targ[" << i << "]=" << e.node->attrs.name - << '(' << e.index << ")"; - if (e.node->is_variable()) { - os << " version=" << e.version << '\n'; - } else { - os << '\n'; - } + if (node->is_variable()) { + os << "Variable:" << node->attrs.name << '\n'; + } else { + os << "--------------------\n"; + os << "Op:" << node->op()->name << ", Name=" << node->attrs.name << '\n' << "Inputs:\n"; + for (size_t i = 0; i < node->inputs.size(); ++i) { + const NodeEntry& e = node->inputs[i]; + os << "\targ[" << i << "]=" << e.node->attrs.name << '(' << e.index << ")"; + if (e.node->is_variable()) { + os << " version=" << e.version << '\n'; + } else { + os << '\n'; } - if (!node->attrs.dict.empty()) { - os << "Attrs:\n"; - // make an ordered copy because unordered_map doesn't guarantee order. - std::map sorted_dict( - node->attrs.dict.begin(), node->attrs.dict.end()); - for (auto &kv : sorted_dict) { - os << '\t' << kv.first << '=' << kv.second << '\n'; - } + } + if (!node->attrs.dict.empty()) { + os << "Attrs:\n"; + // make an ordered copy because unordered_map doesn't guarantee order. + std::map sorted_dict(node->attrs.dict.begin(), + node->attrs.dict.end()); + for (auto& kv : sorted_dict) { + os << '\t' << kv.first << '=' << kv.second << '\n'; } - if (node->control_deps.size() != 0) { - os << "Control deps:\n"; - for (size_t i = 0; i < node->control_deps.size(); ++i) { - os << "\tcdep[" << i << "]=" << node->control_deps[i]->attrs.name << '\n'; - } + } + if (node->control_deps.size() != 0) { + os << "Control deps:\n"; + for (size_t i = 0; i < node->control_deps.size(); ++i) { + os << "\tcdep[" << i << "]=" << node->control_deps[i]->attrs.name << '\n'; } } - }); + } + }); } } -Symbol Symbol::operator[] (size_t index) const { +Symbol Symbol::operator[](size_t index) const { size_t nreturn = outputs.size(); CHECK_LT(index, nreturn) << "Symbol only accept nonnegative index"; if (nreturn == 1) { @@ -208,25 +200,25 @@ std::vector Symbol::ListInputs(ListInputOption option) const { std::vector ret; if (option == kAll) { ret.reserve(this->outputs.size()); - DFSVisit(this->outputs, [&ret](const ObjectPtr &node) { - if (node->is_variable()) { - ret.push_back(node); - } - }); + DFSVisit(this->outputs, [&ret](const ObjectPtr& node) { + if (node->is_variable()) { + ret.push_back(node); + } + }); } else { std::unordered_set mutable_set; std::vector vlist; vlist.reserve(this->outputs.size()); static auto& fmutate_inputs = Op::GetAttr("FMutateInputs"); - DFSVisit(this->outputs, [&mutable_set, &vlist](const ObjectPtr &node) { - if (node->is_variable()) { - vlist.push_back(node); - } else if (fmutate_inputs.count(node->op())) { - for (uint32_t i : fmutate_inputs[node->op()](node->attrs)){ - mutable_set.insert(node->inputs[i].node.get()); - } + DFSVisit(this->outputs, [&mutable_set, &vlist](const ObjectPtr& node) { + if (node->is_variable()) { + vlist.push_back(node); + } else if (fmutate_inputs.count(node->op())) { + for (uint32_t i : fmutate_inputs[node->op()](node->attrs)) { + mutable_set.insert(node->inputs[i].node.get()); } - }); + } + }); ret.reserve(vlist.size()); for (const ObjectPtr& node : vlist) { if ((option == kReadOnlyArgs && mutable_set.count(node.get()) == 0) || @@ -252,7 +244,7 @@ std::vector Symbol::ListOutputNames() const { std::vector ret; ret.reserve(outputs.size()); - for (auto &head : outputs) { + for (auto& head : outputs) { if (head.node->is_variable()) { ret.push_back(head.node->attrs.name); } else { @@ -291,8 +283,7 @@ void Symbol::Compose(const array_view& args, Node* n = outputs[0].node.get(); FInputGraph fng = fgraph.get(n->op(), nullptr); std::vector garg_idx; - if (fng != nullptr) - garg_idx = fng(n->attrs); + if (fng != nullptr) garg_idx = fng(n->attrs); // The names of the arguments that contain graphs. FListInputNames name_fn = flist_inputs.get(n->op(), nullptr); @@ -300,8 +291,7 @@ void Symbol::Compose(const array_view& args, std::vector garg_names(garg_idx.size()); for (size_t i = 0; i < garg_idx.size(); i++) { size_t idx = garg_idx[i]; - if (idx < arg_names.size()) - garg_names[i] = arg_names[idx]; + if (idx < arg_names.size()) garg_names[i] = arg_names[idx]; } // parameter check. @@ -309,13 +299,13 @@ void Symbol::Compose(const array_view& args, // If the argument isn't a graph, it should have only one output. if (garg_idx.empty() || std::find(garg_idx.begin(), garg_idx.end(), i) == garg_idx.end()) CHECK_EQ(args[i]->outputs.size(), 1U) - << "Argument " << i << " is a tuple, single value is required"; + << "Argument " << i << " is a tuple, single value is required"; } for (const auto& kv : kwargs) { - if (garg_names.empty() - || std::find(garg_names.begin(), garg_names.end(), kv.first) == garg_names.end()) + if (garg_names.empty() || + std::find(garg_names.begin(), garg_names.end(), kv.first) == garg_names.end()) CHECK_EQ(kv.second->outputs.size(), 1U) - << "Keyword Argument " << kv.first << " is a tuple, single value is required"; + << "Keyword Argument " << kv.first << " is a tuple, single value is required"; } // assign new name if (!name.empty()) outputs[0].node->attrs.name = name; @@ -323,14 +313,14 @@ void Symbol::Compose(const array_view& args, // Atomic functor composition. if (IsAtomic(outputs)) { uint32_t n_req = n->num_inputs(); - std::vector arg_vec(args.begin(), args.end()); + std::vector arg_vec(args.begin(), args.end()); std::unordered_map kwarg_map(kwargs.begin(), kwargs.end()); // If one of the input arguments is a graph, we need to remove it from the // list. if (fng != nullptr) { std::vector idxes = fng(n->attrs); for (auto idx : idxes) { - const Symbol *sym; + const Symbol* sym; if (idx < arg_vec.size()) { sym = arg_vec[idx]; } else { @@ -339,8 +329,7 @@ void Symbol::Compose(const array_view& args, sym = it->second; kwarg_map.erase(it); } - if (n_req != kVarg) - n_req--; + if (n_req != kVarg) n_req--; n->attrs.subgraphs.push_back(std::make_shared(*sym)); } // Because idxes does not contain duplicates, the loop below functions well. @@ -358,8 +347,7 @@ void Symbol::Compose(const array_view& args, if (n_req != kVarg) { n->inputs.resize(n_req); CHECK_LE(arg_vec.size(), n_req) - << "Incorrect number of arguments, requires " << n_req - << ", provided " << arg_vec.size(); + << "Incorrect number of arguments, requires " << n_req << ", provided " << arg_vec.size(); for (size_t i = 0; i < arg_vec.size(); ++i) { n->inputs[i] = arg_vec[i]->outputs[0]; } @@ -375,8 +363,7 @@ void Symbol::Compose(const array_view& args, n->inputs[i] = it->second->outputs[0]; ++nmatched; } else { - n->inputs[i] = NodeEntry{ - CreateVariableNode(DefaultVarName(name, arg_names[i])), 0, 0}; + n->inputs[i] = NodeEntry{CreateVariableNode(DefaultVarName(name, arg_names[i])), 0, 0}; // copy attribute of parent over automatically created variables n->inputs[i].node->attrs.dict = n->attrs.dict; } @@ -409,20 +396,19 @@ void Symbol::Compose(const array_view& args, } } else { // general composition - CHECK_EQ(args.size(), 0U) - << "General composition only support kwargs for now"; + CHECK_EQ(args.size(), 0U) << "General composition only support kwargs for now"; size_t nmatched = 0; size_t arg_counter = 0; - std::unordered_map replace_map; + std::unordered_map replace_map; // replace map stores the existing replacement plan for arguments node - auto find_replace_map = [&nmatched, &arg_counter, &args, &kwargs, &replace_map] - (const ObjectPtr &node) { + auto find_replace_map = [&nmatched, &arg_counter, &args, &kwargs, + &replace_map](const ObjectPtr& node) { if (node->is_variable()) { if (arg_counter < args.size()) { replace_map[node.get()] = &(args[arg_counter]->outputs[0]); ++arg_counter; } else { - // match kwargs + // match kwargs auto kit = kwargs.find(node->attrs.name); if (kit != kwargs.end()) { replace_map[node.get()] = &(kit->second->outputs[0]); @@ -436,12 +422,11 @@ void Symbol::Compose(const array_view& args, if (nmatched == kwargs.size() && arg_counter <= args.size()) { std::vector update_nodes; std::vector > replace_plan; - auto find_replace_plan = [&replace_map, &replace_plan, &update_nodes] - (const ObjectPtr &node) { + auto find_replace_plan = [&replace_map, &replace_plan, &update_nodes](const ObjectPtr& node) { // visit all the childs, find possible replacement bool repl = false; for (size_t i = 0; i < node->inputs.size(); ++i) { - NodeEntry *e = &(node->inputs[i]); + NodeEntry* e = &(node->inputs[i]); if (e->node->is_variable()) { auto iter = replace_map.find(e->node.get()); if (iter != replace_map.end()) { @@ -479,17 +464,16 @@ void Symbol::Compose(const array_view& args, } } -Symbol Symbol::operator () (const array_view& args, - const std::unordered_map& kwargs, - const std::string& name) const { +Symbol Symbol::operator()(const array_view& args, + const std::unordered_map& kwargs, + const std::string& name) const { Symbol s = this->Copy(); s.Compose(args, kwargs, name); return s; } void Symbol::AddControlDeps(const Symbol& src) { - CHECK_EQ(outputs.size(), 1U) - << "AddControlDeps only works for nongrouped symbol"; + CHECK_EQ(outputs.size(), 1U) << "AddControlDeps only works for nongrouped symbol"; Node* n = outputs[0].node.get(); for (const NodeEntry& sp : src.outputs) { n->control_deps.push_back(sp.node); @@ -500,21 +484,21 @@ Symbol Symbol::GetInternals() const { static auto& fnum_vis_output = Op::GetAttr("FNumVisibleOutputs"); Symbol ret; DFSVisit(this->outputs, [&ret](const ObjectPtr& node) { - Node* n = node.get(); - if (n->is_variable()) { - // grab version from variable. - VariableParam& param = nnvm::get(n->attrs.parsed); - ret.outputs.emplace_back(NodeEntry{node, 0, param.version}); - } else { - uint32_t nout = n->num_outputs(); - if (fnum_vis_output.count(n->op())) { - nout = fnum_vis_output[n->op()](n->attrs); - } - for (uint32_t i = 0; i < nout; ++i) { - ret.outputs.emplace_back(NodeEntry{node, i, 0}); - } + Node* n = node.get(); + if (n->is_variable()) { + // grab version from variable. + VariableParam& param = nnvm::get(n->attrs.parsed); + ret.outputs.emplace_back(NodeEntry{node, 0, param.version}); + } else { + uint32_t nout = n->num_outputs(); + if (fnum_vis_output.count(n->op())) { + nout = fnum_vis_output[n->op()](n->attrs); } - }); + for (uint32_t i = 0; i < nout; ++i) { + ret.outputs.emplace_back(NodeEntry{node, i, 0}); + } + } + }); return ret; } @@ -533,8 +517,7 @@ Symbol Symbol::GetChildren() const { void Symbol::SetAttrs(const std::vector >& attrs) { Node* node = outputs[0].node.get(); for (const NodeEntry& e : outputs) { - CHECK(node == e.node.get()) - << "Symbol.SetAttrs only works for non-grouped symbol"; + CHECK(node == e.node.get()) << "Symbol.SetAttrs only works for non-grouped symbol"; } for (const auto& kv : attrs) { if (kv.first == "name") { @@ -583,29 +566,27 @@ std::unordered_map Symbol::ListAttrs(ListAttrOption op if (option == kRecursive) { std::unordered_map ret; DFSVisit(this->outputs, [&ret](const ObjectPtr& n) { - for (const auto& it : n->attrs.dict) { - ret[n->attrs.name + symbol_constants::kNamespaceSeparator + it.first] = it.second; - } - }); + for (const auto& it : n->attrs.dict) { + ret[n->attrs.name + symbol_constants::kNamespaceSeparator + it.first] = it.second; + } + }); return ret; } else { return outputs[0].node->attrs.dict; } } -std::vector > - Symbol::ListAttrsRecursive() const { +std::vector > Symbol::ListAttrsRecursive() const { std::vector > ret; DFSVisit(this->outputs, [&ret](const ObjectPtr& n) { - for (const auto& it : n->attrs.dict) { - ret.emplace_back(std::make_tuple(n->attrs.name, it.first, it.second)); - } - }); + for (const auto& it : n->attrs.dict) { + ret.emplace_back(std::make_tuple(n->attrs.name, it.first, it.second)); + } + }); return ret; } -Symbol Symbol::CreateFunctor(const Op* op, - std::unordered_map attrs) { +Symbol Symbol::CreateFunctor(const Op* op, std::unordered_map attrs) { static auto& fnum_vis_output = Op::GetAttr("FNumVisibleOutputs"); Symbol s; ObjectPtr n = Node::Create(); @@ -641,9 +622,9 @@ Symbol Symbol::CreateFunctor(const NodeAttrs& attrs) { return s; } -Symbol Symbol::CreateGroup(const std::vector &symbols) { +Symbol Symbol::CreateGroup(const std::vector& symbols) { Symbol ret; - for (const auto &s : symbols) { + for (const auto& s : symbols) { ret.outputs.insert(ret.outputs.end(), s.outputs.begin(), s.outputs.end()); } return ret; diff --git a/nnvm/src/pass/correct_layout.cc b/nnvm/src/pass/correct_layout.cc index e988ebd87915..b9024a56d143 100644 --- a/nnvm/src/pass/correct_layout.cc +++ b/nnvm/src/pass/correct_layout.cc @@ -22,16 +22,15 @@ * \brief Infer and correct layout. */ #include -#include #include -#include #include +#include +#include namespace nnvm { namespace pass { -nnvm::ObjectPtr CreateLayoutTransformNode(const Layout& src, - const Layout& dst) { +nnvm::ObjectPtr CreateLayoutTransformNode(const Layout& src, const Layout& dst) { static const nnvm::Op* trans_op = nnvm::Op::Get("__layout_transform__"); static int count = 0; nnvm::ObjectPtr n = nnvm::Node::Create(); @@ -50,10 +49,7 @@ using LayoutAttrDict = std::unordered_map >; * insert layout transform nodes automatically. */ nnvm::Graph CorrectLayout(nnvm::Graph src) { - static auto& op_correct_layout = - nnvm::Op::GetAttr("FCorrectLayout"); - static auto& op_correct_layout_ex = - nnvm::Op::GetAttr("FCorrectLayoutEx"); + static auto& op_correct_layout = nnvm::Op::GetAttr("FCorrectLayout"); const IndexedGraph& idx = src.indexed_graph(); std::vector mirror_vec(idx.num_nodes(), nullptr); @@ -67,13 +63,12 @@ nnvm::Graph CorrectLayout(nnvm::Graph src) { *new_node = *(inode.source); if (new_node->is_variable()) { // Variable node. No operator. Only one output entry. - auto input_iter = std::find( - idx.input_nodes().cbegin(), idx.input_nodes().cend(), nid); + auto input_iter = std::find(idx.input_nodes().cbegin(), idx.input_nodes().cend(), nid); CHECK(input_iter != idx.input_nodes().cend()); int64_t input_id = std::distance(idx.input_nodes().cbegin(), input_iter); if (src.HasAttr("layout_inputs")) { - new_layouts[new_node.get()] = - {src.GetAttr >("layout_inputs")[input_id]}; + new_layouts[new_node.get()] = { + src.GetAttr >("layout_inputs")[input_id]}; } else { new_layouts[new_node.get()] = {Layout::Undef()}; } @@ -111,24 +106,10 @@ nnvm::Graph CorrectLayout(nnvm::Graph src) { } } - if (op_correct_layout_ex.count(new_node->op())) { - std::vector input_shapes; - if (src.HasAttr("shape")) { - const auto &shapes = src.GetAttr >("shape"); - for (uint32_t i = 0; i < num_inputs; ++i) { - input_shapes.emplace_back(shapes[idx.entry_id(inode.inputs[i])]); - } - } - const auto &flayout = op_correct_layout_ex[new_node->op()]; - CHECK(flayout(new_node->attrs, &input_shapes, &request_ilayouts, - &last_request_ilayouts, &produce_olayouts)) - << "Layout infer fail"; - CHECK_EQ(request_ilayouts.size(), num_inputs); - CHECK_EQ(produce_olayouts.size(), num_outputs); - } else if (op_correct_layout.count(new_node->op())) { - const auto &flayout = op_correct_layout[new_node->op()]; + if (op_correct_layout.count(new_node->op())) { + const auto& flayout = op_correct_layout[new_node->op()]; CHECK(flayout(new_node->attrs, &request_ilayouts, &last_request_ilayouts, &produce_olayouts)) - << "Layout infer fail"; + << "Layout infer fail"; CHECK_EQ(request_ilayouts.size(), num_inputs); CHECK_EQ(produce_olayouts.size(), num_outputs); } @@ -191,10 +172,10 @@ nnvm::Graph CorrectLayout(nnvm::Graph src) { // register pass NNVM_REGISTER_PASS(CorrectLayout) -.describe("Return a layout-transformed graph of src.") -.set_body(CorrectLayout) -.provide_graph_attr("layout") -.set_change_graph(true); + .describe("Return a layout-transformed graph of src.") + .set_body(CorrectLayout) + .provide_graph_attr("layout") + .set_change_graph(true); DMLC_JSON_ENABLE_ANY(LayoutVector, list_layout); diff --git a/nnvm/src/pass/gradient.cc b/nnvm/src/pass/gradient.cc index 9c30a785cac2..1df3af7ffaaf 100644 --- a/nnvm/src/pass/gradient.cc +++ b/nnvm/src/pass/gradient.cc @@ -22,8 +22,9 @@ * \brief Passes that takes gradient of the graph * This code code was modified based on mxnet codebase by Min Lin */ -#include #include +#include + #include #include @@ -53,8 +54,7 @@ NodeEntry DefaultAggregateGradient(std::vector&& v) { } } -bool CheckGradAllZero(const std::vector& grads, - const std::vector& zero_ops) { +bool CheckGradAllZero(const std::vector& grads, const std::vector& zero_ops) { if (!grads.size() || !zero_ops.size()) return false; for (const auto& g : grads) { bool found = false; @@ -82,22 +82,18 @@ struct GradEntry { Graph Gradient(Graph src) { using nnvm::FGradient; - using MirrorFun = std::function; - using AttrHintFun = std::function; + using MirrorFun = std::function; + using AttrHintFun = std::function; - CHECK_NE(src.attrs.count("grad_ys"), 0U) - << "Gradient require grad_ys to be presented."; + CHECK_NE(src.attrs.count("grad_ys"), 0U) << "Gradient require grad_ys to be presented."; CHECK_NE(src.attrs.count("grad_ys_out_grad"), 0U) << "Gradient require grad_ys_out_grad to be presented."; - CHECK_NE(src.attrs.count("grad_xs"), 0U) - << "Gradient require grad_xs to be presented."; - const std::vector& ys = - src.GetAttr >("grad_ys"); + CHECK_NE(src.attrs.count("grad_xs"), 0U) << "Gradient require grad_xs to be presented."; + const std::vector& ys = src.GetAttr >("grad_ys"); const std::vector& ys_out_grad = src.GetAttr >("grad_ys_out_grad"); - const std::vector& xs = - src.GetAttr >("grad_xs"); - using AggFun = std::function&& inputs)>; + const std::vector& xs = src.GetAttr >("grad_xs"); + using AggFun = std::function && inputs)>; AggFun agg_fun = DefaultAggregateGradient; if (src.attrs.count("grad_aggregate_fun") != 0) { agg_fun = src.GetAttr("grad_aggregate_fun"); @@ -114,31 +110,30 @@ Graph Gradient(Graph src) { if (src.attrs.count("zero_ops") != 0) { zero_ops = src.GetAttr >("zero_ops"); } - const Op* copy_op = (src.attrs.count("copy_op") != 0) ? - Op::Get(src.GetAttr("copy_op")) : - nullptr; + const Op* copy_op = + (src.attrs.count("copy_op") != 0) ? Op::Get(src.GetAttr("copy_op")) : nullptr; // topo sort std::vector topo_order; std::unordered_map > output_grads; DFSVisit(ys, [&](const ObjectPtr& node) { - if (output_grads.count(node.get()) == 0) { - output_grads[node.get()].resize(node->num_outputs()); - } - topo_order.push_back(node); - }); + if (output_grads.count(node.get()) == 0) { + output_grads[node.get()].resize(node->num_outputs()); + } + topo_order.push_back(node); + }); CHECK_EQ(ys.size(), ys_out_grad.size()); for (size_t i = 0; i < ys.size(); ++i) { NodeEntry ograd = ys_out_grad[i]; - output_grads[ys[i].node.get()][ys[i].index].grads = { ograd }; + output_grads[ys[i].node.get()][ys[i].index].grads = {ograd}; } // Check that all xs are reachable from ys for (size_t i = 0; i < xs.size(); ++i) { CHECK(output_grads.find(xs[i].node.get()) != output_grads.end()) - << "Cannot differentiate with respect to the " << i+1 << "-th variable " + << "Cannot differentiate with respect to the " << i + 1 << "-th variable " << "because it is unreachable from the outputs."; } @@ -211,8 +206,7 @@ Graph Gradient(Graph src) { LOG(FATAL) << "Operator " << fwd_node->op()->name << " is non-differentiable " << "because it didn't register FGradient attribute."; } - for (const auto& nodeEntry : input_grads) - CHECK(nodeEntry.node); + for (const auto& nodeEntry : input_grads) CHECK(nodeEntry.node); auto git = input_grads.begin(); CHECK((*rit)->inputs.size() <= input_grads.size()); for (auto it = (*rit)->inputs.begin(); it != (*rit)->inputs.end(); ++it, ++git) { @@ -252,12 +246,12 @@ Graph Gradient(Graph src) { copy_node->attrs.name = os.str(); copy_node->inputs.emplace_back(entry.sum); if (copy_node->attrs.op->attr_parser != nullptr) { - copy_node->attrs.op->attr_parser(&(copy_node->attrs)); + copy_node->attrs.op->attr_parser(&(copy_node->attrs)); } unique_grads.emplace(NodeEntry{std::move(copy_node), 0, 0}, std::make_pair(1, counter)); } } else { - ret.outputs[counter] = entry.sum; + ret.outputs[counter] = entry.sum; } ++counter; } @@ -271,12 +265,12 @@ Graph Gradient(Graph src) { // register pass NNVM_REGISTER_PASS(Gradient) -.describe("Return a gradient graph of src.attrs[\"ys\"] wrt src.attrs[\"xs\"]") -.set_body(Gradient) -.set_change_graph(true) -.depend_graph_attr("grad_ys") -.depend_graph_attr("grad_xs") -.depend_graph_attr("grad_ys_out_grad"); + .describe("Return a gradient graph of src.attrs[\"ys\"] wrt src.attrs[\"xs\"]") + .set_body(Gradient) + .set_change_graph(true) + .depend_graph_attr("grad_ys") + .depend_graph_attr("grad_xs") + .depend_graph_attr("grad_ys_out_grad"); } // namespace } // namespace pass diff --git a/nnvm/src/pass/graph_algorithm.h b/nnvm/src/pass/graph_algorithm.h index 1d274ff3b96d..b305c08bc05f 100644 --- a/nnvm/src/pass/graph_algorithm.h +++ b/nnvm/src/pass/graph_algorithm.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -22,11 +22,12 @@ * \brief This header contains graph algorithms on StaticGraph. * It is used compute informations such as whether two * operations can run in parallel, and helps allocation. -*/ + */ #ifndef NNVM_PASS_GRAPH_ALGORITHM_H_ #define NNVM_PASS_GRAPH_ALGORITHM_H_ #include + #include namespace nnvm { @@ -41,10 +42,8 @@ namespace pass { * \param path the output path of nodes. * \return the total reward of best path. */ -inline uint32_t FindBestPath( - const IndexedGraph& graph, - const std::vector& node_reward, - std::vector* path) { +inline uint32_t FindBestPath(const IndexedGraph& graph, const std::vector& node_reward, + std::vector* path) { const uint32_t num_nodes = static_cast(graph.num_nodes()); CHECK_EQ(num_nodes, node_reward.size()); @@ -71,7 +70,8 @@ inline uint32_t FindBestPath( path->clear(); uint32_t reward = 0; for (uint32_t nid = best_start_node; nid < num_nodes; nid = next_node[nid]) { - path->push_back(nid); reward += node_reward[nid]; + path->push_back(nid); + reward += node_reward[nid]; } CHECK_EQ(reward, best_solution); return best_solution; @@ -88,11 +88,8 @@ inline uint32_t FindBestPath( * \param color the color index of each of the node. * \return the total number of colors. */ -inline uint32_t ColorNodeGroup( - const IndexedGraph &graph, - std::vector node_importance, - uint32_t max_ncolor, - std::vector *color) { +inline uint32_t ColorNodeGroup(const IndexedGraph& graph, std::vector node_importance, + uint32_t max_ncolor, std::vector* color) { CHECK_NE(max_ncolor, 0U); CHECK_EQ(graph.num_nodes(), node_importance.size()); diff --git a/nnvm/src/pass/infer_shape_type.cc b/nnvm/src/pass/infer_shape_type.cc index 876dce1c113d..fde1691ee96a 100644 --- a/nnvm/src/pass/infer_shape_type.cc +++ b/nnvm/src/pass/infer_shape_type.cc @@ -21,33 +21,24 @@ * \file infer_shape.cc * \brief Inference the shapes given existin information. */ -#include -#include #include +#include +#include namespace nnvm { namespace pass { namespace { -template -Graph InferAttr(Graph &&ret, - const AttrType empty_val, - const char* infer_name, - const char* input_name, - const char* attr_key_name, - const char* attr_name, - const char* unknown_name, - IsNone fis_none, - FDefault fdefault) { +template +Graph InferAttr(Graph&& ret, const AttrType empty_val, const char* infer_name, + const char* input_name, const char* attr_key_name, const char* attr_name, + const char* unknown_name, IsNone fis_none, FDefault fdefault) { using AttrVector = std::vector; const IndexedGraph& idx = ret.indexed_graph(); - static auto& finfer_shape = - Op::GetAttr >(infer_name); - static auto& is_backward = - Op::GetAttr("TIsBackward"); + static auto& finfer_shape = Op::GetAttr>(infer_name); + static auto& is_backward = Op::GetAttr("TIsBackward"); // gradient function, used to get node correspondence. - static auto& fgrad = - Op::GetAttr("FGradient"); + static auto& fgrad = Op::GetAttr("FGradient"); // reshape shape vector AttrVector rshape; if (ret.attrs.count(attr_name) != 0) { @@ -70,8 +61,7 @@ Graph InferAttr(Graph &&ret, // get the shape hints std::string shape_hints_key = std::string(attr_name) + "_hints"; if (ret.attrs.count(shape_hints_key)) { - NodeEntryMap shape_hints = - ret.GetAttr>(shape_hints_key); + NodeEntryMap shape_hints = ret.GetAttr>(shape_hints_key); for (const auto& kv : shape_hints) { NodeEntry e = kv.first; if (idx.exist(e.node.get())) { @@ -110,7 +100,7 @@ Graph InferAttr(Graph &&ret, } } else if (is_backward.get(inode.source->op(), false) && inode.control_deps.size()) { CHECK_GE(inode.control_deps.size(), 1U) - << "BackwardOp need to have control_deps to its forward op"; + << "BackwardOp need to have control_deps to its forward op"; const IndexedGraph::Node& fnode = idx[inode.control_deps[0]]; ObjectPtr fwd_ptr = inode.source->control_deps[0]; CHECK(fwd_ptr->op() != nullptr) << "Forward op cannot be a variable"; @@ -141,7 +131,7 @@ Graph InferAttr(Graph &&ret, } // out grad entries CHECK(igrad_node != nullptr) - << "Cannot find matching backward op for " << inode.source->attrs.name; + << "Cannot find matching backward op for " << inode.source->attrs.name; for (size_t i = 0; i < igrad_node->inputs.size(); ++i) { const NodeEntry& e = igrad_node->inputs[i]; if (e.node == nullptr) { @@ -174,10 +164,9 @@ Graph InferAttr(Graph &&ret, throw dmlc::Error("Error in operator " + inode.source->attrs.name + ": " + e.what()); } } else { - CHECK(!last_iter) - << "Attribute " << infer_name - << " is not registered by op " << inode.source->op()->name - << " we are not able to complete the inference because of this"; + CHECK(!last_iter) << "Attribute " << infer_name << " is not registered by op " + << inode.source->op()->name + << " we are not able to complete the inference because of this"; } } // Save to the result map. @@ -221,32 +210,30 @@ Graph InferAttr(Graph &&ret, } NNVM_REGISTER_PASS(InferShape) -.describe("Infer the shape of each node entries.") -.set_body([](Graph ret) { - return InferAttr( - std::move(ret), TShape(), - "FInferShape", "shape_inputs", "shape_attr_key", - "shape", "shape_num_unknown_nodes", - [](const TShape& s) { return s.ndim() == 0 || s.Size() == 0; }, - nullptr); - }) -.set_change_graph(false) -.provide_graph_attr("shape"); + .describe("Infer the shape of each node entries.") + .set_body([](Graph ret) { + return InferAttr( + std::move(ret), TShape(), "FInferShape", "shape_inputs", "shape_attr_key", "shape", + "shape_num_unknown_nodes", [](const TShape& s) { return s.ndim() == 0 || s.Size() == 0; }, + nullptr); + }) + .set_change_graph(false) + .provide_graph_attr("shape"); // inference function for same type -inline bool SameType(const NodeAttrs& attrs, - std::vector *iattr, - std::vector *oattr) { +inline bool SameType(const NodeAttrs& attrs, std::vector* iattr, std::vector* oattr) { int def_v = -1; for (int v : *oattr) { if (v != -1) { - def_v = v; break; + def_v = v; + break; } } if (def_v == -1) { for (int v : *iattr) { if (v != -1) { - def_v = v; break; + def_v = v; + break; } } } @@ -261,17 +248,14 @@ inline bool SameType(const NodeAttrs& attrs, } NNVM_REGISTER_PASS(InferType) -.describe("Infer the dtype of each node entries.") -.set_body([](Graph ret) { - return InferAttr( - std::move(ret), -1, - "FInferType", "dtype_inputs", "dtype_attr_key", - "dtype", "dtype_num_unknown_nodes", - [](const int t) { return t == -1; }, - SameType); - }) -.set_change_graph(false) -.provide_graph_attr("dtype"); + .describe("Infer the dtype of each node entries.") + .set_body([](Graph ret) { + return InferAttr( + std::move(ret), -1, "FInferType", "dtype_inputs", "dtype_attr_key", "dtype", + "dtype_num_unknown_nodes", [](const int t) { return t == -1; }, SameType); + }) + .set_change_graph(false) + .provide_graph_attr("dtype"); DMLC_JSON_ENABLE_ANY(ShapeVector, list_shape); DMLC_JSON_ENABLE_ANY(DTypeVector, list_int); diff --git a/nnvm/src/pass/order_mutation.cc b/nnvm/src/pass/order_mutation.cc index b2fa2ca33e07..2575a03ace03 100644 --- a/nnvm/src/pass/order_mutation.cc +++ b/nnvm/src/pass/order_mutation.cc @@ -23,17 +23,15 @@ * To correctly order mutation and read to resolve * write after read problem and read after write problems. */ -#include #include +#include namespace nnvm { namespace pass { namespace { -template -inline T get_with_default(const std::unordered_map &map, - Node* key, - const T& def) { +template +inline T get_with_default(const std::unordered_map& map, Node* key, const T& def) { auto it = map.find(key); if (it != map.end()) return it->second; return def; @@ -46,19 +44,19 @@ inline bool IsMutate(const std::vector& mutate_inputs, uint32_t i) { Graph OrderMutation(const Graph& src) { std::unordered_map > version_hist; DFSVisit(src.outputs, [&version_hist](const ObjectPtr& n) { - for (const NodeEntry& e : n->inputs) { - if (e.node->is_variable()) { - if (e.version != 0 && version_hist.count(e.node.get()) == 0) { - version_hist[e.node.get()] = std::vector{}; - } + for (const NodeEntry& e : n->inputs) { + if (e.node->is_variable()) { + if (e.version != 0 && version_hist.count(e.node.get()) == 0) { + version_hist[e.node.get()] = std::vector{}; } } - }); + } + }); // no mutation happens, everything if fine. if (version_hist.size() == 0) return src; // start preparing for remapping the nodes. std::unordered_map old_new; - auto prepare = [&version_hist, &old_new] (const ObjectPtr& n) { + auto prepare = [&version_hist, &old_new](const ObjectPtr& n) { static auto& fmutate_inputs = Op::GetAttr("FMutateInputs"); std::vector mutate_inputs; if (!n->is_variable() && fmutate_inputs.count(n->op())) { @@ -91,17 +89,17 @@ Graph OrderMutation(const Graph& src) { }; DFSVisit(src.outputs, prepare); // comparator of history entry - auto comparator = [](const NodeEntry& a, const NodeEntry &b) { + auto comparator = [](const NodeEntry& a, const NodeEntry& b) { if (a.version < b.version) return true; if (a.version > b.version) return false; return a.index > b.index; }; - for (auto &kv : version_hist) { + for (auto& kv : version_hist) { std::sort(kv.second.begin(), kv.second.end(), comparator); } // copy the nodes, as well as add control deps - for (auto &kv : old_new) { + for (auto& kv : old_new) { // copy the nodes for (const NodeEntry& e : kv.first->inputs) { auto it = old_new.find(e.node.get()); @@ -112,8 +110,7 @@ Graph OrderMutation(const Graph& src) { } } for (const ObjectPtr& p : kv.first->control_deps) { - kv.second->control_deps.emplace_back( - get_with_default(old_new, p.get(), p)); + kv.second->control_deps.emplace_back(get_with_default(old_new, p.get(), p)); } // add control deps static auto& fmutate_inputs = Op::GetAttr("FMutateInputs"); @@ -127,9 +124,8 @@ Graph OrderMutation(const Graph& src) { const NodeEntry& e = kv.first->inputs[i]; if (e.node->is_variable() && version_hist.count(e.node.get()) != 0) { std::vector& vec = version_hist.at(e.node.get()); - auto it = std::lower_bound(vec.begin(), vec.end(), - NodeEntry{nullptr, 1, e.version}, - comparator); + auto it = + std::lower_bound(vec.begin(), vec.end(), NodeEntry{nullptr, 1, e.version}, comparator); if (IsMutate(mutate_inputs, i)) { int read_dep = 0; while (it != vec.begin()) { @@ -137,37 +133,35 @@ Graph OrderMutation(const Graph& src) { if (it->index != 0) break; ++read_dep; // depend on previous read - kv.second->control_deps.push_back( - get_with_default(old_new, it->node.get(), it->node)); + kv.second->control_deps.push_back(get_with_default(old_new, it->node.get(), it->node)); } if (read_dep == 0 && it->index != 0) { // depend on last write - kv.second->control_deps.push_back( - get_with_default(old_new, it->node.get(), it->node)); + kv.second->control_deps.push_back(get_with_default(old_new, it->node.get(), it->node)); } } else { // depend on last write if (it->index != 0) { - kv.second->control_deps.push_back( - get_with_default(old_new, it->node.get(), it->node)); + kv.second->control_deps.push_back(get_with_default(old_new, it->node.get(), it->node)); } } } } } Graph ret; - for (const NodeEntry &e : src.outputs) { - ret.outputs.emplace_back(NodeEntry{ - get_with_default(old_new, e.node.get(), e.node), e.index, e.version}); + for (const NodeEntry& e : src.outputs) { + ret.outputs.emplace_back( + NodeEntry{get_with_default(old_new, e.node.get(), e.node), e.index, e.version}); } return ret; } NNVM_REGISTER_PASS(OrderMutation) -.describe("Return a new graph that adds control dependencies, "\ - "to order the mutation and reads if mutation exists.") -.set_body(OrderMutation) -.set_change_graph(true); + .describe( + "Return a new graph that adds control dependencies, " + "to order the mutation and reads if mutation exists.") + .set_body(OrderMutation) + .set_change_graph(true); } // namespace } // namespace pass diff --git a/nnvm/src/pass/place_device.cc b/nnvm/src/pass/place_device.cc index 6d6866e472d6..d45658ae24ab 100644 --- a/nnvm/src/pass/place_device.cc +++ b/nnvm/src/pass/place_device.cc @@ -22,9 +22,9 @@ * \brief Inference the device of each operator given known information. * Insert a copy node automatically when there is a cross device. */ -#include -#include #include +#include +#include namespace nnvm { namespace pass { @@ -43,8 +43,7 @@ Graph PlaceDevice(Graph src) { const Op* copy_op = Op::Get(src.GetAttr("device_copy_op")); auto& device_assign_map = src.GetAttr("device_assign_map"); const IndexedGraph& idx = src.indexed_graph(); - static auto& is_backward = - Op::GetAttr("TIsBackward"); + static auto& is_backward = Op::GetAttr("TIsBackward"); DeviceVector device; // copy on write semanatics if (src.attrs.count("device") != 0) { @@ -65,15 +64,15 @@ Graph PlaceDevice(Graph src) { << "The device assignment not found for group " << device_group; device[nid] = dit->second; } else { - if (!inode.source->is_variable() && - is_backward.get(inode.source->op(), false)) { + if (!inode.source->is_variable() && is_backward.get(inode.source->op(), false)) { if (device[inode.control_deps[0]] != -1) { device[nid] = device[inode.control_deps[0]]; } } else { for (const IndexedGraph::NodeEntry& e : inode.inputs) { if (device[e.node_id] != -1) { - device[nid] = device[e.node_id]; break; + device[nid] = device[e.node_id]; + break; } } } @@ -121,20 +120,21 @@ Graph PlaceDevice(Graph src) { auto e = inode.inputs[index]; if (new_node_map[e.node_id] != nullptr || dev_id != device[e.node_id]) { LOG(FATAL) << " mutable state cannot go across device" - << " op=" << inode.source->op()->name - << " input_state_index=" << index; + << " op=" << inode.source->op()->name << " input_state_index=" << index; } } } for (const IndexedGraph::NodeEntry& e : inode.inputs) { if (new_node_map[e.node_id] != nullptr || dev_id != device[e.node_id]) { - need_mutate = true; break; + need_mutate = true; + break; } } if (!need_mutate) { for (const uint32_t cid : inode.control_deps) { - if (new_node_map[cid] != nullptr) { - need_mutate = true; break; + if (new_node_map[cid] != nullptr) { + need_mutate = true; + break; } } } @@ -151,17 +151,15 @@ Graph PlaceDevice(Graph src) { auto copy_key = std::make_tuple(e.node_id, e.index, dev_id); auto it = copy_map.find(copy_key); if (it != copy_map.end() && it->first == copy_key) { - new_node->inputs.emplace_back( - NodeEntry{it->second, 0, 0}); + new_node->inputs.emplace_back(NodeEntry{it->second, 0, 0}); } else { ObjectPtr copy_node = Node::Create(); std::ostringstream os; - os << inode.source->inputs[i].node->attrs.name << "_" << e.index <<"_copy"; + os << inode.source->inputs[i].node->attrs.name << "_" << e.index << "_copy"; copy_node->attrs.op = copy_op; copy_node->attrs.name = os.str(); if (new_node_map[e.node_id] != nullptr) { - copy_node->inputs.emplace_back( - NodeEntry{new_node_map[e.node_id], e.index, 0}); + copy_node->inputs.emplace_back(NodeEntry{new_node_map[e.node_id], e.index, 0}); } else { copy_node->inputs.push_back(inode.source->inputs[i]); } @@ -170,13 +168,11 @@ Graph PlaceDevice(Graph src) { } copy_map[copy_key] = copy_node; new_device_map[copy_node.get()] = dev_id; - new_node->inputs.emplace_back( - NodeEntry{std::move(copy_node), 0, 0}); + new_node->inputs.emplace_back(NodeEntry{std::move(copy_node), 0, 0}); } } else { if (new_node_map[e.node_id] != nullptr) { - new_node->inputs.emplace_back( - NodeEntry{new_node_map[e.node_id], e.index, 0}); + new_node->inputs.emplace_back(NodeEntry{new_node_map[e.node_id], e.index, 0}); } else { new_node->inputs.push_back(inode.source->inputs[i]); } @@ -220,14 +216,15 @@ Graph PlaceDevice(Graph src) { } NNVM_REGISTER_PASS(PlaceDevice) -.describe("Infer the device type of each operator."\ - "Insert a copy node when there is cross device copy") -.set_body(PlaceDevice) -.set_change_graph(true) -.provide_graph_attr("device") -.depend_graph_attr("device_group_attr_key") -.depend_graph_attr("device_assign_map") -.depend_graph_attr("device_copy_op"); + .describe( + "Infer the device type of each operator." + "Insert a copy node when there is cross device copy") + .set_body(PlaceDevice) + .set_change_graph(true) + .provide_graph_attr("device") + .depend_graph_attr("device_group_attr_key") + .depend_graph_attr("device_assign_map") + .depend_graph_attr("device_copy_op"); DMLC_JSON_ENABLE_ANY(DeviceAssignMap, dict_str_int); diff --git a/nnvm/src/pass/plan_memory.cc b/nnvm/src/pass/plan_memory.cc index abd18eda5edd..7d478c646a1f 100644 --- a/nnvm/src/pass/plan_memory.cc +++ b/nnvm/src/pass/plan_memory.cc @@ -22,10 +22,12 @@ * \brief Assign memory tag to each of the data entries. */ #include -#include #include #include +#include + #include + #include "graph_algorithm.h" namespace nnvm { @@ -82,10 +84,10 @@ class GraphAllocator { auto end = free_.upper_bound(size * match_range_); // search for memory blocks larger than requested for (auto it = mid; it != end; ++it) { - StorageEntry *e = it->second; + StorageEntry* e = it->second; if (e->device_id != dev_id) continue; - if (node_color_.size() != 0 && - node_color_[e->released_by_node] != node_color_[node_id]) continue; + if (node_color_.size() != 0 && node_color_[e->released_by_node] != node_color_[node_id]) + continue; // Use exect matching strategy e->max_bytes = std::max(size, e->max_bytes); // find a exact match, erase from map and return @@ -95,10 +97,10 @@ class GraphAllocator { // then search for memory blocks smaller than requested space for (auto it = mid; it != begin;) { --it; - StorageEntry *e = it->second; + StorageEntry* e = it->second; if (e->device_id != dev_id) continue; - if (node_color_.size() != 0 && - node_color_[e->released_by_node] != node_color_[node_id]) continue; + if (node_color_.size() != 0 && node_color_[e->released_by_node] != node_color_[node_id]) + continue; // Use exect matching strategy e->max_bytes = std::max(size, e->max_bytes); // erase from map and return @@ -112,7 +114,7 @@ class GraphAllocator { void Release(StorageID id, uint32_t node_id) { CHECK_NE(id, kBadStorageID); if (id == kExternalStorageID || id == kDynamicStorageID) return; - StorageEntry *e = data_[id].get(); + StorageEntry* e = data_[id].get(); e->released_by_node = node_id; free_.insert({e->max_bytes, e}); } @@ -120,7 +122,7 @@ class GraphAllocator { // totoal number of bytes allocated size_t TotalAllocBytes() const { size_t total = 0; - for (auto &p : data_) { + for (auto& p : data_) { total += p->max_bytes; } return total; @@ -142,8 +144,7 @@ class GraphAllocator { if ((*idx_)[nid].source->is_variable()) continue; importance[nid] = 1; } - num_match_color_ = pass::ColorNodeGroup( - *idx_, importance, num_match_color_, &node_color_); + num_match_color_ = pass::ColorNodeGroup(*idx_, importance, num_match_color_, &node_color_); } } @@ -187,18 +188,16 @@ class GraphAllocator { * Internal method to perform the memory allocation for a graph * */ size_t AllocMemory(const Graph& ret, const IndexedGraph& idx, - const std::pair& node_range, - StorageVector* storage_ptr, + const std::pair& node_range, StorageVector* storage_ptr, std::vector* storage_inplace_index_ptr, - const std::vector& entry_ref_count, - GraphAllocator* allocator) { + const std::vector& entry_ref_count, GraphAllocator* allocator) { static auto& finplace_option = Op::GetAttr("FInplaceOption"); static auto& finplace_identity = Op::GetAttr("FInplaceIdentity"); static auto& fignore_inputs = Op::GetAttr("FIgnoreInputs"); // Get reference - auto &storage = *storage_ptr; - auto &storage_inplace_index = *storage_inplace_index_ptr; + auto& storage = *storage_ptr; + auto& storage_inplace_index = *storage_inplace_index_ptr; // Get attributes from the graph const ShapeVector& shape_vec = ret.GetAttr("shape"); @@ -234,19 +233,16 @@ size_t AllocMemory(const Graph& ret, const IndexedGraph& idx, auto sid_out = storage[eid_out]; auto sid_in = storage[eid_in]; bool ignore_all_inputs = (fignore_inputs.count(inode.source->op()) != 0 && - fignore_inputs[inode.source->op()]( - inode.source->attrs).size() == inode.source->num_inputs()); + fignore_inputs[inode.source->op()](inode.source->attrs).size() == + inode.source->num_inputs()); // Identity should only be true if shape.Size() and types match bool real_identity = identity[ipair] && shape_vec[eid_out].Size() == shape_vec[eid_in].Size() && dtype_vec[eid_out] == dtype_vec[eid_in]; - if (taken[kv.first] == false && - sid_out == GraphAllocator::kBadStorageID && - sid_in >= 0 && + if (taken[kv.first] == false && sid_out == GraphAllocator::kBadStorageID && sid_in >= 0 && ((storage_ref_count[sid_in] == 1 && !ignore_all_inputs) || real_identity) && - entry_ref_count[eid_out] > 0 && - shape_vec[eid_out].Size() == shape_vec[eid_in].Size() && - (dtype_vec[eid_out] == dtype_vec[eid_in] || + entry_ref_count[eid_out] > 0 && shape_vec[eid_out].Size() == shape_vec[eid_in].Size() && + (dtype_vec[eid_out] == dtype_vec[eid_in] || GetDTypeSize(dtype_vec[eid_out]) == GetDTypeSize(dtype_vec[eid_in]))) { // inplace optimization taken[kv.first] = true; @@ -265,21 +261,19 @@ size_t AllocMemory(const Graph& ret, const IndexedGraph& idx, uint32_t eid = idx.entry_id(nid, index); // only request memory for kBadStorageID if (storage[eid] == GraphAllocator::kBadStorageID) { - auto &eshape = shape_vec[eid]; + auto& eshape = shape_vec[eid]; size_t esize = 0; if (eshape.ndim() != 0) esize = eshape.Size(); eids.insert(std::make_pair(esize, eid)); } } for (auto rit = eids.rbegin(); rit != eids.rend(); ++rit) { - uint32_t eid = rit->second; - // normal allocation - const int dev_id = (device_vec != nullptr) ? device_vec->at(eid) : 0; - auto sid = allocator->Request(dev_id, dtype_vec[eid], shape_vec[eid], nid); - if (sid >= 0) { - storage_ref_count[sid] = entry_ref_count[eid]; - } - storage[eid] = sid; + uint32_t eid = rit->second; + auto sid = allocator->Request(dev_id, dtype_vec[eid], shape_vec[eid], nid); + if (sid >= 0) { + storage_ref_count[sid] = entry_ref_count[eid]; + } + storage[eid] = sid; } // check if certain inputs is ignored. std::vector ignore_inputs; @@ -320,7 +314,6 @@ size_t AllocMemory(const Graph& ret, const IndexedGraph& idx, return num_not_allocated; } - // function to plan memory Graph PlanMemory(Graph ret) { // setup ref counter @@ -368,7 +361,7 @@ Graph PlanMemory(Graph ret) { size_t min_allocated_bytes = -1; size_t max_match_range = dmlc::GetEnv("NNVM_EXEC_MATCH_RANGE", 16); size_t min_match_range = - dmlc::GetEnv("NNVM_AUTO_SEARCH_MATCH_RANGE", false) ? 1 : max_match_range; + dmlc::GetEnv("NNVM_AUTO_SEARCH_MATCH_RANGE", false) ? 1 : max_match_range; for (size_t match_range = min_match_range; match_range <= max_match_range; match_range *= 2) { // Make a copy of related fields StorageVector storage_vec(storage); @@ -378,9 +371,8 @@ Graph PlanMemory(Graph ret) { GraphAllocator allocator(&idx, match_range); // number of entries that are not statically allocated. - size_t storage_num_not_allocated = - AllocMemory(ret, idx, node_range, &storage_vec, &storage_inplace_index, - ref_count, &allocator); + size_t storage_num_not_allocated = AllocMemory(ret, idx, node_range, &storage_vec, + &storage_inplace_index, ref_count, &allocator); size_t storage_allocated_bytes = allocator.TotalAllocBytes(); // Choose the plan which leads to minimal memory usage @@ -400,13 +392,13 @@ Graph PlanMemory(Graph ret) { } NNVM_REGISTER_PASS(PlanMemory) -.describe("Plan the memory allocation of each node entries.") -.set_body(PlanMemory) -.set_change_graph(false) -.depend_graph_attr("dtype") -.depend_graph_attr("shape") -.provide_graph_attr("storage_id") -.provide_graph_attr("storage_inplace_index"); + .describe("Plan the memory allocation of each node entries.") + .set_body(PlanMemory) + .set_change_graph(false) + .depend_graph_attr("dtype") + .depend_graph_attr("shape") + .provide_graph_attr("storage_id") + .provide_graph_attr("storage_inplace_index"); } // namespace } // namespace pass diff --git a/nnvm/src/pass/print_graph_ir.cc b/nnvm/src/pass/print_graph_ir.cc index a0127abe10f4..4fe92e665961 100644 --- a/nnvm/src/pass/print_graph_ir.cc +++ b/nnvm/src/pass/print_graph_ir.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -24,6 +24,7 @@ #include #include #include + #include namespace nnvm { @@ -31,47 +32,39 @@ namespace pass { using AttrPrinter = std::function; // NOLINT(*) -template +template AttrPrinter GetVectorPrinter_(const T& vec) { return [&vec](uint32_t index, std::ostream& os) { // NOLINT(*) os << vec[index]; }; } -AttrPrinter GetVectorPrinter(const Graph& graph, - const std::string& key) { +AttrPrinter GetVectorPrinter(const Graph& graph, const std::string& key) { auto it = graph.attrs.find(key); - CHECK(it != graph.attrs.end()) - << "Cannot find " << key << " in graph attr"; + CHECK(it != graph.attrs.end()) << "Cannot find " << key << " in graph attr"; const any& value = *(it->second); if (value.type() == typeid(std::vector)) { - return GetVectorPrinter_( - nnvm::get >(value)); + return GetVectorPrinter_(nnvm::get >(value)); } else if (value.type() == typeid(std::vector)) { - return GetVectorPrinter_( - nnvm::get >(value)); + return GetVectorPrinter_(nnvm::get >(value)); } else if (value.type() == typeid(std::vector)) { - return GetVectorPrinter_( - nnvm::get >(value)); + return GetVectorPrinter_(nnvm::get >(value)); } else { LOG(FATAL) << "Cannot handle type " << value.type().name(); return nullptr; } } - // print the graph ir in readable format -void PrintGraphIR_(Graph src, - const std::vector& join_entry_attrs, +void PrintGraphIR_(Graph src, const std::vector& join_entry_attrs, const std::vector& join_node_attrs, - std::ostream& os) { // NOLINT(*) + std::ostream& os) { // NOLINT(*) const IndexedGraph& idx = src.indexed_graph(); std::vector > trigger; // NOLINT(*) for (const std::string& key : join_entry_attrs) { AttrPrinter fp = GetVectorPrinter(src, key); - auto fprint = [&idx, key, fp]( - uint32_t nid, std::ostream& os) { // NOLINT(*) + auto fprint = [&idx, key, fp](uint32_t nid, std::ostream& os) { // NOLINT(*) const IndexedGraph::Node& inode = idx[nid]; os << ", " << key << "="; if (inode.source->num_outputs() != 1) { @@ -89,8 +82,7 @@ void PrintGraphIR_(Graph src, } for (const std::string& key : join_node_attrs) { AttrPrinter fp = GetVectorPrinter(src, key); - auto fprint = [&idx, key, fp]( - uint32_t nid, std::ostream& os) { // NOLINT(*) + auto fprint = [&idx, key, fp](uint32_t nid, std::ostream& os) { // NOLINT(*) os << ", " << key << "="; fp(idx.entry_id(nid, 0), os); }; @@ -101,7 +93,7 @@ void PrintGraphIR_(Graph src, if (idx.input_nodes().size() < 4) { for (size_t i = 0; i < idx.input_nodes().size(); ++i) { uint32_t nid = idx.input_nodes()[i]; - if (i != 0) { + if (i != 0) { os << ", "; } os << '%' << idx[nid].source->attrs.name; @@ -109,7 +101,7 @@ void PrintGraphIR_(Graph src, } else { for (size_t i = 0; i < idx.input_nodes().size(); ++i) { uint32_t nid = idx.input_nodes()[i]; - if (i != 0) { + if (i != 0) { os << ",\n "; } os << '%' << idx[nid].source->attrs.name; @@ -141,8 +133,8 @@ void PrintGraphIR_(Graph src, for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { const auto& inode = idx[nid]; if (inode.source->is_variable()) continue; - os << " " << "%" << nid << " = " - << inode.source->op()->name << "("; + os << " " + << "%" << nid << " = " << inode.source->op()->name << "("; bool first = true; for (const IndexedGraph::NodeEntry& e : inode.inputs) { if (first) { @@ -213,12 +205,10 @@ Graph PrintGraphIRPass(Graph src) { std::ostringstream os; std::vector join_entry_attrs, join_node_attrs; if (src.attrs.count("join_entry_attrs") != 0) { - join_entry_attrs = src.MoveCopyAttr >( - "join_entry_attrs"); + join_entry_attrs = src.MoveCopyAttr >("join_entry_attrs"); } if (src.attrs.count("join_node_attrs") != 0) { - join_node_attrs = src.MoveCopyAttr >( - "join_node_attrs"); + join_node_attrs = src.MoveCopyAttr >("join_node_attrs"); } PrintGraphIR_(src, join_entry_attrs, join_node_attrs, os); Graph ret; @@ -228,8 +218,8 @@ Graph PrintGraphIRPass(Graph src) { // register pass NNVM_REGISTER_PASS(PrintGraphIR) -.describe("Return a empty Graph, save ir to ret.attrs[\"graphir\"]") -.set_body(PrintGraphIRPass); + .describe("Return a empty Graph, save ir to ret.attrs[\"graphir\"]") + .set_body(PrintGraphIRPass); } // namespace pass } // namespace nnvm diff --git a/nnvm/src/pass/saveload_json.cc b/nnvm/src/pass/saveload_json.cc index 9389995d0521..3916da43618d 100644 --- a/nnvm/src/pass/saveload_json.cc +++ b/nnvm/src/pass/saveload_json.cc @@ -21,20 +21,21 @@ * \file saveload_json.cc * \brief Save and load graph to/from JSON file. */ +#include #include #include -#include + #include namespace dmlc { namespace json { // overload handler for shared ptr -template<> -struct Handler > { - inline static void Write(JSONWriter *writer, const std::shared_ptr &data) { +template <> +struct Handler> { + inline static void Write(JSONWriter* writer, const std::shared_ptr& data) { writer->Write(*data); } - inline static void Read(JSONReader *reader, std::shared_ptr *data) { + inline static void Read(JSONReader* reader, std::shared_ptr* data) { any v; reader->Read(&v); *data = std::make_shared(std::move(v)); @@ -60,17 +61,16 @@ struct JSONNode { uint32_t index; uint32_t version; Entry() = default; - Entry(uint32_t node_id, uint32_t index, uint32_t version): - node_id(node_id), index(index), version(version) { - } - void Save(dmlc::JSONWriter *writer) const { + Entry(uint32_t node_id, uint32_t index, uint32_t version) + : node_id(node_id), index(index), version(version) {} + void Save(dmlc::JSONWriter* writer) const { writer->BeginArray(false); writer->WriteArrayItem(node_id); writer->WriteArrayItem(index); writer->WriteArrayItem(version); writer->EndArray(); } - void Load(dmlc::JSONReader *reader) { + void Load(dmlc::JSONReader* reader) { reader->BeginArray(); CHECK(reader->NextArrayItem()) << "invalid json format"; reader->Read(&node_id); @@ -95,7 +95,7 @@ struct JSONNode { std::vector subgraphs; // function to save JSON node. - void Save(dmlc::JSONWriter *writer) const { + void Save(dmlc::JSONWriter* writer) const { writer->BeginObject(); if (node->op() != nullptr) { writer->WriteObjectKeyValue("op", node->op()->name); @@ -106,8 +106,7 @@ struct JSONNode { writer->WriteObjectKeyValue("name", node->attrs.name); if (node->attrs.dict.size() != 0) { // write attributes in order; - std::map dict( - node->attrs.dict.begin(), node->attrs.dict.end()); + std::map dict(node->attrs.dict.begin(), node->attrs.dict.end()); writer->WriteObjectKeyValue("attrs", dict); } writer->WriteObjectKeyValue("inputs", inputs); @@ -120,7 +119,7 @@ struct JSONNode { writer->EndObject(); } - void Load(dmlc::JSONReader *reader) { + void Load(dmlc::JSONReader* reader) { node = Node::Create(); control_deps.clear(); dmlc::JSONObjectReadHelper helper; @@ -143,10 +142,10 @@ struct JSONNode { if (op_type_str != "null") { try { node->attrs.op = Op::Get(op_type_str); - } catch (const dmlc::Error &err) { + } catch (const dmlc::Error& err) { std::ostringstream os; - os << "Failed loading Op " << node->attrs.name - << " of type " << op_type_str << ": " << err.what(); + os << "Failed loading Op " << node->attrs.name << " of type " << op_type_str << ": " + << err.what(); throw dmlc::Error(os.str()); } } else { @@ -161,9 +160,9 @@ struct JSONGraph { std::vector arg_nodes; std::vector node_row_ptr; std::vector heads; - std::unordered_map > attrs; + std::unordered_map> attrs; - void Save(dmlc::JSONWriter *writer) const { + void Save(dmlc::JSONWriter* writer) const { writer->BeginObject(); writer->WriteObjectKeyValue("nodes", nodes); writer->WriteObjectKeyValue("arg_nodes", arg_nodes); @@ -175,7 +174,7 @@ struct JSONGraph { writer->EndObject(); } - void Load(dmlc::JSONReader *reader) { + void Load(dmlc::JSONReader* reader) { attrs.clear(); dmlc::JSONObjectReadHelper helper; helper.DeclareField("nodes", &nodes); @@ -187,7 +186,7 @@ struct JSONGraph { } }; -void Symbol2JSONGraph(std::shared_ptr src, JSONGraph *jgraph) { +void Symbol2JSONGraph(std::shared_ptr src, JSONGraph* jgraph) { std::unordered_map node2index; jgraph->node_row_ptr.push_back(0); DFSVisit(src->outputs, [&node2index, jgraph](const ObjectPtr& n) { @@ -212,10 +211,10 @@ void Symbol2JSONGraph(std::shared_ptr src, JSONGraph *jgraph) { jgraph->heads.emplace_back(node2index.at(e.node.get()), e.index, e.version); } // recursively construct subgraphs - for (JSONNode &jnode : jgraph->nodes) { + for (JSONNode& jnode : jgraph->nodes) { // construct jnode's subgraphs - const std::vector> &subgraphs = jnode.node->attrs.subgraphs; - std::vector &jsubgraphs = jnode.subgraphs; + const std::vector>& subgraphs = jnode.node->attrs.subgraphs; + std::vector& jsubgraphs = jnode.subgraphs; jsubgraphs.resize(subgraphs.size()); for (uint32_t i = 0; i < subgraphs.size(); ++i) { Symbol2JSONGraph(subgraphs[i], &jsubgraphs[i]); @@ -223,10 +222,10 @@ void Symbol2JSONGraph(std::shared_ptr src, JSONGraph *jgraph) { } } -std::shared_ptr JSONGraph2Symbol(const JSONGraph &jgraph, bool no_parse) { - for (const JSONNode &n : jgraph.nodes) { +std::shared_ptr JSONGraph2Symbol(const JSONGraph& jgraph, bool no_parse) { + for (const JSONNode& n : jgraph.nodes) { n.node->inputs.reserve(n.inputs.size()); - for (const JSONNode::Entry &e : n.inputs) { + for (const JSONNode::Entry& e : n.inputs) { CHECK(e.node_id < jgraph.nodes.size()); n.node->inputs.emplace_back(NodeEntry{jgraph.nodes[e.node_id].node, e.index, e.version}); } @@ -235,7 +234,7 @@ std::shared_ptr JSONGraph2Symbol(const JSONGraph &jgraph, bool no_parse) CHECK(nid < jgraph.nodes.size()); n.node->control_deps.push_back(jgraph.nodes[nid].node); } - for (const JSONGraph &subgraph : n.subgraphs) { + for (const JSONGraph& subgraph : n.subgraphs) { // The "no_parse" option here, is to be compatible with // commit cfd3075e85807dcd8f9534c37e053583dee87524 // (https://github.com/apache/incubator-mxnet/tree/cfd3075e85807dcd8f9534c37e053583dee87524), @@ -248,7 +247,7 @@ std::shared_ptr JSONGraph2Symbol(const JSONGraph &jgraph, bool no_parse) n.node->op()->attr_parser(&(n.node->attrs)); } else if (!no_parse && n.node->is_variable()) { n.node->attrs.parsed = - Symbol::CreateVariable(n.node->attrs.name).outputs[0].node->attrs.parsed; + Symbol::CreateVariable(n.node->attrs.name).outputs[0].node->attrs.parsed; } } // consistency check @@ -258,7 +257,7 @@ std::shared_ptr JSONGraph2Symbol(const JSONGraph &jgraph, bool no_parse) } std::shared_ptr symbol = std::make_shared(); symbol->outputs.reserve(jgraph.heads.size()); - for (const JSONNode::Entry &e : jgraph.heads) { + for (const JSONNode::Entry& e : jgraph.heads) { CHECK(e.node_id < jgraph.nodes.size()); symbol->outputs.emplace_back(NodeEntry{jgraph.nodes[e.node_id].node, e.index, e.version}); } @@ -267,10 +266,8 @@ std::shared_ptr JSONGraph2Symbol(const JSONGraph &jgraph, bool no_parse) // Load a graph from JSON file. Graph LoadJSON(Graph src) { - CHECK_NE(src.attrs.count("json"), 0U) - << "Load JSON require json to be presented."; - const std::string &json_str = - nnvm::get(*src.attrs.at("json")); + CHECK_NE(src.attrs.count("json"), 0U) << "Load JSON require json to be presented."; + const std::string& json_str = nnvm::get(*src.attrs.at("json")); bool no_parse = false; if (src.attrs.count("load_json_no_parse")) { no_parse = nnvm::get(*src.attrs.at("load_json_no_parse")); @@ -305,17 +302,16 @@ Graph SaveJSON(Graph src) { // register pass NNVM_REGISTER_PASS(LoadJSON) -.describe("Return a new Graph, loaded from src.attrs[\"json\"]") -.set_body(LoadJSON) -.set_change_graph(true) -.depend_graph_attr("json"); + .describe("Return a new Graph, loaded from src.attrs[\"json\"]") + .set_body(LoadJSON) + .set_change_graph(true) + .depend_graph_attr("json"); NNVM_REGISTER_PASS(SaveJSON) -.describe("Return a new empty Graph. Save graph to ret.attrs[\"json\"]") -.set_body(SaveJSON) -.set_change_graph(true) -.provide_graph_attr("json"); - + .describe("Return a new empty Graph. Save graph to ret.attrs[\"json\"]") + .set_body(SaveJSON) + .set_change_graph(true) + .provide_graph_attr("json"); DMLC_JSON_ENABLE_ANY(std::string, str); DMLC_JSON_ENABLE_ANY(std::vector, list_int); diff --git a/nnvm/tests/cpp/op_fallback_test.cc b/nnvm/tests/cpp/op_fallback_test.cc deleted file mode 100644 index 477d9be6626d..000000000000 --- a/nnvm/tests/cpp/op_fallback_test.cc +++ /dev/null @@ -1,229 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include - -namespace nnvm { - -extern nnvm::Graph AnnotateGraph(nnvm::Graph g); -extern nnvm::Graph InsertDataCopy(nnvm::Graph g); - -namespace cpptest { -using StringVector = std::vector; -using IntVector = std::vector; - -enum class AnnotationType : int { - kTarget = 1, - kDeivceTarget = 2 -}; - -NNVM_REGISTER_OP(add) - .describe("addtion operation") - .set_num_inputs(2) - .set_num_outputs(1); - -NNVM_REGISTER_OP(sub) - .describe("subtract operation") - .set_num_inputs(2) - .set_num_outputs(1) - .set_fallback_device(true); - -// Register a simple form of the copy operation for testing purpose. -NNVM_REGISTER_OP(device_copy_op) - .describe("Copy data across devices." NNVM_ADD_FILELINE) - .set_num_inputs(1) - .set_num_outputs(1); - -nnvm::Graph GetGraph() { - const auto* add = nnvm::Op::Get("add"); - nnvm::NodePtr add_node = nnvm::Node::Create(); - add_node->attrs.op = add; - add_node->attrs.name = "add"; - nnvm::Symbol add_sym; - add_sym.outputs.push_back(nnvm::NodeEntry{add_node, 0, 0}); - - const auto* sub = nnvm::Op::Get("sub"); - - nnvm::NodePtr sub_node = nnvm::Node::Create(); - sub_node->attrs.op = sub; - sub_node->attrs.name = "sub"; - sub_node->inputs.push_back(add_sym.outputs[0]); - nnvm::Symbol sub_sym; - sub_sym.outputs.push_back(nnvm::NodeEntry{sub_node, 0, 0}); - - nnvm::Symbol sym; - sym.outputs.insert(sym.outputs.end(), add_sym.outputs.begin(), - add_sym.outputs.end()); - sym.outputs.insert(sym.outputs.end(), sub_sym.outputs.begin(), - sub_sym.outputs.end()); - - nnvm::Graph g; - g.outputs = sym.outputs; - return g; -} - -TEST(NodeAttrTest, DefaultValueForNodes) { - nnvm::Graph g = GetGraph(); - const auto& idx = g.indexed_graph(); - const auto& add = idx[0U]; - const auto& sub = idx[1U]; - EXPECT_EQ(add.source->attrs.device_type, 0); - EXPECT_EQ(sub.source->attrs.device_type, 0); - EXPECT_TRUE(sub.source->attrs.op->fallback); -} - -TEST(TargetAnnotationTest, AnnotateNodesWithTarget) { - nnvm::Graph g = GetGraph(); - StringVector targets{"llvm"}; - IntVector devices{1}; - // Setup required attributes. - g.attrs["annotation_type"] = std::make_shared( - static_cast(AnnotationType::kTarget)); - g.attrs["target"] = std::make_shared(targets); - g.attrs["device_type"] = std::make_shared(std::move(devices)); - g = nnvm::AnnotateGraph(g); - const auto& idx = g.indexed_graph(); - const auto& add = idx[0U]; - const auto& sub = idx[1U]; - EXPECT_EQ(g.indexed_graph().num_nodes(), 2); - EXPECT_TRUE(add.source->attrs.dict.count("target")); - EXPECT_TRUE(sub.source->attrs.dict.count("target")); - EXPECT_EQ(add.source->attrs.dict.at("target"), targets[0]); - EXPECT_EQ(sub.source->attrs.dict.at("target"), targets[0]); -} - - -// Both add and sub are explicitly specified to device type 2. However, sub is -// registered with fallback. It, therefore, should be annotated with device -// type 1. -TEST(DeviceFallbackTest, SubOpFallbackToOne) { - nnvm::Graph g = GetGraph(); - int fallback_device = 1; - StringVector op_names{"add", "sub"}; - IntVector op_devices{2, 2}; - StringVector targets{"llvm", "cuda"}; - IntVector devices{1, 2}; - // Setup required attributes. - g.attrs["annotation_type"] = std::make_shared( - static_cast(AnnotationType::kDeivceTarget)); - g.attrs["target"] = std::make_shared(std::move(targets)); - g.attrs["device_type"] = std::make_shared(std::move(devices)); - g.attrs["op_name"] = std::make_shared(std::move(op_names)); - g.attrs["op_device"] = std::make_shared(std::move(op_devices)); - g.attrs["fallback"] = std::make_shared(std::move(fallback_device)); - g = nnvm::AnnotateGraph(g); - const auto& idx = g.indexed_graph(); - const auto& add = idx[0U]; - const auto& sub = idx[1U]; - EXPECT_EQ(g.indexed_graph().num_nodes(), 2); - // add should be annotated with device type 2 - EXPECT_EQ(add.source->attrs.device_type, 2); - // sub should have been scheduled to device type 1 - EXPECT_EQ(sub.source->attrs.device_type, 1); -} - -// No device information is explicitly specified for add. It should be -// annotatedc with the fallback device. -TEST(DeviceFallbackTest, AddOpFallbackToOne) { - nnvm::Graph g = GetGraph(); - int fallback_device = 1; - StringVector op_names{"sub"}; - std::vector op_devices{2}; - StringVector targets{"llvm", "cuda"}; - std::vector devices{1, 2}; - // Setup required attributes. - g.attrs["annotation_type"] = std::make_shared( - static_cast(AnnotationType::kDeivceTarget)); - g.attrs["target"] = std::make_shared(std::move(targets)); - g.attrs["device_type"] = std::make_shared(std::move(devices)); - g.attrs["op_name"] = std::make_shared(std::move(op_names)); - g.attrs["op_device"] = std::make_shared(std::move(op_devices)); - g.attrs["fallback"] = std::make_shared(std::move(fallback_device)); - g = nnvm::AnnotateGraph(g); - const auto& idx = g.indexed_graph(); - const auto& add = idx[0U]; - const auto& sub = idx[1U]; - EXPECT_EQ(g.indexed_graph().num_nodes(), 2); - // add should be annotated with device type 2 - EXPECT_EQ(add.source->attrs.device_type, 1); - // sub should have been scheduled to device type 1 - EXPECT_EQ(sub.source->attrs.device_type, 1); -} - -TEST(CopyNodeInsertionTest, CopyNodeInsertedIsAndAnnotated) { - nnvm::Graph g = GetGraph(); - int fallback_device = 1; - StringVector op_names{"add"}; - IntVector op_devices{2}; - StringVector targets{"llvm", "cuda"}; - IntVector devices{1, 2}; - // Setup required attributes. - g.attrs["annotation_type"] = std::make_shared( - static_cast(AnnotationType::kDeivceTarget)); - g.attrs["target"] = std::make_shared(targets); - g.attrs["device_type"] = std::make_shared(std::move(devices)); - g.attrs["op_name"] = std::make_shared(std::move(op_names)); - g.attrs["op_device"] = std::make_shared(std::move(op_devices)); - g.attrs["fallback"] = std::make_shared(std::move(fallback_device)); - g = nnvm::AnnotateGraph(g); - g = nnvm::InsertDataCopy(g); - const auto& idx = g.indexed_graph(); - const auto& add = idx[0U]; - const auto& copy = idx[1U]; - const auto& sub = idx[2U]; - // A copy node should be inserted. - EXPECT_EQ(g.indexed_graph().num_nodes(), 3); - EXPECT_EQ(add.source->attrs.device_type, 2); - // Both copy node and sub should have the same device type, which is 1. - EXPECT_EQ(copy.source->attrs.device_type, 1); - EXPECT_EQ(sub.source->attrs.device_type, 1); - - // Check annotated target for each node. - EXPECT_TRUE(add.source->attrs.dict.count("target")); - EXPECT_FALSE(copy.source->attrs.dict.count("target")); - EXPECT_TRUE(sub.source->attrs.dict.count("target")); - EXPECT_EQ(add.source->attrs.dict.at("target"), targets[1]); - EXPECT_EQ(sub.source->attrs.dict.at("target"), targets[0]); - - // Check device index array - EXPECT_TRUE(g.HasAttr("device_index")); - const auto& device_vec = g.MoveCopyAttr("device_index"); - EXPECT_THAT(device_vec, testing::ElementsAre(2, 1, 1)); -} - -} // namespace cpptest -} // namespace nnvm - -int main(int argc, char ** argv) { - testing::InitGoogleTest(&argc, argv); - testing::FLAGS_gtest_death_test_style = "threadsafe"; - return RUN_ALL_TESTS(); -} diff --git a/nnvm/tests/cpp/op_test.cc b/nnvm/tests/cpp/op_test.cc index 4c771655d87b..2ebd14688f46 100644 --- a/nnvm/tests/cpp/op_test.cc +++ b/nnvm/tests/cpp/op_test.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -20,16 +20,15 @@ #include #include #include -#include -NNVM_REGISTER_OP(add) -.describe("add two data together") -.set_num_inputs(2) -.set_attr("inplace_pair", std::make_pair(0, 0)); +#include NNVM_REGISTER_OP(add) -.set_attr("nick_name", "plus"); + .describe("add two data together") + .set_num_inputs(2) + .set_attr("inplace_pair", std::make_pair(0, 0)); +NNVM_REGISTER_OP(add).set_attr("nick_name", "plus"); TEST(Op, GetAttr) { using namespace nnvm; @@ -39,7 +38,7 @@ TEST(Op, GetAttr) { CHECK_EQ(nick[add], "plus"); } -int main(int argc, char ** argv) { +int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; return RUN_ALL_TESTS(); diff --git a/nnvm/tests/cpp/tuple_test.cc b/nnvm/tests/cpp/tuple_test.cc index 7bf59b5db7c8..2c2c307aadce 100644 --- a/nnvm/tests/cpp/tuple_test.cc +++ b/nnvm/tests/cpp/tuple_test.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -22,8 +22,8 @@ #include TEST(Tuple, Basic) { - using nnvm::Tuple; using nnvm::TShape; + using nnvm::Tuple; Tuple x{1, 2, 3}; Tuple y{1, 2, 3, 5, 6}; x = std::move(y); @@ -42,7 +42,7 @@ TEST(Tuple, Basic) { CHECK((s == TShape{1, 2, 3})); } -int main(int argc, char ** argv) { +int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; return RUN_ALL_TESTS(); diff --git a/python/setup.py b/python/setup.py index 62f374923714..682589ef5e6f 100644 --- a/python/setup.py +++ b/python/setup.py @@ -156,6 +156,7 @@ def get_package_data_files(): zip_safe=False, install_requires=[ 'numpy', + 'scipy', 'decorator', 'attrs', 'psutil', @@ -164,7 +165,7 @@ def get_package_data_files(): 'matplotlib'], 'extra_feature': ['tornado', 'psutil', - 'xgboost==0.90', + 'xgboost>=1.1.0', 'mypy', 'orderedset', 'antlr4-python3-runtime']}, diff --git a/python/tvm/__init__.py b/python/tvm/__init__.py index f781aef0a8be..6cbc6d2288ac 100644 --- a/python/tvm/__init__.py +++ b/python/tvm/__init__.py @@ -23,14 +23,14 @@ # top-level alias # tvm._ffi from ._ffi.base import TVMError, __version__ -from ._ffi.runtime_ctypes import TypeCode, DataType +from ._ffi.runtime_ctypes import DataTypeCode, DataType from ._ffi import register_object, register_func, register_extension, get_global_func # top-level alias # tvm.runtime from .runtime.object import Object from .runtime.ndarray import context, cpu, gpu, opencl, cl, vulkan, metal, mtl -from .runtime.ndarray import vpi, rocm, opengl, ext_dev, micro_dev, hexagon +from .runtime.ndarray import vpi, rocm, ext_dev, micro_dev, hexagon from .runtime import ndarray as nd # tvm.error @@ -63,12 +63,17 @@ # Contrib initializers from .contrib import rocm as _rocm, nvcc as _nvcc, sdaccel as _sdaccel -# Clean subprocesses when TVM is interrupted -def tvm_excepthook(exctype, value, trbk): - print('\n'.join(traceback.format_exception(exctype, value, trbk))) - if hasattr(multiprocessing, 'active_children'): - # pylint: disable=not-callable - for p in multiprocessing.active_children(): - p.terminate() +def tvm_wrap_excepthook(exception_hook): + """Wrap given excepthook with TVM additional work.""" -sys.excepthook = tvm_excepthook + def wrapper(exctype, value, trbk): + """Clean subprocesses when TVM is interrupted.""" + exception_hook(exctype, value, trbk) + if hasattr(multiprocessing, 'active_children'): + # pylint: disable=not-callable + for p in multiprocessing.active_children(): + p.terminate() + + return wrapper + +sys.excepthook = tvm_wrap_excepthook(sys.excepthook) diff --git a/python/tvm/_ffi/_ctypes/object.py b/python/tvm/_ffi/_ctypes/object.py index b5dc65fd5e79..359b018f0431 100644 --- a/python/tvm/_ffi/_ctypes/object.py +++ b/python/tvm/_ffi/_ctypes/object.py @@ -18,7 +18,7 @@ """Runtime Object api""" import ctypes from ..base import _LIB, check_call -from .types import TypeCode, RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func +from .types import ArgTypeCode, RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func from .ndarray import _register_ndarray, NDArrayBase @@ -50,18 +50,49 @@ def _return_object(x): tindex = ctypes.c_uint() check_call(_LIB.TVMObjectGetTypeIndex(handle, ctypes.byref(tindex))) cls = OBJECT_TYPE.get(tindex.value, _CLASS_OBJECT) + if issubclass(cls, PyNativeObject): + obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT) + obj.handle = handle + return cls.__from_tvm_object__(cls, obj) # Avoid calling __init__ of cls, instead directly call __new__ # This allows child class to implement their own __init__ obj = cls.__new__(cls) obj.handle = handle return obj -RETURN_SWITCH[TypeCode.OBJECT_HANDLE] = _return_object -C_TO_PY_ARG_SWITCH[TypeCode.OBJECT_HANDLE] = _wrap_arg_func( - _return_object, TypeCode.OBJECT_HANDLE) +RETURN_SWITCH[ArgTypeCode.OBJECT_HANDLE] = _return_object +C_TO_PY_ARG_SWITCH[ArgTypeCode.OBJECT_HANDLE] = _wrap_arg_func( + _return_object, ArgTypeCode.OBJECT_HANDLE) + +C_TO_PY_ARG_SWITCH[ArgTypeCode.OBJECT_RVALUE_REF_ARG] = _wrap_arg_func( + _return_object, ArgTypeCode.OBJECT_RVALUE_REF_ARG) + + +class PyNativeObject: + """Base class of all TVM objects that also subclass python's builtin types.""" + __slots__ = [] + + def __init_tvm_object_by_constructor__(self, fconstructor, *args): + """Initialize the internal tvm_object by calling constructor function. + + Parameters + ---------- + fconstructor : Function + Constructor function. + + args: list of objects + The arguments to the constructor + + Note + ---- + We have a special calling convention to call constructor functions. + So the return object is directly set into the object + """ + # pylint: disable=assigning-non-slot + obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT) + obj.__init_handle_by_constructor__(fconstructor, *args) + self.__tvm_object__ = obj -C_TO_PY_ARG_SWITCH[TypeCode.OBJECT_RVALUE_REF_ARG] = _wrap_arg_func( - _return_object, TypeCode.OBJECT_RVALUE_REF_ARG) class ObjectBase(object): diff --git a/python/tvm/_ffi/_ctypes/packed_func.py b/python/tvm/_ffi/_ctypes/packed_func.py index 11bb65504c61..8a2f49a7e6b6 100644 --- a/python/tvm/_ffi/_ctypes/packed_func.py +++ b/python/tvm/_ffi/_ctypes/packed_func.py @@ -26,10 +26,10 @@ from ..runtime_ctypes import DataType, TVMByteArray, TVMContext, ObjectRValueRef from . import ndarray as _nd from .ndarray import NDArrayBase, _make_array -from .types import TVMValue, TypeCode +from .types import TVMValue, ArgTypeCode from .types import TVMPackedCFunc, TVMCFuncFinalizer from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func, _ctx_to_int64 -from .object import ObjectBase, _set_class_object +from .object import ObjectBase, PyNativeObject, _set_class_object from . import object as _object PackedFuncHandle = ctypes.c_void_p @@ -115,30 +115,39 @@ def _make_tvm_args(args, temp_args): for i, arg in enumerate(args): if isinstance(arg, ObjectBase): values[i].v_handle = arg.handle - type_codes[i] = TypeCode.OBJECT_HANDLE + type_codes[i] = ArgTypeCode.OBJECT_HANDLE elif arg is None: values[i].v_handle = None - type_codes[i] = TypeCode.NULL + type_codes[i] = ArgTypeCode.NULL elif isinstance(arg, NDArrayBase): values[i].v_handle = ctypes.cast(arg.handle, ctypes.c_void_p) - type_codes[i] = (TypeCode.NDARRAY_HANDLE - if not arg.is_view else TypeCode.DLTENSOR_HANDLE) + type_codes[i] = (ArgTypeCode.NDARRAY_HANDLE + if not arg.is_view else ArgTypeCode.DLTENSOR_HANDLE) + elif isinstance(arg, PyNativeObject): + values[i].v_handle = arg.__tvm_object__.handle + type_codes[i] = ArgTypeCode.OBJECT_HANDLE elif isinstance(arg, _nd._TVM_COMPATS): values[i].v_handle = ctypes.c_void_p(arg._tvm_handle) type_codes[i] = arg.__class__._tvm_tcode elif isinstance(arg, Integral): values[i].v_int64 = arg - type_codes[i] = TypeCode.INT + type_codes[i] = ArgTypeCode.INT elif isinstance(arg, Number): values[i].v_float64 = arg - type_codes[i] = TypeCode.FLOAT + type_codes[i] = ArgTypeCode.FLOAT elif isinstance(arg, DataType): values[i].v_str = c_str(str(arg)) - type_codes[i] = TypeCode.STR + type_codes[i] = ArgTypeCode.STR elif isinstance(arg, TVMContext): values[i].v_int64 = _ctx_to_int64(arg) - type_codes[i] = TypeCode.TVM_CONTEXT - elif isinstance(arg, bytearray): + type_codes[i] = ArgTypeCode.TVM_CONTEXT + elif isinstance(arg, (bytearray, bytes)): + # from_buffer only taeks in bytearray. + if isinstance(arg, bytes): + byte_arr = bytearray(arg) + temp_args.append(byte_arr) + arg = byte_arr + arr = TVMByteArray() arr.data = ctypes.cast( (ctypes.c_byte * len(arg)).from_buffer(arg), @@ -146,31 +155,31 @@ def _make_tvm_args(args, temp_args): arr.size = len(arg) values[i].v_handle = ctypes.c_void_p(ctypes.addressof(arr)) temp_args.append(arr) - type_codes[i] = TypeCode.BYTES + type_codes[i] = ArgTypeCode.BYTES elif isinstance(arg, string_types): values[i].v_str = c_str(arg) - type_codes[i] = TypeCode.STR + type_codes[i] = ArgTypeCode.STR elif isinstance(arg, (list, tuple, dict, _CLASS_OBJECT_GENERIC)): arg = _FUNC_CONVERT_TO_OBJECT(arg) values[i].v_handle = arg.handle - type_codes[i] = TypeCode.OBJECT_HANDLE + type_codes[i] = ArgTypeCode.OBJECT_HANDLE temp_args.append(arg) elif isinstance(arg, _CLASS_MODULE): values[i].v_handle = arg.handle - type_codes[i] = TypeCode.MODULE_HANDLE + type_codes[i] = ArgTypeCode.MODULE_HANDLE elif isinstance(arg, PackedFuncBase): values[i].v_handle = arg.handle - type_codes[i] = TypeCode.PACKED_FUNC_HANDLE + type_codes[i] = ArgTypeCode.PACKED_FUNC_HANDLE elif isinstance(arg, ctypes.c_void_p): values[i].v_handle = arg - type_codes[i] = TypeCode.HANDLE + type_codes[i] = ArgTypeCode.HANDLE elif isinstance(arg, ObjectRValueRef): values[i].v_handle = ctypes.cast(ctypes.byref(arg.obj.handle), ctypes.c_void_p) - type_codes[i] = TypeCode.OBJECT_RVALUE_REF_ARG + type_codes[i] = ArgTypeCode.OBJECT_RVALUE_REF_ARG elif callable(arg): arg = convert_to_tvm_func(arg) values[i].v_handle = arg.handle - type_codes[i] = TypeCode.PACKED_FUNC_HANDLE + type_codes[i] = ArgTypeCode.PACKED_FUNC_HANDLE temp_args.append(arg) else: raise TypeError("Don't know how to handle type %s" % type(arg)) @@ -231,7 +240,7 @@ def __init_handle_by_constructor__(fconstructor, args): raise get_last_ffi_error() _ = temp_args _ = args - assert ret_tcode.value == TypeCode.OBJECT_HANDLE + assert ret_tcode.value == ArgTypeCode.OBJECT_HANDLE handle = ret_val.v_handle return handle @@ -266,15 +275,15 @@ def _get_global_func(name, allow_missing=False): # setup return handle for function type _object.__init_by_constructor__ = __init_handle_by_constructor__ -RETURN_SWITCH[TypeCode.PACKED_FUNC_HANDLE] = _handle_return_func -RETURN_SWITCH[TypeCode.MODULE_HANDLE] = _return_module -RETURN_SWITCH[TypeCode.NDARRAY_HANDLE] = lambda x: _make_array(x.v_handle, False, True) -C_TO_PY_ARG_SWITCH[TypeCode.PACKED_FUNC_HANDLE] = _wrap_arg_func( - _handle_return_func, TypeCode.PACKED_FUNC_HANDLE) -C_TO_PY_ARG_SWITCH[TypeCode.MODULE_HANDLE] = _wrap_arg_func( - _return_module, TypeCode.MODULE_HANDLE) -C_TO_PY_ARG_SWITCH[TypeCode.DLTENSOR_HANDLE] = lambda x: _make_array(x.v_handle, True, False) -C_TO_PY_ARG_SWITCH[TypeCode.NDARRAY_HANDLE] = lambda x: _make_array(x.v_handle, False, True) +RETURN_SWITCH[ArgTypeCode.PACKED_FUNC_HANDLE] = _handle_return_func +RETURN_SWITCH[ArgTypeCode.MODULE_HANDLE] = _return_module +RETURN_SWITCH[ArgTypeCode.NDARRAY_HANDLE] = lambda x: _make_array(x.v_handle, False, True) +C_TO_PY_ARG_SWITCH[ArgTypeCode.PACKED_FUNC_HANDLE] = _wrap_arg_func( + _handle_return_func, ArgTypeCode.PACKED_FUNC_HANDLE) +C_TO_PY_ARG_SWITCH[ArgTypeCode.MODULE_HANDLE] = _wrap_arg_func( + _return_module, ArgTypeCode.MODULE_HANDLE) +C_TO_PY_ARG_SWITCH[ArgTypeCode.DLTENSOR_HANDLE] = lambda x: _make_array(x.v_handle, True, False) +C_TO_PY_ARG_SWITCH[ArgTypeCode.NDARRAY_HANDLE] = lambda x: _make_array(x.v_handle, False, True) _CLASS_MODULE = None _CLASS_PACKED_FUNC = None diff --git a/python/tvm/_ffi/_ctypes/types.py b/python/tvm/_ffi/_ctypes/types.py index 20be30a59b2f..d4e7b362cbe9 100644 --- a/python/tvm/_ffi/_ctypes/types.py +++ b/python/tvm/_ffi/_ctypes/types.py @@ -19,7 +19,7 @@ import ctypes import struct from ..base import py_str, check_call, _LIB -from ..runtime_ctypes import TVMByteArray, TypeCode, TVMContext +from ..runtime_ctypes import TVMByteArray, ArgTypeCode, TVMContext class TVMValue(ctypes.Union): """TVMValue in C API""" @@ -86,21 +86,21 @@ def _ctx_to_int64(ctx): RETURN_SWITCH = { - TypeCode.INT: lambda x: x.v_int64, - TypeCode.FLOAT: lambda x: x.v_float64, - TypeCode.HANDLE: _return_handle, - TypeCode.NULL: lambda x: None, - TypeCode.STR: lambda x: py_str(x.v_str), - TypeCode.BYTES: _return_bytes, - TypeCode.TVM_CONTEXT: _return_context + ArgTypeCode.INT: lambda x: x.v_int64, + ArgTypeCode.FLOAT: lambda x: x.v_float64, + ArgTypeCode.HANDLE: _return_handle, + ArgTypeCode.NULL: lambda x: None, + ArgTypeCode.STR: lambda x: py_str(x.v_str), + ArgTypeCode.BYTES: _return_bytes, + ArgTypeCode.TVM_CONTEXT: _return_context } C_TO_PY_ARG_SWITCH = { - TypeCode.INT: lambda x: x.v_int64, - TypeCode.FLOAT: lambda x: x.v_float64, - TypeCode.HANDLE: _return_handle, - TypeCode.NULL: lambda x: None, - TypeCode.STR: lambda x: py_str(x.v_str), - TypeCode.BYTES: _return_bytes, - TypeCode.TVM_CONTEXT: _return_context + ArgTypeCode.INT: lambda x: x.v_int64, + ArgTypeCode.FLOAT: lambda x: x.v_float64, + ArgTypeCode.HANDLE: _return_handle, + ArgTypeCode.NULL: lambda x: None, + ArgTypeCode.STR: lambda x: py_str(x.v_str), + ArgTypeCode.BYTES: _return_bytes, + ArgTypeCode.TVM_CONTEXT: _return_context } diff --git a/python/tvm/_ffi/_cython/base.pxi b/python/tvm/_ffi/_cython/base.pxi index 0da66ac2e034..8c9e413813b9 100644 --- a/python/tvm/_ffi/_cython/base.pxi +++ b/python/tvm/_ffi/_cython/base.pxi @@ -22,7 +22,7 @@ from cpython cimport pycapsule from libc.stdint cimport int32_t, int64_t, uint64_t, uint32_t, uint8_t, uint16_t import ctypes -cdef enum TVMTypeCode: +cdef enum TVMArgTypeCode: kInt = 0 kUInt = 1 kFloat = 2 diff --git a/python/tvm/_ffi/_cython/object.pxi b/python/tvm/_ffi/_cython/object.pxi index f2b5cc172d45..371cbbb0a4a2 100644 --- a/python/tvm/_ffi/_cython/object.pxi +++ b/python/tvm/_ffi/_cython/object.pxi @@ -39,18 +39,49 @@ cdef inline object make_ret_object(void* chandle): object_type = OBJECT_TYPE handle = ctypes_handle(chandle) CALL(TVMObjectGetTypeIndex(chandle, &tindex)) + if tindex < len(OBJECT_TYPE): cls = OBJECT_TYPE[tindex] if cls is not None: + if issubclass(cls, PyNativeObject): + obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT) + (obj).chandle = chandle + return cls.__from_tvm_object__(cls, obj) obj = cls.__new__(cls) else: obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT) else: obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT) + (obj).chandle = chandle return obj +class PyNativeObject: + """Base class of all TVM objects that also subclass python's builtin types.""" + __slots__ = [] + + def __init_tvm_object_by_constructor__(self, fconstructor, *args): + """Initialize the internal tvm_object by calling constructor function. + + Parameters + ---------- + fconstructor : Function + Constructor function. + + args: list of objects + The arguments to the constructor + + Note + ---- + We have a special calling convention to call constructor functions. + So the return object is directly set into the object + """ + obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT) + obj.__init_handle_by_constructor__(fconstructor, *args) + self.__tvm_object__ = obj + + cdef class ObjectBase: cdef void* chandle diff --git a/python/tvm/_ffi/_cython/packed_func.pxi b/python/tvm/_ffi/_cython/packed_func.pxi index 6977e108bf88..45bcf64a616d 100644 --- a/python/tvm/_ffi/_cython/packed_func.pxi +++ b/python/tvm/_ffi/_cython/packed_func.pxi @@ -109,6 +109,9 @@ cdef inline int make_arg(object arg, value[0].v_handle = (arg).chandle tcode[0] = (kTVMNDArrayHandle if not (arg).c_is_view else kTVMDLTensorHandle) + elif isinstance(arg, PyNativeObject): + value[0].v_handle = ((arg.__tvm_object__)).chandle + tcode[0] = kTVMObjectHandle elif isinstance(arg, _TVM_COMPATS): ptr = arg._tvm_handle value[0].v_handle = (ptr) @@ -139,7 +142,13 @@ cdef inline int make_arg(object arg, value[0].v_ctx = (( ctypes.addressof(arg)))[0] tcode[0] = kTVMContext - elif isinstance(arg, bytearray): + elif isinstance(arg, (bytes, bytearray)): + # from_buffer only taeks in bytearray. + if isinstance(arg, bytes): + byte_arr = bytearray(arg) + temp_args.append(byte_arr) + arg = byte_arr + arr = TVMByteArray() arr.data = ctypes.cast( (ctypes.c_byte * len(arg)).from_buffer(arg), diff --git a/python/tvm/_ffi/base.py b/python/tvm/_ffi/base.py index 8d3ce19f9444..2cca014b1420 100644 --- a/python/tvm/_ffi/base.py +++ b/python/tvm/_ffi/base.py @@ -48,10 +48,14 @@ def _load_lib(): """Load libary by searching possible path.""" lib_path = libinfo.find_lib_path() lib = ctypes.CDLL(lib_path[0], ctypes.RTLD_GLOBAL) - # DMatrix functions lib.TVMGetLastError.restype = ctypes.c_char_p return lib, os.path.basename(lib_path[0]) +try: + import readline # pylint: disable=unused-import +except ImportError: + pass + # version number __version__ = libinfo.__version__ # library instance diff --git a/python/tvm/_ffi/libinfo.py b/python/tvm/_ffi/libinfo.py index 0d1a4e214791..a1483a1b012b 100644 --- a/python/tvm/_ffi/libinfo.py +++ b/python/tvm/_ffi/libinfo.py @@ -88,6 +88,10 @@ def find_lib_path(name=None, search_path=None, optional=False): dll_path.append(install_lib_dir) + if os.path.isdir(source_dir): + dll_path.append(os.path.join(source_dir, "web", "dist", "wasm")) + dll_path.append(os.path.join(source_dir, "web", "dist")) + dll_path = [os.path.realpath(x) for x in dll_path] if search_path is not None: if isinstance(search_path, list): @@ -154,6 +158,7 @@ def find_include_path(name=None, search_path=None, optional=False): ffi_dir = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) source_dir = os.path.join(ffi_dir, "..", "..", "..") install_include_dir = os.path.join(ffi_dir, "..", "..", "..", "..") + third_party_dir = os.path.join(source_dir, "3rdparty") header_path = [] diff --git a/python/tvm/_ffi/registry.py b/python/tvm/_ffi/registry.py index e4b8b18b4805..0942ccb277a6 100644 --- a/python/tvm/_ffi/registry.py +++ b/python/tvm/_ffi/registry.py @@ -122,7 +122,7 @@ def register_extension(cls, fcreate=None): @tvm.register_extension class MyTensor(object): - _tvm_tcode = tvm.TypeCode.ARRAY_HANDLE + _tvm_tcode = tvm.ArgTypeCode.ARRAY_HANDLE def __init__(self): self.handle = _LIB.NewDLTensor() @@ -132,8 +132,8 @@ def _tvm_handle(self): return self.handle.value """ assert hasattr(cls, "_tvm_tcode") - if fcreate and cls._tvm_tcode < TypeCode.EXT_BEGIN: - raise ValueError("Cannot register create when extension tcode is same as buildin") + if fcreate: + raise ValueError("Extension with fcreate is no longer supported") _reg_extension(cls, fcreate) return cls diff --git a/python/tvm/_ffi/runtime_ctypes.py b/python/tvm/_ffi/runtime_ctypes.py index 6b06ad01c9ff..2e498e38cce8 100644 --- a/python/tvm/_ffi/runtime_ctypes.py +++ b/python/tvm/_ffi/runtime_ctypes.py @@ -23,7 +23,7 @@ tvm_shape_index_t = ctypes.c_int64 -class TypeCode(object): +class ArgTypeCode(object): """Type code used in API calls""" INT = 0 UINT = 1 @@ -42,23 +42,30 @@ class TypeCode(object): OBJECT_RVALUE_REF_ARG = 14 EXT_BEGIN = 15 - class TVMByteArray(ctypes.Structure): """Temp data structure for byte array.""" _fields_ = [("data", ctypes.POINTER(ctypes.c_byte)), ("size", ctypes.c_size_t)] +class DataTypeCode(object): + """DataType code in DLTensor.""" + INT = 0 + UINT = 1 + FLOAT = 2 + HANDLE = 3 + + class DataType(ctypes.Structure): """TVM datatype structure""" _fields_ = [("type_code", ctypes.c_uint8), ("bits", ctypes.c_uint8), ("lanes", ctypes.c_uint16)] CODE2STR = { - 0 : 'int', - 1 : 'uint', - 2 : 'float', - 4 : 'handle' + DataTypeCode.INT : 'int', + DataTypeCode.UINT : 'uint', + DataTypeCode.FLOAT : 'float', + DataTypeCode.HANDLE : 'handle' } def __init__(self, type_str): super(DataType, self).__init__() @@ -67,7 +74,7 @@ def __init__(self, type_str): if type_str == "bool": self.bits = 1 - self.type_code = 1 + self.type_code = DataTypeCode.UINT self.lanes = 1 return @@ -77,16 +84,16 @@ def __init__(self, type_str): bits = 32 if head.startswith("int"): - self.type_code = 0 + self.type_code = DataTypeCode.INT head = head[3:] elif head.startswith("uint"): - self.type_code = 1 + self.type_code = DataTypeCode.UINT head = head[4:] elif head.startswith("float"): - self.type_code = 2 + self.type_code = DataTypeCode.FLOAT head = head[5:] elif head.startswith("handle"): - self.type_code = 4 + self.type_code = DataTypeCode.HANDLE bits = 64 head = "" elif head.startswith("custom"): @@ -143,10 +150,10 @@ class TVMContext(ctypes.Structure): 8 : 'metal', 9 : 'vpi', 10: 'rocm', - 11: 'opengl', 12: 'ext_dev', 13: 'micro_dev', 14: 'hexagon', + 15: 'webgpu' } STR2MASK = { 'llvm': 1, @@ -165,10 +172,10 @@ class TVMContext(ctypes.Structure): 'metal': 8, 'vpi': 9, 'rocm': 10, - 'opengl': 11, 'ext_dev': 12, 'micro_dev': 13, 'hexagon': 14, + 'webgpu': 15, } def __init__(self, device_type, device_id): super(TVMContext, self).__init__() diff --git a/python/tvm/autotvm/feature.py b/python/tvm/autotvm/feature.py index c576ffd76e56..0c0591ccf2a1 100644 --- a/python/tvm/autotvm/feature.py +++ b/python/tvm/autotvm/feature.py @@ -31,7 +31,6 @@ import tvm._ffi from tvm import target as _target -from tvm.tir import ir_pass from tvm.te import schedule from tvm.driver import build_module @@ -46,10 +45,12 @@ def ana_lower(sch, args, # Phase 0 bounds = schedule.InferBound(sch) stmt = schedule.ScheduleOps(sch, bounds, True) - stmt = ir_pass.StorageFlatten(stmt, binds, 64) - stmt = ir_pass.CanonicalSimplify(stmt) + func = schedule.SchedulePostProcToPrimFunc(args, stmt, None) + mod = tvm.IRModule.from_expr(func._move()) + mod = tvm.tir.transform.StorageFlatten(64)(mod._move()) + mod = tvm.tir.transform.Simplify()(mod._move()) assert simple_mode - return stmt + return mod["main"].body try: _get_buffer_curve_sample_flatten = tvm._ffi.get_global_func( diff --git a/python/tvm/autotvm/graph_tuner/base_graph_tuner.py b/python/tvm/autotvm/graph_tuner/base_graph_tuner.py index e7b4694cc53d..1cc4f39d35b4 100644 --- a/python/tvm/autotvm/graph_tuner/base_graph_tuner.py +++ b/python/tvm/autotvm/graph_tuner/base_graph_tuner.py @@ -81,7 +81,7 @@ def __init__(self, graph, input_shapes, records, target_ops, Each row of this file is an encoded record pair. Otherwise, it is an iterator. - target_ops : List of relay.op.Op + target_ops : List of tvm.ir.Op Target tuning operators. target : str or tvm.target diff --git a/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py b/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py index 8470fb681599..b85c5624808c 100644 --- a/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py +++ b/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py @@ -38,7 +38,7 @@ def expr2graph(expr, target_ops, node_dict, node_list): expr : tvm.relay.Expr.Function Input relay function expression. - target_ops: List of relay.op.Op + target_ops: List of tvm.ir.Op List of target relay ops node_dict : dictionary from tvm.relay.Expr to int @@ -157,7 +157,7 @@ def _traverse_expr(node): elif isinstance(node, Constant): node_entry["name"] = "Constant_" + str(node_index) node_entry["types"] = [node.checked_type] - elif isinstance(node, relay.op.op.Op): + elif isinstance(node, tvm.ir.Op): return else: raise RuntimeError("Not supported relay node type in graph tuning: %s" diff --git a/python/tvm/autotvm/measure/local_executor.py b/python/tvm/autotvm/measure/local_executor.py index cf81e2b50e50..a0a826abccf6 100644 --- a/python/tvm/autotvm/measure/local_executor.py +++ b/python/tvm/autotvm/measure/local_executor.py @@ -145,7 +145,7 @@ def submit(self, func, *args, **kwargs): if not self.do_fork: return LocalFutureNoFork(func(*args, **kwargs)) - queue = Queue(2) + queue = Queue(2) # Size of 2 to avoid a race condition with size 1. process = Process(target=call_with_timeout, args=(queue, self.timeout, func, args, kwargs)) process.start() diff --git a/python/tvm/autotvm/measure/measure_methods.py b/python/tvm/autotvm/measure/measure_methods.py index 698ddbc68dd7..b8969f55c00a 100644 --- a/python/tvm/autotvm/measure/measure_methods.py +++ b/python/tvm/autotvm/measure/measure_methods.py @@ -34,10 +34,9 @@ import numpy as np import tvm._ffi +import tvm.ir.transform from tvm import nd, rpc as _rpc, target as _target -from tvm.tir import ir_pass from tvm.error import TVMError -from tvm.target import build_config from tvm.driver import build from tvm.contrib import nvcc, ndk, tar @@ -232,7 +231,7 @@ def set_task(self, task): def get_build_kwargs(self): kwargs = {} if 'cuda' in self.task.target.keys or 'opencl' in self.task.target.keys or \ - 'rocm' in self.task.target.keys: + 'rocm' in self.task.target.keys or 'vulkan' in self.task.target.keys: remote = request_remote(self.key, self.host, self.port) ctx = remote.context(str(self.task.target), 0) max_dims = ctx.max_thread_dimensions @@ -246,6 +245,8 @@ def get_build_kwargs(self): if 'cuda' in self.task.target.keys: kwargs["cuda_arch"] = "sm_" + "".join(ctx.compute_version.split('.')) + if self.task.target.device_name == 'micro_dev': + kwargs.setdefault('build_option', {})['tir.disable_vectorize'] = True return kwargs @@ -359,7 +360,7 @@ def _build_func_common(measure_input, check_gpu=None, cuda_arch=None, build_opti opts = build_option or {} if check_gpu: # Add verify pass to filter out invalid configs in advance. - opts["add_lower_pass"] = [(2, gpu_verify_pass(**check_gpu))] + opts["tir.add_lower_pass"] = [(2, gpu_verify_pass(**check_gpu))] if cuda_arch: set_cuda_target_arch(cuda_arch) @@ -370,7 +371,7 @@ def _build_func_common(measure_input, check_gpu=None, cuda_arch=None, build_opti import vta func = vta.build(s, args, target_host=task.target_host) else: - with build_config(**opts): + with tvm.ir.transform.PassContext(config=opts): func = build(s, args, target_host=task.target_host) return func, tuple((get_const_tuple(x.shape), x.dtype) for x in args) @@ -615,9 +616,9 @@ def gpu_verify_pass(**kwargs): """Verify the validity of a gpu kernel. This pass will check memory usage and number of threads per block. """ - def verify_pass(stmt): - valid = ir_pass.VerifyGPUCode(stmt, kwargs) + def verify_pass(f, *_): + valid = tvm.tir.analysis.verify_gpu_code(f, kwargs) if not valid: raise InstantiationError("Skipped because of invalid gpu kernel") - return stmt - return verify_pass + return f + return tvm.tir.transform.prim_func_pass(verify_pass, opt_level=0) diff --git a/python/tvm/autotvm/task/relay_integration.py b/python/tvm/autotvm/task/relay_integration.py index de183db41e2c..9751d903af5f 100644 --- a/python/tvm/autotvm/task/relay_integration.py +++ b/python/tvm/autotvm/task/relay_integration.py @@ -41,13 +41,13 @@ def _lower(mod, from tvm.relay.backend import graph_runtime_codegen if hasattr(target, 'device_name') and target.device_name == "vta": - with relay.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}): - import vta - with vta.build_config(): - mod, _ = relay.optimize(mod, target, params) - grc = graph_runtime_codegen.GraphRuntimeCodegen(None, target) - grc.codegen(mod["main"]) - return + import vta + with vta.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}): + mod, _ = relay.optimize(mod, target, params) + grc = graph_runtime_codegen.GraphRuntimeCodegen(None, target) + grc.codegen(mod["main"]) + return + # default case # Try graph codegen first to extract autotvm tasks. # If failed to compile, then fallback to use VM compiler. @@ -78,7 +78,7 @@ def extract_from_program(mod, params, target, target_host=None, ops=None): The compilation target target_host: tvm.target.Target The host compilation target - ops: List[relay.op.Op] or None + ops: List[tvm.ir.Op] or None List of relay ops to be tuned. If not specified, all tunable ops will be extracted. Returns @@ -105,7 +105,7 @@ def extract_from_multiple_program(mods, params, target, target_host=None, ops=No The compilation target target_host: tvm.target.Target The host compilation target - ops: List[relay.op.Op] or None + ops: List[tvm.ir.Op] or None List of relay ops to be tuned. If not specified, all tunable ops will be extracted. Returns @@ -137,6 +137,7 @@ def extract_from_multiple_program(mods, params, target, target_host=None, ops=No args=(mod, target, param)) build_thread.start() build_thread.join() + relay.backend.compile_engine.get().clear() logger.disabled = old_state diff --git a/python/tvm/autotvm/task/task.py b/python/tvm/autotvm/task/task.py index 00b667670c65..b7cd6f2b04ed 100644 --- a/python/tvm/autotvm/task/task.py +++ b/python/tvm/autotvm/task/task.py @@ -495,11 +495,11 @@ def _count_flop(exp): if isinstance(exp, expr.Select): return _count_flop(exp.condition) + max(_count_flop(exp.true_value), _count_flop(exp.false_value)) - if isinstance(exp, expr.Call): - if exp.call_type == expr.Call.Halide: - # Ignore flops from indexing expressions. - return 0 + if isinstance(exp, expr.ProducerLoad): + # Ignore flops from indexing expressions. + return 0 + if isinstance(exp, expr.Call): return sum([_count_flop(x) for x in exp.args]) raise FlopCalculationError("Found unsupported operator in the compute expr") diff --git a/python/tvm/autotvm/task/topi_integration.py b/python/tvm/autotvm/task/topi_integration.py index 67f9780c2f93..59e77f7d0098 100644 --- a/python/tvm/autotvm/task/topi_integration.py +++ b/python/tvm/autotvm/task/topi_integration.py @@ -61,7 +61,7 @@ def reset(self, wanted_relay_ops=None): Parameters ---------- - wanted_relay_ops: List of relay.op.Op + wanted_relay_ops: List of tvm.ir.Op The relay ops to be extracted """ self.task_collection = [] diff --git a/python/tvm/autotvm/tophub.py b/python/tvm/autotvm/tophub.py index f13ba5289ce5..c7c55ed1038a 100644 --- a/python/tvm/autotvm/tophub.py +++ b/python/tvm/autotvm/tophub.py @@ -50,12 +50,13 @@ 'llvm': "v0.04", 'cuda': "v0.08", - 'rocm': "v0.04", + 'rocm': "v0.05", 'opencl': "v0.04", 'mali': "v0.06", 'intel_graphics': "v0.02", 'vta': "v0.08", + 'amd_apu': "v0.01", } logger = logging.getLogger('autotvm') @@ -66,8 +67,10 @@ def _alias(name): 'vtacpu': 'vta', 'metal': 'opencl', + 'webgpu': 'opencl', 'vulkan': 'opencl', 'nvptx': 'cuda', + 'amd_apu': 'amd_apu' } return table.get(name, name) @@ -213,10 +216,12 @@ def load_reference_log(backend, model, workload_name): if key not in REFERENCE_LOG_CACHE: tmp = [] + # If TOPHUB_LOCATION is not AUTOTVM_TOPHUB_NONE_LOC, # Download the config file from tophub if not exists. if not os.path.exists(filename): tophub_location = _get_tophub_location() - download_package(tophub_location, package_name) + if tophub_location != AUTOTVM_TOPHUB_NONE_LOC: + download_package(tophub_location, package_name) if os.path.isfile(filename): # in case download failed find = False inp = None diff --git a/python/tvm/autotvm/tuner/callback.py b/python/tvm/autotvm/tuner/callback.py index 4c2fe87cf3c6..cfc1b2c38f85 100644 --- a/python/tvm/autotvm/tuner/callback.py +++ b/python/tvm/autotvm/tuner/callback.py @@ -23,6 +23,7 @@ import numpy as np from .. import record +from ..util import format_si_prefix logger = logging.getLogger('autotvm') @@ -105,7 +106,7 @@ def trial_timestamps(self): return np.array(self.timestamps) -def progress_bar(total, prefix=''): +def progress_bar(total, prefix='', si_prefix='G'): """Display progress bar for tuning Parameters @@ -114,6 +115,8 @@ def progress_bar(total, prefix=''): The total number of trials prefix: str The prefix of output message + si_prefix: str + SI prefix for flops """ class _Context(object): """Context to store local variables""" @@ -130,6 +133,9 @@ def __del__(self): ctx = _Context() tic = time.time() + # Validate si_prefix argument + format_si_prefix(0, si_prefix) + if logger.level < logging.DEBUG: # only print progress bar in non-debug mode sys.stdout.write('\r%s Current/Best: %7.2f/%7.2f GFLOPS | Progress: (%d/%d) ' '| %.2f s' % (prefix, 0, 0, 0, total, time.time() - tic)) @@ -143,14 +149,15 @@ def _callback(tuner, inputs, results): if res.error_no == 0: flops = inp.task.flop / np.mean(res.costs) - if logger.level < logging.DEBUG: # only print progress bar in non-debug mode + if not logger.isEnabledFor(logging.DEBUG): # only print progress bar in non-debug mode ctx.cur_flops = flops ctx.best_flops = tuner.best_flops - sys.stdout.write('\r%s Current/Best: %7.2f/%7.2f GFLOPS | Progress: (%d/%d) ' + sys.stdout.write('\r%s Current/Best: %7.2f/%7.2f %sFLOPS | Progress: (%d/%d) ' '| %.2f s' % - (prefix, ctx.cur_flops/1e9, ctx.best_flops/1e9, ctx.ct, ctx.total, - time.time() - tic)) + (prefix, format_si_prefix(ctx.cur_flops, si_prefix), + format_si_prefix(ctx.best_flops, si_prefix), si_prefix, + ctx.ct, ctx.total, time.time() - tic)) sys.stdout.flush() return _callback diff --git a/python/tvm/autotvm/tuner/ga_tuner.py b/python/tvm/autotvm/tuner/ga_tuner.py index a4c36bcd385e..da10f73d5a53 100644 --- a/python/tvm/autotvm/tuner/ga_tuner.py +++ b/python/tvm/autotvm/tuner/ga_tuner.py @@ -50,7 +50,11 @@ def __init__(self, task, pop_size=100, elite_num=3, mutation_prob=0.1): # space info self.space = task.config_space - self.dims = [len(x) for x in self.space.space_map.values()] + self.dim_keys = [] + self.dims = [] + for k, v in self.space.space_map.items(): + self.dim_keys.append(k) + self.dims.append(len(v)) self.visited = set([]) @@ -123,7 +127,7 @@ def update(self, inputs, results): if len(self.visited) < len(self.space): while knob2point(tmp_gene, self.dims) in self.visited: j = np.random.randint(len(self.dims)) - tmp_gene[j] = np.random.randint(self.dims[j]) + tmp_gene[j] = np.random.randint(self.dims[j]) # pylint: disable=invalid-sequence-index next_genes.append(tmp_gene) self.visited.add(knob2point(tmp_gene, self.dims)) else: diff --git a/python/tvm/autotvm/tuner/tuner.py b/python/tvm/autotvm/tuner/tuner.py index 76d088f4cfb3..2441a4ae642f 100644 --- a/python/tvm/autotvm/tuner/tuner.py +++ b/python/tvm/autotvm/tuner/tuner.py @@ -21,6 +21,7 @@ import numpy as np from ..measure import MeasureInput, create_measure_batch +from ..util import format_si_prefix from ..env import GLOBAL_SCOPE @@ -87,7 +88,7 @@ def update(self, inputs, results): """ - def tune(self, n_trial, measure_option, early_stopping=None, callbacks=()): + def tune(self, n_trial, measure_option, early_stopping=None, callbacks=(), si_prefix='G'): """Begin tuning Parameters @@ -104,6 +105,8 @@ def tune(self, n_trial, measure_option, early_stopping=None, callbacks=()): (Tuner, List of MeasureInput, List of MeasureResult) with no return value. These callback functions will be called on every measurement pair. See autotvm/tuner/callback.py for some examples. + si_prefix: str + One of tvm.autotvm.util.SI_PREFIXES. The SI prefix to use when reporting FLOPS. """ measure_batch = create_measure_batch(self.task, measure_option) n_parallel = getattr(measure_batch, 'n_parallel', 1) @@ -111,6 +114,9 @@ def tune(self, n_trial, measure_option, early_stopping=None, callbacks=()): self.n_trial = n_trial self.early_stopping = early_stopping + # Validate si_prefix arg + format_si_prefix(0, si_prefix) + old_level = logger.level GLOBAL_SCOPE.in_tuning = True @@ -140,9 +146,9 @@ def tune(self, n_trial, measure_option, early_stopping=None, callbacks=()): self.best_measure_pair = (inp, res) self.best_iter = i + k - logger.debug("No: %d\tGFLOPS: %.2f/%.2f\tresult: %s\t%s", - i + k + 1, flops / 1e9, self.best_flops / 1e9, - res, config) + logger.debug("No: %d\t%sFLOPS: %.2f/%.2f\tresult: %s\t%s", + i + k + 1, si_prefix, format_si_prefix(flops, si_prefix), + format_si_prefix(self.best_flops, si_prefix), res, config) i += len(results) self.ttl = min(early_stopping + self.best_iter, n_trial) - i diff --git a/python/tvm/autotvm/tuner/xgboost_cost_model.py b/python/tvm/autotvm/tuner/xgboost_cost_model.py index 305244808a33..15a3390d3522 100644 --- a/python/tvm/autotvm/tuner/xgboost_cost_model.py +++ b/python/tvm/autotvm/tuner/xgboost_cost_model.py @@ -118,7 +118,7 @@ def __init__(self, task, feature_type, loss_type, num_threads=None, log_interval else: raise RuntimeError("Invalid loss type: " + loss_type) - self.xgb_params['silent'] = 1 + self.xgb_params['verbosity'] = 0 if num_threads: self.xgb_params['nthread'] = num_threads self.bst = None diff --git a/python/tvm/autotvm/util.py b/python/tvm/autotvm/util.py index 01d50e86a88a..0d81c123994a 100644 --- a/python/tvm/autotvm/util.py +++ b/python/tvm/autotvm/util.py @@ -23,8 +23,8 @@ from random import randrange import numpy as np - -from tvm.tir import expr, ir_pass +import tvm.arith +from tvm.tir import expr logger = logging.getLogger('autotvm') @@ -156,7 +156,8 @@ def get_const_int(exp): if isinstance(exp, int): return exp if not isinstance(exp, (expr.IntImm,)): - exp = ir_pass.Simplify(exp) + ana = tvm.arith.Analyzer() + exp = ana.simplify(exp) if not isinstance(exp, (expr.IntImm,)): raise ValueError("Expect value to be constant int") return exp.value @@ -180,9 +181,19 @@ def get_const_tuple(in_tuple): if isinstance(elem, expr.Var): ret.append(elem) elif not isinstance(elem, (expr.IntImm, int)): - elem = ir_pass.Simplify(elem) + ana = tvm.arith.Analyzer() + elem = ana.simplify(elem) if not isinstance(elem, (expr.IntImm)): ret.append(elem) else: ret.append(get_const_int(elem)) return tuple(ret) + + +SI_PREFIXES = 'yzafpn\xb5m kMGTPEZY' +YOCTO_EXP10 = -24 + + +def format_si_prefix(x, si_prefix): + exp10 = 10 ** (SI_PREFIXES.index(si_prefix) * 3 + YOCTO_EXP10) + return float(x) / exp10 diff --git a/python/tvm/contrib/binutil.py b/python/tvm/contrib/binutil.py index 521e0885548c..d784b7b9bd6c 100644 --- a/python/tvm/contrib/binutil.py +++ b/python/tvm/contrib/binutil.py @@ -21,7 +21,9 @@ import tvm._ffi from . import util +# TODO does this file still belong in `contrib`. is it too µTVM-specific? +# TODO shouldn't need so many `ALIGN` directives RELOCATION_LD_SCRIPT_TEMPLATE = """ /* linker symbol for use in UTVMInit */ _utvm_stack_pointer_init = 0x{stack_pointer_init:x}; @@ -118,7 +120,7 @@ def tvm_callback_get_section_size(binary_path, section_name, toolchain_prefix): size of the section in bytes """ if not os.path.isfile(binary_path): - raise RuntimeError("no such file \"{}\"".format(binary_path)) + raise RuntimeError('no such file "{}"'.format(binary_path)) # We use the "-A" flag here to get the ".rodata" section's size, which is # not included by default. size_output = run_cmd(["{}size".format(toolchain_prefix), "-A", binary_path]) @@ -145,21 +147,15 @@ def tvm_callback_get_section_size(binary_path, section_name, toolchain_prefix): section_size += entry_size break - # NOTE: For some reason, the size of the BSS section on the RISC-V - # GCC is sometimes reported to be smaller than it is, so we need to adjust - # for this. - if "riscv" in toolchain_prefix and section_name == "bss": - # TODO(weberlo): Figure out why 32 is the minimum constant that works. - # - # The current hypothesis is that the last symbols in the ".bss" and - # ".sbss" sections may have size zero, since the symbols in these - # sections are uninitialized and there's no address that follows that - # would enforce a particular size. - # - # If this is the case, then 32 just happens to be a safe amount of - # padding for most cases, but symbols can be arbitrarily large, so this - # isn't bulletproof. - return section_size + 32 + # NOTE: in the past, section_size has been wrong on x86. it may be + # inconsistent. TODO: maybe stop relying on `*size` to give us the size and + # instead read the section with `*objcopy` and count the bytes. + # NOTE(areusch): I think the problem is due to alignment ops in the linker. + # Since this is going away in the impending switch to on-device runtime, + # add a constant to hopefully absorb these relocations. + if section_size > 0: + section_size += 64 + return section_size @@ -206,11 +202,13 @@ def tvm_callback_relocate_binary( rel_bin : bytearray the relocated binary """ + assert text_start < rodata_start < data_start < bss_start < stack_end stack_pointer_init = stack_end - word_size ld_script_contents = "" # TODO(weberlo): There should be a better way to configure this for different archs. + # TODO is this line even necessary? if "riscv" in toolchain_prefix: - ld_script_contents += "OUTPUT_ARCH( \"riscv\" )\n\n" + ld_script_contents += 'OUTPUT_ARCH( "riscv" )\n\n' ld_script_contents += RELOCATION_LD_SCRIPT_TEMPLATE.format( word_size=word_size, text_start=text_start, @@ -221,7 +219,7 @@ def tvm_callback_relocate_binary( tmp_dir = util.tempdir() rel_obj_path = tmp_dir.relpath("relocated.obj") - rel_ld_script_path = tmp_dir.relpath("relocated.lds") + rel_ld_script_path = tmp_dir.relpath("relocate.lds") with open(rel_ld_script_path, "w") as f: f.write(ld_script_contents) run_cmd([ @@ -229,8 +227,23 @@ def tvm_callback_relocate_binary( binary_path, "-T", rel_ld_script_path, "-o", rel_obj_path]) + with open(rel_obj_path, "rb") as f: rel_bin = bytearray(f.read()) + + gdb_init_dir = os.environ.get("MICRO_GDB_INIT_DIR") + if gdb_init_dir is not None: + gdb_init_path = f"{gdb_init_dir}/.gdbinit" + with open(gdb_init_path, "r") as f: + gdbinit_contents = f.read().split("\n") + new_contents = [] + for line in gdbinit_contents: + new_contents.append(line) + if line.startswith("target"): + new_contents.append(f"add-symbol-file {rel_obj_path}") + with open(gdb_init_path, "w") as f: + f.write("\n".join(new_contents)) + return rel_bin diff --git a/python/tvm/contrib/cc.py b/python/tvm/contrib/cc.py index ae37923a1dcf..cdde0cbed1d5 100644 --- a/python/tvm/contrib/cc.py +++ b/python/tvm/contrib/cc.py @@ -45,12 +45,39 @@ def create_shared(output, The compiler command. """ if sys.platform == "darwin" or sys.platform.startswith("linux"): - _linux_compile(output, objects, options, cc) + _linux_compile(output, objects, options, cc, compile_shared=True) elif sys.platform == "win32": _windows_shared(output, objects, options) else: raise ValueError("Unsupported platform") + +def create_executable(output, + objects, + options=None, + cc="g++"): + """Create executable binary. + + Parameters + ---------- + output : str + The target executable. + + objects : List[str] + List of object files. + + options : List[str] + The list of additional options string. + + cc : Optional[str] + The compiler command. + """ + if sys.platform == "darwin" or sys.platform.startswith("linux"): + _linux_compile(output, objects, options, cc) + else: + raise ValueError("Unsupported platform") + + def get_target_by_dump_machine(compiler): """ Functor of get_target_triple that can get the target triple using compiler. @@ -90,7 +117,8 @@ def get_target_triple(): def cross_compiler(compile_func, options=None, output_format=None, - get_target_triple=None): + get_target_triple=None, + add_files=None): """Create a cross compiler function by specializing compile_func with options. This function can be used to construct compile functions that @@ -111,6 +139,10 @@ def cross_compiler(compile_func, get_target_triple: Optional[Callable] Function that can target triple according to dumpmachine option of compiler. + add_files: Optional[List[str]] + List of paths to additional object, source, library files + to pass as part of the compilation. + Returns ------- fcompile : Callable[[str, str, Optional[str]], None] @@ -133,6 +165,7 @@ def cross_compiler(compile_func, """ base_options = [] if options is None else options kwargs = {} + add_files = [] if add_files is None else add_files # handle case where compile_func is the name of the cc if isinstance(compile_func, str): @@ -144,7 +177,7 @@ def _fcompile(outputs, objects, options=None): all_options = base_options if options is not None: all_options += options - compile_func(outputs, objects, options=all_options, **kwargs) + compile_func(outputs, objects + add_files, options=all_options, **kwargs) if not output_format and hasattr(compile_func, "output_format"): output_format = compile_func.output_format @@ -158,9 +191,10 @@ def _fcompile(outputs, objects, options=None): return _fcompile -def _linux_compile(output, objects, options, compile_cmd="g++"): +def _linux_compile(output, objects, options, + compile_cmd="g++", compile_shared=False): cmd = [compile_cmd] - if output.endswith(".so") or output.endswith(".dylib"): + if compile_shared or output.endswith(".so") or output.endswith(".dylib"): cmd += ["-shared", "-fPIC"] if sys.platform == "darwin": cmd += ["-undefined", "dynamic_lookup"] @@ -179,6 +213,7 @@ def _linux_compile(output, objects, options, compile_cmd="g++"): if proc.returncode != 0: msg = "Compilation error:\n" msg += py_str(out) + msg += "\nCommand line: " + " ".join(cmd) raise RuntimeError(msg) diff --git a/python/tvm/contrib/coreml_runtime.py b/python/tvm/contrib/coreml_runtime.py new file mode 100644 index 000000000000..d9f8c6a4652d --- /dev/null +++ b/python/tvm/contrib/coreml_runtime.py @@ -0,0 +1,69 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""CoreML runtime that load and run coreml models.""" +import tvm._ffi +from ..rpc import base as rpc_base + +def create(model_dir, ctx): + """Create a runtime executor module given a coreml model and context. + Parameters + ---------- + model_dir : str + The directory where the compiled models are located. + ctx : TVMContext + The context to deploy the module. It can be local or remote when there + is only one TVMContext. + Returns + ------- + coreml_runtime : CoreMLModule + Runtime coreml module that can be used to execute the coreml model. + """ + device_type = ctx.device_type + runtime_func = "tvm.coreml_runtime.create" + + if device_type >= rpc_base.RPC_SESS_MASK: + fcreate = ctx._rpc_sess.get_function(runtime_func) + else: + fcreate = tvm._ffi.get_global_func(runtime_func) + + return CoreMLModule(fcreate(model_dir)) + + +class CoreMLModule(object): + """Wrapper runtime module. + + This is a thin wrapper of the underlying TVM module. + you can also directly call set_input, run, and get_output + of underlying module functions + + Parameters + ---------- + module : Module + The internal tvm module that holds the actual coreml functions. + + Attributes + ---------- + module : Module + The internal tvm module that holds the actual coreml functions. + """ + + def __init__(self, module): + self.module = module + self.invoke = module["invoke"] + self.set_input = module["set_input"] + self.get_output = module["get_output"] + self.get_num_outputs = module["get_num_outputs"] diff --git a/python/tvm/contrib/cudnn.py b/python/tvm/contrib/cudnn.py index 5043520ccf13..0650b934b972 100644 --- a/python/tvm/contrib/cudnn.py +++ b/python/tvm/contrib/cudnn.py @@ -182,7 +182,8 @@ def conv_output_shape(tensor_format, x_shape, w_shape, data_dtype, - conv_dtype): + conv_dtype, + groups=1): """Get output shape of 2D or 3D convolution Paramters @@ -205,6 +206,8 @@ def conv_output_shape(tensor_format, data type conv_dtype: str convolution type + groups: int + number of groups Returns ------- @@ -228,7 +231,8 @@ def conv_output_shape(tensor_format, _get_np_int32_array_handle(wshape), _get_np_int32_array_handle(oshape), data_dtype, - conv_dtype) + conv_dtype, + groups) return list(oshape) @@ -240,7 +244,8 @@ def conv_find_algo(tensor_format, w_shape, y_shape, data_dtype, - conv_dtype): + conv_dtype, + groups=1): """Choose the best algo for the given input. Paramters @@ -265,6 +270,8 @@ def conv_find_algo(tensor_format, data type conv_dtype: str convolution type + groups: int + number of groups Returns ------- @@ -287,7 +294,8 @@ def conv_find_algo(tensor_format, _get_np_int32_array_handle(wshape), _get_np_int32_array_handle(yshape), data_dtype, - conv_dtype) + conv_dtype, + groups) def conv_forward(x, @@ -298,7 +306,8 @@ def conv_forward(x, conv_mode, tensor_format, algo, - conv_dtype): + conv_dtype, + groups=1): """Create an extern op that compute 2D or 3D convolution with CuDNN Parameters @@ -325,6 +334,8 @@ def conv_forward(x, if algo == -1, the best algo will be chosen by CUDNN conv_dtype: str convolution type + groups: int + the number of groups Returns ------- @@ -335,8 +346,7 @@ def conv_forward(x, assert dims in (4, 5) conv_dtype = x.dtype if conv_dtype is None else conv_dtype - pad, stride, dilation, _, _ = \ - _prepare_global_func_params(dims - 2, pad, stride, dilation) + pad, stride, dilation, _, _ = _prepare_global_func_params(dims - 2, pad, stride, dilation) oshape = conv_output_shape(tensor_format, pad, @@ -345,7 +355,8 @@ def conv_forward(x, list(x.shape), list(w.shape), x.dtype, - conv_dtype) + conv_dtype, + groups) if algo == -1: # For now if we try to call `cudnnFindConvolutionForwardAlgorithm` when # using INT8 data type, CuDNN will crash down. @@ -361,7 +372,8 @@ def conv_forward(x, list(w.shape), oshape, x.dtype, - conv_dtype) + conv_dtype, + groups) if dims == 4: return te.extern( @@ -380,7 +392,8 @@ def conv_forward(x, ins[0], ins[1], outs[0], - conv_dtype), name="y") + conv_dtype, + groups), name="y") return te.extern( oshape, [x, w], @@ -401,7 +414,8 @@ def conv_forward(x, ins[0], ins[1], outs[0], - conv_dtype), name="y") + conv_dtype, + groups), name="y") def softmax(x, axis=-1): """Compute softmax using CuDNN diff --git a/python/tvm/contrib/debugger/debug_result.py b/python/tvm/contrib/debugger/debug_result.py index 18920c60719e..b1fe1b62b8a9 100644 --- a/python/tvm/contrib/debugger/debug_result.py +++ b/python/tvm/contrib/debugger/debug_result.py @@ -53,9 +53,9 @@ def __init__(self, graph_json, dump_path): self._dump_path = dump_path self._output_tensor_list = [] self._time_list = [] - self._parse_graph(graph_json) + json_obj = self._parse_graph(graph_json) # dump the json information - self.dump_graph_json(graph_json) + self._dump_graph_json(json_obj) def _parse_graph(self, graph_json): """Parse and extract the JSON graph and update the nodes, shapes and dltype. @@ -70,12 +70,12 @@ def _parse_graph(self, graph_json): self._shapes_list = json_obj['attrs']['shape'] self._dtype_list = json_obj['attrs']['dltype'] self._update_graph_json() + return json_obj def _update_graph_json(self): """update the nodes_list with name, shape and data type, for temporarily storing the output. """ - nodes_len = len(self._nodes_list) for i in range(nodes_len): node = self._nodes_list[i] @@ -192,7 +192,7 @@ def node_to_events(node, times, starting_time): with open(os.path.join(self._dump_path, CHROME_TRACE_FILE_NAME), "w") as trace_f: json.dump(result, trace_f) - def dump_graph_json(self, graph): + def _dump_graph_json(self, graph): """Dump json formatted graph. Parameters diff --git a/python/tvm/contrib/emscripten.py b/python/tvm/contrib/emcc.py similarity index 65% rename from python/tvm/contrib/emscripten.py rename to python/tvm/contrib/emcc.py index 7f31273451f7..6e7e997d43ad 100644 --- a/python/tvm/contrib/emscripten.py +++ b/python/tvm/contrib/emcc.py @@ -16,18 +16,16 @@ # under the License. """Util to invoke emscripten compilers in the system.""" # pylint: disable=invalid-name -from __future__ import absolute_import as _abs - import subprocess -from .._ffi.base import py_str -from .._ffi.libinfo import find_lib_path +from tvm._ffi.base import py_str +from tvm._ffi.libinfo import find_lib_path + -def create_js(output, - objects, - options=None, - side_module=False, - cc="emcc"): - """Create emscripten javascript library. +def create_tvmjs_wasm(output, + objects, + options=None, + cc="emcc"): + """Create wasm that is supposed to run with the tvmjs. Parameters ---------- @@ -44,25 +42,28 @@ def create_js(output, The compile string. """ cmd = [cc] - cmd += ["-Oz"] - if not side_module: - cmd += ["-s", "RESERVED_FUNCTION_POINTERS=2"] - cmd += ["-s", "NO_EXIT_RUNTIME=1"] - extra_methods = ['cwrap', 'getValue', 'setValue', 'addFunction'] - cfg = "[" + (','.join("\'%s\'" % x for x in extra_methods)) + "]" - cmd += ["-s", "EXTRA_EXPORTED_RUNTIME_METHODS=" + cfg] - else: - cmd += ["-s", "SIDE_MODULE=1"] - cmd += ["-o", output] + cmd += ["-O3"] + + cmd += ["-std=c++14"] + cmd += ["-s", "ERROR_ON_UNDEFINED_SYMBOLS=0"] + cmd += ["-s", "STANDALONE_WASM=1"] + cmd += ["-s", "ALLOW_MEMORY_GROWTH=1"] + + objects = [objects] if isinstance(objects, str) else objects + with_runtime = False for obj in objects: - if obj.find("libtvm_web_runtime.bc") != -1: + if obj.find("wasm_runtime.bc") != -1: with_runtime = True - if not with_runtime and not side_module: - objects += [find_lib_path("libtvm_web_runtime.bc")[0]] + if not with_runtime: + objects += [find_lib_path("wasm_runtime.bc")[0]] + objects += [find_lib_path("tvmjs_support.bc")[0]] + objects += [find_lib_path("webgpu_runtime.bc")[0]] + + cmd += ["-o", output] cmd += objects if options: @@ -79,4 +80,4 @@ def create_js(output, msg += py_str(out) raise RuntimeError(msg) -create_js.object_format = "bc" +create_tvmjs_wasm.object_format = "bc" diff --git a/python/tvm/contrib/graph_runtime.py b/python/tvm/contrib/graph_runtime.py index 73235f71c77b..740d1c3f19f3 100644 --- a/python/tvm/contrib/graph_runtime.py +++ b/python/tvm/contrib/graph_runtime.py @@ -18,9 +18,10 @@ import numpy as np import tvm._ffi -from .._ffi.base import string_types -from .._ffi.runtime_ctypes import TVMContext -from ..rpc import base as rpc_base +from tvm.rpc import _ffi_api as _rpc_ffi_api +from tvm.rpc import base as rpc_base +from tvm._ffi.base import string_types +from tvm._ffi.runtime_ctypes import TVMContext def create(graph_json_str, libmod, ctx): @@ -99,7 +100,7 @@ def get_device_ctx(libmod, ctx): device_type = cur_ctx.device_type if device_type >= rpc_base.RPC_SESS_MASK: assert libmod.type_key == "rpc" - assert rpc_base._SessTableIndex( + assert _rpc_ffi_api.SessTableIndex( libmod) == cur_ctx._rpc_sess._tbl_index num_rpc_ctx += 1 device_type = cur_ctx.device_type % rpc_base.RPC_SESS_MASK diff --git a/tests/scripts/task_web_build.sh b/python/tvm/contrib/target/__init__.py old mode 100755 new mode 100644 similarity index 88% rename from tests/scripts/task_web_build.sh rename to python/tvm/contrib/target/__init__.py index ec1d15a04fbb..7d815413f28a --- a/tests/scripts/task_web_build.sh +++ b/python/tvm/contrib/target/__init__.py @@ -1,4 +1,3 @@ -#!/bin/bash # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -15,6 +14,5 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -cp /emsdk-portable/.emscripten ~/.emscripten -source /emsdk-portable/emsdk_env.sh -make -j4 +"""Codegen and runtime APIs for targets. +""" diff --git a/python/tvm/contrib/target/coreml.py b/python/tvm/contrib/target/coreml.py new file mode 100644 index 000000000000..e74457ee5378 --- /dev/null +++ b/python/tvm/contrib/target/coreml.py @@ -0,0 +1,226 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-argument, import-outside-toplevel +"""Utility to compile CoreML models""" + +import os +import shutil + +import tvm._ffi +from ...relay.expr_functor import ExprVisitor +from .. import xcode, coreml_runtime + +def _convert_add(builder, name, inputs, outputs, args, attrs): + builder.add_elementwise( + name=name, + input_names=inputs, + output_name=outputs[0], + mode='ADD' + ) + +def _convert_multiply(builder, name, inputs, outputs, args, attrs): + builder.add_elementwise( + name=name, + input_names=inputs, + output_name=outputs[0], + mode='MULTIPLY' + ) + +def _convert_clip(builder, name, inputs, outputs, args, attrs): + builder.add_clip( + name=name, + input_name=inputs[0], + output_name=outputs[0], + min_value=attrs.a_min, + max_value=attrs.a_max + ) + +def _convert_batch_flatten(builder, name, inputs, outputs, args, attrs): + builder.add_flatten_to_2d( + name=name, + input_name=inputs[0], + output_name=outputs[0] + ) + +def _convert_softmax(builder, name, inputs, outputs, args, attrs): + builder.add_softmax_nd( + name=name, + input_name=inputs[0], + output_name=outputs[0], + axis=int(attrs['axis']) + ) + +def _convert_conv2d(builder, name, inputs, outputs, args, attrs): + weight = args[1].data.asnumpy() + if attrs['kernel_layout'] == 'OIHW': + # convert to 'HWIO' + weight = weight.transpose([2, 3, 1, 0]) + kh, kw, kc, oc = weight.shape + + builder.add_convolution( + name=name, + kernel_channels=kc, + output_channels=oc, + height=kh, + width=kw, + stride_height=int(attrs['strides'][0]), + stride_width=int(attrs['strides'][0]), + border_mode="valid", + groups=int(attrs['groups']), + W=weight, + b=None, + has_bias=False, + input_name=inputs[0], + output_name=outputs[0], + dilation_factors=[int(v) for v in attrs['dilation']], + padding_top=int(attrs['padding'][0]), + padding_bottom=int(attrs['padding'][2]), + padding_left=int(attrs['padding'][1]), + padding_right=int(attrs['padding'][3]) + ) + +def _convert_global_avg_pool2d(builder, name, inputs, outputs, args, attrs): + builder.add_pooling( + name=name, + height=1, + width=1, + stride_height=1, + stride_width=1, + layer_type='AVERAGE', + padding_type='VALID', + input_name=inputs[0], + output_name=outputs[0], + is_global=True + ) + +_convert_map = { + 'add' : _convert_add, + 'multiply' : _convert_multiply, + 'clip' : _convert_clip, + 'nn.batch_flatten' : _convert_batch_flatten, + 'nn.softmax' : _convert_softmax, + 'nn.conv2d' : _convert_conv2d, + 'nn.global_avg_pool2d' : _convert_global_avg_pool2d, +} + +class CodegenCoreML(ExprVisitor): + """ + A visitor to traverse subgraphs and build Core ML models. + """ + def __init__(self, model_name, function): + import coremltools + from coremltools.models.neural_network import NeuralNetworkBuilder + + ExprVisitor.__init__(self) + self.model_name = model_name + self.function = function + self.out_map = {} + self.model_inputs_ = [] + self.buf_idx_ = 0 + + # Update inputs and outputs after we visit all the nodes. + # Set dummy values for now. + # TODO: support multiple outputs + inputs = [('', coremltools.models.datatypes.Array(1,)) for _ in self.function.params] + outputs = [('', coremltools.models.datatypes.Array(1,))] + self.builder = NeuralNetworkBuilder(inputs, outputs, + disable_rank5_shape_mapping=True) + + def visit_constant(self, const): + output = "buf_" + str(self.buf_idx_) + self.builder.add_load_constant_nd( + name=output, + output_name=output, + constant_value=const.data.asnumpy(), + shape=const.data.shape + ) + self.buf_idx_ = self.buf_idx_ + 1 + self.out_map[const] = [output] + + def visit_var(self, var): + name = var.name_hint + shape = [int(n) for n in var.type_annotation.shape] + dtype = var.type_annotation.dtype + self.model_inputs_.append((name, shape, dtype)) + self.out_map[var] = [name] + + def visit_call(self, call): + inputs = [] + for arg in call.args: + super().visit(arg) + for out in self.out_map[arg]: + inputs.append(out) + outputs = ["buf_" + str(self.buf_idx_)] + op_name = call.op.name + layer_name = op_name + "_" + str(self.buf_idx_) + + assert op_name in _convert_map, "{} is not supported".format(op_name) + _convert_map[op_name](self.builder, layer_name, inputs, outputs, + call.args, call.attrs) + + self.buf_idx_ = self.buf_idx_ + 1 + self.out_map[call] = outputs + + def compile(self, out_dir): + """ + Build a Core ML model and compile it with Xcode toolchain. + """ + import coremltools + from coremltools.proto.Model_pb2 import ArrayFeatureType + + FEATURE_TYPE_MAP = { + "float32": ArrayFeatureType.FLOAT32, + "float64": ArrayFeatureType.DOUBLE, + "int32": ArrayFeatureType.INT32, + } + + input_names, input_dims, input_dtypes = zip(*self.model_inputs_) + self.builder.set_input(input_names, input_dims) + for i, dtype in enumerate(input_dtypes): + assert dtype in FEATURE_TYPE_MAP + input_desc = self.builder.spec.description.input + input_desc[i].type.multiArrayType.dataType = FEATURE_TYPE_MAP[dtype] + + output_dim = [int(n) for n in self.function.ret_type.shape] + self.builder.set_output(self.out_map[self.function.body], [output_dim]) + for i, dtype in enumerate([self.function.ret_type.dtype]): + assert dtype in FEATURE_TYPE_MAP + output_desc = self.builder.spec.description.output + output_desc[i].type.multiArrayType.dataType = FEATURE_TYPE_MAP[dtype] + + model = coremltools.models.MLModel(self.builder.spec) + xcode.compile_coreml(model, self.model_name, out_dir) + + +@tvm._ffi.register_func("relay.ext.coremlcompiler") +def coreml_compiler(ref): + """ + Create a CoreML runtime from a Relay module. + """ + model_dir = os.getcwd() + if isinstance(ref, tvm.ir.module.IRModule): + for var, func in ref.functions.items(): + name = var.name_hint + builder = CodegenCoreML(name, func) + builder.visit(func.body) + mlmodelc_path = "{}/{}.mlmodelc".format(model_dir, name) + if os.path.exists(mlmodelc_path): + shutil.rmtree(mlmodelc_path) + builder.compile(model_dir) + + ctx = tvm.cpu(0) + return coreml_runtime.create(model_dir, ctx).module diff --git a/python/tvm/contrib/tf_op/module.py b/python/tvm/contrib/tf_op/module.py index f13670e39895..7daf45fcb0c2 100644 --- a/python/tvm/contrib/tf_op/module.py +++ b/python/tvm/contrib/tf_op/module.py @@ -17,6 +17,7 @@ """Module container of TensorFlow TVMDSO op""" import tensorflow as tf from tensorflow.python.framework import load_library +from tensorflow.python import platform class OpModule: @@ -67,7 +68,7 @@ def __init__(self, lib_path, func_name, output_dtype, output_shape): elif output_shape is not None: self.dynamic_output_shape = self._pack_shape_tensor(output_shape) - self.module = load_library.load_op_library('tvm_dso_op.so') + self.module = self._load_platform_specific_library("libtvm_dso_op") self.tvm_dso_op = self.module.tvm_dso_op def apply(self, *params): @@ -82,6 +83,16 @@ def apply(self, *params): def __call__(self, *params): return self.apply(*params) + def _load_platform_specific_library(self, lib_name): + system = platform.system() + if system == "Darwin": + lib_file_name = lib_name + ".dylib" + elif system == "Windows": + lib_file_name = lib_name + ".dll" + else: + lib_file_name = lib_name + ".so" + return load_library.load_op_library(lib_file_name) + def _is_static_shape(self, shape): if shape is None or not isinstance(shape, list): return False diff --git a/python/tvm/contrib/util.py b/python/tvm/contrib/util.py index 2ebe175e8160..8f6dfc7f28ec 100644 --- a/python/tvm/contrib/util.py +++ b/python/tvm/contrib/util.py @@ -15,8 +15,12 @@ # specific language governing permissions and limitations # under the License. """Common system utilities""" +import atexit +import contextlib +import datetime import os import tempfile +import threading import shutil try: import fcntl @@ -24,26 +28,97 @@ fcntl = None +class DirectoryCreatedPastAtExit(Exception): + """Raised when a TempDirectory is created after the atexit hook runs.""" + class TempDirectory(object): """Helper object to manage temp directory during testing. Automatically removes the directory when it went out of scope. """ + + # When True, all TempDirectory are *NOT* deleted and instead live inside a predicable directory + # tree. + _KEEP_FOR_DEBUG = False + + # In debug mode, each tempdir is named after the sequence + _NUM_TEMPDIR_CREATED = 0 + _NUM_TEMPDIR_CREATED_LOCK = threading.Lock() + @classmethod + def _increment_num_tempdir_created(cls): + with cls._NUM_TEMPDIR_CREATED_LOCK: + to_return = cls._NUM_TEMPDIR_CREATED + cls._NUM_TEMPDIR_CREATED += 1 + + return to_return + + _DEBUG_PARENT_DIR = None + @classmethod + def _get_debug_parent_dir(cls): + if cls._DEBUG_PARENT_DIR is None: + all_parents = f'{tempfile.gettempdir()}/tvm-debug-mode-tempdirs' + if not os.path.isdir(all_parents): + os.makedirs(all_parents) + cls._DEBUG_PARENT_DIR = tempfile.mkdtemp( + prefix=datetime.datetime.now().strftime('%Y-%m-%dT%H-%M-%S___'), dir=all_parents) + return cls._DEBUG_PARENT_DIR + + TEMPDIRS = set() + @classmethod + def remove_tempdirs(cls): + temp_dirs = getattr(cls, 'TEMPDIRS', None) + if temp_dirs is None: + return + + for path in temp_dirs: + shutil.rmtree(path, ignore_errors=True) + + cls.TEMPDIRS = None + + @classmethod + @contextlib.contextmanager + def set_keep_for_debug(cls, set_to=True): + """Keep temporary directories past program exit for debugging.""" + old_keep_for_debug = cls._KEEP_FOR_DEBUG + try: + cls._KEEP_FOR_DEBUG = set_to + yield + finally: + cls._KEEP_FOR_DEBUG = old_keep_for_debug + def __init__(self, custom_path=None): + if self.TEMPDIRS is None: + raise DirectoryCreatedPastAtExit() + + self._created_with_keep_for_debug = self._KEEP_FOR_DEBUG if custom_path: os.mkdir(custom_path) self.temp_dir = custom_path else: - self.temp_dir = tempfile.mkdtemp() - self._rmtree = shutil.rmtree + if self._created_with_keep_for_debug: + parent_dir = self._get_debug_parent_dir() + self.temp_dir = f'{parent_dir}/{self._increment_num_tempdir_created():05d}' + os.mkdir(self.temp_dir) + else: + self.temp_dir = tempfile.mkdtemp() + + if not self._created_with_keep_for_debug: + self.TEMPDIRS.add(self.temp_dir) def remove(self): """Remote the tmp dir""" if self.temp_dir: - self._rmtree(self.temp_dir, ignore_errors=True) + if not self._created_with_keep_for_debug: + shutil.rmtree(self.temp_dir, ignore_errors=True) + self.TEMPDIRS.remove(self.temp_dir) self.temp_dir = None def __del__(self): + temp_dirs = getattr(self, 'TEMPDIRS', None) + if temp_dirs is None: + # Do nothing if the atexit hook has already run. + return + self.remove() def relpath(self, name): @@ -72,6 +147,9 @@ def listdir(self): return os.listdir(self.temp_dir) +atexit.register(TempDirectory.remove_tempdirs) + + def tempdir(custom_path=None): """Create temp dir which deletes the contents when exit. diff --git a/python/tvm/contrib/xcode.py b/python/tvm/contrib/xcode.py index f78850d570e5..dd067c35bbcf 100644 --- a/python/tvm/contrib/xcode.py +++ b/python/tvm/contrib/xcode.py @@ -21,6 +21,7 @@ import os import sys import subprocess +import json from .._ffi.base import py_str from . import util @@ -170,6 +171,26 @@ def compile_metal(code, path_target=None, sdk="macosx"): return libbin +def compile_coreml(model, model_name="main", out_dir="."): + """Compile coreml model and return the compiled model path. + """ + mlmodel_path = os.path.join(out_dir, model_name + ".mlmodel") + mlmodelc_path = os.path.join(out_dir, model_name + ".mlmodelc") + metadata = { + "inputs": list(model.input_description), + "outputs": list(model.output_description) + } + # Use the description field to send info to CoreML runtime + model.short_description = json.dumps(metadata) + model.save(mlmodel_path) + + res = xcrun(["coremlcompiler", "compile", mlmodel_path, out_dir]) + if not os.path.isdir(mlmodelc_path): + raise RuntimeError("Compile failed: %s" % res) + + return mlmodelc_path + + class XCodeRPCServer(object): """Wrapper for RPC server diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index a429d0775dae..a19b097168c0 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -25,8 +25,8 @@ from tvm.runtime import ndarray from tvm.ir import container from tvm.ir import CallingConv -from tvm.target import codegen, BuildConfig -from tvm.tir import ir_pass +from tvm.ir.transform import PassContext +from tvm.target import codegen from tvm.te import tensor from tvm.te import schedule from tvm import target as _target @@ -57,7 +57,6 @@ def get_binds(args, compact=False, binds=None): The list of symbolic buffers of arguments. """ binds = {} if binds is None else binds.copy() - cfg = BuildConfig.current() arg_list = [] for x in args: if isinstance(x, tensor.Tensor): @@ -68,8 +67,6 @@ def get_binds(args, compact=False, binds=None): x.shape, dtype=x.dtype, name=x.name, - data_alignment=cfg.data_alignment, - offset_factor=cfg.offset_factor, buffer_type=buffer_type) binds[x] = buf arg_list.append(buf) @@ -84,28 +81,49 @@ def get_binds(args, compact=False, binds=None): return binds, arg_list -def form_body(sch): - """According to the given schedule, form the raw body +def form_irmodule(sch, args, name, binds): + """According to the given schedule, form a function. + Parameters ---------- sch : tvm.te.schedule.Schedule - The given scheduler to form the raw body + The given scheduler to form the raw body + + args : list of Buffer or Tensor or Var + The argument lists to the function. + + name : str + The name of result function. + + binds : dict of :any:`Tensor` to :any:`Buffer`, optional + The binds information Returns ------- The body formed according to the given schedule """ # normalize schedule first + pass_ctx = PassContext.current() sch = sch.normalize() bounds = schedule.InferBound(sch) stmt = schedule.ScheduleOps(sch, bounds) - stmt = ir_pass.InjectPrefetch(stmt) - return stmt + + compact = schedule.VerifyCompactBuffer(stmt) + binds, arg_list = get_binds(args, compact, binds) + + stmt = schedule.SchedulePostProcRewriteForTensorCore(stmt, sch, binds) + func = schedule.SchedulePostProcToPrimFunc(arg_list, stmt, binds) + + func = func.with_attr("global_symbol", name) + + if pass_ctx.config.get("tir.noalias", True): + func = func.with_attr("tir.noalias", True) + return tvm.IRModule({name: func}) def lower(sch, args, - name="default_function", + name="main", binds=None, simple_mode=False): """Lowering step before build into target. @@ -136,10 +154,12 @@ def lower(sch, The result IRModule, if simple_mode=False Then the Stmt before make api is returned. """ - cfg = BuildConfig.current() - add_lower_pass = cfg.add_lower_pass if cfg.add_lower_pass else [] - if cfg.dump_pass_ir: - add_lower_pass = BuildConfig._dump_ir.decorate_custompass(add_lower_pass) + # config setup + pass_ctx = PassContext.current() + instrument_bound_checkers = bool(pass_ctx.config.get("tir.instrument_bound_checkers", False)) + disable_vectorize = bool(pass_ctx.config.get("tir.disable_vectorize", False)) + add_lower_pass = pass_ctx.config.get("tir.add_lower_pass", []) + lower_phase0 = [x[1] for x in add_lower_pass if x[0] == 0] lower_phase1 = [x[1] for x in add_lower_pass if x[0] == 1] lower_phase2 = [x[1] for x in add_lower_pass if x[0] == 2] @@ -147,59 +167,48 @@ def lower(sch, # Phase 0 if isinstance(sch, schedule.Schedule): - stmt = form_body(sch) - - for f in lower_phase0: - stmt = f(stmt) - - compact = ir_pass.VerifyCompactBuffer(stmt) - binds, arg_list = get_binds(args, compact, binds) + mod = form_irmodule(sch, args, name, binds) + else: + mod = sch + pass_list = lower_phase0 # Phase 1 - stmt = ir_pass.RewriteForTensorCore(stmt, sch, binds) - stmt = ir_pass.StorageFlatten(stmt, binds, 64, cfg.instrument_bound_checkers) - stmt = ir_pass.NarrowDataType(stmt, 32) - stmt = ir_pass.CanonicalSimplify(stmt) - for f in lower_phase1: - stmt = f(stmt) + pass_list += [ + tvm.tir.transform.InjectPrefetch(), + tvm.tir.transform.StorageFlatten(64, instrument_bound_checkers), + tvm.tir.transform.NarrowDataType(32), + tvm.tir.transform.Simplify(), + ] + pass_list += lower_phase1 # Phase 2 if not simple_mode: - stmt = ir_pass.LoopPartition(stmt, cfg.partition_const_loop) - if cfg.disable_vectorize: - stmt = ir_pass.SkipVectorize(stmt) - else: - stmt = ir_pass.VectorizeLoop(stmt) - stmt = ir_pass.InjectVirtualThread(stmt) - stmt = ir_pass.InjectDoubleBuffer(stmt, cfg.double_buffer_split_loop) - stmt = ir_pass.StorageRewrite(stmt) - stmt = ir_pass.UnrollLoop( - stmt, - cfg.auto_unroll_max_step, - cfg.auto_unroll_max_depth, - cfg.auto_unroll_max_extent, - cfg.unroll_explicit) - for f in lower_phase2: - stmt = f(stmt) + pass_list += [(tvm.tir.transform.LoopPartition())] + + pass_list += [ + tvm.tir.transform.VectorizeLoop(not disable_vectorize), + tvm.tir.transform.InjectVirtualThread(), + tvm.tir.transform.InjectDoubleBuffer(), + tvm.tir.transform.StorageRewrite(), + tvm.tir.transform.UnrollLoop() + ] + pass_list += lower_phase2 # Phase 3 - stmt = ir_pass.Simplify(stmt) - stmt = ir_pass.RemoveNoOp(stmt) - if not cfg.disable_select_rewriting: - stmt = ir_pass.RewriteUnsafeSelect(stmt) - for f in lower_phase3: - stmt = f(stmt) + pass_list += [ + tvm.tir.transform.Simplify(), + tvm.tir.transform.RemoveNoOp(), + ] + + pass_list += [tvm.tir.transform.RewriteUnsafeSelect()] + pass_list += lower_phase3 + # Instrument BoundCheckers - if cfg.instrument_bound_checkers: - stmt = ir_pass.InstrumentBoundCheckers(stmt) - if simple_mode: - return stmt - - f = tvm.tir.PrimFunc(arg_list, stmt).with_attr( - "global_symbol", tvm.runtime.String(name)) - if cfg.restricted_func: - f = f.with_attr("tir.noalias", True) - mod = tvm.IRModule({name: f}) + if instrument_bound_checkers: + pass_list += [tvm.tir.transform.InstrumentBoundCheckers()] + + optimize = tvm.transform.Sequential(pass_list) + mod = optimize(mod) return mod @@ -232,12 +241,12 @@ def _build_for_device(input_mod, target, target_host): mod_mixed = input_mod mod_mixed = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(mod_mixed) - tvm.tir.analysis.verify_memory(mod_mixed) - opt_mixed = [] + opt_mixed = [tvm.tir.transform.VerifyMemory()] if len(mod_mixed.functions) == 1: opt_mixed += [tvm.tir.transform.Apply(lambda f: f.with_attr("tir.is_entry_func", True))] - if BuildConfig.current().detect_global_barrier: + + if PassContext.current().config.get("tir.detect_global_barrier", False): opt_mixed += [tvm.tir.transform.ThreadSync("global")] opt_mixed += [tvm.tir.transform.ThreadSync("shared"), tvm.tir.transform.ThreadSync("warp"), @@ -254,6 +263,7 @@ def _build_for_device(input_mod, target, target_host): lambda f: "calling_conv" in f.attrs and f.attrs["calling_conv"].value == CallingConv.DEVICE_KERNEL_LAUNCH), tvm.tir.transform.LowerWarpMemory(), + tvm.tir.transform.Simplify(), tvm.tir.transform.LowerDeviceStorageAccessInfo(), tvm.tir.transform.LowerIntrin()]) mod_dev = opt_device(mod_mixed) diff --git a/python/tvm/error.py b/python/tvm/error.py index 4c3e6060c25a..b3502f6b0ead 100644 --- a/python/tvm/error.py +++ b/python/tvm/error.py @@ -57,6 +57,11 @@ def __init__(self, msg): register_error("KeyError", KeyError) +@register_error +class RPCError(RuntimeError): + """Error thrown by the remote server handling the RPC call.""" + + @register_error class OpError(TVMError): """Base class of all operator errors in frontends.""" diff --git a/python/tvm/exec/rpc_proxy.py b/python/tvm/exec/rpc_proxy.py index 4cf341335ea7..eb802866efba 100644 --- a/python/tvm/exec/rpc_proxy.py +++ b/python/tvm/exec/rpc_proxy.py @@ -29,17 +29,22 @@ def find_example_resource(): """Find resource examples.""" curr_path = os.path.dirname(os.path.realpath(os.path.expanduser(__file__))) - base_path = os.path.join(curr_path, "../../../") - index_page = os.path.join(base_path, "web/example_rpc.html") - js_files = [ - os.path.join(base_path, "web/tvm_runtime.js"), - os.path.join(base_path, "build/libtvm_web_runtime.js"), - os.path.join(base_path, "build/libtvm_web_runtime.js.mem") + base_path = os.path.abspath(os.path.join(curr_path, "..", "..", "..")) + index_page = os.path.join(base_path, "web", "apps", "browser", "rpc_server.html") + resource_files = [ + os.path.join(base_path, "web", "dist", "tvmjs.bundle.js"), + os.path.join(base_path, "web", "dist", "wasm", "tvmjs_runtime.wasi.js") ] - for fname in [index_page] + js_files: + resource_base = os.path.join(base_path, "web", "dist", "www") + if os.path.isdir(resource_base): + for fname in os.listdir(resource_base): + full_name = os.path.join(resource_base, fname) + if os.path.isfile(full_name): + resource_files.append(full_name) + for fname in [index_page] + resource_files: if not os.path.exists(fname): raise RuntimeError("Cannot find %s" % fname) - return index_page, js_files + return index_page, resource_files def main(args): @@ -69,7 +74,7 @@ def main(args): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--host', type=str, default="0.0.0.0", + parser.add_argument('--host', type=str, default="localhost", help='the hostname of the server') parser.add_argument('--port', type=int, default=9090, help='The port of the RPC') diff --git a/python/tvm/exec/rpc_server.py b/python/tvm/exec/rpc_server.py index dbb690267e2a..e281e58e3879 100644 --- a/python/tvm/exec/rpc_server.py +++ b/python/tvm/exec/rpc_server.py @@ -20,6 +20,7 @@ import argparse import ast +import json import multiprocessing import sys import logging @@ -41,7 +42,7 @@ def main(args): tracker_addr = (url, port) if not args.key: raise RuntimeError( - "Need key to present type of resource when tracker is available") + 'Need key to present type of resource when tracker is available') else: tracker_addr = None @@ -75,8 +76,8 @@ def init_utvm(args): dev_config = json.load(dev_conf_file) else: dev_config_args = ast.literal_eval(args.utvm_dev_config_args) - default_config_func = micro.device.get_device_funcs(args.utvm_dev_id)['default_config'] - dev_config = default_config_func(*dev_config_args) + generate_config_func = micro.device.get_device_funcs(args.utvm_dev_id)['generate_config'] + dev_config = generate_config_func(*dev_config_args) if args.utvm_dev_config or args.utvm_dev_id: # add MicroTVM overrides @@ -100,8 +101,8 @@ def server_shutdown(): parser.add_argument('--port-end', type=int, default=9199, help='The end search port of the RPC') parser.add_argument('--tracker', type=str, - help="The address of RPC tracker in host:port format. " - "e.g. (10.77.1.234:9190)") + help=("The address of RPC tracker in host:port format. " + "e.g. (10.77.1.234:9190)")) parser.add_argument('--key', type=str, default="", help="The key used to identify the device type in tracker.") parser.add_argument('--silent', action='store_true', @@ -110,17 +111,24 @@ def server_shutdown(): help="Additional library to load") parser.add_argument('--no-fork', dest='fork', action='store_false', help="Use spawn mode to avoid fork. This option \ - is able to avoid potential fork problems with Metal, OpenCL \ - and ROCM compilers.") + is able to avoid potential fork problems with Metal, OpenCL \ + and ROCM compilers.") parser.add_argument('--custom-addr', type=str, help="Custom IP Address to Report to RPC Tracker") parser.add_argument('--utvm-dev-config', type=str, - help='JSON config file for the target device (if using MicroTVM)') - parser.add_argument('--utvm-dev-id', type=str, - help='Unique ID for the target device (if using MicroTVM)') + help=('JSON config file for the target device (if using MicroTVM). ' + 'This file should contain serialized output similar to that returned ' + "from the device module's generate_config. Can't be specified when " + '--utvm-dev-config-args is specified.')) parser.add_argument('--utvm-dev-config-args', type=str, - help=('Python list of literals required to generate a default' - ' MicroTVM config (if --utvm-dev-id is specified)')) + help=("Arguments to the device module's generate_config function. " + 'Must be a python literal parseable by literal_eval. If specified, ' + "the device configuration is generated using the device module's " + "generate_config. Can't be specified when --utvm-dev-config is " + "specified.")) + parser.add_argument('--utvm-dev-id', type=str, + help=('Unique ID for the target device (if using MicroTVM). Should ' + 'match the name of a module underneath tvm.micro.device).')) parser.set_defaults(fork=True) args = parser.parse_args() diff --git a/python/tvm/ir/__init__.py b/python/tvm/ir/__init__.py index 1aabf3e5bca7..f1d1d502a27e 100644 --- a/python/tvm/ir/__init__.py +++ b/python/tvm/ir/__init__.py @@ -23,6 +23,7 @@ from .tensor_type import TensorType from .type_relation import TypeCall, TypeRelation from .expr import BaseExpr, PrimExpr, RelayExpr, GlobalVar, Range +from .op import Op, register_op_attr from .function import CallingConv, BaseFunc from .adt import Constructor, TypeData from .module import IRModule diff --git a/python/tvm/ir/container.py b/python/tvm/ir/container.py index 11ef107f5514..e7374de3a63e 100644 --- a/python/tvm/ir/container.py +++ b/python/tvm/ir/container.py @@ -22,7 +22,7 @@ from tvm.runtime import _ffi_node_api -@tvm._ffi.register_object +@tvm._ffi.register_object("Array") class Array(Object): """Array container of TVM. @@ -61,14 +61,20 @@ def items(self): def __len__(self): return _ffi_node_api.MapSize(self) + def get(self, key, default=None): + """Get an element with a default value. -@tvm._ffi.register_object -class StrMap(Map): - """A special map container that has str as key. + Parameters + ---------- + key : object + The attribute key. - You can use convert to create a dict[str->Object] into a Map. - """ - def items(self): - """Get the items from the map""" - akvs = _ffi_node_api.MapItems(self) - return [(akvs[i].value, akvs[i+1]) for i in range(0, len(akvs), 2)] + default : object + The default object. + + Returns + ------- + value: object + The result value. + """ + return self[key] if key in self else default diff --git a/python/tvm/ir/json_compact.py b/python/tvm/ir/json_compact.py index 9a881cfa6d5b..8b7568574b47 100644 --- a/python/tvm/ir/json_compact.py +++ b/python/tvm/ir/json_compact.py @@ -16,6 +16,9 @@ # under the License. """Tool to upgrade json from historical versions.""" import json +import tvm.ir +import tvm.runtime + def create_updater(node_map, from_ver, to_ver): """Create an updater to update json loaded data. @@ -41,8 +44,12 @@ def _updater(data): nodes = data["nodes"] for idx, item in enumerate(nodes): f = node_map.get(item["type_key"], None) - if f: - nodes[idx] = f(item, nodes) + if isinstance(f, list): + for fpass in f: + item = fpass(item, nodes) + elif f: + item = f(item, nodes) + nodes[idx] = item data["attrs"]["tvm_version"] = to_ver return data return _updater @@ -80,17 +87,35 @@ def _convert(item, _): return _convert def _update_global_key(item, _): - item["repr_str"] = item["global_key"] - del item["global_key"] + if "global_key" in item: + item["repr_str"] = item["global_key"] + del item["global_key"] return item + def _update_from_std_str(key): + def _convert(item, nodes): + str_val = item["attrs"][key] + jdata = json.loads(tvm.ir.save_json(tvm.runtime.String(str_val))) + root_idx = jdata["root"] + val = jdata["nodes"][root_idx] + sidx = len(nodes) + nodes.append(val) + item["attrs"][key] = '%d' % sidx + return item + + return _convert + + node_map = { # Base IR "SourceName": _update_global_key, "EnvFunc": _update_global_key, - "relay.Op": _update_global_key, - "relay.TypeVar": _ftype_var, - "relay.GlobalTypeVar": _ftype_var, + "relay.Op": [_update_global_key, _rename("Op")], + "relay.TypeVar": [_ftype_var, _update_from_std_str("name_hint")], + "TypeVar": _update_from_std_str("name_hint"), + "relay.Id": [_update_from_std_str("name_hint")], + "relay.GlobalTypeVar": [_ftype_var, _update_from_std_str("name_hint")], + "GlobalTypeVar": _update_from_std_str("name_hint"), "relay.Type": _rename("Type"), "relay.TupleType": _rename("TupleType"), "relay.TypeConstraint": _rename("TypeConstraint"), @@ -98,18 +123,63 @@ def _update_global_key(item, _): "relay.IncompleteType": _rename("IncompleteType"), "relay.TypeRelation": _rename("TypeRelation"), "relay.TypeCall": _rename("TypeCall"), + "relay.Constructor": [_update_from_std_str("name_hint")], "relay.Module": _rename("IRModule"), "relay.SourceName": _rename("SourceName"), "relay.Span": _rename("Span"), - "relay.GlobalVar": _rename("GlobalVar"), + "relay.GlobalVar": [_rename("GlobalVar"), _update_from_std_str("name_hint")], + "GlobalVar": _update_from_std_str("name_hint"), "relay.Pass": _rename("transform.Pass"), "relay.PassInfo": _rename("transform.PassInfo"), "relay.PassContext": _rename("transform.PassContext"), "relay.ModulePass": _rename("transform.ModulePass"), "relay.Sequential": _rename("transform.Sequential"), + "StrMap": _rename("Map"), # TIR - "Variable": _update_tir_var("tir.Var"), - "SizeVar": _update_tir_var("tir.SizeVar"), + "Variable": [_update_tir_var("tir.Var"), _update_from_std_str("name")], + "SizeVar": [_update_tir_var("tir.SizeVar"), _update_from_std_str("name")], + "StringImm": [_rename("tir.StringImm"), _update_from_std_str("value")], + "Cast": [_rename("tir.Cast")], + "Add": [_rename("tir.Add")], + "Sub": [_rename("tir.Sub")], + "Mul": [_rename("tir.Mul")], + "Div": [_rename("tir.Div")], + "Mod": [_rename("tir.Mod")], + "FloorDiv": [_rename("tir.FloorDiv")], + "FloorMod": [_rename("tir.FloorMod")], + "Min": [_rename("tir.Min")], + "Max": [_rename("tir.Max")], + "EQ": [_rename("tir.EQ")], + "NE": [_rename("tir.NE")], + "LT": [_rename("tir.LT")], + "LE": [_rename("tir.LE")], + "GT": [_rename("tir.GT")], + "GE": [_rename("tir.GE")], + "And": [_rename("tir.And")], + "Or": [_rename("tir.Or")], + "Not": [_rename("tir.Not")], + "Select": [_rename("tir.Select")], + "Load": [_rename("tir.Load")], + "BufferLoad": [_rename("tir.BufferLoad")], + "Ramp": [_rename("tir.Ramp")], + "Broadcast": [_rename("tir.Broadcast")], + "Shuffle": [_rename("tir.Shuffle")], + "Call": [_rename("tir.Call"), _update_from_std_str("name")], + "Let": [_rename("tir.Let")], + "Any": [_rename("tir.Any")], + "LetStmt": [_rename("tir.LetStmt")], + "AssertStmt": [_rename("tir.AssertStmt")], + "Store": [_rename("tir.Store")], + "BufferStore": [_rename("tir.BufferStore")], + "BufferRealize": [_rename("tir.BufferRealize")], + "Allocate": [_rename("tir.Allocate")], + "IfThenElse": [_rename("tir.IfThenElse")], + "Evaluate": [_rename("tir.Evaluate")], + "Prefetch": [_rename("tir.Prefetch")], + "AttrStmt": [_rename("tir.AttrStmt"), _update_from_std_str("attr_key")], + "Layout": [_rename("tir.Layout"), _update_from_std_str("name")], + "Buffer": [ + _rename("tir.Buffer"), _update_from_std_str("name"), _update_from_std_str("scope")], } return create_updater(node_map, "0.6", "0.7") diff --git a/python/tvm/ir/op.py b/python/tvm/ir/op.py new file mode 100644 index 000000000000..da546ceb0eec --- /dev/null +++ b/python/tvm/ir/op.py @@ -0,0 +1,114 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Primitive operators in the TVM IR.""" +import tvm._ffi +from . expr import RelayExpr +from . import _ffi_api + + +@tvm._ffi.register_object("Op") +class Op(RelayExpr): + """Primitive operator in the IR.""" + def __init__(self): + raise RuntimeError("Cannot create op, use get instead") + + @staticmethod + def get(op_name): + """Get the Op for a given name + + Parameters + ---------- + op_name : str + The operator name + + Returns + ------- + op : Op + The op of the corresponding name + """ + return _ffi_api.GetOp(op_name) + + def get_attr(self, attr_name): + """Get additional attribute about the operator. + + Parameters + ---------- + attr_name : str + The attribute name. + + Returns + ------- + value : object + The attribute value + """ + return _ffi_api.OpGetAttr(self, attr_name) + + def set_attr(self, attr_name, value, plevel=10): + """Set attribute about the operator. + + Parameters + ---------- + attr_name : str + The attribute name + + value : object + The attribute value + + plevel : int + The priority level + """ + _ffi_api.OpSetAttr(self, attr_name, value, plevel) + + def reset_attr(self, attr_name): + """Reset attribute about the operator. + + Parameters + ---------- + attr_name : str + The attribute name + """ + _ffi_api.OpResetAttr(self, attr_name) + + +def register_op_attr(op_name, attr_key, value=None, level=10): + """Register an operator property of an operator by name. + + Parameters + ---------- + op_name : str + The name of operator + + attr_key : str + The attribute name. + + value : object, optional + The value to set + + level : int, optional + The priority level + + Returns + ------- + fregister : function + Register function if value is not specified. + """ + def _register(v): + """internal register function""" + _ffi_api.RegisterOpAttr(op_name, attr_key, v, level) + return v + return _register(value) if value is not None else _register diff --git a/python/tvm/ir/transform.py b/python/tvm/ir/transform.py index 614f9690903a..358ad19ff21a 100644 --- a/python/tvm/ir/transform.py +++ b/python/tvm/ir/transform.py @@ -21,9 +21,7 @@ import functools import tvm._ffi - import tvm.runtime -from tvm.runtime import ndarray as _nd from . import _ffi_transform_api @@ -61,30 +59,21 @@ class PassContext(tvm.runtime.Object): opt_level : Optional[int] The optimization level of this pass. - fallback_device : Optional[Union[int, str, TVMContext]] - The fallback device type. It is also used as the default device for - operators that are not annotated during heterogeneous execution. - required_pass : Optional[Union[List[str], Set[str], Tuple[str]]] The list of passes that are required by a certain pass. disabled_pass : Optional[Union[List[str], Set[str], Tuple[str]]] The list of passes that are disabled. + + config : Optional[Dict[str, Object]] + Additional configurations for specific passes. """ def __init__(self, opt_level=2, - fallback_device=_nd.cpu(), required_pass=None, disabled_pass=None, - trace=None): - if isinstance(fallback_device, str): - fallback_device = _nd.context(fallback_device).device_type - elif isinstance(fallback_device, tvm.runtime.TVMContext): - fallback_device = fallback_device.device_type - if not isinstance(fallback_device, int): - raise TypeError("fallback_device is expected to be the type of " + - "int/str/TVMContext.") - + trace=None, + config=None): required = list(required_pass) if required_pass else [] if not isinstance(required, (list, tuple)): raise TypeError("required_pass is expected to be the type of " + @@ -95,9 +84,9 @@ def __init__(self, raise TypeError("disabled_pass is expected to be the type of " + "list/tuple/set.") + config = config if config else None self.__init_handle_by_constructor__(_ffi_transform_api.PassContext, opt_level, - fallback_device, required, - disabled, trace) + required, disabled, trace, config) def __enter__(self): _ffi_transform_api.EnterPassContext(self) @@ -157,11 +146,6 @@ class Sequential(Pass): """A pass that works on a sequence of pass objects. Multiple passes can be executed sequentially using this class. - Some typical usage of the sequential pass are: - 1. Users provide a list of passes for optimization. - 2. Only an optimization level is provided so that the backend system has - to glob all passes at this level and below to perform the optimizations. - Note that users can also provide a series of passes that they don't want to apply when running a sequential pass. Pass dependency will be resolved in the backend as well. @@ -173,6 +157,9 @@ class Sequential(Pass): opt_level : Optional[int] The optimization level of this sequential pass. + The opt_level of a default sequential pass is set to 0. + Note that some of the passes within the Sequantial may still not be executed + if their opt_level is higher than the provided opt_level. name : Optional[str] The name of the sequential pass. diff --git a/python/tvm/micro/__init__.py b/python/tvm/micro/__init__.py index 9e984c08fe2c..7c1389cc4eef 100644 --- a/python/tvm/micro/__init__.py +++ b/python/tvm/micro/__init__.py @@ -17,6 +17,7 @@ """MicroTVM module for bare-metal backends""" from ..contrib import binutil -from .base import Session, create_micro_mod, cross_compiler -from .base import LibType, get_micro_host_driven_dir, get_micro_device_dir +from .base import DEVICE_SECTIONS +from .base import Session, create_micro_mod, cross_compiler, LibType +from .base import get_micro_host_driven_dir, get_micro_device_dir from . import device diff --git a/python/tvm/micro/base.py b/python/tvm/micro/base.py index 9f50f9855303..cb3c8430a6c5 100644 --- a/python/tvm/micro/base.py +++ b/python/tvm/micro/base.py @@ -19,6 +19,7 @@ from __future__ import absolute_import import os +import re import sys from enum import Enum @@ -28,6 +29,18 @@ from tvm.contrib import util as _util from tvm.contrib import cc as _cc +# all sections that comprise a device's memory layout, in order from lowest +# starting address to highest +DEVICE_SECTIONS = [ + "text", + "rodata", + "data", + "bss", + "args", + "heap", + "workspace", + "stack", +] class LibType(Enum): """Enumeration of library types that can be compiled and loaded onto a device""" @@ -51,9 +64,9 @@ class Session: .. code-block:: python c_mod = ... # some module generated with "c" as the target - dev_config = micro.device.arm.stm32f746xx.default_config("127.0.0.1", 6666) + dev_config = micro.device.arm.stm32f746xx.default_config('127.0.0.1', 6666) with tvm.micro.Session(dev_config) as sess: - micro_mod = create_micro_mod(c_mod, dev_config) + micro_mod = sess.create_micro_mod(c_mod) """ def __init__(self, config): @@ -62,19 +75,20 @@ def __init__(self, config): # grab a binutil instance from the ID in the config dev_funcs = tvm.micro.device.get_device_funcs(config["device_id"]) - self.create_micro_lib = dev_funcs["create_micro_lib"] self.toolchain_prefix = config["toolchain_prefix"] self.mem_layout = config["mem_layout"] - self.word_size = config["word_size"] + self.word_size_bits = config["word_size_bits"] self.thumb_mode = config["thumb_mode"] + self.use_device_timer = config["use_device_timer"] self.comms_method = config["comms_method"] # First, find and compile runtime library. runtime_src_path = os.path.join(get_micro_host_driven_dir(), "utvm_runtime.c") tmp_dir = _util.tempdir() runtime_obj_path = tmp_dir.relpath("utvm_runtime.obj") - self.create_micro_lib(runtime_obj_path, runtime_src_path, LibType.RUNTIME) - #input(f"check {runtime_obj_path}: ") + options = ["-I{}".format(get_micro_host_driven_dir())] + dev_funcs["create_micro_lib"]( + runtime_obj_path, runtime_src_path, LibType.RUNTIME, options=options) comms_method = config["comms_method"] if comms_method == "openocd": @@ -86,6 +100,8 @@ def __init__(self, config): else: raise RuntimeError(f"unknown communication method: f{self.comms_method}") + assert all(map(lambda sec: sec in self.mem_layout, DEVICE_SECTIONS)), \ + "not all sections have an assigned memory layout" self.module = _CreateSession( comms_method, runtime_obj_path, @@ -106,12 +122,16 @@ def __init__(self, config): self.mem_layout["workspace"]["size"], self.mem_layout["stack"].get("start", 0), self.mem_layout["stack"]["size"], - self.word_size, + self.word_size_bits, self.thumb_mode, + self.use_device_timer, server_addr, - server_port) + server_port, + config.get("debug_func")) self._enter = self.module["enter"] self._exit = self.module["exit"] + self.get_last_batch_time = self.module["get_last_batch_time"] + self.get_last_batch_cycles = self.module["get_last_batch_cycles"] def _check_system(self): """Check if the user's system is supported by MicroTVM. @@ -119,7 +139,7 @@ def _check_system(self): Raises error if not supported. """ if not sys.platform.startswith("linux"): - raise RuntimeError("MicroTVM is currently only supported on Linux hosts") + raise RuntimeError("MicroTVM is currently only supported on Linux") # TODO(weberlo): Add 32-bit support. # It's primarily the compilation pipeline that isn't compatible. if sys.maxsize <= 2**32: @@ -133,44 +153,91 @@ def __exit__(self, exc_type, exc_value, exc_traceback): self._exit() -def create_micro_mod(c_mod, dev_config): +def _calc_max_workspace_usage(src): + # TODO factor in alignment to the calculation (alloc sizes will be aligned up to the word size) + alloc_re = re.compile( + r'.*\* ?(.+) = (\(.+\))? TVMBackendAllocWorkspace\(.+, .+, \(uint64_t\)(.+), .+, .+\).*') + free_re = re.compile(r'.*if \(TVMBackendFreeWorkspace\(.+, .+, (\(void\*\))? (.+)\) != 0\) {.*') + max_usage = 0 + alloc_map = {} + for line in src.split("\n"): + if line.strip().startswith("//"): + continue + match = alloc_re.match(line) + if match is not None: + alloc_map[match.group(1)] = int(match.group(3)) + max_usage = max(max_usage, sum(alloc_map.values())) + else: + match = free_re.match(line) + if match is not None: + print(alloc_map) + del alloc_map[match.group(2)] + return max_usage + + +def create_micro_mod(c_mod, dev_config, lib_src_paths=None, lib_headers=None, + lib_include_paths=None): """Produces a micro module from a given module. Parameters ---------- - c_mod : tvm.runtime.Module + c_mod : tvm.module.Module module with "c" as its target backend - dev_config : Dict[str, Any] - MicroTVM config dict for the target device + lib_src_paths: TODO + TODO + + lib_headers: TODO + TODO + + lib_include_paths: TODO + TODO Return ------ - micro_mod : tvm.runtim.Module + micro_mod : tvm.module.Module micro module for the target device """ temp_dir = _util.tempdir() lib_obj_path = temp_dir.relpath("dev_lib.obj") + # TODO use dev config to dispatch on the type of C codegen to run through + # (e.g., CodeGenCArm, CodeGenCHost, CodeGenCRiscV) c_mod.export_library( lib_obj_path, - fcompile=cross_compiler(dev_config, LibType.OPERATOR)) + fcompile=cross_compiler( + dev_config, + LibType.OPERATOR, + lib_src_paths=lib_src_paths, + lib_headers=lib_headers, + lib_include_paths=lib_include_paths)) micro_mod = tvm.runtime.load_module(lib_obj_path) return micro_mod -def cross_compiler(dev_config, lib_type): - """Create a cross-compile function that wraps `create_lib` for a `Binutil` instance. +def cross_compiler(dev_config, lib_type, lib_src_paths=None, lib_headers=None, + lib_include_paths=None): + """Create a cross compile function that wraps `create_lib` for a `Binutil` instance. For use in `tvm.runtime.Module.export_library`. Parameters ---------- - dev_config : Dict[str, Any] - MicroTVM config dict for the target device + create_micro_lib : func + function for creating MicroTVM libraries for a specific device (e.g., + `tvm.micro.device.get_device_funcs('arm.stm32f746xx')['create_micro_lib']`) lib_type : micro.LibType whether to compile a MicroTVM runtime or operator library + lib_src_paths: TODO + TODO + + lib_headers: TODO + e.g., `['cmsis_gcc.h', 'arm_math.h']` + + lib_include_paths: TODO + TODO + Return ------ func : Callable[[str, str, Optional[str]], None] @@ -183,16 +250,49 @@ def cross_compiler(dev_config, lib_type): c_mod = ... # some module generated with "c" as the target fcompile = tvm.micro.cross_compiler(dev_config, LibType.OPERATOR) - c_mod.export_library("dev_lib.obj", fcompile=fcompile) + c_mod.export_library('dev_lib.obj', fcompile=fcompile) """ - dev_funcs = tvm.micro.device.get_device_funcs(dev_config['device_id']) - create_micro_lib = dev_funcs['create_micro_lib'] + assert (lib_headers is None) == (lib_include_paths is None), \ + "must specify both `lib_headers` and `lib_include_paths` or neither" + + if lib_src_paths is None: + lib_src_paths = [] + if lib_include_paths is None: + lib_include_paths = [] + include_options = [] + for include_path in lib_include_paths: + include_options.append("-I") + include_options.append(include_path) + create_micro_lib = tvm.micro.device.get_device_funcs( + dev_config["device_id"])["create_micro_lib"] + mem_layout = dev_config["mem_layout"] + def compile_func(obj_path, src_path, **kwargs): if isinstance(obj_path, list): obj_path = obj_path[0] if isinstance(src_path, list): src_path = src_path[0] - create_micro_lib(obj_path, src_path, lib_type, kwargs.get("options", None)) + options = kwargs.get("options", []) + options += include_options + + # check that workspace allocations don't exceed available workspace memory + with open(src_path) as f: + src_contents = f.read() + max_ws_usage = _calc_max_workspace_usage(src_contents) + available_mem = mem_layout["workspace"]["size"] + if max_ws_usage > available_mem: + raise RuntimeError(f"workspace allocations in library ({max_ws_usage}) " + f"exceed available memory ({available_mem})") + # inject headers into new source path, if requested + if lib_headers: + headers_to_inject = "\n".join(map(lambda s: f"#include <{s}>", lib_headers)) + "\n" + new_src_contents = headers_to_inject + src_contents + tmp_dir = _util.tempdir() + src_path = tmp_dir.relpath(os.path.basename(src_path)) + with open(src_path, "w") as f: + f.write(new_src_contents) + + create_micro_lib(obj_path, src_path, lib_type, options, lib_src_paths=lib_src_paths) return _cc.cross_compiler(compile_func, output_format="obj") diff --git a/python/tvm/micro/device/__init__.py b/python/tvm/micro/device/__init__.py index 1ccd6847edd8..89731b9aa797 100644 --- a/python/tvm/micro/device/__init__.py +++ b/python/tvm/micro/device/__init__.py @@ -16,7 +16,8 @@ # under the License. """Device-specific configuration for MicroTVM""" -from .base import register_device, get_device_funcs, create_micro_lib_base +from .base import create_micro_lib_base, gen_mem_layout +from .base import MemConstraint, register_device, get_device_funcs from . import host from . import arm from . import riscv_spike diff --git a/python/tvm/micro/device/arm/stm32f746xx.py b/python/tvm/micro/device/arm/stm32f746xx.py index 31b44cf9d36b..997093b0c349 100644 --- a/python/tvm/micro/device/arm/stm32f746xx.py +++ b/python/tvm/micro/device/arm/stm32f746xx.py @@ -14,13 +14,32 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Compilation and config definitions for ARM STM32F746XX devices""" -from .. import create_micro_lib_base, register_device +"""Compilation and config definitions for Arm STM32F746XX devices""" +import os +from .. import create_micro_lib_base, register_device, gen_mem_layout, MemConstraint DEVICE_ID = "arm.stm32f746xx" TOOLCHAIN_PREFIX = "arm-none-eabi-" +WORD_SIZE_BITS = 32 +# +# [Device Memory Layout] +# RAM (rwx) : START = 0x20000000, LENGTH = 320K +# Flash (rx) : START = 0x8000000, LENGTH = 1024K +# +BASE_ADDR = 0x20000000 +AVAILABLE_MEM = 320000 +DEFAULT_SECTION_CONSTRAINTS = { + "text": (18000, MemConstraint.ABSOLUTE_BYTES), + "rodata": (100, MemConstraint.ABSOLUTE_BYTES), + "data": (100, MemConstraint.ABSOLUTE_BYTES), + "bss": (640, MemConstraint.ABSOLUTE_BYTES), + "args": (4096, MemConstraint.ABSOLUTE_BYTES), + "heap": (100.0, MemConstraint.WEIGHT), + "workspace": (64000, MemConstraint.ABSOLUTE_BYTES), + "stack": (32, MemConstraint.ABSOLUTE_BYTES), +} -def create_micro_lib(obj_path, src_path, lib_type, options=None): +def create_micro_lib(obj_path, src_path, lib_type, options=None, lib_src_paths=None): """Wrapper over `create_micro_lib_base` to add device-specific options Parameters @@ -36,23 +55,40 @@ def create_micro_lib(obj_path, src_path, lib_type, options=None): options : Optional[List[str]] additional options to pass to GCC + + lib_src_paths : Optional[List[str]] + TODO """ if options is None: options = [] + else: + options = list(options) + options += [ + # TODO(weberlo): make a debug flag + "-O2", "-mcpu=cortex-m7", "-mlittle-endian", "-mfloat-abi=hard", "-mfpu=fpv5-sp-d16", "-mthumb", + "-ffast-math", "-gdwarf-5", + "-DARM_MATH_CM7", + "-D__FPU_PRESENT=1U", + "-DARM_MATH_DSP", + "-Wno-unused-variable", + "-Wno-unused-parameter", + "-I{}".format(os.environ["CMSIS_ST_PATH"]), + "-I{}/Core/Include".format(os.environ["CMSIS_ST_PATH"]) ] create_micro_lib_base( - obj_path, src_path, TOOLCHAIN_PREFIX, DEVICE_ID, lib_type, options=options) + obj_path, src_path, TOOLCHAIN_PREFIX, DEVICE_ID, lib_type, options=options, + lib_src_paths=lib_src_paths) -def default_config(server_addr, server_port): - """Generates a default configuration for ARM STM32F746XX devices +def generate_config(server_addr, server_port, section_constraints=None): + """Generates a configuration for Arm STM32F746XX devices Parameters ---------- @@ -62,55 +98,23 @@ def default_config(server_addr, server_port): server_port : int port of OpenOCD server to connect to + section_constraints: Optional[Dict[str, [Number, MemConstraint]]] + maps section name to the quantity of available memory + Return ------ config : Dict[str, Any] MicroTVM config dict for this device """ + if section_constraints is None: + section_constraints = DEFAULT_SECTION_CONSTRAINTS return { "device_id": DEVICE_ID, "toolchain_prefix": TOOLCHAIN_PREFIX, - # - # [Device Memory Layout] - # RAM (rwx) : START = 0x20000000, LENGTH = 320K - # FLASH (rx) : START = 0x8000000, LENGTH = 1024K - # - "mem_layout": { - "text": { - "start": 0x20000180, - "size": 20480, - }, - "rodata": { - "start": 0x20005180, - "size": 20480, - }, - "data": { - "start": 0x2000a180, - "size": 768, - }, - "bss": { - "start": 0x2000a480, - "size": 768, - }, - "args": { - "start": 0x2000a780, - "size": 1280, - }, - "heap": { - "start": 0x2000ac80, - "size": 262144, - }, - "workspace": { - "start": 0x2004ac80, - "size": 20480, - }, - "stack": { - "start": 0x2004fc80, - "size": 80, - }, - }, - "word_size": 4, + "mem_layout": gen_mem_layout(BASE_ADDR, AVAILABLE_MEM, WORD_SIZE_BITS, section_constraints), + "word_size_bits": WORD_SIZE_BITS, "thumb_mode": True, + "use_device_timer": True, "comms_method": "openocd", "server_addr": server_addr, "server_port": server_port, @@ -119,5 +123,5 @@ def default_config(server_addr, server_port): register_device(DEVICE_ID, { "create_micro_lib": create_micro_lib, - "default_config": default_config, + "generate_config": generate_config, }) diff --git a/python/tvm/micro/device/base.py b/python/tvm/micro/device/base.py index ae53b9cc539f..767284c9c254 100644 --- a/python/tvm/micro/device/base.py +++ b/python/tvm/micro/device/base.py @@ -17,12 +17,13 @@ """Base definitions for MicroTVM config""" import glob import os -from pathlib import Path +import enum +import pathlib from tvm.contrib import util as _util from tvm.contrib.binutil import run_cmd from tvm._ffi.libinfo import find_include_path -from tvm.micro import LibType, get_micro_host_driven_dir, get_micro_device_dir +from tvm.micro import DEVICE_SECTIONS, LibType, get_micro_host_driven_dir, get_micro_device_dir _DEVICE_REGISTRY = {} @@ -38,7 +39,7 @@ def register_device(device_id, device_funcs): dictionary with compilation and config generation functions as values """ if device_id in _DEVICE_REGISTRY: - raise RuntimeError(f"\"{device_id}\" already exists in the device registry") + raise RuntimeError(f'"{device_id}" already exists in the device registry') _DEVICE_REGISTRY[device_id] = device_funcs @@ -56,7 +57,7 @@ def get_device_funcs(device_id): dictionary with compilation and config generation functions as values """ if device_id not in _DEVICE_REGISTRY: - raise RuntimeError(f"\"{device_id}\" does not exist in the binutil registry") + raise RuntimeError(f'"{device_id}" does not exist in the binutil registry') device_funcs = _DEVICE_REGISTRY[device_id] return device_funcs @@ -67,7 +68,9 @@ def create_micro_lib_base( toolchain_prefix, device_id, lib_type, - options=None): + options=None, + lib_src_paths=None, + ): """Compiles code into a binary for the target micro device. Parameters @@ -92,7 +95,12 @@ def create_micro_lib_base( options : List[str] additional options to pass to GCC + + lib_src_paths : Optional[List[str]] + paths to additional source files to be compiled into the library """ + # look at these (specifically `strip`): + # https://stackoverflow.com/questions/15314581/g-compiler-flag-to-minimize-binary-size base_compile_cmd = [ f"{toolchain_prefix}gcc", "-std=c11", @@ -100,7 +108,6 @@ def create_micro_lib_base( "-Wextra", "--pedantic", "-c", - "-O0", "-g", "-nostartfiles", "-nodefaultlibs", @@ -114,40 +121,48 @@ def create_micro_lib_base( src_paths = [] include_paths = find_include_path() + [get_micro_host_driven_dir()] tmp_dir = _util.tempdir() - # we might transform the src path in one of the branches below + # we need to create a new src file in the operator branch new_in_src_path = in_src_path if lib_type == LibType.RUNTIME: dev_dir = _get_device_source_dir(device_id) + dev_src_paths = glob.glob(f"{dev_dir}/*.[csS]") # there needs to at least be a utvm_timer.c file assert dev_src_paths assert "utvm_timer.c" in map(os.path.basename, dev_src_paths) + src_paths += dev_src_paths elif lib_type == LibType.OPERATOR: - # create a temporary copy of the source, so we can inject the dev lib + # create a temporary copy of the operator source, so we can inject the dev lib # header without modifying the original. temp_src_path = tmp_dir.relpath("temp.c") with open(in_src_path, "r") as f: src_lines = f.read().splitlines() - src_lines.insert(0, "#include \"utvm_device_dylib_redirect.c\"") + src_lines.insert(0, '#include "utvm_device_dylib_redirect.c"') with open(temp_src_path, "w") as f: f.write("\n".join(src_lines)) new_in_src_path = temp_src_path - base_compile_cmd += ["-c"] else: raise RuntimeError("unknown lib type") src_paths += [new_in_src_path] + # add any src paths required by the operator + if lib_src_paths is not None: + src_paths += lib_src_paths + + # print(f"include paths: {include_paths}") for path in include_paths: base_compile_cmd += ["-I", path] prereq_obj_paths = [] + # print(src_paths) for src_path in src_paths: - curr_obj_path = Path(src_path).with_suffix(".o").name + curr_obj_path = tmp_dir.relpath(pathlib.Path(src_path).with_suffix(".o").name) assert curr_obj_path not in prereq_obj_paths prereq_obj_paths.append(curr_obj_path) curr_compile_cmd = base_compile_cmd + [src_path, "-o", curr_obj_path] + # TODO(weberlo): make compilation fail if there are any warnings run_cmd(curr_compile_cmd) ld_cmd = [f"{toolchain_prefix}ld", "-relocatable"] @@ -156,6 +171,65 @@ def create_micro_lib_base( run_cmd(ld_cmd) +# TODO we shouldn't need an enum for this. too much bureaucracy. +class MemConstraint(enum.Enum): + """Represents a constraint on the device's memory layout""" + ABSOLUTE_BYTES = 0 + WEIGHT = 1 + + +def gen_mem_layout(base_addr, available_mem, word_size_bits, section_constraints): + """Template function to generate memory layout for devices. + + Parameters + ---------- + base_addr: Number + The address where usable memory begins on this device. + + available_mem: Number + Available memory at base_addr, given in bytes. + + word_size_bits: Number + Number of bits in one word on this device. + + section_constraints: Optional[Dict[str, [Number, MemConstraint]]] + maps section name to the quantity of available memory + """ + assert word_size_bits in (32, 64), "only 32- or 64-bit devices are supported now" + word_size_bytes = word_size_bits // 8 + byte_sum = sum(x[0] + for x in section_constraints.values() + if x[1] == MemConstraint.ABSOLUTE_BYTES) + weight_sum = sum(x[0] + for x in section_constraints.values() + if x[1] == MemConstraint.WEIGHT) + assert byte_sum <= available_mem + available_weight_mem = available_mem - byte_sum + + res = {} + curr_addr = base_addr + for section in DEVICE_SECTIONS: + (val, cons_type) = section_constraints[section] + if cons_type == MemConstraint.ABSOLUTE_BYTES: + assert val % word_size_bytes == 0, \ + f"constraint {val} for {section} section is not word-aligned" + size = val + res[section] = { + "start": curr_addr, + "size": size, + } + else: + size = int((val / weight_sum) * available_weight_mem) + size = (size // word_size_bytes) * word_size_bytes + res[section] = { + "start": curr_addr, + "size": size, + } + curr_addr += size + + return res + + def _get_device_source_dir(device_id): """Grabs the source directory for device-specific uTVM files""" dev_subdir = "/".join(device_id.split(".")) diff --git a/python/tvm/micro/device/host.py b/python/tvm/micro/device/host.py index a5495b60cf99..cad65b919e65 100644 --- a/python/tvm/micro/device/host.py +++ b/python/tvm/micro/device/host.py @@ -17,12 +17,26 @@ """Compilation and config definitions for the host emulated device""" import sys -from . import create_micro_lib_base, register_device +from . import create_micro_lib_base, register_device, gen_mem_layout, MemConstraint DEVICE_ID = "host" TOOLCHAIN_PREFIX = "" +WORD_SIZE_BITS = 64 if sys.maxsize > 2**32 else 32 -def create_micro_lib(obj_path, src_path, lib_type, options=None): +# we pretend we only have 320kb in the default case, so we can use `gen_mem_layout` +DEFAULT_AVAILABLE_MEM = 3200000 +DEFAULT_SECTION_CONSTRAINTS = { + "text": (20480, MemConstraint.ABSOLUTE_BYTES), + "rodata": (20480, MemConstraint.ABSOLUTE_BYTES), + "data": (768, MemConstraint.ABSOLUTE_BYTES), + "bss": (4096, MemConstraint.ABSOLUTE_BYTES), + "args": (4096, MemConstraint.ABSOLUTE_BYTES), + "heap": (262144, MemConstraint.ABSOLUTE_BYTES), + "workspace": (64000, MemConstraint.ABSOLUTE_BYTES), + "stack": (80, MemConstraint.ABSOLUTE_BYTES), +} + +def create_micro_lib(obj_path, src_path, lib_type, options=None, lib_src_paths=None): """Wrapper over `create_micro_lib_base` to add device-specific options Parameters @@ -38,59 +52,66 @@ def create_micro_lib(obj_path, src_path, lib_type, options=None): options : Optional[List[str]] additional options to pass to GCC + + lib_src_paths : Optional[List[str]] + paths to additional source files to be compiled into the library """ if options is None: options = [] + else: + options = list(options) + # Cannot increase optimization level on host due to code loading method. + options.append("-O0") if sys.maxsize > 2**32 and sys.platform.startswith("linux"): options += ["-mcmodel=large"] + options.append('-DUTVM_TARGET_HOST') create_micro_lib_base( - obj_path, src_path, TOOLCHAIN_PREFIX, DEVICE_ID, lib_type, options=options) + obj_path, src_path, TOOLCHAIN_PREFIX, DEVICE_ID, lib_type, options=options, + lib_src_paths=lib_src_paths) -def default_config(): - """Generates a default configuration for the host emulated device +def generate_config(available_mem=None, section_constraints=None): + """Generates a configuration for the host emulated device + + Parameters + ---------- + available_mem: int + number of RW bytes available for use on device + + section_constraints: Optional[Dict[str, Dict[Number, MemConstraint]]] + maps section name to the quantity of available memory Return ------ config : Dict[str, Any] MicroTVM config dict for this device """ + if available_mem is None: + available_mem = DEFAULT_AVAILABLE_MEM + if section_constraints is None: + section_constraints = DEFAULT_SECTION_CONSTRAINTS + mem_layout = gen_mem_layout(0, available_mem, WORD_SIZE_BITS, section_constraints) + # TODO the host emulated device is an outlier, since we don't know how what + # its base address will be until we've created it in the C++. is there any + # way to change the infrastructure around this so it's not so much of an + # outlier? + + # need to zero out all start addresses, because they don't make sense for a + # host device (the memory region is allocated in the backend) + for section in mem_layout: + mem_layout[section]["start"] = 0 return { "device_id": DEVICE_ID, "toolchain_prefix": TOOLCHAIN_PREFIX, - "mem_layout": { - "text": { - "size": 20480, - }, - "rodata": { - "size": 20480, - }, - "data": { - "size": 768, - }, - "bss": { - "size": 768, - }, - "args": { - "size": 1280, - }, - "heap": { - "size": 262144, - }, - "workspace": { - "size": 20480, - }, - "stack": { - "size": 80, - }, - }, - "word_size": 8 if sys.maxsize > 2**32 else 4, + "mem_layout": mem_layout, + "word_size_bits": WORD_SIZE_BITS, "thumb_mode": False, + "use_device_timer": False, "comms_method": "host", } register_device(DEVICE_ID, { "create_micro_lib": create_micro_lib, - "default_config": default_config, + "generate_config": generate_config, }) diff --git a/python/tvm/micro/device/riscv_spike.py b/python/tvm/micro/device/riscv_spike.py index 923e5dfb23a2..32881cab6ba9 100644 --- a/python/tvm/micro/device/riscv_spike.py +++ b/python/tvm/micro/device/riscv_spike.py @@ -15,14 +15,25 @@ # specific language governing permissions and limitations # under the License. """Compilation and config definitions for Spike, a RISC-V functional ISA simulator""" -from collections import OrderedDict -from . import create_micro_lib_base, register_device +from . import create_micro_lib_base, register_device, gen_mem_layout, MemConstraint DEVICE_ID = "riscv_spike" TOOLCHAIN_PREFIX = "riscv64-unknown-elf-" +WORD_SIZE_BITS = 64 -def create_micro_lib(obj_path, src_path, lib_type, options=None): +DEFAULT_SECTION_CONSTRAINTS = { + "text": (18000, MemConstraint.ABSOLUTE_BYTES), + "rodata": (128, MemConstraint.ABSOLUTE_BYTES), + "data": (128, MemConstraint.ABSOLUTE_BYTES), + "bss": (2048, MemConstraint.ABSOLUTE_BYTES), + "args": (4096, MemConstraint.ABSOLUTE_BYTES), + "heap": (100.0, MemConstraint.WEIGHT), + "workspace": (64000, MemConstraint.ABSOLUTE_BYTES), + "stack": (32, MemConstraint.ABSOLUTE_BYTES), +} + +def create_micro_lib(obj_path, src_path, lib_type, options=None, lib_src_paths=None): """Wrapper over `create_micro_lib_base` to add device-specific options Parameters @@ -38,6 +49,9 @@ def create_micro_lib(obj_path, src_path, lib_type, options=None): options : Optional[List[str]] additional options to pass to GCC + + lib_src_paths : Optional[List[str]] + TODO """ create_micro_lib_base( obj_path, @@ -45,11 +59,13 @@ def create_micro_lib(obj_path, src_path, lib_type, options=None): TOOLCHAIN_PREFIX, DEVICE_ID, lib_type, - options=options) + options=options, + lib_src_paths=lib_src_paths + ) -def default_config(base_addr, server_addr, server_port): - """Generates a default configuration for Spike +def generate_config(base_addr, available_mem, server_addr, server_port, section_constraints=None): + """Generates a configuration for Spike Parameters ---------- @@ -62,56 +78,31 @@ def default_config(base_addr, server_addr, server_port): server_port : int port of OpenOCD server to connect to + TODO correct type annotation? + section_constraints: Optional[Dict[str, Tuple[Number, MemConstraint]]] + TODO + Return ------ config : Dict[str, Any] MicroTVM config dict for this device """ - res = { + if section_constraints is None: + section_constraints = DEFAULT_SECTION_CONSTRAINTS + return { "device_id": DEVICE_ID, "toolchain_prefix": TOOLCHAIN_PREFIX, - "mem_layout": OrderedDict([ - ("text", { - "size": 20480, - }), - ("rodata", { - "size": 20480, - }), - ("data", { - "size": 768, - }), - ("bss", { - "size": 768, - }), - ("args", { - "size": 1280, - }), - ("heap", { - "size": 262144, - }), - ("workspace", { - "size": 20480, - }), - ("stack", { - "size": 80, - }), - ]), - "word_size": 4, - "thumb_mode": True, + "mem_layout": gen_mem_layout(base_addr, available_mem, WORD_SIZE_BITS, section_constraints), + "word_size_bits": WORD_SIZE_BITS, + "thumb_mode": False, + "use_device_timer": False, "comms_method": "openocd", "server_addr": server_addr, "server_port": server_port, } - # generate section start addresses from the given `base_addr` - curr_offset = 0 - mem_layout = res["mem_layout"] - for region_dict in mem_layout.values(): - region_dict["start"] = base_addr + curr_offset - curr_offset += region_dict["size"] - return res register_device(DEVICE_ID, { "create_micro_lib": create_micro_lib, - "default_config": default_config, + "generate_config": generate_config, }) diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index 4e520198664c..9c565409a49b 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -40,7 +40,6 @@ from .backend import vm # Root operators -from .op import Op from .op import nn from .op import image from .op import annotation @@ -53,10 +52,17 @@ from . import frontend from . import backend from . import quantize +from . import data_dep_optimization # Dialects from . import qnn +from .scope_builder import ScopeBuilder + +# Load Memory Passes +from .transform import memory_alloc +from .transform import memory_plan + # Required to traverse large programs setrecursionlimit(10000) diff --git a/python/tvm/relay/_parser.py b/python/tvm/relay/_parser.py index 4a73e572f924..eb567658f2a1 100644 --- a/python/tvm/relay/_parser.py +++ b/python/tvm/relay/_parser.py @@ -114,7 +114,12 @@ def convert(self, v): def __call__(self, args, attrs, type_args): if attrs is None: attrs = {} - x = self.operator(*args, **{k: self.convert(v) for k, v in attrs.items()}) + if self.operator in (op.reshape, op.strided_slice): + x = self.operator(*args) + elif self.operator in (op.zeros, op.ones, op.full, op.broadcast_to): + x = self.operator(*args, dtype=attrs["dtype"]) + else: + x = self.operator(*args, **{k: self.convert(v) for k, v in attrs.items()}) if isinstance(x, expr.TupleWrapper): x = x.astuple() return x @@ -151,7 +156,9 @@ def __call__(self, args, attrs, type_args): "nn.dropout": op.nn.dropout_raw, "zeros": op.zeros, "split": op.split, - "cast": op.cast + "cast": op.cast, + "clip": op.clip, + "right_shift": op.right_shift, } TYPE_PREFIXES = [ @@ -371,7 +378,7 @@ def visitProg(self, ctx: RelayParser.ProgContext) -> Union[expr.Expr, IRModule]: return self.module # Exprs - def visitOpIdent(self, ctx) -> op.Op: + def visitOpIdent(self, ctx) -> tvm.ir.Op: op_name = ".".join([name.getText() for name in ctx.CNAME()]) if op_name in FUNC_OPS: return FuncOp(FUNC_OPS[op_name]) diff --git a/python/tvm/relay/analysis/__init__.py b/python/tvm/relay/analysis/__init__.py index a1833c3c08b2..e5b21cb107f5 100644 --- a/python/tvm/relay/analysis/__init__.py +++ b/python/tvm/relay/analysis/__init__.py @@ -28,3 +28,4 @@ # Feature from . import feature +from . import sparse_dense diff --git a/python/tvm/relay/analysis/analysis.py b/python/tvm/relay/analysis/analysis.py index 21f3edfb99eb..c237859eb987 100644 --- a/python/tvm/relay/analysis/analysis.py +++ b/python/tvm/relay/analysis/analysis.py @@ -333,3 +333,21 @@ def extract_fused_functions(mod): for hash_, func in ret_mod.functions.items(): ret[hash_] = func return ret + + +def search_fc_transpose(expr): + """Search fc weight name in the patten: y = nn.dense(x, transpose(w, [1, 0])) + + This function is used in the data_dep_optimization.simplify_fc_transpose method + + Parameters + ---------- + expr : tvm.relay.Expr + + Returns + ------- + ret : Array[String] + Array of weight variable name in pattern y = nn.dense(x, transpose(w, [1, 0])) + """ + ret = _ffi_api.search_fc_transpose(expr) + return ret diff --git a/python/tvm/relay/analysis/annotated_regions.py b/python/tvm/relay/analysis/annotated_regions.py index fc8e85ac8743..f29b72669474 100644 --- a/python/tvm/relay/analysis/annotated_regions.py +++ b/python/tvm/relay/analysis/annotated_regions.py @@ -31,9 +31,9 @@ def __init__(self, expr, region_begin_op, region_end_op): ---------- expr : tvm.relay.Expr The expression from which to construct the regions. - region_begin_op : tvm.relay.Op + region_begin_op : tvm.ir.Op The region begin annotation. - region_end_op : tvm.relay.Op + region_end_op : tvm.ir.Op The region end annotation. """ diff --git a/python/tvm/relay/analysis/sparse_dense.py b/python/tvm/relay/analysis/sparse_dense.py new file mode 100644 index 000000000000..7e8f4345e336 --- /dev/null +++ b/python/tvm/relay/analysis/sparse_dense.py @@ -0,0 +1,93 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=no-else-return +# pylint: disable=unidiomatic-typecheck +""" +This file contains helper functions for convert dense model +to block sparse model +""" +from collections import namedtuple +import numpy as np +import scipy.sparse as sp +import tvm +from . import _ffi_api + + +SparseAnalysisResult = namedtuple("SparseAnalysisResult", [ + "weight_name", + "weight_shape", +]) + +def _search_dense_op_weight(expr): + """Search name of weight in all ```nn.dense``` operator + This is a helpful function to determine which param need + to be converted to sparse + + Parameters + ---------- + expr : relay.Expr + Expr will be searched + + Returns + ------- + ret : Array[String] + name of weight in all ``nn.dense``` operator + """ + return _ffi_api.search_dense_op_weight(expr) + + +def process_params(expr, params, block_size, sparsity_threshold): + """[summary] + + Parameters + ---------- + expr : Relay.Expr + Expr of the network + params : Dict[String, tvm.nd.array] + parameters of the network + block_size : Tuple(int, int) + Blocksize in BSR matrix + sparsity_threshold : float + Minimal sparsity requirement for converting to sparse operation + + Returns + ------- + ret : Namedtuple[weight_name: Array[String], weight_shape: Array[Array[IntImm]]] + return names of qualified dense weight and the shape in BSR format + """ + memo = SparseAnalysisResult(weight_name=[], weight_shape=[]) + weight_names = _search_dense_op_weight(expr) + for name in weight_names: + name = str(name) + w_np = params[name].asnumpy() + sparsity = 1.0 - (np.count_nonzero(w_np) / w_np.size) + if sparsity >= sparsity_threshold: + sparse_weight = sp.bsr_matrix(w_np, blocksize=block_size) + # remove dense weight + del params[name] + memo.weight_name.append(name) + memo.weight_shape.append(list(sparse_weight.data.shape) + + list(sparse_weight.indices.shape) + + list(sparse_weight.indptr.shape)) + params[name + ".data"] = tvm.nd.array(sparse_weight.data) + params[name + ".indices"] = tvm.nd.array(sparse_weight.indices) + params[name + ".indptr"] = tvm.nd.array(sparse_weight.indptr) + ret = SparseAnalysisResult( + weight_name=tvm.runtime.convert(memo.weight_name), + weight_shape=tvm.runtime.convert(memo.weight_shape) + ) + return ret diff --git a/python/tvm/relay/backend/compile_engine.py b/python/tvm/relay/backend/compile_engine.py index 3e35bd22e08f..eb5c2b32c0ef 100644 --- a/python/tvm/relay/backend/compile_engine.py +++ b/python/tvm/relay/backend/compile_engine.py @@ -26,7 +26,6 @@ from ... import target as _target from ... import autotvm from .. import function as _function -from .. import op as _op from .. import ty as _ty from . import _backend @@ -98,7 +97,7 @@ def get_valid_implementations(op, attrs, inputs, out_type, target): Parameters ---------- - op : relay.op.Op + op : tvm.ir.Op Relay operator. attrs : object @@ -157,7 +156,7 @@ def select_implementation(op, attrs, inputs, out_type, target, use_autotvm=True) Parameters ---------- - op : relay.op.Op + op : tvm.ir.Op Relay operator. attrs : object @@ -215,7 +214,7 @@ def select_implementation(op, attrs, inputs, out_type, target, use_autotvm=True) @tvm._ffi.register_func("relay.backend.lower_call") def lower_call(call, inputs, target): """Lower the call expression to op implementation and tensor outputs.""" - assert isinstance(call.op, _op.Op) + assert isinstance(call.op, tvm.ir.Op) op = call.op # Prepare the call_node->checked_type(). For the call node inputs, we ensure that diff --git a/python/tvm/relay/data_dep_optimization/__init__.py b/python/tvm/relay/data_dep_optimization/__init__.py new file mode 100644 index 000000000000..ab0caa20f0bb --- /dev/null +++ b/python/tvm/relay/data_dep_optimization/__init__.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +#pylint: disable=unused-argument, not-context-manager +"""Optimizations involves changing of paramters""" + +from . import bsr_dense +from . import simplify_fc_transpose diff --git a/python/tvm/relay/data_dep_optimization/bsr_dense.py b/python/tvm/relay/data_dep_optimization/bsr_dense.py new file mode 100644 index 000000000000..cc3e5deb302e --- /dev/null +++ b/python/tvm/relay/data_dep_optimization/bsr_dense.py @@ -0,0 +1,57 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +#pylint: disable=unused-argument, not-context-manager +"""Automatic convert model from dense to block sparse""" + +from tvm import relay +from tvm.relay.analysis.sparse_dense import process_params + +from .utils import _run_opt_pass + +def convert(func, params, blocksize, sparsity_threshold): + """Convert a dense func and according parameters to block sparse + + Parameters + ---------- + func : relay.Expr + Expr will be optimized to sparse operation + params : Dict[Srting, tvm.nd.array] + Parameters of the Expr + blocksize : Tuple(int, int) + Blocksize for BSR matrix + sparsity_threshold : float + Minimal sparsity requirement for converting. + If weight sparsity is lower than this threshold, + the dense operation will be kept. + + Returns + ------- + new_func: relay.Expr + Mutated Expr with sparse operations + + params: Dict[Srting, tvm.nd.array] + New params with BSR matrix for mutated Expr + """ + weight_info = process_params(func, params, blocksize, sparsity_threshold) + new_func = _run_opt_pass( + func, + relay.transform.DenseToSparse( + weight_info.weight_name, + weight_info.weight_shape + ) + ) + return new_func, params diff --git a/python/tvm/relay/data_dep_optimization/simplify_fc_transpose.py b/python/tvm/relay/data_dep_optimization/simplify_fc_transpose.py new file mode 100644 index 000000000000..345c579499f5 --- /dev/null +++ b/python/tvm/relay/data_dep_optimization/simplify_fc_transpose.py @@ -0,0 +1,60 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +#pylint: disable=unused-argument, not-context-manager +"""Automatic optimize fc tranpose""" +import numpy as np + +import tvm +from tvm import relay +from tvm.relay.analysis import search_fc_transpose + +from .utils import _run_opt_pass + + +def convert(func, params): + """convert all ```y = nn.dense(x, transpose(w, [1, 0]))``` to + ```y = nn.dense(x, wt)``` + + Parameters + ---------- + func : relay.Expr + Expr will be optimized + params : Dict[String, tvm.nd.array] + Parameters of Expr + + Returns + ------- + new_func : relay.Expr + Mutated Expr from ```y = nn.dense(x, transpose(w, [1, 0]))``` to + ```y = nn.dense(x, wt)``` + params: Dict[String, tvm.nd.array] + Parameters of mutated Expr, with weights pre-transposed + """ + weight_info = search_fc_transpose(func) + for item in weight_info: + name = str(item) + w_np = params[name].asnumpy() + new_w = np.transpose(w_np, axes=[1, 0]) + params[name + ".T"] = tvm.nd.array(new_w) + del params[name] + new_func = _run_opt_pass( + func, + relay.transform.SimplifyFCTranspose( + weight_info, + ) + ) + return new_func, params diff --git a/python/tvm/relay/data_dep_optimization/utils.py b/python/tvm/relay/data_dep_optimization/utils.py new file mode 100644 index 000000000000..6b46f815474a --- /dev/null +++ b/python/tvm/relay/data_dep_optimization/utils.py @@ -0,0 +1,40 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +#pylint: disable=unused-argument, not-context-manager +"""Utils functions for optimizations""" + +import tvm + +def _run_opt_pass(expr, opt_pass): + """Helper function to run pass + + Parameters + ---------- + expr : relay.Expr + Expr will be optimized + opt_pass : relay.Pass + Optimization pass + + Returns + ------- + ret: relay.Expr + Optimized Expr by running opt_pass + """ + assert isinstance(opt_pass, tvm.transform.Pass) + mod = tvm.IRModule.from_expr(expr) + mod = opt_pass(mod) + return mod["main"] diff --git a/python/tvm/relay/dataflow_pattern/__init__.py b/python/tvm/relay/dataflow_pattern/__init__.py new file mode 100644 index 000000000000..915842c8e5fa --- /dev/null +++ b/python/tvm/relay/dataflow_pattern/__init__.py @@ -0,0 +1,783 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""The Relay Pattern Language and tooling.""" +# pylint: disable=no-member +from typing import Callable, Dict, List, Optional + +import tvm._ffi +from tvm.relay.expr import RelayExpr as Expr + +from ... import _ffi as tvm_ffi +from ...ir import make_node +from ...ir.base import Node +from ...runtime import Object +from ..op import get +from . import _ffi as ffi + + +def register_df_node(type_key=None): + """Register a Relay node type. + + Parameters + ---------- + type_key : str or cls + The type key of the node. + """ + if not isinstance(type_key, str): + return tvm._ffi.register_object( + "relay.dataflow_pattern." + type_key.__name__)(type_key) + return tvm._ffi.register_object(type_key) + + +class DFPattern(Node): + """Base class of all Patterns. + """ + + def __call__(self, *args): + return CallPattern(self, list(args)) + + def __or__(self, other): + return AltPattern(self, other) + + def __add__(self, other): + return is_op("add")(self, other) + + def __sub__(self, other): + return is_op("subtract")(self, other) + + def __mul__(self, other): + return is_op("multiply")(self, other) + + def __truediv__(self, other): + return is_op("divide")(self, other) + + def has_attr(self, attrs: Dict[str, Object]): + """ + Add an attribute constraint to this pattern + + Parameters + ---------- + attrs: Dict[str, Object] + + Returns + ------- + result: tvm.relay.dataflow_pattern.DFPattern + The resulting AttrPattern + """ + attrs = make_node("DictAttrs", **attrs) + return AttrPattern(self, attrs) + + def has_type(self, ttype: tvm.ir.type.Type): + """ + Add a type constraint to this pattern + + Parameters + ---------- + ttype: tvm.ir.type.Type + The type to match + + Returns + ------- + result: tvm.relay.dataflow_pattern.DFPattern + The resulting TypePattern + """ + return has_type(ttype, self) + + def has_dtype(self, dtype: str): + """ + Add a type constraint to this pattern + + Parameters + ---------- + dtype: str + The dtype to match + + Returns + ------- + result: tvm.relay.dataflow_pattern.DFPattern + The resulting DataTypePattern + """ + return has_dtype(dtype, self) + + def has_shape(self, shape: List[tvm.ir.PrimExpr]): + """ + Add a type constraint to this pattern + + Parameters + ---------- + shape: List[tvm.ir.PrimExpr] + The shape to match + + Returns + ------- + result: tvm.relay.dataflow_pattern.DFPattern + The resulting ShapePattern + """ + return has_shape(shape, self) + + def match(self, expr: Expr) -> bool: + """ + Match this pattern to an expression + + Parameters + ---------- + expr : tvm.relay.Expr + The expression to match. + + Returns + ------- + result: bool + Whether or not the expression matches the pattern + """ + return match(self, expr) + + def partition(self, + expr: Expr, + attrs: Optional[Dict[str, Object]] = None, + check: Callable[[Expr], bool] = lambda x: True) -> Expr: + """ + Parition the expression into functions defined by this pattern + + Parameters + ---------- + expr : tvm.relay.Expr + The expression to match. + attrs : Optional[Dict[str, Object]] + A dictionary of Attribute name/values to add to the paritioned function + check : Callable[[Expr], bool] + A function to perform more complicated checks on the matched expression. + Returns true if partitioning should proceed, false otherwise. + + Returns + ------- + result : tvm.relay.Expr + The Expression with matched subgraphs replaced by function calls to that subgraph + """ + return partition(self, expr, attrs, check) + + def dominates(self, parent: "DFPattern", path: "DFPattern" = None): + """ + Create a dominator for this pattern. + + Parameters + ---------- + parent: tvm.relay.dataflow_pattern.DFPattern + The parent pattern this pattern dominates. + path: tvm.relay.dataflow_pattern.DFPattern + The fuzzy path pattern. + + Returns + ------- + result: tvm.relay.dataflow_pattern.DFPattern + The resulting DominatorPattern. + """ + if path is None: + path = wildcard() + return DominatorPattern(parent, path, self) + + def optional(self, option_constructor: Callable[["DFPattern"], "DFPattern"]): + """ + Create a optional user of this pattern. + + Parameters + ---------- + option_constructor: function + A function that takes a single Pattern parameter and returns + a constructed pattern matching the option + + Returns + ------- + result: tvm.relay.dataflow_pattern.DFPattern + The resulting Pattern + """ + return self | option_constructor(self) + + +def is_var(name: str = "") -> "DFPattern": + """ + Syntatic sugar for creating an optionally named VarPattern. + + Parameters + ---------- + name: str + The name of the input pattern to match. + + Returns + ------- + result: tvm.relay.dataflow_pattern.DFPattern + The resulting pattern. + """ + return VarPattern(name) + + +def is_constant() -> "DFPattern": + """ + Syntatic sugar for creating a ConstantPattern. + + Parameters + ---------- + name: str + The name of the input pattern to match. + + Returns + ------- + result: tvm.relay.dataflow_pattern.DFPattern + The resulting pattern. + """ + return ConstantPattern() + + +def is_expr(expr: Expr) -> "DFPattern": + """ + Syntatic sugar for creating an ExprPattern. + + Parameters + ---------- + expr: Expr + The Relay expression to match. + + Returns + ------- + result: tvm.relay.dataflow_pattern.DFPattern + The resulting pattern. + """ + return ExprPattern(expr) + + +def is_op(op_name: str) -> "DFPattern": + """ + Syntatic sugar for creating an operator ExprPattern. + + Parameters + ---------- + op_name: String + The name of the relay op + + Returns + ------- + result: tvm.relay.dataflow_pattern.DFPattern + The resulting ExprPattern + """ + op = get(op_name) + return ExprPattern(op) + + +def is_tuple(fields: tvm.ir.container.Array) -> "DFPattern": + """ + Syntatic sugar for creating an ExprPattern. + + Parameters + ---------- + fields : Array[tvm.relay.dataflow_pattern.DFPattern] + The fields in the tuple. + + Returns + ------- + result: tvm.relay.dataflow_pattern.DFPattern + The resulting pattern. + """ + return TuplePattern(fields) + + +def is_tuple_get_item(tuple_value: "DFPattern", index: int) -> "DFPattern": + """ + Syntatic sugar for creating an ExprPattern. + + Parameters + ---------- + tuple_value: tvm.relay.dataflow_pattern.DFPattern + The input tuple expression. + + index: int + The index. + + Returns + ------- + result: tvm.relay.dataflow_pattern.DFPattern + The resulting pattern. + """ + return TupleGetItemPattern(tuple_value, index) + + +def wildcard() -> "DFPattern": + """ + Syntatic sugar for creating a WildcardPattern. + + Returns + ------- + result: tvm.relay.dataflow_pattern.DFPattern + The resulting pattern. + """ + return WildcardPattern() + + +def has_type(ttype: tvm.ir.type.Type, pattern: "DFPattern" = None) -> "DFPattern": + """ + Syntatic sugar for creating a TypePattern + + Parameters + ---------- + ttype: tvm.ir.type.Type + The type to match + + pattern: tvm.relay.dataflow_pattern.DFPattern + The pattern that needs type annotation + + Returns + ------- + result: tvm.relay.dataflow_pattern.DFPattern + The resulting TypePattern + """ + if pattern is None: + pattern = wildcard() + return TypePattern(pattern, ttype) + + +def has_dtype(dtype: str, pattern: "DFPattern" = None) -> "DFPattern": + """ + Syntatic sugar for creating a DataTypePattern + + Parameters + ---------- + dtype: str + The dtype to match + + pattern: tvm.relay.dataflow_pattern.DFPattern + The pattern that needs type annotation + + Returns + ------- + result: tvm.relay.dataflow_pattern.DFPattern + The resulting DataTypePattern + """ + if pattern is None: + pattern = wildcard() + return DataTypePattern(pattern, dtype) + + +def has_shape(shape: List[tvm.ir.PrimExpr], pattern: "DFPattern" = None) -> "DFPattern": + """ + Syntatic sugar for creating a ShapePattern + + Parameters + ---------- + shape: List[tvm.ir.PrimExpr] + The shape to match + + pattern: tvm.relay.dataflow_pattern.DFPattern + The pattern that needs type annotation + + Returns + ------- + result: tvm.relay.dataflow_pattern.DFPattern + The resulting ShapePattern + """ + if pattern is None: + pattern = wildcard() + return ShapePattern(pattern, shape) + + +def has_attr(attrs, pattern=None) -> "DFPattern": + """ + Syntatic sugar for creating an AttrPattern + + Parameters + ---------- + attrs: Dict[str, Object] + The attributes to match + + pattern: Optional[tvm.relay.dataflow_pattern.DFPattern] + The input pattern. + + Returns + ------- + result: tvm.relay.dataflow_pattern.DFPattern + The resulting AttrPattern + """ + if pattern is None: + pattern = wildcard() + return pattern.has_attr(attrs) + + +def dominates(parent: "DFPattern", path: "DFPattern", child: "DFPattern") -> "DFPattern": + """ + Syntatic sugar for creating an Dominator pattern + + Parameters + ---------- + parent: tvm.relay.dataflow_pattern.DFPattern + The parent pattern. + path: tvm.relay.dataflow_pattern.DFPattern + The fuzzy path pattern. + child: tvm.relay.dataflow_pattern.DFPattern + The child pattern. + + Returns + ------- + result: tvm.relay.dataflow_pattern.DFPattern + The resulting DominatorPattern. + """ + return DominatorPattern(parent, path, child) + + +def match(pattern: "DFPattern", expr: Expr) -> bool: + """ + Match a pattern to an expression + + Parameters + ---------- + pattern: tvm.relay.dataflow_pattern.DFPattern + The input pattern. + expr : tvm.relay.Expr + The expression to match. + """ + return ffi.match(pattern, expr) + + +@register_df_node +class ExprPattern(DFPattern): + """A pattern which matches a constant expression. + + Parameters + ---------- + expr : tvm.relay.Expr + The expression to match. + """ + + def __init__(self, expr: Expr): + self.__init_handle_by_constructor__(ffi.ExprPattern, expr) + + +@register_df_node +class VarPattern(DFPattern): + """A local variable in Relay. + + Local variable can be used to declare input + arguments to a function, or intermediate variables. + + Parameters + ---------- + name_hint: str + The name of the variable. Optional, if not provided, + the pattern will match any VarNode. + + type_annotation: tvm.ir.type.Type, optional + The type annotation on the variable. + """ + + def __init__(self, name_hint: str = "", type_annotation: Optional[tvm.ir.type.Type] = None): + self.__init_handle_by_constructor__(ffi.VarPattern, name_hint, type_annotation) + + +@register_df_node +class ConstantPattern(DFPattern): + """A pattern matching a Relay Constant. + """ + def __init__(self): + self.__init_handle_by_constructor__(ffi.ConstantPattern) + + +@register_df_node +class CallPattern(DFPattern): + """A pattern matching a function call node in Relay. + + Parameters + ---------- + op: realy.dataflow_pattern.DFPattern + The operation to be called. + + args: List[realy.dataflow_pattern.DFPattern] + The arguments to the call. + + attrs: Optional[tvm.ir.attrs.Attrs] + Attributes to the call, can be None + + type_args: Optional[List[tvm.ir.type.Type]] + The additional type arguments, this is only + used in advanced usecase of template functions. + """ + + def __init__(self, + op: "DFPattern", + args: List["DFPattern"], + attrs: Optional[tvm.ir.attrs.Attrs] = None, + type_args: Optional[List[tvm.ir.type.Type]] = None): + if not type_args: + type_args = [] + self.__init_handle_by_constructor__(ffi.CallPattern, op, args, attrs, type_args) + + +@register_df_node +class TuplePattern(DFPattern): + """A patern matching a Relay Tuple. + + Parameters + ---------- + fields : Array[tvm.relay.dataflow_pattern.DFPattern] + The fields in the tuple. + """ + + def __init__(self, fields: tvm.ir.container.Array): + self.__init_handle_by_constructor__(ffi.TuplePattern, fields) + + def __getitem__(self, index: int): + if index >= len(self): + raise IndexError("TuplePattern index out of range") + return self.fields[index] + + def __len__(self): + return len(self.fields) + + def astype(self, _): + raise TypeError("astype cannot be used on TuplePattern") + + +@register_df_node +class TupleGetItemPattern(DFPattern): + """Get index-th item from a TuplePattern. + + Parameters + ---------- + tuple_value: tvm.relay.dataflow_pattern.DFPattern + The input tuple expression. + + index: int + The index. + """ + + def __init__(self, tuple_value: "DFPattern", index: int): + self.__init_handle_by_constructor__(ffi.TupleGetItemPattern, tuple_value, index) + + +@register_df_node +class AltPattern(DFPattern): + """Create a Pattern that can match one of two conditions + + Parameters + ---------- + left: tvm.relay.dataflow_pattern.DFPattern + One possible matching pattern. + right: tvm.relay.dataflow_pattern.DFPattern + One possible matching pattern. + """ + + def __init__(self, left: "DFPattern", right: "DFPattern"): + self.__init_handle_by_constructor__(ffi.AltPattern, left, right) + + +@register_df_node +class WildcardPattern(DFPattern): + """A pattern which matches anything. + """ + + def __init__(self): + self.__init_handle_by_constructor__(ffi.WildcardPattern) + + +@register_df_node +class TypePattern(DFPattern): + """A pattern that matches another pattern with a certain type annotation. + + Parameters + ---------- + pattern: tvm.relay.dataflow_pattern.DFPattern + The input pattern that needs type annotation. + + ttype: tvm.ir.type.Type + The type to match. + """ + + def __init__(self, pattern: "DFPattern", ttype: tvm.ir.type.Type): + self.__init_handle_by_constructor__(ffi.TypePattern, pattern, ttype) + + +@register_df_node +class DataTypePattern(DFPattern): + """A pattern that matches another pattern with certain data type + + Parameters + ---------- + pattern: tvm.relay.dataflow_pattern.DFPattern + The input pattern that needs type annotation. + + dtype: str + The dtype to match. + """ + + def __init__(self, pattern: "DFPattern", dtype: str): + self.__init_handle_by_constructor__(ffi.DataTypePattern, pattern, dtype) + + +@register_df_node +class ShapePattern(DFPattern): + """A pattern that matches another pattern with a certain tensor shape + + Parameters + ---------- + pattern: tvm.relay.dataflow_pattern.DFPattern + The input pattern that needs type annotation. + + shape: List[tvm.ir.PrimExpr] + The shape to match. + """ + + def __init__(self, pattern: "DFPattern", shape: List[tvm.ir.PrimExpr]): + self.__init_handle_by_constructor__(ffi.ShapePattern, pattern, shape) + + +@register_df_node +class AttrPattern(DFPattern): + """Get match an expression with a certain attributes. + Currently only supports Op Attributes, not call Attributes. + + Parameters + ---------- + pattern: tvm.relay.dataflow_pattern.DFPattern + The input pattern. + + attrs: tvm.ir.attrs.Attrs + The attributes to match. + """ + + def __init__(self, pattern: "DFPattern", attrs: tvm.ir.attrs.Attrs): + self.__init_handle_by_constructor__(ffi.AttrPattern, pattern, attrs) + + +@register_df_node +class DominatorPattern(DFPattern): + """Match a domination graph. + + Parameters + ---------- + parent: tvm.relay.dataflow_pattern.DFPattern + The parent, i.e., the single node which produces something, + later aggregated by the child. + path: tvm.relay.dataflow_pattern.DFPattern + The fuzzy path pattern between parent and child, + typically matches elementwise ops. + child: tvm.relay.dataflow_pattern.DFPattern + The last node in the domination which is the end user + for all nodes in the path and the parent. + """ + + def __init__(self, parent: "DFPattern", path: "DFPattern", child: "DFPattern"): + self.__init_handle_by_constructor__(ffi.DominatorPattern, parent, path, child) + + +class DFPatternCallback: + """A Callback for Pattern Rewriting. + + When rewrite is called on this DFPatternCallback, the backend will find matches for the + pattern, call the callback function, and replace the matched expression with whatever + the callback returns. + + Users are expect to inherit from this class and provide a "self.pattern" to match + """ + + def rewrite(self, expr: Expr) -> Expr: + """ + Rewrite expression with this callback + + Parameters + ---------- + expr : tvm.relay.Expr + The expression to rewrite. + + Returns + ------- + result : tvm.relay.Expr + The Expression with matched subgraphs rewritten by the callbacks. + """ + return rewrite(self, expr) + + def callback(self, pre: Expr, post: Expr, node_map: tvm.ir.container.Map) -> Expr: + """ + Callback function to use when we found a match to the pattern + + Parameters + ---------- + pre : tvm.relay.Expr + The matching expression from the original graph. + post : tvm.relay.Expr + The matching expression with rewritten inputs + node_map : tvm.ir.container.Map[DFPattern, List[Expr]] + The map between patterns and matched expressions + + Returns + ------- + result : tvm.relay.Expr + The Expression with matched subgraph rewritten by the callback + """ + raise "Unimplemented" + +class _DFPatternCallback(Object): + """C++ implemenation""" + def __init__(self, pattern, callback): + self.__init_handle_by_constructor__(ffi.DFPatternCallback, pattern, callback) + + +def rewrite(callbacks, expr: Expr) -> Expr: + """ + Rewrite expression with the given callbacks. + + Parameters + ---------- + callbacks: tvm.relay.dataflow_pattern.DFPatternCallback + The input callback or list of callbacks. + expr : tvm.relay.Expr + The expression to rewrite. + + Returns + ------- + result : tvm.relay.Expr + The Expression with matched subgraphs rewritten by the callbacks. + """ + if isinstance(callbacks, DFPatternCallback): + tmp = [_DFPatternCallback(callbacks.pattern, callbacks.callback)] + else: + tmp = [] + for callback in callbacks: + tmp.append(_DFPatternCallback(callback.pattern, callback.callback)) + + return ffi.rewrite(tmp, expr) + + +def partition(pattern: "DFPattern", + expr: Expr, + attrs: Optional[Dict[str, Object]] = None, + check: Callable[[Expr], bool] = lambda x: True) -> Expr: + """ + Parition the expression into a series of functions that match the pattern + + Parameters + ---------- + partion: tvm.relay.dataflow_pattern.DFPattern + The pattern to match + expr : tvm.relay.Expr + The expression to split into functions + attrs : Optional[Dict[str, Object]] + A dict of attributes to apply to the partitioned function + check : Callable[[Expr], bool] + A function to perform more complicated checks on the matched expression. + Returns true if partitioning should proceed, false otherwise. + + Returns + ------- + result : tvm.relay.Expr + The Expression with matched subgraphs replaced by function calls to that subgraph + """ + return ffi.partition(pattern, expr, attrs, check) diff --git a/python/tvm/relay/dataflow_pattern/_ffi.py b/python/tvm/relay/dataflow_pattern/_ffi.py new file mode 100644 index 000000000000..b0a702c1d2f5 --- /dev/null +++ b/python/tvm/relay/dataflow_pattern/_ffi.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""DataFlow Pattern Language FFI bindings.""" +import tvm._ffi + +tvm._ffi._init_api("relay.dataflow_pattern", __name__) diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index ff1368394917..fbb98fcf9e3c 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -221,7 +221,7 @@ def __init__(self, name_hint, type_annotation=None): @property def name_hint(self): """Get name hint of the current var.""" - name = self.vid.name_hint + name = str(self.vid.name_hint) return name @@ -234,7 +234,7 @@ class Call(ExprWithOp): Parameters ---------- - op: tvm.relay.Op or any tvm.relay.Expr with function type. + op: tvm.ir.Op or any tvm.relay.Expr with function type. The operation to be called. args: List[tvm.relay.Expr] @@ -504,6 +504,7 @@ def const(value, dtype=None): if not isinstance(value, _nd.NDArray): raise ValueError("value has to be scalar or NDArray") + return Constant(value) diff --git a/python/tvm/relay/expr_functor.py b/python/tvm/relay/expr_functor.py index 874a3a75d5bf..fd9b253c1478 100644 --- a/python/tvm/relay/expr_functor.py +++ b/python/tvm/relay/expr_functor.py @@ -16,13 +16,13 @@ # under the License. # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name """The expression functor of Relay.""" +from tvm.ir import Op from .function import Function from .expr import Call, Let, Var, GlobalVar from .expr import If, Tuple, TupleGetItem, Constant from .expr import RefCreate, RefRead, RefWrite from .adt import Constructor, Match, Clause -from .op import Op class ExprFunctor: """ diff --git a/python/tvm/relay/frontend/__init__.py b/python/tvm/relay/frontend/__init__.py index fa258f48ac76..aba9eea494be 100644 --- a/python/tvm/relay/frontend/__init__.py +++ b/python/tvm/relay/frontend/__init__.py @@ -24,10 +24,6 @@ from __future__ import absolute_import from .mxnet import from_mxnet -from .mxnet_qnn_op_utils import dequantize_mxnet_min_max -from .mxnet_qnn_op_utils import quantize_mxnet_min_max -from .mxnet_qnn_op_utils import get_mkldnn_int8_scale -from .mxnet_qnn_op_utils import get_mkldnn_uint8_scale from .mxnet_qnn_op_utils import quantize_conv_bias_mkldnn_from_var from .keras import from_keras from .onnx import from_onnx diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py index e86890f3639a..6310e3bfcf29 100644 --- a/python/tvm/relay/frontend/common.py +++ b/python/tvm/relay/frontend/common.py @@ -497,15 +497,15 @@ def infer_value(input_val, params, mod=None): portion of the relay graph. This is often needed for functions that whose output shape depends on the value of a tensor. """ + # Check that all free variables have associated parameters. + assert all(var.name_hint in params.keys() for var in analysis.free_vars( + input_val)), "All inputs to infer must be available in params." try: # TODO(kevinthesun): Use VM for all cases. # pylint: disable=import-outside-toplevel from tvm.contrib import graph_runtime - # Check that all free variables have associated parameters. - assert all(var.name_hint in params.keys() for var in analysis.free_vars( - input_val)), "All inputs to infer must be available in params." func = _function.Function(analysis.free_vars(input_val), input_val) - with tvm.relay.build_config(opt_level=0): + with tvm.transform.PassContext(opt_level=0): graph, lib, params = tvm.relay.build(func, target="llvm", params=params) ctx = tvm.cpu(0) m = graph_runtime.create(graph, lib, ctx) @@ -520,7 +520,7 @@ def infer_value(input_val, params, mod=None): exc = tvm.relay.create_executor("debug", mod=mod, ctx=tvm.cpu(), target="llvm") inputs = [] for param in mod['main'].params: - inputs.append(tvm.nd.array(params[param.name_hint])) + inputs.append(params[param.name_hint]) result = exc.evaluate()(*inputs) return result diff --git a/python/tvm/relay/frontend/coreml.py b/python/tvm/relay/frontend/coreml.py index 6658803b3ade..0027c7faab20 100644 --- a/python/tvm/relay/frontend/coreml.py +++ b/python/tvm/relay/frontend/coreml.py @@ -77,10 +77,7 @@ def _ConvolutionLayerParams(op, inexpr, etab): pad_b = valid.paddingAmounts.borderAmounts[0].endEdgeSize pad_r = valid.paddingAmounts.borderAmounts[1].endEdgeSize if not all(v == 0 for v in (pad_t, pad_l, pad_b, pad_r)): - inexpr = _op.nn.pad(data=inexpr, pad_width=((0, 0), - (0, 0), - (pad_t, pad_b), - (pad_l, pad_r))) + params['padding'] = (pad_t, pad_l, pad_b, pad_r) elif op.WhichOneof('ConvolutionPaddingType') == 'same': assert op.same.asymmetryMode == 0, "Only support BOTTOM_RIGHT_HEAVY mode, " \ "which is used by tf/caffe and so on" @@ -88,11 +85,7 @@ def _ConvolutionLayerParams(op, inexpr, etab): strides = params['strides'] pad_t, pad_b = get_pad_value(H, kernel[0], strides[0]) pad_l, pad_r = get_pad_value(W, kernel[1], strides[1]) - inexpr = _op.nn.pad(data=inexpr, pad_width=((0, 0), - (0, 0), - (pad_t, pad_b), - (pad_l, pad_r))) - + params['padding'] = (pad_t, pad_l, pad_b, pad_r) else: raise NotImplementedError("Valid/Same convolution padding implemented") diff --git a/python/tvm/relay/frontend/darknet.py b/python/tvm/relay/frontend/darknet.py index 936d7c0dc87f..62a320780564 100644 --- a/python/tvm/relay/frontend/darknet.py +++ b/python/tvm/relay/frontend/darknet.py @@ -637,12 +637,12 @@ def _get_darknet_attrs(self, layer, layer_num): attr.update({'coords' : layer.coords}) attr.update({'background' : layer.background}) attr.update({'softmax' : layer.softmax}) - attr.update({'shape' : (1, layer.c, layer.h, layer.w)}) + attr.update({'shape' : (-1, layer.c, layer.h, layer.w)}) elif LAYERTYPE.YOLO == layer_type: attr.update({'n' : layer.n}) attr.update({'classes' : layer.classes}) - attr.update({'shape' : (1, layer.c, layer.h, layer.w)}) + attr.update({'shape' : (-1, layer.c, layer.h, layer.w)}) elif LAYERTYPE.UPSAMPLE == layer_type: attr.update({'scale' : layer.stride}) diff --git a/python/tvm/relay/frontend/keras.py b/python/tvm/relay/frontend/keras.py index f86092d110b0..ef76eb69311d 100644 --- a/python/tvm/relay/frontend/keras.py +++ b/python/tvm/relay/frontend/keras.py @@ -186,8 +186,11 @@ def _convert_merge(inexpr, keras_layer, _): elif merge_type == 'Subtract': assert len(inexpr) == 2, "Subtract merge takes 2 inputs." ret = _op.subtract(ret, inexpr[1]) - elif merge_type in ['Add', 'Multiply', 'Maximum']: - op_map = {'Add': _op.add, 'Multiply': _op.multiply, 'Maximum': _op.maximum} + elif merge_type in ['Add', 'Multiply', 'Minimum', 'Maximum']: + op_map = {'Add': _op.add, + 'Multiply': _op.multiply, + 'Minimum': _op.minimum, + 'Maximum': _op.maximum} for i in range(1, len(inexpr)): ret = op_map[merge_type](ret, inexpr[i]) elif merge_type == 'Average': @@ -204,6 +207,14 @@ def _convert_permute(inexpr, keras_layer, _): return _op.transpose(inexpr, axes=(0,) + keras_layer.dims) +def _convert_embedding(inexpr, keras_layer, etab): + indices = inexpr + weightList = keras_layer.get_weights() + weight = etab.new_const(weightList[0]) + out = _op.take(weight, indices.astype('int32'), axis=0) + + return out + def _convert_dense(inexpr, keras_layer, etab): weightList = keras_layer.get_weights() weight = etab.new_const(weightList[0].transpose([1, 0])) @@ -287,15 +298,7 @@ def _convert_convolution(inexpr, keras_layer, etab): in_w = keras_layer.input_shape[2] pad_t, pad_b = _get_pad_pair(in_h, dilated_kernel_h, stride_h) pad_l, pad_r = _get_pad_pair(in_w, dilated_kernel_w, stride_w) - if pad_t == pad_b and pad_l == pad_r: - params['padding'] = (pad_t, pad_l) - elif etab.data_layout == 'NCHW': - inexpr = _op.nn.pad(data=inexpr, pad_width=( - (0, 0), (0, 0), (pad_t, pad_b), (pad_l, pad_r))) - else: - inexpr = _op.nn.pad(data=inexpr, pad_width=( - (0, 0), (pad_t, pad_b), (pad_l, pad_r), (0, 0))) - + params['padding'] = (pad_t, pad_l, pad_b, pad_r) else: msg = 'Padding with {} is not supported for operator Convolution ' \ 'in frontend Keras.' @@ -370,7 +373,7 @@ def _convert_convolution3d(inexpr, keras_layer, etab): pad_d3 = _get_pad_pair(in_d3, dilated_kernel_d3, stride_d3) params['padding'] = [pad_d1[0], pad_d2[0], pad_d3[0], pad_d1[1], pad_d2[1], pad_d3[1]] else: - msg = 'Padding with {} is not supported for operator Convolution ' \ + msg = 'Padding with {} is not supported for operator Convolution3D ' \ 'in frontend Keras.' raise tvm.error.OpAttributeUnImplemented(msg.format(keras_layer.padding)) out = _op.nn.conv3d(data=inexpr, **params) @@ -421,15 +424,7 @@ def _convert_separable_convolution(inexpr, keras_layer, etab): in_w = keras_layer.input_shape[2] pad_t, pad_b = _get_pad_pair(in_h, kernel_h, stride_h) pad_l, pad_r = _get_pad_pair(in_w, kernel_w, stride_w) - if pad_t == pad_b and pad_l == pad_r: - params0['padding'] = (pad_t, pad_l) - elif etab.data_layout == 'NCHW': - inexpr = _op.nn.pad(data=inexpr, pad_width=( - (0, 0), (0, 0), (pad_t, pad_b), (pad_l, pad_r))) - else: - inexpr = _op.nn.pad(data=inexpr, pad_width=( - (0, 0), (pad_t, pad_b), (pad_l, pad_r), (0, 0))) - + params0['padding'] = (pad_t, pad_l, pad_b, pad_r) else: msg = 'Padding with {} is not supported for operator Separable ' \ 'Convolution in frontend Keras.' @@ -548,6 +543,23 @@ def _convert_pooling3d(inexpr, keras_layer, etab): return _op.transpose(out, axes=(0, 2, 3, 4, 1)) + +def _convert_global_pooling3d(inexpr, keras_layer, etab): + _check_data_format(keras_layer) + pool_type = type(keras_layer).__name__ + + global_pool_params = {'layout': etab.data_layout} + if pool_type == 'GlobalMaxPooling3D': + out = _op.nn.global_max_pool3d(inexpr, **global_pool_params) + elif pool_type == 'GlobalAveragePooling3D': + out = _op.nn.global_avg_pool3d(inexpr, **global_pool_params) + else: + raise tvm.error.OpNotImplemented( + 'Operator {} is not supported for frontend Keras.'.format(keras_layer)) + + return _convert_flatten(out, keras_layer, etab) + + def _convert_upsample(inexpr, keras_layer, etab): _check_data_format(keras_layer) upsample_type = type(keras_layer).__name__ @@ -599,8 +611,8 @@ def _convert_cropping(inexpr, keras_layer, _): raise tvm.error.OpNotImplemented( 'Operator {} is not supported for frontend Keras.'.format(crop_type)) int32_max = np.iinfo(np.int32).max - return _op.strided_slice(inexpr, begin=[0, 0, crop_t, crop_l], \ - end=[int32_max, int32_max, in_h-crop_b, in_w-crop_r]) + return _op.strided_slice(inexpr, begin=_expr.const([0, 0, crop_t, crop_l]), \ + end=_expr.const([int32_max, int32_max, in_h-crop_b, in_w-crop_r])) def _convert_batchnorm(inexpr, keras_layer, etab): @@ -890,8 +902,8 @@ def _default_skip(inexpr, keras_layer, _): # pylint: disable=unused-argument # 'SeparableConv3D' : _convert_convolution3d, 'MaxPooling3D' : _convert_pooling3d, 'AveragePooling3D' : _convert_pooling3d, - # 'GlobalMaxPooling3D' : _convert_pooling3d, - # 'GlobalAveragePooling3D' : _convert_pooling3d, + 'GlobalMaxPooling3D' : _convert_global_pooling3d, + 'GlobalAveragePooling3D' : _convert_global_pooling3d, 'UpSampling3D' : _convert_upsample3d, 'ZeroPadding3D' : _convert_padding3d, @@ -902,14 +914,16 @@ def _default_skip(inexpr, keras_layer, _): # pylint: disable=unused-argument # 'TimeDistributed' : _default_skip, 'Average' : _convert_merge, + 'Minimum' : _convert_merge, 'Maximum' : _convert_merge, 'Dot' : _convert_merge, 'Permute' : _convert_permute, - # 'Embedding' : _convert_embedding, + 'Embedding' : _convert_embedding, # 'RepeatVector' : _convert_repeat_vector, 'InputLayer' : _default_skip, 'Dropout' : _default_skip, + 'AlphaDropout' : _default_skip, 'SpatialDropout2D' : _default_skip, 'SpatialDropout1D' : _default_skip, 'GaussianDropout' : _default_skip, diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index 4edf0b80de4c..1d8842d69d12 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -87,10 +87,12 @@ def _mx_fully_connected(inputs, attrs): def _get_channel_axis(layout, op_name): - if layout == "NCHW": + if layout in ["NCHW", "NCDHW"]: return 1 if layout == "NHWC": return 3 + if layout == "NDHWC": + return 4 raise tvm.error.OpAttributeInvalid( 'Value {} in attribute "layout" of operator {} is not valid.'.format(layout, op_name)) @@ -149,13 +151,15 @@ def _mx_zeros(inputs, attrs): def _mx_conv(inputs, attrs): kernel_size = attrs.get_int_tuple("kernel") - if len(kernel_size) == 2: + if len(kernel_size) == 3: + return _mx_conv3d(inputs, attrs) + elif len(kernel_size) == 2: return _mx_conv2d(inputs, attrs) elif len(kernel_size) == 1: return _mx_conv1d(inputs, attrs) else: raise tvm.error.OpAttributeInvalid( - '1D or 2D kernels only are supported for operator Convolution') + '1D, 2D or 3D kernels only are supported for operator Convolution') def _mx_conv1d(inputs, attrs): kernel_size = attrs.get_int_tuple("kernel") @@ -226,15 +230,53 @@ def _mx_conv2d(inputs, attrs): return res +def _get_mx_conv3d_attrs(attrs): + kernel_size = attrs.get_int_tuple("kernel") + data_layout = attrs.get_str("layout", "NCDHW") + if "kernel_layout" in attrs.attrs: + kernel_layout = attrs.get_str("kernel_layout") + else: + kernel_layout = "DHWIO" if data_layout == "NDHWC" else "OIDHW" + new_attrs = {} + new_attrs["channels"] = attrs.get_int("num_filter") + new_attrs["kernel_size"] = kernel_size + new_attrs["strides"] = attrs.get_int_tuple("stride", (1, 1, 1)) + new_attrs["padding"] = attrs.get_int_tuple("pad", (0, 0, 0)) + new_attrs["dilation"] = attrs.get_int_tuple("dilate", (1, 1, 1)) + new_attrs["groups"] = attrs.get_int("num_group", 1) + new_attrs["data_layout"] = data_layout + new_attrs["kernel_layout"] = kernel_layout + return new_attrs + + +def _mx_conv3d(inputs, attrs): + kernel_size = attrs.get_int_tuple("kernel") + data_layout = attrs.get_str("layout", "NCDHW") + if len(kernel_size) != 3: + raise tvm.error.OpAttributeInvalid( + 'Only 3D kernels are supported for operator Convolution') + + new_attrs = _get_mx_conv3d_attrs(attrs) + channel_axis = _get_channel_axis(data_layout, "conv3d") + use_bias = not attrs.get_bool("no_bias", False) + res = _op.nn.conv3d(inputs[0], inputs[1], **new_attrs) + if use_bias: + assert len(inputs) == 3 + res = _op.nn.bias_add(res, inputs[2], axis=channel_axis) + return res + + def _mx_conv_transpose(inputs, attrs): kernel_size = attrs.get_int_tuple("kernel") - if len(kernel_size) == 2: + if len(kernel_size) == 3: + return _mx_conv3d_transpose(inputs, attrs) + elif len(kernel_size) == 2: return _mx_conv2d_transpose(inputs, attrs) elif len(kernel_size) == 1: return _mx_conv1d_transpose(inputs, attrs) else: raise tvm.error.OpAttributeInvalid( - '1D or 2D kernels only are supported for operator Convolution') + '1D, 2D or 3D kernels only are supported for operator Convolution') def _mx_conv1d_transpose(inputs, attrs): @@ -300,6 +342,41 @@ def _mx_conv2d_transpose(inputs, attrs): return res +def _mx_conv3d_transpose(inputs, attrs): + if "target_shape" in attrs.attrs: + raise tvm.error.OpAttributeUnImplemented( + 'Attribute "target_shape" is not supported for operator Conv3D-transpose.') + kernel_size = attrs.get_int_tuple("kernel") + if len(kernel_size) != 3: + raise tvm.error.OpAttributeInvalid( + 'Non-3D kernels are not supported for operator Conv3D-transpose.') + data_layout = attrs.get_str("layout", "NCDHW") + channel_axis = _get_channel_axis(data_layout, "conv3d_transpose") + + if "kernel_layout" in attrs.attrs: + kernel_layout = attrs.get_str("kernel_layout") + else: + kernel_layout = "DHWIO" if data_layout == "NDHWC" else "OIDHW" + + new_attrs = {} + new_attrs["channels"] = attrs.get_int("num_filter") + new_attrs["kernel_size"] = kernel_size + new_attrs["strides"] = attrs.get_int_tuple("stride", (1, 1, 1)) + new_attrs["output_padding"] = attrs.get_int_tuple("adj", (0, 0, 0)) + new_attrs["padding"] = attrs.get_int_tuple("pad", (0, 0, 0)) + new_attrs["dilation"] = attrs.get_int_tuple("dilate", (1, 1, 1)) + new_attrs["groups"] = attrs.get_int("num_group", 1) + new_attrs["data_layout"] = data_layout + new_attrs["kernel_layout"] = kernel_layout + use_bias = not attrs.get_bool("no_bias", True) + res = _op.nn.conv3d_transpose(inputs[0], inputs[1], **new_attrs) + + if use_bias: + assert len(inputs) == 3 + res = _op.nn.bias_add(res, inputs[2], axis=channel_axis) + return res + + def _mx_pooling(inputs, attrs): global_pool = attrs.get_bool("global_pool", False) pool_type = attrs.get_str("pool_type") @@ -318,6 +395,34 @@ def _pool2d(new_op, is_avg): new_attrs["count_include_pad"] = attrs.get_bool("count_include_pad", True) return new_op(inputs[0], **new_attrs) + def _pool3d(new_op, is_avg): + kernel_size = attrs.get_int_tuple("kernel") + if len(kernel_size) != 3: + raise tvm.error.OpAttributeInvalid( + 'Only 3D kernels are supported for operator Pool3D.') + new_attrs = {} + new_attrs["pool_size"] = kernel_size + new_attrs["strides"] = attrs.get_int_tuple("stride", (1, 1, 1)) + new_attrs["padding"] = attrs.get_int_tuple("pad", (0, 0, 0)) + new_attrs["ceil_mode"] = (attrs.get_str("pooling_convention", "valid") == "full") + if is_avg: + new_attrs["count_include_pad"] = attrs.get_bool("count_include_pad", True) + return new_op(inputs[0], **new_attrs) + + #3D pooling + if len(_infer_shape(inputs[0])) == 5: + if pool_type == "max": + if global_pool: + return _op.nn.global_max_pool3d(inputs[0]) + return _pool3d(_op.nn.max_pool3d, False) + if pool_type == "avg": + if global_pool: + return _op.nn.global_avg_pool3d(inputs[0]) + return _pool3d(_op.nn.avg_pool3d, True) + raise tvm.error.OpNotImplemented( + 'Operator {} Pooling is not supported for frontend MXNet.' \ + .format(pool_type.capitalize())) + #2D Pooling if pool_type == "max": if global_pool: return _op.nn.global_max_pool2d(inputs[0]) @@ -327,7 +432,8 @@ def _pool2d(new_op, is_avg): return _op.nn.global_avg_pool2d(inputs[0]) return _pool2d(_op.nn.avg_pool2d, True) raise tvm.error.OpNotImplemented( - 'Operator {} Pooling is not supported for frontend MXNet.'.format(pool_type.capitalize())) + 'Operator {} Pooling is not supported for frontend MXNet.' \ + .format(pool_type.capitalize())) def _mx_adaptive_avg_pooling(inputs, attrs): @@ -382,16 +488,22 @@ def _mx_slice(inputs, attrs): begin = list(attrs.get_int_tuple('begin', None)) end = list(attrs.get_int_tuple('end', None)) stride = attrs.get_int_tuple('step', None) + input_shape = _infer_type(inputs[0]).checked_type.shape if begin is None: raise tvm.error.OpAttributeRequired( 'Attribute "begin" not found in operator Slice.') if end is None: raise tvm.error.OpAttributeRequired( 'Attribute "end" not found in operator Slice.') - begin = tuple(x if x is not None else 0 for x in begin) - new_attrs = {'begin': begin, 'end': end} + begin = (x if x is not None else 0 for x in begin) + for i, ed in enumerate(end): + if ed is None: + end[i] = input_shape[i] + new_attrs = {'begin': _expr.const(list(begin), dtype="int32"), + 'end': _expr.const(list(end), dtype="int32")} if stride is not None: - new_attrs['strides'] = stride + stride = (x if x is not None else 1 for x in stride) + new_attrs['strides'] = _expr.const(list(stride), dtype="int32") return _op.strided_slice(inputs[0], **new_attrs) @@ -431,7 +543,9 @@ def _mx_slice_axis(inputs, attrs): else: begin.append(ax_beg) end.append(ax_end) - return _op.strided_slice(inputs[0], begin, end) + return _op.strided_slice(inputs[0], + _expr.const(begin, dtype="int32"), + _expr.const(end, dtype="int32")) def _mx_crop_like(inputs, attrs): @@ -451,9 +565,9 @@ def _mx_crop_like(inputs, attrs): return _op.slice_like(*inputs, **new_attrs) expr = _infer_type(inputs[1]) like_shape = expr.checked_type.shape - new_attrs['begin'] = [0, 0, offset[0], offset[1]] - new_attrs['end'] = [like_shape[0], like_shape[1], offset[0]+like_shape[2], - offset[1]+like_shape[3]] + new_attrs['begin'] = _expr.const([0, 0, offset[0], offset[1]], dtype="int32") + new_attrs['end'] = _expr.const([like_shape[0], like_shape[1], offset[0]+like_shape[2], + offset[1]+like_shape[3]], dtype="int32") return _op.strided_slice(inputs[0], **new_attrs) @@ -627,7 +741,7 @@ def _mx_multibox_detection(inputs, attrs): ret = _op.vision.multibox_transform_loc(inputs[0], inputs[1], inputs[2], **new_attrs0) - return _op.vision.non_max_suppression(ret[0], ret[1], **new_attrs1) + return _op.vision.non_max_suppression(ret[0], ret[1], ret[1], **new_attrs1) def _mx_batch_dot(inputs, attrs): @@ -693,6 +807,10 @@ def _mx_take(inputs, attrs): axis = attrs.get_int("axis", 0) return _op.take(inputs[0], inputs[1].astype("int32"), axis, mode) +def _mx_gather_nd(inputs, attrs): + assert len(inputs) == 2 + assert len(_infer_shape(inputs[1])) > 1, "index tensor to have at least 2 dimensions" + return _op.gather_nd(inputs[0], inputs[1]) def _mx_reverse(inputs, attrs): assert len(inputs) == 1 @@ -724,6 +842,26 @@ def _mx_resize(inputs, attrs): return _op.image.resize(inputs[0], size, coordinate_transformation_mode="align_corners") +def _mx_grid_generator(inputs, attrs): + transform_type = attrs.get_str("transform_type") + if transform_type == 'affine': + target_shape = attrs.get_int_tuple("target_shape") + return _op.image.affine_grid(_op.reshape(inputs[0], (0, 2, 3)), target_shape) + if transform_type == 'warp': + checked_type = _infer_type(inputs[0]).checked_type + batch, _, height, width = get_const_tuple(checked_type.shape) + dtype = checked_type.dtype + identity_affine = relay.const(np.array([[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]], dtype=dtype)) + identity_affine = _op.broadcast_to(identity_affine, (batch, 2, 3)) + normalizer = (2.0 / np.array([width - 1, height - 1])).reshape(1, -1, 1, 1).astype(dtype) + normalized_flow = inputs[0] * relay.const(normalizer) + grid = _op.image.affine_grid(identity_affine, (height, width)) + return grid + normalized_flow + raise ValueError("unknown transform type" + transform_type) + +def _mx_bilinear_sampler(inputs, attrs): + return _op.image.grid_sample(inputs[0], inputs[1], 'bilinear', 'NCHW') + def _mx_roi_pooling(inputs, attrs): new_attrs = {} new_attrs["pooled_size"] = attrs.get_int_tuple("pooled_size") @@ -767,6 +905,7 @@ def _mx_box_nms(inputs, attrs): id_index=id_index, score_index=score_index) nms_out = _op.vision.non_max_suppression(ret[1], ret[0], + ret[2], iou_threshold=iou_thresh, force_suppress=force_suppress, top_k=top_k, @@ -789,6 +928,24 @@ def _mx_l2_normalize(inputs, attrs): return _op.nn.l2_normalize(inputs[0], **new_attrs) +def _mx_softsign(inputs, attrs): + return inputs[0] / (_expr.const(1.0) + _op.abs(inputs[0])) + + +def _mx_softmin(inputs, attrs): + axis = attrs.get_int("axis", -1) + return _op.nn.softmax(_op.negative(inputs[0]), axis) + + +def _mx_hard_sigmoid(inputs, attrs): + x = (_expr.const(0.2) * inputs[0]) + _expr.const(0.5) + return _op.clip(x, a_min=0.0, a_max=1.0) + + +def _mx_reciprocal(inputs, attrs): + return _expr.const(1.0) /inputs[0] + + def _mx_shape_array(inputs, attrs): assert len(inputs) == 1 if attrs.get_int("lhs_begin", None) is not None: @@ -1073,6 +1230,33 @@ def _mx_one_hot(inputs, attrs): return _op.one_hot(indices, on_value, off_value, depth, -1, dtype) +def _mx_depth_to_space(inputs, attrs): + assert len(inputs) == 1 + new_attrs = {} + new_attrs["block_size"] = attrs.get_int("block_size") + return _op.nn.depth_to_space(*inputs, **new_attrs) + + +def _mx_space_to_depth(inputs, attrs): + assert len(inputs) == 1 + new_attrs = {} + new_attrs["block_size"] = attrs.get_int("block_size") + return _op.nn.space_to_depth(*inputs, **new_attrs) + + +def _mx_correlation(inputs, attrs): + assert len(inputs) == 2 + new_attrs = {} + new_attrs["kernel_size"] = attrs.get_int("kernel_size", 1) + new_attrs["max_displacement"] = attrs.get_int("max_displacement", 1) + new_attrs["stride1"] = attrs.get_int("stride1", 1) + new_attrs["stride2"] = attrs.get_int("stride2", 1) + new_attrs["padding"] = attrs.get_int("pad_size", 0) + new_attrs["is_multiply"] = attrs.get_bool("is_multiply", True) + new_attrs["layout"] = "NCHW" + return _op.nn.correlation(*inputs, **new_attrs) + + def _mx_contrib_fifo_buffer(inputs, attrs): new_attrs = {} new_attrs['axis'] = attrs.get_int('axis') @@ -1698,45 +1882,96 @@ def _get_bias_requantize_scale(_inputs, _data_scale, _kernel_scale): res = _op.nn.relu(res) return res + +def _mx_broadcast_to(inputs, attrs): + data = inputs[0] + tgt_shape = attrs.get_int_tuple("shape", []) + + return _op.broadcast_to(data, tgt_shape) + + +def _mx_logical_not(inputs, input_types): + data = inputs[0] + dtype = _infer_type(data).checked_type.dtype + data = _op.cast(data, "bool") if dtype != "bool" else data + + return _op.cast(_op.logical_not(data), dtype) + + +def _mx_broadcast_logical(logical_op): + def impl(inputs, input_types): + lhs_type = _infer_type(inputs[0]).checked_type.dtype + rhs_type = _infer_type(inputs[1]).checked_type.dtype + lhs = _op.cast(inputs[0], "bool") if lhs_type != "bool" else inputs[0] + rhs = _op.cast(inputs[1], "bool") if rhs_type != "bool" else inputs[1] + + return _op.cast(logical_op(lhs, rhs), lhs_type) + return impl + + # Note: due to attribute conversion constraint # ops in the identity set must be attribute free _identity_list = [ + "abs", "log", "exp", "erf", "sqrt", "floor", "ceil", + "round", + "trunc", + "sign", "sigmoid", - "tanh", "negative", "reshape_like", "zeros_like", "ones_like", "where", - "gather_nd", - "tan", "cos", - "sin" + "cosh", + "sin", + "sinh", + "tan", + "tanh", ] _convert_map = { "_copy" : _rename(_op.copy), "relu" : _rename(_op.nn.relu), "broadcast_add" : _rename(_op.add), + "broadcast_plus" : _rename(_op.add), "broadcast_sub" : _rename(_op.subtract), + "broadcast_minus" : _rename(_op.subtract), "broadcast_mul" : _rename(_op.multiply), "broadcast_div" : _rename(_op.divide), "broadcast_mod" : _rename(_op.mod), "broadcast_maximum" : _rename(_op.maximum), "broadcast_minimum" : _rename(_op.minimum), + "broadcast_power" : _rename(_op.power), + "arccos" : _rename(_op.acos), + "arcsin" : _rename(_op.asin), "arctan" : _rename(_op.atan), + "arccosh" : _rename(_op.acosh), + "arcsinh" : _rename(_op.asinh), + "arctanh" : _rename(_op.atanh), "broadcast_equal" : _mx_compare(_op.equal, _rename), "broadcast_not_equal" : _mx_compare(_op.not_equal, _rename), "broadcast_greater" : _mx_compare(_op.greater, _rename), "broadcast_greater_equal": _mx_compare(_op.greater_equal, _rename), "broadcast_lesser" : _mx_compare(_op.less, _rename), "broadcast_lesser_equal" : _mx_compare(_op.less_equal, _rename), + "broadcast_logical_or" : _mx_broadcast_logical(_op.logical_or), + "broadcast_logical_and" : _mx_broadcast_logical(_op.logical_and), + "broadcast_logical_xor" : _mx_broadcast_logical(_op.logical_xor), + "broadcast_to" : _mx_broadcast_to, + "logical_not" : _mx_logical_not, + "_equal" : _mx_compare(_op.equal, _rename), + "_not_equal" : _mx_compare(_op.not_equal, _rename), + "_greater" : _mx_compare(_op.greater, _rename), + "_greater_equal" : _mx_compare(_op.greater_equal, _rename), + "_lesser" : _mx_compare(_op.less, _rename), + "_lesser_equal" : _mx_compare(_op.less_equal, _rename), "elemwise_add" : _rename(_op.add), "elemwise_sub" : _rename(_op.subtract), "elemwise_mul" : _rename(_op.multiply), @@ -1794,6 +2029,10 @@ def _get_bias_requantize_scale(_inputs, _data_scale, _kernel_scale): "softmax" : _softmax_op(_op.nn.softmax), "log_softmax" : _softmax_op(_op.nn.log_softmax), "Softmax" : _softmax_op(_op.nn.softmax), + "softsign" : _mx_softsign, + "softmin" : _mx_softmin, + "hard_sigmoid" : _mx_hard_sigmoid, + "reciprocal" : _mx_reciprocal, # per op specialization "Reshape" : _reshape, "reshape" : _reshape, @@ -1837,9 +2076,11 @@ def _get_bias_requantize_scale(_inputs, _data_scale, _kernel_scale): "pad" : _mx_pad, "Pad" : _mx_pad, "take" : _mx_take, + "gather_nd" : _mx_gather_nd, "reverse" : _mx_reverse, "squeeze" : _mx_squeeze, "broadcast_axis": _mx_broadcast_axis, + "broadcast_axes": _mx_broadcast_axis, "BlockGrad" : _mx_BlockGrad, "shape_array" : _mx_shape_array, "Embedding" : _mx_embedding, @@ -1854,6 +2095,9 @@ def _get_bias_requantize_scale(_inputs, _data_scale, _kernel_scale): "make_loss" : _mx_make_loss, "_contrib_div_sqrt_dim": _mx_contrib_div_sqrt_dim, "one_hot" : _mx_one_hot, + "depth_to_space" : _mx_depth_to_space, + "space_to_depth" : _mx_space_to_depth, + "Correlation" : _mx_correlation, # vision "_contrib_BilinearResize2D" : _mx_resize, "_contrib_MultiBoxPrior" : _mx_multibox_prior, @@ -1865,6 +2109,8 @@ def _get_bias_requantize_scale(_inputs, _data_scale, _kernel_scale): "_contrib_box_nms" : _mx_box_nms, "_contrib_DeformableConvolution" : _mx_deformable_convolution, "_contrib_AdaptiveAvgPooling2D" : _mx_adaptive_avg_pooling, + "GridGenerator" : _mx_grid_generator, + "BilinearSampler" : _mx_bilinear_sampler, # NLP "RNN" : _mx_rnn_layer, "_rnn_param_concat" : _mx_rnn_param_concat, @@ -1875,7 +2121,6 @@ def _get_bias_requantize_scale(_inputs, _data_scale, _kernel_scale): # List of missing operators that are present in NNVMv1 # TODO(tvm-tvm): support all operators. # - # "broadcast_to", # "contrib_fifo_buffer": _mx_contrib_fifo_buffer, "ring_buffer": _mx_contrib_fifo_buffer, # Qnn ops diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 527a1ed2f07b..05a067d3ff14 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -26,12 +26,30 @@ from .. import expr as _expr from .. import function as _function from .. import op as _op +from .. import vision as _vision + +from ..function import Function +from ..expr import Call, Let +from ..expr import If, Tuple, TupleGetItem +from ..expr import RefCreate, RefRead, RefWrite +from ..expr_functor import ExprFunctor +from ..adt import Match, Clause + from .common import AttrCvt, Renamer from .common import get_relay_op, new_var, infer_shape, infer_channels -from .common import infer_type, infer_value, infer_value_simulated, get_name +from .common import infer_type, get_name +from .common import infer_value as _infer_value +from .common import infer_value_simulated as _infer_value_simulated __all__ = ['from_onnx'] +g = None + +def infer_value(input_val, params, mod=None): + return g.infer_value(input_val, params, mod) + +def infer_value_simulated(input_val, params): + return g.infer_value_simulated(input_val, params) class onnx_input(): """ Dual purpose list or dictionary access object.""" @@ -57,8 +75,7 @@ def __setitem__(self, item, value): if isinstance(item, int): self.input_dict[self.input_keys[item]] = value elif isinstance(item, str): - if item not in self.input_dict: - self.input_keys.append(item) + self.input_keys.append(item) self.input_dict[item] = value else: raise ValueError("Only integer and string indexed writes allowed.") @@ -272,7 +289,7 @@ def _impl_v1(cls, inputs, attr, params): 'kernel_shape': 'pool_size', 'pads': ('padding', 0) }, - ignores=['dilations'], + ignores=['dilations', 'storage_order'], custom_check=dimension_constraint())(inputs, attr, params) @@ -325,7 +342,6 @@ class Conv(OnnxOpConverter): def _impl_v1(cls, inputs, attr, params): # Use shape of input to determine convolution type. input_shape = infer_shape(inputs[0]) - if 'auto_pad' in attr: attr['auto_pad'] = attr['auto_pad'].decode('utf-8') if attr['auto_pad'] in ('SAME_UPPER', 'SAME_LOWER'): @@ -350,7 +366,10 @@ def _impl_v1(cls, inputs, attr, params): attr.pop('auto_pad') elif len(attr['kernel_shape']) == 2: sym_pad = True - padding = attr['pads'] + if 'pads' in attr: + padding = attr['pads'] + else: + padding = [0, 0, 0, 0] for i in range(0, len(padding), 2): sym_pad = sym_pad and padding[i] == padding[i + 1] @@ -458,8 +477,15 @@ def _impl_v1(cls, inputs, attr, params): if not transB: inputs[1] = _op.transpose(inputs[1], axes=(1, 0)) inputs[0] = _op.nn.batch_flatten(inputs[0]) - out = _op.nn.dense(_expr.const(alpha) * inputs[0], - inputs[1], units=channels) + + if alpha != 1.0: + inputs[0] *= _expr.const(alpha) + out = _op.nn.dense(inputs[0], inputs[1], units=channels) + + # skip (beta * C) if zero + C_array = params[inputs[2].name_hint].asnumpy() + if (beta == 0.0) or np.array_equal(C_array, np.array([0])): + return out return _op.nn.bias_add(out, _expr.const(beta) * inputs[2]) @@ -495,11 +521,76 @@ def _impl_v1(cls, inputs, attr, params): return _op.nn.dense(inputs[0], input_1_t) +class Mod(OnnxOpConverter): + """ Operator converter for Mod. + """ + + @classmethod + def _impl_v1(cls, inputs, attr, params): + assert len(inputs) == 2, "Mod op take 2 inputs, {} given".format(len(inputs)) + if attr['fmod'] == 1: + op_name = "floor_mod" + else: + op_name = "mod" + return AttrCvt(op_name)(inputs, {}, params) + + class MaxPool(Pool): """ Operator converter for MaxPool """ name = 'max_pool' +class LpPool(OnnxOpConverter): + """ A helper class for lppool op converters. + """ + @classmethod + def _impl_v1(cls, inputs, attr, params): + input_shape = infer_shape(inputs[0]) + dtype = infer_type(inputs[0]).checked_type.dtype + + if 'auto_pad' in attr: + attr['auto_pad'] = attr['auto_pad'].decode('utf-8') + if attr['auto_pad'] in ('SAME_UPPER', 'SAME_LOWER'): + pad_tuple = [] + for axis in range(len(input_shape) - 2): + axis_shape = input_shape[2 + axis] + stride = attr['strides'][axis] + kernel = attr['kernel_shape'][axis] + pad = get_pad_pair(axis_shape, kernel, stride) + pad_tuple.append(pad) + pad_tuple = tuple([val for pair in zip(*pad_tuple) for val in pair]) + attr['pads'] = pad_tuple + elif attr['auto_pad'] == 'VALID': + attr['pads'] = 0 + elif attr['auto_pad'] == 'NOTSET': + pass + else: + msg = 'Value {} in attribute "auto_pad" of operator {} is invalid.' + raise tvm.error.OpAttributeInvalid(msg.format(attr['auto_pad'], "LpPool")) + attr.pop("auto_pad") + + if 'storage_order' in attr: + attr['layout'] = onnx_storage_order2layout(attr['storage_order'], + dims=(len(input_shape) - 2)) + else: + attr['layout'] = onnx_default_layout(dims=(len(input_shape) - 2)) + + p = _expr.const(attr['p'], dtype) + reci_p = _expr.const(1.0 / attr['p'], dtype) + inputs[0] = _op.power(inputs[0], p) + + out = AttrCvt(op_name=dimension_picker("avg_pool"), + transforms={ + 'kernel_shape': 'pool_size', + 'pads': ('padding', 0) + }, + extras={'count_include_pad': True}, + ignores=['p'], + custom_check=dimension_constraint())(inputs, attr, params) + kernels = attr['kernel_shape'] + out = _op.abs(out) * _expr.const(np.prod(kernels).astype(dtype)) + return _op.power(out, reci_p) + class Mul(Elemwise): """ Operator converter for Multiply. @@ -557,6 +648,31 @@ def _impl_v2(cls, inputs, attr, params): }, )(inputs, attr, params) + @classmethod + def _impl_v11(cls, inputs, attr, params): + pad_width = [] + pads = infer_value_simulated(inputs[1], params).asnumpy() + if len(inputs) == 3: + value = infer_value_simulated(inputs[2], params).asnumpy().item() + else: + value = 0 + attr["pad_value"] = value + dims = int(len(pads) / 2) + for i in range(dims): + pad_width.append((pads[i], pads[i+dims])) + attr['pad_width'] = pad_width + pad_mode = attr.get('mode', b'constant').decode('utf-8') + if pad_mode in ['constant', 'edge', 'reflect']: + attr['pad_mode'] = pad_mode + attr.pop('mode', None) + else: + raise tvm.error.OpAttributeInvalid( + 'Value ' + pad_mode + ' in attribute "mode" is invalid for operator Pad.') + + return AttrCvt('pad')(inputs[:1], attr, params) + + + class ParametricSoftPlus(OnnxOpConverter): """ Operator converter for ParametricSoftPlus. @@ -576,7 +692,12 @@ class Prelu(OnnxOpConverter): @classmethod def _impl_v1(cls, inputs, attr, params): assert len(inputs) == 2, "Prelu need 2 inputs, {} given".format(len(inputs)) - return _op.nn.prelu(inputs[0], inputs[1]) + alpha_shape = infer_shape(inputs[1]) + if len(alpha_shape) != 1: + alpha = _op.reshape(inputs[1], (-1,)) + else: + alpha = inputs[1] + return _op.nn.prelu(inputs[0], alpha) class Reciprocal(OnnxOpConverter): @@ -616,7 +737,7 @@ def _impl_v1(cls, inputs, attr, params): def _impl_v5(cls, inputs, attr, params): if get_name(inputs[1]) in params: # pop shape out of parameters since it wont be needed later. - shape = tuple(params.pop(inputs[1].name_hint).asnumpy()) + shape = tuple(params.pop(inputs[1].name_hint).asnumpy().astype("int32")) out = _op.reshape(inputs[0], shape) else: data, shape = inputs @@ -782,7 +903,10 @@ def _impl_v9(cls, inputs, attr, params): if not scales: #Here we are going to higher OPSET version. assert len(inputs) == 2, "Upsample op take 2 inputs, {} given".format(len(inputs)) - scales = params[inputs[1].name_hint].asnumpy() + if get_name(inputs[1]) in params: + scales = params[inputs[1].name_hint].asnumpy() + else: + scales = infer_value_simulated(inputs[1], params).asnumpy() inputs = inputs[:1] assert scales[0] == 1.0 and scales[1] == 1.0 input_shape = infer_shape(inputs[0]) @@ -910,11 +1034,12 @@ def _impl_v1(cls, inputs, attr, params): attr['ends'] = new_ends except KeyError: pass + begin = list(attr['starts']) + end = list(attr['ends']) - return AttrCvt('strided_slice', - transforms={'starts': 'begin', - 'ends': 'end'}, - ignores=['axes'])(inputs, attr) + return _op.strided_slice(inputs[0], + begin=_expr.const(begin, dtype="int32"), + end=_expr.const(end, dtype="int32")) @classmethod def _impl_v10(cls, inputs, attr, params): @@ -930,7 +1055,9 @@ def _impl_v10(cls, inputs, attr, params): starts, ends, axes) starts = new_starts ends = new_ends - return _op.strided_slice(inputs[0], begin=starts, end=ends) + return _op.strided_slice(inputs[0], + begin=_expr.const(starts, dtype="int32"), + end=_expr.const(ends, dtype="int32")) class Gather(OnnxOpConverter): @@ -943,6 +1070,24 @@ def _impl_v1(cls, inputs, attr, params): extras={'axis': axis})(inputs, {}) +class GatherND(OnnxOpConverter): + """ Operator converter for GatherND. + """ + @classmethod + def _impl_v1(cls, inputs, attr, params): + return _op.gather_nd(inputs[0], inputs[1]) + + +class Scatter(OnnxOpConverter): + """ Operator converter for Scatter. + """ + + @classmethod + def _impl_v1(cls, inputs, attr, params): + axis = attr.get('axis', 0) + return _op.scatter(inputs[0], inputs[1], inputs[2], axis) + + class Greater(OnnxOpConverter): """ Operator logical greater. """ @@ -1060,6 +1205,77 @@ class ReduceProd(Reduce): """ name = 'prod' +class ReduceLogSumExp(Reduce): + """ Operator converter for ReduceLogSumExp. + """ + name = 'logsumexp' + + +class ReduceSumSquare(OnnxOpConverter): + """ Operator converter for ReduceSumSquare. + """ + @classmethod + def _impl_v1(cls, inputs, attr, params): + if 'axes' in attr: + axis = attr.get('axes', 0) + else: + axis_len = len(infer_shape(inputs[0])) + axis = list(range(axis_len)) + attr = {'axis': axis, 'keepdims': attr.get('keepdims', True)} + inputs[0] = inputs[0] * inputs[0] + + return AttrCvt("sum")(inputs, attr) + + +class ReduceL1(OnnxOpConverter): + """ Operator converter for ReduceL1. + """ + @classmethod + def _impl_v1(cls, inputs, attr, params): + if 'axes' in attr: + axis = attr.get('axes', 0) + else: + axis_len = len(infer_shape(inputs[0])) + axis = list(range(axis_len)) + attr = {'axis': axis, 'keepdims': attr.get('keepdims', True)} + inputs[0] = _op.abs(inputs[0]) + + return AttrCvt("sum")(inputs, attr) + + +class ReduceL2(OnnxOpConverter): + """ Operator converter for ReduceL2. + """ + @classmethod + def _impl_v1(cls, inputs, attr, params): + if 'axes' in attr: + axis = attr.get('axes', 0) + else: + axis_len = len(infer_shape(inputs[0])) + axis = list(range(axis_len)) + attr = {'axis': axis, 'keepdims': attr.get('keepdims', True)} + inputs[0] = inputs[0] * inputs[0] + out = AttrCvt("sum")(inputs, attr) + + return _op.sqrt(out) + + +class ReduceLogSum(OnnxOpConverter): + """ Operator converter for ReduceLogSum. + """ + @classmethod + def _impl_v1(cls, inputs, attr, params): + if 'axes' in attr: + axis = attr.get('axes', 0) + else: + axis_len = len(infer_shape(inputs[0])) + axis = list(range(axis_len)) + attr = {'axis': axis, 'keepdims': attr.get('keepdims', True)} + out = AttrCvt("sum")(inputs, attr) + + return _op.log(out) + + class ArgMax(OnnxOpConverter): """ Operator converter for ArgMax. """ @@ -1469,8 +1685,69 @@ def _impl_v9(cls, inputs, attr, params): raise ValueError("Expect 1 input only") output = AttrCvt(op_name='argwhere')(inputs, attr, params) + # ONNX NonZero always outputs int64 + output = _op.cast(output, "int64") return _op.transpose(output, axes=(1, 0)) +class TopK(OnnxOpConverter): + """Operator converter for TopK + """ + @classmethod + def _impl_v1(cls, inputs, attr, params): + if len(inputs) != 2: + raise ValueError("Expect 2 input only") + axis = attr.get("axis", -1) + largest = attr.get("largest", 1) + + if largest == 0: + raise ValueError("TVM only supports finding TopK largest elements") + + K = int(infer_value(inputs[1], params).asnumpy()[0]) + + return _op.topk(inputs[0], k=K, axis=axis) + + +class MaxRoiPool(OnnxOpConverter): + """Operator converter for MaxRoiPool. + """ + @classmethod + def _impl_v1(cls, inputs, attr, params): + assert len(inputs) == 2, "MMaxRoiPool op take 2 inputs, {} given".format(len(inputs)) + + data = inputs[0] + rois = inputs[1] + pooled_shape = attr.get("pooled_shape") + spatial_scale = attr.get("spatial_scale", 1.0) + + return _vision.roi_pool(data, rois, pooled_shape, spatial_scale) + + +class RoiAlign(OnnxOpConverter): + """Operator converter for RoiAlign. + """ + @classmethod + def _impl_v1(cls, inputs, attr, params): + if len(inputs) != 3: + raise ValueError("Expect 3 inputs only") + x = inputs[0] + rois = inputs[1] + batch_indices = inputs[2] + mode = attr.get("mode", "avg") + if mode != b'avg': + raise ValueError("RoiAlign in Relay only uses avg mode") + output_height = attr.get("output_height", 1) + output_width = attr.get("output_width", 1) + + sampling_ratio = attr.get("sampling_ratio", 0) + spatial_scale = attr.get("spatial_scale", 1.0) + + batch_indices = _op.expand_dims(batch_indices, axis=1, num_newaxis=1) + batch_indices = _op.cast( + batch_indices, infer_type(rois).type_annotation.dtype) + rois = _op.concatenate([batch_indices, rois], 1) + + return _vision.roi_align(x, rois, [output_height, output_width], + spatial_scale, sampling_ratio) # compatible operators that do NOT require any conversion. _identity_list = [] @@ -1521,6 +1798,9 @@ def _get_convert_map(opset): 'Reciprocal': Reciprocal.get_converter(opset), 'Floor': Renamer('floor'), 'Ceil': Renamer('ceil'), + 'Round': Renamer('round'), + 'IsInf': Renamer('isinf'), + 'IsNaN': Renamer('isnan'), 'Sqrt': Renamer('sqrt'), 'Relu': Renamer('relu'), 'LeakyRelu': Renamer('leaky_relu'), @@ -1530,6 +1810,17 @@ def _get_convert_map(opset): 'Greater': Greater.get_converter(opset), 'Less': Less.get_converter(opset), 'Log': Renamer('log'), + 'ACos': Renamer('acos'), + 'ACosh': Renamer('acosh'), + 'ASin': Renamer('asin'), + 'ASinh': Renamer('asinh'), + 'ATan': Renamer('atan'), + 'ATanh': Renamer('atanh'), + 'Cos': Renamer('cos'), + 'Cosh': Renamer('cosh'), + 'Sin': Renamer('sin'), + 'Sinh': Renamer('sinh'), + 'Tan': Renamer('tan'), 'Tanh': Renamer('tanh'), 'Pow': Renamer('power'), 'PRelu': Prelu.get_converter(opset), @@ -1549,9 +1840,12 @@ def _get_convert_map(opset): 'SoftPlus': SoftPlus.get_converter(opset), 'Gemm': Gemm.get_converter(opset), 'MatMul': MatMul.get_converter(opset), + 'Mod': Mod.get_converter(opset), + 'Xor': Renamer('logical_xor'), # defs/nn 'AveragePool': AveragePool.get_converter(opset), + 'LpPool': LpPool.get_converter(opset), 'MaxPool': MaxPool.get_converter(opset), 'Conv': Conv.get_converter(opset), 'ConvTranspose': ConvTranspose.get_converter(opset), @@ -1566,16 +1860,26 @@ def _get_convert_map(opset): # Recurrent Layers 'LSTM': LSTM.get_converter(opset), + # defs/vision + 'MaxRoiPool': MaxRoiPool.get_converter(opset), + 'RoiAlign': RoiAlign.get_converter(opset), + # defs/reduction 'ReduceMax': ReduceMax.get_converter(opset), 'ReduceMin': ReduceMin.get_converter(opset), 'ReduceSum': ReduceSum.get_converter(opset), 'ReduceMean': ReduceMean.get_converter(opset), 'ReduceProd': ReduceProd.get_converter(opset), - # 'ReduceProd' - # 'ReduceLogSumExp' + 'ReduceLogSumExp': ReduceLogSumExp.get_converter(opset), + 'ReduceLogSum': ReduceLogSum.get_converter(opset), + 'ReduceSumSquare': ReduceSumSquare.get_converter(opset), + 'ReduceL1': ReduceL1.get_converter(opset), + 'ReduceL2': ReduceL2.get_converter(opset), + + #defs/sorting 'ArgMax': ArgMax.get_converter(opset), 'ArgMin': ArgMin.get_converter(opset), + 'TopK': TopK.get_converter(opset), # defs/tensor 'Cast': Cast.get_converter(opset), @@ -1588,6 +1892,9 @@ def _get_convert_map(opset): 'DepthToSpace': DepthToSpace.get_converter(opset), 'SpaceToDepth': SpaceToDepth.get_converter(opset), 'Gather': Gather.get_converter(opset), + 'GatherND': GatherND.get_converter(opset), + 'Scatter': Scatter.get_converter(opset), + 'ScatterElements': Scatter.get_converter(opset), 'Squeeze': AttrCvt('squeeze', {'axes': 'axis'}), 'Unsqueeze': Unsqueeze.get_converter(opset), 'Pad': Pad.get_converter(opset), @@ -1604,8 +1911,7 @@ def _get_convert_map(opset): 'NonZero': NonZero.get_converter(opset), } - -class GraphProto(object): +class GraphProto(ExprFunctor): """A helper class for handling Relay expression copying from pb2.GraphProto. Definition: https://github.com/onnx/onnx/blob/master/onnx/onnx.proto @@ -1627,6 +1933,101 @@ def __init__(self, shape, dtype): self._shape = shape if shape else {} self._dtype = dtype + #For infering Values + self._tmp_params = {} + self._infer_simulated = True + self._mod = None + super(GraphProto, self).__init__() + + def infer_value(self, input_val, params, mod=None): + self._tmp_params = params + self._infer_simulated = False + self._mod = mod + return self.visit(input_val).data + #return _infer_value(input_val, params, mod) + + def infer_value_simulated(self, input_val, params): + self._tmp_params = params + self._infer_simulated = True + return self.visit(input_val).data + #return _infer_value_simulated(input_val, params) + + def infer(self, expr): + if self._infer_simulated: + out = _infer_value_simulated(expr, self._tmp_params) + else: + out = _infer_value(expr, self._tmp_params) + return _expr.const(out.asnumpy()) + + def visit_function(self, fn): + new_params = [self.visit(x) for x in fn.params] + new_body = self.visit(fn.body) + return self.infer(Function( + list(new_params), + new_body, + fn.ret_type, + fn.type_params, + fn.attrs)) + + def visit_let(self, let): + newvar = self.visit(let.var) + newval = self.visit(let.value) + newbody = self.visit(let.body) + return self.infer(Let(newvar, newval, newbody)) + + def visit_call(self, call): + new_fn = self.visit(call.op) + new_args = [self.visit(arg) for arg in call.args] + return self.infer(Call(new_fn, new_args, call.attrs)) + + def visit_var(self, var): + return self.infer(var) + + def visit_global_id(self, global_var): + return self.infer(global_var) + + def visit_if(self, ite): + return self.infer(If( + self.visit(ite.cond), + self.visit(ite.true_branch), + self.visit(ite.false_branch))) + + def visit_tuple(self, tup): + return Tuple([self.visit(field) for field in tup.fields]) + + def visit_tuple_getitem(self, op): + tuple_value = self.visit(op.tuple_value) + if not tuple_value.same_as(op.tuple_value): + return self.infer(TupleGetItem(tuple_value, op.index)) + return self.infer(op) + + def visit_global_var(self, gvar): + return self.infer(gvar) + + def visit_op(self, op): + return op + + def visit_constant(self, const): + return const + + def visit_constructor(self, con): + return con + + def visit_match(self, m): + return self.infer(Match( + self.visit(m.data), + [Clause(c.lhs, self.visit(c.rhs)) for c in m.clauses], + complete=m.complete)) + + def visit_ref_create(self, r): + return RefCreate(self.visit(r.value)) + + def visit_ref_write(self, r): + return RefWrite(self.visit(r.ref), self.visit(r.value)) + + def visit_ref_read(self, r): + return RefRead(self.visit(r.ref)) + def from_onnx(self, graph, opset): """Construct Relay expression from ONNX graph. @@ -1885,6 +2286,7 @@ def from_onnx(model, warnings.warn(str(e)) except ImportError: pass + global g g = GraphProto(shape, dtype) graph = model.graph if opset is None: @@ -1893,4 +2295,5 @@ def from_onnx(model, except AttributeError: opset = 1 mod, params = g.from_onnx(graph, opset) + g = None return mod, params diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index af60bf20c847..d2451cd80635 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -34,6 +34,7 @@ from .common import get_relay_op from .common import infer_shape as _infer_shape from .common import infer_value as _infer_value +from .common import infer_value_simulated as _infer_value_simulated from .common import infer_type as _infer_type from ..prelude import Prelude, StaticTensorArrayOps @@ -132,27 +133,53 @@ def _impl(inputs, input_types): return get_relay_op(name)(data0, data1) return _impl -def _abs(): + +def _unary(name): def _impl(inputs, input_types): - data = inputs[0] - return _op.abs(data) + input_type = input_types[0] + data = _convert_elemwise_input(inputs[0], input_type) + + return get_relay_op(name)(data) return _impl + +def _log1p(): + def _impl(inputs, input_types): + # 1_plus_log x = log(x + 1) + one = _expr.const(1, dtype="float32") + return _op.log(inputs[0] + one) + return _impl + + def _arange(): def _impl(inputs, input_types): + def _get_value(val, dtype): + if isinstance(val, _expr.Expr): + return _op.cast(val, _convert_data_type(dtype)) + return _create_typed_const(val, dtype) + + def _get_type(val, inp_type): + if isinstance(val, _expr.Expr): + dtype = str(_infer_type(val).checked_type) + return dtype if dtype != "float32" else "float" + return inp_type + if len(inputs) == 5: - dtype = "float" if "float" in input_types[0:1] else _convert_dtype_value(inputs[1]) - start = _create_typed_const(0, dtype) - stop = _create_typed_const(inputs[0], dtype) - step = _create_typed_const(1, dtype) + dtype0 = _get_type(inputs[0], input_types[0]) + dtype = "float" if dtype0 == "float" else _convert_dtype_value(inputs[1]) + start = _get_value(0, dtype) + stop = _get_value(inputs[0], dtype) + step = _get_value(1, dtype) elif len(inputs) == 7: - dtype = "float" if "float" in input_types[0:3] else _convert_dtype_value(inputs[3]) - start = _create_typed_const(inputs[0], dtype) - stop = _create_typed_const(inputs[1], dtype) - step = _create_typed_const(inputs[2], dtype) + types = [_get_type(inputs[i], input_types[i]) for i in range(3)] + dtype = "float" if "float" in types else _convert_dtype_value(inputs[3]) + start = _get_value(inputs[0], dtype) + stop = _get_value(inputs[1], dtype) + step = _get_value(inputs[2], dtype) else: msg = "Unknown number of arguments (%d) to parse." % (len(inputs)) raise AssertionError(msg) + return _op.transform.arange(start=start, stop=stop, step=step, @@ -184,12 +211,12 @@ def tensor_array_concat(lst, axis): assert axis == 0, "Tensor array concat supported only for axis 0" tensor_array, shape = _convert_to_tensor_array(lst, prelude) concat_shape = (Any(),) + shape[1:] - static_tensor_array_ops = StaticTensorArrayOps(prelude, "float32", shape) - static_tensor_array_ops.define_tensor_get_data(concat_shape) - concat = prelude.get_var_static('tensor_array_concat', "float32", shape) concatenated = concat(tensor_array) - get_tensor = prelude.get_var_static('tensor_get_data', "float32", shape) + + static_tensor_array_ops = StaticTensorArrayOps(prelude, "float32", concat_shape) + static_tensor_array_ops.register() + get_tensor = prelude.get_var_static('tensor_get_data', "float32", concat_shape) return get_tensor(concatenated) def _impl(inputs, input_types): @@ -223,15 +250,25 @@ def _impl(inputs, input_types): begin = [0] * len(end) dim = int(inputs[1]) - begin[dim] = int(inputs[2]) + if isinstance(inputs[2], _expr.Call): + begin[dim] = np.asscalar(_infer_value(inputs[2], {}).asnumpy().astype(np.int)) + else: + begin[dim] = int(inputs[2]) if isinstance(inputs[3], str) and inputs[3].isdigit(): end[dim] = min(end[dim], int(inputs[3])) else: - end[dim] = inputs[3] + if isinstance(inputs[3], _expr.Call): + end[dim] = np.asscalar(_infer_value(inputs[3], {}).asnumpy().astype(np.int)) + else: + end[dim] = inputs[3] strides.append(int(inputs[4])) - return _op.transform.strided_slice(data, begin, end, strides) + return _op.transform.strided_slice(data, + begin=_expr.const(begin), + end=_expr.const(end), + strides=_expr.const(strides), + slice_mode="size") return _impl def _split(): @@ -275,15 +312,7 @@ def _impl(inputs, input_types): def _take(): def _impl(inputs, input_types): data = inputs[0] - import torch - - if isinstance(inputs[1], _expr.Var): - indices = _op.cast(inputs[1], "int32") - elif isinstance(inputs[1], torch.Tensor): - indices = _wrap_const(inputs[1].numpy()) - else: - msg = "Data type %s could not be parsed in take operator." % (type(inputs[1])) - raise AssertionError(msg) + indices = _op.cast(inputs[1], "int32") return _op.transform.take(data, indices=indices) return _impl @@ -333,6 +362,40 @@ def _impl(inputs, input_types): return _op.transform.repeat(data, repeats=repeats, axis=axis) return _impl + +def _addcdiv(): + def _impl(inputs, input_types): + data = inputs[0] + c = _expr.const(inputs[3]) + t1 = inputs[1] + t2 = inputs[2] + + return data + (c * (t1 / t2)) + return _impl + + +def _addcmul(): + def _impl(inputs, input_types): + data = inputs[0] + c = _expr.const(inputs[3]) + t1 = inputs[1] + t2 = inputs[2] + + return data + (c * (t1 * t2)) + return _impl + + +def _where(): + def _impl(inputs, input_types): + cond = inputs[0] + x = inputs[1] + y = inputs[2] + + return _op.where(cond, x, y) + + return _impl + + def _ones(): def _impl(inputs, input_types): data = inputs[0] @@ -348,12 +411,25 @@ def _impl(inputs, input_types): msg = "Data type %s could not be parsed in ones op" % (type(data)) raise AssertionError(msg) - dtype_map = {6: "float32", 3: "int32"} - dtype_id = inputs[1] - assert dtype_id in dtype_map, "Unsupported dtype %d" % dtype_id - return _op.full(_expr.const(1), shape, dtype=dtype_map[dtype_id]) + dtype = _convert_data_type(_convert_dtype_value(inputs[1])) + + return _op.full(_expr.const(1), shape, dtype=dtype) return _impl +def _ones_like(): + def _impl(inputs, input_types): + data = inputs[0] + out = _op.ones_like(data) + + # If the input and the output datatype is different, do a cast + dtype = _convert_data_type(_convert_dtype_value(inputs[1])) + if input_types[0] not in dtype: + out = _op.cast(out, dtype) + + return out + return _impl + + def _zeros(): def _impl(inputs, input_types): data = inputs[0] @@ -369,12 +445,91 @@ def _impl(inputs, input_types): msg = "Data type %s could not be parsed in zeros op" % (type(data)) raise AssertionError(msg) - dtype_map = {6: "float32", 3: "int32"} - dtype_id = inputs[1] - assert dtype_id in dtype_map, "Unsupported dtype %d" % dtype_id - return _op.full(_expr.const(0), shape, dtype=dtype_map[dtype_id]) + dtype = _convert_data_type(_convert_dtype_value(inputs[1])) + + return _op.full(_expr.const(0), shape, dtype=dtype) return _impl + +def _zeros_like(): + def _impl(inputs, input_types): + data = inputs[0] + out = _op.zeros_like(data) + + # If the input and the output datatype is different, do a cast + dtype = _convert_data_type(_convert_dtype_value(inputs[1])) + if input_types[0] not in dtype: + out = _op.cast(out, dtype) + + return out + return _impl + + +def _full(): + def _impl(inputs, input_types): + data = inputs[0] + + fill_value = inputs[1] + import torch + if isinstance(data, _expr.Expr): + shape = _infer_shape(data) + elif isinstance(data, list): + shape = data + elif isinstance(data, (torch.Tensor, np.ndarray)): + shape = data.shape + else: + msg = "Data type %s could not be parsed in zeros op" % (type(data)) + raise AssertionError(msg) + + if inputs[2] is not None: # dtype given + dtype = _convert_data_type(_convert_dtype_value(inputs[2])) + else: + dtype = data.type_annotation.dtype + + return _op.full(_expr.const(fill_value), shape, dtype=dtype) + return _impl + +def _full_like(): + def _impl(inputs, input_types): + data = inputs[0] + fill_value = inputs[1] + + out = _op.full_like(data, _expr.const(fill_value)) + + # If the input and the output datatype is different, do a cast + dtype = _convert_data_type(_convert_dtype_value(inputs[2])) + if input_types[0] not in dtype: + out = _op.cast(out, dtype) + + return out + return _impl + + +def _linspace(): + def _impl(inputs, input_types): + start = inputs[0] + stop = inputs[1] + step = inputs[2] + + # Find the spacing between values as step + if step != 1: + step = (stop - start) / (step - 1) + stop = stop + step + else: + stop = start + step + + dtype = "float" if "float" in input_types[0:3] else _convert_dtype_value(inputs[3]) + start = _create_typed_const(start, dtype) + stop = _create_typed_const(stop, dtype) + step = _create_typed_const(step, dtype) + + return _op.transform.arange(start=start, + stop=stop, + step=step, + dtype=_convert_data_type(dtype)) + return _impl + + def _relu(): def _impl(inputs, input_types): data = inputs[0] @@ -415,14 +570,13 @@ def _impl(inputs, input_types): def _gelu(): def _impl(inputs, input_types): - import math data = inputs[0] - - def _pow3(x): - return x * x * x - return _expr.const(0.5) * data * (_expr.const(1.0) + - _op.tanh(_expr.const(math.sqrt(2.0 / math.pi)) * - (data + _expr.const(0.044715) * _pow3(data)))) + # gelu is data * normcdf(data) + # normcdf expressed as erf because we don't currently have that intrinsic + # note that there is also a fastgelu variant approximating normcdf + # with tanh and third order polynomials, but this is "true" gelu + return data * (_expr.const(0.5) + + _op.erf(data * _expr.const(0.5**0.5)) * _expr.const(0.5)) return _impl def _selu(): @@ -602,17 +756,24 @@ def _impl(inputs, input_types): if isinstance(dilation, _expr.Expr): dilation = _infer_shape(dilation) - data_layout = "NCHW" - kernel_layout = "OIHW" - conv_op = _op.nn.conv2d - if use_transpose: - assert len(kernel_size) == 2, "ConvTranspose 3D not supported" - conv_op = _op.nn.conv2d_transpose + if len(kernel_size) == 3: + conv_op = _op.nn.conv3d_transpose + else: + conv_op = _op.nn.conv2d_transpose + else: + if len(kernel_size) == 3: + conv_op = _op.nn.conv3d + else: + conv_op = _op.nn.conv2d + if len(kernel_size) == 3: - conv_op = _op.nn.conv3d data_layout = "NCDHW" kernel_layout = "OIDHW" + else: + data_layout = "NCHW" + kernel_layout = "OIHW" + conv_out = conv_op(data, weight, @@ -748,6 +909,26 @@ def _impl(inputs, input_types): scale=True) return _impl + +def _group_norm(): + def _impl(inputs, input_types): + data = inputs[0] + gamma = inputs[2] + beta = inputs[3] + num_groups = inputs[1] + epsilon = float(inputs[4]) + + return _op.nn.group_norm(data, + gamma=gamma, + beta=beta, + num_groups=num_groups, + axis=1, + epsilon=epsilon, + center=True, + scale=True) + return _impl + + def _transpose(prelude): def _impl(inputs, input_types): data = inputs[0] @@ -782,7 +963,7 @@ def _impl(inputs, input_types): axes[src] = dst axes[dst] = src else: - axes = inputs[1] + axes = _infer_shape(inputs[1], prelude.mod) return _op.transform.transpose(data, axes) return _impl @@ -850,7 +1031,10 @@ def _impl(inputs, input_types): def _numtotensor(): def _impl(inputs, input_types): val = inputs[0] - dtype = type(val) + dtype = input_types[0] + + if isinstance(val, _expr.Expr): + return val if isinstance(val, tvm.tir.IntImm): val = val.__int__() @@ -860,21 +1044,34 @@ def _impl(inputs, input_types): return arr return _impl + +def _tensortonum(): + def _impl(inputs, input_types): + return inputs[0] + return _impl + + def _view(): def _impl(inputs, input_types): data = inputs[0] if len(inputs) == 3: - new_shape = [inputs[1], _infer_shape(inputs[2])[0]] + shape_inp = [inputs[1], _infer_shape(inputs[2])[0]] else: if isinstance(inputs[1], list): - new_shape = inputs[1] + shape_inp = inputs[1] else: - new_shape = _infer_shape(inputs[1]) + shape_inp = _infer_shape(inputs[1]) + new_shape = shape_inp + for i, shape in enumerate(shape_inp): + if isinstance(shape, _expr.Expr): + val = _infer_value_simulated(shape, {}) + new_shape[i] = np.asscalar(val.asnumpy()) return _op.transform.reshape(data, new_shape) return _impl + def _reshape(): def _impl(inputs, input_types): data = inputs[0] @@ -987,6 +1184,44 @@ def _impl(inputs, input_types): return _impl +def _norm(): + def _impl(inputs, input_types): + data = inputs[0] + axis = None + keepdims = False + if len(inputs) > 3: + axis = list(_infer_shape(inputs[2])) + keepdims = bool(inputs[3]) + + order = inputs[1] + if order == np.inf: + return _op.reduce.max(_op.abs(data), axis=axis, keepdims=keepdims) + elif order == np.NINF: + return _op.reduce.min(_op.abs(data), axis=axis, keepdims=keepdims) + else: + reci_order = _expr.const(1.0 / order) + order = _expr.const(order) + return _op.power(_op.reduce.sum(_op.power(_op.abs(data), order), + axis=axis, + keepdims=keepdims), + reci_order) + return _impl + + +def _frobenius_norm(): + def _impl(inputs, input_types): + data = inputs[0] + axis = None + keepdims = False + if len(inputs) > 2: + axis = list(_infer_shape(inputs[1])) + keepdims = bool(inputs[2]) + + return _op.sqrt(_op.reduce.sum((data * data), axis=axis, keepdims=keepdims)) + + return _impl + + def _std(): def _impl(inputs, input_types): data = inputs[0] @@ -1079,7 +1314,10 @@ def _impl(inputs, input_types): end[axis] = i + unif_size stride = [1] * len(shape) - chunk_out = _op.transform.strided_slice(data, begin, end, stride) + chunk_out = _op.transform.strided_slice(data, + begin=_expr.const(begin), + end=_expr.const(end), + strides=_expr.const(stride)) chunks.append(chunk_out) if dim % num_chunks: @@ -1089,42 +1327,88 @@ def _impl(inputs, input_types): end[axis] = dim stride = [1] * len(shape) - chunk_out = _op.transform.strided_slice(data, begin, end, stride) + chunk_out = _op.transform.strided_slice(data, + begin=_expr.const(begin), + end=_expr.const(end), + strides=_expr.const(stride)) chunks.append(chunk_out) return chunks return _impl -def _matmul(): - def _impl(inputs, input_types): - data0 = inputs[0] - data1 = inputs[1] - data1_t = _op.transpose(data1, axes=(1, 0)) +def _matmul(prelude): + def _impl(inputs, input_types): + + inputs_0 = inputs[0] + inputs_1 = inputs[1] + + # Need to check input shape as batch matmul must be supported. + a_shape = _infer_shape(inputs_0, prelude.mod) + b_shape = _infer_shape(inputs_1, prelude.mod) + + # When performing a batch matmul, we need to properly handle N-dim shapes. + if len(a_shape) > 2 or len(b_shape) > 2: + # Convert a and b into 3 dimensional tensors. + a = _op.reshape(inputs_0, [-1, a_shape[-2], a_shape[-1]]) + b = _op.reshape(inputs_1, [-1, b_shape[-2], b_shape[-1]]) + # Broadcast b to match batch size of a + new_b_shape = list(_infer_shape(b, prelude.mod)) + new_a_shape = _infer_shape(a, prelude.mod) + if new_a_shape[0] > new_b_shape[0]: + new_b_shape[0] = new_a_shape[0] + b = _op.broadcast_to(b, new_b_shape) + # Transpose matrix dimensions of b. + b = _op.transpose(b, [0, 2, 1]) + # Perform a batch matmul. + output = _op.nn.batch_matmul(a, b) + # Reshape output to original dimensions. + return _op.reshape(output, [*a_shape[:-2], a_shape[-2], b_shape[-1]]) + + # Otherwise a simple dense op will get the job done. + if len(b_shape) == 1: + input_1 = _op.expand_dims(inputs_1, 0, 1) + else: + input_1 = _op.transpose(inputs_1, axes=(1, 0)) + + out = _op.nn.dense(inputs_0, input_1) + + if len(b_shape) == 1: + out = _op.squeeze(out, axis=[-1]) + + return out - return _op.nn.dense(data0, data1_t) return _impl + def _expand(): def _impl(inputs, input_types): data_in = inputs[0] if isinstance(data_in, _expr.Expr): - shape = _infer_shape(data_in) + shape = list(_infer_shape(data_in)) ndims = len(shape) sizes = _infer_shape(inputs[1]) out = inputs[0] + out_dims = len(sizes) + if ndims < out_dims: + num_newaxis = out_dims - ndims + out = _op.expand_dims(out, axis=0, num_newaxis=num_newaxis) + shape = [1] * num_newaxis + shape + for i in range(ndims): - if sizes[i] in {-1, shape[i]}: + if sizes[i] == -1 or sizes[i] == shape[i]: continue data = list() for temp in range(sizes[i]): data.append(out) - call = _op.tensor.concatenate(data, i) - return call + out = _op.tensor.concatenate(data, i) + + return out return _impl + def _int(): def _impl(inputs, input_types): if isinstance(inputs[0], _expr.Expr): @@ -1142,33 +1426,36 @@ def _impl(inputs, input_types): return None return _impl -def _pad(): - def _impl(inputs, input_types): - data = inputs[0] - padding = inputs[1] - pad_width = list(zip(padding, padding)) - pad_value = inputs[2] - return _op.nn.pad(data, pad_width, pad_value) - return _impl - -def _sqrt(): - def _impl(inputs, input_types): - data = inputs[0] - return _op.tensor.sqrt(data) - return _impl - - -def _rsqrt(): +def _pad(mode): def _impl(inputs, input_types): data = inputs[0] - return _op.tensor.rsqrt(data) - return _impl - + if isinstance(inputs[1], list): + pad_list = inputs[1] + else: + pad_list = list(_infer_shape(inputs[1])) + + # initialize paddings based on input len + pad_len = len(_infer_shape(data)) * 2 + paddings = [0] * pad_len + + if len(pad_list) >= 2: + paddings[-1] = pad_list[1] + paddings[-2] = pad_list[0] + if len(pad_list) >= 4: + paddings[-3] = pad_list[3] + paddings[-4] = pad_list[2] + if len(pad_list) >= 6: + paddings[-5] = pad_list[5] + paddings[-6] = pad_list[4] + + # group into tuple of 2 ints + paddings = [paddings[i:i + 2] for i in range(0, len(paddings), 2)] + + if mode == "constant": + return _op.nn.pad(data, paddings, pad_value=inputs[2], pad_mode=mode) + else: + return _op.nn.pad(data, paddings, pad_mode=mode) -def _ceil(): - def _impl(inputs, input_types): - data = inputs[0] - return _op.ceil(data) return _impl @@ -1181,20 +1468,6 @@ def _impl(inputs, input_types): return _impl -def _floor(): - def _impl(inputs, input_types): - data = inputs[0] - return _op.floor(data) - return _impl - - -def _round(): - def _impl(inputs, input_types): - data = inputs[0] - return _op.round(data) - return _impl - - def _to(): def _impl(inputs, input_types): data = inputs[0] @@ -1263,6 +1536,32 @@ def func(x): return _impl + +def _upsample3d(method): + def _impl(inputs, input_types): + if isinstance(inputs[1], _expr.Var): + out_size = _infer_shape(inputs[1]) + elif isinstance(inputs[1], list): + infer_res = [_infer_value(size, {}) for size in inputs[1]] + out_size = [np.asscalar(res.asnumpy().astype(np.int)) + for res in infer_res] + + data = inputs[0] + + if len(inputs) > 2: + align_corners = inputs[2] + else: + align_corners = False + + if align_corners: + coord_trans = "align_corners" + else: + coord_trans = "half_pixel" + + return _op.image.resize3d(data, out_size, "NCDHW", method, coord_trans) + return _impl + + def _expand_as(): def _impl(inputs, input_types): # TODO: maybe fix this @@ -1272,17 +1571,6 @@ def _impl(inputs, input_types): return inputs[0] return _impl -def _neg(): - def _impl(inputs, input_types): - data = inputs[0] - return _op.tensor.negative(data) - return _impl - -def _tanh(): - def _impl(inputs, input_types): - data = inputs[0] - return _op.tensor.tanh(data) - return _impl def _Bool(): def _impl(inputs, input_types): @@ -1320,16 +1608,7 @@ def _impl(inputs, input_types): def _bitwise_xor(): def _impl(inputs, input_types): lhs = inputs[0] - - import torch - if isinstance(inputs[1], _expr.Var): - rhs = inputs[1] - elif isinstance(inputs[1], torch.Tensor): - rhs = _wrap_const(inputs[1].numpy()) - else: - msg = "Data type %s could not be parsed in bitwise_xor operator." % (type(inputs[1])) - raise AssertionError(msg) - + rhs = inputs[1] lhs = _op.cast(lhs, "bool") if input_types[0] == "bool" else _op.cast(lhs, "int") rhs = _op.cast(rhs, "bool") if input_types[1] == "bool" else _op.cast(rhs, "int") @@ -1348,34 +1627,12 @@ def _impl(inputs, input_types): def _logical_xor(): def _impl(inputs, input_types): lhs = _op.cast(inputs[0], "bool") - - import torch - if isinstance(inputs[1], _expr.Var): - rhs = inputs[1] - elif isinstance(inputs[1], torch.Tensor): - rhs = _wrap_const(inputs[1].numpy()) - else: - msg = "Data type %s could not be parsed in logical_xor operator." % (type(inputs[1])) - raise AssertionError(msg) - - rhs = _op.cast(rhs, "bool") + rhs = _op.cast(inputs[1], "bool") return _op.logical_xor(lhs, rhs) return _impl -def _isfinite(): - def _impl(inputs, input_types): - return _op.isfinite(inputs[0]) - return _impl - - -def _isnan(): - def _impl(inputs, input_types): - return _op.isnan(inputs[0]) - return _impl - - def _list_getitem(prelude): def _impl(inputs, input_types): return prelude.nth(inputs[0], _wrap_const(inputs[1])) @@ -1388,6 +1645,14 @@ def _impl(inputs, input_types): return _impl +def _type_as(): + def _impl(inputs, input_types): + assert len(inputs) == 2 + assert len(input_types) == 2 + return _op.cast(inputs[0], _convert_data_type(input_types[1])) + return _impl + + def _add(prelude): # add_ is overloaded for tensor add and list concat def _impl(inputs, input_types): @@ -1400,18 +1665,62 @@ def _impl(inputs, input_types): def _tensor_array_stack(prelude): def _impl(inputs, input_types): tensor_array, shape = _convert_to_tensor_array(inputs[0], prelude) + + stacked_shape = (Any(),) + shape stack = prelude.get_var_static('tensor_array_stack', "float32", shape) stacked = stack(tensor_array) - stacked_shape = (Any(),) + shape - static_tensor_array_ops = StaticTensorArrayOps(prelude, "float32", shape) - static_tensor_array_ops.define_tensor_get_data(stacked_shape) - # passing stacked_shape below gives "'Prelude' object has no attribute" error - get_tensor = prelude.get_var_static('tensor_get_data', "float32", shape) + static_tensor_array_ops = StaticTensorArrayOps(prelude, "float32", stacked_shape) + static_tensor_array_ops.register() + get_tensor = prelude.get_var_static('tensor_get_data', "float32", stacked_shape) return get_tensor(stacked) return _impl +def _rsub(): + def _impl(inputs, input_types): + # TODO: Figure out a better way to get typing to work for tensor + scalar + type0 = input_types[0] + if isinstance(inputs[1], _expr.Expr): + type0 = input_types[1] + + type1 = input_types[1] + if isinstance(inputs[0], _expr.Expr): + type1 = input_types[0] + + data1 = _convert_elemwise_input(inputs[0], type0) + data0 = _convert_elemwise_input(inputs[1], type1) + alpha = _expr.const(float(inputs[2])) + + return get_relay_op("subtract")(data0, alpha * data1) + return _impl + + +def _embedding(): + def _impl(inputs, input_types): + weight = inputs[0] + indices = inputs[1] + + return _op.take(weight, indices.astype('int32'), axis=0) + return _impl + + +def _one_hot(): + def _impl(inputs, input_types): + indices = inputs[0].astype('int32') + num_classes = inputs[1] + if num_classes == -1: + msg = "Inferring the number of classes is not yet supported." + raise NotImplementedError(msg) + + dtype = 'int32' + on_value = tvm.relay.const(1.0, dtype) + off_value = tvm.relay.const(0.0, dtype) + + return _op.one_hot(indices, on_value, off_value, num_classes, -1, dtype) + return _impl + + # Helper functions for operator implementation def _convert_dtype_value(val): convert_torch_dtype_map = {7:"torch.float64", @@ -1491,6 +1800,7 @@ def _wrap_const(c): def _get_convert_map(prelude): convert_map = { "aten::device" : _none(), + "prim::device" : _none(), "aten::sub" : _elemwise("subtract"), "aten::sub_" : _elemwise("subtract"), "aten::max" : _elemwise("maximum"), @@ -1498,12 +1808,19 @@ def _get_convert_map(prelude): "aten::mul" : _elemwise("multiply"), "aten::mul_" : _elemwise("multiply"), "aten::pow" : _elemwise("power"), - "aten::abs" : _abs(), "aten::arange" : _arange(), "aten::div" : _elemwise("divide"), "aten::div_" : _elemwise("divide"), + "aten::floor_divide" : _elemwise("floor_divide"), + "aten::addcdiv" : _addcdiv(), + "aten::addcmul" : _addcmul(), "aten::ones" : _ones(), + "aten::ones_like" : _ones_like(), "aten::zeros" : _zeros(), + "aten::zeros_like" : _zeros_like(), + "aten::full" : _full(), + "aten::full_like" : _full_like(), + "aten::linspace" : _linspace(), "aten::reciprocal" : _reciprocal(), "aten::repeat" : _repeat(), "aten::repeat_interleave" : _repeat_interleave(), @@ -1516,12 +1833,14 @@ def _get_convert_map(prelude): "aten::split_with_sizes" : _split_with_sizes(), "aten::select" : _select(), "aten::take" : _take(), + "aten::where" : _where(), "aten::topk" : _topk(), "aten::relu" : _relu(), "aten::relu_" : _relu(), "aten::prelu" : _prelu(), "aten::leaky_relu" : _leaky_relu(), "aten::elu" : _elu(), + "aten::elu_" : _elu(), "aten::celu" : _celu(), "aten::gelu" : _gelu(), "aten::selu" : _selu(), @@ -1542,6 +1861,7 @@ def _get_convert_map(prelude): "aten::batch_norm" : _batch_norm(), "aten::instance_norm" : _instance_norm(), "aten::layer_norm" : _layer_norm(), + "aten::group_norm" : _group_norm(), "aten::transpose" : _transpose(prelude), "aten::transpose_" : _transpose(prelude), "aten::t" : _transpose(prelude), @@ -1562,27 +1882,60 @@ def _get_convert_map(prelude): "aten::alpha_dropout" : _dropout(), "aten::mean" : _mean(), "aten::chunk" : _chunk(prelude), - "aten::matmul" : _matmul(), + "aten::matmul" : _matmul(prelude), "aten::expand" : _expand(), "aten::Int" : _int(), "prim::NumToTensor" : _numtotensor(), - "aten::constant_pad_nd" : _pad(), + "prim::ImplicitTensorToNum" : _tensortonum(), + "aten::ScalarImplicit" : _tensortonum(), + "aten::constant_pad_nd" : _pad("constant"), + "aten::reflection_pad1d" : _pad("reflect"), + "aten::reflection_pad2d" : _pad("reflect"), + "aten::replication_pad1d" : _pad("edge"), + "aten::replication_pad2d" : _pad("edge"), + "aten::replication_pad3d" : _pad("edge"), "aten::permute" : _transpose(prelude), "aten::sum" : _reduce("sum"), "aten::prod" : _reduce("prod"), "aten::argmin" : _reduce("argmin"), "aten::argmax" : _reduce("argmax"), + "aten::norm" : _norm(), + "aten::frobenius_norm" : _frobenius_norm(), "aten::std" : _std(), "aten::var" : _variance(), - "aten::sqrt" : _sqrt(), - "aten::rsqrt" : _rsqrt(), - "aten::ceil" : _ceil(), + "aten::abs" : _unary("abs"), + "aten::neg" : _unary("negative"), + "aten::cos" : _unary("cos"), + "aten::cosh" : _unary("cosh"), + "aten::sin" : _unary("sin"), + "aten::sinh" : _unary("sinh"), + "aten::tan" : _unary("tan"), + "aten::tanh" : _unary("tanh"), + "aten::acos" : _unary("acos"), + "aten::asin" : _unary("asin"), + "aten::atan" : _unary("atan"), + "aten::log" : _unary("log"), + "aten::log2" : _unary("log2"), + "aten::log10" : _unary("log10"), + "aten::log1p" : _log1p(), + "aten::exp" : _unary("exp"), + "aten::erf" : _unary("erf"), + "aten::trunc" : _unary("trunc"), + "aten::sign" : _unary("sign"), + "aten::sqrt" : _unary("sqrt"), + "aten::rsqrt" : _unary("rsqrt"), + "aten::ceil" : _unary("ceil"), + "aten::floor" : _unary("floor"), + "aten::round" : _unary("round"), + "aten::isfinite" : _unary("isfinite"), + "aten::isinf" : _unary("isinf"), + "aten::isnan" : _unary("isnan"), "aten::clamp" : _clamp(), - "aten::floor" : _floor(), - "aten::round" : _round(), "aten::detach" : _identity(), "aten::upsample_bilinear2d" : _upsample("bilinear"), "aten::upsample_nearest2d" : _upsample("nearest_neighbor"), + "aten::upsample_trilinear3d" : _upsample3d("trilinear"), + "aten::upsample_nearest3d" : _upsample3d("nearest_neighbor"), "aten::expand_as" : _expand_as(), "aten::lt" : _elemwise("less"), "aten::gt" : _elemwise("greater"), @@ -1594,21 +1947,21 @@ def _get_convert_map(prelude): "aten::logical_xor" : _logical_xor(), "aten::bitwise_not" : _bitwise_not(), "aten::bitwise_xor" : _bitwise_xor(), - "aten::isfinite" : _isfinite(), - "aten::isnan" : _isnan(), "aten::Bool" : _Bool(), "aten::Float" : _Float(), - "aten::neg" : _neg(), - "aten::tanh" : _tanh(), "aten::adaptive_avg_pool3d" : _adaptive_avg_pool_3d(), "aten::adaptive_max_pool3d" : _adaptive_max_pool_3d(), - "aten::mm" : _matmul(), + "aten::rsub" : _rsub(), + "aten::embedding" : _embedding(), + "aten::one_hot" : _one_hot(), + "aten::mm" : _matmul(prelude), "relay::tensor_array_stack" : _tensor_array_stack(prelude), "aten::add" : _add(prelude), "aten::add_" : _add(prelude), "aten::stack" : _tensor_array_stack(prelude), "aten::__getitem__" : _list_getitem(prelude), "aten::len" : _list_len(prelude), + "aten::type_as" : _type_as(), } return convert_map @@ -1767,7 +2120,7 @@ def _get_constant(node): tensor = node.t(attr_name) if len(tensor.shape) == 0: # tensor(0.1) return float(tensor) - return tensor + return _wrap_const(tensor.numpy()) elif ty == "DeviceObjType": return node.s(attr_name) elif ty == "FunctionType": diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 120631ea31dc..af098771a521 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -16,7 +16,7 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=import-self, invalid-name, unused-argument, too-many-lines, len-as-condition, broad-except -# pylint: disable=import-outside-toplevel +# pylint: disable=import-outside-toplevel, redefined-builtin """TF: Tensorflow frontend.""" import warnings from collections import defaultdict @@ -27,7 +27,7 @@ from tvm.ir import IRModule from tvm.relay.prelude import Prelude, StaticTensorArrayOps, get_tensor_array_shape -from tvm.ir import structural_hash as s_hash +from topi.util import get_const_tuple from .. import analysis from .. import expr as _expr @@ -40,7 +40,6 @@ from .common import infer_shape as _infer_shape from .common import infer_channels as _infer_channels from .common import infer_value as _infer_value -from .common import infer_value_simulated as _infer_value_simulated __all__ = ['from_tensorflow'] @@ -96,6 +95,23 @@ def _get_tuple_param(params, input_node): def _need_prelude_for_shape_inference(op): return "TensorArray" in op +def _get_more_static_shape(shape0, shape1): + """Compare two shapes with the same rank, + and return the one with fewer symbolic dimension. + """ + assert len(shape0) == len(shape1) + num_sym_dim0 = 0 + num_sym_dim1 = 0 + for dim0, dim1 in zip(list(shape0), list(shape1)): + if not isinstance(dim0, int): + num_sym_dim0 += 1 + if not isinstance(dim1, int): + num_sym_dim1 += 1 + + if num_sym_dim0 < num_sym_dim1: + return shape0 + return shape1 + def _rsqrt(): def _impl(inputs, attr, params, mod): inputs.append(tvm.relay.const(-0.5, attr['T'].name)) @@ -275,7 +291,7 @@ def _impl(inputs, attr, params, mod): inputs_data = inputs[0] if opname != 'conv_transpose' else inputs[2] # NCHW Layout require weights transpose - weights_shape = _infer_shape(inputs[1]) + weights_shape = _infer_shape(inputs[1], mod) if attr['data_format'] == 'NCHW': tmp_shape = weights_shape if opname in ['conv', 'conv_transpose']: @@ -287,7 +303,7 @@ def _impl(inputs, attr, params, mod): weights_shape = tmp_shape - input_shape = _infer_shape(inputs_data) + input_shape = _infer_shape(inputs_data, mod) if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC": input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)] inputs_data = _op.transpose(inputs_data, axes=(0, 3, 1, 2)) @@ -379,9 +395,6 @@ def _impl(inputs, attr, params, mod): else: attr['kernel_layout'] = 'HWOI' if attr['data_format'] == 'NHWC' else 'OIHW' - use_bias = len(inputs) == (3 if opname != 'conv_transpose' else 4) - channel_axis = 1 if attr['data_format'] == "NCHW" else 3 - # Ignore the new attributes from TF2.0, for now. out = AttrCvt( op_name=_dimension_picker('conv', @@ -394,11 +407,6 @@ def _impl(inputs, attr, params, mod): 'group': ('groups', 1)}, custom_check=_dimension_constraint())([inputs_data, inputs[1]], attr) - if use_bias: - out = _op.nn.bias_add(out, - inputs[2] if opname != 'conv_transpose' else inputs[3], - axis=channel_axis) - if flip_layout: out = _op.transpose(out, axes=(0, 2, 3, 1)) @@ -595,7 +603,7 @@ def _impl(inputs, attr, params, mod): out = AttrCvt( op_name=_dimension_picker('conv', surfix="_transpose" if opname == 'conv_transpose' else ""), - ignores=['explicit_paddings'], + ignores=['explicit_paddings', 'Tshape'], transforms={ 'kernel_shape': 'kernel_size', 'data_format': 'data_layout', @@ -614,6 +622,62 @@ def _impl(inputs, attr, params, mod): return out return _impl +def _nms(): + def _impl(inputs, attr, params, mod): + # Get parameter values + # TODO(yongwww) change nms in relay to support symbolic max_output_size + try: + max_output_size = int(np.atleast_1d(inputs[2].data.asnumpy() + .astype("int64"))[0]) + except Exception: + try: + max_output_size = _infer_value(inputs[2], params, + mod).asnumpy().astype("int64").tolist()[0] + except Exception: + max_output_size = -1 + iou_threshold = np.atleast_1d(inputs[3].data.asnumpy())[0] + # score_threshold was introduced from V3 + score_threshold = np.atleast_1d(inputs[4].data.asnumpy())[0] if len(inputs) > 4 else 0.0 + + # Generate data with shape (1, num_anchors, 5) + scores = AttrCvt(op_name="expand_dims", + ignores=['T_threshold'], + extras={'axis': -1, 'num_newaxis': 1})([inputs[1]], attr) + data = get_relay_op('concatenate')([scores, inputs[0]], -1) + data = get_relay_op('expand_dims')(data, 0, 1) + + # reason why using get_valid_counts is for inference performance + ct, data, indices = get_relay_op('get_valid_counts')(data, + score_threshold=score_threshold, + id_index=-1, + score_index=0) + # TensorFlow NMS doesn't have parameter top_k + top_k = -1 + # TF doesn't have class id for nms input + score_index = 0 + nms_ret = get_relay_op('non_max_suppression')(data=data, + valid_count=ct, + indices=indices, + max_output_size=max_output_size, + iou_threshold=iou_threshold, + force_suppress=True, + top_k=top_k, + coord_start=1, + score_index=score_index, + id_index=-1, + return_indices=True, + invalid_to_bottom=False) + + # squeeze it, TF NMS is not batched + size = get_relay_op("squeeze")(nms_ret[1], axis=[1]) + data_slice = get_relay_op("squeeze")(nms_ret[0], axis=[0]) + + # slice to get the dynamic result + ret = get_relay_op("strided_slice")(data_slice, begin=_expr.const([0]), + end=size, slice_mode="size") + return ret + return _impl + def _decode_image(): def _impl(inputs, attr, params, mod): # Image decode wrapper: Expecting user to feed decoded input to next layer drop this layer. @@ -633,7 +697,7 @@ def _impl(inputs, attr, params, mod): try: crop_size = _get_list_param(params, inputs[3]) except (IndexError, KeyError): - crop_size = _infer_value(inputs[3], params).asnumpy().tolist() + crop_size = _infer_value(inputs[3], params, mod).asnumpy().tolist() method = attr['method'].decode() method = 'nearest_neighbor' if method == 'nearest' else method @@ -667,9 +731,9 @@ def _impl(inputs, attr, params, mod): # Important that the size is defined. If an axis is not, we need to infer what # the shape should be. if -1 in size: - size = _infer_value(inputs[1], params).asnumpy().reshape([-1]).tolist() + size = _infer_value(inputs[1], params, mod).asnumpy().reshape([-1]).tolist() else: - size = _infer_value(inputs[1], params).asnumpy().reshape([-1]).tolist() + size = _infer_value(inputs[1], params, mod).asnumpy().reshape([-1]).tolist() attr['size'] = size inputs.pop(1) @@ -788,52 +852,20 @@ def _impl(inputs, attr, params, mod): def _tensor_array(): def _impl(inputs, attr, params, prelude): - try: - from tensorflow.python.framework import tensor_util - except ImportError as e: - raise ImportError( - "Unable to import tensorflow which is required {}".format(e)) - dtype_str = attr.get('dtype').name assert not attr["dynamic_size"], "Dynamic size tensor array is " \ "not supported in TVM yet." - raw_elem_shape = tensor_util.TensorShapeProtoToList(attr['element_shape']) - elem_shape = [] - for dim in raw_elem_shape: - if dim < 0: - elem_shape.append(Any()) - else: - elem_shape.append(dim) - - if elem_shape: - # Element shape is specified. - # Directly create static tensor array with given shape. - static_tensor_array_ops = StaticTensorArrayOps(prelude, - dtype_str, - elem_shape) - static_tensor_array_ops.register() - tensor_array_constructor = prelude.get_var_static('tensor_array', - dtype_str, - elem_shape) - tensor_array = tensor_array_constructor(inputs[0]) - _static_tensor_array_map[tensor_array] = tensor_array - elif attr['identical_element_shapes']: - # identical_element_shapes is set but element shape is not given. - # We create a static tensor array with dummy shape and record it in - # _static_tensor_array_map. Later when creating other tensor array ops - # which uses this tensor array, we reconstruct this tensor array with - # actual shape. - dummy_shape = () + if "shape" in attr: + shape = attr["shape"] static_tensor_array_ops = StaticTensorArrayOps(prelude, dtype_str, - dummy_shape) + shape) static_tensor_array_ops.register() tensor_array_constructor = prelude.get_var_static('tensor_array', dtype_str, - dummy_shape) + shape) tensor_array = tensor_array_constructor(inputs[0]) - _static_tensor_array_map[tensor_array] = None else: tensor_array_constructor = prelude.get_var('tensor_array', dtype_str) tensor_array = tensor_array_constructor(inputs[0]) @@ -856,21 +888,12 @@ def _impl(inputs, attr, params, prelude): values = unstack_function(inputs[2]) tensor_array_scatter_func = prelude.get_var('tensor_array_scatter', dtype_str) else: + input_t_shape = _get_more_static_shape(input_t_shape, input_shape) + values_shape = (values_shape[0],) + input_t_shape static_tensor_array_ops = StaticTensorArrayOps(prelude, dtype_str, input_t_shape) static_tensor_array_ops.register() - # For scatter operation, it is possible to write to a newly create - # tensor array. We need to check and recreate its input tensor array. - if input_ta in _static_tensor_array_map and \ - _static_tensor_array_map[input_ta] is None: - ta_constructor = prelude.get_var_static('tensor_array', - dtype_str, - input_t_shape) - new_ta = ta_constructor(input_ta.args[0]) - _static_tensor_array_map[input_ta] = new_ta - input_ta = new_ta - # Register static indices shape if isinstance(indices_shape[0], int): static_tensor_array_ops.define_tensor_array_scatter(indices_shape, True) @@ -904,24 +927,28 @@ def _impl(inputs, attr, params, prelude): dtype_str, input_shape) static_tensor_array_ops.register() + if not isinstance(indices_shape[0], int): gather_function = prelude.get_var_static('tensor_array_gather', dtype_str, input_shape) out_tensor_t = gather_function(inputs[2], inputs[1]) + out_shape = (indices_shape[0],) + input_shape + static_tensor_array_ops = StaticTensorArrayOps(prelude, + dtype_str, + out_shape) + static_tensor_array_ops.register() # Output shape is (indices_shape[0],) + input_shape - static_tensor_array_ops.define_tensor_get_data((indices_shape[0],) + input_shape) get_data_func = prelude.get_var_static('tensor_get_data', dtype_str, - input_shape) + out_shape) out = get_data_func(out_tensor_t) else: # For fixed length indices, directly generate static shape output read_func = prelude.get_var_static('tensor_array_read', dtype_str, input_shape) - static_tensor_array_ops.define_tensor_get_data(input_shape) get_data_func = prelude.get_var_static('tensor_get_data', dtype_str, input_shape) @@ -931,7 +958,10 @@ def _impl(inputs, attr, params, prelude): out_tensor = get_data_func(read_func(inputs[2], index)) tensor_list.append(_op.expand_dims(out_tensor, axis=0)) - out = _op.concatenate(tensor_list, axis=0) + if indices_shape[0] > 1: + out = _op.concatenate(tensor_list, axis=0) + else: + out = tensor_list[0] return out return _impl @@ -955,34 +985,30 @@ def _impl(inputs, attr, params, prelude): v = tensor_func(inputs[2]) write_func = prelude.get_var('tensor_array_write', dtype_str) else: - # For write operation, it is possible to write to a newly create - # tensor array. We need to check and recreate its input tensor array. - if input_ta in _static_tensor_array_map and \ - _static_tensor_array_map[input_ta] is None: - static_tensor_array_ops = StaticTensorArrayOps(prelude, - dtype_str, - input_t_shape) - static_tensor_array_ops.register() - ta_constructor = prelude.get_var_static('tensor_array', - dtype_str, - input_t_shape) - new_ta = ta_constructor(input_ta.args[0]) - _static_tensor_array_map[input_ta] = new_ta - input_ta = new_ta - input_ta_shape = input_t_shape - else: - input_ta_rank = len(input_ta_shape) - assert input_ta_rank == input_rank, "Shape rank mismatch: {} vs {}". \ - format(input_ta_rank, input_rank) - static_tensor_array_ops = StaticTensorArrayOps(prelude, - dtype_str, - input_ta_shape) - static_tensor_array_ops.register() + input_ta_rank = len(input_ta_shape) + assert input_ta_rank == input_rank, "Shape rank mismatch: {} vs {}". \ + format(input_ta_rank, input_rank) + static_tensor_array_ops = StaticTensorArrayOps(prelude, + dtype_str, + input_ta_shape) + static_tensor_array_ops.register() tensor_func = prelude.get_var_static("tensor_constructor", dtype_str, input_ta_shape) v = tensor_func(inputs[2]) + # Write tensor with more static shape + actual_shape = _get_more_static_shape(input_t_shape, input_ta_shape) + if actual_shape != input_t_shape: + new_shape = [] + num_any_dim = 0 + for dim in actual_shape: + if not isinstance(dim, int): + num_any_dim += 1 + new_shape.append(dim if isinstance(dim, int) else -1) + if num_any_dim <= 1: + v = tensor_func(_op.reshape(inputs[2], new_shape)) + write_func = prelude.get_var_static('tensor_array_write', dtype_str, input_ta_shape) @@ -1003,7 +1029,6 @@ def _impl(inputs, attr, params, prelude): dtype_str, input_shape) static_tensor_array_ops.register() - static_tensor_array_ops.define_tensor_get_data(input_shape) read_func = prelude.get_var_static("tensor_array_read", dtype_str, input_shape) out_tensor = read_func(inputs[2], _op.take(inputs[1], tvm.relay.const(0))) get_data_func = prelude.get_var_static('tensor_get_data', @@ -1019,39 +1044,22 @@ def _impl(inputs, attr, params, prelude): dtype_str = attr.get('T').name input_ta = inputs[0] input_ta_shape = get_tensor_array_shape(input_ta, dtype_str, prelude) - input_t_shape = _infer_shape(inputs[1], prelude.mod) - input_rank = len(input_t_shape) lengths = _op.cast(inputs[2], 'int32') lengths_shape = _infer_shape(lengths, prelude.mod) value_shape = _infer_shape(inputs[1], prelude.mod) + input_rank = len(value_shape) if input_ta_shape is None: v = prelude.get_var("tensor{}".format(input_rank), dtype_str)(inputs[1]) split_func = prelude.get_var('tensor_array_split', dtype_str) else: - # For split operation, it is possible to write to a newly create - # tensor array. We need to check and recreate its input tensor array. - if input_ta in _static_tensor_array_map and \ - _static_tensor_array_map[input_ta] is None: - input_ta_shape = (Any(),) + input_t_shape[1:] - static_tensor_array_ops = StaticTensorArrayOps(prelude, - dtype_str, - input_ta_shape) - static_tensor_array_ops.register() - ta_constructor = prelude.get_var_static('tensor_array', - dtype_str, - input_ta_shape) - new_ta = ta_constructor(input_ta.args[0]) - _static_tensor_array_map[input_ta] = new_ta - input_ta = new_ta - else: - input_ta_rank = len(input_ta_shape) - assert input_ta_rank == input_rank, "Shape rank mismatch: {} vs {}". \ - format(input_ta_rank, input_rank) - static_tensor_array_ops = StaticTensorArrayOps(prelude, - dtype_str, - input_ta_shape) - static_tensor_array_ops.register() + input_ta_rank = len(input_ta_shape) + assert input_ta_rank == input_rank, "Shape rank mismatch: {} vs {}". \ + format(input_ta_rank, input_rank) + static_tensor_array_ops = StaticTensorArrayOps(prelude, + dtype_str, + input_ta_shape) + static_tensor_array_ops.register() # Check static value/indices shape if isinstance(value_shape[0], int) or isinstance(lengths_shape[0], int): @@ -1093,10 +1101,14 @@ def _impl(inputs, attr, params, prelude): static_tensor_array_ops.register() concat_func = prelude.get_var_static("tensor_array_concat", dtype_str, input_shape) out_tensor = concat_func(inputs[1]) - static_tensor_array_ops.define_tensor_get_data((Any(),) + input_shape[1:]) + out_shape = (Any(),) + input_shape[1:] + static_tensor_array_ops = StaticTensorArrayOps(prelude, + dtype_str, + out_shape) + static_tensor_array_ops.register() get_data_func = prelude.get_var_static('tensor_get_data', dtype_str, - input_shape) + out_shape) out = get_data_func(out_tensor) return out @@ -1104,9 +1116,13 @@ def _impl(inputs, attr, params, prelude): def _tile(): def _impl(inputs, attr, params, mod): - reps = _get_list_param(params, inputs.pop()) - new_input = [] - new_input.append(inputs.pop(0)) + reps_input = inputs.pop() + if isinstance(reps_input, _expr.Call): + np_reps = _infer_value(reps_input, params, mod).asnumpy() + reps = [np_reps.flatten()[i] for i in range(np_reps.flatten().shape[0])] + else: + reps = _get_list_param(params, reps_input) + new_input = [inputs.pop(0)] return AttrCvt( op_name='tile', @@ -1119,25 +1135,31 @@ def _impl(inputs, attr, params, mod): try: begin = _get_list_param(params, inputs[1]) except (IndexError, KeyError, AttributeError): - begin = _infer_value(inputs[1], params).asnumpy().tolist()[0] + # Handle symbolic begin + try: + begin = _infer_value(inputs[1], params, mod).asnumpy().tolist() + except Exception: + begin = inputs[1] try: size = _get_list_param(params, inputs[2]) except (IndexError, KeyError, AttributeError): # Handle symbolic size try: - size = _infer_value(inputs[2], params).asnumpy().tolist()[0] + size = _infer_value(inputs[2], params, mod).asnumpy().tolist() except Exception: size = inputs[2] - data_shape = _infer_shape(inputs[0], mod) - data_dim = len(data_shape) - end = size - if not isinstance(end, (_expr.Call, _expr.Var)): - for i in range(data_dim): - if size[i] == -1: - end[i] = data_shape[i] - else: - end[i] += begin[i] - return _op.strided_slice(inputs[0], begin=begin, end=end) + + # Align begin and strides for dynamic shape. + data_dim = len(_infer_shape(inputs[0], mod)) + strides = [1] * data_dim + if not isinstance(begin, (_expr.Call, _expr.Var)): + for _ in range(len(begin), data_dim): + begin.append(0) + elif not isinstance(size, (_expr.Call, _expr.Var)): + for _ in range(len(size), data_dim): + size.append(-1) + return _op.strided_slice(inputs[0], begin=begin, end=size, + strides=strides, slice_mode="size") return _impl @@ -1151,18 +1173,16 @@ def _impl(inputs, attr, params, mod): # Shape operator is already pruned, hence # try to infer shape by precompute prune if possible. try: - params_new = _infer_value(pop_node, params) - shape_arg = tuple(params_new.asnumpy().astype('int64').flatten()) + params_new = _infer_value(pop_node, params, mod) + shape_arg = tuple(params_new.asnumpy().astype('int32').flatten()) except Exception: # Deal with symbolic shape case. - # Currently only shape_of can be the direct ancestor. - if not isinstance(pop_node, tvm.relay.expr.Call) or \ - "shape_of" not in str(pop_node.op): - raise RuntimeError("If shape operator is used in reshape to " - "express reshape_like, shape_of must be " - "the direct ancestor of reshape when input " - "shape is symbolic.") - return _op.reshape_like(inputs[0], pop_node.args[0]) + if isinstance(pop_node, _expr.Call) and \ + "shape_of" in str(pop_node.op): + # shape_of is the direct ancestor. + return _op.reshape_like(inputs[0], pop_node.args[0]) + shape_arg = pop_node + return AttrCvt( op_name="reshape", extras={'newshape': shape_arg}, @@ -1170,6 +1190,7 @@ def _impl(inputs, attr, params, mod): return _impl + def _depth_to_space(): def _impl(inputs, attr, params, mod): block_size = int(attr['block_size']) @@ -1191,7 +1212,8 @@ def _impl(inputs, attr, params, mod): def _bias_add(): def _impl(inputs, attr, params, mod): # Must expand for proper broadcasting in NCHW. - if attr['data_format'].decode("utf-8") == 'NCHW': + if 'data_format' in attr and \ + attr['data_format'].decode("utf-8") == 'NCHW': bias = _op.reshape(inputs[1], newshape=(1, -1, 1, 1)) else: bias = inputs[1] @@ -1203,7 +1225,7 @@ def _impl(inputs, attr, params, mod): if isinstance(inputs[1], _expr.Var): shape = params[inputs[1].name_hint] else: - shape = _infer_value(inputs[1], params) + shape = _infer_value(inputs[1], params, mod) shape = list(shape.asnumpy().reshape([-1])) return _op.broadcast_to(inputs[0], shape) return _impl @@ -1230,7 +1252,7 @@ def _impl(inputs, attr, params, mod): attr['data_format'] = attr['data_format'].decode("utf-8") if attr['data_format'] == 'NCHW': axis = 1 - if 'U' in attr: + if 'U' in attr and attr['U'].name != attr['T'].name: need_cast = True inputs[0] = _op.cast(inputs[0], dtype=attr['U'].name) # Check if mean and variance are empty @@ -1238,7 +1260,7 @@ def _impl(inputs, attr, params, mod): # For run-time calculation moving_mean_shape = [int(n) for n in inputs[3].type_annotation.shape] moving_variance_shape = [int(n) for n in inputs[4].type_annotation.shape] - if (moving_mean_shape[0] == 0 and moving_variance_shape[0] == 0): + if moving_mean_shape[0] == 0 and moving_variance_shape[0] == 0: inputs[3] = _op.mean(inputs[0], axis=axis, keepdims=False, exclude=True) inputs[4] = _op.variance(inputs[0], axis=axis, keepdims=False, exclude=True) out = AttrCvt(op_name='batch_norm', @@ -1300,16 +1322,12 @@ def _impl(inputs, attr, params, mod): def _fill(): def _impl(inputs, attr, params, mod): - output_shape = attr['_output_shapes'][0] - # Output shape must be defined to avoid errors. If any axis is not, we must - # try to compute its shape. - if output_shape is None or -1 in output_shape: - output_shape = _infer_value(inputs[0], params).asnumpy().reshape([-1]).tolist() + try: + output_shape = _infer_value(inputs[0], params, mod).asnumpy().tolist() + except Exception: + output_shape = inputs[0] - fill_arg = _get_num_param(params, inputs.pop(1)) - dtype = attr['T'].name - return _op.full(tvm.relay.const(fill_arg, dtype), - output_shape, dtype) + return _op.full(inputs[1], output_shape, attr['T'].name) return _impl def _lrn(): @@ -1339,6 +1357,8 @@ def _reduce(op): def _impl(inputs, attr, params, mod): axis = _get_list_param(params, inputs[1]) axis = tuple(axis) + if not axis: + axis = None return AttrCvt( op_name=op, extras={'axis': axis}, @@ -1396,15 +1416,49 @@ def _impl(inputs, attr, params, mod): begin = _get_list_param(params, inputs[1]) end = _get_list_param(params, inputs[2]) stride = _get_list_param(params, inputs[3]) + begin_mask = int(attr.get('begin_mask', 0)) end_mask = int(attr.get('end_mask', 0)) ellipsis_mask = int(attr.get('ellipsis_mask', 0)) new_axis_mask = int(attr.get('new_axis_mask', 0)) shrink_axis_mask = int(attr.get('shrink_axis_mask', 0)) - data_shape = _infer_shape(inputs[0], mod) + in_type = _infer_type(inputs[0], mod) + data_shape = get_const_tuple(in_type.checked_type.shape) data_dim = len(data_shape) stride_dim = len(stride) + # This is a special routine to handle strided_slice after shape_of. + # We need this since in some cases we want to do strided_slice on + # a partial symbolic shape, such as (1, ?), and get a static shape + # (1,). Directly slice on shape_of will result in fully dynamic shape. + # TODO(kevinthesun): Can we generalize this process with partial eval? + if isinstance(inputs[0], _expr.Call) and inputs[0].op == _op.get("shape_of"): + bg = begin[0] + ed = end[0] + st = stride[0] + + if ed <= 0 < st: + ed += data_shape[0] + + in_shape = _infer_shape(inputs[0].args[0], mod) + dtype = in_type.checked_type.dtype + out_data = [] + idx = bg + while idx < ed: + if isinstance(in_shape[idx], int): + out_data.append(in_shape[idx]) + else: + break + idx += st + + # Only return when in_shape is fully static in the range from begin to end. + if idx >= st: + ret = _expr.const(out_data, dtype) + if shrink_axis_mask: + ret = _op.squeeze(ret) + + return ret + def _transform_mask(stride_dim, ellipsis_mask): """Handle mask inputs to create new begin, end, stride and output shape""" m_begin = [0] * data_dim @@ -1444,19 +1498,19 @@ def _transform_mask(stride_dim, ellipsis_mask): break if mask & begin_mask: m_begin[final_index] = data_shape[final_index] \ - if stride[index] < 0 else 0 + if stride[index] < 0 else 0 elif begin[index]: m_begin[final_index] = begin[index] if mask & end_mask: m_end[final_index] = 0 if stride[index] < 0 \ - else data_shape[final_index] + else data_shape[final_index] elif end[index]: m_end[final_index] = end[index] m_stride[final_index] = stride[index] if mask & shrink_axis_mask: #Tensorflow make axis with shrink_axis_mask as dimension 1 m_begin[final_index] = data_shape[final_index] + begin[index] \ - if begin[index] < 0 else begin[index] + if begin[index] < 0 else begin[index] m_end[final_index] = begin[index] + 1 m_stride[final_index] = 1 fshape_indices.append(-2) @@ -1469,8 +1523,11 @@ def _transform_mask(stride_dim, ellipsis_mask): fshape_indices = None if begin_mask or end_mask or ellipsis_mask or new_axis_mask or shrink_axis_mask: begin, end, stride, fshape_indices = _transform_mask(stride_dim, ellipsis_mask) - out = _op.strided_slice(inputs[0], begin=begin, end=end, strides=stride) - out_shape = _infer_shape(out, mod) + out = _op.strided_slice(inputs[0], + begin=begin, + end=end, + strides=stride) + out_shape = _infer_shape(out, mod=mod) if not fshape_indices: fshape_indices = range(len(out_shape)) @@ -1537,7 +1594,7 @@ def _impl(inputs, attr, params, mod): try: axes = _get_list_param(params, inputs[1]) except (IndexError, KeyError, AttributeError): - axes = _infer_value_simulated(inputs[1], params).asnumpy() + axes = _infer_value(inputs[1], params, mod).asnumpy().tolist() return _op.transpose(inputs[0], axes=axes) return _impl @@ -1569,7 +1626,8 @@ def _impl(inputs, attr, params, mod): input_shape = _infer_shape(inputs[0], mod) name = attr["_node_name"] - params[name] = tvm.nd.array([len(input_shape)]) + params[name] = tvm.nd.array(np.array([len(input_shape)]) + .astype("int32")) return [_expr.var(name, shape=params[name].shape, dtype='int32')] @@ -1582,24 +1640,22 @@ def _impl(inputs, attr, params, mod): start = _get_param(params, inputs[0])[0] except (IndexError, KeyError, AttributeError): try: - start = _infer_value(inputs[1], params).asnumpy().tolist() + start = _infer_value(inputs[1], params, mod).asnumpy().tolist() start = start if not isinstance(start, list) else start[0] except Exception: # Symbolic start start = inputs[0] - if hasattr(inputs[1], "name_hint") or isinstance(inputs[1], _expr.Constant): - limit = _get_param(params, inputs[1])[0] - else: - if any(['Rank' in param for param in params]): - limit = params.pop('Rank').asnumpy()[0] - else: - try: - limit = _infer_value(inputs[1], params, mod).asnumpy().tolist() - limit = limit if not isinstance(limit, list) else limit[0] - except Exception: - # Symbolic limit - limit = inputs[1] + try: + limit = _get_param(params, inputs[1])[0] \ + if hasattr(inputs[1], "name_hint") or isinstance(inputs[1], _expr.Constant) \ + else params.pop('Rank').asnumpy()[0] + except (IndexError, KeyError, AttributeError): + try: + limit = _infer_value(inputs[1], params, mod).asnumpy().tolist() + limit = limit if not isinstance(limit, list) else limit[0] + except Exception: + limit = inputs[1] try: delta = _get_param(params, inputs[2])[0] @@ -1734,16 +1790,21 @@ def _impl(inputs, attr, params, mod): try: k = int(_get_num_param(params, k_input)) except (IndexError, KeyError, AttributeError): - k = int(_infer_value(k_input, params).asnumpy().tolist()) - if k < 1: - raise tvm.error.OpAttributeInvalid( - 'Attribute k must be positive in operator TopKV2') + try: + k = int(_infer_value(k_input, params, mod).asnumpy().tolist()) + except Exception: + k = k_input + if isinstance(k, int): + if k < 1: + raise tvm.error.OpAttributeInvalid( + 'Attribute k must be positive in operator TopKV2') + k = _expr.const(k) if attr['sorted'] is False: raise tvm.error.OpAttributeUnImplemented( 'Attribute sorted=False is not supported in operator TopKV2') return AttrCvt(op_name='topk', ignores=['sorted'], - extras={'k': k, 'is_ascend': False, 'dtype': 'int32'})(inputs, attr) + extras={'k': k, 'is_ascend': False, 'dtype': 'int32'})([inputs[0]], attr) return _impl def _floordiv(): @@ -1770,12 +1831,12 @@ def _impl(inputs, attr, params, mod): try: block_shape = _get_list_param(params, inputs[1]) except (IndexError, KeyError, AttributeError): - block_shape = _infer_value(inputs[1], params).asnumpy().tolist() + block_shape = _infer_value(inputs[1], params, mod).asnumpy().tolist() try: paddings = _get_list_param(params, inputs[2]) except (IndexError, KeyError, AttributeError): - paddings = _infer_value(inputs[2], params).asnumpy() + paddings = _infer_value(inputs[2], params, mod).asnumpy() paddings = np.squeeze(paddings) if len(paddings.shape) == 1: paddings = np.expand_dims(paddings, axis=0) @@ -1800,7 +1861,7 @@ def _impl(inputs, attr, params, mod): axes = [2 * i + 2 for i in range(M)] + [0] + [2 * i + 1 for i in range(M)] + \ list(range(1 + 2 * M, 1 + 2 * M + remaining_shape_length)) permuted_reshaped_padded = tvm.relay.transpose(reshaped_padded, axes=axes) - permuted_reshaped_padded_shape = _infer_shape(permuted_reshaped_padded) + permuted_reshaped_padded_shape = _infer_shape(permuted_reshaped_padded, mod) # Reshape permuted_reshaped_padded to flatten block_shape into the batch dimension, # producing an output tensor of shape: # [batch * prod(block_shape)] + [padded_shape[1] / block_shape[0], ..., @@ -1820,12 +1881,12 @@ def _impl(inputs, attr, params, mod): try: block_shape = _get_list_param(params, inputs[1]) except (IndexError, KeyError, AttributeError): - block_shape = _infer_value(inputs[1], params).asnumpy().tolist() + block_shape = _infer_value(inputs[1], params, mod).asnumpy().tolist() try: crops = _get_list_param(params, inputs[2]) except (IndexError, KeyError, AttributeError): - crops = _infer_value(inputs[2], params).asnumpy() + crops = _infer_value(inputs[2], params, mod).asnumpy() crops = np.squeeze(crops) if len(crops.shape) == 1: crops = np.expand_dims(crops, axis=0) @@ -1854,7 +1915,7 @@ def _impl(inputs, attr, params, mod): # [batch / prod(block_shape), input_shape[1] * block_shape[0] - crops[0,0] - crops[0,1], # ..., input_shape[M] * block_shape[M-1] - crops[M-1,0] - crops[M-1,1], # input_shape[M+1], ..., input_shape[N-1]] - reshaped_permuted_shape = _infer_shape(reshaped_permuted) + reshaped_permuted_shape = _infer_shape(reshaped_permuted, mod) cropped = reshaped_permuted for axis in range(1, M+1): crop = crops[axis - 1] @@ -1930,7 +1991,6 @@ def _impl(inputs, attr, params, mod): return _res return _impl - # compatible operators that do NOT require any conversion. _identity_list = [] @@ -1947,6 +2007,8 @@ def _impl(inputs, attr, params, mod): # for N to 1 mapping, currently not supported(?) _convert_map = { 'Abs' : AttrCvt('abs'), + 'Acos' : AttrCvt('acos'), + 'Acosh' : AttrCvt('acosh'), 'Add' : _elemwise('add'), 'AddN' : _add_n(), 'AddV2' : _elemwise('add'), @@ -1954,8 +2016,11 @@ def _impl(inputs, attr, params, mod): 'Any' : _reduce('any'), 'ArgMax' : _argx(_op.argmax, 'argmax'), 'ArgMin' : _argx(_op.argmin, 'argmin'), + 'Asin' : AttrCvt('asin'), + 'Asinh' : AttrCvt('asinh'), 'Assert' : _assert(), 'Atan' : AttrCvt('atan'), + 'Atanh' : AttrCvt('atanh'), 'Atan2' : _atan2(), 'AvgPool' : _pooling('avg_pool'), 'AvgPool3D' : _pool3d('avg_pool3d'), @@ -1974,7 +2039,9 @@ def _impl(inputs, attr, params, mod): 'Conv2D' : _conv('conv'), 'Conv2DBackpropInput' : _conv('conv_transpose'), 'Conv3D' : _conv3d('conv'), + 'Conv3DBackpropInputV2' : _conv3d('conv_transpose'), 'Cos' : AttrCvt('cos'), + 'Cosh' : AttrCvt('cosh'), 'CropAndResize' : _crop_and_resize(), 'DecodeJpeg' : _decode_image(), 'DepthToSpace' : _depth_to_space(), @@ -2024,6 +2091,8 @@ def _impl(inputs, attr, params, mod): 'Mod' : _elemwise('mod'), 'Mul' : _elemwise('multiply'), 'Neg' : AttrCvt('negative'), + 'NonMaxSuppressionV2' : _nms(), + 'NonMaxSuppressionV3' : _nms(), 'NoOp' : _no_op(), 'NotEqual' : _broadcast('not_equal'), 'OneHot' : _one_hot(), @@ -2051,6 +2120,7 @@ def _impl(inputs, attr, params, mod): 'Sigmoid' : AttrCvt('sigmoid'), 'Sign' : AttrCvt('sign'), 'Sin' : AttrCvt('sin'), + 'Sinh' : AttrCvt('sinh'), 'Size' : _size(), 'Slice' : _slice(), 'Softmax' : _softmax(), @@ -2336,29 +2406,36 @@ def _get_abs_layer_name(node): # 1.x. _control_flow_nodes = ['Merge', 'Switch', 'NextIteration', 'Exit', 'Enter', 'LoopCond'] -# A map to record tensor array with fixed rank shape -_static_tensor_array_map = {} - -class RewriteSubgraph(ExprMutator): - """ - A helper class to rewrite expr in while loop function to variable - - Parameters - ---------- - rewrite_map : Dict[expr, expr] - A dictionay contains a set of expr to var mapping. - """ - def __init__(self, rewrite_map): - ExprMutator.__init__(self) - self.rewrite_map = rewrite_map - - def visit(self, expr): - if expr in self.rewrite_map: - return self.rewrite_map[expr] - return super().visit(expr) +# A map to record tensor array write ops and input ta/tensor indices +# Value is (index of tensor array, index of written node) +_tensor_array_write_ops = { + "TensorArrayWrite" : (3, 2), + "TensorArrayScatter" : (0, 2), + "TensorArraySplit" : (0, 1), +} -def rewrite_subgraph(expr, rewrites): - return RewriteSubgraph(rewrites).visit(expr) +def is_tensor_array_constuctor(tf_node): + """Check whether is tensor array constructor node.""" + is_ta = False + ta_start = "TensorArrayV" + if tf_node.op.startswith(ta_start): + is_ta = tf_node.op[len(ta_start)].isnumeric() + return is_ta + +def find_parent_loop_name(node_name, while_loop_name_set): + """Find name of direct parent while loop.""" + ploop_name = "" + name_prefix = node_name.rsplit('/', 1)[0] + if name_prefix.startswith("^"): + name_prefix = name_prefix[1:] + for lname in while_loop_name_set: + if name_prefix.startswith(lname) and len(ploop_name) < len(lname): + ploop_name = lname + + if len(ploop_name) == 0: + ploop_name = name_prefix + + return ploop_name def _in_while_loop(control_flow_node_map, op_name): """ @@ -2385,6 +2462,28 @@ def _in_while_loop(control_flow_node_map, op_name): return op_name in control_flow_node_map and \ "LoopCond" in control_flow_node_map[op_name] +class RewriteSubgraph(ExprMutator): + """ + A helper class to rewrite expr in while loop function to variable. + + Parameters + ---------- + rewrite_map : Dict[expr, expr] + A dictionay contains a set of expr to var mapping. + """ + def __init__(self, rewrite_map): + ExprMutator.__init__(self) + self.rewrite_map = rewrite_map + + def visit(self, expr): + if expr in self.rewrite_map: + return self.rewrite_map[expr] + return super().visit(expr) + +def rewrite_subgraph(expr, rewrites): + """Rewrite loop body.""" + return RewriteSubgraph(rewrites).visit(expr) + class Branch: """A class contains the components that are used to build up a Relay if node. @@ -2465,118 +2564,50 @@ def if_node(self): self._if = self._if_node() return self._if +class VarChecker(ExprVisitor): + """Check whether a Variable is used in loop body. -class LoopBound(ExprVisitor): - """ - When a loop body is create, we get a Relay expression backtracing all - the way back to input node. This will result in lots of unnecessary - expression placed into loop body and compute multiple times. For example, - consider the following tensorflow code: - - .. code-block:: python - - i = tf.constant(0) - data = tf.compat.v1.placeholder(tf.float32, shape=(1024, 1024)) - slice = tf.strided_slice(data, 0, 512) - def c(i): return tf.less(i, 10) - def b(i): return [tf.add(i, 1), tf.add(i, 1) + slice] - r = tf.while_loop(c, b, [i]) - - If we directly create recursive function, slice will be placed into function body. - Instead, we recognize whether slice is inside while_loop block and pass it as an - extra loop variable to avoid duplicate computation. - - TODO(kevinthesun): Add a LICM pass for Relay to handle generic loop/function. + Parameters + ---------- + var : relay.expr.Var + Relay Variable to be checked. """ - def __init__(self, loop_name, hash2tfnode, while_loop_name_set): + def __init__(self, var): ExprVisitor.__init__(self) - self._loop_name = loop_name - self._hash2tfnode = hash2tfnode - self._while_loop_name_set = while_loop_name_set - self.extra_loop_var_names = set() - - def _find_parent_loop_name(self, node_name): - """Find name of direct parent while loop.""" - ploop_name = "" - name_prefix = node_name.rsplit('/', 1)[0] - if name_prefix.startswith("^"): - name_prefix = name_prefix[1:] - # To get the name of the direct parent while loop for a given node, - # we iterate all the while loop names inside TensorFlow graph def. - # If we find a loop name with which current node name starts, - # it means current node is under this loop. However, due to nested - # loop, this loop may not be the direct parent while loop of current - # node. We need to keep the longest loop name, which represents the - # innermost while loop corresponding to current node. - for lname in self._while_loop_name_set: - if name_prefix.startswith(lname) and len(ploop_name) < len(lname): - ploop_name = lname - - if len(ploop_name) == 0: - ploop_name = name_prefix - - return ploop_name + self._var = var + self.used = False def visit(self, expr): - """ - For each expression in the body, look up the corresponding - TensorFlow node with its structural hash. If the current loop is the - direct parent of this node, we check whether its every input node belongs - to the current loop. If not, we mark this input node as an extra loop - variable to the current loop. - """ - expr_hash = s_hash(expr) - - if expr_hash in self._hash2tfnode: - node = self._hash2tfnode[expr_hash] - ploop_name = self._find_parent_loop_name(node.name) - # It is possibel that a node is under nested loop of current loop. - # We only check the direct children of current loop. - if ploop_name == self._loop_name: - for iname in node.input: - iploop_name = self._find_parent_loop_name(iname) - # Use startswith to deal with nested loop - if not iploop_name.startswith(self._loop_name): - if iname not in self.extra_loop_var_names: - self.extra_loop_var_names.add(iname) + if self._var == expr: + self.used = True super().visit(expr) - class Loop: """ A class contains the components that are used to build up a Relay recursive call. - Parameters ---------- - loop_vars : List[tvm.relay.Expr] - The loop variables that used in a while loop. - - cond : tvm.relay.Expr - The condition of a while loop. + mod : tvm.IRModule + Module for current parsed IR. - body : tvm.relay.Expr - The body of a matched while loop. + loop_name : str + Name prefix of while loop in TensorFlow graph. - _loop : tvm.relay.Expr - An internal variable indicates where a recursive call is already created - for a matched TF while loop construct. + lvar2expr : dict from str to dict from Relay.expr.Var to Relay.expr + A dictionary recording all loop vars and corresponding + relay expression. Examples -------- The following is a vanilla loop from TensorFlow: - .. code-block:: python - i = tf.constant(0) c = lambda i: tf.less(i, 10) b = lambda i: tf.add(i, 1) r = tf.while_loop(c, b, [i]) - It will be converted to the following recursive call in Relay: - .. code-block:: python - fn (%while/Less/y: Tensor[(1,), int32], %while/Add/y: Tensor[(1,), int32], %Const: Tensor[(1,), int32]) { @@ -2598,86 +2629,74 @@ class Loop: %6 } """ - def __init__(self, mod, loop_name, hash2tfnode, - node_map, while_loop_name_set): - self.loop_vars = [] + def __init__(self, mod, loop_name, lvar2expr): self.cond = None self.body = [] self._loop = None self._mod = mod self._loop_name = loop_name - self._hash2tfnode = hash2tfnode - self._node_map = node_map - self._while_loop_name_set = while_loop_name_set + self._lvar2expr = lvar2expr + self.loop_vars = [] + self.aligned = False def _while_loop(self): """An internal API to create a Relay recursive call for a matched TF `while_loop` construct. """ + bind_map = {} wl = tvm.relay.var('while_loop') - sb = tvm.relay.scope_builder.ScopeBuilder() - loop_checker = LoopBound(self._loop_name, - self._hash2tfnode, - self._while_loop_name_set) - for body in self.body: - loop_checker.visit(body) - - loop_vars = [] - bind_map = {} - loop_var_hash_set = set() - for var in self.loop_vars: - loop_var_hash_set.add(s_hash(var)) - - extra_nodes = [] - for extra_loop_var_name in loop_checker.extra_loop_var_names: - extra_loop_var_name = extra_loop_var_name.split(':')[0].split("^")[-1] - extra_node = self._node_map[extra_loop_var_name] - extra_node = extra_node if isinstance(extra_node, _expr.Tuple) else extra_node[0] - if s_hash(extra_node) not in loop_var_hash_set: - self.loop_vars.append(extra_node) - extra_nodes.append(extra_node) - - for i, var in enumerate(self.loop_vars): - if not isinstance(var, _expr.Var): - var_chk = _infer_type(var, self._mod) - var_type = var_chk.checked_type - else: - var_type = var.type_annotation - - v = tvm.relay.var("loop_var" + str(i), type_annotation=var_type) - loop_vars.append(v) - bind_map[var] = v - - - self.cond = rewrite_subgraph(self.cond, bind_map) - self.body = [rewrite_subgraph(b, bind_map) for b in self.body] - - self.body_shape = [] - for body in self.body: - current_node = body - shape = _infer_shape(current_node, self._mod) - while not isinstance(shape, (tuple, list)): - current_node = current_node.args[-1] - shape = _infer_shape(current_node, self._mod) - self.body_shape.append(shape) + lv_list = [] + expr_list = [] + extra_vars = [] + + for i, lv in enumerate(self.loop_vars): + if self._loop_name not in self._lvar2expr: + self._lvar2expr[self._loop_name] = {} + + # Handle the case when loop var is not properly lifted. + # This can happen when loop var node name is set accidentally + # beginning with loop name. + if lv not in self._lvar2expr[self._loop_name]: + var_name = "{}_loop_var_{}".format(self._loop_name, i) + var_type = _infer_type(lv, self._mod).checked_type + loop_var = tvm.relay.var(var_name, type_annotation=var_type) + self._lvar2expr[self._loop_name][loop_var] = lv + bind_map[lv] = loop_var + self.loop_vars[i] = loop_var + lv = loop_var + + lv_list.append(lv) + expr_list.append(self._lvar2expr[self._loop_name][lv]) + + if bind_map: + self.cond = rewrite_subgraph(self.cond, bind_map) + self.body = [rewrite_subgraph(b, bind_map) for b in self.body] cond = tvm.relay.op.min(self.cond) + for lv, exp in self._lvar2expr[self._loop_name].items(): + if lv not in self.loop_vars: + var_checker = VarChecker(lv) + for bd in self.body + [cond]: + var_checker.visit(bd) + if var_checker.used: + lv_list.append(lv) + expr_list.append(exp) + extra_vars.append(lv) + break + with sb.if_scope(cond): - extra_args = [] - if extra_nodes: - extra_args = list(loop_vars[-len(extra_nodes):]) - sb.ret(wl(*list(self.body + extra_args))) + sb.ret(wl(*list(self.body + extra_vars))) with sb.else_scope(): - sb.ret(tvm.relay.Tuple(loop_vars)) + sb.ret(tvm.relay.Tuple(lv_list)) - loop_fn = tvm.relay.Function(loop_vars, sb.get()) + loop_fn = tvm.relay.Function(lv_list, sb.get()) sb = tvm.relay.scope_builder.ScopeBuilder() sb.let(wl, loop_fn) - loop_ret = wl(*self.loop_vars) + loop_ret = wl(*expr_list) sb.ret(loop_ret) ret = sb.get() @@ -2711,10 +2730,15 @@ def __init__(self): self._control_flow_node_map = defaultdict(set) self._loop_body_order = {} self._loop_var_order = {} - self._hash2tfnode = {} + self._lvar2expr = {} + self._lname_map = {} + self._sorted_cf_node_names = [] self._while_loop_name_set = set() + self._main_graph_proto = self + self._tensor_array_shapes = {} + self._tensor_array_shape_nodes = {} - def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): + def _get_relay_func(self, graph, layout="NHWC", shape=None, outputs=None): """Construct relay nodes from tensorflow graph definition - GraphDef. Follow the tensorflow graph definition to parse and convert it to Relay. @@ -2760,6 +2784,9 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): missing_operators = self._parse_import_prerequisites(graph) control_flow_nodes = [] + ta_write_nodes = [] + ta_gather_nodes = [] + ta_construct_nodes = [] self._in_shape = shape self._layout = layout self._graph = graph @@ -2823,6 +2850,50 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): if node.op == "Exit": self._while_loop_name_set.add(node_name_prefix) control_flow_nodes.append(node) + elif node.op.startswith("TensorArray"): + if is_tensor_array_constuctor(node): + ta_construct_nodes.append(node) + else: + for ta_write_name, idx in _tensor_array_write_ops.items(): + if node.op.startswith(ta_write_name): + ta_write_nodes.append((node, idx)) + break + if node.op.startswith("TensorArrayGather"): + ta_gather_nodes.append(node) + + # Use tensor array gather to infer static tensor array shape + for gather_node in ta_gather_nodes: + input_ta_name = gather_node.input[0] + input_ta_node = self._tf_node_map[input_ta_name] + if is_tensor_array_constuctor(input_ta_node): + gather_attr = self._parse_attr(gather_node.attr) + if "element_shape" not in gather_attr: + continue + raw_elem_shape = tensor_util.TensorShapeProtoToList(gather_attr["element_shape"]) + elem_shape = [] + for dim in raw_elem_shape: + if dim < 0: + elem_shape.append(Any()) + else: + elem_shape.append(int(dim)) + self._tensor_array_shapes[input_ta_node.name] = elem_shape + + # Fetch node contains static tensor array shape + for item in ta_write_nodes: + wnode = item[0] + ta_idx, inode_idx = item[1] + + stack = [self._tf_node_map[wnode.input[ta_idx].split(":")[0]]] + while stack: + cnode = stack.pop(0) + if not cnode.op.startswith("TensorArray"): + for iname in cnode.input: + stack.append(self._tf_node_map[iname.split(":")[0]]) + elif cnode.name != wnode.name: + if is_tensor_array_constuctor(cnode): + inode = self._tf_node_map[wnode.input[inode_idx].split(":")[0]] + self._tensor_array_shape_nodes[cnode.name] = (inode, wnode.op) + break # First, parse all control flow nodes. # Convert tf.cond to Branch and tf.while_loop to Loop. @@ -2847,6 +2918,9 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): if i == len(control_flow_nodes) - 1: sorted_cf_nodes.extend(exits) + for node in sorted_cf_nodes: + self._sorted_cf_node_names.append(node.name) + for node in sorted_cf_nodes: self._backtrack_construct(node.name) @@ -2880,7 +2954,20 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): out.append(out_rnn) out = out[0] if len(out) == 1 else _expr.Tuple(out) - func = _function.Function(analysis.free_vars(out), out) + fvars = analysis.free_vars(out) + func = _function.Function(fvars, out) + final_params = {} + for fv in fvars: + if fv.name_hint in self._params: + final_params[fv.name_hint] = self._params[fv.name_hint] + self._params = final_params + return func + + def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): + """ Wrapper to _get_relay_func which converts Tensorflow graph to Relay function + which is used as main function for the Relay module + """ + func = self._get_relay_func(graph, layout=layout, shape=shape, outputs=outputs) self._mod["main"] = func return self._mod, self._params @@ -2891,16 +2978,24 @@ def _parse_import_prerequisites(self, graph): which are not supported """ missing_operators = set() + from tensorflow.python.framework import op_def_registry for node in graph.node: + getOpDef = op_def_registry._registered_ops.get if hasattr(op_def_registry,\ + "_registered_ops") else op_def_registry.get + op_def = getOpDef(node.op) if node.op == "Placeholder" or node.op == 'PlaceholderWithDefault': pass elif node.op == "Const": pass + elif node.op in ["PartitionedCall", "StatefulPartitionedCall"]: + pass else: if any([node.op in t for t in [_identity_list, _convert_map, _convert_map_rnn, _control_flow_nodes]]): pass + elif op_def is not None and op_def.is_stateful: + missing_operators.add(node.op) else: missing_operators.add(node.op) @@ -3053,37 +3148,40 @@ def _convert_control_flow_operator(self, node, inputs, attrs, control_flow_node_ Converted relay expression. """ node_name_prefix = node.name.rsplit('/', 1)[0] + plname = find_parent_loop_name(node.name, self._while_loop_name_set) if node.op == "Merge": if _in_while_loop(self._control_flow_node_map, node_name_prefix): - op = self._backtrack_construct(node.input[0]) + op = self._licm_construct(plname, node.input[0]) if node_name_prefix not in self._loops: self._loops[node_name_prefix] = Loop(self._mod, - node_name_prefix, - self._hash2tfnode, - self._nodes, - self._while_loop_name_set) + plname, + self._lvar2expr) else: - if len(self._branches) == 0: - raise RuntimeError("Cannot find a created " - "conditional for merge node") + if node_name_prefix not in self._branches: + switch_prefix = node_name_prefix + "/Switch" + merge_idx = self._sorted_cf_node_names.index(node.name) + for i in range(merge_idx - 1, -1, -1): + cf_name = self._sorted_cf_node_names[i] + if cf_name.startswith(switch_prefix): + self._backtrack_construct(cf_name) + break + branch = self._branches[node_name_prefix] - false_br = self._backtrack_construct(node.input[0]) - true_br = self._backtrack_construct(node.input[1]) - assert len(true_br) == 1 - assert len(false_br) == 1 - branch.true_branch = true_br[0] - branch.false_branch = false_br[0] - op = [branch.if_node()] + false_br = self._licm_construct(plname, node.input[0]) + true_br = self._licm_construct(plname, node.input[1]) + branch.true_branch = true_br + branch.false_branch = false_br + op = branch.if_node() if node_name_prefix not in self._while_loop_name_set: try: cond_val = np.all(_infer_value(branch.cond, self._params, self._mod).asnumpy()) if cond_val: - op = [branch.true_branch] + op = branch.true_branch else: - op = [branch.false_branch] + op = branch.false_branch except Exception: - op = [branch.if_node()] + op = branch.if_node() elif node.op == "Exit": loop = self._loops[node_name_prefix] @@ -3109,17 +3207,15 @@ def _convert_control_flow_operator(self, node, inputs, attrs, control_flow_node_ if exit_number == j: body_pos = i break - op = [_expr.TupleGetItem(expr, body_pos)] + op = _expr.TupleGetItem(expr, body_pos) elif node.op == "Enter": - op = self._backtrack_construct(node.input[0]) + op = self._licm_construct(plname, node.input[0]) elif node.op == "LoopCond": - op = self._backtrack_construct(node.input[0]) - assert len(op) == 1 - self._loops[node_name_prefix].cond = op[0] + op = self._licm_construct(plname, node.input[0]) + self._loops[node_name_prefix].cond = op elif node.op == "Switch": - op = self._backtrack_construct(node.input[0]) - cond = self._backtrack_construct(node.input[1]) - assert len(op) == 1 + op = self._licm_construct(plname, node.input[0]) + cond = self._licm_construct(plname, node.input[1]) if _in_while_loop(self._control_flow_node_map, node_name_prefix): if node_name_prefix not in self._loop_var_order: self._loop_var_order[node_name_prefix] = [] @@ -3128,11 +3224,11 @@ def _convert_control_flow_operator(self, node, inputs, attrs, control_flow_node_ else: self._loop_var_order[node_name_prefix].\ append(int(node.name.split("Switch_")[-1])) - self._loops[node_name_prefix].loop_vars.append(op[0]) + self._loops[node_name_prefix].loop_vars.append(op) else: if node_name_prefix not in self._branches: self._branches[node_name_prefix] = Branch() - self._branches[node_name_prefix].cond = cond[0] + self._branches[node_name_prefix].cond = cond elif node.op == "NextIteration": if node_name_prefix not in self._loop_body_order: self._loop_body_order[node_name_prefix] = [] @@ -3141,16 +3237,99 @@ def _convert_control_flow_operator(self, node, inputs, attrs, control_flow_node_ else: self._loop_body_order[node_name_prefix].\ append(int(node.name.split("NextIteration_")[-1])) - op = self._backtrack_construct(node.input[0]) - - assert len(op) == 1 - self._loops[node_name_prefix].body.append(op[0]) + op = self._licm_construct(plname, node.input[0]) + self._loops[node_name_prefix].body.append(op) else: raise Exception("Cannot identify control flow operator: " + "{}".format(node.op)) return op + def _partition_call_operator(self, inputs, attr): + """ + Convert the Relay Partition call ops into Relay Function calls and + function definitions from Tensorflow graph library attribute to Relay global + functions + + Parameters + ---------- + node: TensorFlow graph node object. + A TensorFlow graph node object. + + inputs : List[tvm.relay.Expr] + List of input symbols. + + attrs : Dict[tvm.Attrs] + Dict of operator attributes. + + Returns + ------- + op : tvm.relay.Expr + Converted relay expression. + """ + + try: + from tensorflow.python.framework import function_def_to_graph + except ImportError as e: + raise ImportError( + "Unable to import tensorflow which is required {}".format(e)) + + main_graph_proto = self._main_graph_proto + outer_graph_def = main_graph_proto._graph + + node_func_name = attr.get('f').name + func = next((f for f in outer_graph_def.library.function + if f.signature.name == node_func_name), None) + if func: + devices = set(node.device for node in func.node_def) + if len(devices) > 1: + raise Exception("Found inconsistent Device assignment in the "\ + "Stateful Partitioned SubGraph. Rejecting "\ + "the subgraph ") + # Convert function definition to graph + func_input_shapes = func.attr["_input_shapes"].list.shape + subgraph, _ = function_def_to_graph.\ + function_def_to_graph_def(func, func_input_shapes) + + # Computing subgraph's input shape dictionary + subgraph_shape_dict, input_expr_dict = {}, {} + for f_arg, input in zip(func.signature.input_arg, inputs): + input_expr_dict[f_arg.name] = input + subgraph_shape_dict[f_arg.name] = _infer_shape(input, main_graph_proto._mod) + + func_name = 'func_{}'.format(func.signature.name) + try: + global_func = main_graph_proto._mod[func_name] + sub_func = global_func + sub_params = main_graph_proto._params + except ValueError: + # Construct relay nodes from the subgraph + g1 = SubGraphProto(main_graph_proto) + sub_func, sub_params = g1.from_tensorflow(subgraph, shape=subgraph_shape_dict) + main_graph_proto._params.update(sub_params) + func_expr = _function.Function(sub_func.params, sub_func.body) + global_func = tvm.relay.GlobalVar(func_name) + main_graph_proto._mod[global_func] = func_expr + + param_exprs = [] + for param_expr in sub_func.params: + # sub_params is subset of sub_func.params + param_name = param_expr.vid.name_hint + if param_name in input_expr_dict.keys(): + param_exprs.append(input_expr_dict[param_name]) + elif param_name in sub_params.keys(): + param_exprs.append(param_expr) + else: + raise Exception("Input parameter {} not found".format(param_name)) + + sb = tvm.relay.scope_builder.ScopeBuilder() + loop_ret = global_func(*param_exprs) + sb.ret(loop_ret) + ret = sb.get() + else: + raise Exception("Function not found - {}".format(node_func_name)) + return ret + def _convert_operator(self, op_name, inputs, attrs, graph, identity_list=None, convert_map=None): """Convert from Tensorflow operator to relay operator. @@ -3192,10 +3371,62 @@ def _convert_operator(self, op_name, inputs, attrs, sym = self._convert_rnn_operator(op_name, inputs, attrs, self._params, graph, convert_map_rnn) + + elif op_name in ["PartitionedCall", "StatefulPartitionedCall"]: + sym = self._partition_call_operator(inputs, attrs) else: raise NotImplementedError("Operator {} not implemented.".format(op_name)) return sym + def _licm_construct(self, loop_name, node_name): + """Construct a node by considering whether it is + loop invariant with the given while loop. If yes, we + generate a loop Variable. Otherwise, return regular + converted relay expression. + + Parameters + ---------- + loop_name : str + TensorFlow while loop name to be checked. + + node_name : str + TensorFlow node name. + + Returns + ------- + out : relay.Expr or relay.Var + Converted relay expression or loop var. + """ + actual_expr = self._backtrack_construct(node_name) + tn = node_name.split(':') + node_name = tn[0].split("^")[-1] + cloop_name = find_parent_loop_name(node_name, self._while_loop_name_set) + + if loop_name in self._while_loop_name_set and not cloop_name.startswith(loop_name): + if loop_name not in self._lvar2expr: + self._lvar2expr[loop_name] = {} + if loop_name not in self._lname_map: + self._lname_map[loop_name] = {} + + if node_name not in self._lname_map[loop_name]: + var_name = "{}_loop_var".format(node_name) + var_type = _infer_type(actual_expr, self._mod).checked_type + loop_var = tvm.relay.var(var_name, type_annotation=var_type) + try: + extra_param = _infer_value(actual_expr, self._params, self._mod) + self._params[var_name] = extra_param + except Exception: + pass + self._lvar2expr[loop_name][loop_var] = actual_expr + self._lname_map[loop_name][node_name] = loop_var + ret = loop_var + else: + ret = self._lname_map[loop_name][node_name] + else: + ret = actual_expr + + return ret + def _backtrack_construct(self, node_name): """Convert a specific tensorflow node to relay expression. @@ -3208,17 +3439,23 @@ def _backtrack_construct(self, node_name): Parameters ---------- node_name : str - Tensorflow node name. + TensorFlow node name. Returns ------- op : relay.Expr Converted relay expression """ - node_name = node_name.split(':')[0].split("^")[-1] + try: + from tensorflow.python.framework import tensor_util + except ImportError as e: + raise ImportError( + "Unable to import tensorflow which is required {}".format(e)) + + input_op_name = node_name.split(':')[0].split("^")[-1] - if node_name not in self._nodes: - node = self._tf_node_map[node_name] + if input_op_name not in self._nodes: + node = self._tf_node_map[input_op_name] attr = self._parse_attr(node.attr) if node.op in _control_flow_nodes: @@ -3227,20 +3464,50 @@ def _backtrack_construct(self, node_name): attr, self._control_flow_node_map) else: - attr["_output_shapes"] = self._output_shapes[node_name] + attr["_output_shapes"] = self._output_shapes[input_op_name] attr["_node_name"] = node.name attr["_target_layout"] = self._layout - inputs = [] - for iname in node.input: - in_op = self._backtrack_construct(iname) - if isinstance(in_op, _expr.TupleWrapper): - tn = iname.split(':') - tensor_slot = int(tn[1]) if len(tn) > 1 else 0 - in_op = in_op[tensor_slot] - else: - in_op = in_op[0] - inputs.append(in_op) + inputs = [self._backtrack_construct(iname) for iname in node.input] + + plname = find_parent_loop_name(node_name, self._while_loop_name_set) + + # For TensorArrayV3 op, we need to infer shape first + if is_tensor_array_constuctor(node): + raw_elem_shape = tensor_util.TensorShapeProtoToList(attr['element_shape']) + elem_shape = [] + for dim in raw_elem_shape: + if dim < 0: + elem_shape.append(Any()) + else: + elem_shape.append(dim) + + if elem_shape: + attr["shape"] = elem_shape + if attr['identical_element_shapes'] or elem_shape: + shape_node, wnode_op = self._tensor_array_shape_nodes[node.name] + converted = self._backtrack_construct(shape_node.name) + shape = _infer_shape(converted, self._mod) + if wnode_op.startswith("TensorArraySplit"): + shape = (Any(),) + shape[1:] + elif wnode_op.startswith("TensorArrayScatter"): + shape = shape[1:] + + if node.name in self._tensor_array_shapes: + preset_shape = self._tensor_array_shapes[node.name] + shape = _get_more_static_shape(shape, preset_shape) + + if "shape" in attr: + attr["shape"] = _get_more_static_shape(shape, attr["shape"]) + else: + attr["shape"] = shape + + # LICM + if plname in self._while_loop_name_set: + for i, iname in enumerate(node.input): + actual_input = self._licm_construct(plname, iname) + inputs[i] = actual_input + op = self._convert_operator(node.op, inputs, attr, self._graph) if isinstance(op, np.ndarray): @@ -3252,11 +3519,32 @@ def _backtrack_construct(self, node_name): elif isinstance(op, (_expr.Expr, _expr.TupleGetItem)): op = [op] - node_hash = s_hash(op) if isinstance(op, _expr.Tuple) else s_hash(op[0]) - self._hash2tfnode[node_hash] = node - self._nodes[node_name] = op + self._nodes[input_op_name] = op + + out = self._nodes[input_op_name] + + if isinstance(out, _expr.TupleWrapper): + tn = node_name.split(':') + tensor_slot = int(tn[1]) if len(tn) > 1 else 0 + return out[tensor_slot] + + return out[0] + + +class SubGraphProto(GraphProto): + """ A helper class for handling relay subgraph copying from Tensorflow GraphDef. + """ + def __init__(self, main_graph_proto): + super().__init__() + self._main_graph_proto = main_graph_proto # holds main graph proto object + + def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): + """ Wrapper to _get_relay_func which converts Tensorflow graph to Relay function. + Return Relay function and params + """ + func = self._get_relay_func(graph, layout=layout, shape=shape, outputs=outputs) + return func, self._params - return self._nodes[node_name] def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None): """Load tensorflow graph which is a python tensorflow graph object into relay. @@ -3284,6 +3572,7 @@ def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None): params : dict of str to tvm.nd.NDArray Dict of converted parameters stored in tvm.nd.NDArray format """ + g = GraphProto() mod, params = g.from_tensorflow(graph, layout, shape, outputs) return mod, params diff --git a/python/tvm/relay/frontend/tensorflow_parser.py b/python/tvm/relay/frontend/tensorflow_parser.py index fdbb8768597f..771aed06ac10 100644 --- a/python/tvm/relay/frontend/tensorflow_parser.py +++ b/python/tvm/relay/frontend/tensorflow_parser.py @@ -30,6 +30,10 @@ class TFParser(object): model_dir : tensorflow frozen pb file or a directory that contains saved model or checkpoints. + outputs : List of output tensor names (Optional) + Optional output node names. This will be protected for saved model + when we do remove training nodes. + Examples -------- .. code-block:: python @@ -38,11 +42,12 @@ class TFParser(object): graphdef = parser.parse() """ - def __init__(self, model_dir): + def __init__(self, model_dir, outputs=None): from tensorflow.core.framework import graph_pb2 self._tmp_dir = util.tempdir() self._model_dir = model_dir self._graph = graph_pb2.GraphDef() + self._outputs = outputs or [] def _set_graph(self, graph): """Set Graph""" @@ -128,7 +133,8 @@ def _load_saved_model(self): output_graph_def = graph_pb2.GraphDef() with open(output_graph_filename, "rb") as f: output_graph_def.ParseFromString(f.read()) - output_graph_def = graph_util.remove_training_nodes(output_graph_def) + output_graph_def = graph_util.remove_training_nodes(output_graph_def, + protected_nodes=self._outputs) return output_graph_def def _load_ckpt(self): diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index d489bd34f7ac..113f764b4cf9 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -17,6 +17,7 @@ # pylint: disable=invalid-name, unused-argument, too-many-lines, import-outside-toplevel """Tensorflow lite frontend.""" import math +import itertools import numpy as np import tvm from tvm.ir import IRModule @@ -28,9 +29,10 @@ from .. import op as _op from .. import qnn as _qnn from ... import nd as _nd -from .util import get_scalar_from_constant from .common import ExprTable from .common import infer_shape as _infer_shape +from .tflite_flexbuffer import FlexBufferDecoder + __all__ = ['from_tflite'] @@ -64,6 +66,7 @@ def __init__(self, model, subgraph, exp_tab): self.convert_map = { 'ABS': self.convert_abs, 'ADD': self.convert_add, + 'ADD_N': self.convert_add_n, 'AVERAGE_POOL_2D': self.convert_average_pool2d, 'BATCH_TO_SPACE_ND': self.convert_batch_to_space_nd, 'CAST': self.convert_cast, @@ -73,29 +76,35 @@ def __init__(self, model, subgraph, exp_tab): 'COS': self.convert_cos, 'DEPTH_TO_SPACE': self.convert_depth_to_space, 'DEPTHWISE_CONV_2D': self.convert_depthwise_conv2d, + 'DEQUANTIZE': self.convert_dequantize, 'DETECTION_POSTPROCESS': self.convert_detection_postprocess, 'DIV': self.convert_div, 'ELU': self.convert_elu, 'EQUAL': self.convert_equal, 'EXP': self.convert_exp, + 'FILL': self.convert_fill, 'FLOOR_DIV': self.convert_floor_div, 'FLOOR_MOD': self.convert_floor_mod, 'FLOOR': self.convert_floor, 'FULLY_CONNECTED': self.convert_fully_connected, + 'GATHER': self.convert_gather, + 'GATHER_ND' : self.convert_gather_nd, 'GREATER_EQUAL': self.convert_greater_equal, 'GREATER': self.convert_greater, 'HARD_SWISH': self.convert_hard_swish, 'L2_NORMALIZATION': self.convert_l2_normalization, + 'L2_POOL_2D': self.convert_l2_pool2d, 'LESS_EQUAL': self.convert_less_equal, 'LESS': self.convert_less, 'LOCAL_RESPONSE_NORMALIZATION': self.convert_lrn, 'LOG': self.convert_log, 'LOGICAL_AND': self.convert_logical_and, + 'LOGICAL_NOT': self.convert_logical_not, 'LOGICAL_OR': self.convert_logical_or, 'LOGISTIC': self.convert_logistic, 'MAX_POOL_2D': self.convert_max_pool2d, 'MAXIMUM': self.convert_maximum, - 'MEAN': self._convert_reduce_mean, + 'MEAN': self.convert_reduce_mean, 'MINIMUM': self.convert_minimum, 'MIRROR_PAD': self.convert_mirror_pad, 'MUL': self.convert_mul, @@ -105,28 +114,35 @@ def __init__(self, model, subgraph, exp_tab): 'PAD': self.convert_pad, 'POW': self.convert_pow, 'PRELU': self.convert_prelu, - 'REDUCE_ANY': self._convert_reduce_any, - 'REDUCE_MAX': self._convert_reduce_max, - 'REDUCE_MIN': self._convert_reduce_min, - 'REDUCE_PROD': self._convert_reduce_prod, + 'RANGE': self.convert_range, + 'QUANTIZE': self.convert_quantize, + 'REDUCE_ANY': self.convert_reduce_any, + 'REDUCE_MAX': self.convert_reduce_max, + 'REDUCE_MIN': self.convert_reduce_min, + 'REDUCE_PROD': self.convert_reduce_prod, 'RELU':self.convert_relu, 'RESHAPE': self.convert_reshape, 'RESIZE_BILINEAR': self.convert_resize_bilinear, 'RESIZE_NEAREST_NEIGHBOR': self.convert_resize_nearest_neighbor, 'ROUND': self.convert_round, 'RSQRT': self.convert_rsqrt, + 'SELECT': self.convert_select, + 'SHAPE': self.convert_shape, 'SIN': self.convert_sin, 'SLICE': self.convert_slice, 'SOFTMAX': self.convert_softmax, 'SPACE_TO_BATCH_ND': self.convert_space_to_batch_nd, 'SPACE_TO_DEPTH': self.convert_space_to_depth, + 'SPARSE_TO_DENSE': self.convert_sparse_to_dense, 'SPLIT': self.convert_split, + 'SPLIT_V': self.convert_split_v, 'SQRT': self.convert_sqrt, 'SQUARE': self.convert_square, 'SQUARED_DIFFERENCE': self.convert_squared_difference, 'SQUEEZE': self.convert_squeeze, + 'STRIDED_SLICE': self.convert_strided_slice, 'SUB': self.convert_sub, - 'SUM': self._convert_reduce_sum, + 'SUM': self.convert_reduce_sum, 'TAN': self.convert_tan, 'TANH':self.convert_tanh, 'TILE': self.convert_tile, @@ -134,6 +150,7 @@ def __init__(self, model, subgraph, exp_tab): 'TRANSPOSE_CONV': self.convert_transpose_conv, 'TRANSPOSE': self.convert_transpose, 'UNPACK': self.convert_unpack, + 'WHERE': self.convert_select, 'ZEROS_LIKE': self.convert_zeros_like, } @@ -159,7 +176,12 @@ def convert_op_to_relay(self): op = self.subgraph.Operators(op_idx) op_code_str = self.get_op_code_str(op) output_tensors = self.get_output_tensors(op) + try: + from tflite.Operator import Operator + except ImportError: + raise ImportError("The tflite package must be installed") + assert isinstance(op, Operator) ret = self.convert_map[op_code_str](op) if len(output_tensors) == 1: @@ -261,6 +283,8 @@ def get_tensor_type_str(self, tensor_type): except ImportError: raise ImportError("The tflite package must be installed") + if tensor_type == TensorType.INT8: + return "int8" if tensor_type == TensorType.UINT8: return "uint8" if tensor_type == TensorType.FLOAT32: @@ -288,12 +312,6 @@ def has_same_qnn_params(self, lhs_tensor, rhs_tensor): def is_quantized(self, op): """Check if an input tensor is quantized.""" - try: - from tflite.Operator import Operator - except ImportError: - raise ImportError("The tflite package must be installed") - - assert isinstance(op, Operator) input_tensors = self.get_input_tensors(op) first_tensor = input_tensors[0] return first_tensor.qnn_params is not None @@ -315,6 +333,45 @@ def dequantize(self, expr, tensor): input_zero_point=tensor.qnn_params['zero_point']) return dequantized + + def convert_qnn_fused_activation_function(self, expr, fused_activation_fn, + scale, zero_point, dtype): + """Convert TFLite fused activation function. The expr is an input quantized tensor with + scale and zero point """ + try: + from tflite.ActivationFunctionType import ActivationFunctionType + except ImportError: + raise ImportError("The tflite package must be installed") + + # Quantize a float value to an quantized integer value + quantize = lambda x: float(int(round(x / scale)) + zero_point) + + # Get min/max of the output dtype. This will be used to ensure that clip a_min/a_max are not + # beyond the dtype range. + qmin = float(tvm.tir.op.min_value(dtype).value) + qmax = float(tvm.tir.op.max_value(dtype).value) + + # The input expr is a quantized tensor with its scale and zero point. We calculate the + # suitable clip off points based on these scale and zero point. + if fused_activation_fn == ActivationFunctionType.NONE: + return expr + if fused_activation_fn == ActivationFunctionType.RELU6: + return _op.clip(expr, + a_min=max(qmin, quantize(0)), + a_max=min(qmax, quantize(6.0))) + if fused_activation_fn == ActivationFunctionType.RELU_N1_TO_1: + return _op.clip(expr, + a_min=max(qmin, quantize(-1.0)), + a_max=min(qmax, quantize(1.0))) + if fused_activation_fn == ActivationFunctionType.RELU: + return _op.clip(expr, + a_min=max(qmin, quantize(0.0)), + a_max=qmax) + + fused_activation_fn_str = self.activation_fn_type[fused_activation_fn] + raise tvm.error.OpNotImplemented( + 'Quantized activation {} is not supported yet.'.format(fused_activation_fn_str)) + def convert_conv2d(self, op): """Convert TFLite conv2d""" return self.convert_conv(op, "conv2d") @@ -331,16 +388,18 @@ def convert_max_pool2d(self, op): """Convert TFLite max pool2d""" return self.convert_pool2d(op, "max") + def convert_l2_pool2d(self, op): + """Convert TFLite l2 pool2d""" + return self.convert_pool2d(op, "l2") + def convert_reshape(self, op): """Convert TFLite reshape""" try: from tflite.BuiltinOptions import BuiltinOptions - from tflite.Operator import Operator from tflite.ReshapeOptions import ReshapeOptions except ImportError: raise ImportError("The tflite package must be installed") - assert isinstance(op, Operator) input_tensors = self.get_input_tensors(op) assert input_tensors, "input tensors should not be empty" input_tensor = input_tensors[0] @@ -368,7 +427,6 @@ def _convert_resize(self, method, op): """Generic method to Convert TFLite RESIZE operators""" try: from tflite.BuiltinOptions import BuiltinOptions - from tflite.Operator import Operator from tflite.ResizeBilinearOptions import ResizeBilinearOptions # ResizeNearestNeighborOptions was added in tflite v1.13 tflite_ver = 1120 @@ -378,7 +436,6 @@ def _convert_resize(self, method, op): except ImportError: raise ImportError("The tflite package must be installed") - assert isinstance(op, Operator) input_tensors = self.get_input_tensors(op) assert len(input_tensors) == 2, "input tensors length should be 2" @@ -421,14 +478,11 @@ def convert_resize_nearest_neighbor(self, op): def convert_l2_normalization(self, op): """Convert TFLite L2_NORMALIZATION """ try: - from tflite.Operator import Operator from tflite.BuiltinOptions import BuiltinOptions from tflite.L2NormOptions import L2NormOptions - from tflite.ActivationFunctionType import ActivationFunctionType except ImportError: raise ImportError("The tflite package must be installed") - assert isinstance(op, Operator) input_tensors = self.get_input_tensors(op) assert len(input_tensors) == 1, "input tensors length should be 1" input_tensor = input_tensors[0] @@ -450,30 +504,26 @@ def convert_l2_normalization(self, op): if self.is_quantized(op): raise tvm.error.OpNotImplemented( 'TFLite quantized L2_NORMALIZATION operator is not supported yet.') + # TFL uses only the default epsilon value out = _op.nn.l2_normalize(in_expr, eps=1e-12, axis=[input_tensor_rank - 1]) # if we have fused activation fn - if fused_activation_fn != ActivationFunctionType.NONE: - if not output_tensor.qnn_params: - out = self.convert_fused_activation_function(out, fused_activation_fn) - else: - raise tvm.error.OpNotImplemented( - 'TFLite quantized L2_NORMALIZATION operator\ - with fused activation function is not supported yet.') + if output_tensor.qnn_params: + raise tvm.error.OpNotImplemented( + 'TFLite quantized L2_NORMALIZATION operator is not supported yet.') + out = self.convert_fused_activation_function(out, fused_activation_fn) return out def convert_lrn(self, op): """Convert TFLite LOCAL_RESPONSE_NORMALIZATION """ try: - from tflite.Operator import Operator from tflite.BuiltinOptions import BuiltinOptions from tflite.LocalResponseNormalizationOptions import LocalResponseNormalizationOptions except ImportError: raise ImportError("The tflite package must be installed") - assert isinstance(op, Operator) if self.is_quantized(op): raise tvm.error.OpNotImplemented( 'TFlite quantized LRN operator is not supported yet.') @@ -503,12 +553,6 @@ def convert_lrn(self, op): def convert_logistic(self, op): """Convert TFLite LOGISTIC""" - try: - from tflite.Operator import Operator - except ImportError: - raise ImportError("The tflite package must be installed") - - assert isinstance(op, Operator) input_tensors = self.get_input_tensors(op) assert len(input_tensors) == 1, "input tensors length should be 1" @@ -529,12 +573,6 @@ def convert_logistic(self, op): def convert_softmax(self, op): """Convert TFLite softmax""" - try: - from tflite.Operator import Operator - except ImportError: - raise ImportError("The tflite package must be installed") - - assert isinstance(op, Operator) input_tensors = self.get_input_tensors(op) assert len(input_tensors) == 1, "input tensors length should be 1" @@ -564,12 +602,6 @@ def convert_softmax(self, op): def convert_tanh(self, op): """Convert TFLite TANH""" - try: - from tflite.Operator import Operator - except ImportError: - raise ImportError("The tflite package must be installed") - - assert isinstance(op, Operator) input_tensors = self.get_input_tensors(op) assert len(input_tensors) == 1, "input tensors length should be 1" @@ -579,14 +611,41 @@ def convert_tanh(self, op): return out - def convert_relu(self, op): - """Convert TFLite ReLU""" + def convert_range(self, op): + """Convert TFLite Range""" try: - from tflite.Operator import Operator + from tflite.TensorType import TensorType except ImportError: raise ImportError("The tflite package must be installed") - assert isinstance(op, Operator) + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 3, "input tensors length should be 3" + + start, limit, delta = input_tensors[0], input_tensors[1], input_tensors[2] + + expressions = [self.get_tensor_expr(t) for t in [start, limit, delta]] + + # out type inference + if delta.tensor.Type() == TensorType.FLOAT32: + out_type = self.get_tensor_type_str(delta.tensor.Type()) + else: + out_type = self.get_tensor_type_str(start.tensor.Type()) + + out = _op.arange(expressions[0], expressions[1], expressions[2], out_type) + + return out + + def convert_shape(self, op): + """Convert TFLite Shape""" + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 1, "input tensors length should be 1" + + out = _op.shape_of(self.get_tensor_expr(input_tensors[0])) + + return out + + def convert_relu(self, op): + """Convert TFLite ReLU""" input_tensors = self.get_input_tensors(op) assert len(input_tensors) == 1, "input tensors length should be 1" @@ -598,12 +657,6 @@ def convert_relu(self, op): def convert_hard_swish(self, op): """Convert TFLite Hard swish""" - try: - from tflite.Operator import Operator - except ImportError: - raise ImportError("The tflite package must be installed") - assert isinstance(op, Operator) - input_tensors = self.get_input_tensors(op) assert len(input_tensors) == 1, "input tensors length should be 1" input_tensor = input_tensors[0] @@ -635,14 +688,11 @@ def _hard_swish(data): def convert_concatenation(self, op): """Convert TFLite concatenation""" try: - from tflite.Operator import Operator from tflite.ConcatenationOptions import ConcatenationOptions from tflite.BuiltinOptions import BuiltinOptions - from tflite.ActivationFunctionType import ActivationFunctionType except ImportError: raise ImportError("The tflite package must be installed") - assert isinstance(op, Operator) input_tensors = self.get_input_tensors(op) assert len(input_tensors) >= 1, "input tensors should greater than 1" in_exprs = [self.get_expr(input_tensor.tensor_idx) for input_tensor in input_tensors] @@ -671,24 +721,24 @@ def convert_concatenation(self, op): output_zero_point=output_tensor.qnn_params['zero_point'], axis=concatenation_axis) - # if we have activation fn - if fused_activation_fn != ActivationFunctionType.NONE: - if not output_tensor.qnn_params: - out = self.convert_fused_activation_function(out, fused_activation_fn) - else: - raise tvm.error.OpNotImplemented( - 'Operator {} with fused activation is not supported yet.' - .format('qnn.op.concatenate')) + # Handle fused activations + if output_tensor.qnn_params: + scale_val = get_scalar_from_constant(output_tensor.qnn_params['scale']) + zero_point_val = get_scalar_from_constant(output_tensor.qnn_params['zero_point']) + output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type()) + out = self.convert_qnn_fused_activation_function(\ + expr=out, + fused_activation_fn=fused_activation_fn, + scale=scale_val, + zero_point=zero_point_val, + dtype=output_tensor_type_str) + else: + out = self.convert_fused_activation_function(out, fused_activation_fn) + return out def _convert_unary_elemwise(self, relay_op, op): """Generic method to convert TFLite unary elemwise functions""" - try: - from tflite.Operator import Operator - except ImportError: - raise ImportError("The tflite package must be installed") - - assert isinstance(op, Operator) input_tensors = self.get_input_tensors(op) assert len(input_tensors) == 1, "input tensors length should be 1" @@ -784,12 +834,6 @@ def convert_neg(self, op): def convert_elu(self, op): """Convert TFLite ELU""" - try: - from tflite.Operator import Operator - except ImportError: - raise ImportError("The tflite package must be installed") - assert isinstance(op, Operator) - if self.is_quantized(op): raise tvm.error.OpNotImplemented( 'TFlite quantized ELU operator is not supported yet.') @@ -807,12 +851,6 @@ def convert_elu(self, op): def convert_square(self, op): """Convert TFLite SQUARE""" - try: - from tflite.Operator import Operator - except ImportError: - raise ImportError("The tflite package must be installed") - - assert isinstance(op, Operator) input_tensors = self.get_input_tensors(op) assert len(input_tensors) == 1, "input tensors length should be 1" input_tensor = input_tensors[0] @@ -834,43 +872,21 @@ def convert_square(self, op): def _convert_elemwise(self, relay_op, op): """Generic method to Convert TFLite elemwise""" try: - from tflite.Operator import Operator from tflite.AddOptions import AddOptions from tflite.SubOptions import SubOptions from tflite.MulOptions import MulOptions from tflite.DivOptions import DivOptions from tflite.BuiltinOptions import BuiltinOptions - from tflite.ActivationFunctionType import ActivationFunctionType except ImportError: raise ImportError("The tflite package must be installed") - assert isinstance(op, Operator) input_tensors = self.get_input_tensors(op) assert len(input_tensors) == 2, "input tensors length should be 2" lhs_tensor = input_tensors[0] - if self.has_expr(lhs_tensor.tensor_idx): - # In most cases, we can assume that TOCO fuses elemwise operators - # with constants - it means both will be tensors. - lhs_expr = self.get_expr(lhs_tensor.tensor_idx) - else: - # However, in some corner cases, the elemwise operator is not fused, - # we can receive as constant. - lhs_type_str = self.get_tensor_type_str(lhs_tensor.tensor.Type()) - lhs_expr = self.exp_tab.new_const(self.get_tensor_value(lhs_tensor), - dtype=lhs_type_str) - rhs_tensor = input_tensors[1] - if self.has_expr(rhs_tensor.tensor_idx): - # In most cases, we can assume that TOCO fuses elemwise operators - # with constants - it means both will be tensors. - rhs_expr = self.get_expr(rhs_tensor.tensor_idx) - else: - # However, in some corner cases, the elemwise operator is not fused, - # we can receive as constant. - rhs_type_str = self.get_tensor_type_str(rhs_tensor.tensor.Type()) - rhs_expr = self.exp_tab.new_const(self.get_tensor_value(rhs_tensor), - dtype=rhs_type_str) + lhs_expr = self.get_tensor_expr(lhs_tensor) + rhs_expr = self.get_tensor_expr(rhs_tensor) output_tensors = self.get_output_tensors(op) assert len(output_tensors) == 1, "output tensors length should be 1" @@ -906,13 +922,20 @@ def _convert_elemwise(self, relay_op, op): op_options = op.BuiltinOptions() options.Init(op_options.Bytes, op_options.Pos) fused_activation_fn = options.FusedActivationFunction() - # if we have activation fn - if fused_activation_fn != ActivationFunctionType.NONE: - if output_tensor.qnn_params: - raise tvm.error.OpNotImplemented( - 'Elemwise operators with fused activation are not supported yet.') - out = self.convert_fused_activation_function(out, fused_activation_fn) + # Handle fused activations + if output_tensor.qnn_params: + scale_val = get_scalar_from_constant(output_tensor.qnn_params['scale']) + zero_point_val = get_scalar_from_constant(output_tensor.qnn_params['zero_point']) + output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type()) + out = self.convert_qnn_fused_activation_function(\ + expr=out, + fused_activation_fn=fused_activation_fn, + scale=scale_val, + zero_point=zero_point_val, + dtype=output_tensor_type_str) + else: + out = self.convert_fused_activation_function(out, fused_activation_fn) return out def convert_add(self, op): @@ -922,6 +945,20 @@ def convert_add(self, op): return self._convert_elemwise(_qnn.op.add, op) return self._convert_elemwise(_op.add, op) + def convert_add_n(self, op): + """Convert TFLite ADD_N""" + output_tensors = self.get_output_tensors(op) + assert len(output_tensors) == 1, "output tensors length should be 1" + + input_tensors = self.get_input_tensors(op) + assert not input_tensors[0].qnn_params, "TFLite does not support quantized ADD_N." + lhs_expr = self.get_tensor_expr(input_tensors[0]) + for rhs_tensor in input_tensors[1:]: + assert not rhs_tensor.qnn_params, "TFLite does not support quantized ADD_N" + rhs_expr = self.get_tensor_expr(rhs_tensor) + lhs_expr = _op.add(lhs_expr, rhs_expr) + return lhs_expr + def convert_sub(self, op): """Convert TFLite SUB""" # Check if the input tensor is quantized, call QNN op @@ -1025,12 +1062,6 @@ def convert_not_equal(self, op): def _convert_logical_binary(self, relay_op, op): """Generic method to convert logical binary ops""" - try: - from tflite.Operator import Operator - except ImportError: - raise ImportError("The tflite package must be installed") - - assert isinstance(op, Operator) input_tensors = self.get_input_tensors(op) assert len(input_tensors) == 2, "input tensors length should be 2" @@ -1050,14 +1081,254 @@ def convert_logical_or(self, op): """Convert tflite LOGICAL_OR""" return self._convert_logical_binary(_op.logical_or, op) - def convert_zeros_like(self, op): - """Convert TFLite ZEROS LIKE""" + def convert_logical_not(self, op): + """Convert tflite LOGICAL_NOT""" + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 1, "input tensors length should be 1" + + data = self.get_expr(input_tensors[0].tensor_idx) + out = _op.logical_not(data) + + return out + + def convert_gather(self, op): + """Method to Convert TFLite GATHER operator""" try: - from tflite.Operator import Operator + from tflite.BuiltinOptions import BuiltinOptions + from tflite.GatherOptions import GatherOptions + from tflite.TensorType import TensorType except ImportError: raise ImportError("The tflite package must be installed") - assert isinstance(op, Operator) + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 2, "input tensors length should be 2" + + data = self.get_expr(input_tensors[0].tensor_idx) + + indices = input_tensors[1] + indices_type = indices.tensor.Type() + assert indices_type in (TensorType.INT32, TensorType.INT64) + indices_type_str = self.get_tensor_type_str(indices_type) + indices = self.exp_tab.new_const(self.get_tensor_value(indices), + dtype=indices_type_str) + + assert op.BuiltinOptionsType() == BuiltinOptions.GatherOptions + op_options = op.BuiltinOptions() + gather_options = GatherOptions() + gather_options.Init(op_options.Bytes, op_options.Pos) + axis = gather_options.Axis() + + # Check the indices are with in bounds. + data_shape = list(input_tensors[0].tensor.ShapeAsNumpy()) + data_dim = len(data_shape) + + axis_n = axis + if axis_n < 0: + axis_n += axis_n + data_dim + assert axis_n >= 0, "Axis out of bounds" + assert axis_n < data_dim, "Axis out of bounds" + + indices_val = self.get_tensor_value(input_tensors[1]) + indices_shape = list(indices_val.shape) + indices_len = len(indices_shape) + + out_shape = [] + for i in range(data_dim): + if axis_n == i: + for j in range(indices_len): + out_shape.append(indices_shape[j]) + else: + out_shape.append(data_shape[i]) + + loopover = [range(s) for s in out_shape] + for idx in list(itertools.product(*loopover)): + indices_position = [idx[j] for j in range(axis_n, axis_n+indices_len)] + + real_indices = [idx[j] for j in range(axis_n)] + real_indices.append(indices_val[tuple(indices_position)]) + real_indices.extend([idx[j] for j in range(axis_n + indices_len, len(idx))]) + for r, d in zip(real_indices, data_shape): + if r >= d: + raise ValueError("TFLite out of bound indices are not supported.") + + # Use mode 'fast' since indices are already checked within bounds. + out = _op.take(data, indices, axis=axis, mode="fast") + return out + + def convert_gather_nd(self, op): + """Method to Convert TFLite GATHER_ND operator""" + try: + from tflite.TensorType import TensorType + except ImportError: + raise ImportError("The tflite package must be installed") + + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 2, "input tensors length should be 2" + + for t in input_tensors: + assert not t.qnn_params, "Quantized input is not expected." + + data = self.get_tensor_expr(input_tensors[0]) + indices = self.get_tensor_expr(input_tensors[1]) + + indices_type = input_tensors[1].tensor.Type() + assert indices_type in (TensorType.INT32, TensorType.INT64) + + indices_dims = len(_infer_shape(indices)) + indices_t = _op.transpose(indices, axes=[-1] + list(range(indices_dims-1))) + + out = _op.gather_nd(data, indices_t) + return out + + def convert_strided_slice(self, op): + """Method to Convert TFLite STRIDED_SLICE operator. + NOTE: Eventhough tensorflow supports begin_mask, end_mask, ellipsis_mask, new_axis_mask + and shrink_axis_mask, tflite doesn't support these and expect these values to be zero. + But in future, they may open up the mask implementation, so kept the implementation + same as tensorflow. + + This op extracts a slice of size (end - begin) / stride from the given input tensor. + Starting at the location specified by begin the slice continues by adding stride to the + index until all dimensions are not less than end. Note that a stride can be negative, + which causes a reverse slice. + + For slice input[val0, val1, ..., valn], begin/end/strides will be vectors of length n. + + In each mask field(begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask) + the ith bit will correspond to the ith val. + + If the ith bit of begin_mask is set, begin[i] is ignored and the fullest possible range + in that dimension is used instead. + + If the ith bit of ellipsis_mask is set, as many unspecified dimensions as needed will be + inserted between other dimensions. Only one non-zero bit is allowed in ellipsis_mask. + + If the ith bit of new_axis_mask is set, then begin, end, and stride are ignored and a + new length 1 dimension is added at this point in the output tensor. + + If the ith bit of shrink_axis_mask is set, it implies that the ith specification shrinks + the dimensionality by 1, taking on the value at index begin[i]. end[i] and strides[i] + are ignored in this case. + begin and end are zero-indexed. strides entries must be non-zero. + + TVM Relay implementation of doesn't support mask, so the mask values are processed in + this function and begin/end/strides are updated accordingly. If any mask is present, and + since tvm doesn't support mask computation directly, the output need a final reshape. + """ + try: + from tflite.BuiltinOptions import BuiltinOptions + from tflite.StridedSliceOptions import StridedSliceOptions + except ImportError: + raise ImportError("The tflite package must be installed") + + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 4, "input tensors length should be 4" + + data_expr = self.get_expr(input_tensors[0].tensor_idx) + + begin = list(self.get_tensor_value(input_tensors[1])) + end = list(self.get_tensor_value(input_tensors[2])) + stride = list(self.get_tensor_value(input_tensors[3])) + + assert op.BuiltinOptionsType() == BuiltinOptions.StridedSliceOptions + op_options = op.BuiltinOptions() + options = StridedSliceOptions() + options.Init(op_options.Bytes, op_options.Pos) + begin_mask = options.BeginMask() + end_mask = options.EndMask() + ellipsis_mask = options.EllipsisMask() + new_axis_mask = options.NewAxisMask() + shrink_axis_mask = options.ShrinkAxisMask() + + data_shape = list(input_tensors[0].tensor.ShapeAsNumpy()) + data_dim = len(data_shape) + stride_dim = len(stride) + def _transform_mask(stride_dim, ellipsis_mask): + """Handle mask inputs to create new begin, end, stride and output shape""" + m_begin = [0] * data_dim + m_end = [0] * data_dim + m_stride = [0] * data_dim + fshape_indices = [] + #Count new axis after ellipsis_mask, consider while applying ellipsis_mask. + ellipsis_seen = False + new_axes_after_ellipsis = 0 + for i in range(stride_dim): + mask = 1 << i + if ellipsis_seen and (mask & new_axis_mask) != 0: + new_axes_after_ellipsis += 1 + if (mask & ellipsis_mask) != 0: + ellipsis_seen = True + if not ellipsis_seen: + #Used later for extending the stride attributes in the below loop. + ellipsis_mask |= (1 << stride_dim) + stride_dim += 1 + final_index = 0 + for index in range(stride_dim): + mask = 1 << index + if mask & ellipsis_mask: + #Identify the end index for applying ellipsis_mask + to_index = min(((data_dim - (stride_dim-index)) + 1 \ + + new_axes_after_ellipsis), data_dim) + for i in range(final_index, to_index): + m_begin[final_index] = 0 + m_end[final_index] = data_shape[final_index] + m_stride[final_index] = 1 + fshape_indices.append(final_index) + final_index += 1 + elif mask &new_axis_mask: + fshape_indices.append(-1) + elif not mask & new_axis_mask: + if final_index == len(m_begin): + break + if mask & begin_mask: + m_begin[final_index] = data_shape[final_index] \ + if stride[index] < 0 else 0 + elif begin[index]: + m_begin[final_index] = begin[index] + if mask & end_mask: + m_end[final_index] = 0 if stride[index] < 0 \ + else data_shape[final_index] + elif end[index]: + m_end[final_index] = end[index] + m_stride[final_index] = stride[index] + if mask & shrink_axis_mask: + #Tensorflow make axis with shrink_axis_mask as dimension 1 + m_begin[final_index] = data_shape[final_index] + begin[index] \ + if begin[index] < 0 else begin[index] + m_end[final_index] = begin[index] + 1 + m_stride[final_index] = 1 + fshape_indices.append(-2) + else: + fshape_indices.append(final_index) + + final_index += 1 + return m_begin, m_end, m_stride, fshape_indices + + fshape_indices = None + if begin_mask or end_mask or ellipsis_mask or new_axis_mask or shrink_axis_mask: + begin, end, stride, fshape_indices = _transform_mask(stride_dim, ellipsis_mask) + + out = _op.strided_slice(data_expr, begin=begin, end=end, strides=stride) + out_shape = _infer_shape(out) + if not fshape_indices: + fshape_indices = range(len(out_shape)) + + #Create final output shape. + final_output = [] + for gather_index in fshape_indices: + if gather_index == -1: + final_output.append(1) + elif gather_index == -2: + pass + else: + final_output.append(out_shape[gather_index]) + + if not final_output: + return out + return _op.reshape(out, newshape=tuple(final_output)) + + def convert_zeros_like(self, op): + """Convert TFLite ZEROS LIKE""" input_tensors = self.get_input_tensors(op) assert len(input_tensors) == 1, "input tensors length should be 1" @@ -1067,16 +1338,29 @@ def convert_zeros_like(self, op): return out + def convert_fill(self, op): + """Convert TFLite FILL""" + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 2, "input tensors length should be 2" + + if self.has_expr(input_tensors[0].tensor_idx): + raise tvm.error.OpNotImplemented("For dims parameter of Fill operator," + " only constant values are supported.") + + in_dims = list(self.get_tensor_value(input_tensors[0])) + in_value_expr = self.get_expr(input_tensors[1].tensor_idx) + out = _op.full(in_value_expr, in_dims) + + return out + def _convert_reduce(self, relay_op, op): - """Generic method to Convert TFLite MEAN operators""" + """Generic method to Convert TFLite REDUCE operators""" try: from tflite.BuiltinOptions import BuiltinOptions - from tflite.Operator import Operator from tflite.ReducerOptions import ReducerOptions except ImportError: raise ImportError("The tflite package must be installed") - assert isinstance(op, Operator) input_tensors = self.get_input_tensors(op) assert len(input_tensors) == 2, "input tensors length should be 2" @@ -1114,36 +1398,33 @@ def _convert_reduce(self, relay_op, op): return out - def _convert_reduce_min(self, op): + def convert_reduce_min(self, op): return self._convert_reduce(_op.reduce.min, op) - def _convert_reduce_max(self, op): + def convert_reduce_max(self, op): return self._convert_reduce(_op.reduce.max, op) - def _convert_reduce_mean(self, op): + def convert_reduce_mean(self, op): return self._convert_reduce(_op.reduce.mean, op) - def _convert_reduce_prod(self, op): + def convert_reduce_prod(self, op): return self._convert_reduce(_op.reduce.prod, op) - def _convert_reduce_sum(self, op): + def convert_reduce_sum(self, op): return self._convert_reduce(_op.reduce.sum, op) - def _convert_reduce_any(self, op): + def convert_reduce_any(self, op): return self._convert_reduce(_op.reduce.any, op) def convert_fully_connected(self, op): """Convert TFLite fully connected""" try: - from tflite.Operator import Operator from tflite.FullyConnectedOptions import FullyConnectedOptions from tflite.BuiltinOptions import BuiltinOptions from tflite.TensorType import TensorType - from tflite.ActivationFunctionType import ActivationFunctionType except ImportError: raise ImportError("The tflite package must be installed") - assert isinstance(op, Operator) input_tensors = self.get_input_tensors(op) assert len(input_tensors) >= 2, "input tensors length should be >= 2" @@ -1160,16 +1441,28 @@ def convert_fully_connected(self, op): input_tensor_shape = input_tensor.tensor.ShapeAsNumpy() weight_tensor_shape = weight_tensor.tensor.ShapeAsNumpy() - # reshape input tensor from N H W C to N H*W*C - input_size_per_batch = 1 - for s in range(1, len(input_tensor_shape)): - input_size_per_batch *= input_tensor_shape[s] - assert input_size_per_batch == weight_tensor_shape[1], \ - "input size and weight size are mismatched" - target_shape = tuple((input_tensor_shape[0], input_size_per_batch)) + # Weight should have only 2 dimensions(TFLite convention) + assert len(weight_tensor_shape) == 2, "Weight should be only 2-dim" + + # Input shape: [i_batch_size, ..., n_inputs] + # Filter shape: [n_inputs, n_units] + # + # As we will transform Fully_Connected Input to Dense Op inputs as below + # Dense expected Input shape: [batch_size, n_units] + # Dense expected Weight shape: [out_dim, n_units] + # Dense output shape: [batch_size, out_dim] + # So it is evident that input shape: [batch_size = input_size / n_units, n_units] + input_size = 1 + for _, shape in enumerate(input_tensor_shape): + input_size *= shape + + # First get the batch size + batch_size = int(input_size / weight_tensor_shape[1]) + target_shape = tuple((batch_size, weight_tensor_shape[1])) in_expr = self.get_expr(input_tensor_idx) in_expr = _op.reshape(in_expr, target_shape) + #TODO: Change the output shape calculation based on keep_dim option assert op.BuiltinOptionsType() == BuiltinOptions.FullyConnectedOptions op_options = op.BuiltinOptions() fully_connected_options = FullyConnectedOptions() @@ -1181,8 +1474,11 @@ def convert_fully_connected(self, op): assert weight_tensor_type in (TensorType.UINT8, TensorType.FLOAT32) weight_tensor_type_str = self.get_tensor_type_str(weight_tensor_type) - weight_value = self.get_tensor_value(weight_tensor) - weight_expr = self.exp_tab.new_const(weight_value, dtype=weight_tensor_type_str) + if self.has_expr(weight_tensor.tensor_idx): + weight_expr = self.get_expr(weight_tensor.tensor_idx) + else: + weight_value = self.get_tensor_value(weight_tensor) + weight_expr = self.exp_tab.new_const(weight_value, dtype=weight_tensor_type_str) weight_shape = _infer_shape(weight_expr) if input_tensor.qnn_params: @@ -1207,15 +1503,6 @@ def convert_fully_connected(self, op): dtype=bias_tensor_type_str) out = _op.nn.bias_add(out, bias_expr) - # If we have fused activations - if fused_activation_fn != ActivationFunctionType.NONE: - if not output_tensor.qnn_params: - out = self.convert_fused_activation_function(out, fused_activation_fn) - else: - raise tvm.error.OpNotImplemented( - 'Operator {} with fused activation is not supported yet.' - .format('qnn.op.dense')) - # Finally if the dense is quantized. Add a requantize at the end. if output_tensor.qnn_params: data_scale = input_tensor.qnn_params['scale'] @@ -1225,6 +1512,8 @@ def convert_fully_connected(self, op): new_input_scale_val = data_scale_val * weight_scale_val new_input_scale = relay.const(new_input_scale_val, 'float32') new_input_zero_point = relay.const(0, 'int32') + + # Requantize out = _qnn.op.requantize(out, input_scale=new_input_scale, input_zero_point=new_input_zero_point, @@ -1232,18 +1521,29 @@ def convert_fully_connected(self, op): output_zero_point=output_tensor.qnn_params['zero_point'], out_dtype=output_tensor_type_str) + # Call activation function + output_scale_val = get_scalar_from_constant(output_tensor.qnn_params['scale']) + output_zero_point_val = get_scalar_from_constant(output_tensor.qnn_params['zero_point']) + out = self.convert_qnn_fused_activation_function(\ + expr=out, + fused_activation_fn=fused_activation_fn, + scale=output_scale_val, + zero_point=output_zero_point_val, + dtype=output_tensor_type_str) + + else: + out = self.convert_fused_activation_function(out, fused_activation_fn) + return out def convert_squeeze(self, op): """Convert TFLite squeeze""" try: from tflite.BuiltinOptions import BuiltinOptions - from tflite.Operator import Operator from tflite.SqueezeOptions import SqueezeOptions except ImportError: raise ImportError("The tflite package must be installed") - assert isinstance(op, Operator) input_tensors = self.get_input_tensors(op) output_tensors = self.get_output_tensors(op) assert len(input_tensors) == 1, "input tensors length should be 1" @@ -1268,7 +1568,9 @@ def convert_fused_activation_function(self, in_expr, fused_activation_fn): from tflite.ActivationFunctionType import ActivationFunctionType except ImportError: raise ImportError("The tflite package must be installed") - assert fused_activation_fn != ActivationFunctionType.NONE + + if fused_activation_fn == ActivationFunctionType.NONE: + return in_expr if fused_activation_fn == ActivationFunctionType.RELU6: return _op.clip(in_expr, a_min=0, a_max=6) if fused_activation_fn == ActivationFunctionType.RELU: @@ -1279,22 +1581,19 @@ def convert_fused_activation_function(self, in_expr, fused_activation_fn): return _op.tanh(in_expr) fused_activation_fn_str = self.activation_fn_type[fused_activation_fn] raise tvm.error.OpNotImplemented( - 'Operator {} is not supported for frontend TFLite.'.format(fused_activation_fn_str)) + 'Fused activation {} is not supported yet.'.format(fused_activation_fn_str)) def convert_conv(self, op, conv_type): """convolution implementation.""" try: from tflite.BuiltinOptions import BuiltinOptions - from tflite.ActivationFunctionType import ActivationFunctionType from tflite.TensorType import TensorType - from tflite.Operator import Operator from tflite.Conv2DOptions import Conv2DOptions from tflite.DepthwiseConv2DOptions import DepthwiseConv2DOptions from tflite.Padding import Padding except ImportError: raise ImportError("The tflite package must be installed") - assert isinstance(op, Operator) input_tensors = self.get_input_tensors(op) assert len(input_tensors) >= 2, "input tensors length should be >= 2" @@ -1389,13 +1688,7 @@ def convert_conv(self, op, conv_type): pad_left, pad_right = get_pad_value(input_w, dilated_kernel_w, stride_w) do_pad = not (pad_top == 0 and pad_bottom == 0 and pad_left == 0 and pad_right == 0) if do_pad: - pad_value = 0 - if input_tensor.qnn_params: - pad_value = get_scalar_from_constant(input_tensor.qnn_params['zero_point']) - in_expr = _op.nn.pad(data=in_expr, pad_width=((0, 0), - (pad_top, pad_bottom), - (pad_left, pad_right), - (0, 0)), pad_value=float(pad_value)) + params['padding'] = [pad_top, pad_left, pad_bottom, pad_right] else: raise tvm.error.OpAttributeUnImplemented( @@ -1424,17 +1717,9 @@ def convert_conv(self, op, conv_type): channel_axis = 3 out = _op.nn.bias_add(out, bias_expr, axis=channel_axis) - # If we have fused activations - if fused_activation_fn != ActivationFunctionType.NONE: - if not output_tensor.qnn_params: - out = self.convert_fused_activation_function(out, fused_activation_fn) - else: - raise tvm.error.OpNotImplemented( - 'Operator {} with fused activation is not supported yet.' - .format('qnn.op.conv2d')) - - # Finally if the conv is quantized. Add a requantize at the end. + # Handle fused activation. if output_tensor.qnn_params: + # Calculate the intermediate scale and zero point of the int32 output. data_scale = input_tensor.qnn_params['scale'] weight_scale = weight_tensor.qnn_params['scale'] data_scale_val = get_scalar_from_constant(data_scale) @@ -1442,6 +1727,8 @@ def convert_conv(self, op, conv_type): new_input_scale_val = data_scale_val * weight_scale_val new_input_scale = relay.const(new_input_scale_val, 'float32') new_input_zero_point = relay.const(0, 'int32') + + # Finally requantize out = _qnn.op.requantize(out, input_scale=new_input_scale, input_zero_point=new_input_zero_point, @@ -1449,18 +1736,28 @@ def convert_conv(self, op, conv_type): output_zero_point=output_tensor.qnn_params['zero_point'], out_dtype=output_tensor_type_str) + # Call activation function + output_scale_val = get_scalar_from_constant(output_tensor.qnn_params['scale']) + output_zero_point_val = get_scalar_from_constant(output_tensor.qnn_params['zero_point']) + out = self.convert_qnn_fused_activation_function(\ + expr=out, + fused_activation_fn=fused_activation_fn, + scale=output_scale_val, + zero_point=output_zero_point_val, + dtype=output_tensor_type_str) + else: + out = self.convert_fused_activation_function(out, fused_activation_fn) + return out def convert_split(self, op): """split implementation.""" try: from tflite.BuiltinOptions import BuiltinOptions - from tflite.Operator import Operator from tflite.SplitOptions import SplitOptions except ImportError: raise ImportError("The tflite package must be installed") - assert isinstance(op, Operator) input_tensors = self.get_input_tensors(op) assert len(input_tensors) == 2, "input tensors length should be == 2" @@ -1488,14 +1785,37 @@ def convert_split(self, op): return out + def convert_split_v(self, op): + """SPLIT_V implementation.""" + input_tensors = self.get_input_tensors(op) + + assert len(input_tensors) == 3, "input tensors length should be 3" + + input_tensor = input_tensors[0] + input_tensor_idx = input_tensor.tensor_idx + in_expr = self.get_expr(input_tensor_idx) + + if self.has_expr(input_tensors[1].tensor_idx): + raise tvm.error.OpNotImplemented("For size_splits parameter of SPLIT_V operator, " + "only constant values are supported.") + size_splits = list(self.get_tensor_value(input_tensors[1])) + size_splits = tuple(np.cumsum(size_splits)[:-1]) + + axis_tensor = input_tensors[2] + split_axis = self.get_tensor_value(axis_tensor) + + out = _op.split(in_expr, size_splits, axis=int(split_axis)) + # Relay does not like a TupleWrapper of 1 element, further this + # only shows up with tf1.13 if we use a split with num_splits==1. + # In tf 1.14 this doesn't appear as it is automatically a reshape + # operation. + if isinstance(out, _expr.TupleWrapper) and out.size == 1: + out = out[0] + + return out + def convert_slice(self, op): """Convert TFLite SLICE""" - try: - from tflite.Operator import Operator - except ImportError: - raise ImportError("The tflite package must be installed") - - assert isinstance(op, Operator) input_tensors = self.get_input_tensors(op) assert len(input_tensors) == 3, "input tensors length should be == 3" input_tensor = input_tensors[0] @@ -1517,14 +1837,20 @@ def convert_slice(self, op): return out + def convert_select(self, op): + """Convert TFLite SELECT""" + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 3, "input tensors length should be == 3" + cond = self.get_tensor_expr(input_tensors[0]) + x = self.get_tensor_expr(input_tensors[1]) + y = self.get_tensor_expr(input_tensors[2]) + + out = _op.where(cond, x, y) + + return out + def convert_transpose(self, op): """transpose implementation.""" - try: - from tflite.Operator import Operator - except ImportError: - raise ImportError("The tflite package must be installed") - - assert isinstance(op, Operator) input_tensors = self.get_input_tensors(op) assert len(input_tensors) == 2, "input tensors length should be 2" input_tensor = input_tensors[0] @@ -1545,13 +1871,11 @@ def convert_transpose(self, op): def convert_cast(self, op): """Convert TFLite CAST""" try: - from tflite.Operator import Operator from tflite.BuiltinOptions import BuiltinOptions from tflite.CastOptions import CastOptions except ImportError: raise ImportError("The tflite package must be installed") - assert isinstance(op, Operator) input_tensors = self.get_input_tensors(op) assert len(input_tensors) == 1, "input tensors length should be 1" input_tensor = input_tensors[0] @@ -1569,12 +1893,6 @@ def convert_cast(self, op): def convert_tile(self, op): """tile implementation.""" - try: - from tflite.Operator import Operator - except ImportError: - raise ImportError("The tflite package must be installed") - - assert isinstance(op, Operator) input_tensors = self.get_input_tensors(op) assert len(input_tensors) == 2, "input tensors length should be 2" input_tensor = input_tensors[0] @@ -1591,12 +1909,6 @@ def convert_tile(self, op): def convert_topk_v2(self, op): """ Convert TFLite TOPK_v2 """ - try: - from tflite.Operator import Operator - except ImportError: - raise ImportError("The tflite package must be installed") - - assert isinstance(op, Operator) input_tensors = self.get_input_tensors(op) assert len(input_tensors) == 2, "input tensors length should be 2" input_tensor = input_tensors[0] @@ -1611,14 +1923,11 @@ def convert_pool2d(self, op, pool_type): """pool2d implementation.""" try: from tflite.BuiltinOptions import BuiltinOptions - from tflite.ActivationFunctionType import ActivationFunctionType - from tflite.Operator import Operator from tflite.Pool2DOptions import Pool2DOptions from tflite.Padding import Padding except ImportError: raise ImportError("The tflite package must be installed") - assert isinstance(op, Operator) input_tensors = self.get_input_tensors(op) assert len(input_tensors) == 1, "input tensors length should be 1" input_tensor = input_tensors[0] @@ -1674,27 +1983,37 @@ def convert_pool2d(self, op, pool_type): assert self.has_same_qnn_params(input_tensor, output_tensor), \ "qnn.op.max_pool2d requires input and output qnn params to be same" out = _op.nn.max_pool2d(in_expr, **params) + elif pool_type == "l2": + # L2_POOL_2D is equivalent to square_root(avg_pool(square(in_data))) + # TFLite does not have support for quantised L2_POOL_2D op. + assert not input_tensor.qnn_params, \ + "As TFLite does not have support for quantized L2_POOL_2D, \ + Quantized input is not expected." + exp_type = self.get_tensor_type_str(output_tensor.tensor.Type()) + square_exp = _op.power(in_expr, relay.const(2, exp_type)) + avg_pool_exp = _op.nn.avg_pool2d(square_exp, **params) + out = _op.sqrt(avg_pool_exp) else: raise tvm.error.OpNotImplemented( 'Operator {} is not supported for frontend TFLite.'.format(pool_type + ' pool')) - # If we have fused activations - if fused_activation_fn != ActivationFunctionType.NONE: - if input_tensor.qnn_params: - raise tvm.error.OpNotImplemented( - 'Operator {} with fused activation is not supported yet.' - .format('qnn.op.pool2d')) + # Handle fused activations + if output_tensor.qnn_params: + scale_val = get_scalar_from_constant(output_tensor.qnn_params['scale']) + zero_point_val = get_scalar_from_constant(output_tensor.qnn_params['zero_point']) + out = self.convert_qnn_fused_activation_function(\ + expr=out, + fused_activation_fn=fused_activation_fn, + scale=scale_val, + zero_point=zero_point_val, + dtype=output_tensor_type_str) + else: out = self.convert_fused_activation_function(out, fused_activation_fn) + return out def convert_pad(self, op): """Convert TFLite PAD""" - try: - from tflite.Operator import Operator - except ImportError: - raise ImportError("The tflite package must be installed") - - assert isinstance(op, Operator) input_tensors = self.get_input_tensors(op) assert len(input_tensors) == 2, "input tensors length should be 2" @@ -1740,7 +2059,6 @@ def convert_floor_mod(self, op): def convert_mirror_pad(self, op): """Convert TFLite MIRROR_PAD""" try: - from tflite.Operator import Operator from tflite.BuiltinOptions import BuiltinOptions from tflite.MirrorPadOptions import MirrorPadOptions except ImportError: @@ -1751,7 +2069,6 @@ def convert_mirror_pad(self, op): raise tvm.error.OpNotImplemented( 'TFlite quantized MIRROR_PAD operator is not supported yet.') - assert isinstance(op, Operator) input_tensors = self.get_input_tensors(op) assert len(input_tensors) == 2, "input tensors length should be 2" @@ -1779,12 +2096,10 @@ def convert_pack(self, op): """Convert TFLite pack""" try: from tflite.BuiltinOptions import BuiltinOptions - from tflite.Operator import Operator from tflite.PackOptions import PackOptions except ImportError: raise ImportError("The tflite package must be installed") - assert isinstance(op, Operator) input_tensors = self.get_input_tensors(op) assert len(input_tensors) >= 1, "input tensors should greater than 1" in_exprs = [self.get_expr(input_tensor.tensor_idx) for input_tensor in input_tensors] @@ -1806,12 +2121,10 @@ def convert_unpack(self, op): """Convert TFLite unpack""" try: from tflite.BuiltinOptions import BuiltinOptions - from tflite.Operator import Operator from tflite.UnpackOptions import UnpackOptions except ImportError: raise ImportError("The tflite package must be installed") - assert isinstance(op, Operator) input_tensors = self.get_input_tensors(op) assert len(input_tensors) == 1, "input tensors length should be 1" input_tensor = input_tensors[0] @@ -1848,12 +2161,7 @@ def convert_unpack(self, op): def convert_batch_to_space_nd(self, op): """batch_to_space_nd implementation.""" - try: - from tflite.Operator import Operator - except ImportError: - raise ImportError("The tflite package must be installed") - assert isinstance(op, Operator) input_tensors = self.get_input_tensors(op) assert len(input_tensors) == 3, "input tensors length should be 3" @@ -1901,12 +2209,6 @@ def convert_batch_to_space_nd(self, op): def convert_space_to_batch_nd(self, op): """space_to_batch_nd implementation.""" - try: - from tflite.Operator import Operator - except ImportError: - raise ImportError("The tflite package must be installed") - - assert isinstance(op, Operator) input_tensors = self.get_input_tensors(op) assert len(input_tensors) == 3, "input tensors length should be 3" @@ -1960,12 +2262,10 @@ def convert_depth_to_space(self, op): """Convert TFLite DEPTH_TO_SPACE""" try: from tflite.BuiltinOptions import BuiltinOptions - from tflite.Operator import Operator from tflite.DepthToSpaceOptions import DepthToSpaceOptions except ImportError: raise ImportError("The tflite package must be installed") - assert isinstance(op, Operator) input_tensors = self.get_input_tensors(op) assert len(input_tensors) == 1, "input tensors length should be 1" @@ -1985,12 +2285,10 @@ def convert_space_to_depth(self, op): """Convert TFLite SPACE_TO_DEPTH""" try: from tflite.BuiltinOptions import BuiltinOptions - from tflite.Operator import Operator from tflite.SpaceToDepthOptions import SpaceToDepthOptions except ImportError: raise ImportError("The tflite package must be installed") - assert isinstance(op, Operator) input_tensors = self.get_input_tensors(op) assert len(input_tensors) == 1, "input tensors length should be 1" @@ -2006,14 +2304,38 @@ def convert_space_to_depth(self, op): return out - def convert_prelu(self, op): - """Convert TFLite PReLU""" + def convert_sparse_to_dense(self, op): + """Convert TFLite SPARSE_TO_DENSE""" try: - from tflite.Operator import Operator + from tflite.TensorType import TensorType except ImportError: raise ImportError("The tflite package must be installed") - assert isinstance(op, Operator) + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 4, "input tensors length should be 4" + + indices, values = input_tensors[0], input_tensors[2] + default_value = input_tensors[3] + output_shape = input_tensors[1] + + for t in input_tensors: + assert not t.qnn_params, "Quantized input is not expected." + + for t in [indices, output_shape]: + t_type = t.tensor.Type() + assert t_type in (TensorType.INT32, TensorType.INT64) + + out = _op.sparse_to_dense( + self.get_tensor_expr(indices), + list(self.get_tensor_value(output_shape)), + self.get_tensor_expr(values), + self.get_tensor_expr(default_value) + ) + + return out + + def convert_prelu(self, op): + """Convert TFLite PReLU""" input_tensors = self.get_input_tensors(op) assert len(input_tensors) == 2, "input tensors length should be 2" @@ -2033,13 +2355,11 @@ def convert_transpose_conv(self, op): try: from tflite.BuiltinOptions import BuiltinOptions from tflite.TensorType import TensorType - from tflite.Operator import Operator from tflite.TransposeConvOptions import TransposeConvOptions from tflite.Padding import Padding except ImportError: raise ImportError("The tflite package must be installed") - assert isinstance(op, Operator) input_tensors = self.get_input_tensors(op) assert len(input_tensors) == 3, "input tensors length should be 3" @@ -2106,35 +2426,57 @@ def convert_transpose_conv(self, op): return out + def convert_quantize(self, op): + """Convert TFLite Quantize""" + + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 1, "input tensors length should be 1" + input_tensor = input_tensors[0] + in_expr = self.get_expr(input_tensor.tensor_idx) + + output_tensors = self.get_output_tensors(op) + assert len(output_tensors) == 1, "output tensors length should be 1" + output_tensor = output_tensors[0] + + # The output must be quantized + assert output_tensor.qnn_params + # Quantize the input + out = self.quantize(in_expr, output_tensor) + + return out + + def convert_dequantize(self, op): + """Convert TFLite Dequantize""" + + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 1, "input tensors length should be 1" + input_tensor = input_tensors[0] + in_expr = self.get_expr(input_tensor.tensor_idx) + + # The input must be quantized + assert input_tensor.qnn_params + # Dequantize the input. + out = self.dequantize(in_expr, input_tensor) + + return out + def convert_detection_postprocess(self, op): """Convert TFLite_Detection_PostProcess""" - _option_names = [ - "w_scale", - "max_detections", - "_output_quantized", - "detections_per_class", - "x_scale", - "nms_score_threshold", - "num_classes", - "max_classes_per_detection", - "use_regular_nms", - "y_scale", - "h_scale", - "_support_output_type_float_in_quantized_op", - "nms_iou_threshold" - ] - - custom_options = get_custom_options(op, _option_names) - if custom_options["use_regular_nms"]: - raise tvm.error.OpAttributeUnImplemented( - "use_regular_nms=True is not yet supported for operator {}." - .format("TFLite_Detection_PostProcess") - ) + flexbuffer = op.CustomOptionsAsNumpy().tobytes() + custom_options = FlexBufferDecoder(flexbuffer).decode() + + if "use_regular_nms" in custom_options: + if custom_options["use_regular_nms"]: + raise tvm.error.OpAttributeUnImplemented( + "use_regular_nms=True is not yet supported for operator {}." + .format("TFLite_Detection_PostProcess") + ) inputs = self.get_input_tensors(op) assert len(inputs) == 3, "inputs length should be 3" cls_pred = self.get_expr(inputs[1].tensor_idx) loc_prob = self.get_expr(inputs[0].tensor_idx) + batch_size = inputs[1].tensor.Shape(0) anchor_values = self.get_tensor_value(inputs[2]) anchor_boxes = len(anchor_values) anchor_type = self.get_tensor_type_str(inputs[2].tensor.Type()) @@ -2162,7 +2504,7 @@ def convert_detection_postprocess(self, op): loc_prob = _op.concatenate( [loc_coords[1], loc_coords[0], loc_coords[3], loc_coords[2]], axis=2 ) - loc_prob = _op.reshape(loc_prob, [1, anchor_boxes*4]) + loc_prob = _op.reshape(loc_prob, [batch_size, anchor_boxes*4]) # anchor coords are in yxhw format # need to convert to ltrb @@ -2202,13 +2544,17 @@ def convert_detection_postprocess(self, op): ret = _op.vision.multibox_transform_loc(cls_pred, loc_prob, anchor_expr, **multibox_transform_loc_attrs) - ret = _op.vision.non_max_suppression(ret[0], ret[1], **non_max_suppression_attrs) + ret = _op.vision.non_max_suppression(ret[0], ret[1], ret[1], **non_max_suppression_attrs) ret = _op.vision.get_valid_counts(ret, 0) valid_count = ret[0] + # keep only the top 'max_detections' rows + ret = _op.strided_slice(ret[1], + [0, 0, 0], + [batch_size, custom_options["max_detections"], anchor_boxes]) # the output needs some reshaping to match tflite - ret = _op.split(ret[1], 6, axis=2) - cls_ids = ret[0] - scores = ret[1] + ret = _op.split(ret, 6, axis=2) + cls_ids = _op.reshape(ret[0], [batch_size, -1]) + scores = _op.reshape(ret[1], [batch_size, -1]) boxes = _op.concatenate([ret[3], ret[2], ret[5], ret[4]], axis=2) ret = _expr.TupleWrapper(_expr.Tuple([boxes, cls_ids, scores, valid_count]), size=4) return ret @@ -2219,6 +2565,31 @@ def get_expr(self, input_tensor_idx): def has_expr(self, input_tensor_idx): return self.exp_tab.has_expr(get_tensor_name(self.subgraph, input_tensor_idx)) + def get_tensor_expr(self, tensor): + """ Returns constant expr for constant else a tensor expr""" + if self.has_expr(tensor.tensor_idx): + # In most cases, we can assume that TOCO fuses elemwise operators + # with constants - it means both will be tensors. + expr = self.get_expr(tensor.tensor_idx) + else: + # However, in some corner cases, the elemwise operator is not fused, + # we can receive as constant. + type_str = self.get_tensor_type_str(tensor.tensor.Type()) + expr = self.exp_tab.new_const(self.get_tensor_value(tensor), dtype=type_str) + + return expr + + +def get_scalar_from_constant(expr): + """ Returns scalar value from Relay constant scalar. """ + assert isinstance(expr, _expr.Constant) and not expr.data.shape, \ + "Expr is not a constant scalar." + value = expr.data.asnumpy() + assert value.dtype == np.dtype(np.int32) or value.dtype == np.dtype(np.float32), \ + "value must be float32/int32" + return np.asscalar(value) + + def build_str_map(obj): """Build string map of TFLite enum int value @@ -2284,98 +2655,13 @@ def get_tensor_name(subgraph, tensor_idx): return subgraph.Tensors(tensor_idx).Name().decode("utf-8") -def get_custom_options(op, option_names): - """Get the options of a custom operator. - - This implements partial flexbuffer deserialization to be able - to read custom options. It is not intended to be a general - purpose flexbuffer deserializer and as such only supports a - limited number of types and assumes the data is a flat map. - - Parameters - ---------- - op: - A custom TFlite operator. - option_names: list - A complete list of the custom option names. - - Returns - ------- - options: dict - A dictionary of the custom options. - - """ - import struct - from enum import IntEnum - - class _FlexBufferType(IntEnum): - """Flexbuffer type schema from flexbuffers.h""" - FBT_NULL = 0 - FBT_INT = 1 - FBT_UINT = 2 - FBT_FLOAT = 3 - # Types above stored inline, types below store an offset. - FBT_KEY = 4 - FBT_STRING = 5 - FBT_INDIRECT_INT = 6 - FBT_INDIRECT_UINT = 7 - FBT_INDIRECT_FLOAT = 8 - FBT_MAP = 9 - FBT_VECTOR = 10 # Untyped. - FBT_VECTOR_INT = 11 # Typed any size (stores no type table). - FBT_VECTOR_UINT = 12 - FBT_VECTOR_FLOAT = 13 - FBT_VECTOR_KEY = 14 - FBT_VECTOR_STRING = 15 - FBT_VECTOR_INT2 = 16 # Typed tuple (no type table, no size field). - FBT_VECTOR_UINT2 = 17 - FBT_VECTOR_FLOAT2 = 18 - FBT_VECTOR_INT3 = 19 # Typed triple (no type table, no size field). - FBT_VECTOR_UINT3 = 20 - FBT_VECTOR_FLOAT3 = 21 - FBT_VECTOR_INT4 = 22 # Typed quad (no type table, no size field). - FBT_VECTOR_UINT4 = 23 - FBT_VECTOR_FLOAT4 = 24 - FBT_BLOB = 25 - FBT_BOOL = 26 - FBT_VECTOR_BOOL = 36 # To Allow the same type of conversion of type to vector type - - buffer = op.CustomOptionsAsNumpy().tobytes() - value_vector_offset = buffer[-3] - buffer = buffer[:-3] - num_bytes = 4 # Assume all values are stored in 32 bit width - value_vector_size = struct.unpack( - "> 2) - value_offset = -value_vector_offset + i*num_bytes - value_bytes = buffer[value_offset:value_offset+num_bytes] - if flex_type == _FlexBufferType.FBT_BOOL: - value = bool(value_bytes[0]) - if flex_type == _FlexBufferType.FBT_INT: - value = struct.unpack("> 2) + value_bytes = self.buffer[end + i * byte_width: end + (i + 1) * byte_width] + if value_type == FlexBufferType.FBT_BOOL: + value = bool(value_bytes[0]) + elif value_type == FlexBufferType.FBT_INT: + value = struct.unpack("> 2) + byte_width = 1 << BitWidth(root_packed_type & 3) + + if root_type == FlexBufferType.FBT_MAP: + return self.decode_map(root_end, byte_width, root_byte_width) + raise NotImplementedError("Flexbuffer Decoding is partially imlpemented.") diff --git a/python/tvm/relay/op/__init__.py b/python/tvm/relay/op/__init__.py index b3054d67885b..ce0df9532d66 100644 --- a/python/tvm/relay/op/__init__.py +++ b/python/tvm/relay/op/__init__.py @@ -17,9 +17,9 @@ #pylint: disable=wildcard-import, redefined-builtin """Relay core operators.""" # operator defs -from .op import get, register, register_compute, register_gradient, \ +from .op import get, register_compute, register_gradient, \ register_pattern, register_alter_op_layout, register_legalize, \ - Op, OpPattern, OpStrategy, debug, register_external_compiler + OpPattern, OpStrategy, debug, register_external_compiler from . import strategy # Operators diff --git a/python/tvm/relay/op/_algorithm.py b/python/tvm/relay/op/_algorithm.py index e1e6fd3a1139..5a20480a222b 100644 --- a/python/tvm/relay/op/_algorithm.py +++ b/python/tvm/relay/op/_algorithm.py @@ -18,7 +18,11 @@ # pylint: disable=invalid-name,unused-argument from __future__ import absolute_import +from tvm.te.hybrid import script +from tvm.runtime import convert + from . import strategy +from . import op as _reg from .op import OpPattern, register_pattern from .op import register_strategy @@ -29,3 +33,67 @@ # topk register_strategy("topk", strategy.topk_strategy) register_pattern("topk", OpPattern.OPAQUE) + +@script +def _topk_shape_func_input_data(data, k, axis): + ndim = len(data.shape) + val_out = output_tensor((ndim,), "int64") + indices_out = output_tensor((ndim,), "int64") + + for i in const_range(ndim): + if i != axis: + val_out[i] = int64(data.shape[i]) + indices_out[i] = int64(data.shape[i]) + else: + if k[0] < 1: + val_out[i] = int64(data.shape[i]) + indices_out[i] = int64(data.shape[i]) + else: + val_out[i] = int64(k[0]) + indices_out[i] = int64(k[0]) + return val_out, indices_out + +@script +def _topk_shape_func_input_shape(data_shape, k, axis): + ndim = data_shape.shape[0] + val_out = output_tensor((ndim,), "int64") + indices_out = output_tensor((ndim,), "int64") + + for i in const_range(ndim): + if i != axis: + val_out[i] = int64(data_shape[i]) + indices_out[i] = int64(data_shape[i]) + else: + if k < 1: + val_out[i] = int64(data_shape[i]) + indices_out[i] = int64(data_shape[i]) + else: + val_out[i] = int64(k) + indices_out[i] = int64(k) + return val_out, indices_out + +@_reg.register_shape_func("topk", True) +def topk_shape_func(attrs, inputs, _): + """ + Shape func for topk. + """ + axis = attrs.axis + if attrs.k is not None: + if axis < 0: + axis += inputs[0].shape[0] + val_out, indices_out = \ + _topk_shape_func_input_shape(inputs[0], attrs.k, convert(axis)) + else: + if axis < 0: + axis += len(inputs[0].shape) + val_out, indices_out = \ + _topk_shape_func_input_data(inputs[0], inputs[1], convert(axis)) + ret_type = attrs.ret_type + if ret_type == "both": + ret = [val_out, indices_out] + elif ret_type == "values": + ret = [val_out] + else: + ret = [indices_out] + + return ret diff --git a/python/tvm/relay/op/_tensor.py b/python/tvm/relay/op/_tensor.py index 6bddaa1337f6..cd9e4ed050d2 100644 --- a/python/tvm/relay/op/_tensor.py +++ b/python/tvm/relay/op/_tensor.py @@ -16,21 +16,29 @@ # under the License. #pylint: disable=invalid-name, unused-argument, len-as-condition """Backend compiler related feature registration""" -import topi -from tvm.runtime import convert from tvm.te.hybrid import script -from topi.util import get_const_tuple +import topi + from .op import register_compute, register_shape_func from .op import register_broadcast_schedule, register_injective_schedule from .op import register_pattern, OpPattern register_broadcast_schedule("log") +register_broadcast_schedule("log2") +register_broadcast_schedule("log10") register_broadcast_schedule("tan") register_broadcast_schedule("cos") +register_broadcast_schedule("cosh") register_broadcast_schedule("sin") +register_broadcast_schedule("sinh") +register_broadcast_schedule("acos") +register_broadcast_schedule("acosh") +register_broadcast_schedule("asin") +register_broadcast_schedule("asinh") register_broadcast_schedule("atan") +register_broadcast_schedule("atanh") register_broadcast_schedule("exp") register_broadcast_schedule("erf") register_broadcast_schedule("sqrt") @@ -84,7 +92,7 @@ # zeros @register_compute("zeros") def zeros_compute(attrs, inputs, output_type): - assert not inputs + assert len(inputs) == 1 return [topi.full(output_type.shape, output_type.dtype, 0.0)] register_broadcast_schedule("zeros") @@ -101,7 +109,7 @@ def zeros_like_compute(attrs, inputs, output_type): # ones @register_compute("ones") def ones_compute(attrs, inputs, output_type): - assert not inputs + assert len(inputs) == 1 return [topi.full(output_type.shape, output_type.dtype, 1.0)] register_broadcast_schedule("ones") @@ -123,20 +131,10 @@ def clip_compute(attrs, inputs, output_type): register_injective_schedule("clip") -@script -def _cast_shape_function(x): - out_ndim = len(x) - out = output_tensor((out_ndim,), "int64") - for i in const_range(out_ndim): - out[i] = x[i] - return out - -def cast_shape_func(attrs, inputs, out_ndims): - return [_cast_shape_function(*inputs)] - +# full @script def _full_shape_func(shape): - out_ndim = len(shape) + out_ndim = shape.shape[0] out = output_tensor((out_ndim,), "int64") for i in const_range(out_ndim): out[i] = int64(shape[i]) @@ -144,10 +142,15 @@ def _full_shape_func(shape): def full_shape_func(attrs, inputs, out_ndims): """ - Shape func for zeros, zeros_like, ones, ones_like. + Shape func for full. + """ + return [_full_shape_func(inputs[1])] + +def no_data_full_shape_func(attrs, inputs, out_ndims): + """ + Shape func for zeros and ones. """ - shape = get_const_tuple(attrs.shape) - return [_full_shape_func(convert(shape))] + return [_full_shape_func(inputs[0])] @script def _broadcast_shape_func(x, y, ndim): @@ -189,13 +192,14 @@ def elemwise_shape_func(attrs, inputs, _): """ return [topi.math.identity(inputs[0])] -register_shape_func("cast", False, cast_shape_func) -register_shape_func("zeros", False, full_shape_func) +register_shape_func("cast", False, elemwise_shape_func) +register_shape_func("zeros", True, no_data_full_shape_func) register_shape_func("zeros_like", False, elemwise_shape_func) -register_shape_func("ones", False, full_shape_func) +register_shape_func("ones", True, no_data_full_shape_func) register_shape_func("ones_like", False, elemwise_shape_func) -register_shape_func("full", False, full_shape_func) +register_shape_func("full", True, full_shape_func) register_shape_func("full_like", False, elemwise_shape_func) +register_shape_func("broadcast_to", True, full_shape_func) register_shape_func("add", False, broadcast_shape_func) register_shape_func("subtract", False, broadcast_shape_func) diff --git a/python/tvm/relay/op/_tensor_grad.py b/python/tvm/relay/op/_tensor_grad.py index 33a193799288..0deb87a60e34 100644 --- a/python/tvm/relay/op/_tensor_grad.py +++ b/python/tvm/relay/op/_tensor_grad.py @@ -27,12 +27,15 @@ from .reduce import sum as _sum from .tensor import ( cos, + cosh, exp, less, negative, ones_like, power, sin, + sinh, + sqrt, zeros_like, equal, shape_of, @@ -61,6 +64,24 @@ def log_grad(orig, grad): return [grad * ones_like(x) / x] +@register_gradient("log2") +def log2_grad(orig, grad): + """Returns [grad * 1 / (log(2) * x)]""" + x = orig.args[0] + ones = ones_like(x) + two = const(2.0) + return [grad * ones / (log(two) * x)] + + +@register_gradient("log10") +def log10_grad(orig, grad): + """Returns [grad * 1 / (log(10) * x)]""" + x = orig.args[0] + ones = ones_like(x) + ten = const(10.0) + return [grad * ones / (log(ten) * x)] + + @register_gradient("tan") def tan_grad(orig, grad): """Returns [grad / (cos^2(x))]""" @@ -76,18 +97,74 @@ def cos_grad(orig, grad): return [grad * (-ones * sin(x))] +@register_gradient("cosh") +def cosh_grad(orig, grad): + """Returns [grad * sinh(x)]""" + x = orig.args[0] + return [grad * sinh(x)] + + @register_gradient("sin") def sin_grad(orig, grad): """Returns [grad * cos(x)]""" x = orig.args[0] return [grad * cos(x)] + +@register_gradient("sinh") +def sinh_grad(orig, grad): + """Returns [grad * cosh(x)]""" + x = orig.args[0] + return [grad * cosh(x)] + + +@register_gradient("acos") +def acos_grad(orig, grad): + """Returns [grad * -1/((1 - (x ^ 2)) ^ 1/2)]""" + x = orig.args[0] + ones = ones_like(x) + return [grad * (-ones / sqrt(ones - (x * x)))] + + +@register_gradient("acosh") +def acosh_grad(orig, grad): + """Returns [grad * 1/((x - 1) ^ 1/2 * (x + 1) ^ 1/2)]""" + x = orig.args[0] + ones = ones_like(x) + return [grad * ones / sqrt((x * x) - ones)] + + +@register_gradient("asin") +def asin_grad(orig, grad): + """Returns [grad * 1/((1 - (x ^ 2)) ^ (1/2))]""" + x = orig.args[0] + ones = ones_like(x) + return [grad * ones / sqrt(ones - (x * x))] + + +@register_gradient("asinh") +def asinh_grad(orig, grad): + """Returns [grad * 1/((1 + (x ^ 2)) ^ (1/2))]""" + x = orig.args[0] + ones = ones_like(x) + return [grad * ones / sqrt(ones + (x * x))] + + @register_gradient("atan") def atan_grad(orig, grad): """Returns [grad * 1 / (1 + x ^ 2)]""" x = orig.args[0] - a = const(2.0) - return [grad * ones_like(x) / (ones_like(x) + power(x, a))] + ones = ones_like(x) + return [grad * ones / (ones + (x * x))] + + +@register_gradient("atanh") +def atanh_grad(orig, grad): + """Returns [grad * 1 / (1 - x ^ 2)]""" + x = orig.args[0] + ones = ones_like(x) + return [grad * ones / (ones - (x * x))] + @register_gradient("exp") def exp_grad(orig, grad): @@ -155,14 +232,14 @@ def divide_grad(orig, grad): @register_gradient("zeros") def zeros_grad(orig, grad): - """Returns []""" - return [] + """Returns [shape]""" + return [orig.args[0]] @register_gradient("ones") def ones_grad(orig, grad): - """Returns []""" - return [] + """Returns [shape]""" + return [orig.args[0]] @register_gradient("zeros_like") @@ -313,8 +390,10 @@ def conv2d_grad(orig, grad): assert padded_weight_grad_h >= filter_h assert padded_weight_grad_w >= filter_w if padded_weight_grad_h > filter_h or padded_weight_grad_w > filter_w: - backward_weight = strided_slice(backward_weight, begin=[0, 0, 0, 0], - end=[None, None, filter_h, filter_w]) + backward_weight = strided_slice(backward_weight, + begin=const([0, 0, 0, 0], dtype="int64"), + end=const([out_channel, in_channel // attrs.groups, + filter_h, filter_w], dtype="int64")) return [backward_data, backward_weight] @@ -395,14 +474,15 @@ def bias_add_grad(orig, grad): def dense_grad(orig, grad): """Returns [grad' @ weight, data @ grad']""" data, weight = orig.args - return [collapse_sum_like(transpose(grad) * weight, data), - collapse_sum_like(data * transpose(grad), weight)] - + return [collapse_sum_like(_nn.dense(grad, transpose(weight), + units=weight.checked_type.shape[1]), data), + collapse_sum_like(_nn.dense(transpose(grad), transpose(data), + units=data.checked_type.shape[1]), weight)] @register_gradient("reshape") def reshape_grad(orig, grad): """Gradient of reshape""" - return [reshape_like(grad, orig.args[0])] + return [reshape_like(grad, orig.args[0]), orig.args[1]] @register_gradient("cast") diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index ee23fcefe010..f134b8251afa 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -26,6 +26,7 @@ from . import op as _reg from . import strategy from .op import OpPattern +from ._tensor import elemwise_shape_func _reg.register_broadcast_schedule("broadcast_to") _reg.register_broadcast_schedule("broadcast_to_like") @@ -50,11 +51,13 @@ _reg.register_injective_schedule("transpose") _reg.register_injective_schedule("stack") _reg.register_injective_schedule("_contrib_reverse_reshape") +_reg.register_injective_schedule("gather") _reg.register_injective_schedule("gather_nd") _reg.register_injective_schedule("sequence_mask") _reg.register_injective_schedule("one_hot") _reg.register_reduce_schedule("collapse_sum_like") _reg.register_injective_schedule("unravel_index") +_reg.register_injective_schedule("sparse_to_dense") # concatenate _reg.register_schedule("concatenate", strategy.schedule_concatenate) @@ -87,6 +90,14 @@ def compute_argwhere(attrs, inputs, output_type): _reg.register_schedule("argwhere", strategy.schedule_argwhere) +# scatter +@_reg.register_compute("scatter") +def compute_scatter(attrs, inputs, output_type): + """Compute definition of scatter""" + return [topi.scatter(inputs[0], inputs[1], inputs[2], attrs.axis)] + +_reg.register_schedule("scatter", strategy.schedule_scatter) + ##################### # Shape functions # ##################### @@ -99,8 +110,77 @@ def _arange_shape_func(start, stop, step): @_reg.register_shape_func("arange", True) def arange_shape_func(attrs, inputs, _): + """ + Shape func for arange + """ return [_arange_shape_func(*inputs)] +@script +def _strided_slice_shape_func_input_data(data, begin, end, strides, + slice_mode): + ndim = len(data.shape) + out = output_tensor((ndim,), "int64") + for i in const_range(ndim): + cbegin = 0 + cend = data.shape[i] + cstride = 1 + if strides.shape[0] > i: + cstride = strides[i] + if begin.shape[0] > i: + cbegin = begin[i] + if end.shape[0] <= i: + cend = data.shape[i] + elif slice_mode != 0: + cstride = 1 + if end[i] < 0: + cend = data.shape[i] + else: + cend = cbegin + end[i] + else: + cend = end[i] + assert cstride != 0, "Strides can't be zero." + out[i] = int64(ceil_div((int64(cend) - int64(cbegin)), int64(cstride))) + return out + +@script +def _strided_slice_shape_func_input_shape(data_shape, begin, end, strides, slice_mode): + ndim = data_shape.shape[0] + out = output_tensor((ndim,), "int64") + for i in const_range(ndim): + cbegin = int64(0) + cend = int64(data_shape[i]) + cstride = int64(1) + if len(strides) > i: + cstride = int64(strides[i]) + if len(begin) > i: + cbegin = int64(begin[i]) + if len(end) <= i: + cend = int64(data_shape[i]) + elif slice_mode != 0: + cstride = int64(1) + if end[i] < 0: + cend = int64(data_shape[i]) + else: + cend = cbegin + int64(end[i]) + else: + cend = int64(end[i]) + assert cstride != 0, "Strides can't be zero." + out[i] = int64(ceil_div((int64(cend) - int64(cbegin)), int64(cstride))) + return out + + +@_reg.register_shape_func("strided_slice", True) +def strided_slice_shape_func(attrs, inputs, _): + """ + Shape func for strided_slice + """ + slice_mode = convert(0 if attrs.slice_mode == "end" else 1) + # data independent if begin, end and strides exist + if attrs.begin and attrs.end and attrs.strides: + return [_strided_slice_shape_func_input_shape(inputs[0], attrs.begin, attrs.end, + attrs.strides, slice_mode)] + return [_strided_slice_shape_func_input_data(*inputs, slice_mode)] + @script def _concatenate_shape_func(inputs, axis): ndim = inputs[0].shape[0] @@ -120,11 +200,83 @@ def _concatenate_shape_func(inputs, axis): @_reg.register_shape_func("concatenate", False) def concatenate_shape_func(attrs, inputs, _): axis = get_const_int(attrs.axis) + if axis < 0: + axis += inputs[0].shape[0] return [_concatenate_shape_func(inputs, convert(axis))] @script -def _reshape_shape_func(data_shape, newshape, ndim): +def _reshape_shape_func_input_shape(data_shape, newshape, ndim): + out = output_tensor((ndim,), "int64") + src_idx = 0 + dst_idx = 0 + infer_idx = -1 + copy = False + skip = 0 + for i in const_range(len(newshape)): + if skip > 0: + skip -= 1 + elif newshape[i] > 0: + out[dst_idx] = int64(newshape[i]) + src_idx += 1 + dst_idx += 1 + elif newshape[i] == 0: + out[dst_idx] = data_shape[src_idx] + src_idx += 1 + dst_idx += 1 + elif newshape[i] == -1: + assert infer_idx < 0, "One and only one dim can be inferred" + out[dst_idx] = int64(1) + infer_idx = i + dst_idx += 1 + elif newshape[i] == -2: + copy = True + elif newshape[i] == -3: + assert data_shape.shape[0] - src_idx > 1, \ + "Not enough dims in input shape for -3" + out[dst_idx] = data_shape[src_idx] * data_shape[src_idx+1] + src_idx += 2 + dst_idx += 1 + elif newshape[i] == -4: + assert len(newshape) - i > 2, "Not enough dims in new shape for -4" + if newshape[i+1] == -1: + assert newshape[i+2] != -1, "Split dims cannot both be -1." + out[dst_idx] = data_shape[src_idx] // int64(newshape[i+2]) + out[dst_idx+1] = int64(newshape[i+2]) + else: + out[dst_idx] = int64(newshape[i+1]) + if newshape[i+2] == -1: + out[dst_idx+1] = data_shape[src_idx] // int64(newshape[i+1]) + else: + out[dst_idx+1] = int64(newshape[i+2]) + assert data_shape[src_idx] == out[dst_idx] * out[dst_idx+1],\ + "Product of split dims doesn't match to input dim" + src_idx += 1 + dst_idx += 2 + skip = 2 + else: + assert False, "Invalid special values in new shape" + if len(data_shape.shape) > 0: + # if data is not constant, we can then handle -1 and -2 + if copy: + for i in range(src_idx, data_shape.shape[0]): + out[dst_idx] = data_shape[i] + dst_idx += 1 + if infer_idx >= 0: + old_size = int64(1) + for i in const_range(data_shape.shape[0]): + old_size *= data_shape[i] + new_size = int64(1) + for i in const_range(out.shape[0]): + new_size *= out[i] + out[infer_idx] = old_size // new_size + return out + +@script +def _reshape_shape_func_input_data(data, newshape, ndim): out = output_tensor((ndim,), "int64") + data_shape = allocate((len(data.shape),), "int64") + for x in const_range(len(data.shape)): + data_shape[x] = int64(data.shape[x]) src_idx = 0 dst_idx = 0 infer_idx = -1 @@ -189,10 +341,13 @@ def _reshape_shape_func(data_shape, newshape, ndim): out[infer_idx] = old_size // new_size return out -@_reg.register_shape_func("reshape", False) +@_reg.register_shape_func("reshape", True) def reshape_shape_func(attrs, inputs, out_ndims): - newshape = get_const_tuple(attrs.newshape) - return [_reshape_shape_func(inputs[0], convert(newshape), out_ndims[0])] + if attrs.newshape is None: + return [_reshape_shape_func_input_data(*inputs, out_ndims[0])] + return [_reshape_shape_func_input_shape(inputs[0], + convert(attrs.newshape), + out_ndims[0])] @script def _take_no_axis_shape_func(indices_shape, out_ndim): @@ -308,6 +463,8 @@ def argwhere_shape_func(attrs, inputs, out_ndims): return [_argwhere_shape_func_5d(inputs[0])] return ValueError("Does not support rank higher than 5 in argwhere") +_reg.register_shape_func("scatter", False, elemwise_shape_func) + @script def _layout_transform_shape_func(data_shape, out_layout_len, diff --git a/python/tvm/relay/op/algorithm.py b/python/tvm/relay/op/algorithm.py index 17fab80118af..d31e89a49f43 100644 --- a/python/tvm/relay/op/algorithm.py +++ b/python/tvm/relay/op/algorithm.py @@ -17,7 +17,7 @@ """Classic algorithm operation""" from __future__ import absolute_import as _abs from . import _make -from ..expr import TupleWrapper +from ..expr import TupleWrapper, const def argsort(data, axis=-1, is_ascend=1, dtype="int32"): """Performs sorting along the given axis and returns an array of indicies @@ -48,7 +48,8 @@ def argsort(data, axis=-1, is_ascend=1, dtype="int32"): return _make.argsort(data, axis, is_ascend, dtype) -def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int32"): +def topk(data, k=1, axis=-1, ret_type="both", + is_ascend=False, dtype="int32"): """Get the top k elements in an input tensor along the given axis. ret_type specifies the return type, can be one of ("both", "values", "indices"). @@ -58,7 +59,7 @@ def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int32"): data : relay.Expr The input data tensor. - k : int, optional + k : int or relay.Expr, optional Number of top elements to select. Return all elements if k < 1. axis : int, optional @@ -81,6 +82,8 @@ def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int32"): out : relay.Expr or List[relay.Expr] The computed result. """ + if isinstance(k, int): + k = const(k, "int64") out = _make.topk(data, k, axis, ret_type, is_ascend, dtype) if ret_type == "both": return TupleWrapper(out, 2) diff --git a/python/tvm/relay/op/contrib/__init__.py b/python/tvm/relay/op/contrib/__init__.py index 3a3f6d5aa304..0e1b4b024a5a 100644 --- a/python/tvm/relay/op/contrib/__init__.py +++ b/python/tvm/relay/op/contrib/__init__.py @@ -19,3 +19,4 @@ from .register import get_pattern_table, register_pattern_table from .dnnl import * +from .coreml import * diff --git a/python/tvm/relay/op/contrib/coreml.py b/python/tvm/relay/op/contrib/coreml.py new file mode 100644 index 000000000000..dc14c2a13089 --- /dev/null +++ b/python/tvm/relay/op/contrib/coreml.py @@ -0,0 +1,45 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-argument +"""CoreML codegen supported operators.""" +import tvm.ir +from tvm.contrib.target.coreml import _convert_map +from ...expr import Constant + + +def _register_coreml_op(op_name): + """Register a function to check the given operator is supported by Core ML. + + Paramters + --------- + op_name : Str + The name of operator that will be registered. + + """ + def _check_supported(attrs, args): + if op_name == 'nn.conv2d': + if not isinstance(args[1], Constant): + return False + if attrs['kernel_layout'] not in ['HWIO', 'OIHW']: + return False + return True + + tvm.ir.register_op_attr(op_name, "target.coremlcompiler", _check_supported) + + +for op in _convert_map: + _register_coreml_op(op) diff --git a/python/tvm/relay/op/contrib/dnnl.py b/python/tvm/relay/op/contrib/dnnl.py index 71ef430ec9c6..27574a80cc5b 100644 --- a/python/tvm/relay/op/contrib/dnnl.py +++ b/python/tvm/relay/op/contrib/dnnl.py @@ -32,8 +32,8 @@ - The other way is to implement the function by themselves to check the attributes of the op and decide if it should be offloaded to DNNL. """ -from ... import expr as _expr -from ... import op as _op +import tvm.ir +from ...dataflow_pattern import wildcard, is_op from .register import register_pattern_table @@ -51,7 +51,7 @@ def _register_external_op_helper(op_name, supported=True): f : callable A function that returns if the operator is supported by DNNL. """ - @_op.register(op_name, "target.dnnl") + @tvm.ir.register_op_attr(op_name, "target.dnnl") def _func_wrapper(attrs, args): return supported @@ -68,15 +68,15 @@ def _func_wrapper(attrs, args): def make_pattern(with_bias=True): - data = _expr.var("data") - weight = _expr.var("weight") - bias = _expr.var("bias") - conv = _op.nn.conv2d(data, weight) + data = wildcard() + weight = wildcard() + bias = wildcard() + conv = is_op('nn.conv2d')(data, weight) if with_bias: - conv_out = _op.add(conv, bias) + conv_out = is_op('add')(conv, bias) else: conv_out = conv - return _op.nn.relu(conv_out) + return is_op('nn.relu')(conv_out) @register_pattern_table("dnnl") diff --git a/python/tvm/relay/op/image/_image.py b/python/tvm/relay/op/image/_image.py index ba9d62ae397b..bcb110fee7d6 100644 --- a/python/tvm/relay/op/image/_image.py +++ b/python/tvm/relay/op/image/_image.py @@ -19,6 +19,7 @@ from __future__ import absolute_import import topi +from topi.util import get_const_tuple from .. import op as reg from .. import strategy from ..op import OpPattern @@ -37,6 +38,18 @@ def compute_resize(attrs, inputs, out_type): reg.register_injective_schedule("image.resize") +@reg.register_compute("image.resize3d") +def compute_resize3d(attrs, inputs, out_type): + size = attrs.size + layout = attrs.layout + method = attrs.method + coord_trans = attrs.coordinate_transformation_mode + out_dtype = attrs.out_dtype + return [topi.image.resize3d(inputs[0], size, layout, method, coord_trans, out_dtype)] + +reg.register_injective_schedule("image.resize3d") + + # crop and resize @reg.register_compute("image.crop_and_resize") def compute_crop_and_resize(attrs, inputs, out_type): @@ -55,3 +68,22 @@ def compute_crop_and_resize(attrs, inputs, out_type): # dilation2d reg.register_strategy("image.dilation2d", strategy.dilation2d_strategy) reg.register_pattern("image.dilation2d", OpPattern.OUT_ELEMWISE_FUSABLE) + + +# affine_grid +@reg.register_compute("image.affine_grid") +def compute_affine_grid(attrs, inputs, out_dtype): + target_shape = get_const_tuple(attrs.target_shape) + return [topi.image.affine_grid(inputs[0], target_shape)] + +reg.register_injective_schedule("image.affine_grid") + + +# grid_sample +@reg.register_compute("image.grid_sample") +def compute_grid_sample(attrs, inputs, out_dtype): + method = attrs.method + layout = attrs.layout + return [topi.image.grid_sample(inputs[0], inputs[1], method, layout)] + +reg.register_injective_schedule("image.grid_sample") diff --git a/python/tvm/relay/op/image/image.py b/python/tvm/relay/op/image/image.py index 097322c9eaf6..62889e0b674e 100644 --- a/python/tvm/relay/op/image/image.py +++ b/python/tvm/relay/op/image/image.py @@ -64,6 +64,52 @@ def resize(data, return _make.resize(data, size, layout, method, coordinate_transformation_mode, out_dtype) +def resize3d(data, + size, + layout="NCDHW", + method="trilinear", + coordinate_transformation_mode="half_pixel", + out_dtype=None): + """Image resize 3D operator. + + This operator takes data as input and does 3D scaling to the given scale factor. + In the default case, where the data_layout is `NCDHW` + with data of shape (n, c, d, h, w) + out will have a shape (n, c, size[0], size[1], size[2]) + + method indicates the algorithm to be used while calculating the out value + and method can be one of ("trilinear", "nearest_neighbor") + + Parameters + ---------- + data : relay.Expr + The input data to the operator. + + size: Tuple of Expr + The out size to which the image will be resized. + + layout : str, optional + Layout of the input. + + method : str, optional + Scale method to used [nearest_neighbor, trilinear]. + + coordinate_transformation_mode : string, optional + Describes how to transform the coordinate in the resized tensor + to the coordinate in the original tensor. + [half_pixel, align_corners, asymmetric] + + out_dtype : str, optional + Type to return. If left None returns the same type as input. + + Returns + ------- + result: relay.Expr + The resized result. + """ + return _make.resize3d(data, size, layout, method, coordinate_transformation_mode, out_dtype) + + def crop_and_resize(data, boxes, box_indices, @@ -169,3 +215,67 @@ def dilation2d(data, return _make.dilation2d(data, weight, strides, padding, dilations, data_layout, kernel_layout, out_dtype) + + +def affine_grid(data, target_shape=None): + """affine_grid operator that generates 2D sampling grid. + + This operation is described in https://arxiv.org/pdf/1506.02025.pdf. It generates a uniform + sampling grid within the target shape and normalizes it to [-1, 1]. The provided affine + transformation is then applied on the sampling grid. + + Parameters + ---------- + data : tvm.Tensor + 3-D with shape [batch, 2, 3]. The affine matrix. + + target_shape: list/tuple of two int + Specifies the output shape (H, W). + + Returns + ------- + Output : tvm.Tensor + 4-D with shape [batch, 2, target_height, target_width] + """ + return _make.affine_grid(data, target_shape) + +def grid_sample(data, grid, method='bilinear', layout='NCHW'): + """Applies bilinear sampling to input feature map. + + Given :math:`data` and :math:`grid`, then the output is computed by + + .. math:: + + x_{src} = grid[batch, 0, y_{dst}, x_{dst}] \\ + y_{src} = grid[batch, 1, y_{dst}, x_{dst}] \\ + output[batch, channel, y_{dst}, x_{dst}] = G(data[batch, channel, y_{src}, x_{src}) + + :math:`x_{dst}`, :math:`y_{dst}` enumerate all spatial locations in :math:`output`, and + :math:`G()` denotes the interpolation function. + The out-boundary points will be padded with zeros. The shape of the output will be + (data.shape[0], data.shape[1], grid.shape[2], grid.shape[3]). + + The operator assumes that :math:`grid` has been normalized to [-1, 1]. + + grid_sample often cooperates with affine_grid which generates sampling grids for grid_sample. + + Parameters + ---------- + data : tvm.Tensor + 4-D with shape [batch, in_channel, in_height, in_width] + + grid : tvm.Tensor + 4-D with shape [batch, 2, out_height, out_width] + + method : str + The interpolation method. Only 'bilinear' is supported. + + layout : str + The layout of input data and the output. + + Returns + ------- + Output : tvm.Tensor + 4-D with shape [batch, 2, out_height, out_width] + """ + return _make.grid_sample(data, grid, method, layout) diff --git a/python/tvm/relay/op/memory/memory.py b/python/tvm/relay/op/memory/memory.py index 509db354b42c..4092545d552c 100644 --- a/python/tvm/relay/op/memory/memory.py +++ b/python/tvm/relay/op/memory/memory.py @@ -40,7 +40,7 @@ def invoke_tvm_op(func, inputs, outputs): """ return _make.invoke_tvm_op(func, inputs, outputs) -def alloc_tensor(storage, shape, dtype='float32', assert_shape=None): +def alloc_tensor(storage, offset, shape, dtype='float32', assert_shape=None): """Allocate a tensor with the provided shape, and dtype. Parameters @@ -48,6 +48,9 @@ def alloc_tensor(storage, shape, dtype='float32', assert_shape=None): storage : tvm.relay.Expr The storage to allocate from. + offset : tvm.relay.Expr + The offset to allocate from. + shape : tvm.relay.Expr The shape of the tensor to allocate. @@ -61,7 +64,7 @@ def alloc_tensor(storage, shape, dtype='float32', assert_shape=None): result : tvm.relay.Expr The alloc_tensor expression. """ - return _make.alloc_tensor(storage, shape, dtype, assert_shape) + return _make.alloc_tensor(storage, offset, shape, dtype, assert_shape) def alloc_storage(size, alignment, ctx, dtype_hint='float32'): """Allocate a piece of tensor storage. diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 5f6aa898711b..1c76f57a6343 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -69,7 +69,7 @@ def compute_sparse_dense(attrs, inputs, out_type): """Compute definition of sparse_dense""" return [topi.nn.sparse_dense(inputs[0], inputs[1], inputs[2], inputs[3])] -reg.register_schedule("nn.sparse_dense", strategy.schedule_sparse_dense) +reg.register_strategy("nn.sparse_dense", strategy.sparse_dense_strategy) reg.register_pattern("nn.sparse_dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE) @@ -118,7 +118,7 @@ def legalize_conv2d(attrs, inputs, types): return topi.nn.conv2d_legalize(attrs, inputs, types) @reg.register_convert_op_layout("nn.conv2d") -def convert_conv2d(attrs, inputs, tinfos, desired_layout): +def convert_conv2d(attrs, inputs, tinfos, desired_layouts): """Convert Layout pass registration for conv2d op. Parameters @@ -129,8 +129,9 @@ def convert_conv2d(attrs, inputs, tinfos, desired_layout): The args of the Relay expr to be legalized tinfos : list of types List of input and output types - desired_layout : str - The desired layout + desired_layouts : list of layout strings + List of layouts defining our desired + layout for the data and kernel inputs respectively. Returns ------- @@ -141,11 +142,20 @@ def convert_conv2d(attrs, inputs, tinfos, desired_layout): from tvm import relay data, weight = inputs new_attrs = dict(attrs) - new_attrs['data_layout'] = desired_layout - if desired_layout == 'NCHW': + assert len(desired_layouts) == 2, "A desired layout is expected for both of nn.conv2d's inputs" + desired_data_layout, desired_kernel_layout = map(str, desired_layouts) + assert desired_data_layout != "default", "Data layout cannot be default" + new_attrs['data_layout'] = desired_data_layout + + if desired_kernel_layout != "default": + new_attrs['kernel_layout'] = desired_kernel_layout + return relay.nn.conv2d(data, weight, **new_attrs) + + # Handle default kernel layouts + if desired_data_layout == 'NCHW': new_attrs['kernel_layout'] = 'OIHW' return relay.nn.conv2d(data, weight, **new_attrs) - elif desired_layout == 'NHWC': + elif desired_data_layout == 'NHWC': # Check for depthwise convolution. if is_depthwise_conv2d(data.shape, attrs['data_layout'], weight.shape, attrs['kernel_layout'], attrs['groups']): @@ -153,9 +163,8 @@ def convert_conv2d(attrs, inputs, tinfos, desired_layout): else: new_attrs['kernel_layout'] = 'HWIO' return relay.nn.conv2d(data, weight, **new_attrs) - else: - assert "Layout %s is not yet supported." % (desired_layout) - return None + + raise ValueError("Layout %s is not yet supported." % desired_data_layout) # conv2d_transpose @@ -183,6 +192,31 @@ def legalize_conv2d_transpose(attrs, inputs, types): return topi.nn.conv2d_transpose_legalize(attrs, inputs, types) +# conv3d_transpose +reg.register_strategy("nn.conv3d_transpose", strategy.conv3d_transpose_strategy) +reg.register_pattern("nn.conv3d_transpose", OpPattern.OUT_ELEMWISE_FUSABLE) + +@reg.register_legalize("nn.conv3d_transpose") +def legalize_conv3d_transpose(attrs, inputs, types): + """Legalize conv3d_transpose op. + + Parameters + ---------- + attrs : tvm.ir.Attrs + Attributes of current Transposed convolution + inputs : list of tvm.relay.Expr + The args of the Relay expr to be legalized + types : list of types + List of input and output types + + Returns + ------- + result : tvm.relay.Expr + The legalized expr + """ + return topi.nn.conv3d_transpose_legalize(attrs, inputs, types) + + # conv3d reg.register_strategy("nn.conv3d", strategy.conv3d_strategy) reg.register_pattern("nn.conv3d", OpPattern.OUT_ELEMWISE_FUSABLE) @@ -193,7 +227,7 @@ def alter_op_layout_conv3d(attrs, inputs, tinfos, out_type): return topi.nn.conv3d_alter_layout(attrs, inputs, tinfos, out_type) @reg.register_convert_op_layout("nn.conv3d") -def convert_conv3d(attrs, inputs, tinfos, desired_layout): +def convert_conv3d(attrs, inputs, tinfos, desired_layouts): """Convert Layout pass registration for conv3d op. Parameters @@ -204,8 +238,9 @@ def convert_conv3d(attrs, inputs, tinfos, desired_layout): The args of the Relay expr to be legalized tinfos : list of types List of input and output types - desired_layout : str - The desired layout + desired_layouts : list of layout strings + List of layouts defining our desired + layout for the data and kernel inputs respectively. Returns ------- @@ -216,16 +251,25 @@ def convert_conv3d(attrs, inputs, tinfos, desired_layout): from tvm import relay data, weight = inputs new_attrs = dict(attrs) - new_attrs['data_layout'] = desired_layout - if desired_layout == 'NCDHW': + assert len(desired_layouts) == 2, "A desired layout is expected for both of nn.conv3d's inputs" + desired_data_layout, desired_kernel_layout = map(str, desired_layouts) + assert desired_data_layout != "default", "Data layout cannot be default" + new_attrs['data_layout'] = desired_data_layout + + if desired_kernel_layout != "default": + new_attrs['kernel_layout'] = desired_kernel_layout + return relay.nn.conv3d(data, weight, **new_attrs) + + # Handle default kernel layouts + if desired_data_layout == 'NCDHW': new_attrs['kernel_layout'] = 'OIDHW' return relay.nn.conv3d(data, weight, **new_attrs) - elif desired_layout == "NDHWC": + elif desired_data_layout == "NDHWC": new_attrs['kernel_layout'] = 'DHWIO' return relay.nn.conv3d(data, weight, **new_attrs) - else: - assert "Layout %s is not yet supported" % desired_layout - return None + + raise ValueError("Layout %s is not yet supported" % desired_data_layout) + # conv3d_winograd related operators reg.register_strategy("nn.contrib_conv3d_winograd_without_weight_transform", @@ -502,6 +546,15 @@ def compute_cross_entropy(attrs, inputs, out_dtype): reg.register_pattern("nn.cross_entropy", OpPattern.OPAQUE) +# dilate +@reg.register_compute("nn.dilate") +def compute_dilate(attrs, inputs, out_dtype): + return [topi.nn.dilate(inputs[0], attrs.strides)] + +reg.register_broadcast_schedule("nn.dilate") +reg.register_pattern("nn.dilate", OpPattern.INJECTIVE) + + # cross_entropy_with_logits @reg.register_compute("nn.cross_entropy_with_logits") def compute_cross_entropy_with_logits(attrs, inputs, out_dtype): @@ -535,6 +588,11 @@ def compute_space_to_depth(attrs, inputs, out_dtype): reg.register_pattern("nn.space_to_depth", OpPattern.INJECTIVE) +# correlation +reg.register_strategy("nn.correlation", strategy.correlation_strategy) +reg.register_pattern("nn.correlation", OpPattern.OUT_ELEMWISE_FUSABLE) + + ##################### # Shape functions # ##################### @@ -697,6 +755,21 @@ def pad_shape_func(attrs, inputs, _): pad_width.append(get_const_tuple(pair)) return [_pad_shape_func(inputs[0], convert(pad_width))] +@script +def _dilate_shape_func(data_shape, strides): + out = output_tensor((data_shape.shape[0],), "int64") + for i in const_range(out.shape[0]): + out[i] = (data_shape[i] - 1) * strides[i] + 1 + + return out + +@reg.register_shape_func("nn.dilate", False) +def dilate_shape_func(attrs, inputs, _): + """ + Shape function for dilate op. + """ + return [_dilate_shape_func(inputs[0], convert(attrs.strides))] + reg.register_shape_func("nn.bias_add", False, elemwise_shape_func) reg.register_shape_func("nn.softmax", False, elemwise_shape_func) reg.register_shape_func("nn.relu", False, elemwise_shape_func) diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index d0a81bccd085..34d07dce2863 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -19,7 +19,7 @@ from tvm.relay import expr from . import _make -from .util import get_pad_tuple2d, get_pad_tuple3d +from .util import get_pad_tuple1d, get_pad_tuple2d, get_pad_tuple3d def conv1d(data, @@ -372,6 +372,76 @@ def contrib_conv3d_winograd_without_weight_transform(data, groups, channels, kernel_size, data_layout, kernel_layout, out_layout, out_dtype) +def conv3d_transpose(data, + weight, + strides=(1, 1, 1), + padding=(0, 0, 0), + dilation=(1, 1, 1), + groups=1, + channels=None, + kernel_size=None, + data_layout="NCDHW", + kernel_layout="OIDHW", + out_layout="", + output_padding=(0, 0, 0), + out_dtype=""): + r"""3D transpose convolution. + + Parameters + ---------- + data : tvm.relay.Expr + The input data to the operator. + + weight : tvm.relay.Expr + The weight expressions. + + strides : Optional[Tuple[int]] + The strides of convolution. + + padding : Optional[int, Tuple[int]] + The padding of convolution on both sides of inputs before convolution. + + dilation : Optional[int, Tuple[int]] + Specifies the dilation rate to be used for dilated convolution. + + groups : Optional[int] + Number of groups for grouped convolution. + + channels : Optional[int] + Number of output channels of this convolution. + + kernel_size : Optional[int, Tuple[int]] + The spatial of the convolution kernel. + + data_layout : Optional[str] + Layout of the input. + + kernel_layout : Optional[str] + Layout of the weight. + + out_layout : Optional[str] + Layout of the output, by default, out_layout is the same as data_layout + + out_dtype : Optional[str] + Specifies the output data type for mixed precision conv3d. + + Returns + ------- + result : tvm.relay.Expr + The computed result. + """ + + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size, kernel_size) + if isinstance(strides, int): + strides = (strides, strides, strides) + if isinstance(dilation, int): + dilation = (dilation, dilation, dilation) + padding = get_pad_tuple3d(padding) + + return _make.conv3d_transpose(data, weight, strides, padding, dilation, + groups, channels, kernel_size, data_layout, + kernel_layout, out_layout, output_padding, out_dtype) def conv2d_transpose(data, weight, @@ -601,10 +671,11 @@ def max_pool1d(data, result : tvm.relay.Expr The computed result. """ + if isinstance(pool_size, int): + pool_size = (pool_size,) if isinstance(strides, int): strides = (strides,) - if isinstance(padding, int): - padding = (padding,) + padding = get_pad_tuple1d(padding) return _make.max_pool1d(data, pool_size, strides, padding, layout, ceil_mode) @@ -661,6 +732,11 @@ def max_pool2d(data, result : tvm.relay.Expr The computed result. """ + if isinstance(pool_size, int): + pool_size = (pool_size, pool_size) + if isinstance(strides, int): + strides = (strides, strides) + padding = get_pad_tuple2d(padding) return _make.max_pool2d(data, pool_size, strides, padding, layout, ceil_mode) @@ -709,6 +785,11 @@ def max_pool3d(data, result : tvm.relay.Expr The computed result. """ + if isinstance(pool_size, int): + pool_size = (pool_size, pool_size, pool_size) + if isinstance(strides, int): + strides = (strides, strides, strides) + padding = get_pad_tuple3d(padding) return _make.max_pool3d(data, pool_size, strides, padding, layout, ceil_mode) @@ -761,10 +842,11 @@ def avg_pool1d(data, result : tvm.relay.Expr The computed result. """ + if isinstance(pool_size, int): + pool_size = (pool_size,) if isinstance(strides, int): strides = (strides,) - if isinstance(padding, int): - padding = (padding,) + padding = get_pad_tuple1d(padding) return _make.avg_pool1d(data, pool_size, strides, padding, layout, ceil_mode, count_include_pad) @@ -826,6 +908,11 @@ def avg_pool2d(data, result : tvm.relay.Expr The computed result. """ + if isinstance(pool_size, int): + pool_size = (pool_size, pool_size) + if isinstance(strides, int): + strides = (strides, strides) + padding = get_pad_tuple2d(padding) return _make.avg_pool2d(data, pool_size, strides, padding, layout, ceil_mode, count_include_pad) @@ -878,6 +965,11 @@ def avg_pool3d(data, result : tvm.relay.Expr The computed result. """ + if isinstance(pool_size, int): + pool_size = (pool_size, pool_size, pool_size) + if isinstance(strides, int): + strides = (strides, strides, strides) + padding = get_pad_tuple3d(padding) return _make.avg_pool3d(data, pool_size, strides, padding, layout, ceil_mode, count_include_pad) @@ -1347,6 +1439,25 @@ def pad(data, return _make.pad(data, pad_width, pad_value, pad_mode) +def dilate(data, strides): + """Dilate data with zeros. + + Parameters + ---------- + data : tvm.relay.Expr + n-D, can be any layout. + + strides : + Dilation stride on each dimension, 1 means no dilation. + + Returns + ------- + Output : tvm.relay.Expr + The computed result + """ + return _make.dilate(data, strides) + + def mirror_pad(data, pad_width, mode="SYMMETRIC"): @@ -1708,6 +1819,75 @@ def layer_norm(data, return _make.layer_norm(data, gamma, beta, axis, epsilon, center, scale) +def group_norm(data, + gamma, + beta, + num_groups, + axis=1, + epsilon=1e-5, + center=True, + scale=True): + r""" + Group normalization normalizes over group of channels for each training examples. + We can say that, Group Norm is in between Instance Norm and Layer Norm. When we put + all the channels into a single group, group normalization becomes Layer normalization. + And, when we put each channel into different groups it becomes Instance normalization + + https://arxiv.org/pdf/1803.08494.pdf + + Applies group normalization to the n-dimensional input array by seperating the input channels + into 'num_groups' groups, each containing 'num_channels / num_groups' channels. + The mean and standard-deviation are calculated separately over the each group. gamma and + beta are learnable per-channel affine transform parameter vectors of size num_channels. + + .. math:: + + out = \frac{data - mean(data, axis)}{\sqrt{var(data, axis)+\epsilon}} + * gamma + beta + + Unlike batch normalization, the mean and var are computed along a group of channels. + + If the input has size k on axis 1, then both gamma and beta have shape (k,). + + .. note:: + + This operator can be optimized away for inference. + + Parameters + ---------- + data : tvm.relay.Expr + Input to which group_norm will be applied. + + gamma : tvm.relay.Expr + The gamma scale factor. + + beta : tvm.relay.Expr + The beta offset factor. + + num_groups : int + The number of groups to separate the channels into. + + axis : int, optional, default=1 + The axis of the channels. + + epsilon : double, optional, default=1e-5 + Small float added to variance to avoid dividing by zero. + + center : boolean, optional, default=True + If True, add offset of beta to normalized tensor, If False, + beta is ignored. + + scale : boolean, optional, default=True + If True, multiply by gamma. If False, gamma is not used. + + Returns + ------- + result : tvm.relay.Expr + The normalized data. + """ + return _make.group_norm(data, gamma, beta, num_groups, axis, epsilon, center, scale) + + def batch_matmul(x, y): r""" Computes batch matrix multiplication of `x` and `y` when `x` and `y` are data @@ -2582,3 +2762,155 @@ def adaptive_avg_pool3d(data, """ output_size = [] or output_size return _make.adaptive_avg_pool3d(data, output_size, layout) + + +def global_max_pool3d(data, + layout="NCDHW"): + r"""3D global maximum pooling operator. + + This operator takes data as input and does 3D max value calculation + across each window represented by DxWxH. + + In the default case, where the data_layout is `NCDHW` + a data Tensor with shape `(batch_size, in_channels, depth, height, width)`, + to produce an output Tensor with the following rule: + + with data of shape (b, c, d, h, w) + .. math:: + + \mbox{out}(b, c, 1, 1, 1) = \max_{l=0, \ldots, d}, \max_{m=0, \ldots, h}, + \max_{n=0, \ldots, w} \mbox{data}(b, c, l, m, n) + + Parameters + ---------- + data : tvm.relay.Expr + The input data to the operator. + + layout : str, optional + Layout of the input. + + Returns + ------- + result : tvm.relay.Expr + The computed result. + """ + output_size = [1, 1, 1] + return _make.adaptive_max_pool3d(data, output_size, layout) + + +def global_avg_pool3d(data, + layout="NCDHW"): + r"""3D global average pooling operator. + + This operator takes data as input and does 3D average value calculation + across each window represented by DxWxH. + + In the default case, where the data_layout is `NCDHW` + a data Tensor with shape `(batch_size, in_channels, depth, height, width)`, + to produce an output Tensor with the following rule: + + with data of shape (b, c, d, h, w) + + .. math:: + + \mbox{out}(b, c, 1, 1, 1) = \frac{1}{d * h * w} \sum_{l=0}^{d-1} \sum_{m=0}^{h-1} + \sum_{n=0}^{w-1} \mbox{data}(b, c, l, m, n) + + Parameters + ---------- + data : tvm.relay.Expr + The input data to the operator. + + layout : str, optional + Layout of the input. + + Returns + ------- + result : tvm.relay.Expr + The computed result. + """ + output_size = [1, 1, 1] + return _make.adaptive_avg_pool3d(data, output_size, layout) + + +def correlation(data1, data2, kernel_size, max_displacement, stride1, stride2, padding, + is_multiply, layout): + r"""Applies correlation to inputs. + + The correlation layer performs multiplicative patch comparisons between two feature maps. + Given two multi-channel feature maps :math:`f_{1}, f_{2}`, with :math:`w`, :math:`h`, and + :math:`c` being their width, height, and number of channels, the correlation layer lets the + network compare each patch from :math:`f_{1}` with each patch from :math:`f_{2}`. + + For now we consider only a single comparison of two patches. The 'correlation' of two patches + centered at :math:`x_{1}` in the first map and :math:`x_{2}` in the second map is then defined + as: + + .. math:: + + c(x_{1}, x_{2}) = \sum_{o \in [-k,k] \times [-k,k]} + + for a square patch of size :math:`K:=2k+1`. + + Note that the equation above is identical to one step of a convolution in neural networks, but + instead of convolving data with a filter, it convolves data with other data. For this + reason, it has no training weights. + + Computing :math:`c(x_{1}, x_{2})` involves :math:`c * K^{2}` multiplications. Comparing all + patch combinations involves :math:`w^{2}*h^{2}` such computations. + + Given a maximum displacement :math:`d`, for each location :math:`x_{1}` it computes + correlations :math:`c(x_{1}, x_{2})` only in a neighborhood of size :math:`D:=2d+1`, + by limiting the range of :math:`x_{2}`. We use strides :math:`s_{1}, s_{2}`, to quantize + :math:`x_{1}` globally and to quantize :math:`x_{2}` within the neighborhood + centered around :math:`x_{1}`. + + The final output is defined by the following expression: + + .. math:: + + out[n, q, i, j] = c(x_{i, j}, x_{q}) + + where :math:`i` and :math:`j` enumerate spatial locations in :math:`f_{1}`, and :math:`q` + denotes the :math:`q^{th}` neighborhood of :math:`x_{i,j}`. + + Parameters + ---------- + data1 : tvm.te.Tensor + 4-D with shape [batch, channel, height, width] + + data2 : tvm.te.Tensor + 4-D with shape [batch, channel, height, width] + + kernel_size: int + Kernel size for correlation, must be an odd number + + max_displacement: int + Max displacement of Correlation + + stride1: int + Stride for data1 + + stride2: int + Stride for data2 within the neightborhood centered around data1 + + padding : int or a list/tuple of 2 or 4 ints + Padding size, or + [pad_height, pad_width] for 2 ints, or + [pad_top, pad_left, pad_bottom, pad_right] for 4 ints + + is_multiply: bool + operation type is either multiplication or substraction + + layout: str + layout of data1, data2 and the output + + Returns + ------- + Output : tvm.te.Tensor + 4-D with shape [batch, out_channel, out_height, out_width] + """ + if isinstance(padding, int): + padding = (padding, padding) + return _make.correlation(data1, data2, kernel_size, max_displacement, stride1, stride2, + padding, is_multiply, layout) diff --git a/python/tvm/relay/op/nn/util.py b/python/tvm/relay/op/nn/util.py index 1fdcad73c74e..fc687cfe070e 100644 --- a/python/tvm/relay/op/nn/util.py +++ b/python/tvm/relay/op/nn/util.py @@ -19,6 +19,37 @@ from tvm.ir import container +def get_pad_tuple1d(padding): + """Common code to get the 1 dimensional pad option + Parameters + ---------- + padding : Union[int, Tuple[int, ...]] + Padding size + Returns + ------- + pad_left : int + Padding size on left + pad_right : int + Padding size on right. + """ + # compute the padding size + if isinstance(padding, container.Array): + padding = list(padding) + if isinstance(padding, (tuple, list)): + if len(padding) == 1: + pad_w = padding[0] * 2 + elif len(padding) == 2: + return padding[0], padding[1] + else: + raise ValueError("Size of padding can only be 1 or 2") + elif isinstance(padding, int): + pad_w = padding * 2 + else: + raise ValueError("Unknown padding option %s" % padding) + pad_left = (pad_w + 1) // 2 + return pad_left, pad_w - pad_left + + def get_pad_tuple2d(padding): """Common code to get the pad option Parameters diff --git a/python/tvm/relay/op/op.py b/python/tvm/relay/op/op.py index e6bd6bf230dd..7fad9a258f2b 100644 --- a/python/tvm/relay/op/op.py +++ b/python/tvm/relay/op/op.py @@ -17,61 +17,13 @@ #pylint: disable=unused-argument,invalid-name """The base node types for the Relay language.""" import tvm._ffi +import tvm.ir from tvm.driver import lower, build -from ..expr import RelayExpr from ...target import get_native_generic_func, GenericFunc from ...runtime import Object from . import _make -@tvm._ffi.register_object("relay.Op") -class Op(RelayExpr): - """A Relay operator definition.""" - - def __init__(self): - raise RuntimeError("Cannot create op, use get instead") - - def get_attr(self, attr_name): - """Get additional attribute about the operator. - - Parameters - ---------- - attr_name : str - The attribute name. - - Returns - ------- - value : object - The attribute value - """ - return _OpGetAttr(self, attr_name) - - def set_attr(self, attr_name, value, plevel=10): - """Set attribute about the operator. - - Parameters - ---------- - attr_name : str - The attribute name - - value : object - The attribute value - - plevel : int - The priority level - """ - _OpSetAttr(self, attr_name, value, plevel) - - def reset_attr(self, attr_name): - """Reset attribute about the operator. - - Parameters - ---------- - attr_name : str - The attribute name - """ - _OpResetAttr(self, attr_name) - def get(op_name): """Get the Op for a given name @@ -86,37 +38,7 @@ def get(op_name): op : Op The op of the corresponding name """ - return _GetOp(op_name) - - -def register(op_name, attr_key, value=None, level=10): - """Register an operator property of an operator. - - - Parameters - ---------- - op_name : str - The name of operator - - attr_key : str - The attribute name. - - value : object, optional - The value to set - - level : int, optional - The priority level - - Returns - ------- - fregister : function - Register function if value is not specified. - """ - def _register(v): - """internal register function""" - _Register(op_name, attr_key, v, level) - return v - return _register(value) if value is not None else _register + return tvm.ir.Op.get(op_name) class OpPattern(object): @@ -258,7 +180,7 @@ def register_compute(op_name, compute=None, level=10): level : int The priority level """ - return register(op_name, "FTVMCompute", compute, level) + return tvm.ir.register_op_attr(op_name, "FTVMCompute", compute, level) def register_strategy(op_name, fstrategy=None, level=10): @@ -279,7 +201,7 @@ def register_strategy(op_name, fstrategy=None, level=10): if not isinstance(fstrategy, GenericFunc): assert hasattr(fstrategy, "generic_func_node") fstrategy = fstrategy.generic_func_node - return register(op_name, "FTVMStrategy", fstrategy, level) + return tvm.ir.register_op_attr(op_name, "FTVMStrategy", fstrategy, level) def register_schedule(op_name, schedule, level=10): @@ -360,7 +282,7 @@ def register_alter_op_layout(op_name, alter_layout=None, level=10): level : int The priority level """ - return register(op_name, "FTVMAlterOpLayout", alter_layout, level) + return tvm.ir.register_op_attr(op_name, "FTVMAlterOpLayout", alter_layout, level) def register_convert_op_layout(op_name, convert_layout=None, level=10): @@ -377,7 +299,7 @@ def register_convert_op_layout(op_name, convert_layout=None, level=10): level : int The priority level """ - return register(op_name, "FTVMConvertOpLayout", convert_layout, level) + return tvm.ir.register_op_attr(op_name, "FTVMConvertOpLayout", convert_layout, level) def register_legalize(op_name, legal_op=None, level=10): @@ -394,7 +316,7 @@ def register_legalize(op_name, legal_op=None, level=10): level : int The priority level """ - return register(op_name, "FTVMLegalize", legal_op, level) + return tvm.ir.register_op_attr(op_name, "FTVMLegalize", legal_op, level) def register_pattern(op_name, pattern, level=10): @@ -411,7 +333,7 @@ def register_pattern(op_name, pattern, level=10): level : int The priority level """ - return register(op_name, "TOpPattern", pattern, level) + return tvm.ir.register_op_attr(op_name, "TOpPattern", pattern, level) def register_gradient(op_name, fgradient=None, level=10): @@ -428,7 +350,7 @@ def register_gradient(op_name, fgradient=None, level=10): level : int The priority level """ - return register(op_name, "FPrimalGradient", fgradient, level) + return tvm.ir.register_op_attr(op_name, "FPrimalGradient", fgradient, level) def register_shape_func(op_name, data_dependant, shape_func=None, level=10): @@ -450,7 +372,7 @@ def register_shape_func(op_name, data_dependant, shape_func=None, level=10): The priority level """ get(op_name).set_attr("TShapeDataDependant", data_dependant, level) - return register(op_name, "FShapeFunc", shape_func, level) + return tvm.ir.register_op_attr(op_name, "FShapeFunc", shape_func, level) def register_external_compiler(op_name, fexternal=None, level=10): @@ -469,7 +391,7 @@ def register_external_compiler(op_name, fexternal=None, level=10): level : int The priority level """ - return register(op_name, "FTVMExternalCompiler", fexternal, level) + return tvm.ir.register_op_attr(op_name, "FTVMExternalCompiler", fexternal, level) @tvm._ffi.register_func("relay.op.compiler._lower") diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py index a47be7673830..429c4f1b9940 100644 --- a/python/tvm/relay/op/op_attrs.py +++ b/python/tvm/relay/op/op_attrs.py @@ -189,6 +189,9 @@ class TransposeAttrs(Attrs): class ReshapeAttrs(Attrs): """Attributes for transform.reshape""" +@tvm._ffi.register_object("relay.attrs.GatherAttrs") +class GatherAttrs(Attrs): + """Attributes for transform.gather""" @tvm._ffi.register_object("relay.attrs.TakeAttrs") class TakeAttrs(Attrs): @@ -349,7 +352,20 @@ class BinaryDenseAttrs(Attrs): class Conv2DTransposeAttrs(Attrs): """Attributes used in Transposed Conv2D operators""" +@tvm._ffi.register_object("relay.attrs.Conv3DTransposeAttrs") +class Conv3DTransposeAttrs(Attrs): + """Attributes used in Transposed Conv3D operators""" + +@tvm._ffi.register_object("relay.attrs.DilateAttrs") +class DilateAttrs(Attrs): + """Attributes used in dilate operators""" + @tvm._ffi.register_object("relay.attrs.SubPixelAttrs") class SubPixelAttrs(Attrs): """Attributes used in depth to space and space to depth operators""" + + +@tvm._ffi.register_object("relay.attrs.CorrelationAttrs") +class CorrelationAttrs(Attrs): + """Attributes used in correlation operators""" diff --git a/python/tvm/relay/op/reduce.py b/python/tvm/relay/op/reduce.py index d3226012e887..988c94928d33 100644 --- a/python/tvm/relay/op/reduce.py +++ b/python/tvm/relay/op/reduce.py @@ -18,7 +18,7 @@ # pylint: disable=redefined-builtin from . import _make -from .tensor import sqrt +from .tensor import sqrt, log, exp from .transform import squeeze from ..expr import Tuple, TupleWrapper @@ -475,3 +475,40 @@ def prod(data, axis=None, keepdims=False, exclude=False): """ axis = [axis] if isinstance(axis, int) else axis return _make.prod(data, axis, keepdims, exclude) + + +def logsumexp(data, axis=None, keepdims=False): + """Compute the log of the sum of exponentials of input elements over given axes. + + This function is more numerically stable than log(sum(exp(input))). + It avoids overflows caused by taking the exp of large inputs and underflows + caused by taking the log of small inputs. + + Parameters + ---------- + data : relay.Expr + The input data + + axis : None or int or tuple of int + Axis or axes along which a standard deviation operation is performed. + The default, axis=None, will compute the log of the sum of exponentials of all elements + in the input array. If axis is negative it counts from the last to the first axis. + + keepdims : bool + If this is set to True, the axes which are reduced are left in the result as dimensions + with size one. + + Returns + ------- + result : relay.Expr + The computed result. + """ + + axis = [axis] if isinstance(axis, int) else axis + max_x = max(data, axis, True) + exp_x = exp(data - max_x) + sum_x = sum(exp_x, axis, True) + out_x = log(sum_x) + max_x + if not keepdims: + out_x = squeeze(out_x, axis) + return out_x diff --git a/python/tvm/relay/op/strategy/__init__.py b/python/tvm/relay/op/strategy/__init__.py index 59adf8262664..8d0543ba30af 100644 --- a/python/tvm/relay/op/strategy/__init__.py +++ b/python/tvm/relay/op/strategy/__init__.py @@ -26,6 +26,5 @@ from . import hls from . import mali from . import bifrost -from . import opengl from . import rocm from . import intel_graphics diff --git a/python/tvm/relay/op/strategy/arm_cpu.py b/python/tvm/relay/op/strategy/arm_cpu.py index bcef8ab43a24..6bdec67617e1 100644 --- a/python/tvm/relay/op/strategy/arm_cpu.py +++ b/python/tvm/relay/op/strategy/arm_cpu.py @@ -20,24 +20,25 @@ import logging import topi +from ....target import arm_isa from .generic import * from .. import op as _op logger = logging.getLogger('strategy') -@schedule_injective.register("arm_cpu") +@schedule_injective.register(["arm_cpu", "micro_dev"]) def schedule_injective_arm_cpu(_, outs, target): """schedule injective ops for arm cpu""" with target: return topi.arm_cpu.schedule_injective(outs) -@schedule_concatenate.register("arm_cpu") +@schedule_concatenate.register(["arm_cpu", "micro_dev"]) def schedule_concatenate_arm_cpu(_, outs, target): """schedule concatenate for arm cpu""" with target: return topi.arm_cpu.schedule_concatenate(outs) -@conv2d_strategy.register("arm_cpu") +@conv2d_strategy.register(["arm_cpu", "micro_dev"]) def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): """conv2d arm cpu strategy""" strategy = _op.OpStrategy() @@ -51,6 +52,8 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): if dilation_h < 1 or dilation_w < 1: raise ValueError("dilation should be positive value") + isa = arm_isa.IsaAnalyzer(target) + if groups == 1: if layout == "NCHW": if kernel_layout == "OIHW": @@ -59,16 +62,22 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): wrap_compute_conv2d(topi.arm_cpu.conv2d_nchw_spatial_pack), wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nchw_spatial_pack), name="conv2d_nchw_spatial_pack.arm_cpu") + # Intel x86 conv2d schedule. strategy.add_implementation( wrap_compute_conv2d(topi.x86.conv2d_nchw), wrap_topi_schedule(topi.x86.schedule_conv2d_nchw), name="conv2d_nchw.x86") + # check if winograd algorithm is applicable _, _, kh, kw = get_const_tuple(kernel.shape) pt, pl, pb, pr = topi.nn.get_pad_tuple(padding, (kh, kw)) - if kh == 3 and kw == 3 and stride_h == 1 and stride_w == 1 and \ - dilation_h == 1 and dilation_w == 1: + is_winograd_applicable = "float" in data.dtype and \ + "float" in kernel.dtype and \ + kh == 3 and kw == 3 and \ + stride_h == 1 and stride_w == 1 and \ + dilation_h == 1 and dilation_w == 1 + if is_winograd_applicable: strategy.add_implementation( wrap_compute_conv2d(topi.arm_cpu.conv2d_nchw_winograd), wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nchw_winograd), @@ -96,11 +105,22 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): wrap_topi_schedule(topi.generic.schedule_conv2d_hwcn), name="conv2d_hwcn.generic") elif layout == "NHWC": - assert kernel_layout == "HWIO" - strategy.add_implementation( - wrap_compute_conv2d(topi.arm_cpu.conv2d_nhwc_spatial_pack), - wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nhwc_spatial_pack), - name="conv2d_nhwc_spatial_pack.arm_cpu") + channels = data.shape[3] + if "SMLAD" in isa and (channels % 4) == 0 and kernel_layout == "HWOI": + strategy.add_implementation( + wrap_compute_conv2d(topi.arm_cpu.conv2d_direct_simd), + wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_direct_simd), + name='conv2d_direct_simd.micro_dev') + elif kernel_layout == "HWIO": + strategy.add_implementation( + wrap_compute_conv2d(topi.arm_cpu.conv2d_nhwc_spatial_pack), + wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nhwc_spatial_pack), + name="conv2d_nhwc_spatial_pack.arm_cpu") + else: + raise RuntimeError("Unsupported kernel layout {} for conv2d NHWC". + format(kernel_layout)) + + else: raise RuntimeError("Unsupported conv2d layout {} for arm cpu".format(layout)) elif is_depthwise_conv2d(data.shape, layout, kernel.shape, kernel_layout, groups): @@ -226,7 +246,7 @@ def conv2d_winograd_without_weight_transfrom_strategy_arm_cpu(attrs, inputs, out format(layout)) return strategy -@conv2d_transpose_strategy.register("arm_cpu") +@conv2d_transpose_strategy.register(["arm_cpu", "micro_dev"]) def conv2d_transpose_strategy_arm_cpu(attrs, inputs, out_type, target): """conv2d_transpose arm cpu strategy""" layout = attrs.data_layout diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index 4e5088f0d85e..e0091a18de72 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -58,7 +58,7 @@ def schedule_pool_grad_cuda(attrs, outs, target): def schedule_adaptive_pool_cuda(attrs, outs, target): """schedule adaptive pooling ops for cuda""" with target: - return topi.cuda.schedule_adaptive_pool(outs) + return topi.cuda.schedule_adaptive_pool(outs, attrs.layout) @softmax_strategy.register(["cuda", "gpu"]) def softmax_strategy_cuda(attrs, inputs, out_type, target): @@ -136,8 +136,32 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target): wrap_compute_conv2d(topi.cuda.conv2d_nhwc), wrap_topi_schedule(topi.cuda.schedule_conv2d_nhwc), name="conv2d_nhwc.cuda") - N, _, _, _ = get_const_tuple(data.shape) - _, _, CI, CO = get_const_tuple(kernel.shape) + N, H, W, _ = get_const_tuple(data.shape) + KH, KW, CI, CO = get_const_tuple(kernel.shape) + # Winograd shape related judgment + judge_winograd_tensorcore, judge_winograd_shape = winograd_judge(N, H, W, KH, KW, + CI, CO, padding, + stride_h, stride_w, + dilation_h, dilation_w, + pre_flag=False) + if judge_winograd_shape: + if target.target_name == "cuda" and \ + nvcc.have_tensorcore(tvm.gpu(0).compute_version) and \ + judge_winograd_tensorcore: + strategy.add_implementation( + wrap_compute_conv2d(topi.cuda.conv2d_nhwc_winograd_tensorcore), + wrap_topi_schedule( + topi.cuda.schedule_conv2d_nhwc_winograd_tensorcore), + name="conv2d_nhwc_winograd_tensorcore.cuda", + plevel=5) + else: + strategy.add_implementation( + wrap_compute_conv2d( + topi.cuda.conv2d_nhwc_winograd_direct), + wrap_topi_schedule( + topi.cuda.schedule_conv2d_nhwc_winograd_direct), + name="conv2d_nhwc_winograd_direct.cuda", + plevel=5) if target.target_name == "cuda": if nvcc.have_tensorcore(tvm.gpu(0).compute_version): if (N % 16 == 0 and CI % 16 == 0 and CO % 16 == 0) or \ @@ -161,7 +185,9 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target): if layout in ["NCHW", "NHWC"] and padding[0] == padding[2] and \ padding[1] == padding[3]: strategy.add_implementation( - wrap_compute_conv2d(topi.cuda.conv2d_cudnn, True), + wrap_compute_conv2d(topi.cuda.conv2d_cudnn, + need_data_layout=True, + has_groups=True), wrap_topi_schedule(topi.cuda.schedule_conv2d_cudnn), name="conv2d_cudnn.cuda", plevel=15) @@ -181,6 +207,20 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target): else: raise RuntimeError("Unsupported depthwise_conv2d layout {}".format(layout)) else: # group_conv2d + # add cudnn implementation, if any + cudnn_impl = False + if target.target_name == "cuda" and "cudnn" in target.libs: + if layout in ["NCHW", "NHWC"] and padding[0] == padding[2] and \ + padding[1] == padding[3]: + strategy.add_implementation( + wrap_compute_conv2d(topi.cuda.conv2d_cudnn, + need_data_layout=True, + has_groups=True), + wrap_topi_schedule(topi.cuda.schedule_conv2d_cudnn), + name="conv2d_cudnn.cuda", + plevel=15) + cudnn_impl = True + if layout == 'NCHW': # TODO(@vinx13, @icemelon9): Use group_conv2d_NCHWc_int8 when dtype is int8/uint8. assert kernel_layout == "OIHW" @@ -194,7 +234,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target): wrap_compute_conv2d(topi.cuda.group_conv2d_NCHWc_int8, True), wrap_topi_schedule(topi.cuda.schedule_group_conv2d_NCHWc_int8), name="group_conv2d_NCHWc_int8.cuda") - else: + elif not cudnn_impl: raise RuntimeError("Unsupported group_conv2d layout {}".format(layout)) return strategy @@ -204,6 +244,9 @@ def conv2d_winograd_without_weight_transfrom_strategy_cuda(attrs, inputs, out_ty dilation = attrs.get_int_tuple("dilation") groups = attrs.get_int("groups") layout = attrs.data_layout + data, kernel = inputs + stride_h, stride_w = attrs.get_int_tuple("strides") + padding = attrs.get_int_tuple("padding") assert dilation == (1, 1), "Do not support dilate now" assert groups == 1, "Do not supoort arbitrary group number" strategy = _op.OpStrategy() @@ -213,6 +256,30 @@ def conv2d_winograd_without_weight_transfrom_strategy_cuda(attrs, inputs, out_ty wrap_topi_schedule( topi.cuda.schedule_conv2d_nchw_winograd_without_weight_transform), name="conv2d_nchw_winograd_without_weight_transform.cuda") + elif layout == "NHWC": + N, H, W, _ = get_const_tuple(data.shape) + alpha, _, CI, CO = get_const_tuple(kernel.shape) + dilation_h, dilation_w = dilation + judge_winograd_tensorcore, _ = winograd_judge(N, H, W, alpha, alpha, CI, CO, + padding, stride_h, stride_w, + dilation_h, dilation_w, + pre_flag=True) + if target.target_name == "cuda" and \ + nvcc.have_tensorcore(tvm.gpu(0).compute_version) and \ + judge_winograd_tensorcore: + strategy.add_implementation( + wrap_compute_conv2d( + topi.cuda.conv2d_nhwc_winograd_tensorcore_without_weight_transform), + wrap_topi_schedule( + topi.cuda.schedule_conv2d_nhwc_winograd_tensorcore_without_weight_transform), + name="conv2d_nhwc_winograd_tensorcore_without_weight_transform.cuda") + else: + strategy.add_implementation( + wrap_compute_conv2d( + topi.cuda.conv2d_nhwc_winograd_direct_without_weight_transform), + wrap_topi_schedule( + topi.cuda.schedule_conv2d_nhwc_winograd_direct_without_weight_transform), + name="conv2d_nhwc_winograd_direct_without_weight_transform.cuda") else: raise RuntimeError("Unsupported conv2d_winograd_without_weight_transfrom layout {}". format(layout)) @@ -246,6 +313,24 @@ def conv2d_transpose_strategy_cuda(attrs, inputs, out_type, target): name="conv2d_transpose_nchw.cuda") return strategy + +@conv3d_transpose_strategy.register(["cuda", "gpu"]) +def conv3d_transpose_strategy_cuda(attrs, inputs, out_type, target): + """conv3d_transpose cuda strategy""" + layout = attrs.data_layout + dilation = get_const_tuple(attrs.dilation) + groups = attrs.groups + assert layout == "NCDHW", "only support ncdhw for now" + assert dilation == (1, 1, 1), "not support dilate now" + assert groups == 1, "only support groups == 1 for now" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_conv3d_transpose(topi.cuda.conv3d_transpose_ncdhw), + wrap_topi_schedule(topi.cuda.schedule_conv3d_transpose_ncdhw), + name="conv3d_transpose_ncdhw.cuda") + return strategy + + @conv3d_strategy.register(["cuda", "gpu"]) def conv3d_strategy_cuda(attrs, inputs, out_type, target): """conv3d cuda strategy""" @@ -373,15 +458,16 @@ def dense_strategy_cuda(attrs, inputs, out_type, target): wrap_topi_schedule(topi.cuda.schedule_dense_large_batch), name="dense_large_batch.cuda", plevel=5) - if nvcc.have_tensorcore(tvm.gpu(0).compute_version): - if(i % 16 == 0 and b % 16 == 0 and o % 16 == 0) \ - or (i % 16 == 0 and b % 8 == 0 and o % 32 == 0) \ - or (i % 16 == 0 and b % 32 == 0 and o % 8 == 0): - strategy.add_implementation( - wrap_compute_dense(topi.cuda.dense_tensorcore), - wrap_topi_schedule(topi.cuda.schedule_dense_tensorcore), - name="dense_tensorcore.cuda", - plevel=20) + if target.target_name == "cuda": + if nvcc.have_tensorcore(tvm.gpu(0).compute_version): + if(i % 16 == 0 and b % 16 == 0 and o % 16 == 0) \ + or (i % 16 == 0 and b % 8 == 0 and o % 32 == 0) \ + or (i % 16 == 0 and b % 32 == 0 and o % 8 == 0): + strategy.add_implementation( + wrap_compute_dense(topi.cuda.dense_tensorcore), + wrap_topi_schedule(topi.cuda.schedule_dense_tensorcore), + name="dense_tensorcore.cuda", + plevel=20) if target.target_name == "cuda" and "cublas" in target.libs: strategy.add_implementation( wrap_compute_dense(topi.cuda.dense_cublas), @@ -395,7 +481,7 @@ def batch_matmul_strategy_cuda(attrs, inputs, out_type, target): """batch_matmul cuda strategy""" strategy = _op.OpStrategy() strategy.add_implementation( - wrap_compute_batch_matmul(topi.nn.batch_matmul), + wrap_compute_batch_matmul(topi.cuda.batch_matmul), wrap_topi_schedule(topi.cuda.schedule_batch_matmul), name="batch_matmul.cuda", plevel=10) @@ -407,6 +493,19 @@ def batch_matmul_strategy_cuda(attrs, inputs, out_type, target): plevel=15) return strategy + +@sparse_dense_strategy.register(["cuda", "gpu"]) +def sparse_dense_strategy_cuda(attrs, inputs, out_type, target): + """sparse dense cuda strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_sparse_dense(topi.cuda.sparse_dense), + wrap_topi_schedule(topi.cuda.schedule_sparse_dense), + name="sparse_dense.cuda", + plevel=10) + return strategy + + @argsort_strategy.register(["cuda", "gpu"]) def argsort_strategy_cuda(attrs, inputs, out_type, target): """argsort cuda strategy""" @@ -499,3 +598,38 @@ def proposal_strategy_cuda(attrs, inputs, out_type, target): wrap_topi_schedule(topi.cuda.schedule_proposal), name="proposal.cuda") return strategy + +def winograd_judge(N, H, W, KH, KW, CI, CO, padding, stride_h, + stride_w, dilation_h, dilation_w, pre_flag): + """Winograd judgement about tensorcore and shape""" + if H % 8 == 0: + tile_size = 4 + else: + tile_size = 2 + if pre_flag: + alpha = KH + KH = KW = alpha + 1 - tile_size + pt, pl, pb, pr = topi.nn.get_pad_tuple(padding, (KH, KW)) + OH = (H + pt + pb - KH) // stride_h + 1 + OW = (W + pl + pr - KW) // stride_w + 1 + nH, nW = (OH + tile_size - 1) // tile_size, (OW + tile_size - 1) // tile_size + P = N * nH * nW + judge_winograd_tensorcore = (P % 16 == 0 and CI % 16 == 0 and CO % 16 == 0) or \ + (P % 8 == 0 and CI % 16 == 0 and CO % 32 == 0) or \ + (P % 32 == 0 and CI % 16 == 0 and CO % 8 == 0) + judge_winograd_shape = 2 < KH < 8 and 2 < KW < 8 and KH == KW and \ + stride_h == 1 and stride_w == 1 and \ + dilation_h == 1 and dilation_w == 1 + return judge_winograd_tensorcore, judge_winograd_shape + +@correlation_strategy.register(["cuda", "gpu"]) +def correlation_strategy_cuda(attrs, inputs, out_type, target): + """correlation cuda strategy""" + layout = attrs.layout + assert layout == "NCHW", "Only support NCHW layout" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_correlation(topi.cuda.correlation_nchw), + wrap_topi_schedule(topi.cuda.schedule_correlation_nchw), + name="correlation.cuda") + return strategy diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index c3eadce2b8dd..b1fb421c3e2e 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -345,6 +345,44 @@ def conv2d_transpose_strategy(attrs, inputs, out_type, target): name="conv2d_transpose_nchw.generic") return strategy + +# conv3d_transpose +def wrap_compute_conv3d_transpose(topi_compute): + """wrap conv3d_transpose topi compute""" + def compute_conv3d_transpose(attrs, inputs, out_dtype): + """Compute definition of conv3d_transpose""" + padding = get_const_tuple(attrs.padding) + strides = get_const_tuple(attrs.strides) + out_dtype = attrs.out_dtype + out_dtype = (inputs[0].dtype if out_dtype in ("same", "") + else out_dtype) + out = topi_compute( + inputs[0], inputs[1], strides, padding, out_dtype) + output_padding = get_const_tuple(attrs.output_padding) + out = topi.nn.pad(out, + [0, 0, 0, 0, 0], + [0, 0, output_padding[0], output_padding[1], output_padding[2]]) + return [out] + return compute_conv3d_transpose + + +@override_native_generic_func("conv3d_transpose_strategy") +def conv3d_transpose_strategy(attrs, inputs, out_type, target): + """conv3d_transpose generic strategy""" + logger.warning("conv3d_transpose is not optimized for this platform.") + layout = attrs.data_layout + dilation = get_const_tuple(attrs.dilation) + groups = attrs.groups + assert layout == "NCDHW", "only support ncdhw for now" + assert dilation == (1, 1, 1), "not support dilate now" + assert groups == 1, "only support groups == 1 for now" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_conv3d_transpose(topi.nn.conv3d_transpose_ncdhw), + wrap_topi_schedule(topi.generic.schedule_conv3d_transpose_ncdhw), + name="conv3d_transpose_ncdhw.generic") + return strategy + # conv3d def wrap_compute_conv3d(topi_compute, need_layout=False): """wrap conv3d topi compute""" @@ -561,12 +599,22 @@ def batch_matmul_strategy(attrs, inputs, out_type, target): name="batch_matmul.generic") return strategy -# sparse_dense -@generic_func -def schedule_sparse_dense(attrs, outs, target): - """schedule sparse_dense""" - with target: - return topi.generic.schedule_sparse_dense(outs) +# sparse dense +def wrap_compute_sparse_dense(topi_compute): + """wrap sparse dense topi compute""" + def _compute_sparse_dense(attrs, inputs, out_type): + return [topi_compute(inputs[0], inputs[1], inputs[2], inputs[3])] + return _compute_sparse_dense + +@override_native_generic_func("sparse_dense_strategy") +def sparse_dense_strategy(attrs, inputs, out_type, target): + """sparse dense generic strategy""" + logger.warning("sparse dense is not optimized for this platform.") + strategy = _op.OpStrategy() + strategy.add_implementation(wrap_compute_sparse_dense(topi.nn.sparse_dense), + wrap_topi_schedule(topi.generic.schedule_sparse_dense), + name="sparse_dense.generic") + return strategy # sparse_transpose @generic_func @@ -598,7 +646,9 @@ def argsort_strategy(attrs, inputs, out_type, target): def wrap_compute_topk(topi_compute): """Wrap topk compute""" def _compute_topk(attrs, inputs, out_type): - k = get_const_int(attrs.k) + k = inputs[1] + if attrs.k is not None: + k = attrs.k axis = get_const_int(attrs.axis) ret_type = attrs.ret_type is_ascend = bool(get_const_int(attrs.is_ascend)) @@ -693,9 +743,13 @@ def _compute_nms(attrs, inputs, out_type): score_index = get_const_int(attrs.score_index) id_index = get_const_int(attrs.id_index) invalid_to_bottom = bool(get_const_int(attrs.invalid_to_bottom)) - return [topi_compute(inputs[0], inputs[1], max_output_size, iou_threshold, - force_suppress, top_k, coord_start, score_index, - id_index, return_indices, invalid_to_bottom)] + if return_indices: + return topi_compute(inputs[0], inputs[1], inputs[2], max_output_size, iou_threshold, + force_suppress, top_k, coord_start, score_index, id_index, + return_indices, invalid_to_bottom) + return [topi_compute(inputs[0], inputs[1], inputs[2], max_output_size, iou_threshold, + force_suppress, top_k, coord_start, score_index, id_index, + return_indices, invalid_to_bottom)] return _compute_nms @override_native_generic_func("non_max_suppression_strategy") @@ -768,6 +822,13 @@ def schedule_argwhere(attrs, outs, target): with target: return topi.generic.schedule_argwhere(outs) +# scatter +@generic_func +def schedule_scatter(attrs, outs, target): + """schedule scatter""" + with target: + return topi.generic.schedule_scatter(outs) + # bitserial_conv2d def wrap_compute_bitserial_conv2d(topi_compute): """wrap bitserial_conv2d topi compute""" @@ -829,3 +890,30 @@ def bitserial_dense_strategy(attrs, inputs, out_type, target): wrap_topi_schedule(topi.generic.schedule_bitserial_dense), name="bitserial_dense.generic") return strategy + +# correlation +def wrap_compute_correlation(topi_compute): + """wrap correlation topi compute""" + def _compute_correlation(attrs, inputs, out_type): + kernel_size = attrs.kernel_size + max_displacement = attrs.max_displacement + stride1 = attrs.stride1 + stride2 = attrs.stride2 + padding = get_const_tuple(attrs.padding) + is_multiply = attrs.is_multiply + return [topi_compute(inputs[0], inputs[1], kernel_size, max_displacement, stride1, stride2, + padding, is_multiply)] + return _compute_correlation + +@override_native_generic_func("correlation_strategy") +def correlation_strategy(attrs, inputs, out_type, target): + """correlation generic strategy""" + logger.warning("correlation is not optimized for this platform.") + layout = attrs.layout + assert layout == "NCHW", "Only support NCHW layout" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_correlation(topi.nn.correlation_nchw), + wrap_topi_schedule(topi.generic.schedule_correlation_nchw), + name="correlation.generic") + return strategy diff --git a/python/tvm/relay/op/strategy/opengl.py b/python/tvm/relay/op/strategy/opengl.py deleted file mode 100644 index 12c288c83b7e..000000000000 --- a/python/tvm/relay/op/strategy/opengl.py +++ /dev/null @@ -1,83 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Definition of OpenGL operator strategy.""" -# pylint: disable=invalid-name,unused-argument,wildcard-import,unused-wildcard-import -import topi -from .generic import * -from .. import op as _op - -@schedule_injective.register("opengl") -def schedule_injective_opengl(attrs, outs, target): - """schedule injective ops for opengl""" - with target: - return topi.opengl.schedule_injective(outs) - -@schedule_concatenate.register("opengl") -def schedule_concatenate_opengl(attrs, outs, target): - """schedule concatenate for opengl""" - with target: - return topi.opengl.schedule_injective(outs) - -@schedule_pool.register("opengl") -def schedule_pool_opengl(attrs, outs, target): - """schedule pooling ops for opengl""" - with target: - return topi.opengl.schedule_pool(outs, attrs.layout) - -@schedule_adaptive_pool.register("opengl") -def schedule_adaptive_pool_opengl(attrs, outs, target): - """schedule adative pooling ops for opengl""" - with target: - return topi.opengl.schedule_adaptive_pool(outs) - -@softmax_strategy.register("opengl") -def softmax_strategy_opengl(attrs, inputs, out_type, target): - """softmax opengl strategy""" - strategy = _op.OpStrategy() - strategy.add_implementation( - wrap_compute_softmax(topi.nn.softmax), - wrap_topi_schedule(topi.opengl.schedule_softmax), - name="softmax.opengl") - return strategy - -@schedule_log_softmax.register("opengl") -def schedule_log_softmax_opengl(attrs, outs, target): - """schedule log_softmax for opengl""" - with target: - return topi.opengl.schedule_softmax(outs) - -@conv2d_strategy.register("opengl") -def conv2d_strategy_opengl(attrs, inputs, out_type, target): - """conv2d opengl strategy""" - strategy = _op.OpStrategy() - groups = attrs.groups - layout = attrs.data_layout - assert groups == 1, "Don't support group conv2d on OpenGL" - assert layout == "NCHW", "Only support conv2d layout NCHW for OpenGL" - strategy.add_implementation(wrap_compute_conv2d(topi.nn.conv2d), - wrap_topi_schedule(topi.opengl.schedule_conv2d_nchw), - name="conv2d_nchw.opengl") - return strategy - -@dense_strategy.register("opengl") -def dense_strategy_opengl(attrs, inputs, out_type, target): - """dense opengl strategy""" - strategy = _op.OpStrategy() - strategy.add_implementation(wrap_compute_dense(topi.nn.dense), - wrap_topi_schedule(topi.opengl.schedule_dense), - name="dense.opengl") - return strategy diff --git a/python/tvm/relay/op/strategy/rocm.py b/python/tvm/relay/op/strategy/rocm.py index 6cda346e5068..b1213f1acbf1 100644 --- a/python/tvm/relay/op/strategy/rocm.py +++ b/python/tvm/relay/op/strategy/rocm.py @@ -36,6 +36,7 @@ def conv2d_strategy_rocm(attrs, inputs, out_type, target): layout = attrs.data_layout stride_h, stride_w = attrs.get_int_tuple("strides") kernel_layout = attrs.kernel_layout + padding = attrs.get_int_tuple("padding") if dilation_h < 1 or dilation_w < 1: raise ValueError("dilation should be positive value") @@ -77,7 +78,8 @@ def conv2d_strategy_rocm(attrs, inputs, out_type, target): else: raise RuntimeError("Unsupported conv2d layout {} for CUDA".format(layout)) # add miopen implementation - if "miopen" in target.libs and layout == "NCHW": + if "miopen" in target.libs and layout == "NCHW" and padding[0] == padding[2] and \ + padding[1] == padding[3]: strategy.add_implementation( wrap_compute_conv2d(topi.rocm.conv2d_nchw_miopen, True), wrap_topi_schedule(topi.rocm.schedule_conv2d_nchw_miopen), diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index ba0b3d20b549..b02db416bdc8 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -18,6 +18,7 @@ # pylint: disable=invalid-name,unused-argument,wildcard-import,unused-wildcard-import import logging +import re import topi from tvm.te import SpecializedCondition from .generic import * @@ -25,6 +26,9 @@ logger = logging.getLogger('strategy') +_NCHWc_matcher = re.compile("^NCHW[0-9]+c$") +_OIHWio_matcher = re.compile("^OIHW[0-9]+i[0-9]+o$") + @schedule_injective.register("cpu") def schedule_injective_cpu(attrs, outs, target): """schedule injective ops for x86""" @@ -96,6 +100,9 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target): wrap_compute_conv2d(topi.x86.conv2d_nchw), wrap_topi_schedule(topi.x86.schedule_conv2d_nchw), name="conv2d_nchw.x86") + elif _NCHWc_matcher.match(layout): # check if layout is NCHWxc + assert _OIHWio_matcher.match(kernel_layout) # check if kernel is OIHWio + return conv2d_NCHWc_strategy_cpu(attrs, inputs, out_type, target) elif layout == "NHWC": assert kernel_layout == "HWIO" logger.warning("For x86 target, NCHW layout is recommended for conv2d.") @@ -128,6 +135,9 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target): wrap_compute_conv2d(topi.nn.depthwise_conv2d_nchw), wrap_topi_schedule(topi.generic.schedule_depthwise_conv2d_nchw), name="depthwise_conv2d_nchw.generic") + elif _NCHWc_matcher.match(layout): # check if layout is NCHWxc + assert _OIHWio_matcher.match(kernel_layout) # check if kernel is OIHWio + return depthwise_conv2d_NCHWc_strategy_cpu(attrs, inputs, out_type, target) elif layout == "NHWC": assert kernel_layout == "HWOI" logger.warning("depthwise_conv2d NHWC layout is not optimized for x86.") @@ -192,6 +202,24 @@ def conv2d_transpose_strategy_cpu(attrs, inputs, out_type, target): name="conv2d_transpose_nchw.x86") return strategy + +@conv3d_transpose_strategy.register("cpu") +def conv3d_transpose_strategy_cpu(attrs, inputs, out_type, target): + """conv3d_transpose x86 strategy""" + layout = attrs.data_layout + dilation = get_const_tuple(attrs.dilation) + groups = attrs.groups + assert layout == "NCDHW", "only support ncdhw for now" + assert dilation == (1, 1, 1), "not support dilate now" + assert groups == 1, "only support groups == 1 for now" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_conv3d_transpose(topi.x86.conv3d_transpose_ncdhw), + wrap_topi_schedule(topi.x86.schedule_conv3d_transpose_ncdhw), + name="conv3d_transpose_ncdhw.x86") + return strategy + + @conv3d_strategy.register("cpu") def conv3d_strategy_cpu(attrs, inputs, out_type, target): """conv3d generic strategy""" @@ -266,11 +294,16 @@ def batch_matmul_strategy_cpu(attrs, inputs, out_type, target): plevel=15) return strategy -@schedule_sparse_dense.register("cpu") -def schedule_sparse_dense_cpu(attrs, outs, target): - """schedule sparse_dense for x86""" - with target: - return topi.x86.schedule_sparse_dense(outs) +@sparse_dense_strategy.register("cpu") +def sparse_dense_strategy_cpu(attrs, inputs, out_type, target): + """sparse dense x86 strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation(wrap_compute_sparse_dense(topi.nn.sparse_dense), + wrap_topi_schedule(topi.x86.schedule_sparse_dense), + name="sparse_dense.x86", + plevel=10) + return strategy + @roi_align_strategy.register("cpu") def roi_align_strategy_cpu(attrs, inputs, out_type, target): diff --git a/python/tvm/relay/op/tensor.py b/python/tvm/relay/op/tensor.py index 162f83b1f52a..c60dbee6dd64 100644 --- a/python/tvm/relay/op/tensor.py +++ b/python/tvm/relay/op/tensor.py @@ -20,7 +20,7 @@ from tvm.runtime import TVMContext as _TVMContext from . import _make -from ..expr import Tuple +from ..expr import Tuple, const # We create a wrapper function for each operator in the @@ -47,6 +47,36 @@ def log(data): """ return _make.log(data) +def log2(data): + """Compute elementwise log to the base 2 of data. + + Parameters + ---------- + data : relay.Expr + The input data + + Returns + ------- + result : relay.Expr + The computed result. + """ + return _make.log2(data) + +def log10(data): + """Compute elementwise log to the base 10 of data. + + Parameters + ---------- + data : relay.Expr + The input data + + Returns + ------- + result : relay.Expr + The computed result. + """ + return _make.log10(data) + def tan(data): """Compute elementwise tan of data. @@ -77,6 +107,21 @@ def cos(data): """ return _make.cos(data) +def cosh(data): + """Compute elementwise cosh of data. + + Parameters + ---------- + data : relay.Expr + The input data + + Returns + ------- + result : relay.Expr + The computed result. + """ + return _make.cosh(data) + def sin(data): """Compute elementwise sin of data. @@ -92,6 +137,81 @@ def sin(data): """ return _make.sin(data) +def sinh(data): + """Compute elementwise sinh of data. + + Parameters + ---------- + data : relay.Expr + The input data + + Returns + ------- + result : relay.Expr + The computed result. + """ + return _make.sinh(data) + +def acos(data): + """Compute elementwise acos of data. + + Parameters + ---------- + data : relay.Expr + The input data + + Returns + ------- + result : relay.Expr + The computed result. + """ + return _make.acos(data) + +def acosh(data): + """Compute elementwise acosh of data. + + Parameters + ---------- + data : relay.Expr + The input data + + Returns + ------- + result : relay.Expr + The computed result. + """ + return _make.acosh(data) + +def asin(data): + """Compute elementwise asin of data. + + Parameters + ---------- + data : relay.Expr + The input data + + Returns + ------- + result : relay.Expr + The computed result. + """ + return _make.asin(data) + +def asinh(data): + """Compute elementwise asinh of data. + + Parameters + ---------- + data : relay.Expr + The input data + + Returns + ------- + result : relay.Expr + The computed result. + """ + return _make.asinh(data) + def atan(data): """Compute elementwise atan of data. @@ -107,6 +227,21 @@ def atan(data): """ return _make.atan(data) +def atanh(data): + """Compute elementwise atanh of data. + + Parameters + ---------- + data : relay.Expr + The input data + + Returns + ------- + result : relay.Expr + The computed result. + """ + return _make.atanh(data) + def exp(data): """Compute elementwise exp of data. @@ -793,7 +928,7 @@ def zeros(shape, dtype): Parameters ---------- - shape : tuple of int + shape : tuple of int or relay.Expr The shape of the target. dtype : data type @@ -804,6 +939,8 @@ def zeros(shape, dtype): result : relay.Expr The resulting tensor. """ + if isinstance(shape, (list, tuple)): + shape = const(list(shape), "int32") return _make.zeros(shape, dtype) @@ -828,7 +965,7 @@ def ones(shape, dtype): Parameters ---------- - shape : tuple of int + shape : tuple of int or relay.Expr The shape of the target. dtype : data type @@ -839,6 +976,8 @@ def ones(shape, dtype): result : relay.Expr The resulting tensor. """ + if isinstance(shape, (list, tuple)): + shape = const(list(shape), "int32") return _make.ones(shape, dtype) diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index d7a7b4f02d3d..05958fc39196 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -20,6 +20,7 @@ from . import _make from ..expr import TupleWrapper, const +from ...tir import expr as _expr def cast(data, dtype): @@ -201,7 +202,7 @@ def reshape(data, newshape): data : relay.Expr The input data to the operator. - newshape : Union[int, Tuple[int], List[int]] + newshape : Union[int, Tuple[int], List[int]] or relay.Expr The new shape. Should be compatible with the original shape. Returns @@ -210,8 +211,19 @@ def reshape(data, newshape): The reshaped result. """ if isinstance(newshape, int): - newshape = [newshape] - return _make.reshape(data, list(newshape)) + newshape = const([newshape]) + if isinstance(newshape, (tuple, list)): + tempshape = [] + for shape in newshape: + if isinstance(shape, _expr.IntImm): + tempshape.append(shape.value) + else: + try: + tempshape.append(int(shape)) + except ValueError as err: + raise RuntimeError('Unrecognized shape type: %s' % err) + newshape = const(tempshape) + return _make.reshape(data, newshape) def argwhere(condition): """Find the indices of elements of a tensor that are @@ -236,6 +248,30 @@ def argwhere(condition): """ return _make.argwhere(condition) +def scatter(data, indices, updates, axis): + """Update data at positions defined by indices with values in updates + + Parameters + ---------- + data : relay.Expr + The input data to the operator. + + indices : relay.Expr + The index locations to update. + + updates : relay.Expr + The values to update. + + axis : int + The axis to scatter on + + Returns + ------- + ret : relay.Expr + The computed result. + """ + return _make.scatter(data, indices, updates, axis) + def reshape_like(data, shape_like): """Reshapes the input array by the size of another array. For an input array with shape ``(d1, d2, ..., dk)``, `reshape_like` operation reshapes @@ -297,7 +333,7 @@ def full(fill_value, shape=(), dtype=""): fill_value : relay.Expr The value to fill. Must be a scalar. - shape : tuple of int + shape : tuple of int or relay.Expr The shape of the target. dtype : data type, optional (defaults to data type of the fill value) @@ -308,6 +344,8 @@ def full(fill_value, shape=(), dtype=""): result : relay.Expr The resulting tensor. """ + if isinstance(shape, (list, tuple)): + shape = const(list(shape), "int32") return _make.full(fill_value, shape, dtype) @@ -500,7 +538,7 @@ def where(condition, x, y): Returns ------- result : relay.Expr - The selected array. + The selected array. Examples -------- @@ -525,7 +563,7 @@ def broadcast_to(data, shape): data : relay.Expr The input tensor. - shape : shape + shape : tuple of int or relay.Expr Provide the shape to broadcast to. Returns @@ -533,6 +571,8 @@ def broadcast_to(data, shape): result : relay.Expr The resulting tensor. """ + if isinstance(shape, (list, tuple)): + shape = const(list(shape), "int32") return _make.broadcast_to(data, shape) def broadcast_to_like(data, broadcast_type): @@ -605,7 +645,7 @@ def split(data, indices_or_sections, axis=0): return TupleWrapper(_make.split(data, indices_or_sections, axis), ret_size) -def strided_slice(data, begin, end, strides=None): +def strided_slice(data, begin, end, strides=None, slice_mode="end"): """Strided slice of an array. Parameters @@ -613,23 +653,36 @@ def strided_slice(data, begin, end, strides=None): data : relay.Expr The source array to be sliced. - begin: list of int + begin : relay.Expr, Tuple[int], or List[int] The indices to begin with in the slicing. - end: list of int + end : relay.Expr, Tuple[int], or List[int] Indices indicating end of the slice. - strides: list of int, optional + strides : relay.Expr, Tuple[int], or List[int], optional Specifies the stride values, it can be negative in that case, the input tensor will be reversed in that particular axis. + slice_mode : str, optional + The slice mode [end, size]. + end: The ending indices for the slice [default]. + size: The input strides will be ignored, input end in this mode indicates + the size of a slice starting at the location specified by begin. If end[i] + is -1, all remaining elements in that dimension are included in the slice. + Returns ------- ret : relay.Expr The computed result. """ - strides = strides or [] - return _make.strided_slice(data, list(begin), list(end), list(strides)) + strides = strides or const([1], dtype="int32") + if isinstance(begin, (tuple, list)): + begin = const(list(begin)) + if isinstance(end, (tuple, list)): + end = const(list(end)) + if isinstance(strides, (tuple, list)): + strides = const(list(strides)) + return _make.strided_slice(data, begin, end, strides, slice_mode) def strided_set(data, v, begin, end, strides=None): @@ -643,13 +696,13 @@ def strided_set(data, v, begin, end, strides=None): v : relay.Expr The data to be set. - begin: relay.Expr + begin: relay.Expr, Tuple[int], or List[int] The indices to begin with in the slicing. - end: relay.Expr + end: relay.Expr, Tuple[int], or List[int] Indices indicating end of the slice. - strides: relay.Expr, optional + strides: relay.Expr, Tuple[int], or List[int], optional Specifies the stride values, it can be negative in that case, the input tensor will be reversed in that particular axis. @@ -659,6 +712,12 @@ def strided_set(data, v, begin, end, strides=None): The computed result. """ strides = strides or const([1], dtype="int32") + if isinstance(begin, (tuple, list)): + begin = const(list(begin)) + if isinstance(end, (tuple, list)): + end = const(list(end)) + if isinstance(strides, (tuple, list)): + strides = const(list(strides)) return _make.strided_set(data, v, begin, end, strides) @@ -741,6 +800,43 @@ def reverse_reshape(data, newshape): return _make._contrib_reverse_reshape(data, list(newshape)) +def gather(data, axis, indices): + """Gather values along given axis from given indices. + + E.g. for a 3D tensor, output is computed as: + + .. code-block:: python + + out[i][j][k] = data[indices[i][j][k]][j][k] # if axis == 0 + out[i][j][k] = data[i][indices[i][j][k]][k] # if axis == 1 + out[i][j][k] = data[i][j][indices[i][j][k]] # if axis == 2 + + ``indices`` must have same shape as ``data``, except at dimension ``axis`` + which must just be not null. Output will have same shape as ``indices``. + + Parameters + ---------- + data: relay.Expr + The input data to the operator. + + axis: int + The axis along which to index. + + indices: relay.Expr + The indices of values to gather. + + Examples + -------- + .. code-block:: python + + data = [[1, 2], [3, 4]] + axis = 1 + indices = [[0, 0], [1, 0]] + relay.gather(data, axis, indices) = [[1, 1], [4, 3]] + """ + return _make.gather(data, axis, indices) + + def gather_nd(data, indices): """Gather elements or slices from data and store to a tensor whose shape is defined by indices. @@ -884,3 +980,34 @@ def unravel_index(indices, shape): """ return _make.unravel_index(indices, shape) + +def sparse_to_dense(sparse_indices, output_shape, sparse_values, default_value=0): + """Converts a sparse representation into a dense tensor. + + Example:: + - sparse_to_dense([[0, 0], [1, 1]], [2, 2], [3, 3], 0) = [[3, 0], [0, 3]] + + Parameters + ---------- + sparse_indices : relay.Expr + A 0-D, 1-D, or 2-D tensor of integers containing location of sparse values. + + output_shape : relay.Expr + A list of integers. Shape of the dense output tensor. + + sparse_values : relay.Expr + A 0-D or 1-D tensor containing the sparse values for the sparse indices. + + default_value : relay.Expr + A 0-D tensor containing the default value for the remaining locations. + Defaults to 0. + + Returns + ------- + result : relay.Expr + Dense tensor of shape output_shape. Has the same type as sparse_values. + """ + + if default_value == 0: + default_value = const(0) + return _make.sparse_to_dense(sparse_indices, output_shape, sparse_values, default_value) diff --git a/python/tvm/relay/op/vision/_vision.py b/python/tvm/relay/op/vision/_vision.py index 6e2008ad74c0..f6c4f811f13d 100644 --- a/python/tvm/relay/op/vision/_vision.py +++ b/python/tvm/relay/op/vision/_vision.py @@ -18,6 +18,8 @@ """Definition of vision ops""" from __future__ import absolute_import +import topi +from tvm.te.hybrid import script from .. import op as reg from .. import strategy from ..op import OpPattern @@ -40,3 +42,38 @@ # non-maximum suppression reg.register_strategy("vision.non_max_suppression", strategy.nms_strategy) reg.register_pattern("vision.non_max_suppression", OpPattern.OPAQUE) + +@script +def _get_valid_counts_shape_func(data_shape): + valid_counts_shape = output_tensor((1,), "int64") + out_tensor_shape = output_tensor((data_shape.shape[0],), "int64") + out_indices_shape = output_tensor((2,), "int64") + + valid_counts_shape[0] = data_shape[0] + for i in const_range(data_shape.shape[0]): + out_tensor_shape[i] = data_shape[i] + out_indices_shape[0] = data_shape[0] + out_indices_shape[1] = data_shape[1] + + return valid_counts_shape, out_tensor_shape, out_indices_shape + +@reg.register_shape_func("vision.get_valid_counts", False) +def get_valid_counts_shape_func(attrs, inputs, _): + return _get_valid_counts_shape_func(inputs[0]) + +@script +def _nms_shape_func(data_shape): + out_shape = output_tensor((2,), "int64") + count_shape = output_tensor((2,), "int64") + + out_shape[0] = data_shape[0] + out_shape[1] = data_shape[1] + count_shape[0] = data_shape[0] + count_shape[1] = int64(1) + return out_shape, count_shape + +@reg.register_shape_func("vision.non_max_suppression", False) +def nms_shape_func(attrs, inputs, _): + if attrs.return_indices: + return _nms_shape_func(inputs[0]) + return [topi.math.identity(inputs[0])] diff --git a/python/tvm/relay/op/vision/nms.py b/python/tvm/relay/op/vision/nms.py index 70a9ec9ed5e4..b60b49ab0ccd 100644 --- a/python/tvm/relay/op/vision/nms.py +++ b/python/tvm/relay/op/vision/nms.py @@ -47,14 +47,18 @@ def get_valid_counts(data, out_tensor : relay.Expr Rearranged data tensor. + + out_indices: relay.Expr + Indices in input data """ return expr.TupleWrapper( _make.get_valid_counts(data, score_threshold, - id_index, score_index), 2) + id_index, score_index), 3) def non_max_suppression(data, valid_count, + indices, max_output_size=-1, iou_threshold=0.5, force_suppress=False, @@ -69,12 +73,23 @@ def non_max_suppression(data, Parameters ---------- data : relay.Expr - 3-D tensor with shape [batch_size, num_anchors, 6]. + 3-D tensor with shape [batch_size, num_anchors, 6] + or [batch_size, num_anchors, 5]. The last dimension should be in format of - [class_id, score, box_left, box_top, box_right, box_bottom]. + [class_id, score, box_left, box_top, box_right, box_bottom] + or [score, box_left, box_top, box_right, box_bottom]. It could + be the second output out_tensor of get_valid_counts. valid_count : relay.Expr - 1-D tensor for valid number of boxes. + 1-D tensor for valid number of boxes. It could be the output + valid_count of get_valid_counts. + + indices: relay.Expr + 2-D tensor with shape [batch_size, num_anchors], represents + the index of box in original data. It could be the third + output out_indices of get_valid_counts. The values in the + second dimension are like the output of arange(num_anchors) + if get_valid_counts is not used before non_max_suppression. max_output_size : int, optional Max number of output valid boxes for each instance. @@ -106,10 +121,24 @@ def non_max_suppression(data, Returns ------- - out : relay.Expr - 3-D tensor with shape [batch_size, num_anchors, 6]. + out : relay.Expr or relay.Tuple + return relay.Expr if return_indices is disabled, a 3-D tensor + with shape [batch_size, num_anchors, 6] or [batch_size, num_anchors, 5]. + if return_indices is True, return relay.Tuple of two 2-D tensors, with + shape [batch_size, num_anchors] and [batch_size, num_valid_anchors] respectively. """ - return _make.non_max_suppression(data, valid_count, max_output_size, - iou_threshold, force_suppress, top_k, - coord_start, score_index, id_index, - return_indices, invalid_to_bottom) + out = _make.non_max_suppression(data, + valid_count, + indices, + max_output_size, + iou_threshold, + force_suppress, + top_k, + coord_start, + score_index, + id_index, + return_indices, + invalid_to_bottom) + if return_indices: + return expr.TupleWrapper(out, 2) + return out diff --git a/python/tvm/relay/op/vision/rcnn.py b/python/tvm/relay/op/vision/rcnn.py index d160228d300a..1798ae946dc0 100644 --- a/python/tvm/relay/op/vision/rcnn.py +++ b/python/tvm/relay/op/vision/rcnn.py @@ -101,10 +101,10 @@ def proposal(cls_prob, [im_height, im_width, im_scale] scales : list/tuple of float - Scales of anchor windoes. + Scales of anchor windows. ratios : list/tuple of float - Ratios of anchor windoes. + Ratios of anchor windows. feature_stride : int The size of the receptive field each unit in the convolution layer of the rpn, for example diff --git a/python/tvm/relay/prelude.py b/python/tvm/relay/prelude.py index 243eace0fb94..5b2ecc27b998 100644 --- a/python/tvm/relay/prelude.py +++ b/python/tvm/relay/prelude.py @@ -555,21 +555,21 @@ def define_tensor_array_gather(self): self.prelude.mod[gather_var] = \ Function([tensor_array, indices], body, output_tensor_type_var(), []) - def define_tensor_get_data(self, data_shape): + def define_tensor_get_data(self): """Defines a function to get a Tensor from tensor_t with given shape. """ tensor_get_data_name = self.get_name("tensor_get_data") tensor_get_data_var = self._create_global_var(tensor_get_data_name) setattr(self.prelude, tensor_get_data_name, tensor_get_data_var) - - tensor_type_var, tensor_constructor = self._get_adt_by_shape(data_shape) + tensor_type_var = self.get_var('tensor_t') + tensor_constructor = self.get_var('tensor_constructor') t = Var('tensor', tensor_type_var()) tvar = Var('t') case =\ Clause(PatternConstructor(tensor_constructor, [PatternVar(tvar)]), tvar) self.prelude.mod[tensor_get_data_var] = \ Function([t], Match(t, [case], False), - TensorType(data_shape, self.dtype), []) + TensorType(self.shape, self.dtype), []) def register(self): """Register all tensor array ops in Prelude""" @@ -586,6 +586,7 @@ def register(self): self.define_tensor_array_concat() self.define_tensor_array_stack() self.define_tensor_array_gather() + self.define_tensor_get_data() def _get_adt_by_shape(self, shape): """Get ADT type and constructor with given shape.""" diff --git a/python/tvm/relay/qnn/op/layout_conversions.py b/python/tvm/relay/qnn/op/layout_conversions.py index f5850b8748ad..caa4c56f5abb 100644 --- a/python/tvm/relay/qnn/op/layout_conversions.py +++ b/python/tvm/relay/qnn/op/layout_conversions.py @@ -22,7 +22,7 @@ @reg.register_convert_op_layout("qnn.conv2d") -def convert_qnn_conv2d(attrs, inputs, tinfos, desired_layout): +def convert_qnn_conv2d(attrs, inputs, tinfos, desired_layouts): """Convert Layout pass registration for QNN conv2d op. Parameters @@ -33,8 +33,9 @@ def convert_qnn_conv2d(attrs, inputs, tinfos, desired_layout): The args of the Relay expr to be legalized tinfos : list of types List of input and output types - desired_layout : str - The desired layout + desired_layouts : list of layout strings + List of layouts defining our desired + layout for the data and kernel inputs respectively. Returns ------- @@ -43,11 +44,18 @@ def convert_qnn_conv2d(attrs, inputs, tinfos, desired_layout): """ # pylint: disable=import-outside-toplevel from tvm import relay - assert desired_layout == 'NCHW', \ - "Currently only transformation to NCHW layout is supported." - if desired_layout == 'NCHW': - new_attrs = dict(attrs) - new_attrs['data_layout'] = desired_layout - new_attrs['kernel_layout'] = 'OIHW' + assert len(desired_layouts) == 2, "A desired layout is expected for both of qnn.conv2d's inputs" + desired_data_layout, desired_kernel_layout = map(str, desired_layouts) + assert desired_data_layout != "default", "Data layout cannot be default" + + new_attrs = dict(attrs) + new_attrs['data_layout'] = desired_data_layout + + if desired_data_layout == 'NCHW': + if desired_kernel_layout != "default": + new_attrs['kernel_layout'] = desired_kernel_layout + else: + new_attrs['kernel_layout'] = 'OIHW' return relay.qnn.op.conv2d(*inputs, **new_attrs) - return None + + raise ValueError('Layout %s is not yet supported' % desired_data_layout) diff --git a/python/tvm/relay/qnn/op/legalizations.py b/python/tvm/relay/qnn/op/legalizations.py index b1c19092b4c7..d3b0e44a1a13 100644 --- a/python/tvm/relay/qnn/op/legalizations.py +++ b/python/tvm/relay/qnn/op/legalizations.py @@ -16,12 +16,11 @@ # under the License. # pylint: disable=invalid-name, unused-argument """Backend QNN related feature registration""" -from __future__ import absolute_import +import numpy as np import tvm from tvm import relay from .. import op as reg -from ...frontend.util import get_scalar_from_constant ################################################# # Register the functions for different operators. @@ -54,6 +53,15 @@ def qnn_dense_legalize(attrs, inputs, types): # Helper functions. ################### +def get_scalar_from_constant(expr): + """ Returns scalar value from Relay constant scalar. """ + assert isinstance(expr, relay.Constant) and not expr.data.shape, \ + "Expr is not a constant scalar." + value = expr.data.asnumpy() + assert value.dtype == np.dtype(np.int32) or value.dtype == np.dtype(np.float32), \ + "value must be float32/int32" + return np.asscalar(value) + # Helper function for lowering in the abscence of fast Int8 arithmetic units. def helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay_op): """ Converts QNN operators into a sequence of Relay operators that are friendly to HW that do diff --git a/python/tvm/relay/qnn/op/op.py b/python/tvm/relay/qnn/op/op.py index 6da15ebb479e..720bac4297ef 100644 --- a/python/tvm/relay/qnn/op/op.py +++ b/python/tvm/relay/qnn/op/op.py @@ -16,7 +16,7 @@ # under the License. #pylint: disable=unused-argument """The register functions for the QNN dialect.""" -from tvm.relay.op.op import register +import tvm.ir def register_qnn_legalize(op_name, legal_op=None, level=10): """Register legal transformation function for a QNN op @@ -32,4 +32,4 @@ def register_qnn_legalize(op_name, legal_op=None, level=10): level : int The priority level """ - return register(op_name, "FTVMQnnLegalize", legal_op, level) + return tvm.ir.register_op_attr(op_name, "FTVMQnnLegalize", legal_op, level) diff --git a/python/tvm/relay/qnn/op/qnn.py b/python/tvm/relay/qnn/op/qnn.py index 5c1baef4db94..5a3106d1e787 100644 --- a/python/tvm/relay/qnn/op/qnn.py +++ b/python/tvm/relay/qnn/op/qnn.py @@ -18,7 +18,7 @@ """QNN dialect operators.""" from __future__ import absolute_import as _abs -from tvm.relay.expr import Tuple +from tvm.relay.expr import Tuple, TupleWrapper from tvm.relay.op.nn.util import get_pad_tuple2d from . import _make @@ -156,7 +156,7 @@ def concatenate(data, Parameters ---------- - data : Union(List[relay.Expr], Tuple[relay.Expr]) + data : Union(List[relay.Expr], Tuple[relay.Expr], TupleWrapper[relay.Expr]) The list of quantized tensors. input_scales : List[relay.Expr] @@ -180,15 +180,16 @@ def concatenate(data, The concatenated quantized tensor. """ - data = list(data) - if not data: - raise ValueError("relay.concatenate requires data to be non-empty.") + if isinstance(data, (list, tuple)): + data = Tuple(data) + elif isinstance(data, TupleWrapper): + data = data.tuple_value if not isinstance(axis, int): raise ValueError("For now, we only support integer axis") input_scales = list(input_scales) input_zero_points = list(input_zero_points) - return _make.concatenate(Tuple(data), + return _make.concatenate(data, Tuple(input_scales), Tuple(input_zero_points), output_scale, diff --git a/python/tvm/relay/quantize/_annotate.py b/python/tvm/relay/quantize/_annotate.py index 2658a0aa7dad..952a86466300 100644 --- a/python/tvm/relay/quantize/_annotate.py +++ b/python/tvm/relay/quantize/_annotate.py @@ -19,17 +19,16 @@ import warnings import topi import tvm._ffi - +from tvm.relay.op import op as _reg from .. import expr as _expr from .. import analysis as _analysis from .. import op as _op -from ..op import op as _reg from . import _quantize from .quantize import QAnnotateKind, current_qconfig, quantize_context from .quantize import _forward_op -@_reg.register_compute("relay.op.annotation.simulated_quantize") +@_op.register_compute("relay.op.annotation.simulated_quantize") def simulated_quantize_compute(attrs, inputs, out_type): """Compiler for simulated_quantize.""" assert len(inputs) == 4 @@ -106,8 +105,8 @@ def frewrite_with_guard(ref_call, new_args, ctx): if not current_qconfig().guard(ref_call): return default_rewrite(ref_call, new_args, ctx) return func(ref_call, new_args, ctx) - _reg._Register(op_name, "FQAnnotateRewrite", frewrite_with_guard, level) - return frewrite_with_guard + + return tvm.ir.register_op_attr(op_name, "FQAnnotateRewrite", frewrite_with_guard, level) return _register(frewrite) if frewrite is not None else _register @@ -174,11 +173,14 @@ def conv2d_rewrite(ref_call, new_args, ctx): return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION) -# TODO(tmoreau89,ziheng) need to include an option to turn off dense quant -# @register_annotate_function("nn.dense") +@register_annotate_function("nn.dense") def dense_rewrite(ref_call, new_args, ctx): """Rewrite function for dense. Lhs of dense will be quantized to input field, and rhs of dense will be quantized to weight field. Output would be in activation field.""" + + if current_qconfig().skip_dense_layer: + return None + if quantize_context().check_to_skip(ref_call): return None diff --git a/python/tvm/relay/quantize/_calibrate.py b/python/tvm/relay/quantize/_calibrate.py index 9794698a0447..9590e87534d1 100644 --- a/python/tvm/relay/quantize/_calibrate.py +++ b/python/tvm/relay/quantize/_calibrate.py @@ -28,7 +28,6 @@ from .. import op as _op from .. import expr as _expr from .. import analysis as _analysis -from .. import transform as _transform from .. import build_module as _build_module from ...contrib import graph_runtime from .kl_divergence import _find_scale_by_kl @@ -45,7 +44,7 @@ def _get_profile_runtime(mod): target = 'llvm' ctx = tvm.context(target) - with _transform.build_config(opt_level=3): + with tvm.transform.PassContext(opt_level=3): graph, lib, params = _build_module.build(func, target=target) runtime = graph_runtime.create(graph, lib, ctx) runtime.set_input(**params) @@ -139,10 +138,14 @@ def _make_const(val): const_params[nclip_min] = _make_const(- (valid_range - 1)) const_params[nclip_max] = _make_const((valid_range - 1)) - func = mod['main'] - _analysis.post_order_visit(func, visit_func) - func = _expr.bind(func, const_params) - return IRModule.from_expr(func) + main_func = mod['main'] + _analysis.post_order_visit(main_func, visit_func) + main_func = _expr.bind(main_func, const_params) + func_dict = {} + for global_var, func in mod.functions.items(): + if global_var.name_hint != 'main': + func_dict[global_var] = func + return IRModule.from_expr(main_func, func_dict) # weight scale functions diff --git a/python/tvm/relay/quantize/_partition.py b/python/tvm/relay/quantize/_partition.py index bb3db99eed79..a607f4ea50b8 100644 --- a/python/tvm/relay/quantize/_partition.py +++ b/python/tvm/relay/quantize/_partition.py @@ -19,14 +19,11 @@ import tvm from .. import expr as _expr from .. import analysis as _analysis -from ..op import op as _reg from . import _quantize from .quantize import _forward_op def register_partition_function(op_name, frewrite=None, level=10): - def _register(func): - return _reg._Register(op_name, "FQPartitionRewrite", func, level) - return _register(frewrite) if frewrite is not None else _register + return tvm.ir.register_op_attr(op_name, "FQPartitionRewrite", frewrite, level) @tvm._ffi.register_object("relay.QPartitionExpr") diff --git a/python/tvm/relay/quantize/quantize.py b/python/tvm/relay/quantize/quantize.py index 958d0dc5d6ce..28ebf7f3032b 100644 --- a/python/tvm/relay/quantize/quantize.py +++ b/python/tvm/relay/quantize/quantize.py @@ -78,6 +78,7 @@ class QConfig(Object): "calibrate_mode": "global_scale", "global_scale": 8.0, "weight_scale": "power2", + "skip_dense_layer": True, "skip_conv_layers": [0], "do_simulation": False, "round_for_shift": True, @@ -121,7 +122,7 @@ def __enter__(self): return self def __exit__(self, ptype, value, trace): - _quantize._ExitQConfigScope(self) + _quantize._ExitQConfigScope() def __setattr__(self, name, value): if name in QConfig._node_defaults: @@ -157,6 +158,9 @@ def qconfig(**kwargs): of two. max: Find the maximum of the absolute value of the tensor + skip_dense_layer: boolean + Whether to skip all nn.dense layer type. By default are skipped. + skip_conv_layers: list Specifying which layers to be skipped. Provide a list of indices that indicate which conv2d layers to leave untouched. Start from 0. diff --git a/python/tvm/relay/tensorrt.py b/python/tvm/relay/tensorrt.py index d4a90e5c3683..2d2d0e407019 100644 --- a/python/tvm/relay/tensorrt.py +++ b/python/tvm/relay/tensorrt.py @@ -21,11 +21,11 @@ import os import numpy as np import tvm +import tvm.ir import tvm.relay.transform as transform from tvm import relay from tvm.relay.expr import Call, Constant, Tuple, GlobalVar from tvm.relay.build_module import bind_params_by_name -from tvm.relay import op as reg from tvm.relay.transform import _ffi_api from tvm.relay.expr_functor import ExprMutator @@ -116,7 +116,7 @@ def IsTrtRuntimeAvailable(): return GetTrtVersion() != () def _register_external_op_helper(op_name, supported=True): - @reg.register(op_name, "target.tensorrt") + @tvm.ir.register_op_attr(op_name, "target.tensorrt") def _func_wrapper(attrs, args): if any([x.checked_type.dtype != "float32" for x in args]): print("Only float32 inputs are supported for TensorRT.") @@ -125,7 +125,7 @@ def _func_wrapper(attrs, args): return _func_wrapper def _register_external_op_helper_func(op_name, func, trt_version): - @reg.register(op_name, "target.tensorrt") + @tvm.ir.register_op_attr(op_name, "target.tensorrt") def _func_wrapper(attrs, args): if any([x.checked_type.dtype != "float32" for x in args]): print("Only float32 inputs are supported for TensorRT.") @@ -162,7 +162,7 @@ def register_tensorrt_annotations(trt_version, use_implicit_batch=True): #_register_external_op_helper("split") #_register_external_op_helper("slice_like") - @reg.register("add", "target.tensorrt") + @tvm.ir.register_op_attr("add", "target.tensorrt") def add_whitelist_fn(attrs, args): # pylint: disable=unused-variable if any([x.checked_type.dtype != "float32" for x in args]): print("Only float32 inputs are supported for TensorRT.") @@ -175,7 +175,7 @@ def add_whitelist_fn(attrs, args): # pylint: disable=unused-variable return False return True - @reg.register("nn.batch_norm", "target.tensorrt") + @tvm.ir.register_op_attr("nn.batch_norm", "target.tensorrt") def batch_norm_whitelist_fn(attrs, args): # pylint: disable=unused-variable if any([x.checked_type.dtype != "float32" for x in args]): print("Only float32 inputs are supported for TensorRT.") @@ -185,7 +185,7 @@ def batch_norm_whitelist_fn(attrs, args): # pylint: disable=unused-variable return False return True - @reg.register("nn.softmax", "target.tensorrt") + @tvm.ir.register_op_attr("nn.softmax", "target.tensorrt") def softmax_whitelist_fn(attrs, args): # pylint: disable=unused-variable if any([x.checked_type.dtype != "float32" for x in args]): print("Only float32 inputs are supported for TensorRT.") @@ -195,7 +195,7 @@ def softmax_whitelist_fn(attrs, args): # pylint: disable=unused-variable return False return True - @reg.register("nn.conv2d", "target.tensorrt") + @tvm.ir.register_op_attr("nn.conv2d", "target.tensorrt") def conv2d_whitelist_fn(attrs, args): # pylint: disable=unused-variable if any([x.checked_type.dtype != "float32" for x in args]): print("Only float32 inputs are supported for TensorRT.") @@ -211,7 +211,7 @@ def conv2d_whitelist_fn(attrs, args): # pylint: disable=unused-variable return False return True - @reg.register("nn.dense", "target.tensorrt") + @tvm.ir.register_op_attr("nn.dense", "target.tensorrt") def dense_whitelist_fn(attrs, args): # pylint: disable=unused-variable if any([x.checked_type.dtype != "float32" for x in args]): print("Only float32 inputs are supported for TensorRT.") @@ -226,7 +226,7 @@ def dense_whitelist_fn(attrs, args): # pylint: disable=unused-variable return False return True - @reg.register("nn.bias_add", "target.tensorrt") + @tvm.ir.register_op_attr("nn.bias_add", "target.tensorrt") def bias_add_whitelist_fn(attrs, args): # pylint: disable=unused-variable # TODO(trevmorr): BiasAddSimplifier creates a pattern which cannot be # converted to TRT without binding params and constant folding. @@ -241,7 +241,7 @@ def bias_add_whitelist_fn(attrs, args): # pylint: disable=unused-variable return False return True - @reg.register("nn.max_pool2d", "target.tensorrt") + @tvm.ir.register_op_attr("nn.max_pool2d", "target.tensorrt") def max_pool_2d_whitelist_fn(attrs, args): # pylint: disable=unused-variable if any([x.checked_type.dtype != "float32" for x in args]): print("Only float32 inputs are supported for TensorRT.") @@ -251,7 +251,7 @@ def max_pool_2d_whitelist_fn(attrs, args): # pylint: disable=unused-variable return False return True - @reg.register("nn.avg_pool2d", "target.tensorrt") + @tvm.ir.register_op_attr("nn.avg_pool2d", "target.tensorrt") def avg_pool_2d_whitelist_fn(attrs, args): # pylint: disable=unused-variable if any([x.checked_type.dtype != "float32" for x in args]): print("Only float32 inputs are supported for TensorRT.") @@ -268,7 +268,7 @@ def avg_pool_2d_whitelist_fn(attrs, args): # pylint: disable=unused-variable return False return True - @reg.register("nn.global_max_pool2d", "target.tensorrt") + @tvm.ir.register_op_attr("nn.global_max_pool2d", "target.tensorrt") def global_max_pool_2d_whitelist_fn(attrs, args): # pylint: disable=unused-variable if any([x.checked_type.dtype != "float32" for x in args]): print("Only float32 inputs are supported for TensorRT.") @@ -278,7 +278,7 @@ def global_max_pool_2d_whitelist_fn(attrs, args): # pylint: disable=unused-varia return False return True - @reg.register("nn.global_avg_pool2d", "target.tensorrt") + @tvm.ir.register_op_attr("nn.global_avg_pool2d", "target.tensorrt") def global_avg_pool_2d_whitelist_fn(attrs, args): # pylint: disable=unused-variable if any([x.checked_type.dtype != "float32" for x in args]): print("Only float32 inputs are supported for TensorRT.") @@ -288,7 +288,7 @@ def global_avg_pool_2d_whitelist_fn(attrs, args): # pylint: disable=unused-varia return False return True - @reg.register("expand_dims", "target.tensorrt") + @tvm.ir.register_op_attr("expand_dims", "target.tensorrt") def expand_dims_whitelist_fn(attrs, args): # pylint: disable=unused-variable if any([x.checked_type.dtype != "float32" for x in args]): print("Only float32 inputs are supported for TensorRT.") @@ -298,7 +298,7 @@ def expand_dims_whitelist_fn(attrs, args): # pylint: disable=unused-variable return False return True - @reg.register("squeeze", "target.tensorrt") + @tvm.ir.register_op_attr("squeeze", "target.tensorrt") def squeeze_whitelist_fn(attrs, args): # pylint: disable=unused-variable if any([x.checked_type.dtype != "float32" for x in args]): print("Only float32 inputs are supported for TensorRT.") @@ -311,7 +311,7 @@ def squeeze_whitelist_fn(attrs, args): # pylint: disable=unused-variable return False return True - @reg.register("concatenate", "target.tensorrt") + @tvm.ir.register_op_attr("concatenate", "target.tensorrt") def concatenate_whitelist_fn(attrs, args): # pylint: disable=unused-variable if any([x.dtype != "float32" for x in args[0].checked_type.fields]): print("Only float32 inputs are supported for TensorRT.") @@ -328,7 +328,7 @@ def concatenate_whitelist_fn(attrs, args): # pylint: disable=unused-variable return False return True - @reg.register("nn.conv2d_transpose", "target.tensorrt") + @tvm.ir.register_op_attr("nn.conv2d_transpose", "target.tensorrt") def conv2d_transpose_whitelist_fn(attrs, args): # pylint: disable=unused-variable if any([x.checked_type.dtype != "float32" for x in args]): print("Only float32 inputs are supported for TensorRT.") @@ -350,7 +350,7 @@ def conv2d_transpose_whitelist_fn(attrs, args): # pylint: disable=unused-variabl return False return True - @reg.register("transpose", "target.tensorrt") + @tvm.ir.register_op_attr("transpose", "target.tensorrt") def transpose_whitelist_fn(attrs, args): # pylint: disable=unused-variable if any([x.checked_type.dtype != "float32" for x in args]): print("Only float32 inputs are supported for TensorRT.") @@ -360,7 +360,7 @@ def transpose_whitelist_fn(attrs, args): # pylint: disable=unused-variable return False return True - @reg.register("reshape", "target.tensorrt") + @tvm.ir.register_op_attr("reshape", "target.tensorrt") def reshape_whitelist_fn(attrs, args): # pylint: disable=unused-variable if any([x.checked_type.dtype != "float32" for x in args]): print("Only float32 inputs are supported for TensorRT.") @@ -390,7 +390,7 @@ def reshape_whitelist_fn(attrs, args): # pylint: disable=unused-variable return False return True - @reg.register("nn.pad", "target.tensorrt") + @tvm.ir.register_op_attr("nn.pad", "target.tensorrt") def pad_whitelist_fn(attrs, args): # pylint: disable=unused-variable if any([x.checked_type.dtype != "float32" for x in args]): print("Only float32 inputs are supported for TensorRT.") @@ -433,7 +433,7 @@ def trt_5_1_5_whitelist_fn(attrs, args, op_name, trt_version): _register_external_op_helper_func("atan", trt_5_1_5_whitelist_fn, trt_version) _register_external_op_helper_func("ceil", trt_5_1_5_whitelist_fn, trt_version) - @reg.register("strided_slice", "target.tensorrt") + @tvm.ir.register_op_attr("strided_slice", "target.tensorrt") def strided_slice_whitelist_fn(attrs, args): # pylint: disable=unused-variable if any([x.checked_type.dtype != "float32" for x in args]): print("Only float32 inputs are supported for TensorRT.") @@ -456,7 +456,7 @@ def strided_slice_whitelist_fn(attrs, args): # pylint: disable=unused-variable return False return True - @reg.register("image.resize", "target.tensorrt") + @tvm.ir.register_op_attr("image.resize", "target.tensorrt") def resize_whitelist_fn(attrs, args): # pylint: disable=unused-variable if any([x.checked_type.dtype != "float32" for x in args]): print("Only float32 inputs are supported for TensorRT.") @@ -469,7 +469,7 @@ def resize_whitelist_fn(attrs, args): # pylint: disable=unused-variable # TODO(trevmorr): coordinate transform method return True - @reg.register("nn.adaptive_max_pool2d", "target.tensorrt") + @tvm.ir.register_op_attr("nn.adaptive_max_pool2d", "target.tensorrt") def adapative_max_pool2d_whitelist_fn(attrs, args): # pylint: disable=unused-variable if any([x.checked_type.dtype != "float32" for x in args]): print("Only float32 inputs are supported for TensorRT.") @@ -479,7 +479,7 @@ def adapative_max_pool2d_whitelist_fn(attrs, args): # pylint: disable=unused-var return False return True - @reg.register("nn.adaptive_avg_pool2d", "target.tensorrt") + @tvm.ir.register_op_attr("nn.adaptive_avg_pool2d", "target.tensorrt") def adapative_avg_pool2d_whitelist_fn(attrs, args): # pylint: disable=unused-variable if any([x.checked_type.dtype != "float32" for x in args]): print("Only float32 inputs are supported for TensorRT.") @@ -489,7 +489,7 @@ def adapative_avg_pool2d_whitelist_fn(attrs, args): # pylint: disable=unused-var return False return True - @reg.register("nn.upsampling", "target.tensorrt") + @tvm.ir.register_op_attr("nn.upsampling", "target.tensorrt") def upsampling_whitelist_fn(attrs, args): # pylint: disable=unused-variable if any([x.checked_type.dtype != "float32" for x in args]): print("Only float32 inputs are supported for TensorRT.") @@ -687,7 +687,7 @@ def EnableTrt(mod, params=None, trt_version=None, use_implicit_batch=True, SimplifySliceLikePass(), RemoveDropoutPass(), transform.RemoveUnusedFunctions(), - transform.ConvertLayout('NCHW'), + transform.ConvertLayout({'nn.conv2d': ['NCHW', 'default']}), transform.FoldConstant(), LegalizeLayoutTranformPass(), transform.AnnotateTarget('tensorrt'), diff --git a/python/tvm/relay/testing/py_converter.py b/python/tvm/relay/testing/py_converter.py index 61a04ec392dd..351f15364966 100644 --- a/python/tvm/relay/testing/py_converter.py +++ b/python/tvm/relay/testing/py_converter.py @@ -190,7 +190,7 @@ def convert_func_node(self, func: Function, name_var=None): if name_var is None: func_name = self.generate_function_name('_anon_func') if isinstance(name_var, GlobalVar): - func_name = name_var.name_hint + func_name = str(name_var.name_hint) if isinstance(name_var, Var): func_name = self.get_var_name(name_var) @@ -411,7 +411,7 @@ def visit_var(self, var: Expr): def visit_global_var(self, gvar: Expr): # we don't need to add numbers to global var names because # the *names* are checked for uniqueness in the mod - return (Name(gvar.name_hint, Load()), []) + return (Name(str(gvar.name_hint), Load()), []) def visit_let(self, letexp: Expr): @@ -493,7 +493,7 @@ def visit_call(self, call: Expr): func = call.op fields, field_defs = self.convert_fields(call.args) - if isinstance(func, relay.Op): + if isinstance(func, tvm.ir.Op): raise Exception('Operators should have been lowered and eliminated') if isinstance(func, relay.Constructor): diff --git a/python/tvm/relay/testing/tf.py b/python/tvm/relay/testing/tf.py index 1a231eb1aaed..dc7937c0b346 100644 --- a/python/tvm/relay/testing/tf.py +++ b/python/tvm/relay/testing/tf.py @@ -183,11 +183,16 @@ def get_workload_official(model_url, model_sub_path): model_path = download_testdata(model_url, model_tar_name, module=['tf', 'official']) dir_path = os.path.dirname(model_path) - import tarfile if model_path.endswith("tgz") or model_path.endswith("gz"): + import tarfile tar = tarfile.open(model_path) tar.extractall(path=dir_path) tar.close() + elif model_path.endswith("zip"): + import zipfile + zip_object = zipfile.ZipFile(model_path) + zip_object.extractall(path=dir_path) + zip_object.close() else: raise RuntimeError('Could not decompress the file: ' + model_path) return os.path.join(dir_path, model_sub_path) diff --git a/python/tvm/relay/transform/__init__.py b/python/tvm/relay/transform/__init__.py index 93d4341635a0..138a36611c6f 100644 --- a/python/tvm/relay/transform/__init__.py +++ b/python/tvm/relay/transform/__init__.py @@ -18,5 +18,4 @@ """The Relay IR namespace containing transformations.""" # transformation passes from .transform import * - from . import memory_alloc diff --git a/python/tvm/relay/transform/memory_alloc.py b/python/tvm/relay/transform/memory_alloc.py index 611fb1babf55..6c081cbac0de 100644 --- a/python/tvm/relay/transform/memory_alloc.py +++ b/python/tvm/relay/transform/memory_alloc.py @@ -28,19 +28,21 @@ from ..backend import compile_engine from ..op.memory import flatten_tuple_type, from_tuple_type, to_tuple_type from ...import cpu +from ..op.memory import alloc_storage +def alloc_tensor(storage, shape, dtype='float32', assert_shape=None): + offset = expr.const(0, dtype="int64") + return op.memory.alloc_tensor(storage, offset, shape, dtype, assert_shape) def is_primitive(call): return hasattr(call, 'op') and hasattr(call.op, 'attrs') and \ hasattr(call.op.attrs, 'Primitive') and int(call.op.attrs.Primitive) == 1 class ManifestAllocPass(ExprMutator): - """A pass for explictly manifesting all memory allocations in Relay.""" + """A pass for explicitly manifesting all memory allocations in Relay.""" def __init__(self, target_host): self.invoke_tvm = op.memory.invoke_tvm_op - self.alloc_storage = op.memory.alloc_storage - self.alloc_tensor = op.memory.alloc_tensor self.shape_func = op.memory.shape_func self.scopes = [ScopeBuilder()] self.target_host = target_host @@ -94,17 +96,16 @@ def make_static_allocation(self, scope, tensor_type, i): """Allocate a tensor with a statically known shape.""" shape = [int(sh) for sh in tensor_type.shape] if len(shape) == 0: - shape = expr.const(np.array([]).astype( - self.compute_dtype), dtype=self.compute_dtype) + shape = expr.const(np.empty((), dtype=self.compute_dtype), dtype=self.compute_dtype) else: shape = expr.const(np.array(shape), dtype=self.compute_dtype) size = self.compute_storage(tensor_type) alignment = self.compute_alignment(tensor_type.dtype) dtype = tensor_type.dtype - sto = scope.let("storage_{0}".format(i), self.alloc_storage( + sto = scope.let("storage_{0}".format(i), alloc_storage( size, alignment, self.default_context, dtype)) # TODO(@jroesch): There is a bug with typing based on the constant shape. - tensor = self.alloc_tensor(sto, shape, dtype, tensor_type.shape) + tensor = alloc_tensor(sto, shape, dtype, tensor_type.shape) return scope.let("tensor_{0}".format(i), tensor) def visit_let(self, let): @@ -172,14 +173,14 @@ def dynamic_invoke(self, scope, func, ins, new_args, out_types, ret_type): size = self.compute_storage_in_relay( out_shape, out_type.dtype) alignment = self.compute_alignment(out_type.dtype) - sto = scope.let("storage_{i}".format(i=i), self.alloc_storage( + sto = scope.let("storage_{i}".format(i=i), alloc_storage( size, alignment, self.default_context, out_type.dtype)) storages.append(sto) outs = [] sh_ty_storage = zip(out_shapes, out_types, storages) for i, (out_shape, out_type, storage) in enumerate(sh_ty_storage): - alloc = self.alloc_tensor( + alloc = alloc_tensor( storage, out_shape, out_type.dtype, @@ -204,6 +205,7 @@ def visit_call(self, call): # Because we are in ANF we do not need to visit the arguments. scope = self.current_scope() new_args = [self.visit(arg) for arg in call.args] + ins = expr.Tuple(new_args) ret_type = call.checked_type out_types = flatten_tuple_type(ret_type) @@ -233,7 +235,7 @@ def __init__(self, target_host): self.target_host = target_host def transform_function(self, func, mod, _): - # TODO(@jroesch): Is there a way to do one shot initilization? + # TODO(@jroesch): Is there a way to do one shot initialization? # can we have def pass_init? mod.import_from_std("core.rly") ea = ManifestAllocPass(self.target_host) diff --git a/python/tvm/relay/transform/memory_plan.py b/python/tvm/relay/transform/memory_plan.py new file mode 100644 index 000000000000..8f21af9292a9 --- /dev/null +++ b/python/tvm/relay/transform/memory_plan.py @@ -0,0 +1,366 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=no-else-return,invalid-name,len-as-condition,too-many-nested-blocks +""" +A pass for manifesting explicit memory allocations. +""" +from typing import Optional, Dict, List, Tuple +from collections import defaultdict +import attr + +from ..expr_functor import ExprMutator +from .. import op, expr +from ..function import Function +from ... import register_func, ir, cpu +from ..._ffi.runtime_ctypes import TVMContext +from ... import IRModule +from .. import transform +from . import function_pass + + +def is_primitive(call): + return ( + hasattr(call, "op") + and hasattr(call.op, "attrs") + and hasattr(call.op.attrs, "Primitive") + and int(call.op.attrs.Primitive) == 1 + ) + + +@attr.s(auto_attribs=True) +class Region: + """ + Represents a control-free allocation region. + + The below pass groups sets of allocations into regions, + then replaces the region with a single allocation. + """ + var: expr.Var + size: expr.Expr + alignment: Optional[expr.Expr] + dtype: Optional[str] + ctx: TVMContext + offsets: Dict[expr.Var, Tuple[expr.Expr, expr.Expr]] + + @staticmethod + def empty(region_no): + zero = expr.const(0, dtype="int64") + assert len(zero.data.shape) == 0 + region_var = expr.var(f"region{region_no}") + return Region(region_var, zero, None, None, None, {}) + + def grow( + self, old_storage: expr.Var, + size: expr.Expr, alignment: expr.Expr, + ctx: TVMContext, + dtype: str) -> None: + """Grow the region by a given allocation as well as track the old storage + for later rewriting the program to use the allocated region. + """ + if self.dtype: + assert self.dtype == dtype, "must have matching dtypes in a region" + else: + self.dtype = dtype + + if self.alignment: + assert ir.structural_equal( + self.alignment, alignment + ), "must have matching alignments in a region" + else: + self.alignment = alignment + + if self.ctx: + assert (self.ctx.device_type == ctx.device_type and + self.ctx.device_id == ctx.device_id), "must have matching context" + else: + assert ctx + self.ctx = ctx + + new_size = (size + self.alignment - expr.const(1, "int64")) \ + / self.alignment * self.alignment + + # Record the offset at which we allocate the storage. + offset_var: expr.RelayExpr = expr.var(f"offset{len(self.offsets)}") + self.offsets[old_storage] = (offset_var, self.size) + + self.size = self.size + new_size + + def offset_for(self, alloc: expr.Expr) -> expr.Expr: + return self.offsets.get(alloc, [None])[0] + + def to_expr(self, body: expr.Expr) -> expr.Expr: + """ + Generate the prelude code for a region, wrapping the body in it. + + The prelude contains the single allocation for a region, and + all offset computations. + """ + + if self.ctx is None: + self.ctx = cpu(0) + + # Generate bindings for each and every size computation + # we must do this to maintain ANF. + bindings: List[Tuple[expr.Expr, expr.Expr]] = [] + + # First compute the total size. + total_size = expr.var(f"total_size{hash(body)}") + bindings.append((total_size, self.size)) + + # Allocate the entire region with a single call. + alloc = op.memory.alloc_storage(total_size, self.alignment, self.ctx, self.dtype) + bindings.append((self.var, alloc)) + + # Generate variables which contain all of the offset math. + # Ensure we constant evaluate away all the math here. + # + # In theory we can support dynamic offsets but this + # requires another round of memory planning and + # potentially colaescing. + for alloc in self.offsets: + (var, offset) = self.offsets[alloc] + bindings.append((var, offset)) + + body = mk_let(bindings, body) + return body + + +def iterative_let(let, each_binding, kont): + bindings = [] + while isinstance(let, expr.Let): + lhs = let.var + rhs = let.value + bindings.append(each_binding(lhs, rhs)) + let = let.body + + return kont(bindings, let) + + + +def mk_let(bindings, body): + for var, value in reversed(bindings): + assert var + assert value + assert body + body = expr.Let(var, value, body) + + return body + +def const_eval(mod, exp): + mod = IRModule.from_expr(exp, type_defs=mod.type_definitions) + mod = transform.FoldConstant()(mod) + return mod["main"] + +class StorageCoalesce(ExprMutator): + """ + A pass for coalescing allocations into region/arena allocations. + + After this pass each allocation comes from the same backing storage, + but will never overlap even in time, i.e. the allocations are just + packed into a contiguous block of memory. + + A secondary part of memory planning will perform liveness analysis to + overlap these in time, i.e when an early tensor dies we will attempt + to reuse its slot. + """ + + def __init__(self): + super().__init__() + self.regions = [] + + def enter_scope(self) -> None: + region_no = len(self.regions) + self.regions.append(defaultdict(lambda: Region.empty(region_no))) + + def exit_scope(self, body: expr.Expr) -> expr.Expr: + """When leaving a scope build a region allocation for the scope.""" + dtype_region = self.regions.pop() + for _, region in reversed(list(dtype_region.items())): + if len(region.offsets) != 0: + body = region.to_expr(body) + + return body + + def current_region(self, dtype) -> Region: + current_scope = self.regions[-1] + return current_scope[dtype] + + def new_region_and_offset(self, old_storage): + for dtype_region in reversed(self.regions): + for dtype in dtype_region: + region = dtype_region[dtype] + offset = region.offset_for(old_storage) + if offset: + return region, offset + + raise Exception("could not find offset in any valid region") + + def visit_function(self, fn): + """Transform the function body to use region allocation scheme.""" + func = fn + if getattr(func.attrs, "Primitive", 0) == 1: + return super().visit_function(func) + else: + self.enter_scope() + body = self.visit(func.body) + body = self.exit_scope(body) + return Function( + func.params, + body, + func.ret_type, + func.type_params, + func.attrs, + ) + + def visit_if(self, ite): + self.enter_scope() + true_branch = self.visit(ite.true_branch) + true_branch = self.exit_scope(true_branch) + + self.enter_scope() + false_branch = self.visit(ite.false_branch) + false_branch = self.exit_scope(false_branch) + + return expr.If(ite.cond, true_branch, false_branch) + + + def mk_let(self, dynamic_regions): + """Let bind the dynamic regions""" + def _mk_let(bindings, body): + for var, value in reversed(bindings): + assert var + assert value is not None + assert body + body = expr.Let(var, value, body) + if var in dynamic_regions: + body = self.exit_scope(body) + + return body + + return _mk_let + + def visit_let(self, let): + dynamic_regions = [] + def _each_binding(lhs, rhs): + if isinstance(rhs, expr.Call) and rhs.op == op.op.get( + "memory.alloc_storage" + ): + return self.process_alloc_storage(dynamic_regions, lhs, rhs) + elif isinstance(rhs, expr.Call) and rhs.op == op.op.get( + "memory.alloc_tensor" + ): + return self.process_alloc_tensor(lhs, rhs) + else: + return lhs, rhs + + result = iterative_let(let, _each_binding, self.mk_let(dynamic_regions)) + assert result + return result + + def process_alloc_storage(self, dynamic_regions, lhs, call): + """Process alloc_storage""" + size, alignment = call.args + dtype = call.attrs.dtype + ctx = TVMContext(call.attrs.device_type, call.attrs.device_id) + + if not isinstance(size, expr.Constant): + self.enter_scope() + dynamic_regions.append(lhs) + + region = self.current_region(dtype) + region.grow(lhs, size, alignment, ctx, dtype) + return lhs, region.var + + def process_alloc_tensor(self, lhs, call): + """Process alloc tensor. Region and offset are computed""" + storage, old_offset, shape = call.args + region, offset = self.new_region_and_offset(storage) + + assert ( + old_offset.data.asnumpy().item() == 0 + ), "no offsets should yet be allocated" + return ( + lhs, + expr.Call(call.op, [region.var, offset, shape], call.attrs), + ) + +class LiftConst(ExprMutator): + """An internal pass to lift constants to the top level of function.""" + def __init__(self): + self.i = 0 + self.constants = [] + self.top_level = True + super().__init__() + + def visit_constant(self, const): + var = expr.var(f"const{self.i}") + self.i += 1 + self.constants.append((var, const)) + return var + + def visit_function(self, fn): + if int(getattr(fn.attrs, "Primitive", 0)) == 1: + return fn + + outer_constant = self.constants + self.constants = [] + # Populates self.constants. + body = self.visit(fn.body) + body = mk_let(self.constants, body) + self.constants = outer_constant + + return Function( + fn.params, + body, + fn.ret_type, + fn.type_params, + fn.attrs) + + def visit_let(self, let): + bindings = [] + while isinstance(let, expr.Let): + new_var = self.visit(let.var) + new_val = self.visit(let.value) + bindings.append((new_var, new_val)) + let = let.body + + new_body = self.visit(let) + return mk_let(bindings, new_body) + +@function_pass(opt_level=0) +class MemoryPlan: + """An explicit pass wrapper around StorageCoalesce.""" + + def transform_function(self, func, mod, _): + mod.import_from_std("core.rly") + sc = StorageCoalesce() + func = sc.visit(func) + return func + +register_func("relay.transform.MemoryPlan", MemoryPlan) + +@function_pass(opt_level=0) +class LiftConstants: + """An explicit pass wrapper around LiftConst.""" + + def transform_function(self, func, mod, _): + mod.import_from_std("core.rly") + func = LiftConst().visit(func) + return func + + +register_func("relay.transform.LiftConstants", LiftConstants) diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 292c5fd39acb..8f4ec1046500 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -21,6 +21,7 @@ import types import inspect import functools +import warnings import tvm.ir from tvm import te @@ -31,11 +32,12 @@ def build_config(opt_level=2, - fallback_device=_nd.cpu(), required_pass=None, disabled_pass=None, trace=None): - """Configure the build behavior by setting config variables. + """Configure the build behavior by setting config variables. This function + will be deprecated in TVM v0.7. Instead, we should directly use + tvm.transform.PassContext. Parameters ---------- @@ -59,10 +61,6 @@ def build_config(opt_level=2, "FastMath": 4 } - fallback_device : int, str, or tvmContext, optional - The fallback device. It is also used as the default device for - operators without specified device during heterogeneous execution. - required_pass: set of str, optional Optimization passes that are required regardless of optimization level. @@ -77,9 +75,9 @@ def build_config(opt_level=2, pass_context: PassContext The pass context for optimizations. """ - return tvm.ir.transform.PassContext( - opt_level, fallback_device, required_pass, - disabled_pass, trace) + warnings.warn("relay.build_config will be deprecated. Please use \ + tvm.transform.PassContext directly", DeprecationWarning) + return tvm.transform.PassContext(opt_level, required_pass, disabled_pass, trace) @tvm._ffi.register_object("relay.FunctionPass") @@ -324,7 +322,7 @@ def AlterOpLayout(): return _ffi_api.AlterOpLayout() -def ConvertLayout(desired_layout): +def ConvertLayout(desired_layouts): """ Given a dest layout, this pass transforms the expr such that most of the ops input data layout is changed to the dest layout. In ideal situation, there are only 2 layout transforms, one at the start and one at the end. @@ -341,15 +339,18 @@ def ConvertLayout(desired_layout): Parameters ---------- - desired_layout : str - The desired layout for the transformed expr. + desired_layouts : map of op_name to list of layouts + Specify a mapping of operator names to a list of layouts to convert to, in the order + defined by the operator. An example for nn.conv2d could be: {"nn.conv2d", ["NHWC", "OHWI]}, + where the first item in the list specifies the data layout and the second specifies the + kernel layout. Returns ------- pass: FunctionPass The pass. """ - return _ffi_api.ConvertLayout(desired_layout) + return _ffi_api.ConvertLayout(desired_layouts) def Legalize(legalize_map_attr_name="FTVMLegalize"): @@ -377,7 +378,7 @@ def MergeComposite(pattern_table): Parameters ---------- - pattern_table : list(tuple) + pattern_table : List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Function]] A list of (pattern_name, pattern, check) tuples. The order of the patterns in the list will determine the order of priority in which they are matched. @@ -653,7 +654,8 @@ def to_cps(func, mod=None): result: tvm.relay.Function The output function. """ - return _ffi_api.to_cps(func, mod) + use_mod = mod if mod is not None else tvm.ir.IRModule() + return _ffi_api.to_cps(func, use_mod) def un_cps(func): @@ -839,3 +841,43 @@ def visit_var(self, var): return relay.Var(var.name_hint, relay.TensorType(new_shape, ty.dtype)) return var return ChangeBatchMutator().visit(func) + + +def DenseToSparse(weight_name, weight_shape): + """ + Rewrite qualified ```nn.dense operation``` to ```nn.sparse_dense``` + This pass is used in ```data_dep_optimization.bsr_dense``` + Parameters of this pass is generated by ```analysis.sparse_dense.process_params``` + + Parameters + ---------- + weight_name: Array[String] + Names of weights which qualified sparse contrains + + weight_shape: Array[Array[IntImm]] + Weights shape in BSR format. + + Returns + ------- + ret : tvm.transform.Pass + The registered DenseToSparse pass. + """ + return _ffi_api.DenseToSparse(weight_name, weight_shape) + +def SimplifyFCTranspose(target_weight_name): + """ + Rewrite ```y = nn.dense(x, transpose(w, [1, 0]))``` to ```y = nn.dense(x, wt)``` + This pass is used in ```data_dep_optimization.simplify_fc_transpose``` + + Parameters + ---------- + weight_name: Array[String] + Names of weights which qualified ```y = nn.dense(x, transpose(w, [1, 0]))``` + This parameter is generated by ```analysis.search_fc_transpose``` function + + Returns + ------- + ret : tvm.transform.Pass + The registered SimplifyFCTranspose pass. + """ + return _ffi_api.SimplifyFCTranspose(target_weight_name) diff --git a/python/tvm/rpc/__init__.py b/python/tvm/rpc/__init__.py index 5f959eb44745..b64ba33d9e09 100644 --- a/python/tvm/rpc/__init__.py +++ b/python/tvm/rpc/__init__.py @@ -26,4 +26,6 @@ """ from .server import Server -from .client import RPCSession, LocalSession, TrackerSession, connect, connect_tracker +from .client import connect, connect_tracker +from .client import RPCSession, LocalSession, PopenSession, TrackerSession +from .minrpc import with_minrpc diff --git a/python/tvm/tir/ir_pass.py b/python/tvm/rpc/_ffi_api.py similarity index 64% rename from python/tvm/tir/ir_pass.py rename to python/tvm/rpc/_ffi_api.py index 239b1fb98dd0..1a7cc739b5c1 100644 --- a/python/tvm/tir/ir_pass.py +++ b/python/tvm/rpc/_ffi_api.py @@ -14,15 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Namespace of IR pass functions. - -This namespace is used for developers. While you do not see any declarations. -The functions are automatically exported from C++ side via PackedFunc. - -Each api is a PackedFunc that can be called in a positional argument manner. -You can read "include/tvm/tir/ir_pass.h" for the function signature and -"src/api/api_pass.cc" for the PackedFunc's body of these functions. -""" +"""FFI APIs for tvm.rpc""" import tvm._ffi -tvm._ffi._init_api("tvm.ir_pass", __name__) + +tvm._ffi._init_api("rpc", __name__) diff --git a/python/tvm/rpc/base.py b/python/tvm/rpc/base.py index bc81534a12d9..f0e33f8503f2 100644 --- a/python/tvm/rpc/base.py +++ b/python/tvm/rpc/base.py @@ -17,8 +17,6 @@ """Base definitions for RPC.""" # pylint: disable=invalid-name -from __future__ import absolute_import - import socket import time import json @@ -26,7 +24,6 @@ import struct import random import logging -import tvm._ffi from .._ffi.base import py_str @@ -176,7 +173,3 @@ def connect_with_retry(addr, timeout=60, retry_period=5): logger.warning("Cannot connect to tracker %s, retry in %g secs...", str(addr), retry_period) time.sleep(retry_period) - - -# Still use tvm.rpc for the foreign functions -tvm._ffi._init_api("tvm.rpc", "tvm.rpc.base") diff --git a/python/tvm/rpc/client.py b/python/tvm/rpc/client.py index ed57e0d4276d..2f96c9b62976 100644 --- a/python/tvm/rpc/client.py +++ b/python/tvm/rpc/client.py @@ -15,19 +15,20 @@ # specific language governing permissions and limitations # under the License. """RPC client tools""" -from __future__ import absolute_import - import os +import stat import socket import struct import time + import tvm._ffi from tvm.contrib import util from tvm._ffi.base import TVMError from tvm.runtime import ndarray as nd -from tvm.runtime import load_module as _load_module from . import base +from . import server +from . import _ffi_api class RPCSession(object): @@ -38,9 +39,23 @@ class RPCSession(object): # pylint: disable=invalid-name def __init__(self, sess): self._sess = sess - self._tbl_index = base._SessTableIndex(sess) + self._tbl_index = _ffi_api.SessTableIndex(sess) self._remote_funcs = {} + def system_lib(self): + """Get system-wide library module. + + Returns + ------- + module : runtime.Module + The system-wide library module. + + See Also + -------- + tvm.runtime.system_lib + """ + return self.get_function("runtime.SystemLib")() + def get_function(self, name): """Get function from the session. @@ -145,7 +160,7 @@ def load_module(self, path): m : Module The remote module containing remote function. """ - return base._LoadRemoteModule(self._sess, path) + return _ffi_api.LoadRemoteModule(self._sess, path) def cpu(self, dev_id=0): """Construct CPU device.""" @@ -167,14 +182,14 @@ def metal(self, dev_id=0): """Construct Metal device.""" return self.context(8, dev_id) - def opengl(self, dev_id=0): - """Construct OpenGL device.""" - return self.context(11, dev_id) - def ext_dev(self, dev_id=0): """Construct extension device.""" return self.context(12, dev_id) + def webgpu(self, dev_id=0): + """Construct WebGPU device.""" + return self.context(15, dev_id) + class LocalSession(RPCSession): """RPCSession interface backed by local environment. @@ -183,28 +198,41 @@ class LocalSession(RPCSession): need to be ran both locally and remotely. """ def __init__(self): - # pylint: disable=super-init-not-called - self.context = nd.context - self.get_function = tvm._ffi.get_global_func - self._temp = util.tempdir() + self._temp = server._server_env([]) + RPCSession.__init__(self, _ffi_api.LocalSession()) - def upload(self, data, target=None): - if isinstance(data, bytearray): - if not target: - raise ValueError("target must present when file is a bytearray") - blob = data - else: - blob = bytearray(open(data, "rb").read()) - if not target: - target = os.path.basename(data) - with open(self._temp.relpath(target), "wb") as f: - f.write(blob) - def download(self, path): - return bytearray(open(self._temp.relpath(path), "rb").read()) +@tvm._ffi.register_func("rpc.PopenSession") +def _popen_session(binary): + temp = util.tempdir() + + if isinstance(binary, (bytes, bytearray)): + path_exec = temp.relpath("server.minrpc") + with open(path_exec, "wb") as outfile: + outfile.write(binary) + os.chmod(path_exec, stat.S_IXUSR | stat.S_IRUSR) + path_exec = os.path.abspath(path_exec) + else: + path_exec = os.path.abspath(binary) + if not os.path.isfile(path_exec): + raise RuntimeError(f"{path_exec} does not exist.") + if not os.access(path_exec, os.X_OK): + raise RuntimeError(f"{path_exec} is not executable.") + + sess = _ffi_api.CreatePipeClient(path_exec) + return sess - def load_module(self, path): - return _load_module(self._temp.relpath(path)) + +class PopenSession(RPCSession): + """RPCSession interface backed by popen. + + Parameters + ---------- + binary : List[Union[str, bytes]] + The binary to be executed. + """ + def __init__(self, binary): + RPCSession.__init__(self, _popen_session(binary)) class TrackerSession(object): @@ -378,7 +406,7 @@ def request_and_run(self, key, max_retry, str(last_err))) -def connect(url, port, key="", session_timeout=0): +def connect(url, port, key="", session_timeout=0, session_constructor_args=None): """Connect to RPC Server Parameters @@ -397,15 +425,43 @@ def connect(url, port, key="", session_timeout=0): the connection when duration is longer than this value. When duration is zero, it means the request must always be kept alive. + session_constructor_args: List + List of additional arguments to passed as the remote session constructor. + The first element of the list is always a string specifying the name of + the session constructor, the following args are the positional args to that function. + Returns ------- sess : RPCSession The connected session. + + Examples + -------- + Normal usage + .. code-block:: python + + client = rpc.connect(server_url, server_port, server_key) + + Session_constructor can be used to customize the session in the remote + The following code connects to a remote internal server via a proxy + by constructing another RPCClientSession on the proxy machine and use that + as the serving session of the proxy endpoint. + + .. code-block:: python + + client_via_proxy = rpc.connect( + proxy_server_url, proxy_server_port, proxy_server_key, + session_constructor_args=[ + "rpc.Connect", internal_url, internal_port, internal_key]) + """ try: if session_timeout: key += " -timeout=%s" % str(session_timeout) - sess = base._Connect(url, port, key) + session_constructor_args = session_constructor_args if session_constructor_args else [] + if not isinstance(session_constructor_args, (list, tuple)): + raise TypeError("Expect the session constructor to be a list or tuple") + sess = _ffi_api.Connect(url, port, key, *session_constructor_args) except NameError: raise RuntimeError("Please compile with USE_RPC=1") return RPCSession(sess) diff --git a/python/tvm/rpc/minrpc.py b/python/tvm/rpc/minrpc.py new file mode 100644 index 000000000000..760c5362f11d --- /dev/null +++ b/python/tvm/rpc/minrpc.py @@ -0,0 +1,86 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Utils to path.""" +import os +from tvm._ffi import libinfo +from tvm.contrib import cc + + +def find_minrpc_server_libpath(server="posix_popen_server"): + """Get the path of minrpc server libary. + + Parameters + ---------- + server : str + The kind of built in minrpc server. + + Returns + ------- + path : str + The path to the min server library. + """ + curr_dir = os.path.dirname(os.path.realpath(os.path.expanduser(__file__))) + source_dir = os.path.abspath(os.path.join(curr_dir, "..", "..", "..")) + + path = os.path.join( + source_dir, "src", "runtime", "rpc", "minrpc", ("%s.cc" % server)) + + candidates = [path] + if not os.path.isfile(path): + raise RuntimeError("Cannot find minserver %s, in candidates %s" % (server, candidates)) + return path + + +def with_minrpc(compile_func, + server="posix_popen_server", + runtime="libtvm"): + """Attach the compiler function with minrpc related options. + + Parameters + ---------- + compile_func : Union[str, Callable[[str, str, Optional[str]], None]] + The compilation function to decorate. + + server : str + The server type. + + runtime : str + The runtime library. + + Returns + ------- + fcompile : function + The return compilation. + """ + server_path = find_minrpc_server_libpath(server) + runtime_path = libinfo.find_lib_path( + [runtime, runtime + ".so", runtime + ".dylib"])[0] + + runtime_dir = os.path.abspath(os.path.dirname(runtime_path)) + options = ["-std=c++14"] + # Make sure the rpath to the libtvm is set so we can do local tests. + # Note that however, this approach won't work on remote. + # Always recommend to to link statically. + options += ["-Wl,-rpath=" + runtime_dir] + options += ["-I" + path for path in libinfo.find_include_path()] + fcompile = cc.cross_compiler( + compile_func, + options=options, + add_files=[server_path, runtime_path]) + fcompile.__name__ = "with_minrpc" + fcompile.need_system_lib = True + return fcompile diff --git a/python/tvm/rpc/proxy.py b/python/tvm/rpc/proxy.py index c3a3647948ee..994e230b982a 100644 --- a/python/tvm/rpc/proxy.py +++ b/python/tvm/rpc/proxy.py @@ -42,6 +42,7 @@ raise ImportError( "RPCProxy module requires tornado package %s. Try 'pip install tornado'." % error_msg) +from . import _ffi_api from . import base from .base import TrackerCode from .server import _server_env @@ -129,7 +130,7 @@ def close_pair(self): def on_close_event(self): """on close event""" assert not self._done - logging.info("RPCProxy:on_close %s ...", self.name()) + logging.info("RPCProxy:on_close_event %s ...", self.name()) if self.match_key: key = self.match_key if self._proxy._client_pool.get(key, None) == self: @@ -157,10 +158,12 @@ def on_message(self, message): self.on_data(message) def on_close(self): + logging.info("RPCProxy: on_close %s ...", self.name()) + self._close_process = True + if self.forward_proxy: self.forward_proxy.signal_close() self.forward_proxy = None - logging.info("%s Close socket..", self.name()) self.on_close_event() @@ -186,6 +189,7 @@ def send_data(self, message): self.on_error(err) def on_close(self): + logging.info("RPCProxy: on_close %s ...", self.name()) if self.forward_proxy: self.forward_proxy.signal_close() self.forward_proxy = None @@ -549,7 +553,7 @@ def _fsend(data): data = bytes(data) conn.write_message(data, binary=True) return len(data) - on_message = base._CreateEventDrivenServer( + on_message = _ffi_api.CreateEventDrivenServer( _fsend, "WebSocketProxyServer", "%toinit") return on_message diff --git a/python/tvm/rpc/server.py b/python/tvm/rpc/server.py index 627d67a0a835..15a3c7de789d 100644 --- a/python/tvm/rpc/server.py +++ b/python/tvm/rpc/server.py @@ -43,6 +43,7 @@ from tvm._ffi.libinfo import find_lib_path from tvm.runtime.module import load_module as _load_module from tvm.contrib import util +from . import _ffi_api from . import base from . base import TrackerCode @@ -56,7 +57,7 @@ def _server_env(load_library, work_path=None): temp = util.tempdir() # pylint: disable=unused-variable - @tvm._ffi.register_func("tvm.rpc.server.workpath") + @tvm._ffi.register_func("tvm.rpc.server.workpath", override=True) def get_workpath(path): return temp.relpath(path) @@ -81,7 +82,7 @@ def _serve_loop(sock, addr, load_library, work_path=None): """Server loop""" sockfd = sock.fileno() temp = _server_env(load_library, work_path) - base._ServerLoop(sockfd) + _ffi_api.ServerLoop(sockfd) if not work_path: temp.remove() logger.info("Finish serving %s", addr) @@ -325,9 +326,12 @@ def __init__(self, key="", load_library=None, custom_addr=None, - silent=False): + silent=False, + utvm_dev_id=None, + utvm_dev_config_args=None, + ): try: - if base._ServerLoop is None: + if _ffi_api.ServerLoop is None: raise RuntimeError("Please compile with USE_RPC=1") except NameError: raise RuntimeError("Please compile with USE_RPC=1") @@ -355,6 +359,10 @@ def __init__(self, cmd += ["--custom-addr", custom_addr] if silent: cmd += ["--silent"] + if utvm_dev_id is not None: + assert utvm_dev_config_args is not None + cmd += [f"--utvm-dev-id={utvm_dev_id}"] + cmd += [f"--utvm-dev-config-args={utvm_dev_config_args}"] # prexec_fn is not thread safe and may result in deadlock. # python 3.2 introduced the start_new_session parameter as diff --git a/python/tvm/runtime/__init__.py b/python/tvm/runtime/__init__.py index 235ef0cf219e..21c06c517bd7 100644 --- a/python/tvm/runtime/__init__.py +++ b/python/tvm/runtime/__init__.py @@ -20,12 +20,12 @@ from .packed_func import PackedFunc from .object import Object from .object_generic import ObjectGeneric, ObjectTypes -from .ndarray import NDArray, DataType, TypeCode, TVMContext +from .ndarray import NDArray, DataType, DataTypeCode, TVMContext from .module import Module # function exposures from .object_generic import convert_to_object, convert, const from .ndarray import context, cpu, gpu, opencl, cl, vulkan, metal, mtl -from .ndarray import vpi, rocm, opengl, ext_dev, micro_dev +from .ndarray import vpi, rocm, ext_dev, micro_dev from .module import load_module, enabled, system_lib from .container import String diff --git a/python/tvm/runtime/container.py b/python/tvm/runtime/container.py index a719dcd4eaf0..ae87534cd211 100644 --- a/python/tvm/runtime/container.py +++ b/python/tvm/runtime/container.py @@ -16,9 +16,10 @@ # under the License. """Runtime container structures.""" import tvm._ffi -from tvm._ffi.base import string_types -from tvm.runtime import Object, ObjectTypes -from tvm.runtime import _ffi_api +from .object import Object, PyNativeObject +from .object_generic import ObjectTypes +from . import _ffi_api + def getitem_helper(obj, elem_getter, length, idx): """Helper function to implement a pythonic getitem function. @@ -60,7 +61,7 @@ def getitem_helper(obj, elem_getter, length, idx): return elem_getter(obj, idx) -@tvm._ffi.register_object("vm.ADT") +@tvm._ffi.register_object("runtime.ADT") class ADT(Object): """Algebatic data type(ADT) object. @@ -112,64 +113,26 @@ def tuple_object(fields=None): @tvm._ffi.register_object("runtime.String") -class String(Object): - """The string object. +class String(str, PyNativeObject): + """TVM runtime.String object, represented as a python str. Parameters ---------- - string : str - The string used to construct a runtime String object - - Returns - ------- - ret : String - The created object. + content : str + The content string used to construct the object. """ - def __init__(self, string): - self.__init_handle_by_constructor__(_ffi_api.String, string) - - def __str__(self): - return _ffi_api.GetStdString(self) - - def __len__(self): - return _ffi_api.GetStringSize(self) - - def __hash__(self): - return _ffi_api.StringHash(self) - - def __eq__(self, other): - if isinstance(other, string_types): - return self.__str__() == other - - if not isinstance(other, String): - return False - - return _ffi_api.CompareString(self, other) == 0 - - def __ne__(self, other): - return not self.__eq__(other) - - def __gt__(self, other): - return _ffi_api.CompareString(self, other) > 0 - - def __lt__(self, other): - return _ffi_api.CompareString(self, other) < 0 - - def __getitem__(self, key): - return self.__str__()[key] - - def startswith(self, string): - """Check if the runtime string starts with a given string - - Parameters - ---------- - string : str - The provided string - - Returns - ------- - ret : boolean - Return true if the runtime string starts with the given string, - otherwise, false. - """ - return self.__str__().startswith(string) + __slots__ = ["__tvm_object__"] + + def __new__(cls, content): + """Construct from string content.""" + val = str.__new__(cls, content) + val.__init_tvm_object_by_constructor__(_ffi_api.String, content) + return val + + # pylint: disable=no-self-argument + def __from_tvm_object__(cls, obj): + """Construct from a given tvm object.""" + content = _ffi_api.GetFFIString(obj) + val = str.__new__(cls, content) + val.__tvm_object__ = obj + return val diff --git a/python/tvm/runtime/module.py b/python/tvm/runtime/module.py index 90e7cfe4f9d4..322c881cefb3 100644 --- a/python/tvm/runtime/module.py +++ b/python/tvm/runtime/module.py @@ -111,7 +111,6 @@ def __call__(self, *args): # pylint: disable=not-callable return self.entry_func(*args) - def __repr__(self): return "Module(%s, %x)" % (self.type_key, self.handle.value) @@ -147,9 +146,6 @@ def imported_modules(self): nmod = _ffi_api.ModuleImportsSize(self) return [_ffi_api.ModuleGetImport(self, i) for i in range(nmod)] - def is_empty(self): - return _ffi_api.IsEmpty(self) - def save(self, file_name, fmt=""): """Save the module to file. @@ -250,6 +246,7 @@ def _dso_exportable(self): def export_library(self, file_name, fcompile=None, + addons=None, **kwargs): """Export the module and its imported device code one library. @@ -280,19 +277,6 @@ def export_library(self, if isinstance(file_name, Path): file_name = str(file_name) - if self.is_empty(): - logging.info("The lib generated by the NNVM compiler does not contain optimized " - "functions for any operators. This usually happens when an external " - "accelerator, e.g. TensorRT, is employed along with TVM to compile " - "the model, and all the operators in the model are supported by the " - "external accelerator at runtime. Therefore, " - "the NNVM compiler skipped optimizing them at the compile time.") - if os.path.isfile(file_name): - logging.warning("Lib file %s exists, and will be overwritten by the newly created" - " lib with the same name.", file_name) - open(file_name, 'w').close() - return - if self.type_key == "stackvm": if not file_name.endswith(".stackvm"): raise ValueError("Module[%s]: can only be saved as stackvm format." @@ -302,7 +286,7 @@ def export_library(self, modules = self._collect_dso_modules() temp = _util.tempdir() - files = [] + files = addons if addons else [] is_system_lib = False has_c_module = False llvm_target_triple = None @@ -332,9 +316,12 @@ def export_library(self, if llvm_target_triple is None and hasattr(fcompile, "get_target_triple"): llvm_target_triple = fcompile.get_target_triple() + if getattr(fcompile, "need_system_lib", False) and not is_system_lib: + raise ValueError("%s need --system-lib option" % str(fcompile)) + if self.imported_modules: if enabled("llvm") and llvm_target_triple: - path_obj = temp.relpath("devc.o") + path_obj = temp.relpath("devc." + object_format) m = _ffi_api.ModulePackImportsToLLVM(self, is_system_lib, llvm_target_triple) m.save(path_obj) files.append(path_obj) @@ -398,16 +385,6 @@ def load_module(path, fmt=""): This function will automatically call cc.create_shared if the path is in format .o or .tar """ - if os.stat(path).st_size == 0: - logging.info("The lib generated by the NNVM compiler does not contain optimized " - "functions for any operators. This usually happens when an external " - "accelerator, e.g. TensorRT, is employed along with TVM to compile " - "the model, and all the operators in the model are supported by the " - "external accelerator at runtime. Therefore, the NNVM compiler skipped " - "optimizing them at the compile time. The TVM runtime " - "will create an empty Module as a dummy module.") - return _ffi_api.CreateEmptyModule() - # High level handling for .o and .tar file. # We support this to be consistent with RPC module load. if path.endswith(".o"): diff --git a/python/tvm/runtime/ndarray.py b/python/tvm/runtime/ndarray.py index ee7ab7b5d11f..060673dc19c6 100644 --- a/python/tvm/runtime/ndarray.py +++ b/python/tvm/runtime/ndarray.py @@ -22,7 +22,7 @@ from tvm._ffi.base import _LIB, check_call, c_array, string_types, _FFI_MODE from tvm._ffi.runtime_ctypes import DataType, TVMContext, TVMArray, TVMArrayHandle -from tvm._ffi.runtime_ctypes import TypeCode, tvm_shape_index_t +from tvm._ffi.runtime_ctypes import DataTypeCode, tvm_shape_index_t try: # pylint: disable=wrong-import-position @@ -36,7 +36,7 @@ from tvm._ffi._ctypes.ndarray import NDArrayBase -@tvm._ffi.register_object +@tvm._ffi.register_object("runtime.NDArray") class NDArray(NDArrayBase): """Lightweight NDArray class of TVM runtime. @@ -219,7 +219,7 @@ def context(dev_type, dev_id=0): """ if isinstance(dev_type, string_types): if '-device=micro_dev' in dev_type: - dev_type = 'micro_dev' + dev_type = TVMContext.STR2MASK['micro_dev'] else: dev_type = dev_type.split()[0] if dev_type not in TVMContext.STR2MASK: @@ -409,22 +409,6 @@ def vulkan(dev_id=0): return TVMContext(7, dev_id) -def opengl(dev_id=0): - """Construct a OpenGL device - - Parameters - ---------- - dev_id : int, optional - The integer device id - - Returns - ------- - ctx : TVMContext - The created context - """ - return TVMContext(11, dev_id) - - def ext_dev(dev_id=0): """Construct a extension device @@ -478,6 +462,22 @@ def hexagon(dev_id=0): return TVMContext(14, dev_id) +def webgpu(dev_id=0): + """Construct a webgpu device. + + Parameters + ---------- + dev_id : int, optional + The integer device id + + Returns + ------- + ctx : TVMContext + The created context + """ + return TVMContext(15, dev_id) + + cl = opencl mtl = metal diff --git a/python/tvm/runtime/object.py b/python/tvm/runtime/object.py index a55eeb0cb3ee..ec9f22b82c25 100644 --- a/python/tvm/runtime/object.py +++ b/python/tvm/runtime/object.py @@ -27,11 +27,11 @@ if _FFI_MODE == "ctypes": raise ImportError() from tvm._ffi._cy3.core import _set_class_object, _set_class_object_generic - from tvm._ffi._cy3.core import ObjectBase + from tvm._ffi._cy3.core import ObjectBase, PyNativeObject except (RuntimeError, ImportError): # pylint: disable=wrong-import-position,unused-import from tvm._ffi._ctypes.packed_func import _set_class_object, _set_class_object_generic - from tvm._ffi._ctypes.object import ObjectBase + from tvm._ffi._ctypes.object import ObjectBase, PyNativeObject def _new_object(cls): @@ -41,6 +41,7 @@ def _new_object(cls): class Object(ObjectBase): """Base class for all tvm's runtime objects.""" + __slots__ = [] def __repr__(self): return _ffi_node_api.AsRepr(self) @@ -57,7 +58,7 @@ def __getattr__(self, name): "%s has no attribute %s" % (str(type(self)), name)) def __hash__(self): - return _ffi_api.ObjectHash(self) + return _ffi_api.ObjectPtrHash(self) def __eq__(self, other): return self.same_as(other) @@ -78,13 +79,10 @@ def __getstate__(self): def __setstate__(self, state): # pylint: disable=assigning-non-slot, assignment-from-no-return handle = state['handle'] + self.handle = None if handle is not None: - json_str = handle - other = _ffi_node_api.LoadJSON(json_str) - self.handle = other.handle - other.handle = None - else: - self.handle = None + self.__init_handle_by_constructor__( + _ffi_node_api.LoadJSON, handle) def _move(self): """Create an RValue reference to the object and mark the object as moved. diff --git a/python/tvm/runtime/object_generic.py b/python/tvm/runtime/object_generic.py index ac20b67e8299..8f559ae24aac 100644 --- a/python/tvm/runtime/object_generic.py +++ b/python/tvm/runtime/object_generic.py @@ -21,7 +21,7 @@ from tvm._ffi.runtime_ctypes import ObjectRValueRef from . import _ffi_node_api, _ffi_api -from .object import ObjectBase, _set_class_object_generic +from .object import ObjectBase, PyNativeObject, _set_class_object_generic from .ndarray import NDArrayBase from .packed_func import PackedFuncBase, convert_to_tvm_func from .module import Module @@ -34,11 +34,11 @@ def asobject(self): raise NotImplementedError() -ObjectTypes = (ObjectBase, NDArrayBase, Module, ObjectRValueRef) +ObjectTypes = (ObjectBase, NDArrayBase, Module, ObjectRValueRef, PyNativeObject) def convert_to_object(value): - """Convert a python value to corresponding object type. + """Convert a Python value to corresponding object type. Parameters ---------- diff --git a/python/tvm/runtime/packed_func.py b/python/tvm/runtime/packed_func.py index a04e32be0ea2..af4265a66ad1 100644 --- a/python/tvm/runtime/packed_func.py +++ b/python/tvm/runtime/packed_func.py @@ -44,8 +44,6 @@ class PackedFunc(PackedFuncBase): The compiled module returns Function. TVM backend also registers and exposes its API as Functions. - For example, the developer function exposed in tvm.ir_pass are actually - C++ functions that are registered as PackedFunc The following are list of common usage scenario of tvm.runtime.PackedFunc. diff --git a/python/tvm/target/__init__.py b/python/tvm/target/__init__.py index 6b86ff0d0c66..2553fedb9869 100644 --- a/python/tvm/target/__init__.py +++ b/python/tvm/target/__init__.py @@ -55,10 +55,9 @@ We can also use other specific function in this module to create specific targets. """ from .target import Target, create -from .target import cuda, rocm, mali, intel_graphics, opengl, arm_cpu, rasp, vta, bifrost, hexagon +from .target import cuda, rocm, mali, intel_graphics, arm_cpu, rasp, vta, bifrost, hexagon from .generic_func import GenericFunc from .generic_func import generic_func, get_native_generic_func, override_native_generic_func from . import datatype from . import codegen from .intrin import register_intrin_rule -from .build_config import BuildConfig, build_config diff --git a/python/tvm/target/arm_isa.py b/python/tvm/target/arm_isa.py new file mode 100644 index 000000000000..c40296e50713 --- /dev/null +++ b/python/tvm/target/arm_isa.py @@ -0,0 +1,34 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Defines functions to analyze available opcodes in the ARM ISA.""" + + +ARM_ISA_MAP = { + 'armv7e-m': ['SMLAD'], +} + + +class IsaAnalyzer(object): + + def __init__(self, target): + self.target = target + # TODO: actually parse -mcpu + arch = 'armv7e-m' + self._isa_map = ARM_ISA_MAP[arch] + + def __contains__(self, instruction): + return instruction in self._isa_map diff --git a/python/tvm/target/build_config.py b/python/tvm/target/build_config.py deleted file mode 100644 index 8aae6be54a8b..000000000000 --- a/python/tvm/target/build_config.py +++ /dev/null @@ -1,248 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Target dependent BuildConfig for low-level passes.""" -# TODO(tvm-team) consolidate with PassContext -import tvm._ffi -import tvm.ir - -from tvm.runtime import Object -from tvm.tir import Stmt -from . import _ffi_api - - -class DumpIR(object): - """ - Dump IR for each pass. - With it, you can dump ir just like gcc/llvm. - - How to use: - ----------- - .. code-block:: python - - with tvm.target.build_config(dump_pass_ir=True) - run() - """ - scope_level = 0 - def __init__(self): - self._pass_id = 0 - self._recover_list = [] - - def decorate(self, func): - """ decorate the pass function""" - def dump(*args, **kwargs): - """dump function""" - retv = func(*args, **kwargs) - if not isinstance(retv, (Stmt,)): - return retv - fname = func.func_name if hasattr(func, 'func_name') else func.__name__ - pname = str(self._pass_id) + "_" + fname + "_ir.cc" - with open(pname, "a") as f: - out = retv - f.write(str(out)) - self._pass_id += 1 - return retv - return dump - - def decorate_irpass(self): - """decorate ir_pass and ScheduleOps""" - self._old_sgpass = tvm.te.schedule.ScheduleOps - tvm.te.schedule.ScheduleOps = self.decorate(tvm.te.schedule.ScheduleOps) - vset = vars(tvm.tir.ir_pass) - k = v = 0 - def recover(): - vset[k] = v - for k, v in vset.items(): - self._recover_list.append(recover) - vset[k] = self.decorate(v) if isinstance(v, tvm.runtime.PackedFunc) else v - - def decorate_custompass(self, custom_pass): - """decorate given list of custom passes, and return decorated passes""" - custom_pass = custom_pass if custom_pass else [] - pass_list = [] - for idx, x in enumerate(custom_pass): - x[1].__name__ = "custom{}_phase{}".format(idx, x[0]) - pass_list += [(x[0], self.decorate(x[1]))] - return pass_list - - def enter(self): - """only decorate outermost nest""" - if DumpIR.scope_level > 0: - return - self.decorate_irpass() - self._pass_id = 0 - DumpIR.scope_level += 1 - - def exit(self): - """recover outermost nest""" - if DumpIR.scope_level > 1: - return - # recover decorated functions - for f in self._recover_list: - f() - tvm.te.schedule.ScheduleOps = self._old_sgpass - DumpIR.scope_level -= 1 - - -@tvm._ffi.register_object -class BuildConfig(Object): - """Configuration scope to set a build config option. - - Note - ---- - This object is backed by object protocol in C++, with arguments that can be - exchanged between python and C++. - - Do not construct directly, use build_config instead. - - The fields that are backed by the C++ object are immutable once an instance - is constructed. See _object_defaults for the fields. - """ - - _object_defaults = { - "auto_unroll_max_step": 0, - "auto_unroll_max_depth": 8, - "auto_unroll_max_extent": 0, - "unroll_explicit": True, - "detect_global_barrier": False, - "partition_const_loop": False, - "offset_factor": 0, - "data_alignment": -1, - "restricted_func": True, - "double_buffer_split_loop": 1, - "dump_pass_ir": False, - "instrument_bound_checkers": False, - "disable_select_rewriting": False, - "disable_vectorize": False, - "disable_assert": False - } - _dump_ir = DumpIR() - - # pylint: disable=no-member - def __init__(self, handle): - """Initialize the function with handle - - Parameters - ---------- - handle : SymbolHandle - the handle to the underlying C++ Symbol - """ - super(BuildConfig, self).__init__(handle) - self.handle = handle - - @property - def add_lower_pass(self): - size = _ffi_api.BuildConfigGetAddLowerPassInfo(self) - result = [] - for i in range(size): - phase = _ffi_api.BuildConfigGetAddLowerPassInfo(self, i, True) - func = _ffi_api.BuildConfigGetAddLowerPassInfo(self, i, False) - result += [(phase, func)] - return result - - @add_lower_pass.setter - def add_lower_pass(self, value): - add_lower_pass_args = [] - for x in value: - add_lower_pass_args += [x[0], x[1]] - _ffi_api.BuildConfigSetAddLowerPass(self, *add_lower_pass_args) - - def __enter__(self): - # pylint: disable=protected-access - _ffi_api.EnterBuildConfigScope(self) - if self.dump_pass_ir: - BuildConfig._dump_ir.enter() - return self - - def __exit__(self, ptype, value, trace): - if self.dump_pass_ir: - BuildConfig._dump_ir.exit() - _ffi_api.ExitBuildConfigScope(self) - - def __setattr__(self, name, value): - if name in BuildConfig._object_defaults: - raise AttributeError( - "'%s' object cannot set attribute '%s'" % (str(type(self)), name)) - return super(BuildConfig, self).__setattr__(name, value) - - @staticmethod - def current(): - """Get the current build configuration.""" - return _ffi_api.GetCurrentBuildConfig() - - -def build_config(**kwargs): - """Configure the build behavior by setting config variables. - - Parameters - ---------- - auto_unroll_max_step: int, default=0 - Threshold of number of steps in the loop to be automatically unrolled. - This takes inner loop count into consideration. - - auto_unroll_max_depth: int, default=8 - The maximum nested level of loops that can be automatically unrolled. - - unroll_explicit: bool, default=True - Whether explicitly unroll the loop, if set false, the unroll hint will - be passed to the CodeGen phase, which may generate pragma unroll hint. - Set this to be true if CodeGen support unroll pragma and - when we want to be more readable. - - detect_global_barrier: bool, default=True - Whether detect global barrier. - - partition_const_loop: bool, default=False - Whether partition const loop - - data_alignment: int, optional - The alignment of data pointer in bytes. - If -1 is passed, the alignment will be set to TVM's internal default. - - offset_factor: int, default=0 - The factor used in default buffer declaration. - If specified as 0, offset field is not used. - - restricted_func: bool, default=True - Whether build restricted function. - That is each buffer argument to the function are guaranteed - not to overlap. This enables more optimization. - Corresponds to restricted keyword in C99 - - double_buffer_split_loop: int, default=2 - Whether split the loop with factor. If it is zero, no splitting will happen. - It it is bigger than one, the logic will do a split with factor equals the integer - and unroll the inner loop. This allows the buffer fetching won't contain condition. - - add_lower_pass: list of tuple (phase, function(Stmt->Stmt)), default=None - phase contains an integer on which optimization pass we apply the pass. - Additional lowering passes to be applied before make_api. - - dump_pass_ir: dump ir of each pass into file idx_passname_ir.cc, default=False - - Returns - ------- - config: BuildConfig - The build configuration - """ - node_args = {k: v if k not in kwargs else kwargs[k] - for k, v in BuildConfig._object_defaults.items()} - config = tvm.ir.make_node("BuildConfig", **node_args) - - if "add_lower_pass" in kwargs: - config.add_lower_pass = kwargs["add_lower_pass"] - - return config diff --git a/python/tvm/target/datatype.py b/python/tvm/target/datatype.py index 328568a360bc..e42ac6b37806 100644 --- a/python/tvm/target/datatype.py +++ b/python/tvm/target/datatype.py @@ -88,7 +88,7 @@ def register_op(lower_func, op_name, target, type_name, src_type_name=None): op_name : str The name of the operation which the function computes, given by its - Halide::Internal class name (e.g. Add, LE, Cast). + class name (e.g. Add, LE, Cast). target : str The name of codegen target. @@ -136,8 +136,8 @@ def lower(op): dtype += "x" + str(t.lanes) if isinstance(op, (_Cast, _FloatImm)): return _Call(dtype, extern_func_name, convert([op.value]), - _Call.Extern, None, 0) + _Call.Extern) return _Call(dtype, extern_func_name, convert([op.a, op.b]), - _Call.Extern, None, 0) + _Call.Extern) return lower diff --git a/python/tvm/target/target.py b/python/tvm/target/target.py index fd15ff916ae1..3335e12ba5f6 100644 --- a/python/tvm/target/target.py +++ b/python/tvm/target/target.py @@ -172,18 +172,6 @@ def intel_graphics(model='unknown', options=None): return _ffi_api.TargetCreate("opencl", *opts) -def opengl(model='unknown', options=None): - """Returns a OpenGL target. - - Parameters - ---------- - options : str or list of str - Additional options - """ - opts = _merge_opts(["-model=%s" % model], options) - return _ffi_api.TargetCreate("opengl", *opts) - - def arm_cpu(model='unknown', options=None): """Returns a ARM CPU target. This function will also download pre-tuned op parameters when there is none. diff --git a/python/tvm/te/__init__.py b/python/tvm/te/__init__.py index 0016160860fd..939956c1a005 100644 --- a/python/tvm/te/__init__.py +++ b/python/tvm/te/__init__.py @@ -19,7 +19,9 @@ """ # expose all operators in tvm tir.op from tvm.tir import any, all, min_value, max_value, trace -from tvm.tir import exp, erf, tanh, sigmoid, log, tan, cos, sin, atan, sqrt, rsqrt, floor, ceil +from tvm.tir import exp, erf, tanh, sigmoid, log, tan, cos, sin, sqrt, rsqrt, floor, ceil +from tvm.tir import sinh, cosh, log2, log10 +from tvm.tir import asin, asinh, acos, acosh, atan, atanh from tvm.tir import trunc, abs, round, nearbyint, power, popcount, fmod, if_then_else from tvm.tir import isnan, isfinite, isinf from tvm.tir import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod diff --git a/python/tvm/te/hybrid/__init__.py b/python/tvm/te/hybrid/__init__.py index 31acaeb66618..42bcc86f7945 100644 --- a/python/tvm/te/hybrid/__init__.py +++ b/python/tvm/te/hybrid/__init__.py @@ -30,7 +30,7 @@ # 2. Support multi-level HalideIR import inspect import tvm._ffi -from tvm.driver.build_module import form_body +import tvm.te.schedule from tvm._ffi.base import decorate from .module import HybridModule @@ -87,8 +87,10 @@ def build(sch, inputs, outputs, name="hybrid_func"): The built results is wrapped in a HybridModule. The usage of HybridModule is roughly the same as normal TVM-built modules. """ + sch = sch.normalize() + bounds = tvm.te.schedule.InferBound(sch) + stmt = tvm.te.schedule.ScheduleOps(sch, bounds) - stmt = form_body(sch) src = _Dump(stmt, inputs, outputs, name) return HybridModule(src, name) diff --git a/python/tvm/te/hybrid/parser.py b/python/tvm/te/hybrid/parser.py index 107f51b8bbcc..913b4534eea6 100644 --- a/python/tvm/te/hybrid/parser.py +++ b/python/tvm/te/hybrid/parser.py @@ -29,10 +29,10 @@ import tvm.tir import tvm.te import tvm.te._ffi_api +import tvm.arith from tvm.tir import expr as _expr from tvm.tir import stmt as _stmt -from tvm.tir import ir_pass as _ir_pass from tvm.te.tensor import Tensor, Operation from tvm.tir import all as _all from tvm.tir import any as _any @@ -160,6 +160,7 @@ def __init__(self, args, usage, symbols, closure_vars, func_name=None): self.outputs = [] # Output tensors' name self.side_effect = set() # Tensors with side effects self.parsed_body = None # The parsed HalideIR body + self.analyzer = tvm.arith.Analyzer() self.returned = False # If this function has a valid return @@ -211,7 +212,7 @@ def wrap_up_realize(self, node, body): _domain = [Range.make_by_min_extent(0, i) for i in _buf.shape] _dtype = _buf.dtype _true = tvm.runtime.convert(True) - body = tvm.tir.Realize(_buf.op, 0, _dtype, _domain, _true, body) + body = tvm.tir.ProducerRealize(_buf, _domain, _true, body) body = tvm.tir.AttrStmt(_buf.op, 'realize_scope', tvm.runtime.convert(_scope), body) for elem in to_pop: @@ -271,8 +272,7 @@ def visit_Name(self, node): return entry if isinstance(node.ctx, ast.Load) else None if ty is Symbol.BufferVar: if isinstance(node.ctx, ast.Load): - return tvm.tir.Call(entry.dtype, entry.name, [tvm.runtime.const(0, 'int32')], \ - _expr.Call.Halide, entry.op, entry.value_index) + return tvm.tir.ProducerLoad(entry, [tvm.runtime.const(0, 'int32')]) return entry, [tvm.runtime.const(0, 'int32')] # Do I need any assertion here? return entry @@ -304,10 +304,10 @@ def visit_AugAssign(self, node): args = [tvm.runtime.const(0, 'int32')] _internal_assert(isinstance(buf, Tensor), "LHS is supposed to be Tensor!") - read = tvm.tir.Call(buf.dtype, buf.name, args, _expr.Call.Halide, buf.op, buf.value_index) + read = tvm.tir.ProducerLoad(buf, args) value = HybridParser._binop_maker[type(node.op)](read, rhs) - return tvm.tir.Provide(buf.op, 0, value, args) + return tvm.tir.ProducerStore(buf, value, args) def visit_Assign(self, node): @@ -326,7 +326,7 @@ def visit_Assign(self, node): _internal_assert(len(node.targets) == 1, "So far only one-valued assignment is supported!") lhs = node.targets[0] if isinstance(rhs, _expr.PrimExpr): - rhs = _ir_pass.Simplify(rhs) + rhs = self.analyzer.simplify(rhs) if isinstance(lhs, ast.Name): #TODO: support defined intermediate buffer later lhs_ = lhs @@ -358,13 +358,13 @@ def visit_Assign(self, node): lhs = self.visit(lhs_) if lhs is not None: buf, args = lhs - return tvm.tir.Provide(buf.op, 0, rhs, args) + return tvm.tir.ProducerStore(buf, rhs, args) return util.make_nop() lhs, args = self.visit(lhs) _internal_assert(isinstance(lhs, Tensor), \ "An array access's LHS is expected to be a expr.Call!") - res = tvm.tir.Provide(lhs.op, lhs.value_index, rhs, args) + res = tvm.tir.ProducerStore(lhs, rhs, args) return res @@ -391,8 +391,7 @@ def visit_Subscript(self, node): arr = arr[i.value] return arr if isinstance(node.ctx, ast.Load): - return tvm.tir.Call(arr.dtype, arr.name, args, - _expr.Call.Halide, arr.op, arr.value_index) + return tvm.tir.ProducerLoad(arr, args) return arr, args def visit_With(self, node): @@ -410,7 +409,7 @@ def visit_With(self, node): def visit_If(self, node): - cond = _ir_pass.CanonicalSimplify(self.visit(node.test)) + cond = self.analyzer.simplify(self.visit(node.test)) # Return no IfThenElse if proven if isinstance(cond, _expr.IntImm): @@ -501,8 +500,8 @@ def visit_For(self, node): _name = node.target.id if isinstance(for_type, tuple): - low = _ir_pass.CanonicalSimplify(low) - ext = _ir_pass.CanonicalSimplify(ext) + low = self.analyzer.simplify(low) + ext = self.analyzer.simplify(ext) _internal_assert(isinstance(low, _expr.ConstExpr) and isinstance(ext, _expr.ConstExpr), \ "Const range should start from a const " + \ diff --git a/python/tvm/te/hybrid/util.py b/python/tvm/te/hybrid/util.py index 6c019893bf20..891d7baf893e 100644 --- a/python/tvm/te/hybrid/util.py +++ b/python/tvm/te/hybrid/util.py @@ -72,19 +72,18 @@ def _pruned_source(func): def replace_io(body, rmap): """Replacing tensors usage according to the dict given""" # pylint: disable=import-outside-toplevel - from tvm.tir import ir_pass + from tvm.tir import stmt_functor def replace(op): - if isinstance(op, _stmt.Provide) and op.func in rmap.keys(): - buf = rmap[op.func] - return _stmt.Provide(buf.op, op.value_index, op.value, op.args) - if isinstance(op, _expr.Call) and op.func in rmap.keys(): - buf = rmap[op.func] - return _expr.Call(buf.dtype, buf.name, op.args, \ - _expr.Call.Halide, buf.op, buf.value_index) + if isinstance(op, _stmt.ProducerStore) and op.producer.op in rmap.keys(): + buf = rmap[op.producer.op] + return _stmt.ProducerStore(buf, op.value, op.indices) + if isinstance(op, _expr.ProducerLoad) and op.producer.op in rmap.keys(): + buf = rmap[op.producer.op] + return _expr.ProducerLoad(buf, op.indices) return None - return ir_pass.IRTransform(body, None, replace, ['Provide', 'Call']) + return stmt_functor.ir_transform(body, None, replace, ['tir.ProducerStore', 'tir.ProducerLoad']) def _is_tvm_arg_types(args): diff --git a/python/tvm/te/schedule.py b/python/tvm/te/schedule.py index f8bbe09725f2..b61195472560 100644 --- a/python/tvm/te/schedule.py +++ b/python/tvm/te/schedule.py @@ -509,13 +509,6 @@ def double_buffer(self): """ _ffi_api.StageDoubleBuffer(self) - def opengl(self): - """The special OpenGL schedule - - Maps each output element to a pixel. - """ - _ffi_api.StageOpenGL(self) - @tvm._ffi.register_object class SpecializedCondition(Object): diff --git a/python/tvm/te/tensor.py b/python/tvm/te/tensor.py index 739268aba4a5..7d73bf42ab7d 100644 --- a/python/tvm/te/tensor.py +++ b/python/tvm/te/tensor.py @@ -19,7 +19,7 @@ import tvm._ffi from tvm.runtime import Object, ObjectGeneric, convert_to_object -from tvm.tir import expr as _expr +from tvm.tir import expr as _expr, DataProducer from . import _ffi_api @@ -52,7 +52,7 @@ class TensorIntrinCall(Object): @tvm._ffi.register_object -class Tensor(Object, _expr.ExprOp): +class Tensor(DataProducer, _expr.ExprOp): """Tensor object, to construct, see function.Tensor""" def __call__(self, *indices): @@ -69,9 +69,8 @@ def __call__(self, *indices): else: raise ValueError("The indices must be expression") - return _expr.Call(self.dtype, self.op.name, - args, _expr.Call.Halide, - self.op, self.value_index) + return _expr.ProducerLoad(self, args) + def __getitem__(self, indices): return TensorSlice(self, indices) diff --git a/python/tvm/te/tensor_intrin.py b/python/tvm/te/tensor_intrin.py index c5c2afef1c93..cd488a7fbd14 100644 --- a/python/tvm/te/tensor_intrin.py +++ b/python/tvm/te/tensor_intrin.py @@ -20,7 +20,6 @@ from tvm.runtime import Object, convert from tvm.ir import Range -from tvm.target import BuildConfig from .tensor import PlaceholderOp from . import tensor as _tensor @@ -68,7 +67,9 @@ def __call__(self, *args, **kwargs): def decl_tensor_intrin(op, fcompute, name="tensor_intrin", - binds=None, scalar_params=None): + binds=None, + scalar_params=None, + default_buffer_params=None): """Declare a tensor intrinsic function. Parameters @@ -104,6 +105,9 @@ def decl_tensor_intrin(op, scalar_params: a list of variables used by op, whose values will be passed as scalar_inputs when the tensor intrinsic is called. + default_buffer_params: Optional[dict] + Dictionary of buffer arguments to be passed when constructing a buffer. + Returns ------- intrin: TensorIntrin @@ -122,12 +126,11 @@ def decl_tensor_intrin(op, if not isinstance(t.op, PlaceholderOp): raise ValueError("Do not yet support composition op") - cfg = BuildConfig.current() + default_buffer_params = {} if default_buffer_params is None else default_buffer_params for t in tensors: buf = (binds[t] if t in binds else tvm.tir.decl_buffer(t.shape, t.dtype, t.op.name, - data_alignment=cfg.data_alignment, - offset_factor=cfg.offset_factor)) + **default_buffer_params)) binds_list.append(buf) if scalar_params: diff --git a/python/tvm/testing.py b/python/tvm/testing.py index 0f50636d68d8..5a3d394c098f 100644 --- a/python/tvm/testing.py +++ b/python/tvm/testing.py @@ -20,6 +20,8 @@ import logging import numpy as np import tvm +import tvm.arith +import tvm.tir import tvm._ffi @@ -168,4 +170,23 @@ def compare_derivative(j, n_der, grad): x_name, grad.shape, dist, max_diff, avg_diff) +def assert_prim_expr_equal(lhs, rhs): + """Assert lhs and rhs equals to each iother. + + Parameters + ---------- + lhs : tvm.tir.PrimExpr + The left operand. + + rhs : tvm.tir.PrimExpr + The left operand. + """ + ana = tvm.arith.Analyzer() + res = ana.simplify(lhs - rhs) + equal = isinstance(res, tvm.tir.IntImm) and res.value == 0 + if not equal: + raise ValueError("{} and {} are not equal".format(lhs, rhs)) + + + tvm._ffi._init_api("testing", __name__) diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index d2238ad754ac..982b31cc2f54 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -19,16 +19,17 @@ from tvm.ir import PrimExpr from tvm.runtime import const -from .buffer import Buffer, decl_buffer +from .buffer import Buffer, decl_buffer, DataProducer from .data_layout import Layout, BijectiveLayout, bijective_layout, layout from .expr import Var, SizeVar, Reduce, FloatImm, IntImm, StringImm, Cast from .expr import Add, Sub, Mul, Div, Mod, FloorDiv, FloorMod from .expr import Min, Max, EQ, NE, LT, LE, GT, GE, And, Or, Not -from .expr import Select, BufferLoad, Load, Ramp, Broadcast, Shuffle, Call, Let +from .expr import Select, BufferLoad, ProducerLoad, Load, Ramp, Broadcast, Shuffle, Call, Let from .expr import IterVar, Any from .stmt import Stmt, LetStmt, AssertStmt, For -from .stmt import BufferStore, Store, Provide, Allocate, AttrStmt, Free, Realize, SeqStmt +from .stmt import BufferStore, BufferRealize, Store, ProducerStore, Allocate, AttrStmt +from .stmt import Free, ProducerRealize, SeqStmt from .stmt import IfThenElse, Evaluate, Prefetch, stmt_seq, stmt_list from .function import PrimFunc @@ -36,7 +37,9 @@ from .op import call_packed, call_pure_intrin, call_intrin, call_pure_extern, call_extern from .op import call_llvm_intrin, all, any, min_value, max_value, trace from .op import exp, exp2, exp10, log, log2, log10, log1p, ldexp -from .op import cos, sin, cosh, sinh, tan, tanh, atan, atan2 +from .op import sin, sinh, asin, asinh +from .op import cos, cosh, acos, acosh +from .op import tan, tanh, atan, atan2, atanh from .op import erf, sigmoid, sqrt, rsqrt, floor, ceil, hypot from .op import trunc, abs, round, nextafter, nearbyint, power, popcount, fmod, if_then_else from .op import isnan, isfinite, isinf, copysign @@ -44,6 +47,6 @@ from .op import comm_reducer, min, max, sum from . import ir_builder -from . import ir_pass from . import transform from . import analysis +from . import stmt_functor diff --git a/python/tvm/tir/analysis/analysis.py b/python/tvm/tir/analysis/analysis.py index 448d0e6c5f8e..1a3eb4806677 100644 --- a/python/tvm/tir/analysis/analysis.py +++ b/python/tvm/tir/analysis/analysis.py @@ -57,12 +57,52 @@ def expr_deep_equal(lhs, rhs): return _ffi_api.expr_deep_equal(lhs, rhs) -def verify_memory(mod): +def verify_ssa(func): + """Verify if the func is in SSA form. + + Parameters + ---------- + func: tvm.tir.PrimFunc + The module to be verified. + + Returns + ------- + result : bool + The result of verification. + """ + return _ffi_api.verify_ssa(func) + + +def verify_memory(func): + """Verify if func contains illegal host side direct memory access. + + Parameters + ---------- + func: tvm.tir.PrimFunc + The module to be verified. + + Returns + ------- + result : bool + The result of verification. + """ + return _ffi_api.verify_memory(func) + + +def verify_gpu_code(func, constraints): """Verify if module contains illegal host side direct memory access. Parameters ---------- - mod: tvm.IRModule + func: tvm.tir.PrimFunc The module to be verified. + + constraints : Dict[str, int] + The attribute constraints. + + Returns + ------- + result : bool + The result of verification. """ - _ffi_api.verify_memory(mod) + return _ffi_api.verify_gpu_code(func, constraints) diff --git a/python/tvm/tir/buffer.py b/python/tvm/tir/buffer.py index 0c7753e4d8ec..11bfb4c55921 100644 --- a/python/tvm/tir/buffer.py +++ b/python/tvm/tir/buffer.py @@ -24,7 +24,7 @@ from . import _ffi_api -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.Buffer") class Buffer(Object): """Symbolic data buffer in TVM. @@ -245,3 +245,8 @@ def decl_buffer(shape, return _ffi_api.Buffer( data, dtype, shape, strides, elem_offset, name, scope, data_alignment, offset_factor, buffer_type) + + +@tvm._ffi.register_object("tir.DataProducer") +class DataProducer(Object): + pass diff --git a/python/tvm/tir/data_layout.py b/python/tvm/tir/data_layout.py index fd8c7a942297..161647377e37 100644 --- a/python/tvm/tir/data_layout.py +++ b/python/tvm/tir/data_layout.py @@ -20,7 +20,7 @@ from tvm.runtime import Object from . import _ffi_api -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.Layout") class Layout(Object): """Layout is composed of upper cases, lower cases and numbers, where upper case indicates a primal axis and @@ -77,7 +77,7 @@ def factor_of(self, axis): return _ffi_api.LayoutFactorOf(self, axis) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.BijectiveLayout") class BijectiveLayout(Object): """Bijective mapping for two layouts (src-layout and dst-layout). It provides shape and index conversion between each other. diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py index 4cbece363f71..f8cb05431a5b 100644 --- a/python/tvm/tir/expr.py +++ b/python/tvm/tir/expr.py @@ -29,7 +29,7 @@ """ import tvm._ffi -from tvm.runtime import Object, ObjectGeneric, DataType, TypeCode, const +from tvm.runtime import Object, ObjectGeneric, DataType, DataTypeCode, const from tvm.ir import PrimExpr import tvm.ir._ffi_api from . import generic as _generic @@ -47,13 +47,13 @@ def _dtype_is_int(value): if isinstance(value, int): return True return (isinstance(value, ExprOp) and - DataType(value.dtype).type_code == TypeCode.INT) + DataType(value.dtype).type_code == DataTypeCode.INT) def _dtype_is_float(value): if isinstance(value, float): return True return (isinstance(value, ExprOp) and - DataType(value.dtype).type_code == TypeCode.FLOAT) + DataType(value.dtype).type_code == DataTypeCode.FLOAT) class ExprOp(object): """Operator overloading for Expr like expressions.""" @@ -144,7 +144,7 @@ def __rxor__(self, other): def __invert__(self): if _dtype_is_float(self): raise RuntimeError("Cannot use ~ operator on float type Expr.") - return _ffi_api.Call(self.dtype, "bitwise_not", [self], Call.PureIntrinsic, None, 0) + return _ffi_api.Call(self.dtype, "bitwise_not", [self], Call.PureIntrinsic) def __lt__(self, other): return _ffi_api._OpLT(self, other) @@ -321,7 +321,7 @@ def __init__(self, name, dtype): _ffi_api.SizeVar, name, dtype) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.IterVar") class IterVar(Object, ExprOp): """Represent iteration variable. @@ -373,7 +373,7 @@ def __init__(self, dom, var, iter_type, thread_tag=""): _ffi_api.IterVar, dom, var, iter_type, thread_tag) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.CommReducer") class CommReducer(Object): """Communicative reduce operator @@ -396,7 +396,7 @@ def __init__(self, lhs, rhs, result, identity_element): _ffi_api.CommReducer, lhs, rhs, result, identity_element) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.Reduce") class Reduce(PrimExprWithOp): """Reduce node. @@ -475,7 +475,7 @@ def __bool__(self): return self.__nonzero__() -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.StringImm") class StringImm(ConstExpr): """String constant. @@ -499,7 +499,7 @@ def __ne__(self, other): return self.value != other -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.Cast") class Cast(PrimExprWithOp): """Cast expression. @@ -516,7 +516,7 @@ def __init__(self, dtype, value): _ffi_api.Cast, dtype, value) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.Add") class Add(BinaryOpExpr): """Add node. @@ -533,7 +533,7 @@ def __init__(self, a, b): _ffi_api.Add, a, b) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.Sub") class Sub(BinaryOpExpr): """Sub node. @@ -550,7 +550,7 @@ def __init__(self, a, b): _ffi_api.Sub, a, b) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.Mul") class Mul(BinaryOpExpr): """Mul node. @@ -567,7 +567,7 @@ def __init__(self, a, b): _ffi_api.Mul, a, b) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.Div") class Div(BinaryOpExpr): """Div node. @@ -584,7 +584,7 @@ def __init__(self, a, b): _ffi_api.Div, a, b) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.Mod") class Mod(BinaryOpExpr): """Mod node. @@ -601,7 +601,7 @@ def __init__(self, a, b): _ffi_api.Mod, a, b) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.FloorDiv") class FloorDiv(BinaryOpExpr): """FloorDiv node. @@ -618,7 +618,7 @@ def __init__(self, a, b): _ffi_api.FloorDiv, a, b) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.FloorMod") class FloorMod(BinaryOpExpr): """FloorMod node. @@ -635,7 +635,7 @@ def __init__(self, a, b): _ffi_api.FloorMod, a, b) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.Min") class Min(BinaryOpExpr): """Min node. @@ -652,7 +652,7 @@ def __init__(self, a, b): _ffi_api.Min, a, b) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.Max") class Max(BinaryOpExpr): """Max node. @@ -669,7 +669,7 @@ def __init__(self, a, b): _ffi_api.Max, a, b) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.EQ") class EQ(CmpExpr): """EQ node. @@ -686,7 +686,7 @@ def __init__(self, a, b): _ffi_api.EQ, a, b) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.NE") class NE(CmpExpr): """NE node. @@ -703,7 +703,7 @@ def __init__(self, a, b): _ffi_api.NE, a, b) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.LT") class LT(CmpExpr): """LT node. @@ -720,7 +720,7 @@ def __init__(self, a, b): _ffi_api.LT, a, b) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.LE") class LE(CmpExpr): """LE node. @@ -737,7 +737,7 @@ def __init__(self, a, b): _ffi_api.LE, a, b) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.GT") class GT(CmpExpr): """GT node. @@ -754,7 +754,7 @@ def __init__(self, a, b): _ffi_api.GT, a, b) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.GE") class GE(CmpExpr): """GE node. @@ -771,7 +771,7 @@ def __init__(self, a, b): _ffi_api.GE, a, b) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.And") class And(LogicalExpr): """And node. @@ -788,7 +788,7 @@ def __init__(self, a, b): _ffi_api.And, a, b) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.Or") class Or(LogicalExpr): """Or node. @@ -805,7 +805,7 @@ def __init__(self, a, b): _ffi_api.Or, a, b) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.Not") class Not(LogicalExpr): """Not node. @@ -819,7 +819,7 @@ def __init__(self, a): _ffi_api.Not, a) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.Select") class Select(PrimExprWithOp): """Select node. @@ -847,7 +847,7 @@ def __init__(self, condition, true_value, false_value): _ffi_api.Select, condition, true_value, false_value) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.Load") class Load(PrimExprWithOp): """Load node. @@ -871,7 +871,7 @@ def __init__(self, dtype, buffer_var, index, predicate=None): _ffi_api.Load, dtype, buffer_var, index, *args) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.BufferLoad") class BufferLoad(PrimExprWithOp): """Buffer load node. @@ -888,7 +888,24 @@ def __init__(self, buffer, indices): _ffi_api.BufferLoad, buffer, indices) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.ProducerLoad") +class ProducerLoad(PrimExprWithOp): + """Producer load node. + + Parameters + ---------- + producer : DataProducer + The buffer to be loaded. + + indices : List[PrimExpr] + The buffer indices. + """ + def __init__(self, producer, indices): + self.__init_handle_by_constructor__( + _ffi_api.ProducerLoad, producer, indices) + + +@tvm._ffi.register_object("tir.Ramp") class Ramp(PrimExprWithOp): """Ramp node. @@ -908,7 +925,7 @@ def __init__(self, base, stride, lanes): _ffi_api.Ramp, base, stride, lanes) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.Broadcast") class Broadcast(PrimExprWithOp): """Broadcast node. @@ -925,7 +942,7 @@ def __init__(self, value, lanes): _ffi_api.Broadcast, value, lanes) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.Shuffle") class Shuffle(PrimExprWithOp): """Shuffle node. @@ -942,7 +959,7 @@ def __init__(self, vectors, indices): _ffi_api.Shuffle, vectors, indices) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.Call") class Call(PrimExprWithOp): """Call node. @@ -959,25 +976,18 @@ class Call(PrimExprWithOp): call_type : int The type of the call - - func : Operation, optional - Operation if call_type is Halide - - value_index : int - The output value index """ Extern = 0 ExternCPlusPlus = 1 PureExtern = 2 - Halide = 3 Intrinsic = 4 PureIntrinsic = 5 - def __init__(self, dtype, name, args, call_type, func, value_index): + def __init__(self, dtype, name, args, call_type): self.__init_handle_by_constructor__( - _ffi_api.Call, dtype, name, args, call_type, func, value_index) + _ffi_api.Call, dtype, name, args, call_type) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.Let") class Let(PrimExprWithOp): """Let node. @@ -997,7 +1007,7 @@ def __init__(self, var, value, body): _ffi_api.Let, var, value, body) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.Any") class Any(PrimExpr): """Any node. """ diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py index 4ec1a71f345e..47ad94f503d8 100644 --- a/python/tvm/tir/function.py +++ b/python/tvm/tir/function.py @@ -67,3 +67,19 @@ def __init__(self, self.__init_handle_by_constructor__( _ffi_api.PrimFunc, param_list, body, ret_type, buffer_map, attrs) + + def with_body(self, new_body): + """Create a new PrimFunc with the same set signatures but a new body. + + Parameters + ---------- + new_body : Stmt + The new body. + + Returns + ------- + new_func : PrimFunc + The created new function. + """ + return PrimFunc( + self.params, new_body, self.ret_type, self.buffer_map, self.attrs) diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py index 0c4c36888eb5..47ba2e2c805c 100644 --- a/python/tvm/tir/ir_builder.py +++ b/python/tvm/tir/ir_builder.py @@ -21,7 +21,6 @@ from . import stmt as _stmt from . import expr as _expr -from . import ir_pass as _pass class WithScope(object): @@ -212,7 +211,7 @@ def for_range(self, begin, end, name="i", dtype="int32", for_type="serial"): self.nidx += 1 self._seq_stack.append([]) loop_var = _expr.Var(name, dtype=dtype) - extent = end if begin == 0 else _pass.Simplify(end - begin) + extent = end if begin == 0 else (end - begin) def _exit_cb(): if for_type == "serial": for_type_id = 0 @@ -381,7 +380,7 @@ def likely(self, expr): The expression will likely tag. """ return _expr.Call(expr.dtype, "likely", [expr], - _expr.Call.PureIntrinsic, None, 0) + _expr.Call.PureIntrinsic) def get(self): """Return the builded IR. diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index ce3edee12f8c..929d422ccc43 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -30,9 +30,9 @@ def _pack_buffer(buf): """ assert buf.shape shape = Call("handle", "tvm_stack_make_shape", buf.shape, - Call.Intrinsic, None, 0) + Call.Intrinsic) strides = Call("handle", "tvm_stack_make_shape", buf.strides, - Call.Intrinsic, None, 0) if buf.strides else 0 + Call.Intrinsic) if buf.strides else 0 pack_args = [buf.data, shape, strides, @@ -40,7 +40,7 @@ def _pack_buffer(buf): const(0, dtype=buf.dtype), buf.elem_offset] return Call("handle", "tvm_stack_make_array", - pack_args, Call.Intrinsic, None, 0) + pack_args, Call.Intrinsic) def call_packed(*args): """Build expression by call an external packed function. @@ -68,7 +68,7 @@ def call_packed(*args): """ call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args] return Call( - "int32", "tvm_call_packed", call_args, Call.Intrinsic, None, 0) + "int32", "tvm_call_packed", call_args, Call.Intrinsic) def call_pure_intrin(dtype, func_name, *args): @@ -95,7 +95,7 @@ def call_pure_intrin(dtype, func_name, *args): """ args = convert(args) return Call( - dtype, func_name, convert(args), Call.PureIntrinsic, None, 0) + dtype, func_name, convert(args), Call.PureIntrinsic) def call_intrin(dtype, func_name, *args): @@ -122,7 +122,7 @@ def call_intrin(dtype, func_name, *args): """ args = convert(args) return Call( - dtype, func_name, convert(args), Call.Intrinsic, None, 0) + dtype, func_name, convert(args), Call.Intrinsic) def call_pure_extern(dtype, func_name, *args): @@ -145,7 +145,7 @@ def call_pure_extern(dtype, func_name, *args): The call expression. """ return Call( - dtype, func_name, convert(args), Call.PureExtern, None, 0) + dtype, func_name, convert(args), Call.PureExtern) def call_extern(dtype, func_name, *args): @@ -168,7 +168,7 @@ def call_extern(dtype, func_name, *args): The call expression. """ return Call( - dtype, func_name, convert(args), Call.Extern, None, 0) + dtype, func_name, convert(args), Call.Extern) def call_llvm_intrin(dtype, name, *args): @@ -278,7 +278,7 @@ def trace(args, trace_action="tvm.default_trace_action"): call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args] call_args.insert(0, trace_action) return tvm.tir.Call( - args[-1].dtype, "tvm_call_trace_packed", call_args, tvm.tir.Call.Intrinsic, None, 0) + args[-1].dtype, "tvm_call_trace_packed", call_args, tvm.tir.Call.Intrinsic) @@ -522,6 +522,38 @@ def cosh(x): return call_pure_intrin(x.dtype, "cosh", x) +def acos(x): + """Take acos of input x. + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + return call_pure_intrin(x.dtype, "acos", x) + + +def acosh(x): + """Take acos of input x. + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + return call_pure_intrin(x.dtype, "acosh", x) + + def sin(x): """Take sin of input x. @@ -539,7 +571,7 @@ def sin(x): def sinh(x): - """Take sin of input x. + """Take sinh of input x. Parameters ---------- @@ -554,6 +586,38 @@ def sinh(x): return call_pure_intrin(x.dtype, "sinh", x) +def asin(x): + """Take asin of input x. + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + return call_pure_intrin(x.dtype, "asin", x) + + +def asinh(x): + """Take asinh of input x. + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + return call_pure_intrin(x.dtype, "asinh", x) + + def atan(x): """Take atan of input x. @@ -570,6 +634,22 @@ def atan(x): return call_pure_intrin(x.dtype, "atan", x) +def atanh(x): + """Take atanh of input x. + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + return call_pure_intrin(x.dtype, "atanh", x) + + def atan2(x1, x2): """Take arctan2(x1, x2). diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py index c5b2a7957319..4536580737e5 100644 --- a/python/tvm/tir/stmt.py +++ b/python/tvm/tir/stmt.py @@ -36,7 +36,7 @@ class Stmt(Object): """Base class of all the statements.""" -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.LetStmt") class LetStmt(Stmt): """LetStmt node. @@ -56,7 +56,7 @@ def __init__(self, var, value, body): _ffi_api.LetStmt, var, value, body) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.AssertStmt") class AssertStmt(Stmt): """AssertStmt node. @@ -76,7 +76,7 @@ def __init__(self, condition, message, body): _ffi_api.AssertStmt, condition, message, body) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.For") class For(Stmt): """For node. @@ -116,7 +116,7 @@ def __init__(self, for_type, device_api, body) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.Store") class Store(Stmt): """Store node. @@ -140,7 +140,7 @@ def __init__(self, buffer_var, value, index, predicate=None): _ffi_api.Store, buffer_var, value, index, *args) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.BufferStore") class BufferStore(Stmt): """Buffer store node. @@ -160,30 +160,50 @@ def __init__(self, buffer, value, indices): _ffi_api.BufferStore, buffer, value, indices) -@tvm._ffi.register_object -class Provide(Stmt): - """Provide node. +@tvm._ffi.register_object("tir.BufferRealize") +class BufferRealize(Stmt): + """Buffer realize node. Parameters ---------- - func : Operation - The operation to create the function. + buffer : Buffer + The buffer. + + bounds : List[Range] + The value we to be stored. + + condition : PrimExpr + The realize condition. + + body : Stmt + The body of the statement. + """ + def __init__(self, buffer, bounds, condition, body): + self.__init_handle_by_constructor__( + _ffi_api.BufferRealize, buffer, bounds, condition, body) + - value_index : int - The output value index +@tvm._ffi.register_object("tir.ProducerStore") +class ProducerStore(Stmt): + """ProducerStore node. + + Parameters + ---------- + producer : DataProducer + The data producer. value : PrimExpr The value to be stored. - args : list of Expr - The index arguments of the Provide. + indices : list of Expr + The index arguments of the store. """ - def __init__(self, func, value_index, value, args): + def __init__(self, producer, value, indices): self.__init_handle_by_constructor__( - _ffi_api.Provide, func, value_index, value, args) + _ffi_api.ProducerStore, producer, value, indices) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.Allocate") class Allocate(Stmt): """Allocate node. @@ -215,7 +235,7 @@ def __init__(self, extents, condition, body) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.AttrStmt") class AttrStmt(Stmt): """AttrStmt node. @@ -238,7 +258,7 @@ def __init__(self, node, attr_key, value, body): _ffi_api.AttrStmt, node, attr_key, value, body) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.Free") class Free(Stmt): """Free node. @@ -252,20 +272,14 @@ def __init__(self, buffer_var): _ffi_api.Free, buffer_var) -@tvm._ffi.register_object -class Realize(Stmt): - """Realize node. +@tvm._ffi.register_object("tir.ProducerRealize") +class ProducerRealize(Stmt): + """ProducerRealize node. Parameters ---------- - func : Operation - The operation to create the function. - - value_index : int - The output value index - - dtype : str - The data type of the operation. + producer : DataProducer + The data producer. bounds : list of range The bound of realize @@ -277,18 +291,15 @@ class Realize(Stmt): The realize body """ def __init__(self, - func, - value_index, - dtype, + producer, bounds, condition, body): self.__init_handle_by_constructor__( - _ffi_api.Realize, func, value_index, dtype, - bounds, condition, body) + _ffi_api.ProducerRealize, producer, bounds, condition, body) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.SeqStmt") class SeqStmt(Stmt): """Sequence of statements. @@ -308,7 +319,7 @@ def __len__(self): return len(self.seq) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.IfThenElse") class IfThenElse(Stmt): """IfThenElse node. @@ -328,7 +339,7 @@ def __init__(self, condition, then_case, else_case): _ffi_api.IfThenElse, condition, then_case, else_case) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.Evaluate") class Evaluate(Stmt): """Evaluate node. @@ -342,27 +353,21 @@ def __init__(self, value): _ffi_api.Evaluate, value) -@tvm._ffi.register_object +@tvm._ffi.register_object("tir.Prefetch") class Prefetch(Stmt): """Prefetch node. Parameters ---------- - func : Operation - The operation to create the function. - - value_index : int - The output value index - - dtype : str - The data type to be prefetched. + buffer : Buffer + The buffer to be prefetched. bounds : list of Range The bounds to be prefetched. """ - def __init__(self, func, value_index, dtype, bounds): + def __init__(self, buffer, bounds): self.__init_handle_by_constructor__( - _ffi_api.Prefetch, func, value_index, dtype, bounds) + _ffi_api.Prefetch, buffer, bounds) def stmt_seq(*args): diff --git a/python/tvm/tir/stmt_functor.py b/python/tvm/tir/stmt_functor.py new file mode 100644 index 000000000000..cea8d1474621 --- /dev/null +++ b/python/tvm/tir/stmt_functor.py @@ -0,0 +1,77 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Statement functor utilities for IR transformations""" +from . import _ffi_api + + +def ir_transform(stmt, preorder, postorder, only_enable=None): + """Recursively visit and transform ir nodes in post DFS order. + + Parameters + ---------- + stmt : Stmt + The input to be transformed. + + preorder: function + The function called in before recursive mutation + If preorder returns None, then the transform will proceed to recursive call. + If preorder returns a not None Stmt/Expr, the transformer will simply return it and + won't do further recursion. + + postorder : function + The function called after recursive mutation. + + only_enable : Optional[List[str]] + List of types that we only enable. + + Returns + ------- + result : Stmt + The result. + """ + return _ffi_api.IRTransform(stmt, preorder, postorder, only_enable) + + +def post_order_visit(stmt, fvisit): + """Recursively visit the ir in post DFS order node, apply fvisit + Each node is guaranteed to be visited only once. + + Parameters + ---------- + fvisit: function + The visitor function. + """ + return _ffi_api.PostOrderVisit(stmt, fvisit) + + +def substitute(node, vmap): + """ Substitute the var specified by vmap. + + Parameters + ---------- + node: ObjectRef + The input. + + vmap : Dict[Var, PrimExpr] + The variable mapping. + + Returns + ------- + result : Stmt + The result. + """ + return _ffi_api.Substitute(node, vmap) diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 9f64a93a4860..a5af3537473f 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -60,6 +60,206 @@ def _transform(func, mod, ctx): return _fpass.prim_func_pass(_transform, opt_level=0, name="Filter") +def InjectPrefetch(): + """Inject prefetch instructions into stmt. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.InjectPrefetch() + + +def StorageFlatten(cache_line_size, create_bound_attribute=False): + """Flatten the multi-dimensional read/write to 1D. + + + Parameters + ---------- + cache_line_size: int + The size of CPU cache line. + + create_bound_attribute: + Whether to create bound attributes. + + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.StorageFlatten(cache_line_size, create_bound_attribute) + + +def InjectCopyIntrin(pragma_key, fintrin): + """Inject virtual thread loops. + + Parameters + ---------- + pragma_key : str + The pragma key for hint of copy. + + fintrin : function + The function with signature copyintrin(src, dst, pad_before, pad_after, pad_value) + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.InjectCopyIntrin(pragma_key, fintrin) + + +def CoProcSync(): + """Detect and insert sync points to co-processor. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.CoProcSync() + + +def LiftAttrScope(attr_key): + """Lift common attrs with attr_key to outer scope. + + Parameters + ---------- + attr_key : str + The attribute key to be checked. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.LiftAttrScope(attr_key) + + +def LoopPartition(): + """Inject virtual thread loops. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.LoopPartition() + + +def VectorizeLoop(enable_vectorize=True): + """Lower vectorization loops. + + Parameters + ---------- + enable_vectorize : bool + Whether vectorization is enabled. + Will lower to scalar loop when it is turned off. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.VectorizeLoop(enable_vectorize) + + +def InjectVirtualThread(): + """Inject virtual thread loops. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.InjectVirtualThread() + + +def InjectDoubleBuffer(): + """Inject double buffer statements. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.InjectDoubleBuffer() + + +def StorageRewrite(): + """Rewrite storage allocation pattern. + + Moves the allocation to outer most possible scope. + Trying to share space between allocations to make + a static allocation plan when possible. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.StorageRewrite() + + +def UnrollLoop(): + """Unroll the constant loop marked by unroll. + + This pass also automatically attach pragma unroll tag to loops which meets the standard. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.UnrollLoop() + + +def RemoveNoOp(): + """Remove No Op from the Stmt. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.RemoveNoOp() + + +def RewriteUnsafeSelect(): + """Detect and rewrite unsafe select that contains memory access. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.RewriteUnsafeSelect() + + +def Simplify(): + """Run arithmetic simplifications on the statements and expressions. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.Simplify() + + +def InstrumentBoundCheckers(): + """Instruments bound checkers. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.InstrumentBoundCheckers() + + def LowerCustomDatatypes(): """Lower custom datatypes. @@ -101,6 +301,17 @@ def SplitHostDevice(): return _ffi_api.SplitHostDevice() +def DecorateDeviceScope(): + """Decorate all the function's body as device function. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.DecorateDeviceScope() + + def SkipAssert(): """Skip assert stmt. @@ -227,3 +438,14 @@ def NarrowDataType(target_bits): Run this pass after StorageFlatten. """ return _ffi_api.NarrowDataType(target_bits) + + +def VerifyMemory(): + """Verify if func contains illegal host side direct memory access. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.VerifyMemory() diff --git a/rust/.rustfmt.toml b/rust/.rustfmt.toml index 3c51bb384c68..5a1f1d27514f 100644 --- a/rust/.rustfmt.toml +++ b/rust/.rustfmt.toml @@ -29,3 +29,4 @@ merge_derives = true use_try_shorthand = false use_field_init_shorthand = false force_explicit_abi = true + diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 8467f6a92ea8..6849c039f86f 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -22,9 +22,12 @@ members = [ "runtime", "runtime/tests/test_tvm_basic", "runtime/tests/test_tvm_dso", + "runtime/tests/test_wasm32", "runtime/tests/test_nn", "frontend", "frontend/tests/basics", "frontend/tests/callback", - "frontend/examples/resnet" + "frontend/examples/resnet", + "tvm-sys", + "tvm-rt" ] diff --git a/rust/common/build.rs b/rust/common/build.rs index b3ae7b6d1837..07326f41f801 100644 --- a/rust/common/build.rs +++ b/rust/common/build.rs @@ -51,6 +51,7 @@ fn main() { .layout_tests(false) .derive_partialeq(true) .derive_eq(true) + .derive_default(true) .generate() .expect("unable to generate bindings") .write_to_file(PathBuf::from("src/c_runtime_api.rs")) diff --git a/rust/common/src/array.rs b/rust/common/src/array.rs index d0a66a62b8bf..a8f4f989c146 100644 --- a/rust/common/src/array.rs +++ b/rust/common/src/array.rs @@ -133,6 +133,7 @@ macro_rules! impl_dltensor_from_ndarray { shape: arr.shape().as_ptr() as *const i64 as *mut i64, strides: arr.strides().as_ptr() as *const isize as *mut i64, byte_offset: 0, + ..Default::default() } } } diff --git a/rust/common/src/lib.rs b/rust/common/src/lib.rs index 2ae64e7a32b3..33b2993bf3da 100644 --- a/rust/common/src/lib.rs +++ b/rust/common/src/lib.rs @@ -31,8 +31,13 @@ pub mod ffi { include!(concat!(env!("CARGO_MANIFEST_DIR"), "/src/c_runtime_api.rs")); - pub type BackendPackedCFunc = - extern "C" fn(args: *const TVMValue, type_codes: *const c_int, num_args: c_int) -> c_int; + pub type BackendPackedCFunc = extern "C" fn( + args: *const TVMValue, + type_codes: *const c_int, + num_args: c_int, + out_ret_value: *mut TVMValue, + out_ret_tcode: *mut u32, + ) -> c_int; } pub mod array; diff --git a/rust/common/src/packed_func.rs b/rust/common/src/packed_func.rs index f3bac39b6a10..65434b928269 100644 --- a/rust/common/src/packed_func.rs +++ b/rust/common/src/packed_func.rs @@ -94,52 +94,52 @@ macro_rules! TVMPODValue { DLDataTypeCode_kDLInt => Int($value.v_int64), DLDataTypeCode_kDLUInt => UInt($value.v_int64), DLDataTypeCode_kDLFloat => Float($value.v_float64), - TVMTypeCode_kTVMNullptr => Null, - TVMTypeCode_kTVMDataType => DataType($value.v_type), - TVMTypeCode_kTVMContext => Context($value.v_ctx), - TVMTypeCode_kTVMOpaqueHandle => Handle($value.v_handle), - TVMTypeCode_kTVMDLTensorHandle => ArrayHandle($value.v_handle as TVMArrayHandle), - TVMTypeCode_kTVMObjectHandle => ObjectHandle($value.v_handle), - TVMTypeCode_kTVMModuleHandle => ModuleHandle($value.v_handle), - TVMTypeCode_kTVMPackedFuncHandle => FuncHandle($value.v_handle), - TVMTypeCode_kTVMNDArrayHandle => NDArrayHandle($value.v_handle), + TVMArgTypeCode_kTVMNullptr => Null, + TVMArgTypeCode_kTVMDataType => DataType($value.v_type), + TVMArgTypeCode_kTVMContext => Context($value.v_ctx), + TVMArgTypeCode_kTVMOpaqueHandle => Handle($value.v_handle), + TVMArgTypeCode_kTVMDLTensorHandle => ArrayHandle($value.v_handle as TVMArrayHandle), + TVMArgTypeCode_kTVMObjectHandle => ObjectHandle($value.v_handle), + TVMArgTypeCode_kTVMModuleHandle => ModuleHandle($value.v_handle), + TVMArgTypeCode_kTVMPackedFuncHandle => FuncHandle($value.v_handle), + TVMArgTypeCode_kTVMNDArrayHandle => NDArrayHandle($value.v_handle), $( $tvm_type => { $from_tvm_type } ),+ _ => unimplemented!("{}", type_code), } } } - pub fn to_tvm_value(&self) -> (TVMValue, TVMTypeCode) { + pub fn to_tvm_value(&self) -> (TVMValue, TVMArgTypeCode) { use $name::*; match self { Int(val) => (TVMValue { v_int64: *val }, DLDataTypeCode_kDLInt), UInt(val) => (TVMValue { v_int64: *val as i64 }, DLDataTypeCode_kDLUInt), Float(val) => (TVMValue { v_float64: *val }, DLDataTypeCode_kDLFloat), - Null => (TVMValue{ v_int64: 0 },TVMTypeCode_kTVMNullptr), - DataType(val) => (TVMValue { v_type: *val }, TVMTypeCode_kTVMDataType), - Context(val) => (TVMValue { v_ctx: val.clone() }, TVMTypeCode_kTVMContext), + Null => (TVMValue{ v_int64: 0 },TVMArgTypeCode_kTVMNullptr), + DataType(val) => (TVMValue { v_type: *val }, TVMArgTypeCode_kTVMDataType), + Context(val) => (TVMValue { v_ctx: val.clone() }, TVMArgTypeCode_kTVMContext), String(val) => { ( TVMValue { v_handle: val.as_ptr() as *mut c_void }, - TVMTypeCode_kTVMStr, + TVMArgTypeCode_kTVMStr, ) } - Handle(val) => (TVMValue { v_handle: *val }, TVMTypeCode_kTVMOpaqueHandle), + Handle(val) => (TVMValue { v_handle: *val }, TVMArgTypeCode_kTVMOpaqueHandle), ArrayHandle(val) => { ( TVMValue { v_handle: *val as *const _ as *mut c_void }, - TVMTypeCode_kTVMNDArrayHandle, + TVMArgTypeCode_kTVMNDArrayHandle, ) }, - ObjectHandle(val) => (TVMValue { v_handle: *val }, TVMTypeCode_kTVMObjectHandle), + ObjectHandle(val) => (TVMValue { v_handle: *val }, TVMArgTypeCode_kTVMObjectHandle), ModuleHandle(val) => - (TVMValue { v_handle: *val }, TVMTypeCode_kTVMModuleHandle), + (TVMValue { v_handle: *val }, TVMArgTypeCode_kTVMModuleHandle), FuncHandle(val) => ( TVMValue { v_handle: *val }, - TVMTypeCode_kTVMPackedFuncHandle + TVMArgTypeCode_kTVMPackedFuncHandle ), NDArrayHandle(val) => - (TVMValue { v_handle: *val }, TVMTypeCode_kTVMNDArrayHandle), + (TVMValue { v_handle: *val }, TVMArgTypeCode_kTVMNDArrayHandle), $( $self_type($val) => { $from_self_type } ),+ } } @@ -155,14 +155,14 @@ TVMPODValue! { Str(&'a CStr), }, match value { - TVMTypeCode_kTVMBytes => { Bytes(&*(value.v_handle as *const TVMByteArray)) } - TVMTypeCode_kTVMStr => { Str(CStr::from_ptr(value.v_handle as *const i8)) } + TVMArgTypeCode_kTVMBytes => { Bytes(&*(value.v_handle as *const TVMByteArray)) } + TVMArgTypeCode_kTVMStr => { Str(CStr::from_ptr(value.v_handle as *const i8)) } }, match &self { Bytes(val) => { - (TVMValue { v_handle: val as *const _ as *mut c_void }, TVMTypeCode_kTVMBytes) + (TVMValue { v_handle: val as *const _ as *mut c_void }, TVMArgTypeCode_kTVMBytes) } - Str(val) => { (TVMValue { v_handle: val.as_ptr() as *mut c_void }, TVMTypeCode_kTVMStr) } + Str(val) => { (TVMValue { v_handle: val.as_ptr() as *mut c_void }, TVMArgTypeCode_kTVMStr) } } } @@ -188,14 +188,14 @@ TVMPODValue! { Str(&'static CStr), }, match value { - TVMTypeCode_kTVMBytes => { Bytes(*(value.v_handle as *const TVMByteArray)) } - TVMTypeCode_kTVMStr => { Str(CStr::from_ptr(value.v_handle as *mut i8)) } + TVMArgTypeCode_kTVMBytes => { Bytes(*(value.v_handle as *const TVMByteArray)) } + TVMArgTypeCode_kTVMStr => { Str(CStr::from_ptr(value.v_handle as *mut i8)) } }, match &self { Bytes(val) => - { (TVMValue { v_handle: val as *const _ as *mut c_void }, TVMTypeCode_kTVMBytes ) } + { (TVMValue { v_handle: val as *const _ as *mut c_void }, TVMArgTypeCode_kTVMBytes ) } Str(val) => - { (TVMValue { v_str: val.as_ptr() }, TVMTypeCode_kTVMStr ) } + { (TVMValue { v_str: val.as_ptr() }, TVMArgTypeCode_kTVMStr ) } } } diff --git a/rust/frontend/examples/resnet/src/build_resnet.py b/rust/frontend/examples/resnet/src/build_resnet.py index 49c67bf1c4f3..a09a0c3a56eb 100644 --- a/rust/frontend/examples/resnet/src/build_resnet.py +++ b/rust/frontend/examples/resnet/src/build_resnet.py @@ -75,8 +75,8 @@ def build(target_dir): num_layers=18, batch_size=batch_size, image_shape=image_shape) # compile the model - with relay.build_config(opt_level=opt_level): - graph, lib, params = relay.build_module.build(net, target, params=params) + with tvm.transform.PassContext(opt_level=opt_level): + graph, lib, params = relay.build_module.build(net, target, params=params) # save the model artifacts lib.save(deploy_lib) diff --git a/rust/frontend/src/context.rs b/rust/frontend/src/context.rs index 6d08e391fc78..e1e3bf82e80f 100644 --- a/rust/frontend/src/context.rs +++ b/rust/frontend/src/context.rs @@ -211,7 +211,6 @@ impl_ctxs!((cpu, 1); (metal, 8); (vpi, 9); (rocm, 10); - (opengl, 11); (ext_dev, 12)); impl<'a> From<&'a str> for TVMContext { diff --git a/rust/frontend/src/function.rs b/rust/frontend/src/function.rs index 8411b03592d1..88d6cc80fe1c 100644 --- a/rust/frontend/src/function.rs +++ b/rust/frontend/src/function.rs @@ -204,7 +204,7 @@ impl<'a, 'm> Builder<'a, 'm> { ensure!(self.func.is_some(), errors::FunctionNotFoundError); let num_args = self.arg_buf.len(); - let (mut values, mut type_codes): (Vec, Vec) = + let (mut values, mut type_codes): (Vec, Vec) = self.arg_buf.iter().map(|arg| arg.to_tvm_value()).unzip(); let mut ret_val = unsafe { MaybeUninit::uninit().assume_init() }; @@ -257,9 +257,9 @@ unsafe extern "C" fn tvm_callback( for i in 0..len { value = args_list[i]; tcode = type_codes_list[i]; - if tcode == ffi::TVMTypeCode_kTVMObjectHandle as c_int - || tcode == ffi::TVMTypeCode_kTVMPackedFuncHandle as c_int - || tcode == ffi::TVMTypeCode_kTVMModuleHandle as c_int + if tcode == ffi::TVMArgTypeCode_kTVMObjectHandle as c_int + || tcode == ffi::TVMArgTypeCode_kTVMPackedFuncHandle as c_int + || tcode == ffi::TVMArgTypeCode_kTVMModuleHandle as c_int { check_call!(ffi::TVMCbArgToReturn( &mut value as *mut _, diff --git a/rust/macros/Cargo.toml b/rust/macros/Cargo.toml index 784b35e2fdae..97ebeca0d713 100644 --- a/rust/macros/Cargo.toml +++ b/rust/macros/Cargo.toml @@ -16,7 +16,7 @@ # under the License. [package] -name = "tvm-macros" +name = "old-tvm-macros" version = "0.1.1" license = "Apache-2.0" description = "Procedural macros of the TVM crate." @@ -32,5 +32,5 @@ proc-macro = true [dependencies] goblin = "0.0.24" proc-macro2 = "^1.0" -quote = "1.0" -syn = "1.0" +quote = "^1.0" +syn = { version = "1.0.17", features = ["full", "extra-traits"] } diff --git a/rust/runtime/Cargo.toml b/rust/runtime/Cargo.toml index eb531f96e5be..cc149d4d1620 100644 --- a/rust/runtime/Cargo.toml +++ b/rust/runtime/Cargo.toml @@ -39,7 +39,7 @@ serde = "1.0" serde_derive = "1.0" serde_json = "1.0" tvm-common = { version = "0.1", path = "../common" } -tvm-macros = { version = "0.1", path = "../macros" } +old-tvm-macros = { version = "0.1", path = "../macros" } [target.'cfg(not(any(target_arch = "wasm32", target_env = "sgx")))'.dependencies] libloading = "0.5" diff --git a/rust/runtime/src/array.rs b/rust/runtime/src/array.rs index 2b6c7c217e28..c38b3ff8e527 100644 --- a/rust/runtime/src/array.rs +++ b/rust/runtime/src/array.rs @@ -297,6 +297,7 @@ impl<'a> Tensor<'a> { self.strides.as_ref().unwrap().as_ptr() } as *mut i64, byte_offset: 0, + ..Default::default() } } } diff --git a/rust/runtime/src/graph.rs b/rust/runtime/src/graph.rs index 518bf724f319..71541ba27826 100644 --- a/rust/runtime/src/graph.rs +++ b/rust/runtime/src/graph.rs @@ -382,7 +382,18 @@ named! { // Converts a bytes to String. named! { name, - map_res!(length_data!(le_u64), |b: &[u8]| String::from_utf8(b.to_vec())) + do_parse!( + len_l: le_u32 >> + len_h: le_u32 >> + data: take!(len_l) >> + ( + if len_h == 0 { + String::from_utf8(data.to_vec()).unwrap() + } else { + panic!("Too long string") + } + ) + ) } // Parses a TVMContext diff --git a/rust/runtime/src/lib.rs b/rust/runtime/src/lib.rs index de1b79d21d15..07aaaae2fb24 100644 --- a/rust/runtime/src/lib.rs +++ b/rust/runtime/src/lib.rs @@ -41,6 +41,7 @@ extern crate num_cpus; extern crate serde; #[macro_use] extern crate serde_derive; +extern crate old_tvm_macros as tvm_macros; extern crate serde_json; extern crate tvm_common; diff --git a/rust/runtime/src/module/mod.rs b/rust/runtime/src/module/mod.rs index 856dd78193bc..cb4d7776dd0b 100644 --- a/rust/runtime/src/module/mod.rs +++ b/rust/runtime/src/module/mod.rs @@ -44,9 +44,17 @@ fn wrap_backend_packed_func(func_name: String, func: BackendPackedCFunc) -> Box< (val, code as i32) }) .unzip(); - let exit_code = func(values.as_ptr(), type_codes.as_ptr(), values.len() as i32); + let ret: TVMRetValue = TVMRetValue::default(); + let (mut ret_val, mut ret_type_code) = ret.to_tvm_value(); + let exit_code = func( + values.as_ptr(), + type_codes.as_ptr(), + values.len() as i32, + &mut ret_val, + &mut ret_type_code, + ); if exit_code == 0 { - Ok(TVMRetValue::default()) + Ok(TVMRetValue::from_tvm_value(ret_val, ret_type_code)) } else { Err(tvm_common::errors::FuncCallError::get_with_context( func_name.clone(), diff --git a/rust/runtime/src/threading.rs b/rust/runtime/src/threading.rs index f473bbf3990a..b8be01270ae7 100644 --- a/rust/runtime/src/threading.rs +++ b/rust/runtime/src/threading.rs @@ -18,7 +18,6 @@ */ use std::{ - env, os::raw::{c_int, c_void}, sync::{ atomic::{AtomicUsize, Ordering}, @@ -27,6 +26,9 @@ use std::{ thread::{self, JoinHandle}, }; +#[cfg(not(target_arch = "wasm32"))] +use std::env; + use crossbeam::channel::{bounded, Receiver, Sender}; use tvm_common::ffi::TVMParallelGroupEnv; @@ -147,7 +149,10 @@ impl ThreadPool { fn run_worker(queue: Receiver) { loop { - let task = queue.recv().expect("should recv"); + let task = match queue.recv() { + Ok(v) => v, + Err(_) => break, + }; let result = task.run(); if result == ::min_value() { break; diff --git a/rust/runtime/src/workspace.rs b/rust/runtime/src/workspace.rs index 8344dfbb1adf..65ad25324cae 100644 --- a/rust/runtime/src/workspace.rs +++ b/rust/runtime/src/workspace.rs @@ -64,7 +64,7 @@ impl WorkspacePool { .iter() .fold(None, |cur_ws_idx: Option, &idx| { let ws_size = self.workspaces[idx].size(); - if !ws_size >= size { + if ws_size < size { return cur_ws_idx; } cur_ws_idx.or(Some(idx)).and_then(|cur_idx| { @@ -92,9 +92,8 @@ impl WorkspacePool { break; } } - if let Some(ws_idx) = ws_idx { - self.free.push(ws_idx); - } + let ws_idx = ws_idx.ok_or_else(|| format_err!("Invalid pointer"))?; + self.free.push(ws_idx); Ok(()) } } @@ -135,6 +134,5 @@ pub extern "C" fn TVMBackendFreeWorkspace( Ok(()) => 0, Err(_) => -1, }) as c_int - }); - 0 + }) } diff --git a/rust/runtime/tests/test_wasm32/.cargo/config b/rust/runtime/tests/test_wasm32/.cargo/config new file mode 100644 index 000000000000..6b77899cb333 --- /dev/null +++ b/rust/runtime/tests/test_wasm32/.cargo/config @@ -0,0 +1,2 @@ +[build] +target = "wasm32-wasi" diff --git a/rust/runtime/tests/test_wasm32/Cargo.toml b/rust/runtime/tests/test_wasm32/Cargo.toml new file mode 100644 index 000000000000..1d3373a9e60f --- /dev/null +++ b/rust/runtime/tests/test_wasm32/Cargo.toml @@ -0,0 +1,26 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "test-wasm32" +version = "0.0.0" +license = "Apache-2.0" +authors = ["TVM Contributors"] + +[dependencies] +ndarray="0.12" +tvm-runtime = { path = "../../" } diff --git a/rust/runtime/tests/test_wasm32/build.rs b/rust/runtime/tests/test_wasm32/build.rs new file mode 100644 index 000000000000..8b72be290267 --- /dev/null +++ b/rust/runtime/tests/test_wasm32/build.rs @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::{path::PathBuf, process::Command}; + +fn main() { + let mut out_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + out_dir.push("lib"); + + if !out_dir.is_dir() { + std::fs::create_dir(&out_dir).unwrap(); + } + + let obj_file = out_dir.join("test.o"); + let lib_file = out_dir.join("libtest_wasm32.a"); + + let output = Command::new(concat!( + env!("CARGO_MANIFEST_DIR"), + "/src/build_test_lib.py" + )) + .arg(&out_dir) + .output() + .expect("Failed to execute command"); + assert!( + obj_file.exists(), + "Could not build tvm lib: {}", + String::from_utf8(output.stderr) + .unwrap() + .trim() + .split("\n") + .last() + .unwrap_or("") + ); + + let ar = option_env!("LLVM_AR").unwrap_or("llvm-ar-8"); + let output = Command::new(ar) + .arg("rcs") + .arg(&lib_file) + .arg(&obj_file) + .output() + .expect("Failed to execute command"); + assert!( + lib_file.exists(), + "Could not create archive: {}", + String::from_utf8(output.stderr) + .unwrap() + .trim() + .split("\n") + .last() + .unwrap_or("") + ); + + println!("cargo:rustc-link-lib=static=test_wasm32"); + println!("cargo:rustc-link-search=native={}", out_dir.display()); +} diff --git a/tests/python/unittest/test_tensorrt.py b/rust/runtime/tests/test_wasm32/src/build_test_lib.py old mode 100644 new mode 100755 similarity index 57% rename from tests/python/unittest/test_tensorrt.py rename to rust/runtime/tests/test_wasm32/src/build_test_lib.py index 8e79653cc0cc..6016c60c4ea3 --- a/tests/python/unittest/test_tensorrt.py +++ b/rust/runtime/tests/test_wasm32/src/build_test_lib.py @@ -1,3 +1,4 @@ +#!/usr/bin/env python3 # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -14,22 +15,24 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import os -from unittest import mock -import tempfile + +"""Prepares a simple TVM library for testing.""" + +from os import path as osp +import sys import tvm +from tvm import te -def test_empty_library_export(): - with tempfile.TemporaryDirectory() as temp_dir: - temp_file_path = os.path.join(temp_dir, "tmp_lib") - print(temp_file_path) - with mock.patch.object(tvm.runtime.Module, "is_empty") as is_empty_mock: - is_empty_mock.return_value = True - module = tvm.runtime.Module(None) - module.export_library(temp_file_path) - assert(os.path.isfile(temp_file_path)) - +def main(): + n = te.var('n') + A = te.placeholder((n,), name='A') + B = te.placeholder((n,), name='B') + C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name='C') + s = tvm.te.create_schedule(C.op) + s[C].parallel(s[C].op.axis[0]) + print(tvm.lower(s, [A, B, C], simple_mode=True)) + tvm.build(s, [A, B, C], 'llvm -target=wasm32-unknown-unknown --system-lib').save(osp.join(sys.argv[1], 'test.o')) -if __name__ == "__main__": - test_empty_library_export() +if __name__ == '__main__': + main() diff --git a/rust/runtime/tests/test_wasm32/src/main.rs b/rust/runtime/tests/test_wasm32/src/main.rs new file mode 100644 index 000000000000..a46cfa979bec --- /dev/null +++ b/rust/runtime/tests/test_wasm32/src/main.rs @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +extern "C" { + static __tvm_module_ctx: i32; +} + +#[no_mangle] +unsafe fn __get_tvm_module_ctx() -> i32 { + // Refer a symbol in the libtest_wasm32.a to make sure that the link of the + // library is not optimized out. + __tvm_module_ctx +} + +extern crate ndarray; +#[macro_use] +extern crate tvm_runtime; + +use ndarray::Array; +use tvm_runtime::{DLTensor, Module as _, SystemLibModule}; + +fn main() { + // try static + let mut a = Array::from_vec(vec![1f32, 2., 3., 4.]); + let mut b = Array::from_vec(vec![1f32, 0., 1., 0.]); + let mut c = Array::from_vec(vec![0f32; 4]); + let e = Array::from_vec(vec![2f32, 2., 4., 4.]); + let mut a_dl: DLTensor = (&mut a).into(); + let mut b_dl: DLTensor = (&mut b).into(); + let mut c_dl: DLTensor = (&mut c).into(); + + let syslib = SystemLibModule::default(); + let add = syslib + .get_function("default_function") + .expect("main function not found"); + call_packed!(add, &mut a_dl, &mut b_dl, &mut c_dl).unwrap(); + assert!(c.all_close(&e, 1e-8f32)); +} diff --git a/rust/tvm-macros/Cargo.toml b/rust/tvm-macros/Cargo.toml new file mode 100644 index 000000000000..7abc9ae64f7c --- /dev/null +++ b/rust/tvm-macros/Cargo.toml @@ -0,0 +1,36 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "tvm-macros" +version = "0.1.1" +license = "Apache-2.0" +description = "Procedural macros of the TVM crate." +repository = "https://github.com/apache/incubator-tvm" +readme = "README.md" +keywords = ["tvm"] +authors = ["TVM Contributors"] +edition = "2018" + +[lib] +proc-macro = true + +[dependencies] +goblin = "0.0.24" +proc-macro2 = "^1.0" +quote = "^1.0" +syn = { version = "1.0.17", features = ["full", "extra-traits"] } diff --git a/rust/tvm-macros/src/external.rs b/rust/tvm-macros/src/external.rs new file mode 100644 index 000000000000..8833d6084574 --- /dev/null +++ b/rust/tvm-macros/src/external.rs @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +use proc_macro2::Span; +use quote::quote; +use syn::parse::{Parse, ParseStream, Result}; + +use syn::{FnArg, Generics, Ident, Lit, Meta, NestedMeta, Pat, ReturnType, TraitItemMethod, Type}; + +struct External { + tvm_name: String, + ident: Ident, + generics: Generics, + inputs: Vec, + ret_type: ReturnType, +} + +impl Parse for External { + fn parse(input: ParseStream) -> Result { + let method: TraitItemMethod = input.parse()?; + assert_eq!(method.attrs.len(), 1); + let sig = method.sig; + let tvm_name = method.attrs[0].parse_meta()?; + let tvm_name = match tvm_name { + Meta::List(meta_list) => { + let name = meta_list.path.get_ident().expect("name"); + assert_eq!(name.to_string(), "name".to_string()); + match meta_list.nested.first() { + Some(NestedMeta::Lit(Lit::Str(lit))) => lit.value(), + _ => panic!(), + } + } + _ => panic!(), + }; + assert_eq!(method.default, None); + assert!(method.semi_token != None); + let ident = sig.ident; + let generics = sig.generics; + let inputs = sig.inputs.iter().map(|param| param.clone()).collect(); + let ret_type = sig.output; + + Ok(External { + tvm_name, + ident, + generics, + inputs, + ret_type, + }) + } +} + +struct ExternalInput { + externs: Vec, +} + +impl Parse for ExternalInput { + fn parse(input: ParseStream) -> Result { + let mut externs: Vec = Vec::new(); + + loop { + if input.is_empty() { + break; + } + externs.push(input.parse()?); + } + + Ok(ExternalInput { externs }) + } +} + +pub fn macro_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + let ext_input = syn::parse_macro_input!(input as ExternalInput); + + let tvm_rt_crate = crate::util::get_tvm_rt_crate(); + + let err_type = quote! { #tvm_rt_crate::Error }; + + let mut items = Vec::new(); + + for external in &ext_input.externs { + let name = &external.ident; + let global_name = format!("global_{}", external.ident); + let global_name = Ident::new(&global_name, Span::call_site()); + let ext_name = &external.tvm_name; + + let ty_params: Vec = external + .generics + .params + .iter() + .map(|ty_param| match ty_param { + syn::GenericParam::Type(param) => param.clone(), + _ => panic!(), + }) + .collect(); + + let args = &external.inputs; + + let (args, tys): (Vec, Vec) = args + .iter() + .map(|arg| match arg { + FnArg::Typed(pat_type) => match &*pat_type.pat { + Pat::Ident(pat_ident) => { + let ident: Ident = pat_ident.ident.clone(); + let ty: Type = *pat_type.ty.clone(); + (ident, ty) + } + _ => panic!(), + }, + _ => panic!(), + }) + .unzip(); + + let ret_type = match &external.ret_type { + ReturnType::Type(_, rtype) => *rtype.clone(), + _ => panic!(), + }; + + let global = quote! { + #[allow(non_upper_case_globals)] + static #global_name: ::once_cell::sync::Lazy<#tvm_rt_crate::Function> = + ::once_cell::sync::Lazy::new(|| { + #tvm_rt_crate::Function::get(#ext_name) + .expect(concat!("unable to load external function", stringify!(#ext_name), "from TVM registry.")) + }); + }; + + items.push(global); + + let wrapper = quote! { + pub fn #name<#(#ty_params),*>(#(#args : #tys),*) -> Result<#ret_type, #err_type> { + let func_ref: #tvm_rt_crate::Function = #global_name.clone(); + let func_ref: Box Result<#ret_type, #err_type>> = func_ref.to_boxed_fn(); + let res: #ret_type = func_ref(#(#args),*)?; + Ok(res) + } + }; + + items.push(wrapper); + } + + proc_macro::TokenStream::from(quote! { + #(#items + )* + }) +} diff --git a/rust/tvm-macros/src/import_module.rs b/rust/tvm-macros/src/import_module.rs new file mode 100644 index 000000000000..6b059ae363f8 --- /dev/null +++ b/rust/tvm-macros/src/import_module.rs @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +use quote::quote; +use std::{fs::File, io::Read}; +use syn::parse::{Parse, ParseStream, Result}; +use syn::LitStr; + +use std::path::PathBuf; + +struct ImportModule { + importing_file: LitStr, +} + +impl Parse for ImportModule { + fn parse(input: ParseStream) -> Result { + let importing_file: LitStr = input.parse()?; + Ok(ImportModule { importing_file }) + } +} + +pub fn macro_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + let import_module_args = syn::parse_macro_input!(input as ImportModule); + + let manifest = + std::env::var("CARGO_MANIFEST_DIR").expect("variable should always be set by Cargo."); + + let mut path = PathBuf::new(); + path.push(manifest); + path = path.join(import_module_args.importing_file.value()); + + let mut fd = File::open(&path) + .unwrap_or_else(|_| panic!("Unable to find TVM object file at `{}`", path.display())); + let mut buffer = Vec::new(); + fd.read_to_end(&mut buffer).unwrap(); + + let fn_names = match goblin::Object::parse(&buffer).unwrap() { + goblin::Object::Elf(elf) => elf + .syms + .iter() + .filter_map(|s| { + if s.st_type() == 0 || goblin::elf::sym::type_to_str(s.st_type()) == "FILE" { + return None; + } + match elf.strtab.get(s.st_name) { + Some(Ok(name)) if name != "" => { + Some(syn::Ident::new(name, proc_macro2::Span::call_site())) + } + _ => None, + } + }) + .collect::>(), + goblin::Object::Mach(goblin::mach::Mach::Binary(obj)) => { + obj.symbols() + .filter_map(|s| match s { + Ok((name, ref nlist)) + if nlist.is_global() + && nlist.n_sect != 0 + && !name.ends_with("tvm_module_ctx") => + { + Some(syn::Ident::new( + if name.starts_with('_') { + // Mach objects prepend a _ to globals. + &name[1..] + } else { + &name + }, + proc_macro2::Span::call_site(), + )) + } + _ => None, + }) + .collect::>() + } + _ => panic!("Unsupported object format."), + }; + + let extern_fns = quote! { + mod ext { + extern "C" { + #( + pub(super) fn #fn_names( + args: *const tvm_runtime::ffi::TVMValue, + type_codes: *const std::os::raw::c_int, + num_args: std::os::raw::c_int + ) -> std::os::raw::c_int; + )* + } + } + }; + + let fns = quote! { + use tvm_runtime::{ffi::TVMValue, ArgValue, RetValue, FuncCallError}; + #extern_fns + + #( + pub fn #fn_names(args: &[ArgValue]) -> Result { + let (values, type_codes): (Vec, Vec) = args + .into_iter() + .map(|arg| { + let (val, code) = arg.to_tvm_value(); + (val, code as i32) + }) + .unzip(); + let exit_code = unsafe { + ext::#fn_names(values.as_ptr(), type_codes.as_ptr(), values.len() as i32) + }; + if exit_code == 0 { + Ok(RetValue::default()) + } else { + Err(FuncCallError::get_with_context(stringify!(#fn_names).to_string())) + } + } + )* + }; + + proc_macro::TokenStream::from(fns) +} diff --git a/rust/tvm-macros/src/lib.rs b/rust/tvm-macros/src/lib.rs new file mode 100644 index 000000000000..603e1ceaafcc --- /dev/null +++ b/rust/tvm-macros/src/lib.rs @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use proc_macro::TokenStream; + +mod external; +mod import_module; +mod object; +mod util; + +#[proc_macro] +pub fn import_module(input: TokenStream) -> TokenStream { + import_module::macro_impl(input) +} + +#[proc_macro_derive(Object, attributes(base, ref_name, type_key))] +pub fn macro_impl(input: TokenStream) -> TokenStream { + // let input = proc_macro2::TokenStream::from(input); + TokenStream::from(object::macro_impl(input)) +} + +#[proc_macro] +pub fn external(input: TokenStream) -> TokenStream { + external::macro_impl(input) +} diff --git a/rust/tvm-macros/src/object.rs b/rust/tvm-macros/src/object.rs new file mode 100644 index 000000000000..bee22c367189 --- /dev/null +++ b/rust/tvm-macros/src/object.rs @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use proc_macro::TokenStream; +use proc_macro2::Span; +use quote::quote; +use syn::DeriveInput; +use syn::Ident; + +use crate::util::get_tvm_rt_crate; + +pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream { + let tvm_rt_crate = get_tvm_rt_crate(); + let derive_input = syn::parse_macro_input!(input as DeriveInput); + let payload_id = derive_input.ident; + + let mut type_key = None; + let mut ref_name = None; + let base = Some(Ident::new("base", Span::call_site())); + + for attr in derive_input.attrs { + if attr.path.is_ident("type_key") { + type_key = Some(attr.parse_meta().expect("foo")) + } + + if attr.path.is_ident("ref_name") { + ref_name = Some(attr.parse_meta().expect("foo")) + } + } + + let type_key = if let Some(syn::Meta::NameValue(name_value)) = type_key { + match name_value.lit { + syn::Lit::Str(type_key) => type_key, + _ => panic!("foo"), + } + } else { + panic!("bar"); + }; + + let ref_name = if let Some(syn::Meta::NameValue(name_value)) = ref_name { + match name_value.lit { + syn::Lit::Str(ref_name) => ref_name, + _ => panic!("foo"), + } + } else { + panic!("bar"); + }; + + let ref_id = Ident::new(&ref_name.value(), Span::call_site()); + let base = base.expect("should be present"); + + let expanded = quote! { + unsafe impl #tvm_rt_crate::object::IsObject for #payload_id { + const TYPE_KEY: &'static str = #type_key; + + fn as_object<'s>(&'s self) -> &'s Object { + &self.#base.as_object() + } + } + + #[derive(Clone)] + pub struct #ref_id(Option<#tvm_rt_crate::object::ObjectPtr<#payload_id>>); + + impl #tvm_rt_crate::object::ToObjectRef for #ref_id { + fn to_object_ref(&self) -> ObjectRef { + ObjectRef(self.0.as_ref().map(|o| o.upcast())) + } + } + + impl std::ops::Deref for #ref_id { + type Target = #payload_id; + + fn deref(&self) -> &Self::Target { + self.0.as_ref().unwrap() + } + } + + impl std::convert::TryFrom<#tvm_rt_crate::RetValue> for #ref_id { + type Error = #tvm_rt_crate::Error; + + fn try_from(ret_val: #tvm_rt_crate::RetValue) -> Result<#ref_id, Self::Error> { + use std::convert::TryInto; + let oref: ObjectRef = ret_val.try_into()?; + let ptr = oref.0.ok_or(#tvm_rt_crate::Error::Null)?; + let ptr = ptr.downcast::<#payload_id>()?; + Ok(#ref_id(Some(ptr))) + } + } + + impl<'a> From<#ref_id> for #tvm_rt_crate::ArgValue<'a> { + fn from(object_ref: #ref_id) -> #tvm_rt_crate::ArgValue<'a> { + use std::ffi::c_void; + let object_ptr = &object_ref.0; + match object_ptr { + None => { + #tvm_rt_crate::ArgValue:: + ObjectHandle(std::ptr::null::() as *mut c_void) + } + Some(value) => value.clone().into() + } + } + } + + impl<'a> From<&#ref_id> for #tvm_rt_crate::ArgValue<'a> { + fn from(object_ref: &#ref_id) -> #tvm_rt_crate::ArgValue<'a> { + let oref: #ref_id = object_ref.clone(); + #tvm_rt_crate::ArgValue::<'a>::from(oref) + } + } + + impl<'a> std::convert::TryFrom<#tvm_rt_crate::ArgValue<'a>> for #ref_id { + type Error = #tvm_rt_crate::Error; + + fn try_from(arg_value: #tvm_rt_crate::ArgValue<'a>) -> Result<#ref_id, Self::Error> { + use std::convert::TryInto; + let optr = arg_value.try_into()?; + Ok(#ref_id(Some(optr))) + } + } + + impl<'a> std::convert::TryFrom<&#tvm_rt_crate::ArgValue<'a>> for #ref_id { + type Error = #tvm_rt_crate::Error; + + fn try_from(arg_value: &#tvm_rt_crate::ArgValue<'a>) -> Result<#ref_id, Self::Error> { + use std::convert::TryInto; + let optr = arg_value.try_into()?; + Ok(#ref_id(Some(optr))) + } + } + + impl From<#ref_id> for #tvm_rt_crate::RetValue { + fn from(object_ref: #ref_id) -> #tvm_rt_crate::RetValue { + use std::ffi::c_void; + let object_ptr = &object_ref.0; + match object_ptr { + None => { + #tvm_rt_crate::RetValue::ObjectHandle(std::ptr::null::() as *mut c_void) + } + Some(value) => value.clone().into() + } + } + } + + }; + + TokenStream::from(expanded) +} diff --git a/rust/tvm-macros/src/util.rs b/rust/tvm-macros/src/util.rs new file mode 100644 index 000000000000..1e720f04dfef --- /dev/null +++ b/rust/tvm-macros/src/util.rs @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use proc_macro2::TokenStream; +use quote::quote; +use std::env; + +pub fn get_tvm_rt_crate() -> TokenStream { + if env::var("CARGO_PKG_NAME").unwrap() == "tvm-rt" { + quote!(crate) + } else { + quote!(tvm_rt) + } +} diff --git a/rust/tvm-rt/.gitignore b/rust/tvm-rt/.gitignore new file mode 100644 index 000000000000..2430329c78b6 --- /dev/null +++ b/rust/tvm-rt/.gitignore @@ -0,0 +1,7 @@ +target +**/*.rs.bk +Cargo.lock +/tests/basics/add_* +/examples/resnet/deploy_* +/examples/resnet/*.png +/examples/resnet/synset.* diff --git a/rust/tvm-rt/Cargo.toml b/rust/tvm-rt/Cargo.toml new file mode 100644 index 000000000000..465ae583ab6c --- /dev/null +++ b/rust/tvm-rt/Cargo.toml @@ -0,0 +1,45 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "tvm-rt" +version = "0.1.0" +license = "Apache-2.0" +description = "Rust bindings for the TVM runtime API." +repository = "https://github.com/apache/incubator-tvm" +homepage = "https://github.com/apache/incubator-tvm" +readme = "README.md" +keywords = ["rust", "tvm"] +categories = ["api-bindings", "science"] +authors = ["TVM Contributors"] +edition = "2018" + +[dependencies] +thiserror = "^1.0" +ndarray = "0.12" +num-traits = "0.2" +tvm-sys = { version = "0.1", path = "../tvm-sys/", features = ["bindings"] } +tvm-macros = { version = "0.1", path = "../tvm-macros" } +paste = "0.1" +mashup = "0.1" +once_cell = "^1.3.1" + +[dev-dependencies] +anyhow = "^1.0" + +[features] +blas = ["ndarray/blas"] diff --git a/rust/tvm-rt/README.md b/rust/tvm-rt/README.md new file mode 100644 index 000000000000..7c87939db301 --- /dev/null +++ b/rust/tvm-rt/README.md @@ -0,0 +1,60 @@ + + + + + + + + + + + + + + + + + +# TVM Runtime Support + +This crate provides an idiomatic Rust API for [TVM](https://github.com/apache/incubator-tvm) runtime. +Currently this is tested on `1.42.0` and above. + +## What Does This Crate Offer? + +TVM is an end-to-end deep learning compiler which takes high level machine learning +models or tensor computations and lowers them into executable code for a variety +of heterogenous devices (e.g., CPU, GPU). + +This crate provides access to the APIs for manipulating runtime data structures, +as well as TVM's cross-language Object system which functions similarly to systems +such as COM, enabling cross-language interoperability. + +## Installations + +Please follow TVM [installation](https://tvm.apache.org/docs/install/index.html) instructions, +`export TVM_HOME=/path/to/tvm` and add `libtvm_runtime` to your `LD_LIBRARY_PATH`. + +### Example of registering a cross-language closure. + +One can use `register!` macro to expose a Rust closure with arguments which implement `TryFrom` +and return types which implement `Into`. Once registered with TVM these functions can be +accessed via Python or C++, or any other language which implements the TVM packed function convention +see `docs.tvm.ai` for more information. + +```rust +use tvm_rt::{ArgValue, RetValue}; +use tvm_rt::function::{Function, Result, register}; + +fn sum(x: i64, y: i64, z: i64) -> i64 { + x + y + z +} + +fn main() { + register(sum, "mysum".to_owned()).unwrap(); + let func = Function::get("mysum").unwrap(); + let boxed_fn = func.to_boxed_fn:: Result>(); + let ret = boxed_fn(10, 20, 30).unwrap(); + assert_eq!(ret, 60); +} +``` diff --git a/rust/tvm-rt/src/context.rs b/rust/tvm-rt/src/context.rs new file mode 100644 index 000000000000..b0fea33c6c61 --- /dev/null +++ b/rust/tvm-rt/src/context.rs @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::os::raw::c_void; +use std::ptr; + +use crate::errors::Error; + +use tvm_sys::ffi; + +pub use tvm_sys::context::*; + +trait ContextExt { + /// Checks whether the context exists or not. + fn exist(&self) -> bool; + fn sync(&self) -> Result<(), Error>; + fn max_threads_per_block(&self) -> isize; + fn warp_size(&self) -> isize; + fn max_shared_memory_per_block(&self) -> isize; + fn compute_version(&self) -> isize; + fn device_name(&self) -> isize; + fn max_clock_rate(&self) -> isize; + fn multi_processor_count(&self) -> isize; + fn max_thread_dimensions(&self) -> isize; +} + +macro_rules! impl_device_attrs { + ($(($attr_name:ident, $attr_kind:expr));+) => { + $( + fn $attr_name(&self) -> isize { + get_device_attr(self.device_type as i32, self.device_id as i32, 0) + .expect("should not fail") as isize + } + + )+ + }; +} + +crate::external! { + #[name("runtime.GetDeviceAttr")] + fn get_device_attr(device_type: i32, device_id: i32, device_kind: i32) -> i32; +} + +impl ContextExt for Context { + fn exist(&self) -> bool { + let exists = get_device_attr(self.device_type as i32, self.device_id as i32, 0) + .expect("should not fail"); + + exists != 0 + } + + /// Synchronize the context stream. + fn sync(&self) -> Result<(), Error> { + check_call!(ffi::TVMSynchronize( + self.device_type as i32, + self.device_id as i32, + ptr::null_mut() as *mut c_void + )); + Ok(()) + } + + impl_device_attrs!((max_threads_per_block, 1); + (warp_size, 2); + (max_shared_memory_per_block, 3); + (compute_version, 4); + (device_name, 5); + (max_clock_rate, 6); + (multi_processor_count, 7); + (max_thread_dimensions, 8)); +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn sync() { + let ctx = Context::cpu(0); + assert!(ctx.sync().is_ok()) + } +} diff --git a/rust/tvm-rt/src/errors.rs b/rust/tvm-rt/src/errors.rs new file mode 100644 index 000000000000..0b45ebf445bf --- /dev/null +++ b/rust/tvm-rt/src/errors.rs @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use crate::DataType; +use thiserror::Error; + +#[derive(Debug, Error)] +#[error("Function was not set in `function::Builder`")] +pub struct FunctionNotFoundError; + +#[derive(Debug, Error)] +#[error("Expected type `{expected}` but found `{actual}`")] +pub struct TypeMismatchError { + pub expected: String, + pub actual: String, +} + +#[derive(Debug, Error)] +pub enum NDArrayError { + #[error("Missing NDArray shape.")] + MissingShape, + #[error("Cannot convert from an empty array.")] + EmptyArray, + #[error("Invalid datatype when attempting to convert ndarray.")] + InvalidDatatype(#[from] tvm_sys::datatype::ParseDataTypeError), + #[error("a shape error occurred in the Rust ndarray library")] + ShapeError(#[from] ndarray::ShapeError), + #[error("Expected type `{expected}` but found `{actual}`")] + DataTypeMismatch { + expected: DataType, + actual: DataType, + }, +} + +#[derive(Debug, Error)] +pub enum Error { + #[error("{0}")] + Downcast(#[from] tvm_sys::errors::ValueDowncastError), + #[error("raw pointer passed across boundary was null")] + Null, + #[error("failed to load module due to invalid path {0}")] + ModuleLoadPath(String), + #[error("failed to convert String into CString due to embedded nul character")] + ToCString(#[from] std::ffi::NulError), + #[error("failed to convert CString into String")] + FromCString(#[from] std::ffi::IntoStringError), + #[error("Handle `{0}` is null.")] + NullHandle(String), + #[error("{0}")] + NDArray(#[from] NDArrayError), + #[error("{0}")] + CallFailed(String), +} + +impl Error { + pub fn downcast(actual_type: String, expected_type: &'static str) -> Error { + Self::Downcast(tvm_sys::errors::ValueDowncastError { + actual_type, + expected_type, + }) + } +} diff --git a/rust/tvm-rt/src/function.rs b/rust/tvm-rt/src/function.rs new file mode 100644 index 000000000000..cb8777a6227b --- /dev/null +++ b/rust/tvm-rt/src/function.rs @@ -0,0 +1,303 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +//! This module provides an idiomatic Rust API for creating and working with TVM functions. +//! +//! For calling an already registered TVM function use [`function::Builder`] +//! To register a TVM packed function from Rust side either +//! use [`function::register`] or the macro [`register_global_func`]. +//! +//! See the tests and examples repository for more examples. + +use std::convert::TryFrom; +use std::{ + ffi::CString, + os::raw::{c_char, c_int}, + ptr, str, +}; + +pub use tvm_sys::{ffi, ArgValue, RetValue}; + +use crate::errors::Error; + +use super::to_boxed_fn::ToBoxedFn; +use super::to_function::{ToFunction, Typed}; + +pub type Result = std::result::Result; + +/// Wrapper around TVM function handle which includes `is_global` +/// indicating whether the function is global or not, and `is_cloned` showing +/// not to drop a cloned function from Rust side. +/// The value of these fields can be accessed through their respective methods. +#[derive(Debug, Hash)] +pub struct Function { + pub(crate) handle: ffi::TVMFunctionHandle, + // whether the registered function is global or not. + is_global: bool, + from_rust: bool, +} + +unsafe impl Send for Function {} +unsafe impl Sync for Function {} + +impl Function { + pub(crate) fn new(handle: ffi::TVMFunctionHandle) -> Self { + Function { + handle, + is_global: false, + from_rust: false, + } + } + + /// For a given function, it returns a function by name. + pub fn get>(name: S) -> Option { + let name = CString::new(name.as_ref()).unwrap(); + let mut handle = ptr::null_mut() as ffi::TVMFunctionHandle; + + check_call!(ffi::TVMFuncGetGlobal( + name.as_ptr() as *const c_char, + &mut handle as *mut _ + )); + + if handle.is_null() { + None + } else { + Some(Function { + handle, + is_global: true, + from_rust: false, + }) + } + } + + pub fn get_boxed>(name: S) -> Option> + where + F: ToBoxedFn, + { + Self::get(name).map(|f| f.to_boxed_fn::()) + } + + /// Returns the underlying TVM function handle. + pub fn handle(&self) -> ffi::TVMFunctionHandle { + self.handle + } + + /// Returns `true` if the underlying TVM function is global and `false` otherwise. + pub fn is_global(&self) -> bool { + self.is_global + } + + /// Calls the function that created from `Builder`. + pub fn invoke<'a>(&self, arg_buf: Vec>) -> Result { + let num_args = arg_buf.len(); + let (mut values, mut type_codes): (Vec, Vec) = + arg_buf.iter().map(|arg| arg.to_tvm_value()).unzip(); + let mut ret_val = ffi::TVMValue { v_int64: 0 }; + let mut ret_type_code = 0i32; + + check_call!(ffi::TVMFuncCall( + self.handle, + values.as_mut_ptr() as *mut ffi::TVMValue, + type_codes.as_mut_ptr() as *mut c_int, + num_args as c_int, + &mut ret_val as *mut _, + &mut ret_type_code as *mut _ + )); + + Ok(RetValue::from_tvm_value(ret_val, ret_type_code as u32)) + } + + pub fn to_boxed_fn(self) -> Box + where + F: ToBoxedFn, + { + F::to_boxed_fn(self) + } +} + +impl Clone for Function { + fn clone(&self) -> Function { + Self { + handle: self.handle, + is_global: self.is_global, + from_rust: true, + } + } +} + +// impl Drop for Function { +// fn drop(&mut self) { +// if !self.is_global && !self.is_cloned { +// check_call!(ffi::TVMFuncFree(self.handle)); +// } +// } +// } + +impl From for RetValue { + fn from(func: Function) -> RetValue { + RetValue::FuncHandle(func.handle) + } +} + +impl TryFrom for Function { + type Error = Error; + + fn try_from(ret_value: RetValue) -> Result { + match ret_value { + RetValue::FuncHandle(handle) => Ok(Function::new(handle)), + _ => Err(Error::downcast( + format!("{:?}", ret_value), + "FunctionHandle", + )), + } + } +} + +impl<'a> From for ArgValue<'a> { + fn from(func: Function) -> ArgValue<'a> { + ArgValue::FuncHandle(func.handle) + } +} + +impl<'a> TryFrom> for Function { + type Error = Error; + + fn try_from(arg_value: ArgValue<'a>) -> Result { + match arg_value { + ArgValue::FuncHandle(handle) => Ok(Function::new(handle)), + _ => Err(Error::downcast( + format!("{:?}", arg_value), + "FunctionHandle", + )), + } + } +} + +impl<'a> TryFrom<&ArgValue<'a>> for Function { + type Error = Error; + + fn try_from(arg_value: &ArgValue<'a>) -> Result { + match arg_value { + ArgValue::FuncHandle(handle) => Ok(Function::new(*handle)), + _ => Err(Error::downcast( + format!("{:?}", arg_value), + "FunctionHandle", + )), + } + } +} + +/// Registers a Rust function with an arbitrary type signature in +/// the TVM registry. +/// +/// +/// A function is convertible if and only if its arguments and return types are convertible +/// to and from TVM values respectively. +/// +/// Use [`register_override`] if control of overriding existing global TVM function +/// is required, this function will panic if a function is already registered. +/// +/// ## Example +/// +/// ``` +/// # use tvm_rt::{ArgValue, RetValue}; +/// # use tvm_rt::function::{Function, Result, register}; +/// +/// fn sum(x: i64, y: i64, z: i64) -> i64 { +/// x + y + z +/// } +/// +/// register(sum, "mysum".to_owned()).unwrap(); +/// let func = Function::get("mysum").unwrap(); +/// let boxed_fn = func.to_boxed_fn:: Result>(); +/// let ret = boxed_fn(10, 20, 30).unwrap(); +/// assert_eq!(ret, 60); +/// ``` +pub fn register>(f: F, name: S) -> Result<()> +where + F: ToFunction, + F: Typed, +{ + register_override(f, name, false) +} + +/// Register a function with explicit control over whether to override an existing registration or not. +/// +/// See `register` for more details on how to use the registration API. +pub fn register_override>(f: F, name: S, override_: bool) -> Result<()> +where + F: ToFunction, + F: Typed, +{ + let func = f.to_function(); + let name = name.into(); + // Not sure about this code + let handle = func.handle(); + let name = CString::new(name)?; + check_call!(ffi::TVMFuncRegisterGlobal( + name.into_raw(), + handle, + override_ as c_int + )); + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::function::Function; + + static CANARY: &str = "runtime.ModuleLoadFromFile"; + + #[test] + fn get_fn() { + assert!(Function::get(CANARY).is_some()); + assert!(Function::get("does not exists!").is_none()); + } + + #[test] + fn register_and_call_closure0() { + use crate::function; + use function::Result; + + fn constfn() -> i64 { + return 10; + } + + function::register_override(constfn, "constfn".to_owned(), true).unwrap(); + + let func = Function::get_boxed:: Result, _>("constfn").unwrap(); + let ret = func().unwrap(); + assert_eq!(ret, 10); + } + + #[test] + fn register_and_call_closure1() { + use crate::function::{self}; + + fn ident(x: i64) -> i64 { + return x; + } + + function::register_override(ident, "ident".to_owned(), true).unwrap(); + let func = Function::get_boxed:: Result, _>("ident").unwrap(); + assert_eq!(func(60).unwrap(), 60); + } +} diff --git a/rust/tvm-rt/src/lib.rs b/rust/tvm-rt/src/lib.rs new file mode 100644 index 000000000000..10f8317bf7bd --- /dev/null +++ b/rust/tvm-rt/src/lib.rs @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +//! [TVM](https://github.com/apache/incubator-tvm) is a compiler stack for deep learning systems. +//! +//! This crate provides an idiomatic Rust API for TVM runtime. +//! +//! The TVM runtime API contains the data structures used by higher-level TVM executors. +//! Specifically it exposes the basic types such as NDArray, as well as the more general object system. +//! The TVM object system enables cross-language interoperability including that of closures for all +//! supported languages including C++, and Python. + +pub mod object; +pub mod string; + +pub use object::*; +pub use string::*; + +use std::{ + ffi::{CStr, CString}, + str, +}; + +pub use crate::{ + context::{Context, DeviceType}, + errors::*, + function::Function, + module::Module, + ndarray::NDArray, +}; + +pub use function::{ArgValue, RetValue}; +pub use tvm_sys::byte_array::ByteArray; +pub use tvm_sys::datatype::DataType; +use tvm_sys::ffi; + +pub use tvm_macros::external; + +// Macro to check the return call to TVM runtime shared library. + +#[macro_export] +macro_rules! tvm_call { + ($e:expr) => {{ + if unsafe { $e } != 0 { + Err($crate::get_last_error().into()) + } else { + Ok(()) + } + }}; +} + +#[macro_export] +macro_rules! check_call { + ($e:expr) => {{ + if unsafe { $e } != 0 { + panic!("{}", $crate::get_last_error()); + } + }}; +} + +/// Gets the last error message. +pub fn get_last_error() -> &'static str { + unsafe { + match CStr::from_ptr(ffi::TVMGetLastError()).to_str() { + Ok(s) => s, + Err(_) => "Invalid UTF-8 message", + } + } +} + +pub(crate) fn set_last_error(err: &E) { + let c_string = CString::new(err.to_string()).unwrap(); + unsafe { + ffi::TVMAPISetLastError(c_string.as_ptr()); + } +} + +#[macro_use] +pub mod function; +pub mod context; +pub mod errors; +pub mod module; +pub mod ndarray; +pub mod to_boxed_fn; +mod to_function; +pub mod value; + +/// Outputs the current TVM version. +pub fn version() -> &'static str { + match str::from_utf8(ffi::TVM_VERSION) { + Ok(s) => s, + Err(_) => "Invalid UTF-8 string", + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn print_version() { + println!("TVM version: {}", version()); + } + + #[test] + fn set_error() { + let err = errors::NDArrayError::EmptyArray; + set_last_error(&err); + assert_eq!( + get_last_error().trim(), + errors::NDArrayError::EmptyArray.to_string() + ); + } +} diff --git a/rust/tvm-rt/src/module.rs b/rust/tvm-rt/src/module.rs new file mode 100644 index 000000000000..b540c1ba9981 --- /dev/null +++ b/rust/tvm-rt/src/module.rs @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +//! Provides the [`Module`] type and methods for working with runtime TVM modules. + +use std::{ + ffi::CString, + os::raw::{c_char, c_int}, + path::Path, + ptr, +}; + +use tvm_sys::ffi; + +use crate::errors::Error; +use crate::{errors, function::Function}; + +const ENTRY_FUNC: &str = "__tvm_main__"; + +/// Wrapper around TVM module handle which contains an entry function. +/// The entry function can be applied to an imported module through [`entry_func`]. +/// +/// [`entry_func`]:struct.Module.html#method.entry_func +#[derive(Debug, Clone)] +pub struct Module { + pub(crate) handle: ffi::TVMModuleHandle, + entry_func: Option, +} + +crate::external! { + #[name("runtime.RuntimeEnabled")] + fn runtime_enabled(target: CString) -> i32; + + #[name("runtime.ModuleLoadFromFile")] + fn load_from_file(file_name: CString, format: CString) -> Module; +} + +impl Module { + pub(crate) fn new(handle: ffi::TVMModuleHandle) -> Self { + Self { + handle, + entry_func: None, + } + } + + pub fn entry(&mut self) -> Option { + if self.entry_func.is_none() { + self.entry_func = self.get_function(ENTRY_FUNC, false).ok(); + } + self.entry_func.clone() + } + + /// Gets a function by name from a registered module. + pub fn get_function(&self, name: &str, query_import: bool) -> Result { + let name = CString::new(name)?; + let mut fhandle = ptr::null_mut() as ffi::TVMFunctionHandle; + check_call!(ffi::TVMModGetFunction( + self.handle, + name.as_ptr() as *const c_char, + query_import as c_int, + &mut fhandle as *mut _ + )); + + if !fhandle.is_null() { + return Err(errors::Error::NullHandle(name.into_string()?.to_string())); + } + + Ok(Function::new(fhandle)) + } + + /// Imports a dependent module such as `.ptx` for gpu. + pub fn import_module(&self, dependent_module: Module) { + check_call!(ffi::TVMModImport(self.handle, dependent_module.handle)) + } + + /// Loads a module shared library from path. + pub fn load>(path: &P) -> Result { + let ext = CString::new( + path.as_ref() + .extension() + .unwrap_or_else(|| std::ffi::OsStr::new("")) + .to_str() + .ok_or_else(|| Error::ModuleLoadPath(path.as_ref().display().to_string()))?, + )?; + + let cpath = CString::new( + path.as_ref() + .to_str() + .ok_or_else(|| Error::ModuleLoadPath(path.as_ref().display().to_string()))?, + )?; + + let module = load_from_file(cpath, ext)?; + Ok(module) + } + + /// Checks if a target device is enabled for a module. + pub fn enabled(&self, target: &str) -> bool { + let target = CString::new(target).unwrap(); + let enabled = runtime_enabled(target).unwrap(); + enabled != 0 + } + + /// Returns the underlying module handle. + pub fn handle(&self) -> ffi::TVMModuleHandle { + self.handle + } +} + +impl Drop for Module { + fn drop(&mut self) { + check_call!(ffi::TVMModFree(self.handle)); + } +} diff --git a/rust/tvm-rt/src/ndarray.rs b/rust/tvm-rt/src/ndarray.rs new file mode 100644 index 000000000000..b7ae4622849d --- /dev/null +++ b/rust/tvm-rt/src/ndarray.rs @@ -0,0 +1,438 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +//! This module implements the [`NDArray`] type for working with *TVM tensors* or +//! coverting from a Rust's ndarray to TVM `NDArray`. +//! +//! One can create an empty NDArray given the shape, context and dtype using [`empty`]. +//! To create an NDArray from a mutable buffer in cpu use [`copy_from_buffer`]. +//! To copy an NDArray to different context use [`copy_to_ctx`]. +//! +//! Given a [`Rust's dynamic ndarray`], one can convert it to TVM NDArray as follows: +//! +//! # Example +//! +//! ``` +//! # use tvm_rt::{NDArray, Context, DataType}; +//! # use ndarray::{Array, ArrayD}; +//! # use std::str::FromStr; +//! use std::convert::TryFrom; +//! +//! let a = Array::from_shape_vec((2, 2), vec![1f32, 2., 3., 4.]) +//! .unwrap() +//! .into_dyn(); // Rust's ndarray +//! let nd = NDArray::from_rust_ndarray(&a, Context::cpu(0), DataType::from_str("float32").unwrap()).unwrap(); +//! assert_eq!(nd.shape(), Some(&mut [2, 2][..])); +//! let rnd: ArrayD = ArrayD::try_from(&nd).unwrap(); +//! assert!(rnd.all_close(&a, 1e-8f32)); +//! ``` +//! +//! [`Rust's dynamic ndarray`]:https://docs.rs/ndarray/0.12.1/ndarray/ +//! [`copy_from_buffer`]:struct.NDArray.html#method.copy_from_buffer +//! [`copy_to_ctx`]:struct.NDArray.html#method.copy_to_ctx + +use std::convert::TryInto; +use std::ffi::c_void; +use std::{convert::TryFrom, mem, os::raw::c_int, ptr, slice, str::FromStr}; + +use crate::errors::NDArrayError; + +use tvm_sys::ffi::DLTensor; +use tvm_sys::{ffi, ByteArray, Context, DataType}; + +use ndarray::{Array, ArrayD}; +use num_traits::Num; + +/// See the [`module-level documentation`](../ndarray/index.html) for more details. +/// +/// Wrapper around TVM array handle. +#[derive(Debug)] +pub enum NDArray { + Borrowed { handle: ffi::TVMArrayHandle }, + Owned { handle: *mut c_void }, +} + +impl NDArray { + pub(crate) fn new(handle: ffi::TVMArrayHandle) -> Self { + NDArray::Borrowed { handle } + } + + pub(crate) fn from_ndarray_handle(handle: *mut c_void) -> Self { + NDArray::Owned { handle } + } + + pub fn as_dltensor(&self) -> &DLTensor { + let ptr: *mut DLTensor = match self { + NDArray::Borrowed { ref handle } => *handle, + NDArray::Owned { ref handle } => *handle as *mut DLTensor, + }; + + unsafe { std::mem::transmute(ptr) } + } + + pub(crate) fn as_raw_dltensor(&self) -> *mut DLTensor { + match self { + NDArray::Borrowed { handle } => *handle, + NDArray::Owned { handle } => *handle as *mut DLTensor, + } + } + + pub fn is_view(&self) -> bool { + if let &NDArray::Borrowed { .. } = self { + true + } else { + false + } + } + + /// Returns the shape of the NDArray. + pub fn shape(&self) -> Option<&mut [usize]> { + let arr = self.as_dltensor(); + if arr.shape.is_null() || arr.data.is_null() { + return None; + }; + let slc = unsafe { slice::from_raw_parts_mut(arr.shape as *mut usize, arr.ndim as usize) }; + Some(slc) + } + + /// Returns the total number of entries of the NDArray. + pub fn size(&self) -> Option { + self.shape().map(|v| v.iter().product()) + } + + /// Returns the context which the NDArray was defined. + pub fn ctx(&self) -> Context { + self.as_dltensor().ctx.into() + } + + /// Returns the type of the entries of the NDArray. + pub fn dtype(&self) -> DataType { + self.as_dltensor().dtype.into() + } + + /// Returns the number of dimensions of the NDArray. + pub fn ndim(&self) -> usize { + self.as_dltensor() + .ndim + .try_into() + .expect("number of dimensions must always be positive") + } + + /// Returns the strides of the underlying NDArray. + pub fn strides(&self) -> Option<&[usize]> { + unsafe { + let sz = self.ndim() * mem::size_of::(); + let strides_ptr = self.as_dltensor().strides as *const usize; + let slc = slice::from_raw_parts(strides_ptr, sz); + Some(slc) + } + } + + /// Shows whether the underlying ndarray is contiguous in memory or not. + pub fn is_contiguous(&self) -> Result { + Ok(match self.strides() { + None => true, + Some(strides) => { + // NDArrayError::MissingShape in case shape is not determined + self.shape() + .ok_or(NDArrayError::MissingShape)? + .iter() + .zip(strides) + .rfold( + (true, 1), + |(is_contig, expected_stride), (shape, stride)| { + ( + is_contig && *stride == expected_stride, + expected_stride * (*shape as usize), + ) + }, + ) + .0 + } + }) + } + + pub fn byte_offset(&self) -> isize { + self.as_dltensor().byte_offset as isize + } + + /// Flattens the NDArray to a `Vec` of the same type in cpu. + /// + /// ## Example + /// + /// ``` + /// # use tvm_rt::{Context, DataType, NDArray}; + /// # use std::str::FromStr; + /// let mut shape = [4]; + /// let mut data = vec![1i32, 2, 3, 4]; + /// let ctx = Context::cpu(0); + /// let mut ndarray = NDArray::empty(&mut shape, ctx, DataType::from_str("int32").unwrap()); + /// ndarray.copy_from_buffer(&mut data); + /// assert_eq!(ndarray.shape(), Some(&mut shape[..])); + /// assert_eq!(ndarray.to_vec::().unwrap(), data); + /// ``` + pub fn to_vec(&self) -> Result, NDArrayError> { + if !self.shape().is_some() { + return Err(NDArrayError::EmptyArray); + } + let earr = NDArray::empty( + self.shape().ok_or(NDArrayError::MissingShape)?, + Context::cpu(0), + self.dtype(), + ); + let target = self.copy_to_ndarray(earr)?; + let arr = target.as_dltensor(); + let sz = self.size().ok_or(NDArrayError::MissingShape)?; + let mut v: Vec = Vec::with_capacity(sz * mem::size_of::()); + unsafe { + v.as_mut_ptr() + .copy_from_nonoverlapping(arr.data as *const T, sz); + v.set_len(sz); + } + Ok(v) + } + + /// Converts the NDArray to [`ByteArray`]. + pub fn to_bytearray(&self) -> Result { + let v = self.to_vec::()?; + Ok(ByteArray::from(v)) + } + + /// Creates an NDArray from a mutable buffer of types i32, u32 or f32 in cpu. + /// + /// ## Example + /// + /// ``` + /// # use tvm_rt::{Context, DataType, NDArray}; + /// # use std::str::FromStr; + /// let shape = &mut [2]; + /// let mut data = vec![1f32, 2.0]; + /// let ctx = Context::cpu(0); + /// let mut ndarray = NDArray::empty(shape, ctx, DataType::from_str("int32").unwrap()); + /// ndarray.copy_from_buffer(&mut data); + /// ``` + /// + /// *Note*: if something goes wrong during the copy, it will panic + /// from TVM side. See `TVMArrayCopyFromBytes` in `include/tvm/runtime/c_runtime_api.h`. + pub fn copy_from_buffer(&mut self, data: &mut [T]) { + check_call!(ffi::TVMArrayCopyFromBytes( + self.as_raw_dltensor(), + data.as_ptr() as *mut _, + data.len() * mem::size_of::() + )); + } + + /// Copies the NDArray to another target NDArray. + pub fn copy_to_ndarray(&self, target: NDArray) -> Result { + if self.dtype() != target.dtype() { + return Err(NDArrayError::DataTypeMismatch { + expected: self.dtype(), + actual: target.dtype(), + }); + } + + check_call!(ffi::TVMArrayCopyFromTo( + self.as_raw_dltensor(), + target.as_raw_dltensor(), + ptr::null_mut() as ffi::TVMStreamHandle + )); + + Ok(target) + } + + /// Copies the NDArray to a target context. + pub fn copy_to_ctx(&self, target: &Context) -> Result { + let tmp = NDArray::empty( + self.shape().ok_or(NDArrayError::MissingShape)?, + *target, + self.dtype(), + ); + let copy = self.copy_to_ndarray(tmp)?; + Ok(copy) + } + + /// Converts a Rust's ndarray to TVM NDArray. + pub fn from_rust_ndarray( + rnd: &ArrayD, + ctx: Context, + dtype: DataType, + ) -> Result { + let shape = rnd.shape().to_vec(); + let mut nd = NDArray::empty(&shape, ctx, dtype); + let mut buf = Array::from_iter(rnd.into_iter().map(|&v| v as T)); + nd.copy_from_buffer( + buf.as_slice_mut() + .expect("Array from iter must be contiguous."), + ); + Ok(nd) + } + + /// Allocates and creates an empty NDArray given the shape, context and dtype. + pub fn empty(shape: &[usize], ctx: Context, dtype: DataType) -> NDArray { + let mut handle = ptr::null_mut() as ffi::TVMArrayHandle; + let dtype: tvm_sys::ffi::DLDataType = dtype.into(); + check_call!(ffi::TVMArrayAlloc( + shape.as_ptr() as *const i64, + shape.len() as c_int, + i32::from(dtype.code) as c_int, + i32::from(dtype.bits) as c_int, + i32::from(dtype.lanes) as c_int, + ctx.device_type as c_int, + ctx.device_id as c_int, + &mut handle as *mut _, + )); + NDArray::Borrowed { handle: handle } + } +} + +macro_rules! impl_from_ndarray_rustndarray { + ($type:ty, $type_name:tt) => { + impl<'a> TryFrom<&'a NDArray> for ArrayD<$type> { + type Error = NDArrayError; + + fn try_from(nd: &NDArray) -> Result, Self::Error> { + if !nd.shape().is_some() { + return Err(NDArrayError::MissingShape); + } + assert_eq!(nd.dtype(), DataType::from_str($type_name)?, "Type mismatch"); + Ok(Array::from_shape_vec( + &*nd.shape().ok_or(NDArrayError::MissingShape)?, + nd.to_vec::<$type>()?, + )?) + } + } + + impl<'a> TryFrom<&'a mut NDArray> for ArrayD<$type> { + type Error = NDArrayError; + + fn try_from(nd: &mut NDArray) -> Result, Self::Error> { + if !nd.shape().is_some() { + return Err(NDArrayError::MissingShape); + }; + assert_eq!(nd.dtype(), DataType::from_str($type_name)?, "Type mismatch"); + Ok(Array::from_shape_vec( + &*nd.shape().ok_or(NDArrayError::MissingShape)?, + nd.to_vec::<$type>()?, + )?) + } + } + }; +} + +impl_from_ndarray_rustndarray!(i32, "int"); +impl_from_ndarray_rustndarray!(u32, "uint"); +impl_from_ndarray_rustndarray!(f32, "float"); + +impl Drop for NDArray { + fn drop(&mut self) { + if let &mut NDArray::Owned { .. } = self { + check_call!(ffi::TVMArrayFree(self.as_raw_dltensor())); + } + } +} + +mod sealed { + /// Private trait to prevent other traits from being implemeneted in downstream crates. + pub trait Sealed {} +} + +/// A trait for the supported 32-bits numerical types in frontend. +pub trait Num32: Num + sealed::Sealed { + const BITS: u8 = 32; +} + +macro_rules! impl_num32 { + ($($type:ty),+) => { + $( + impl sealed::Sealed for $type {} + impl Num32 for $type {} + )+ + }; +} + +impl_num32!(i32, u32, f32); + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn basics() { + let shape = &mut [1, 2, 3]; + let ctx = Context::cpu(0); + let ndarray = NDArray::empty(shape, ctx, DataType::from_str("int32").unwrap()); + assert_eq!(ndarray.shape().unwrap(), shape); + assert_eq!( + ndarray.size().unwrap(), + shape.to_vec().into_iter().product() + ); + assert_eq!(ndarray.ndim(), 3); + assert!(ndarray.strides().is_none()); + assert_eq!(ndarray.byte_offset(), 0); + } + + #[test] + fn copy() { + let shape = &mut [4]; + let mut data = vec![1i32, 2, 3, 4]; + let ctx = Context::cpu(0); + let mut ndarray = NDArray::empty(shape, ctx, DataType::from_str("int32").unwrap()); + assert!(ndarray.to_vec::().is_ok()); + ndarray.copy_from_buffer(&mut data); + assert_eq!(ndarray.shape().unwrap(), shape); + assert_eq!(ndarray.to_vec::().unwrap(), data); + assert_eq!(ndarray.ndim(), 1); + assert!(ndarray.is_contiguous().is_ok()); + assert_eq!(ndarray.byte_offset(), 0); + let shape = vec![4]; + let e = NDArray::empty( + &shape, + Context::cpu(0), + DataType::from_str("int32").unwrap(), + ); + let nd = ndarray.copy_to_ndarray(e); + assert!(nd.is_ok()); + assert_eq!(nd.unwrap().to_vec::().unwrap(), data); + } + + #[test] + #[should_panic(expected = "called `Result::unwrap()` on an `Err`")] + fn copy_wrong_dtype() { + let shape = vec![4]; + let mut data = vec![1f32, 2., 3., 4.]; + let ctx = Context::cpu(0); + let mut nd_float = NDArray::empty(&shape, ctx, DataType::from_str("float32").unwrap()); + nd_float.copy_from_buffer(&mut data); + let empty_int = NDArray::empty(&shape, ctx, DataType::from_str("int32").unwrap()); + nd_float.copy_to_ndarray(empty_int).unwrap(); + } + + #[test] + fn rust_ndarray() { + let a = Array::from_shape_vec((2, 2), vec![1f32, 2., 3., 4.]) + .unwrap() + .into_dyn(); + let nd = + NDArray::from_rust_ndarray(&a, Context::cpu(0), DataType::from_str("float32").unwrap()) + .unwrap(); + assert_eq!(nd.shape().unwrap(), &mut [2, 2]); + let rnd: ArrayD = ArrayD::try_from(&nd).unwrap(); + assert!(rnd.all_close(&a, 1e-8f32)); + } +} diff --git a/rust/tvm-rt/src/object/mod.rs b/rust/tvm-rt/src/object/mod.rs new file mode 100644 index 000000000000..c49f84e2d916 --- /dev/null +++ b/rust/tvm-rt/src/object/mod.rs @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::convert::TryFrom; +use std::convert::TryInto; +use std::ffi::CString; + +use crate::errors::Error; +use crate::external; + +use tvm_sys::{ArgValue, RetValue}; + +mod object_ptr; + +pub use object_ptr::{IsObject, Object, ObjectPtr}; + +#[derive(Clone)] +pub struct ObjectRef(pub Option>); + +impl ObjectRef { + pub fn null() -> ObjectRef { + ObjectRef(None) + } +} + +pub trait ToObjectRef { + fn to_object_ref(&self) -> ObjectRef; +} + +impl ToObjectRef for ObjectRef { + fn to_object_ref(&self) -> ObjectRef { + self.clone() + } +} + +impl TryFrom for ObjectRef { + type Error = Error; + + fn try_from(ret_val: RetValue) -> Result { + let optr = ret_val.try_into()?; + Ok(ObjectRef(Some(optr))) + } +} + +impl From for RetValue { + fn from(object_ref: ObjectRef) -> RetValue { + use std::ffi::c_void; + let object_ptr = object_ref.0; + match object_ptr { + None => RetValue::ObjectHandle(std::ptr::null::() as *mut c_void), + Some(value) => value.clone().into(), + } + } +} + +impl<'a> std::convert::TryFrom> for ObjectRef { + type Error = Error; + + fn try_from(arg_value: ArgValue<'a>) -> Result { + let optr = arg_value.try_into()?; + Ok(ObjectRef(Some(optr))) + } +} + +impl<'a> std::convert::TryFrom<&ArgValue<'a>> for ObjectRef { + type Error = Error; + + fn try_from(arg_value: &ArgValue<'a>) -> Result { + // TODO(@jroesch): remove the clone + let value: ArgValue<'a> = arg_value.clone(); + ObjectRef::try_from(value) + } +} + +impl<'a> From for ArgValue<'a> { + fn from(object_ref: ObjectRef) -> ArgValue<'a> { + use std::ffi::c_void; + let object_ptr = &object_ref.0; + match object_ptr { + None => ArgValue::ObjectHandle(std::ptr::null::() as *mut c_void), + Some(value) => value.clone().into(), + } + } +} + +impl<'a> From<&ObjectRef> for ArgValue<'a> { + fn from(object_ref: &ObjectRef) -> ArgValue<'a> { + let oref: ObjectRef = object_ref.clone(); + ArgValue::<'a>::from(oref) + } +} + +external! { + #[name("ir.DebugPrint")] + fn debug_print(object: ObjectRef) -> CString; +} + +// external! { +// #[name("ir.TextPrinter")] +// fn as_text(object: ObjectRef) -> CString; +// } diff --git a/rust/tvm-rt/src/object/object_ptr.rs b/rust/tvm-rt/src/object/object_ptr.rs new file mode 100644 index 000000000000..40e218454f6a --- /dev/null +++ b/rust/tvm-rt/src/object/object_ptr.rs @@ -0,0 +1,353 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::convert::TryFrom; +use std::ffi::CString; +use std::ptr::NonNull; +use std::sync::atomic::AtomicI32; + +use tvm_sys::ffi::{self, TVMObjectFree, TVMObjectRetain, TVMObjectTypeKey2Index}; +use tvm_sys::{ArgValue, RetValue}; + +use crate::errors::Error; + +type Deleter = unsafe extern "C" fn(object: *mut Object) -> (); + +#[derive(Debug)] +#[repr(C)] +pub struct Object { + pub type_index: u32, + // TODO(@jroesch): pretty sure Rust and C++ atomics are the same, but not sure. + // NB: in general we should not touch this in Rust. + pub(self) ref_count: AtomicI32, + pub fdeleter: Deleter, +} + +unsafe extern "C" fn delete(object: *mut Object) { + let typed_object: *mut T = std::mem::transmute(object); + T::typed_delete(typed_object); +} + +fn derived_from(child_type_index: u32, parent_type_index: u32) -> bool { + let mut is_derived = 0; + crate::check_call!(ffi::TVMObjectDerivedFrom( + child_type_index, + parent_type_index, + &mut is_derived + )); + + if is_derived == 0 { + false + } else { + true + } +} + +impl Object { + fn new(type_index: u32, deleter: Deleter) -> Object { + Object { + type_index, + // Note: do not touch this field directly again, this is + // a critical section, we write a 1 to the atomic which will now + // be managed by the C++ atomics. + // In the future we should probably use C-atomcis. + ref_count: AtomicI32::new(0), + fdeleter: deleter, + } + } + + fn get_type_index() -> u32 { + let type_key = T::TYPE_KEY; + let cstring = CString::new(type_key).expect("type key must not contain null characters"); + if type_key == "Object" { + return 0; + } else { + let mut index = 0; + unsafe { + let index_ptr = std::mem::transmute(&mut index); + if TVMObjectTypeKey2Index(cstring.as_ptr(), index_ptr) != 0 { + panic!(crate::get_last_error()) + } + } + return index; + } + } + + pub fn base_object() -> Object { + let index = Object::get_type_index::(); + Object::new(index, delete::) + } + + pub(self) fn inc_ref(&self) { + unsafe { + let raw_ptr = std::mem::transmute(self); + assert_eq!(TVMObjectRetain(raw_ptr), 0); + } + } + + pub(self) fn dec_ref(&self) { + unsafe { + let raw_ptr = std::mem::transmute(self); + assert_eq!(TVMObjectFree(raw_ptr), 0); + } + } +} + +pub unsafe trait IsObject { + const TYPE_KEY: &'static str; + + fn as_object<'s>(&'s self) -> &'s Object; + + unsafe extern "C" fn typed_delete(object: *mut Self) { + let object = Box::from_raw(object); + drop(object) + } +} + +unsafe impl IsObject for Object { + const TYPE_KEY: &'static str = "Object"; + + fn as_object<'s>(&'s self) -> &'s Object { + self + } +} + +#[repr(C)] +pub struct ObjectPtr { + pub ptr: NonNull, +} + +fn inc_ref(ptr: NonNull) { + unsafe { ptr.as_ref().as_object().inc_ref() } +} + +fn dec_ref(ptr: NonNull) { + unsafe { ptr.as_ref().as_object().dec_ref() } +} + +impl ObjectPtr { + fn from_raw(object_ptr: *mut Object) -> Option> { + let non_null = NonNull::new(object_ptr); + non_null.map(|ptr| ObjectPtr { ptr }) + } +} + +impl Clone for ObjectPtr { + fn clone(&self) -> Self { + inc_ref(self.ptr); + ObjectPtr { ptr: self.ptr } + } +} + +impl Drop for ObjectPtr { + fn drop(&mut self) { + dec_ref(self.ptr); + } +} + +impl ObjectPtr { + pub fn leak<'a>(object_ptr: ObjectPtr) -> &'a mut T + where + T: 'a, + { + unsafe { &mut *std::mem::ManuallyDrop::new(object_ptr).ptr.as_ptr() } + } + + pub fn new(object: T) -> ObjectPtr { + let object_ptr = Box::new(object); + let object_ptr = Box::leak(object_ptr); + let ptr = NonNull::from(object_ptr); + inc_ref(ptr); + ObjectPtr { ptr } + } + + pub fn count(&self) -> i32 { + // need to do atomic read in C++ + // ABI compatible atomics is funky/hard. + self.as_object() + .ref_count + .load(std::sync::atomic::Ordering::SeqCst) + } + + fn as_object<'s>(&'s self) -> &'s Object { + unsafe { self.ptr.as_ref().as_object() } + } + + pub fn upcast(&self) -> ObjectPtr { + ObjectPtr { + ptr: self.ptr.cast(), + } + } + + pub fn downcast(&self) -> Result, Error> { + let child_index = Object::get_type_index::(); + let object_index = self.as_object().type_index; + + let is_derived = if child_index == object_index { + true + } else { + // TODO(@jroesch): write tests + derived_from(object_index, child_index) + }; + + if is_derived { + Ok(ObjectPtr { + ptr: self.ptr.cast(), + }) + } else { + Err(Error::downcast("TODOget_type_key".into(), U::TYPE_KEY)) + } + } +} + +impl std::ops::Deref for ObjectPtr { + type Target = T; + + fn deref(&self) -> &Self::Target { + unsafe { self.ptr.as_ref() } + } +} + +impl<'a, T: IsObject> From> for RetValue { + fn from(object_ptr: ObjectPtr) -> RetValue { + let raw_object_ptr = ObjectPtr::leak(object_ptr); + let void_ptr = unsafe { std::mem::transmute(raw_object_ptr) }; + RetValue::ObjectHandle(void_ptr) + } +} + +impl<'a, T: IsObject> TryFrom for ObjectPtr { + type Error = Error; + + fn try_from(ret_value: RetValue) -> Result, Self::Error> { + match ret_value { + RetValue::ObjectHandle(handle) => { + let handle: *mut Object = unsafe { std::mem::transmute(handle) }; + let optr = ObjectPtr::from_raw(handle).ok_or(Error::Null)?; + optr.downcast() + } + _ => Err(Error::downcast(format!("{:?}", ret_value), "ObjectHandle")), + } + } +} + +impl<'a, T: IsObject> From> for ArgValue<'a> { + fn from(object_ptr: ObjectPtr) -> ArgValue<'a> { + let raw_object_ptr = ObjectPtr::leak(object_ptr); + let void_ptr = unsafe { std::mem::transmute(raw_object_ptr) }; + ArgValue::ObjectHandle(void_ptr) + } +} + +impl<'a, T: IsObject> TryFrom> for ObjectPtr { + type Error = Error; + + fn try_from(arg_value: ArgValue<'a>) -> Result, Self::Error> { + match arg_value { + ArgValue::ObjectHandle(handle) => { + let handle = unsafe { std::mem::transmute(handle) }; + let optr = ObjectPtr::from_raw(handle).ok_or(Error::Null)?; + optr.downcast() + } + _ => Err(Error::downcast(format!("{:?}", arg_value), "ObjectHandle")), + } + } +} + +impl<'a, T: IsObject> TryFrom<&ArgValue<'a>> for ObjectPtr { + type Error = Error; + + fn try_from(arg_value: &ArgValue<'a>) -> Result, Self::Error> { + match arg_value { + ArgValue::ObjectHandle(handle) => { + let handle = unsafe { std::mem::transmute(handle) }; + let optr = ObjectPtr::from_raw(handle).ok_or(Error::Null)?; + optr.downcast() + } + _ => Err(Error::downcast(format!("{:?}", arg_value), "ObjectHandle")), + } + } +} + +#[cfg(test)] +mod tests { + use super::{Object, ObjectPtr}; + use anyhow::{ensure, Result}; + use std::convert::TryInto; + use tvm_sys::{ArgValue, RetValue}; + + #[test] + fn test_new_object() -> anyhow::Result<()> { + let object = Object::base_object::(); + let ptr = ObjectPtr::new(object); + assert_eq!(ptr.count(), 1); + Ok(()) + } + + #[test] + fn roundtrip_retvalue() -> Result<()> { + let ptr = ObjectPtr::new(Object::base_object::()); + let ret_value: RetValue = ptr.clone().into(); + let ptr2: ObjectPtr = ret_value.try_into()?; + ensure!( + ptr.type_index == ptr2.type_index, + "type indices do not match" + ); + ensure!( + ptr.fdeleter == ptr2.fdeleter, + "objects have different deleters" + ); + Ok(()) + } + + #[test] + fn roundtrip_argvalue() -> Result<()> { + let ptr = ObjectPtr::new(Object::base_object::()); + let arg_value: ArgValue = ptr.clone().into(); + let ptr2: ObjectPtr = arg_value.try_into()?; + ensure!( + ptr.type_index == ptr2.type_index, + "type indices do not match" + ); + ensure!( + ptr.fdeleter == ptr2.fdeleter, + "objects have different deleters" + ); + Ok(()) + } + + fn test_fn(o: ObjectPtr) -> ObjectPtr { + assert_eq!(o.count(), 2); + return o; + } + + #[test] + fn test_ref_count_boundary() { + use super::*; + use crate::function::{register, Function, Result}; + let ptr = ObjectPtr::new(Object::base_object::()); + let stay = ptr.clone(); + assert_eq!(ptr.count(), 2); + register(test_fn, "my_func").unwrap(); + let func = Function::get("my_func").unwrap(); + let func = func.to_boxed_fn::) -> Result>>(); + func(ptr).unwrap(); + assert_eq!(stay.count(), 1); + } +} diff --git a/rust/tvm-rt/src/string.rs b/rust/tvm-rt/src/string.rs new file mode 100644 index 000000000000..26758b1170e7 --- /dev/null +++ b/rust/tvm-rt/src/string.rs @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::ffi::{CString, NulError}; +use std::os::raw::c_char; + +use super::errors::Error; +use super::{Object, ObjectPtr, ObjectRef}; + +use tvm_macros::Object; + +#[repr(C)] +#[derive(Object)] +#[ref_name = "String"] +#[type_key = "runtime.String"] +pub struct StringObj { + base: Object, + data: *const c_char, + size: u64, +} + +impl String { + pub fn new(string: std::string::String) -> Result { + let cstring = CString::new(string)?; + + // The string is being corrupted. + // why is this wrong + let length = cstring.as_bytes().len(); + + let string_obj = StringObj { + base: Object::base_object::(), + data: cstring.into_raw(), + size: length as u64, + }; + + let object_ptr = ObjectPtr::new(string_obj); + Ok(String(Some(object_ptr))) + } + + pub fn to_cstring(&self) -> Result { + use std::slice; + let ptr = self.0.as_ref().unwrap().data; + let size = self.0.as_ref().unwrap().size; + unsafe { + let slice: &[u8] = slice::from_raw_parts(ptr as *const u8, size as usize); + CString::new(slice) + } + } + + pub fn to_string(&self) -> Result { + let string = self.to_cstring()?.into_string()?; + Ok(string) + } +} + +// #[cfg(test)] +// mod tests { +// use super::String; +// use crate::object::debug_print; +// use crate::ToObjectRef; +// use anyhow::{ensure, Result}; + +// #[test] +// fn test_string_debug() -> Result<()> { +// let s = String::new("foo".to_string()).unwrap(); +// let object_ref = s.to_object_ref(); +// println!("about to call"); +// let string = debug_print(object_ref)?; +// println!("after call"); +// ensure!( +// string.into_string().expect("is cstring").contains("foo"), +// "string content is invalid" +// ); +// Ok(()) +// } +// } diff --git a/rust/tvm-rt/src/to_boxed_fn.rs b/rust/tvm-rt/src/to_boxed_fn.rs new file mode 100644 index 000000000000..f0e5e80ff2ad --- /dev/null +++ b/rust/tvm-rt/src/to_boxed_fn.rs @@ -0,0 +1,227 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +//! This module provides a method for converting type erased TVM functions +//! into a boxed Rust closure. +//! +//! To call a registered function check the [`ToBoxedFn::to_boxed_fn`] method. +//! +//! See the tests and examples repository for more examples. + +pub use tvm_sys::{ffi, ArgValue, RetValue}; + +use crate::{errors, Module}; + +use super::function::{Function, Result}; + +pub trait ToBoxedFn { + fn to_boxed_fn(func: Function) -> Box; +} + +use std::convert::{TryFrom, TryInto}; + +impl ToBoxedFn for dyn Fn() -> Result +where + errors::Error: From, + O: TryFrom, +{ + fn to_boxed_fn(func: Function) -> Box { + Box::new(move || { + let mut builder = Builder::default(); + builder.func = Some(func.clone()); + let res = builder.invoke()?.try_into()?; + Ok(res) + }) + } +} + +impl ToBoxedFn for dyn Fn(A) -> Result +where + errors::Error: From, + A: Into>, + O: TryFrom, +{ + fn to_boxed_fn(func: Function) -> Box { + Box::new(move |a: A| { + let mut builder = Builder::default(); + builder.func = Some(func.clone()); + builder.arg(a.into()); + let res = builder.invoke()?.try_into()?; + Ok(res) + }) + } +} + +impl ToBoxedFn for dyn Fn(A, B) -> Result +where + errors::Error: From, + A: Into>, + B: Into>, + O: TryFrom, +{ + fn to_boxed_fn(func: Function) -> Box { + Box::new(move |a: A, b: B| { + let mut builder = Builder::default(); + builder.func = Some(func.clone()); + builder.arg(a.into()); + builder.arg(b.into()); + let res = builder.invoke()?.try_into()?; + Ok(res) + }) + } +} + +impl ToBoxedFn for dyn Fn(A, B, C) -> Result +where + errors::Error: From, + A: Into>, + B: Into>, + C: Into>, + O: TryFrom, +{ + fn to_boxed_fn(func: Function) -> Box { + Box::new(move |a: A, b: B, c: C| { + let mut builder = Builder::default(); + builder.func = Some(func.clone()); + builder.arg(a.into()); + builder.arg(b.into()); + builder.arg(c.into()); + let res = builder.invoke()?.try_into()?; + Ok(res) + }) + } +} + +impl ToBoxedFn for dyn Fn(A, B, C, D) -> Result +where + errors::Error: From, + A: Into>, + B: Into>, + C: Into>, + D: Into>, + O: TryFrom, +{ + fn to_boxed_fn(func: Function) -> Box { + Box::new(move |a: A, b: B, c: C, d: D| { + let mut builder = Builder::default(); + builder.func = Some(func.clone()); + builder.arg(a.into()); + builder.arg(b.into()); + builder.arg(c.into()); + builder.arg(d.into()); + let res = builder.invoke()?.try_into()?; + Ok(res) + }) + } +} + +/// Function builder in order to create and call functions. +/// +/// *Note:* Currently TVM functions accept *at most* one return value. +#[derive(Default)] +pub struct Builder<'a> { + pub func: Option, + pub arg_buf: Vec>, + pub ret_buf: Option, +} + +impl<'a, 'm> Builder<'a> { + pub fn new( + func: Option, + arg_buf: Vec>, + ret_buf: Option, + ) -> Self { + Self { + func, + arg_buf, + ret_buf, + } + } + + pub fn get_function(&mut self, name: &'m str) -> &mut Self { + self.func = Function::get(name); + self + } + + /// Pushes a [`ArgValue`] into the function argument buffer. + pub fn arg(&mut self, arg: T) -> &mut Self + where + ArgValue<'a>: From, + { + self.arg_buf.push(arg.into()); + self + } + + /// Pushes multiple [`ArgValue`]s into the function argument buffer. + pub fn args(&mut self, args: I) -> &mut Self + where + I: IntoIterator, + ArgValue<'a>: From, + { + args.into_iter().for_each(|arg| { + self.arg(arg); + }); + self + } + + /// Sets an output for a function that requires a mutable output to be provided. + /// See the `basics` in tests for an example. + pub fn set_output(&mut self, ret: T) -> &mut Self + where + RetValue: From, + { + self.ret_buf = Some(ret.into()); + self + } + + pub fn invoke(self) -> Result { + self.func.unwrap().invoke(self.arg_buf) + } +} + +/// Converts a [`Function`] to builder. Currently, this is the best way to work with +/// TVM functions. +impl<'a, 'm> From for Builder<'a> { + fn from(func: Function) -> Self { + Builder::new(Some(func), Vec::new(), None) + } +} + +/// Converts a mutable reference of a [`Module`] to [`Builder`]. +impl<'a, 'm> From<&'m mut Module> for Builder<'a> { + fn from(module: &'m mut Module) -> Self { + Builder::new(module.entry(), Vec::new(), None) + } +} +#[cfg(test)] +mod tests { + use crate::function::{self, Function, Result}; + + #[test] + fn to_boxed_fn0() { + fn boxed0() -> i64 { + return 10; + } + + function::register_override(boxed0, "boxed0".to_owned(), true).unwrap(); + let func = Function::get("boxed0").unwrap(); + let typed_func: Box Result> = func.to_boxed_fn(); + assert_eq!(typed_func().unwrap(), 10); + } +} diff --git a/rust/tvm-rt/src/to_function.rs b/rust/tvm-rt/src/to_function.rs new file mode 100644 index 000000000000..4814d098238a --- /dev/null +++ b/rust/tvm-rt/src/to_function.rs @@ -0,0 +1,307 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +//! This module provides an idiomatic Rust API for creating and working with TVM functions. +//! +//! For calling an already registered TVM function use [`function::Builder`] +//! To register a TVM packed function from Rust side either +//! use [`function::register`] or the macro [`register_global_func`]. +//! +//! See the tests and examples repository for more examples. + +use std::convert::{TryFrom, TryInto}; +use std::{ + os::raw::{c_int, c_void}, + ptr, slice, +}; + +use super::{function::Result, Function}; +use crate::errors::Error; + +pub use tvm_sys::{ffi, ArgValue, RetValue}; + +/// A trait representing whether the function arguments +/// and return type can be assigned to a TVM packed function. +/// +/// By splitting the conversion to function into two traits +/// we are able to improve error reporting, by splitting the +/// conversion of inputs and outputs to this trait. +/// +/// And the implementation of it to `ToFunction`. +pub trait Typed { + fn args(i: &[ArgValue<'static>]) -> Result; + fn ret(o: O) -> RetValue; +} + +impl> Typed<(), O> for F +where + F: Fn() -> O, +{ + fn args(_args: &[ArgValue<'static>]) -> Result<()> { + debug_assert!(_args.len() == 0); + Ok(()) + } + + fn ret(o: O) -> RetValue { + o.into() + } +} + +impl, E> Typed<(A,), O> for F +where + F: Fn(A) -> O, + Error: From, + A: TryFrom, Error = E>, +{ + fn args(args: &[ArgValue<'static>]) -> Result<(A,)> { + debug_assert!(args.len() == 1); + let a: A = args[0].clone().try_into()?; + Ok((a,)) + } + + fn ret(o: O) -> RetValue { + o.into() + } +} + +impl, E> Typed<(A, B), O> for F +where + F: Fn(A, B) -> O, + Error: From, + A: TryFrom, Error = E>, + B: TryFrom, Error = E>, +{ + fn args(args: &[ArgValue<'static>]) -> Result<(A, B)> { + debug_assert!(args.len() == 2); + let a: A = args[0].clone().try_into()?; + let b: B = args[1].clone().try_into()?; + Ok((a, b)) + } + + fn ret(o: O) -> RetValue { + o.into() + } +} + +impl, E> Typed<(A, B, C), O> for F +where + F: Fn(A, B, C) -> O, + Error: From, + A: TryFrom, Error = E>, + B: TryFrom, Error = E>, + C: TryFrom, Error = E>, +{ + fn args(args: &[ArgValue<'static>]) -> Result<(A, B, C)> { + debug_assert!(args.len() == 3); + let a: A = args[0].clone().try_into()?; + let b: B = args[1].clone().try_into()?; + let c: C = args[2].clone().try_into()?; + Ok((a, b, c)) + } + + fn ret(o: O) -> RetValue { + o.into() + } +} + +pub trait ToFunction: Sized { + type Handle; + + fn into_raw(self) -> *mut Self::Handle; + + fn call(handle: *mut Self::Handle, args: &[ArgValue<'static>]) -> Result + where + Self: Typed; + + fn drop(handle: *mut Self::Handle); + + fn to_function(self) -> Function + where + Self: Typed, + { + let mut fhandle = ptr::null_mut() as ffi::TVMFunctionHandle; + let resource_handle = self.into_raw(); + + check_call!(ffi::TVMFuncCreateFromCFunc( + Some(Self::tvm_callback), + resource_handle as *mut _, + None, // Some(Self::tvm_finalizer), + &mut fhandle as *mut ffi::TVMFunctionHandle, + )); + + Function::new(fhandle) + } + + /// The callback function which is wrapped converted by TVM + /// into a packed function stored in fhandle. + unsafe extern "C" fn tvm_callback( + args: *mut ffi::TVMValue, + type_codes: *mut c_int, + num_args: c_int, + ret: ffi::TVMRetValueHandle, + resource_handle: *mut c_void, + ) -> c_int + where + Self: Typed, + { + #![allow(unused_assignments, unused_unsafe)] + // turning off the incorrect linter complaints + let len = num_args as usize; + let args_list = slice::from_raw_parts_mut(args, len); + let type_codes_list = slice::from_raw_parts_mut(type_codes, len); + let mut local_args: Vec = Vec::new(); + let mut value = ffi::TVMValue { v_int64: 0 }; + let mut tcode = 0; + let resource_handle = resource_handle as *mut Self::Handle; + for i in 0..len { + value = args_list[i]; + tcode = type_codes_list[i]; + if tcode == ffi::TVMArgTypeCode_kTVMObjectHandle as c_int + || tcode == ffi::TVMArgTypeCode_kTVMPackedFuncHandle as c_int + || tcode == ffi::TVMArgTypeCode_kTVMModuleHandle as c_int + { + check_call!(ffi::TVMCbArgToReturn( + &mut value as *mut _, + &mut tcode as *mut _ + )); + } + let arg_value = ArgValue::from_tvm_value(value, tcode as u32); + local_args.push(arg_value); + } + + let rv = match Self::call(resource_handle, local_args.as_slice()) { + Ok(v) => v, + Err(msg) => { + crate::set_last_error(&msg); + return -1; + } + }; + + let (mut ret_val, ret_tcode) = rv.to_tvm_value(); + let mut ret_type_code = ret_tcode as c_int; + + check_call!(ffi::TVMCFuncSetReturn( + ret, + &mut ret_val as *mut _, + &mut ret_type_code as *mut _, + 1 as c_int + )); + 0 + } + + /// The finalizer which is invoked when the packed function's + /// reference count is zero. + unsafe extern "C" fn tvm_finalizer(fhandle: *mut c_void) { + let handle = std::mem::transmute(fhandle); + Self::drop(handle) + } +} + +impl ToFunction<(), O> for F +where + F: Fn() -> O + 'static, +{ + type Handle = Box O + 'static>; + + fn into_raw(self) -> *mut Self::Handle { + let ptr: Box = Box::new(Box::new(self)); + Box::into_raw(ptr) + } + + fn call(handle: *mut Self::Handle, _: &[ArgValue<'static>]) -> Result + where + F: Typed<(), O>, + { + // Ideally we shouldn't need to clone, probably doesn't really matter. + let out = unsafe { (*handle)() }; + Ok(F::ret(out)) + } + + fn drop(_: *mut Self::Handle) {} +} + +macro_rules! to_function_instance { + ($(($param:ident,$index:tt),)+) => { + impl ToFunction<($($param,)+), O> for + F where F: Fn($($param,)+) -> O + 'static { + type Handle = Box O + 'static>; + + fn into_raw(self) -> *mut Self::Handle { + let ptr: Box = Box::new(Box::new(self)); + Box::into_raw(ptr) + } + + fn call(handle: *mut Self::Handle, args: &[ArgValue<'static>]) -> Result where F: Typed<($($param,)+), O> { + // Ideally we shouldn't need to clone, probably doesn't really matter. + let args = F::args(args)?; + let out = unsafe { + (*handle)($(args.$index),+) + }; + Ok(F::ret(out)) + } + + fn drop(_: *mut Self::Handle) {} + } + } +} + +to_function_instance!((A, 0),); +to_function_instance!((A, 0), (B, 1),); +to_function_instance!((A, 0), (B, 1), (C, 2),); +to_function_instance!((A, 0), (B, 1), (C, 2), (D, 3),); + +#[cfg(test)] +mod tests { + use super::{Function, ToFunction, Typed}; + + fn zero() -> i32 { + 10 + } + + fn helper(f: F) -> Function + where + F: ToFunction, + F: Typed, + { + f.to_function() + } + + #[test] + fn test_to_function0() { + helper(zero); + } + + fn one_arg(i: i32) -> i32 { + i + } + + #[test] + fn test_to_function1() { + helper(one_arg); + } + + fn two_arg(i: i32, j: i32) -> i32 { + i + j + } + + #[test] + fn test_to_function2() { + helper(two_arg); + } +} diff --git a/rust/tvm-rt/src/value.rs b/rust/tvm-rt/src/value.rs new file mode 100644 index 000000000000..1812c0cfbe45 --- /dev/null +++ b/rust/tvm-rt/src/value.rs @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +//! This module implements [`ArgValue`] and [`RetValue`] types +//! and their conversions needed for the types used in frontend crate. +//! `RetValue` is the owned version of `TVMPODValue`. + +use std::convert::TryFrom; +// use std::ffi::c_void; + +use crate::{ArgValue, Module, NDArray, RetValue}; +use tvm_sys::{errors::ValueDowncastError, ffi::TVMModuleHandle, try_downcast}; + +macro_rules! impl_handle_val { + ($type:ty, $variant:ident, $inner_type:ty, $ctor:path) => { + impl<'a> From<&'a $type> for ArgValue<'a> { + fn from(arg: &'a $type) -> Self { + ArgValue::$variant(arg.handle() as $inner_type) + } + } + + impl<'a> From<&'a mut $type> for ArgValue<'a> { + fn from(arg: &'a mut $type) -> Self { + ArgValue::$variant(arg.handle() as $inner_type) + } + } + + impl<'a> TryFrom> for $type { + type Error = ValueDowncastError; + fn try_from(val: ArgValue<'a>) -> Result<$type, Self::Error> { + try_downcast!(val -> $type, |ArgValue::$variant(val)| { $ctor(val) }) + } + } + + impl<'a, 'v> TryFrom<&'a ArgValue<'v>> for $type { + type Error = ValueDowncastError; + fn try_from(val: &'a ArgValue<'v>) -> Result<$type, Self::Error> { + try_downcast!(val -> $type, |ArgValue::$variant(val)| { $ctor(*val) }) + } + } + + impl From<$type> for RetValue { + fn from(val: $type) -> RetValue { + RetValue::$variant(val.handle() as $inner_type) + } + } + + impl TryFrom for $type { + type Error = ValueDowncastError; + fn try_from(val: RetValue) -> Result<$type, Self::Error> { + try_downcast!(val -> $type, |RetValue::$variant(val)| { $ctor(val) }) + } + } + }; +} + +impl_handle_val!(Module, ModuleHandle, TVMModuleHandle, Module::new); + +impl<'a> From<&'a NDArray> for ArgValue<'a> { + fn from(arg: &'a NDArray) -> Self { + match arg { + &NDArray::Borrowed { handle } => ArgValue::ArrayHandle(handle), + &NDArray::Owned { handle } => ArgValue::NDArrayHandle(handle), + } + } +} + +impl<'a> From<&'a mut NDArray> for ArgValue<'a> { + fn from(arg: &'a mut NDArray) -> Self { + match arg { + &mut NDArray::Borrowed { handle } => ArgValue::ArrayHandle(handle), + &mut NDArray::Owned { handle } => ArgValue::NDArrayHandle(handle), + } + } +} + +impl<'a> TryFrom> for NDArray { + type Error = ValueDowncastError; + fn try_from(val: ArgValue<'a>) -> Result { + try_downcast!(val -> NDArray, + |ArgValue::NDArrayHandle(val)| { NDArray::from_ndarray_handle(val) }, + |ArgValue::ArrayHandle(val)| { NDArray::new(val) }) + } +} + +impl<'a, 'v> TryFrom<&'a ArgValue<'v>> for NDArray { + type Error = ValueDowncastError; + fn try_from(val: &'a ArgValue<'v>) -> Result { + try_downcast!(val -> NDArray, + |ArgValue::NDArrayHandle(val)| { NDArray::from_ndarray_handle(*val) }, + |ArgValue::ArrayHandle(val)| { NDArray::new(*val) }) + } +} + +impl From for RetValue { + fn from(val: NDArray) -> RetValue { + match val { + NDArray::Owned { handle } => RetValue::NDArrayHandle(handle), + _ => panic!("NYI"), + } + } +} + +impl TryFrom for NDArray { + type Error = ValueDowncastError; + fn try_from(val: RetValue) -> Result { + try_downcast!(val -> NDArray, + |RetValue::NDArrayHandle(val)| { NDArray::from_ndarray_handle(val) }, + |RetValue::ArrayHandle(val)| { NDArray::new(val) }) + } +} + +#[cfg(test)] +mod tests { + use std::{convert::TryInto, str::FromStr}; + + use crate::{ByteArray, Context, DataType}; + + use super::*; + + #[test] + fn bytearray() { + let w = vec![1u8, 2, 3, 4, 5]; + let v = ByteArray::from(w.as_slice()); + let tvm: ByteArray = RetValue::from(v).try_into().unwrap(); + assert_eq!( + tvm.data(), + w.iter().copied().collect::>().as_slice() + ); + } + + #[test] + fn ty() { + let t = DataType::from_str("int32").unwrap(); + let tvm: DataType = RetValue::from(t).try_into().unwrap(); + assert_eq!(tvm, t); + } + + #[test] + fn ctx() { + let c = Context::from_str("gpu").unwrap(); + let tvm: Context = RetValue::from(c).try_into().unwrap(); + assert_eq!(tvm, c); + } +} diff --git a/rust/tvm-sys/Cargo.toml b/rust/tvm-sys/Cargo.toml new file mode 100644 index 000000000000..fe4d0bf987bf --- /dev/null +++ b/rust/tvm-sys/Cargo.toml @@ -0,0 +1,35 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "tvm-sys" +version = "0.1.0" +authors = ["TVM Contributors"] +license = "Apache-2.0" +edition = "2018" + +[features] +bindings = [] + +[dependencies] +thiserror = "^1.0" +anyhow = "^1.0" +ndarray = "0.12" +enumn = "^0.1" + +[build-dependencies] +bindgen = "0.51" diff --git a/rust/tvm-sys/build.rs b/rust/tvm-sys/build.rs new file mode 100644 index 000000000000..85e16bead085 --- /dev/null +++ b/rust/tvm-sys/build.rs @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +extern crate bindgen; + +use std::path::PathBuf; + +use std::env; + +fn main() { + let tvm_home = option_env!("TVM_HOME").map(str::to_string).unwrap_or({ + let crate_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .canonicalize() + .unwrap(); + crate_dir + .parent() + .unwrap() + .parent() + .unwrap() + .to_str() + .unwrap() + .to_string() + }); + + if cfg!(feature = "bindings") { + println!("cargo:rerun-if-env-changed=TVM_HOME"); + println!("cargo:rustc-link-lib=dylib=tvm"); + println!("cargo:rustc-link-search={}/build", tvm_home); + } + + // @see rust-bindgen#550 for `blacklist_type` + bindgen::Builder::default() + .header(format!("{}/include/tvm/runtime/c_runtime_api.h", tvm_home)) + .header(format!("{}/include/tvm/runtime/c_backend_api.h", tvm_home)) + .clang_arg(format!("-I{}/3rdparty/dlpack/include/", tvm_home)) + .clang_arg(format!("-I{}/include/", tvm_home)) + .blacklist_type("max_align_t") + .layout_tests(false) + .derive_partialeq(true) + .derive_eq(true) + .generate() + .expect("unable to generate bindings") + .write_to_file(PathBuf::from("src/c_runtime_api.rs")) + .expect("can not write the bindings!"); +} diff --git a/rust/tvm-sys/src/array.rs b/rust/tvm-sys/src/array.rs new file mode 100644 index 000000000000..1627e9e22860 --- /dev/null +++ b/rust/tvm-sys/src/array.rs @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::{ + mem, + os::raw::{c_int, c_void}, +}; + +use crate::ffi::{ + DLContext, DLDataType, DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLDataTypeCode_kDLUInt, + DLDeviceType_kDLCPU, DLTensor, +}; + +/// `From` conversions to `DLTensor` for `ndarray::Array`. +/// Takes a reference to the `ndarray` since `DLTensor` is not owned. +macro_rules! impl_dltensor_from_ndarray { + ($type:ty, $typecode:expr) => { + impl<'a, D: ndarray::Dimension> From<&'a mut ndarray::Array<$type, D>> for DLTensor { + fn from(arr: &'a mut ndarray::Array<$type, D>) -> Self { + DLTensor { + data: arr.as_mut_ptr() as *mut c_void, + ctx: DLContext { + device_type: DLDeviceType_kDLCPU, + device_id: 0, + }, + ndim: arr.ndim() as c_int, + dtype: DLDataType { + code: $typecode as u8, + bits: 8 * mem::size_of::<$type>() as u8, + lanes: 1, + }, + shape: arr.shape().as_ptr() as *const i64 as *mut i64, + strides: arr.strides().as_ptr() as *const i64 as *mut i64, + byte_offset: 0, + } + } + } + }; +} + +impl_dltensor_from_ndarray!(f32, DLDataTypeCode_kDLFloat); +impl_dltensor_from_ndarray!(f64, DLDataTypeCode_kDLFloat); +impl_dltensor_from_ndarray!(i32, DLDataTypeCode_kDLInt); +impl_dltensor_from_ndarray!(i64, DLDataTypeCode_kDLInt); +impl_dltensor_from_ndarray!(u32, DLDataTypeCode_kDLUInt); +impl_dltensor_from_ndarray!(u64, DLDataTypeCode_kDLUInt); diff --git a/rust/tvm-sys/src/byte_array.rs b/rust/tvm-sys/src/byte_array.rs new file mode 100644 index 000000000000..9bd95262820f --- /dev/null +++ b/rust/tvm-sys/src/byte_array.rs @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +use std::convert::TryFrom; +use std::os::raw::c_char; + +use crate::errors::ValueDowncastError; +use crate::ffi::TVMByteArray; +use crate::{ArgValue, RetValue}; + +/// A newtype wrapping a raw TVM byte-array. +/// +/// ## Example +/// +/// ``` +/// let v = b"hello"; +/// let barr = tvm_sys::ByteArray::from(&v); +/// assert_eq!(barr.len(), v.len()); +/// assert_eq!(barr.data(), &[104u8, 101, 108, 108, 111]); +/// ``` +pub struct ByteArray { + /// The raw FFI ByteArray. + array: TVMByteArray, +} + +impl ByteArray { + /// Gets the underlying byte-array + pub fn data(&self) -> &'static [u8] { + unsafe { std::slice::from_raw_parts(self.array.data as *const u8, self.array.size) } + } + + /// Gets the length of the underlying byte-array + pub fn len(&self) -> usize { + self.array.size + } + + /// Converts the underlying byte-array to `Vec` + pub fn to_vec(&self) -> Vec { + self.data().to_vec() + } + + pub fn is_empty(&self) -> bool { + self.len() == 0 + } +} + +// Needs AsRef for Vec +impl> From for ByteArray { + fn from(arg: T) -> Self { + let arg = arg.as_ref(); + ByteArray { + array: TVMByteArray { + data: arg.as_ptr() as *const c_char, + size: arg.len(), + }, + } + } +} + +impl TryFrom> for ByteArray { + type Error = ValueDowncastError; + + fn try_from(val: ArgValue<'static>) -> Result { + match val { + ArgValue::Bytes(array) => Ok(ByteArray { array: *array }), + _ => Err(ValueDowncastError { + expected_type: "ByteArray", + actual_type: format!("{:?}", val), + }), + } + } +} + +impl From for RetValue { + fn from(val: ByteArray) -> RetValue { + RetValue::Bytes(val.array) + } +} + +impl TryFrom for ByteArray { + type Error = ValueDowncastError; + fn try_from(val: RetValue) -> Result { + match val { + RetValue::Bytes(array) => Ok(ByteArray { array }), + _ => Err(ValueDowncastError { + expected_type: "ByteArray", + actual_type: format!("{:?}", val), + }), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn convert() { + let v = vec![1u8, 2, 3]; + let barr = ByteArray::from(&v); + assert_eq!(barr.len(), v.len()); + assert_eq!(barr.to_vec(), vec![1u8, 2, 3]); + let v = b"hello"; + let barr = ByteArray::from(&v); + assert_eq!(barr.len(), v.len()); + assert_eq!(barr.data(), &[104u8, 101, 108, 108, 111]); + } +} diff --git a/rust/tvm-sys/src/context.rs b/rust/tvm-sys/src/context.rs new file mode 100644 index 000000000000..64b58b9f42c9 --- /dev/null +++ b/rust/tvm-sys/src/context.rs @@ -0,0 +1,284 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +//! Provides [`Context`] and related device queries. +//! +//! Create a new context for device type and device id. +//! +//! # Example +//! +//! ``` +//! # use tvm_sys::{DeviceType, Context}; +//! let cpu = DeviceType::from("cpu"); +//! let ctx = Context::new(cpu , 0); +//! let cpu0 = Context::cpu(0); +//! assert_eq!(ctx, cpu0); +//! ``` +//! +//! Or from a supported device name. +//! +//! ``` +//! use tvm_sys::Context; +//! let cpu0 = Context::from("cpu"); +//! println!("{}", cpu0); +//! ``` + +use std::convert::TryFrom; +use std::fmt::{self, Display, Formatter}; +use std::str::FromStr; + +use crate::ffi::{self, *}; +use crate::packed_func::{ArgValue, RetValue}; + +use anyhow::Result; +use enumn::N; +use thiserror::Error; + +/// Device type represents the set of devices supported by +/// [TVM](https://github.com/apache/incubator-tvm). +/// +/// ## Example +/// +/// ``` +/// use tvm_sys::DeviceType; +/// let cpu = DeviceType::from("cpu"); +/// println!("device is: {}", cpu); +///``` + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, N)] +#[repr(i64)] +pub enum DeviceType { + CPU = 1, + GPU, + CPUPinned, + OpenCL, + Vulkan, + Metal, + VPI, + ROCM, + ExtDev, +} + +impl Default for DeviceType { + /// default device is cpu. + fn default() -> Self { + DeviceType::CPU + } +} + +impl From for ffi::DLDeviceType { + fn from(device_type: DeviceType) -> Self { + device_type as Self + } +} + +impl From for DeviceType { + fn from(device_type: ffi::DLDeviceType) -> Self { + Self::n(device_type as _).expect("invalid enumeration value for ffi::DLDeviceType") + } +} + +impl Display for DeviceType { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + write!( + f, + "{}", + match self { + DeviceType::CPU => "cpu", + DeviceType::GPU => "gpu", + DeviceType::CPUPinned => "cpu_pinned", + DeviceType::OpenCL => "opencl", + DeviceType::Vulkan => "vulkan", + DeviceType::Metal => "metal", + DeviceType::VPI => "vpi", + DeviceType::ROCM => "rocm", + DeviceType::ExtDev => "ext_device", + // DeviceType(_) => "rpc", + } + ) + } +} + +impl<'a> From<&'a str> for DeviceType { + fn from(type_str: &'a str) -> Self { + match type_str { + "cpu" => DeviceType::CPU, + "llvm" => DeviceType::CPU, + "stackvm" => DeviceType::CPU, + "gpu" => DeviceType::GPU, + "cuda" => DeviceType::GPU, + "nvptx" => DeviceType::GPU, + "cl" => DeviceType::OpenCL, + "opencl" => DeviceType::OpenCL, + "metal" => DeviceType::Metal, + "vpi" => DeviceType::VPI, + "rocm" => DeviceType::ROCM, + _ => panic!("{:?} not supported!", type_str), + } + } +} + +impl<'a> From<&DeviceType> for ArgValue<'a> { + fn from(dev: &DeviceType) -> Self { + Self::Int(*dev as _) + } +} + +#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] +pub struct Context { + pub device_type: DeviceType, + pub device_id: usize, +} + +impl Context { + pub fn new(device_type: DeviceType, device_id: usize) -> Context { + Context { + device_type, + device_id, + } + } +} + +impl<'a> From<&'a Context> for DLContext { + fn from(ctx: &'a Context) -> Self { + Self { + device_type: ctx.device_type.into(), + device_id: ctx.device_id as i32, + } + } +} + +impl Default for Context { + fn default() -> Self { + Self { + device_type: DLDeviceType_kDLCPU.into(), + device_id: 0, + } + } +} + +#[derive(Debug, Error)] +#[error("unsupported device: {0}")] +pub struct UnsupportedDeviceError(String); + +macro_rules! impl_tvm_context { + ( $( $dev_type:ident : [ $( $dev_name:ident ),+ ] ),+ ) => { + /// Creates a Context from a string (e.g., "cpu", "gpu", "ext_dev") + impl FromStr for Context { + type Err = UnsupportedDeviceError; + fn from_str(type_str: &str) -> Result { + Ok(Self { + device_type: match type_str { + $( $( stringify!($dev_name) )|+ => $dev_type.into()),+, + _ => return Err(UnsupportedDeviceError(type_str.to_string())), + }, + device_id: 0, + }) + } + } + + impl Context { + $( + $( + pub fn $dev_name(device_id: usize) -> Self { + Self { + device_type: $dev_type.into(), + device_id: device_id, + } + } + )+ + )+ + } + }; +} + +impl_tvm_context!( + DLDeviceType_kDLCPU: [cpu, llvm, stackvm], + DLDeviceType_kDLGPU: [gpu, cuda, nvptx], + DLDeviceType_kDLOpenCL: [cl], + DLDeviceType_kDLMetal: [metal], + DLDeviceType_kDLVPI: [vpi], + DLDeviceType_kDLROCM: [rocm], + DLDeviceType_kDLExtDev: [ext_dev] +); + +impl<'a> From<&'a str> for Context { + fn from(target: &str) -> Self { + Context::new(DeviceType::from(target), 0) + } +} + +impl From for Context { + fn from(ctx: ffi::DLContext) -> Self { + Context { + device_type: DeviceType::from(ctx.device_type), + device_id: ctx.device_id as usize, + } + } +} + +impl From for ffi::DLContext { + fn from(ctx: Context) -> Self { + ffi::DLContext { + device_type: ctx.device_type.into(), + device_id: ctx.device_id as i32, + } + } +} + +impl Display for Context { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + write!(f, "{}({})", self.device_type, self.device_id) + } +} + +impl From for RetValue { + fn from(ret_value: Context) -> RetValue { + RetValue::Context(ret_value.into()) + } +} + +impl TryFrom for Context { + type Error = anyhow::Error; + fn try_from(ret_value: RetValue) -> anyhow::Result { + match ret_value { + RetValue::Context(dt) => Ok(dt.into()), + // TODO(@jroesch): improve + _ => Err(anyhow::anyhow!("unable to convert datatype from ...")), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn context() { + let ctx = Context::cpu(0); + println!("ctx: {}", ctx); + let default_ctx = Context::new(DeviceType::CPU, 0); + assert_eq!(ctx.clone(), default_ctx); + assert_ne!(ctx, Context::gpu(0)); + + let str_ctx = Context::new(DeviceType::GPU, 0); + assert_eq!(str_ctx.clone(), str_ctx); + assert_ne!(str_ctx, Context::new(DeviceType::CPU, 0)); + } +} diff --git a/rust/tvm-sys/src/datatype.rs b/rust/tvm-sys/src/datatype.rs new file mode 100644 index 000000000000..ccdee3f6f753 --- /dev/null +++ b/rust/tvm-sys/src/datatype.rs @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::any::TypeId; +use std::convert::TryFrom; +use std::str::FromStr; + +use crate::ffi::DLDataType; +use crate::packed_func::RetValue; + +use thiserror::Error; + +const DL_INT_CODE: u8 = 0; +const DL_UINT_CODE: u8 = 1; +const DL_FLOAT_CODE: u8 = 2; +const DL_HANDLE: u8 = 3; + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct DataType { + code: u8, + bits: u8, + lanes: u16, +} + +impl DataType { + pub fn new(code: u8, bits: u8, lanes: u16) -> DataType { + DataType { code, bits, lanes } + } + + /// Returns the number of bytes occupied by an element of this `DataType`. + pub fn itemsize(&self) -> usize { + (self.bits as usize * self.lanes as usize) >> 3 + } + + /// Returns whether this `DataType` represents primitive type `T`. + pub fn is_type(&self) -> bool { + if self.lanes != 1 { + return false; + } + let typ = TypeId::of::(); + (typ == TypeId::of::() && self.code == DL_INT_CODE && self.bits == 32) + || (typ == TypeId::of::() && self.code == DL_INT_CODE && self.bits == 64) + || (typ == TypeId::of::() && self.code == DL_UINT_CODE && self.bits == 32) + || (typ == TypeId::of::() && self.code == DL_UINT_CODE && self.bits == 64) + || (typ == TypeId::of::() && self.code == DL_FLOAT_CODE && self.bits == 32) + || (typ == TypeId::of::() && self.code == DL_FLOAT_CODE && self.bits == 64) + } + + pub fn code(&self) -> usize { + self.code as usize + } + + pub fn bits(&self) -> usize { + self.bits as usize + } + + pub fn lanes(&self) -> usize { + self.lanes as usize + } +} + +impl<'a> From<&'a DataType> for DLDataType { + fn from(dtype: &'a DataType) -> Self { + Self { + code: dtype.code as u8, + bits: dtype.bits as u8, + lanes: dtype.lanes as u16, + } + } +} + +impl From for DataType { + fn from(dtype: DLDataType) -> Self { + Self { + code: dtype.code, + bits: dtype.bits, + lanes: dtype.lanes, + } + } +} + +impl From for DLDataType { + fn from(dtype: DataType) -> Self { + Self { + code: dtype.code, + bits: dtype.bits, + lanes: dtype.lanes, + } + } +} + +#[derive(Debug, Error)] +pub enum ParseDataTypeError { + #[error("invalid number: {0}")] + InvalidNumber(std::num::ParseIntError), + #[error("missing data type specifier (e.g., int32, float64)")] + MissingDataType, + #[error("unknown type: {0}")] + UnknownType(String), +} + +/// Implements TVMType conversion from `&str` of general format `{dtype}{bits}x{lanes}` +/// such as "int32", "float32" or with lane "float32x1". +impl FromStr for DataType { + type Err = ParseDataTypeError; + + fn from_str(type_str: &str) -> Result { + use ParseDataTypeError::*; + + if type_str == "bool" { + return Ok(DataType::new(1, 1, 1)); + } + + let mut type_lanes = type_str.split('x'); + let typ = type_lanes.next().ok_or(MissingDataType)?; + let lanes = type_lanes + .next() + .map(|l| ::from_str_radix(l, 10)) + .unwrap_or(Ok(1)) + .map_err(InvalidNumber)?; + let (type_name, bits) = match typ.find(char::is_numeric) { + Some(idx) => { + let (name, bits_str) = typ.split_at(idx); + ( + name, + u8::from_str_radix(bits_str, 10).map_err(InvalidNumber)?, + ) + } + None => (typ, 32), + }; + + let type_code = match type_name { + "int" => DL_INT_CODE, + "uint" => DL_UINT_CODE, + "float" => DL_FLOAT_CODE, + "handle" => DL_HANDLE, + _ => return Err(UnknownType(type_name.to_string())), + }; + + Ok(DataType::new(type_code, bits, lanes)) + } +} + +impl std::fmt::Display for DataType { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + if self.bits == 1 && self.lanes == 1 { + return write!(f, "bool"); + } + let mut type_str = match self.code { + DL_INT_CODE => "int", + DL_UINT_CODE => "uint", + DL_FLOAT_CODE => "float", + DL_HANDLE => "handle", + _ => "unknown", + } + .to_string(); + + type_str += &self.bits.to_string(); + if self.lanes > 1 { + type_str += &format!("x{}", self.lanes); + } + f.write_str(&type_str) + } +} + +impl From for RetValue { + fn from(dt: DataType) -> RetValue { + RetValue::DataType((&dt).into()) + } +} + +impl TryFrom for DataType { + type Error = anyhow::Error; + fn try_from(ret_value: RetValue) -> anyhow::Result { + match ret_value { + RetValue::DataType(dt) => Ok(dt.into()), + // TODO(@jroesch): improve + _ => Err(anyhow::anyhow!("unable to convert datatype from ...")), + } + } +} diff --git a/tests/cpp/ir_ssa_test.cc b/rust/tvm-sys/src/errors.rs similarity index 52% rename from tests/cpp/ir_ssa_test.cc rename to rust/tvm-sys/src/errors.rs index 56f178dbcf4e..54fe261ec37e 100644 --- a/tests/cpp/ir_ssa_test.cc +++ b/rust/tvm-sys/src/errors.rs @@ -17,33 +17,30 @@ * under the License. */ -#include -#include -#include +use thiserror::Error; - -TEST(IRSSA, Convert) { - using namespace tvm; - using namespace tvm::tir; - Var x("x"), y; - PrimExpr let = LetNode::make(x, 1, x + 1); - - auto z = EvaluateNode::make(let + let); - CHECK(!tir::VerifySSA(z)); - auto z_ssa = tir::ConvertSSA(z); - CHECK(tir::VerifySSA(z_ssa)); +#[derive(Error, Debug)] +#[error("invalid header (expected {expected_type:?}, found {actual_type:?})")] +pub struct ValueDowncastError { + pub actual_type: String, + pub expected_type: &'static str, } -TEST(IRSSA, Basic) { - using namespace tvm::tir; - using namespace tvm; - Var x("x"), y; - auto z = EvaluateNode::make(x + y); - CHECK(tir::VerifySSA(z)); +#[derive(Error, Debug)] +#[error("Function call `{context:?}` returned error: {message:?}")] +pub struct FuncCallError { + context: String, + message: String, } -int main(int argc, char ** argv) { - testing::InitGoogleTest(&argc, argv); - testing::FLAGS_gtest_death_test_style = "threadsafe"; - return RUN_ALL_TESTS(); +impl FuncCallError { + pub fn get_with_context(context: String) -> Self { + Self { + context, + message: unsafe { std::ffi::CStr::from_ptr(crate::ffi::TVMGetLastError()) } + .to_str() + .expect("failed while attempting to retrieve the TVM error message") + .to_owned(), + } + } } diff --git a/rust/tvm-sys/src/lib.rs b/rust/tvm-sys/src/lib.rs new file mode 100644 index 000000000000..0f455e726d26 --- /dev/null +++ b/rust/tvm-sys/src/lib.rs @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +//! This crate contains the minimal interface over TVM's +//! C runtime API. +//! +//! These common bindings are useful to both runtimes +//! written in Rust, as well as higher level API bindings. +//! +//! See the `tvm-rt` or `tvm` crates for full bindings to +//! the TVM API. + +/// The low-level C runtime FFI API for TVM. +pub mod ffi { + #![allow(non_camel_case_types, non_snake_case, non_upper_case_globals, unused)] + + use std::os::raw::{c_char, c_int, c_void}; + + include!(concat!(env!("CARGO_MANIFEST_DIR"), "/src/c_runtime_api.rs")); + + pub type BackendPackedCFunc = extern "C" fn( + args: *const TVMValue, + type_codes: *const c_int, + num_args: c_int, + out_ret_value: *mut TVMValue, + out_ret_tcode: *mut u32, + ) -> c_int; +} + +pub mod array; +pub mod byte_array; +pub mod context; +pub mod datatype; +pub mod errors; +#[macro_use] +pub mod packed_func; +pub mod value; + +pub use byte_array::ByteArray; +pub use context::{Context, DeviceType}; +pub use datatype::DataType; +pub use errors::*; +pub use packed_func::{ArgValue, RetValue}; diff --git a/rust/tvm-sys/src/packed_func.rs b/rust/tvm-sys/src/packed_func.rs new file mode 100644 index 000000000000..a326aa1b8fdf --- /dev/null +++ b/rust/tvm-sys/src/packed_func.rs @@ -0,0 +1,380 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::{ + convert::TryFrom, + ffi::{CStr, CString}, + os::raw::c_void, +}; + +use crate::{errors::ValueDowncastError, ffi::*}; + +pub use crate::ffi::TVMValue; + +pub trait PackedFunc: + Fn(&[ArgValue]) -> Result + Send + Sync +{ +} + +impl PackedFunc for T where + T: Fn(&[ArgValue]) -> Result + Send + Sync +{ +} + +/// Calls a packed function and returns a `RetValue`. +/// +/// # Example +/// +/// `call_packed!(my_tvm_func, &mut arg1, &mut arg2)` +#[macro_export] +macro_rules! call_packed { + ($fn:expr, $($args:expr),+) => { + $fn(&[$($args.into(),)+]) + }; + ($fn:expr) => { + $fn(&Vec::new()) + }; +} + +/// Constructs a derivative of a TVMPodValue. +macro_rules! TVMPODValue { + { + $(#[$m:meta])+ + $name:ident $(<$a:lifetime>)? { + $($extra_variant:ident ( $variant_type:ty ) ),+ $(,)? + }, + match $value:ident { + $($tvm_type:ident => { $from_tvm_type:expr })+ + }, + match &self { + $($self_type:ident ( $val:ident ) => { $from_self_type:expr })+ + } + $(,)? + } => { + $(#[$m])+ + #[derive(Clone, Debug)] + pub enum $name $(<$a>)? { + Int(i64), + UInt(i64), + Float(f64), + Null, + DataType(DLDataType), + String(CString), + Context(TVMContext), + Handle(*mut c_void), + ArrayHandle(TVMArrayHandle), + ObjectHandle(*mut c_void), + ModuleHandle(TVMModuleHandle), + FuncHandle(TVMFunctionHandle), + NDArrayHandle(*mut c_void), + $($extra_variant($variant_type)),+ + } + + impl $(<$a>)? $name $(<$a>)? { + pub fn from_tvm_value($value: TVMValue, type_code: u32) -> Self { + use $name::*; + #[allow(non_upper_case_globals)] + unsafe { + match type_code as _ { + DLDataTypeCode_kDLInt => Int($value.v_int64), + DLDataTypeCode_kDLUInt => UInt($value.v_int64), + DLDataTypeCode_kDLFloat => Float($value.v_float64), + TVMArgTypeCode_kTVMNullptr => Null, + TVMArgTypeCode_kTVMDataType => DataType($value.v_type), + TVMArgTypeCode_kTVMContext => Context($value.v_ctx), + TVMArgTypeCode_kTVMOpaqueHandle => Handle($value.v_handle), + TVMArgTypeCode_kTVMDLTensorHandle => ArrayHandle($value.v_handle as TVMArrayHandle), + TVMArgTypeCode_kTVMObjectHandle => ObjectHandle($value.v_handle), + TVMArgTypeCode_kTVMModuleHandle => ModuleHandle($value.v_handle), + TVMArgTypeCode_kTVMPackedFuncHandle => FuncHandle($value.v_handle), + TVMArgTypeCode_kTVMNDArrayHandle => NDArrayHandle($value.v_handle), + $( $tvm_type => { $from_tvm_type } ),+ + _ => unimplemented!("{}", type_code), + } + } + } + + pub fn to_tvm_value(&self) -> (TVMValue, TVMArgTypeCode) { + use $name::*; + match self { + Int(val) => (TVMValue { v_int64: *val }, DLDataTypeCode_kDLInt), + UInt(val) => (TVMValue { v_int64: *val as i64 }, DLDataTypeCode_kDLUInt), + Float(val) => (TVMValue { v_float64: *val }, DLDataTypeCode_kDLFloat), + Null => (TVMValue{ v_int64: 0 },TVMArgTypeCode_kTVMNullptr), + DataType(val) => (TVMValue { v_type: *val }, TVMArgTypeCode_kTVMDataType), + Context(val) => (TVMValue { v_ctx: val.clone() }, TVMArgTypeCode_kTVMContext), + String(val) => { + ( + TVMValue { v_handle: val.as_ptr() as *mut c_void }, + TVMArgTypeCode_kTVMStr, + ) + } + Handle(val) => (TVMValue { v_handle: *val }, TVMArgTypeCode_kTVMOpaqueHandle), + ArrayHandle(val) => { + ( + TVMValue { v_handle: *val as *const _ as *mut c_void }, + TVMArgTypeCode_kTVMNDArrayHandle, + ) + }, + ObjectHandle(val) => (TVMValue { v_handle: *val }, TVMArgTypeCode_kTVMObjectHandle), + ModuleHandle(val) => + (TVMValue { v_handle: *val }, TVMArgTypeCode_kTVMModuleHandle), + FuncHandle(val) => ( + TVMValue { v_handle: *val }, + TVMArgTypeCode_kTVMPackedFuncHandle + ), + NDArrayHandle(val) => + (TVMValue { v_handle: *val }, TVMArgTypeCode_kTVMNDArrayHandle), + $( $self_type($val) => { $from_self_type } ),+ + } + } + } + } +} + +TVMPODValue! { + /// A borrowed TVMPODValue. Can be constructed using `into()` but the preferred way + /// to obtain a `ArgValue` is automatically via `call_packed!`. + ArgValue<'a> { + Bytes(&'a TVMByteArray), + Str(&'a CStr), + }, + match value { + TVMArgTypeCode_kTVMBytes => { Bytes(&*(value.v_handle as *const TVMByteArray)) } + TVMArgTypeCode_kTVMStr => { Str(CStr::from_ptr(value.v_handle as *const i8)) } + }, + match &self { + Bytes(val) => { + (TVMValue { v_handle: val as *const _ as *mut c_void }, TVMArgTypeCode_kTVMBytes) + } + Str(val) => { (TVMValue { v_handle: val.as_ptr() as *mut c_void }, TVMArgTypeCode_kTVMStr) } + } +} + +TVMPODValue! { + /// An owned TVMPODValue. Can be converted from a variety of primitive and object types. + /// Can be downcasted using `try_from` if it contains the desired type. + /// + /// # Example + /// + /// ``` + /// use std::convert::{TryFrom, TryInto}; + /// use tvm_sys::RetValue; + /// + /// let a = 42u32; + /// let b: u32 = tvm_sys::RetValue::from(a).try_into().unwrap(); + /// + /// let s = "hello, world!"; + /// let t: RetValue = s.to_string().into(); + /// assert_eq!(String::try_from(t).unwrap(), s); + /// ``` + RetValue { + Bytes(TVMByteArray), + Str(&'static CStr), + }, + match value { + TVMArgTypeCode_kTVMBytes => { Bytes(*(value.v_handle as *const TVMByteArray)) } + TVMArgTypeCode_kTVMStr => { Str(CStr::from_ptr(value.v_handle as *mut i8)) } + }, + match &self { + Bytes(val) => + { (TVMValue { v_handle: val as *const _ as *mut c_void }, TVMArgTypeCode_kTVMBytes ) } + Str(val) => + { (TVMValue { v_str: val.as_ptr() }, TVMArgTypeCode_kTVMStr ) } + } +} + +#[macro_export] +macro_rules! try_downcast { + ($val:ident -> $into:ty, $( |$pat:pat| { $converter:expr } ),+ ) => { + match $val { + $( $pat => { Ok($converter) } )+ + _ => Err($crate::errors::ValueDowncastError { + actual_type: format!("{:?}", $val), + expected_type: stringify!($into), + }), + } + }; +} + +/// Creates a conversion to a `ArgValue` for a primitive type and DLDataTypeCode. +macro_rules! impl_pod_value { + ($variant:ident, $inner_ty:ty, [ $( $type:ty ),+ ] ) => { + $( + impl<'a> From<$type> for ArgValue<'a> { + fn from(val: $type) -> Self { + Self::$variant(val as $inner_ty) + } + } + + impl<'a, 'v> From<&'a $type> for ArgValue<'v> { + fn from(val: &'a $type) -> Self { + Self::$variant(*val as $inner_ty) + } + } + + impl<'a> TryFrom> for $type { + type Error = $crate::errors::ValueDowncastError; + fn try_from(val: ArgValue<'a>) -> Result { + try_downcast!(val -> $type, |ArgValue::$variant(val)| { val as $type }) + } + } + + impl<'a, 'v> TryFrom<&'a ArgValue<'v>> for $type { + type Error = $crate::errors::ValueDowncastError; + fn try_from(val: &'a ArgValue<'v>) -> Result { + try_downcast!(val -> $type, |ArgValue::$variant(val)| { *val as $type }) + } + } + + impl From<$type> for RetValue { + fn from(val: $type) -> Self { + Self::$variant(val as $inner_ty) + } + } + + impl TryFrom for $type { + type Error = $crate::errors::ValueDowncastError; + fn try_from(val: RetValue) -> Result { + try_downcast!(val -> $type, |RetValue::$variant(val)| { val as $type }) + } + } + )+ + }; +} + +impl_pod_value!(Int, i64, [i8, i16, i32, i64, isize]); +impl_pod_value!(UInt, i64, [u8, u16, u32, u64, usize]); +impl_pod_value!(Float, f64, [f32, f64]); +impl_pod_value!(DataType, DLDataType, [DLDataType]); +impl_pod_value!(Context, TVMContext, [TVMContext]); + +impl<'a> From<&'a str> for ArgValue<'a> { + fn from(s: &'a str) -> Self { + Self::String(CString::new(s).unwrap()) + } +} + +impl<'a> From for ArgValue<'a> { + fn from(s: String) -> Self { + Self::String(CString::new(s).unwrap()) + } +} + +impl<'a> From<&'a CStr> for ArgValue<'a> { + fn from(s: &'a CStr) -> Self { + Self::Str(s) + } +} + +impl<'a> From for ArgValue<'a> { + fn from(s: CString) -> Self { + Self::String(s) + } +} + +impl<'a> From<&'a TVMByteArray> for ArgValue<'a> { + fn from(s: &'a TVMByteArray) -> Self { + Self::Bytes(s) + } +} + +impl<'a> TryFrom> for &'a str { + type Error = ValueDowncastError; + fn try_from(val: ArgValue<'a>) -> Result { + try_downcast!(val -> &str, |ArgValue::Str(s)| { s.to_str().unwrap() }) + } +} + +impl<'a, 'v> TryFrom<&'a ArgValue<'v>> for &'v str { + type Error = ValueDowncastError; + fn try_from(val: &'a ArgValue<'v>) -> Result { + try_downcast!(val -> &str, |ArgValue::Str(s)| { s.to_str().unwrap() }) + } +} + +/// Converts an unspecialized handle to a ArgValue. +impl From<*const T> for ArgValue<'static> { + fn from(ptr: *const T) -> Self { + Self::Handle(ptr as *mut c_void) + } +} + +/// Converts an unspecialized mutable handle to a ArgValue. +impl From<*mut T> for ArgValue<'static> { + fn from(ptr: *mut T) -> Self { + Self::Handle(ptr as *mut c_void) + } +} + +impl<'a> From<&'a mut DLTensor> for ArgValue<'a> { + fn from(arr: &'a mut DLTensor) -> Self { + Self::ArrayHandle(arr as *mut DLTensor) + } +} + +impl<'a> From<&'a DLTensor> for ArgValue<'a> { + fn from(arr: &'a DLTensor) -> Self { + Self::ArrayHandle(arr as *const _ as *mut DLTensor) + } +} + +impl TryFrom for String { + type Error = ValueDowncastError; + fn try_from(val: RetValue) -> Result { + try_downcast!( + val -> String, + |RetValue::String(s)| { s.into_string().unwrap() }, + |RetValue::Str(s)| { s.to_str().unwrap().to_string() } + ) + } +} + +impl From for RetValue { + fn from(s: String) -> Self { + Self::String(std::ffi::CString::new(s).unwrap()) + } +} + +impl From for RetValue { + fn from(arr: TVMByteArray) -> Self { + Self::Bytes(arr) + } +} + +impl TryFrom for TVMByteArray { + type Error = ValueDowncastError; + fn try_from(val: RetValue) -> Result { + try_downcast!(val -> TVMByteArray, |RetValue::Bytes(val)| { val }) + } +} + +impl Default for RetValue { + fn default() -> Self { + Self::Int(0) + } +} + +impl TryFrom for std::ffi::CString { + type Error = ValueDowncastError; + fn try_from(val: RetValue) -> Result { + try_downcast!(val -> std::ffi::CString, + |RetValue::Str(val)| { val.into() }) + } +} diff --git a/rust/tvm-sys/src/value.rs b/rust/tvm-sys/src/value.rs new file mode 100644 index 000000000000..a9ad5f523fde --- /dev/null +++ b/rust/tvm-sys/src/value.rs @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::str::FromStr; + +use crate::ffi::*; + +use thiserror::Error; + +macro_rules! impl_pod_tvm_value { + ($field:ident, $field_ty:ty, $( $ty:ty ),+) => { + $( + impl From<$ty> for TVMValue { + fn from(val: $ty) -> Self { + TVMValue { $field: val as $field_ty } + } + } + + impl From for $ty { + fn from(val: TVMValue) -> Self { + unsafe { val.$field as $ty } + } + } + )+ + }; + ($field:ident, $ty:ty) => { + impl_pod_tvm_value!($field, $ty, $ty); + } +} + +impl_pod_tvm_value!(v_int64, i64, i8, u8, i16, u16, i32, u32, i64, u64, isize, usize); +impl_pod_tvm_value!(v_float64, f64, f32, f64); +impl_pod_tvm_value!(v_type, DLDataType); +impl_pod_tvm_value!(v_ctx, TVMContext); + +#[derive(Debug, Error)] +#[error("unsupported device: {0}")] +pub struct UnsupportedDeviceError(String); + +macro_rules! impl_tvm_context { + ( $( $dev_type:ident : [ $( $dev_name:ident ),+ ] ),+ ) => { + /// Creates a TVMContext from a string (e.g., "cpu", "gpu", "ext_dev") + impl FromStr for TVMContext { + type Err = UnsupportedDeviceError; + fn from_str(type_str: &str) -> Result { + Ok(Self { + device_type: match type_str { + $( $( stringify!($dev_name) )|+ => $dev_type ),+, + _ => return Err(UnsupportedDeviceError(type_str.to_string())), + }, + device_id: 0, + }) + } + } + + impl TVMContext { + $( + $( + pub fn $dev_name(device_id: usize) -> Self { + Self { + device_type: $dev_type, + device_id: device_id as i32, + } + } + )+ + )+ + } + }; +} + +impl_tvm_context!( + DLDeviceType_kDLCPU: [cpu, llvm, stackvm], + DLDeviceType_kDLGPU: [gpu, cuda, nvptx], + DLDeviceType_kDLOpenCL: [cl], + DLDeviceType_kDLMetal: [metal], + DLDeviceType_kDLVPI: [vpi], + DLDeviceType_kDLROCM: [rocm], + DLDeviceType_kDLExtDev: [ext_dev] +); diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc index 83dfc64009cf..037c76665d4b 100644 --- a/src/arith/analyzer.cc +++ b/src/arith/analyzer.cc @@ -20,9 +20,9 @@ /*! * \file tvm/arith/analyzer.cc */ +#include #include #include -#include #include namespace tvm { @@ -33,34 +33,33 @@ Analyzer::Analyzer() modular_set(this), rewrite_simplify(this), canonical_simplify(this), - int_set(this) { -} + int_set(this) {} -void Analyzer::Bind(const Var& var, const PrimExpr& expr) { +void Analyzer::Bind(const Var& var, const PrimExpr& expr, bool override) { PrimExpr new_expr = expr; new_expr = this->canonical_simplify(new_expr); new_expr = this->rewrite_simplify(new_expr); - this->const_int_bound.Update(var, this->const_int_bound(new_expr)); - this->modular_set.Update(var, this->modular_set(new_expr)); - this->rewrite_simplify.Update(var, new_expr); - this->canonical_simplify.Update(var, new_expr); + this->const_int_bound.Update(var, this->const_int_bound(new_expr), override); + this->modular_set.Update(var, this->modular_set(new_expr), override); + this->rewrite_simplify.Update(var, new_expr, override); + this->canonical_simplify.Update(var, new_expr, override); } -void Analyzer::Bind(const Var& var, const Range& range) { +void Analyzer::Bind(const Var& var, const Range& range, bool override) { CHECK(range.defined()); if (tir::is_one(range->extent)) { - this->Bind(var, range->min); + this->Bind(var, range->min, override); } else { - this->const_int_bound.Bind(var, range); + this->const_int_bound.Bind(var, range, override); } // skip modular_set // skip rewrite simplify } -void Analyzer::Bind(const Map& variables) { +void Analyzer::Bind(const Map& variables, bool override) { for (const auto& iter : variables) { - this->Bind(iter.first, iter.second); + this->Bind(iter.first, iter.second, override); } } @@ -92,6 +91,15 @@ bool Analyzer::CanProveGreaterEqual(const PrimExpr& expr, int64_t lower_bound) { return false; } +bool Analyzer::CanProveLess(const PrimExpr& expr, int64_t upper_bound) { + if (const auto* ptr = expr.as()) { + return ptr->value < upper_bound; + } + auto bd = this->const_int_bound(this->rewrite_simplify(expr)); + if (bd->max_value < upper_bound) return true; + return false; +} + bool Analyzer::CanProve(const PrimExpr& expr) { if (const auto* ptr = expr.as()) { return ptr->value != 0; @@ -115,63 +123,53 @@ PrimExpr Analyzer::Simplify(const PrimExpr& expr) { return res; } -TVM_REGISTER_GLOBAL("arith.CreateAnalyzer") -.set_body([](TVMArgs args, TVMRetValue* ret) { - using runtime::PackedFunc; - using runtime::TypedPackedFunc; - auto self = std::make_shared(); - auto f = [self](std::string name) -> PackedFunc { - if (name == "const_int_bound") { - return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { - *ret = self->const_int_bound(args[0]); - }); - } else if (name == "modular_set") { - return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { - *ret = self->modular_set(args[0]); - }); - } else if (name == "const_int_bound_update") { - return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { - self->const_int_bound.Update(args[0], args[1], args[2]); - }); - } else if (name == "Simplify") { - return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { - *ret = self->Simplify(args[0]); - }); - } else if (name == "rewrite_simplify") { - return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { - *ret = self->rewrite_simplify(args[0]); - }); - } else if (name == "canonical_simplify") { - return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { - *ret = self->canonical_simplify(args[0]); - }); - } else if (name == "int_set") { - return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { - *ret = self->int_set(args[0], args[1]); - }); - } else if (name == "bind") { - return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { - if (args[1].IsObjectRef()) { - self->Bind(args[0], args[1].operator Range()); - } else { - self->Bind(args[0], args[1].operator PrimExpr()); - } - }); - } else if (name == "enter_constraint_context") { - return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { - // can't use make_shared due to noexcept(false) decl in destructor, - // see https://stackoverflow.com/a/43907314 - auto ctx = std::shared_ptr >( - new With(self.get(), args[0])); - auto fexit = [ctx](TVMArgs, TVMRetValue*) mutable { - ctx.reset(); - }; - *ret = PackedFunc(fexit); - }); - } - return PackedFunc(); - }; - *ret = TypedPackedFunc(f); +TVM_REGISTER_GLOBAL("arith.CreateAnalyzer").set_body([](TVMArgs args, TVMRetValue* ret) { + using runtime::PackedFunc; + using runtime::TypedPackedFunc; + auto self = std::make_shared(); + auto f = [self](std::string name) -> PackedFunc { + if (name == "const_int_bound") { + return PackedFunc( + [self](TVMArgs args, TVMRetValue* ret) { *ret = self->const_int_bound(args[0]); }); + } else if (name == "modular_set") { + return PackedFunc( + [self](TVMArgs args, TVMRetValue* ret) { *ret = self->modular_set(args[0]); }); + } else if (name == "const_int_bound_update") { + return PackedFunc([self](TVMArgs args, TVMRetValue* ret) { + self->const_int_bound.Update(args[0], args[1], args[2]); + }); + } else if (name == "Simplify") { + return PackedFunc([self](TVMArgs args, TVMRetValue* ret) { *ret = self->Simplify(args[0]); }); + } else if (name == "rewrite_simplify") { + return PackedFunc( + [self](TVMArgs args, TVMRetValue* ret) { *ret = self->rewrite_simplify(args[0]); }); + } else if (name == "canonical_simplify") { + return PackedFunc( + [self](TVMArgs args, TVMRetValue* ret) { *ret = self->canonical_simplify(args[0]); }); + } else if (name == "int_set") { + return PackedFunc( + [self](TVMArgs args, TVMRetValue* ret) { *ret = self->int_set(args[0], args[1]); }); + } else if (name == "bind") { + return PackedFunc([self](TVMArgs args, TVMRetValue* ret) { + if (args[1].IsObjectRef()) { + self->Bind(args[0], args[1].operator Range()); + } else { + self->Bind(args[0], args[1].operator PrimExpr()); + } + }); + } else if (name == "enter_constraint_context") { + return PackedFunc([self](TVMArgs args, TVMRetValue* ret) { + // can't use make_shared due to noexcept(false) decl in destructor, + // see https://stackoverflow.com/a/43907314 + auto ctx = std::shared_ptr >( + new With(self.get(), args[0])); + auto fexit = [ctx](TVMArgs, TVMRetValue*) mutable { ctx.reset(); }; + *ret = PackedFunc(fexit); + }); + } + return PackedFunc(); + }; + *ret = TypedPackedFunc(f); }); } // namespace arith diff --git a/src/arith/bound_deducer.cc b/src/arith/bound_deducer.cc index 26be5d51115f..496eb204f24b 100644 --- a/src/arith/bound_deducer.cc +++ b/src/arith/bound_deducer.cc @@ -21,14 +21,14 @@ * \file bound_deducer.cc * \brief Utility to deduce bound of expression */ +#include #include #include -#include #include -#include -#include #include +#include + #include "interval_set.h" namespace tvm { @@ -38,7 +38,7 @@ using namespace tir; // a visitor to find the path to the target variable // from a expression. -class VariablePathFinder: public ExprVisitor { +class VariablePathFinder : public ExprVisitor { public: explicit VariablePathFinder(PrimExpr target) : target_(target) {} @@ -68,17 +68,17 @@ std::vector GetPath(PrimExpr target, PrimExpr expr) { return v.path_; } -enum CompareOp {kGreater, kLess, kEqual}; +enum CompareOp { kGreater, kLess, kEqual }; // a visitor to deduce the bound of a variable from a expression -class BoundDeducer: public ExprVisitor { +class BoundDeducer : public ExprVisitor { public: friend class BoundDeduceInputChecker; friend class Converter; BoundDeducer(PrimExpr target, PrimExpr expr, const std::unordered_map& hint_map, const std::unordered_map& relax_map) - : target_(target), expr_(expr), hint_map_(hint_map), relax_map_(relax_map) {} + : target_(target), expr_(expr), hint_map_(hint_map), relax_map_(relax_map) {} void Deduce(); @@ -120,7 +120,7 @@ class BoundDeducer: public ExprVisitor { result_ += op->b; } else { result_ -= op->a; - result_ = - result_; + result_ = -result_; comp_op = ReverseOp(comp_op); } this->VisitExpr(left ? op->a : op->b); @@ -149,7 +149,7 @@ class BoundDeducer: public ExprVisitor { // always use relax bound bool divided = analyzer_.CanProve(floormod(result_, operand) == 0); - result_ = floordiv(result_, operand); // rounding down here + result_ = floordiv(result_, operand); // rounding down here if (!divided) { if (comp_op == kGreater) { @@ -194,7 +194,7 @@ class BoundDeducer: public ExprVisitor { Analyzer analyzer_; }; -class BoundDeduceInputChecker: public ExprVisitor { +class BoundDeduceInputChecker : public ExprVisitor { public: bool Check(BoundDeducer* deducer) { deducer_ = deducer; @@ -220,9 +220,12 @@ void BoundDeducer::Init() { CompareOp BoundDeducer::ReverseOp(CompareOp comp_op) { switch (comp_op) { - case kEqual: return kEqual; // IntSet can not represent range for `NE - case kGreater: return kLess; - case kLess: return kGreater; + case kEqual: + return kEqual; // IntSet can not represent range for `NE + case kGreater: + return kLess; + case kLess: + return kGreater; default: LOG(FATAL) << "Not a valid compare op"; return kGreater; // return some default value @@ -319,18 +322,18 @@ void BoundDeducer::Relax() { // Both LHS and RHS of the EQ should behave as constants e.g. i == j, // can not be resolved when either `i` or `j` or both are variables with // some Range OR `i` and `j` both should be a single point in IntSet - if (comp_op == kEqual && (!analyzer_.CanProve(b.min() == b.max()) - || !analyzer_.CanProve(a.min() == a.max()))) { + if (comp_op == kEqual && + (!analyzer_.CanProve(b.min() == b.max()) || !analyzer_.CanProve(a.min() == a.max()))) { success_ = false; return; } - expr_ = (comp_op == kGreater) ? a.min() : a.max(); + expr_ = (comp_op == kGreater) ? a.min() : a.max(); result_ = (comp_op == kGreater) ? b.max() : b.min(); } IntSet DeduceBound(PrimExpr v, PrimExpr e, - const std::unordered_map& hint_map, - const std::unordered_map& relax_map) { + const std::unordered_map& hint_map, + const std::unordered_map& relax_map) { BoundDeducer d(v, e, hint_map, relax_map); d.Deduce(); if (!d.success_) return IntSet::nothing(); @@ -348,8 +351,7 @@ IntSet DeduceBound(PrimExpr v, PrimExpr e, // assuming e >= 0, deduce the bound of variable from it. // return empty set to represent deduce failure. -IntSet DeduceBound(PrimExpr v, PrimExpr e, - const Map& hint_map, +IntSet DeduceBound(PrimExpr v, PrimExpr e, const Map& hint_map, const Map& relax_map) { std::unordered_map hmap; for (auto kv : hint_map) { @@ -362,16 +364,11 @@ IntSet DeduceBound(PrimExpr v, PrimExpr e, return DeduceBound(v, e, hmap, rmap); } - TVM_REGISTER_GLOBAL("arith.DeduceBound") -.set_body_typed([]( - PrimExpr v, PrimExpr cond, - const Map hint_map, - const Map relax_map -) { - return DeduceBound(v, cond, hint_map, relax_map); -}); - + .set_body_typed([](PrimExpr v, PrimExpr cond, const Map hint_map, + const Map relax_map) { + return DeduceBound(v, cond, hint_map, relax_map); + }); } // namespace arith } // namespace tvm diff --git a/src/arith/canonical_simplify.cc b/src/arith/canonical_simplify.cc index 7a6e772c2935..b81565f3b735 100644 --- a/src/arith/canonical_simplify.cc +++ b/src/arith/canonical_simplify.cc @@ -22,8 +22,8 @@ * \brief Canonical form based simplification. */ #include -#include #include +#include #include "const_fold.h" #include "pattern_match.h" @@ -37,7 +37,6 @@ using namespace tir; class SumExpr; class SplitExpr; - /*! * \brief Base class of all temporary expression introduced * for canonicalization. @@ -53,10 +52,10 @@ class CanonicalExprNode : public PrimExprNode { virtual PrimExpr Normalize() const = 0; // overrides - void VisitAttrs(tvm::AttrVisitor* v) { - } + void VisitAttrs(tvm::AttrVisitor* v) {} static constexpr const char* _type_key = "arith.CanonicalExpr"; + static constexpr const uint32_t _type_child_slots = 2; TVM_DECLARE_BASE_OBJECT_INFO(CanonicalExprNode, PrimExprNode); }; @@ -110,9 +109,7 @@ class SplitExprNode : public CanonicalExprNode { DivMode div_mode{kTruncDiv}; /*! \brief verify that this is a valid entry. */ - void Verify() const { - CHECK(upper_factor == kPosInf || upper_factor % lower_factor == 0); - } + void Verify() const { CHECK(upper_factor == kPosInf || upper_factor % lower_factor == 0); } PrimExpr NormalizeWithScale(int64_t sscale) const { PrimExpr res = this->index; @@ -134,13 +131,9 @@ class SplitExprNode : public CanonicalExprNode { return res; } - PrimExpr Normalize() const final { - return NormalizeWithScale(1); - } + PrimExpr Normalize() const final { return NormalizeWithScale(1); } - void MulToSelf(int64_t scale) { - this->scale *= scale; - } + void MulToSelf(int64_t scale) { this->scale *= scale; } inline bool IndexEqual(const SplitExpr& other) const; inline bool DivModeCompatibleTo(DivMode mode) const; @@ -185,9 +178,7 @@ class SumExprNode : public CanonicalExprNode { /*! \brief Base value in the summation. */ int64_t base{0}; /*! \brief The expression equals zero. */ - bool IsZero() const { - return base == 0 && args.size() == 0; - } + bool IsZero() const { return base == 0 && args.size() == 0; } /*! * \brief Return the normal Expr that is equivalent to self. * \return The normal expression. @@ -197,9 +188,7 @@ class SumExprNode : public CanonicalExprNode { if (this->args.size() == 0) { return make_const(this->dtype, this->base); } - return Normalize_(this->dtype, - SimplifySplitExprs(args), - base); + return Normalize_(this->dtype, SimplifySplitExprs(args), base); } /*! * \brief Whether self is divisible by scale. @@ -238,9 +227,7 @@ class SumExprNode : public CanonicalExprNode { * \brief add constant value to self. * \param value to be added. */ - void AddToSelf(int64_t value) { - this->base += value; - } + void AddToSelf(int64_t value) { this->base += value; } /*! * \brief self += other * scale; * \param other The expression to be added. @@ -256,8 +243,7 @@ class SumExprNode : public CanonicalExprNode { if (args[start]->IndexEqual(other)) break; } for (size_t j = start; j < args.size(); ++j) { - if (!args[j]->IndexEqual(other) || - other->lower_factor > args[j]->lower_factor) { + if (!args[j]->IndexEqual(other) || other->lower_factor > args[j]->lower_factor) { other.CopyOnWrite()->scale *= scale; this->args.insert(this->args.begin() + j, other); return; @@ -285,8 +271,7 @@ class SumExprNode : public CanonicalExprNode { * \param args The original list of arguments. * \return simplified version. */ - static std::vector - SimplifySplitExprs(std::vector args) { + static std::vector SimplifySplitExprs(std::vector args) { // NOTE: This algorithm relies on the factor that args are divided into segments // and each segment is sorted in descending order of lower_factor. for (size_t i = 0; i < args.size(); ++i) { @@ -296,14 +281,12 @@ class SumExprNode : public CanonicalExprNode { SplitExpr& rhs = args[j]; if (!lhs->IndexEqual(rhs)) break; if (lhs->upper_factor < rhs->lower_factor) break; - if (lhs->upper_factor == rhs->upper_factor && - lhs->lower_factor == rhs->lower_factor && + if (lhs->upper_factor == rhs->upper_factor && lhs->lower_factor == rhs->lower_factor && lhs->DivModeCompatibleTo(rhs->div_mode)) { // folding same co-efficient. rhs.CopyOnWrite()->scale += lhs->scale; lhs.CopyOnWrite()->scale = 0; - } else if (lhs->lower_factor == rhs->upper_factor && - rhs->scale != 0 && + } else if (lhs->lower_factor == rhs->upper_factor && rhs->scale != 0 && lhs->scale % rhs->scale == 0 && lhs->lower_factor == (lhs->scale / rhs->scale) * rhs->lower_factor && lhs->DivModeCompatibleTo(rhs->div_mode)) { @@ -384,9 +367,7 @@ class SumExprNode : public CanonicalExprNode { std::stable_sort(args.begin(), args.end(), fcompare); return args; } - static PrimExpr Normalize_(DataType dtype, - const std::vector& args, - int64_t base) { + static PrimExpr Normalize_(DataType dtype, const std::vector& args, int64_t base) { // Positive scales first PrimExpr res = make_const(dtype, 0); for (size_t i = 0; i < args.size(); ++i) { @@ -431,9 +412,7 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl { public: using Rewriter = RewriteSimplifier::Impl; - explicit Impl(Analyzer* parent) - : Rewriter(parent) {} - + explicit Impl(Analyzer* parent) : Rewriter(parent) {} PrimExpr CanonicalSimplify(PrimExpr expr) { expr = operator()(expr); @@ -447,9 +426,7 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl { } // Normal mutation without normalization. - PrimExpr CanonicalMutate(PrimExpr expr) { - return Rewriter::VisitExpr(expr); - } + PrimExpr CanonicalMutate(PrimExpr expr) { return Rewriter::VisitExpr(expr); } using Rewriter::VisitExpr_; PrimExpr VisitExpr_(const AddNode* op) final; @@ -485,9 +462,7 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl { * \param out_divisible The result divisible component. * \param out_non_divisible The non-divisible component. */ - void SeparateDivisibleParts(const SumExprNode* psum, - int64_t coeff, - SumExpr* out_divisible, + void SeparateDivisibleParts(const SumExprNode* psum, int64_t coeff, SumExpr* out_divisible, SumExpr* out_non_divisible); /*! * \brief Normalize expr to normal expr. @@ -567,8 +542,7 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl { PrimExpr SimplifyReduceCombiner(const ReduceNode* op); }; -PrimExpr CanonicalSimplifier::Impl:: -VisitExpr_(const AddNode* op) { +PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const AddNode* op) { if (!IsIndexType(op->dtype)) { return Rewriter::VisitExpr_(op); } @@ -577,7 +551,7 @@ VisitExpr_(const AddNode* op) { PrimExpr b = this->CanonicalMutate(op->b); // const folding - PrimExpr const_res = TryConstFold(a, b); + PrimExpr const_res = TryConstFold(a, b); if (const_res.defined()) return const_res; // canonical form simplification. @@ -593,8 +567,7 @@ VisitExpr_(const AddNode* op) { return std::move(ret); } -PrimExpr CanonicalSimplifier::Impl:: -VisitExpr_(const SubNode* op) { +PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const SubNode* op) { if (!IsIndexType(op->dtype)) { return Rewriter::VisitExpr_(op); } @@ -603,7 +576,7 @@ VisitExpr_(const SubNode* op) { PrimExpr b = this->CanonicalMutate(op->b); // const folding - PrimExpr const_res = TryConstFold(a, b); + PrimExpr const_res = TryConstFold(a, b); if (const_res.defined()) return const_res; // canonical form simplification. @@ -619,9 +592,7 @@ VisitExpr_(const SubNode* op) { return std::move(ret); } - -PrimExpr CanonicalSimplifier::Impl:: -VisitExpr_(const MulNode* op) { +PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const MulNode* op) { if (!IsIndexType(op->dtype)) { return Rewriter::VisitExpr_(op); } @@ -630,7 +601,7 @@ VisitExpr_(const MulNode* op) { PrimExpr b = this->CanonicalMutate(op->b); // const folding - PrimExpr const_res = TryConstFold(a, b); + PrimExpr const_res = TryConstFold(a, b); if (const_res.defined()) return const_res; // x * c @@ -655,15 +626,13 @@ VisitExpr_(const MulNode* op) { if (op->a.same_as(a) && op->b.same_as(b)) { return GetRef(op); } else { - return MulNode::make(a, b); + return Mul(a, b); } } -void CanonicalSimplifier::Impl:: -SeparateDivisibleParts(const SumExprNode* psum, - int64_t coeff, - SumExpr* out_divisible, - SumExpr* out_non_divisible) { +void CanonicalSimplifier::Impl::SeparateDivisibleParts(const SumExprNode* psum, int64_t coeff, + SumExpr* out_divisible, + SumExpr* out_non_divisible) { auto divisible = make_object(); auto non_divisible = make_object(); divisible->dtype = psum->dtype; @@ -685,8 +654,7 @@ SeparateDivisibleParts(const SumExprNode* psum, *out_non_divisible = SumExpr(non_divisible); } -SplitExpr CanonicalSimplifier::Impl:: -SplitDivConst(SplitExpr lhs, int64_t cval, DivMode div_mode) { +SplitExpr CanonicalSimplifier::Impl::SplitDivConst(SplitExpr lhs, int64_t cval, DivMode div_mode) { CHECK_GT(cval, 0); lhs = ConvertDivMode(lhs, div_mode); @@ -727,8 +695,7 @@ SplitDivConst(SplitExpr lhs, int64_t cval, DivMode div_mode) { return lhs; } -PrimExpr CanonicalSimplifier::Impl:: -VisitExpr_(const DivNode* op) { +PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const DivNode* op) { if (!IsIndexType(op->dtype)) { return Rewriter::VisitExpr_(op); } @@ -737,7 +704,7 @@ VisitExpr_(const DivNode* op) { PrimExpr b = this->CanonicalMutate(op->b); // const folding - PrimExpr const_res = TryConstFold(a, b); + PrimExpr const_res = TryConstFold
(a, b); if (const_res.defined()) return const_res; PVar c1; // x / c1 @@ -763,8 +730,7 @@ VisitExpr_(const DivNode* op) { } else { // if 0 <= extra < cval, it means the extra can be eliminated. if (TryCompare(temp, cval) != kLT) { - lhs.CopyOnWrite()->AddToSelf( - SplitDivConst(ToSplitExpr(temp), cval, kTruncDiv), 1); + lhs.CopyOnWrite()->AddToSelf(SplitDivConst(ToSplitExpr(temp), cval, kTruncDiv), 1); } } return std::move(lhs); @@ -784,12 +750,11 @@ VisitExpr_(const DivNode* op) { if (op->a.same_as(a) && op->b.same_as(b)) { return GetRef(op); } else { - return DivNode::make(a, b); + return Div(a, b); } } -PrimExpr CanonicalSimplifier::Impl:: -VisitExpr_(const FloorDivNode* op) { +PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { if (!IsIndexType(op->dtype)) { return Rewriter::VisitExpr_(op); } @@ -797,7 +762,7 @@ VisitExpr_(const FloorDivNode* op) { PrimExpr b = this->CanonicalMutate(op->b); // const folding - PrimExpr const_res = TryConstFold(a, b); + PrimExpr const_res = TryConstFold(a, b); if (const_res.defined()) return const_res; PVar c1; // x / c1 @@ -820,8 +785,7 @@ VisitExpr_(const FloorDivNode* op) { } else { // if 0 <= extra < cval, it means the extra can be eliminated. if (!(TryCompare(temp, cval) == kLT && analyzer_->CanProveGreaterEqual(temp, 0))) { - lhs.CopyOnWrite()->AddToSelf( - SplitDivConst(ToSplitExpr(temp), cval, kFloorDiv), 1); + lhs.CopyOnWrite()->AddToSelf(SplitDivConst(ToSplitExpr(temp), cval, kFloorDiv), 1); } } return std::move(lhs); @@ -840,12 +804,11 @@ VisitExpr_(const FloorDivNode* op) { if (op->a.same_as(a) && op->b.same_as(b)) { return GetRef(op); } else { - return FloorDivNode::make(a, b); + return FloorDiv(a, b); } } -SplitExpr CanonicalSimplifier::Impl:: -SplitModConst(SplitExpr lhs, int64_t cval, DivMode div_mode) { +SplitExpr CanonicalSimplifier::Impl::SplitModConst(SplitExpr lhs, int64_t cval, DivMode div_mode) { CHECK_GT(cval, 0); lhs = ConvertDivMode(lhs, div_mode); @@ -859,16 +822,15 @@ SplitModConst(SplitExpr lhs, int64_t cval, DivMode div_mode) { // (x / c1) % c2 => (x % (c1 * c2)) / c2 int64_t new_upper_factor = lhs->lower_factor * scaled_cval; // try to see if we can reduce the existing upper modular. - if (lhs->upper_factor == SplitExprNode::kPosInf || - lhs->upper_factor % new_upper_factor == 0) { + if (lhs->upper_factor == SplitExprNode::kPosInf || lhs->upper_factor % new_upper_factor == 0) { // we gained a new upper factor that is smaller // than the original one // Perhaps there are more chances in simplifying the index // Do a recursive call to simplify the mod with the new factor. - if (new_upper_factor < lhs->upper_factor && - lhs->upper_factor != SplitExprNode::kPosInf) { - auto updated = ToSplitExpr(this->VisitExpr(ModImpl( - lhs->index, make_const(lhs.dtype(), new_upper_factor), div_mode))); + if (new_upper_factor < lhs->upper_factor && lhs->upper_factor != SplitExprNode::kPosInf) { + auto updated = ToSplitExpr(this->VisitExpr( + ModImpl(lhs->index, make_const(lhs.dtype(), new_upper_factor), div_mode))); + updated.CopyOnWrite()->scale = lhs->scale; // re-apply the lower_factor if (lhs->lower_factor != 1) { return SplitDivConst(updated, lhs->lower_factor, div_mode); @@ -894,8 +856,7 @@ SplitModConst(SplitExpr lhs, int64_t cval, DivMode div_mode) { return lhs; } -PrimExpr CanonicalSimplifier::Impl:: -VisitExpr_(const ModNode* op) { +PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const ModNode* op) { if (!IsIndexType(op->dtype)) { return Rewriter::VisitExpr_(op); } @@ -904,7 +865,7 @@ VisitExpr_(const ModNode* op) { PrimExpr b = this->CanonicalMutate(op->b); // const folding - PrimExpr const_res = TryConstFold(a, b); + PrimExpr const_res = TryConstFold(a, b); if (const_res.defined()) return const_res; PVar c1; @@ -939,8 +900,7 @@ VisitExpr_(const ModNode* op) { // (x - 5) % 3 => (x - 2) % 3 if x - 5 >= 0 auto cbound = analyzer_->const_int_bound(Normalize(a)); int64_t new_base = psum->base % cval; - if (cbound->min_value >= 0 && - cbound->min_value - psum->base + new_base >= 0) { + if (cbound->min_value >= 0 && cbound->min_value - psum->base + new_base >= 0) { SumExpr sum_expr = Downcast(a); sum_expr.CopyOnWrite()->base = new_base; return SplitModConst(ToSplitExpr(std::move(sum_expr)), cval, kTruncDiv); @@ -960,12 +920,11 @@ VisitExpr_(const ModNode* op) { if (op->a.same_as(a) && op->b.same_as(b)) { return GetRef(op); } else { - return ModNode::make(a, b); + return Mod(a, b); } } -PrimExpr CanonicalSimplifier::Impl:: -VisitExpr_(const FloorModNode* op) { +PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const FloorModNode* op) { if (!IsIndexType(op->dtype)) { return Rewriter::VisitExpr_(op); } @@ -974,7 +933,7 @@ VisitExpr_(const FloorModNode* op) { PrimExpr b = this->CanonicalMutate(op->b); // const folding - PrimExpr const_res = TryConstFold(a, b); + PrimExpr const_res = TryConstFold(a, b); if (const_res.defined()) return const_res; PVar c1; @@ -989,8 +948,7 @@ VisitExpr_(const FloorModNode* op) { return floormod(temp, c1.Eval()); } else { // If temp < cval && temp >=0 then can remove the mod. - if (TryCompare(temp, cval) == kLT && - analyzer_->CanProveGreaterEqual(temp, 0)) { + if (TryCompare(temp, cval) == kLT && analyzer_->CanProveGreaterEqual(temp, 0)) { return temp; } else { // contonue to use logic below. @@ -1020,13 +978,12 @@ VisitExpr_(const FloorModNode* op) { if (op->a.same_as(a) && op->b.same_as(b)) { return GetRef(op); } else { - return FloorModNode::make(a, b); + return FloorMod(a, b); } } // Simplify reduce expression. -PrimExpr CanonicalSimplifier::Impl:: -SimplifyReduceCombiner(const ReduceNode* op) { +PrimExpr CanonicalSimplifier::Impl::SimplifyReduceCombiner(const ReduceNode* op) { // First simplify the results Array simplified_result; for (const auto& res : op->combiner->result) { @@ -1060,8 +1017,7 @@ SimplifyReduceCombiner(const ReduceNode* op) { // components which have side effects should also be preserved for (size_t i = 0; i < used.size(); ++i) { - if (HasSideEffect(op->source[i]) || - HasSideEffect(op->combiner->identity_element[i]) || + if (HasSideEffect(op->source[i]) || HasSideEffect(op->combiner->identity_element[i]) || HasSideEffect(op->combiner->result[i])) { mark_used(i); } @@ -1089,14 +1045,11 @@ SimplifyReduceCombiner(const ReduceNode* op) { } } - CommReducer new_combiner = - CommReducerNode::make(new_lhs, new_rhs, new_result, new_identity); - return ReduceNode::make( - new_combiner, new_source, op->axis, op->condition, new_value_index); + CommReducer new_combiner = CommReducer(new_lhs, new_rhs, new_result, new_identity); + return Reduce(new_combiner, new_source, op->axis, op->condition, new_value_index); } -PrimExpr CanonicalSimplifier::Impl:: -VisitExpr_(const ReduceNode* op) { +PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const ReduceNode* op) { // Recursively call simplification when necessary. PrimExpr ret = RewriteSimplifier::Impl::VisitExpr_(op); op = ret.as(); @@ -1107,10 +1060,8 @@ VisitExpr_(const ReduceNode* op) { // assumption we would have to perform a single iteration of the loop, i.e. use // `(*op->combiner.get())(op->combineop->identity_element, op->source)[op->value_index]` // instead of `op->source[op->value_index]`. The former may be more difficult to simplify. - return this->VisitExpr( - SelectNode::make(op->condition, - op->source[op->value_index], - op->combiner->identity_element[op->value_index])); + return this->VisitExpr(Select(op->condition, op->source[op->value_index], + op->combiner->identity_element[op->value_index])); } // combiner simplification. ret = SimplifyReduceCombiner(op); @@ -1121,19 +1072,13 @@ PrimExpr CanonicalSimplifier::operator()(const PrimExpr& expr) { return impl_->CanonicalSimplify(expr); } -void CanonicalSimplifier::Update(const Var& var, - const PrimExpr& info, - bool override) { +void CanonicalSimplifier::Update(const Var& var, const PrimExpr& info, bool override) { impl_->Update(var, info, override); } -CanonicalSimplifier::CanonicalSimplifier(Analyzer* parent) - : impl_(new Impl(parent)) { -} +CanonicalSimplifier::CanonicalSimplifier(Analyzer* parent) : impl_(new Impl(parent)) {} -CanonicalSimplifier::~CanonicalSimplifier() { - delete impl_; -} +CanonicalSimplifier::~CanonicalSimplifier() { delete impl_; } } // namespace arith } // namespace tvm diff --git a/src/arith/compute_expr.h b/src/arith/compute_expr.h deleted file mode 100644 index adb4f3000a29..000000000000 --- a/src/arith/compute_expr.h +++ /dev/null @@ -1,129 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file compute_expr.h - * \brief Utility to invoke certan compute operations. - */ -#ifndef TVM_ARITH_COMPUTE_EXPR_H_ -#define TVM_ARITH_COMPUTE_EXPR_H_ - -#include -#include -#include - -namespace tvm { -namespace arith { - -/*! - * \brief Compute the expression with the given binary op. - * \param lhs The left operand - * \param rhs The right operand - * \tparam Op the computation operator - * \return The result. - */ -template -inline PrimExpr Compute(PrimExpr lhs, PrimExpr rhs) { - return OP::make(lhs, rhs); -} - -/*! - * \brief Compute an reduction with Op - * \param values The input values. - * \param empty_value The value when return if it is empty, can be Expr() - * which will cause an error to be rasied. - * \tparam Op The computation operator - * \return The result. - */ -template -inline PrimExpr ComputeReduce( - const Array& values, PrimExpr empty_value); - -inline bool GetConst(PrimExpr e, int64_t* out) { - if (e.dtype().is_vector()) return false; - const int64_t* v = tir::as_const_int(e); - if (v) { - *out = *v; return true; - } else { - return false; - } -} - -// get a small constant int -inline bool GetConstInt(PrimExpr e, int* out) { - int64_t v1 = 0; - if (GetConst(e, &v1)) { - if (v1 > static_cast( - std::numeric_limits::max())) return false; - *out = static_cast(v1); return true; - } - return false; -} - -template<> -inline PrimExpr Compute(PrimExpr a, PrimExpr b) { - return a + b; -} - -template<> -inline PrimExpr Compute(PrimExpr a, PrimExpr b) { - return a - b; -} - -template<> -inline PrimExpr Compute(PrimExpr a, PrimExpr b) { - return a * b; -} - -template<> -inline PrimExpr Compute(PrimExpr a, PrimExpr b) { - return truncdiv(a, b); -} - -template<> -inline PrimExpr Compute(PrimExpr a, PrimExpr b) { - return truncmod(a, b); -} - -template<> -inline PrimExpr Compute(PrimExpr a, PrimExpr b) { - return max(a, b); -} - -template<> -inline PrimExpr Compute(PrimExpr a, PrimExpr b) { - return min(a, b); -} - -template -inline PrimExpr ComputeReduce(const Array& values, PrimExpr empty_value) { - if (values.size() == 0U) { - CHECK(empty_value.defined()); - return empty_value; - } - PrimExpr res = values[0]; - for (size_t i = 1; i < values.size(); ++i) { - res = Compute(res, values[i]); - } - return res; -} - -} // namespace arith -} // namespace tvm -#endif // TVM_ARITH_COMPUTE_EXPR_H_ diff --git a/src/arith/const_fold.h b/src/arith/const_fold.h index a440af994202..876d336454d8 100644 --- a/src/arith/const_fold.h +++ b/src/arith/const_fold.h @@ -26,8 +26,10 @@ #include #include + #include #include + #include "int_operator.h" namespace tvm { @@ -43,10 +45,8 @@ namespace arith { * \note a and b Must already matched data types with each other. * \return nullptr if constant fold fails, otherwise return folded result. */ -template -inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { - return PrimExpr(); -} +template +inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b); /*! * \brief Try to run unary compute with constant folding. @@ -57,7 +57,7 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { * \note a and b Must already matched data types with each other. * \return nullptr if constant fold fails, otherwise return folded result. */ -template +template inline PrimExpr TryConstFold(PrimExpr a); /*! @@ -70,255 +70,251 @@ inline PrimExpr TryConstFold(PrimExpr a); * \return the checked result. */ inline bool IsIndexType(const DataType& type) { - return type.is_int() && type.lanes() == 1 && - (type.bits() == 32 || type.bits() == 64); + return type.is_int() && type.lanes() == 1 && (type.bits() == 32 || type.bits() == 64); } - -#define TVM_ARITH_CONST_PROPAGATION(BODY) \ - using tir::FloatImmNode; \ - const IntImmNode* pa = a.as(); \ - const IntImmNode* pb = b.as(); \ - const FloatImmNode* fa = a.as(); \ - const FloatImmNode* fb = b.as(); \ +#define TVM_ARITH_CONST_PROPAGATION(BODY) \ + using tir::FloatImmNode; \ + const IntImmNode* pa = a.as(); \ + const IntImmNode* pb = b.as(); \ + const FloatImmNode* fa = a.as(); \ + const FloatImmNode* fb = b.as(); \ BODY; - -#define TVM_INDEX_CONST_PROPAGATION(BODY) \ - const IntImmNode* pa = a.as(); \ - const IntImmNode* pb = b.as(); \ - const DataType& ta = a.dtype(); \ - const DataType& tb = b.dtype(); \ - if (arith::IsIndexType(ta) && arith::IsIndexType(tb)) { \ - BODY; \ - } \ - +#define TVM_INDEX_CONST_PROPAGATION(BODY) \ + const IntImmNode* pa = a.as(); \ + const IntImmNode* pb = b.as(); \ + const DataType& ta = a.dtype(); \ + const DataType& tb = b.dtype(); \ + if (arith::IsIndexType(ta) && arith::IsIndexType(tb)) { \ + BODY; \ + } // specialization of constant folders. -template<> -inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { +template <> +inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - const DataType& rtype = a.dtype(); - if (pa && pb) return IntImm(rtype, pa->value + pb->value); - if (pa && pa->value == 0) return b; - if (pb && pb->value == 0) return a; - if (fa && fb) return FloatImm(rtype, fa->value + fb->value); - if (fa && fa->value == 0) return b; - if (fb && fb->value == 0) return a; - }); + const DataType& rtype = a.dtype(); + if (pa && pb) return IntImm(rtype, pa->value + pb->value); + if (pa && pa->value == 0) return b; + if (pb && pb->value == 0) return a; + if (fa && fb) return FloatImm(rtype, fa->value + fb->value); + if (fa && fa->value == 0) return b; + if (fb && fb->value == 0) return a; + }); return PrimExpr(); } -template<> -inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { +template <> +inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - const DataType& rtype = a.dtype(); - if (pa && pb) return IntImm(rtype, pa->value - pb->value); - if (pb && pb->value == 0) return a; - if (fa && fb) return FloatImm(rtype, fa->value - fb->value); - if (fb && fb->value == 0) return a; - }); + const DataType& rtype = a.dtype(); + if (pa && pb) return IntImm(rtype, pa->value - pb->value); + if (pb && pb->value == 0) return a; + if (fa && fb) return FloatImm(rtype, fa->value - fb->value); + if (fb && fb->value == 0) return a; + }); return PrimExpr(); } -template<> -inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { +template <> +inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - const DataType& rtype = a.dtype(); - if (pa && pb) return IntImm(rtype, pa->value * pb->value); - if (pa) { - if (pa->value == 1) return b; - if (pa->value == 0) return a; - } - if (pb) { - if (pb->value == 1) return a; - if (pb->value == 0) return b; - } - if (fa && fb) return FloatImm(rtype, fa->value * fb->value); - if (fa) { - if (fa->value == 1) return b; - if (fa->value == 0) return a; - } - if (fb) { - if (fb->value == 1) return a; - if (fb->value == 0) return b; - } - }); + const DataType& rtype = a.dtype(); + if (pa && pb) return IntImm(rtype, pa->value * pb->value); + if (pa) { + if (pa->value == 1) return b; + if (pa->value == 0) return a; + } + if (pb) { + if (pb->value == 1) return a; + if (pb->value == 0) return b; + } + if (fa && fb) return FloatImm(rtype, fa->value * fb->value); + if (fa) { + if (fa->value == 1) return b; + if (fa->value == 0) return a; + } + if (fb) { + if (fb->value == 1) return a; + if (fb->value == 0) return b; + } + }); return PrimExpr(); } -template<> -inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { +template <> +inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - const DataType& rtype = a.dtype(); - if (pa && pb) { - // due to division and mod can have different modes - // NOTE: this will assumes truc div. - CHECK_NE(pb->value, 0) << "Divide by zero"; - return IntImm(rtype, pa->value / pb->value); - } - if (pa) { - if (pa->value == 0) return a; - } - if (pb) { - if (pb->value == 1) return a; - CHECK_NE(pb->value, 0) << "Divide by zero"; - } - if (fa && fb && fb->value != 0) { - return FloatImm(rtype, fa->value / fb->value); - } - if (fa && fa->value == 0) return a; - if (fb) { - if (fb->value == 1) return a; - CHECK_NE(fb->value, 0) << "Divide by zero"; - } - }); + const DataType& rtype = a.dtype(); + if (pa && pb) { + // due to division and mod can have different modes + // NOTE: this will assumes truc div. + CHECK_NE(pb->value, 0) << "Divide by zero"; + return IntImm(rtype, pa->value / pb->value); + } + if (pa) { + if (pa->value == 0) return a; + } + if (pb) { + if (pb->value == 1) return a; + CHECK_NE(pb->value, 0) << "Divide by zero"; + } + if (fa && fb && fb->value != 0) { + return FloatImm(rtype, fa->value / fb->value); + } + if (fa && fa->value == 0) return a; + if (fb) { + if (fb->value == 1) return a; + CHECK_NE(fb->value, 0) << "Divide by zero"; + } + }); return PrimExpr(); } -template<> -inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { +template <> +inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_INDEX_CONST_PROPAGATION({ - const DataType& rtype = a.dtype(); - if (pa && pb) { - CHECK_NE(pb->value, 0) << "Divide by zero"; - return IntImm(rtype, pa->value % pb->value); - } - if (pa) { - if (pa->value == 0) return a; - } - if (pb) { - if (pb->value == 1) return tir::make_zero(rtype); - CHECK_NE(pb->value, 0) << "Divide by zero"; - } - }); + const DataType& rtype = a.dtype(); + if (pa && pb) { + CHECK_NE(pb->value, 0) << "Divide by zero"; + return IntImm(rtype, pa->value % pb->value); + } + if (pa) { + if (pa->value == 0) return a; + } + if (pb) { + if (pb->value == 1) return tir::make_zero(rtype); + CHECK_NE(pb->value, 0) << "Divide by zero"; + } + }); return PrimExpr(); } -template<> -inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { +template <> +inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - const DataType& rtype = a.dtype(); - if (pa && pb) { - CHECK_NE(pb->value, 0) << "Divide by zero"; - return IntImm(rtype, arith::floordiv(pa->value, pb->value)); - } - if (pa) { - if (pa->value == 0) return a; - } - if (pb) { - if (pb->value == 1) return a; - CHECK_NE(pb->value, 0) << "Divide by zero"; - } - if (fa && fb && fb->value != 0) { - return FloatImm(rtype, std::floor(fa->value / fb->value)); - } - if (fa && fa->value == 0) return a; - if (fb) { - if (fb->value == 1) return a; - CHECK_NE(fb->value, 0) << "Divide by zero"; - } - }); + const DataType& rtype = a.dtype(); + if (pa && pb) { + CHECK_NE(pb->value, 0) << "Divide by zero"; + return IntImm(rtype, arith::floordiv(pa->value, pb->value)); + } + if (pa) { + if (pa->value == 0) return a; + } + if (pb) { + if (pb->value == 1) return a; + CHECK_NE(pb->value, 0) << "Divide by zero"; + } + if (fa && fb && fb->value != 0) { + return FloatImm(rtype, std::floor(fa->value / fb->value)); + } + if (fa && fa->value == 0) return a; + if (fb) { + if (fb->value == 1) return a; + CHECK_NE(fb->value, 0) << "Divide by zero"; + } + }); return PrimExpr(); } -template<> -inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { +template <> +inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_INDEX_CONST_PROPAGATION({ - const DataType& rtype = a.dtype(); - if (pa && pb) { - CHECK_NE(pb->value, 0) << "Divide by zero"; - return IntImm(rtype, floormod(pa->value, pb->value)); - } - if (pa) { - if (pa->value == 0) return a; - } - if (pb) { - if (pb->value == 1) return tir::make_zero(rtype); - CHECK_NE(pb->value, 0) << "Divide by zero"; - } - }); + const DataType& rtype = a.dtype(); + if (pa && pb) { + CHECK_NE(pb->value, 0) << "Divide by zero"; + return IntImm(rtype, floormod(pa->value, pb->value)); + } + if (pa) { + if (pa->value == 0) return a; + } + if (pb) { + if (pb->value == 1) return tir::make_zero(rtype); + CHECK_NE(pb->value, 0) << "Divide by zero"; + } + }); return PrimExpr(); } -template<> -inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { +template <> +inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - const DataType& rtype = a.dtype(); - if (pa && pb) return IntImm(rtype, std::min(pa->value, pb->value)); - if (fa && fb) return FloatImm(rtype, std::min(fa->value, fb->value)); - }); + const DataType& rtype = a.dtype(); + if (pa && pb) return IntImm(rtype, std::min(pa->value, pb->value)); + if (fa && fb) return FloatImm(rtype, std::min(fa->value, fb->value)); + }); if (a.same_as(b)) return a; return PrimExpr(); } -template<> -inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { +template <> +inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - const DataType& rtype = a.dtype(); - if (pa && pb) return IntImm(rtype, std::max(pa->value, pb->value)); - if (fa && fb) return FloatImm(rtype, std::max(fa->value, fb->value)); - }); + const DataType& rtype = a.dtype(); + if (pa && pb) return IntImm(rtype, std::max(pa->value, pb->value)); + if (fa && fb) return FloatImm(rtype, std::max(fa->value, fb->value)); + }); if (a.same_as(b)) return a; return PrimExpr(); } -template<> -inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { +template <> +inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return IntImm(DataType::UInt(1), pa->value > pb->value); - if (fa && fb) return IntImm(DataType::UInt(1), fa->value > fb->value); - }); + if (pa && pb) return IntImm(DataType::UInt(1), pa->value > pb->value); + if (fa && fb) return IntImm(DataType::UInt(1), fa->value > fb->value); + }); return PrimExpr(); } -template<> -inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { +template <> +inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return IntImm(DataType::UInt(1), pa->value >= pb->value); - if (fa && fb) return IntImm(DataType::UInt(1), fa->value >= fb->value); - }); + if (pa && pb) return IntImm(DataType::UInt(1), pa->value >= pb->value); + if (fa && fb) return IntImm(DataType::UInt(1), fa->value >= fb->value); + }); return PrimExpr(); } -template<> -inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { +template <> +inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return IntImm(DataType::UInt(1), pa->value < pb->value); - if (fa && fb) return IntImm(DataType::UInt(1), fa->value < fb->value); - }); + if (pa && pb) return IntImm(DataType::UInt(1), pa->value < pb->value); + if (fa && fb) return IntImm(DataType::UInt(1), fa->value < fb->value); + }); return PrimExpr(); } -template<> -inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { +template <> +inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return IntImm(DataType::UInt(1), pa->value <= pb->value); - if (fa && fb) return IntImm(DataType::UInt(1), fa->value <= fb->value); - }); + if (pa && pb) return IntImm(DataType::UInt(1), pa->value <= pb->value); + if (fa && fb) return IntImm(DataType::UInt(1), fa->value <= fb->value); + }); return PrimExpr(); } -template<> -inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { +template <> +inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return IntImm(DataType::UInt(1), pa->value == pb->value); - if (fa && fb) return IntImm(DataType::UInt(1), fa->value == fb->value); - }); + if (pa && pb) return IntImm(DataType::UInt(1), pa->value == pb->value); + if (fa && fb) return IntImm(DataType::UInt(1), fa->value == fb->value); + }); return PrimExpr(); } -template<> -inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { +template <> +inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return IntImm(DataType::UInt(1), pa->value != pb->value); - if (fa && fb) return IntImm(DataType::UInt(1), fa->value != fb->value); - }); + if (pa && pb) return IntImm(DataType::UInt(1), pa->value != pb->value); + if (fa && fb) return IntImm(DataType::UInt(1), fa->value != fb->value); + }); return PrimExpr(); } -template<> -inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { +template <> +inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { const IntImmNode* pa = a.as(); const IntImmNode* pb = b.as(); if (pa && pa->value) return b; @@ -328,8 +324,8 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { return PrimExpr(); } -template<> -inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { +template <> +inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { const IntImmNode* pa = a.as(); const IntImmNode* pb = b.as(); if (pa && pa->value) return a; @@ -339,8 +335,8 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { return PrimExpr(); } -template<> -inline PrimExpr TryConstFold(PrimExpr a) { +template <> +inline PrimExpr TryConstFold(PrimExpr a) { const IntImmNode* pa = a.as(); if (pa) { return IntImm(DataType::UInt(1), !(pa->value)); @@ -364,9 +360,7 @@ struct SymbolicLimits { * * \return positive infinity. */ -inline PrimExpr pos_inf() { - return SymbolicLimits::pos_inf_; -} +inline PrimExpr pos_inf() { return SymbolicLimits::pos_inf_; } /*! * \brief Check if value is positive infinity. @@ -374,9 +368,7 @@ inline PrimExpr pos_inf() { * * \return The check result. */ -inline bool is_pos_inf(const PrimExpr& value) { - return value.same_as(SymbolicLimits::pos_inf_); -} +inline bool is_pos_inf(const PrimExpr& value) { return value.same_as(SymbolicLimits::pos_inf_); } /*! * \brief Opaque expression representing negative infinity. @@ -386,9 +378,7 @@ inline bool is_pos_inf(const PrimExpr& value) { * * \return negative infinity. */ -inline PrimExpr neg_inf() { - return SymbolicLimits::neg_inf_; -} +inline PrimExpr neg_inf() { return SymbolicLimits::neg_inf_; } /*! * \brief Check if value is negative infinity. @@ -396,9 +386,7 @@ inline PrimExpr neg_inf() { * * \return The check result. */ -inline bool is_neg_inf(const PrimExpr& value) { - return value.same_as(SymbolicLimits::neg_inf_); -} +inline bool is_neg_inf(const PrimExpr& value) { return value.same_as(SymbolicLimits::neg_inf_); } } // namespace arith } // namespace tvm diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc index 57dfc157fc21..c33990cd1f4f 100644 --- a/src/arith/const_int_bound.cc +++ b/src/arith/const_int_bound.cc @@ -20,10 +20,12 @@ /*! * \file tvm/arith/const_int_bound.cc */ -#include #include +#include #include + #include + #include "int_operator.h" #include "pattern_match.h" @@ -34,8 +36,7 @@ using namespace tir; TVM_REGISTER_NODE_TYPE(ConstIntBoundNode); -ConstIntBound::ConstIntBound( - int64_t min_value, int64_t max_value) { +ConstIntBound::ConstIntBound(int64_t min_value, int64_t max_value) { auto node = make_object(); node->min_value = min_value; node->max_value = max_value; @@ -46,8 +47,7 @@ ConstIntBound MakeConstIntBound(int64_t min_value, int64_t max_value) { return ConstIntBound(min_value, max_value); } -TVM_REGISTER_GLOBAL("arith.ConstIntBound") -.set_body_typed(MakeConstIntBound); +TVM_REGISTER_GLOBAL("arith.ConstIntBound").set_body_typed(MakeConstIntBound); inline void PrintBoundValue(std::ostream& os, int64_t val) { if (val == ConstIntBound::kPosInf) { @@ -60,31 +60,29 @@ inline void PrintBoundValue(std::ostream& os, int64_t val) { } TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "ConstIntBound["; - PrintBoundValue(p->stream, op->min_value); - p->stream << ','; - PrintBoundValue(p->stream, op->max_value); - p->stream << ']'; - }); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "ConstIntBound["; + PrintBoundValue(p->stream, op->min_value); + p->stream << ','; + PrintBoundValue(p->stream, op->max_value); + p->stream << ']'; + }); // internal entry for const int bound struct ConstIntBoundAnalyzer::Entry { int64_t min_value; int64_t max_value; - bool is_const(int64_t value) const { - return min_value == max_value && min_value == value; - } + bool is_const(int64_t value) const { return min_value == max_value && min_value == value; } bool operator==(const Entry& other) const { return min_value == other.min_value && max_value == other.max_value; } }; -class ConstIntBoundAnalyzer::Impl : - public ExprFunctor { +class ConstIntBoundAnalyzer::Impl + : public ExprFunctor { public: /*! \brief additional bound info about expr \in bound */ struct BoundInfo { @@ -94,46 +92,39 @@ class ConstIntBoundAnalyzer::Impl : Entry bound; BoundInfo() {} - BoundInfo(PrimExpr expr, Entry bound) - : expr(expr), bound(bound) { - } + BoundInfo(PrimExpr expr, Entry bound) : expr(expr), bound(bound) {} }; - void Bind(const Var& var, const Range& range) { + void Bind(const Var& var, const Range& range, bool override) { Entry a = VisitExpr(range->min); Entry b = VisitExpr(range->extent); Entry ret; ret.min_value = a.min_value; ret.max_value = InfAwareAdd(a.max_value, InfAwareAdd(b.max_value, -1)); - Update(var, ret, false); + Update(var, ret, override); } - void Update(const Var& var, - const Entry& info, - bool override) { + void Update(const Var& var, const Entry& info, bool override) { if (!override) { auto it = var_map_.find(var); if (it != var_map_.end()) { - CHECK(it->second == info) - << "Trying to update var \'" << var << "\'" - << " with a different const bound: " - << "original=" << ConstIntBound(it->second.min_value, it->second.max_value) - << ", new=" << ConstIntBound(info.min_value, info.max_value); + CHECK(it->second == info) << "Trying to update var \'" << var << "\'" + << " with a different const bound: " + << "original=" + << ConstIntBound(it->second.min_value, it->second.max_value) + << ", new=" << ConstIntBound(info.min_value, info.max_value); } } var_map_[var] = info; } - void Update(const Var& var, - const ConstIntBound& info, - bool override) { + void Update(const Var& var, const ConstIntBound& info, bool override) { Update(var, MakeBound(info->min_value, info->max_value), override); } // Override visitor behaviors Entry VisitExprDefault_(const Object* op) final { - return Everything( - static_cast(op)->dtype); + return Everything(static_cast(op)->dtype); } Entry VisitExpr(const PrimExpr& expr) final { @@ -147,15 +138,16 @@ class ConstIntBoundAnalyzer::Impl : } } if (bound_) { - const PrimExprNode* op = expr.as(); - auto val = bound_->find(op); + auto val = bound_->find(expr); if (val != bound_->end()) { - CHECK(val->second->min_value == res.min_value && - val->second->max_value == res.max_value) - << "Detected bound for " << expr - << "conflicts with memorization"; + auto everything = Everything(expr->dtype); + CHECK( + (val->second->min_value == res.min_value && val->second->max_value == res.max_value) || + (val->second->min_value == everything.min_value && + val->second->max_value == everything.max_value)) + << "Detected bound for " << expr << "conflicts with memorization"; } - (*bound_)[op] = ConstIntBound(res.min_value, res.max_value); + (*bound_)[expr] = ConstIntBound(res.min_value, res.max_value); } return res; } @@ -176,9 +168,7 @@ class ConstIntBoundAnalyzer::Impl : return Intersect(a, b); } - Entry VisitExpr_(const IntImmNode* op) final { - return MakeBound(op->value, op->value); - } + Entry VisitExpr_(const IntImmNode* op) final { return MakeBound(op->value, op->value); } Entry VisitExpr_(const AddNode* op) final { Entry a = VisitExpr(op->a); @@ -223,8 +213,7 @@ class ConstIntBoundAnalyzer::Impl : // 0 <= [a_min, a_max] < b_min if (a.max_value < b.min_value) return a; // other case, we can get close to 0 - return MakeBound(0, - std::min(a.max_value, b_max_cap)); + return MakeBound(0, std::min(a.max_value, b_max_cap)); } else { return MakeBound(std::max(a.min_value, -b_max_cap), std::min(std::max(a.max_value, (int64_t)0), b_max_cap)); @@ -363,11 +352,11 @@ class ConstIntBoundAnalyzer::Impl : private: friend class ConstIntBoundAnalyzer; // internal variable map - std::unordered_map var_map_; + std::unordered_map var_map_; // additional bound info std::vector additional_info_; // look up table for memorization - std::unordered_map* bound_{nullptr}; + BoundMapType* bound_{nullptr}; // constants: the limit value means umlimited // NOTE: kNegInf/kPosInf are used to represent infinity. static const constexpr int64_t kNegInf = ConstIntBound::kNegInf; @@ -382,7 +371,7 @@ class ConstIntBoundAnalyzer::Impl : * \tparam F the operator function type. * \return The result. */ - template + template static Entry BinaryOpBoundry(Entry a, Entry b, const F& op) { Entry ret; // The boundary point must be shihft of the original boundary. @@ -560,35 +549,28 @@ ConstIntBound ConstIntBoundAnalyzer::operator()(const PrimExpr& expr) { return ConstIntBound(ret.min_value, ret.max_value); } -ConstIntBound ConstIntBoundAnalyzer::operator()(const PrimExpr& expr, - std::unordered_map* bound) { +ConstIntBound ConstIntBoundAnalyzer::operator()(const PrimExpr& expr, BoundMapType* bound) { impl_->bound_ = bound; Entry ret = impl_->VisitExpr(expr); impl_->bound_ = nullptr; return ConstIntBound(ret.min_value, ret.max_value); } -void ConstIntBoundAnalyzer::Update(const Var& var, - const ConstIntBound& info, - bool override) { +void ConstIntBoundAnalyzer::Update(const Var& var, const ConstIntBound& info, bool override) { impl_->Update(var, info, override); } -void ConstIntBoundAnalyzer::Bind(const Var& var, const Range& range) { - impl_->Bind(var, range); +void ConstIntBoundAnalyzer::Bind(const Var& var, const Range& range, bool override) { + impl_->Bind(var, range, override); } std::function ConstIntBoundAnalyzer::EnterConstraint(const PrimExpr& constraint) { return impl_->EnterConstraint(constraint); } -ConstIntBoundAnalyzer::ConstIntBoundAnalyzer(Analyzer* parent) - : impl_(new Impl()) { -} +ConstIntBoundAnalyzer::ConstIntBoundAnalyzer(Analyzer* parent) : impl_(new Impl()) {} -ConstIntBoundAnalyzer::~ConstIntBoundAnalyzer() { - delete impl_; -} +ConstIntBoundAnalyzer::~ConstIntBoundAnalyzer() { delete impl_; } } // namespace arith } // namespace tvm diff --git a/src/arith/detect_linear_equation.cc b/src/arith/detect_linear_equation.cc index cc9c745a24b8..f0634feac083 100644 --- a/src/arith/detect_linear_equation.cc +++ b/src/arith/detect_linear_equation.cc @@ -21,12 +21,13 @@ * \file detect_linear_equation.cc * \brief Utility to detect patterns in the expression. */ +#include #include +#include #include -#include #include +#include #include -#include namespace tvm { namespace arith { @@ -44,11 +45,9 @@ struct IntervalEntry { PrimExpr max_value; }; -class LinearEqDetector - : public ExprFunctor { +class LinearEqDetector : public ExprFunctor { public: - explicit LinearEqDetector(Var var) - : var_(var) {} + explicit LinearEqDetector(Var var) : var_(var) {} bool Detect(const PrimExpr& e, LinearEqEntry* ret) { *ret = VisitExpr(e, e); @@ -141,8 +140,7 @@ class LinearEqDetector } }; -Array DetectLinearEquation(const PrimExpr& e, - const Array& vars) { +Array DetectLinearEquation(const PrimExpr& e, const Array& vars) { PrimExpr base = e; Array coeff; @@ -156,10 +154,12 @@ Array DetectLinearEquation(const PrimExpr& e, } std::unordered_set vset; + auto vset_contains = [&](const VarNode* node) { return vset.count(node) != 0; }; + for (size_t i = vars.size(); i > 1; --i) { vset.insert(vars[i - 1].get()); // The previous coeff contains the variable - if (ExprUseVar(coeff[i - 2], vset)) { + if (ExprUseVar(coeff[i - 2], vset_contains)) { return Array(); } } @@ -168,9 +168,8 @@ Array DetectLinearEquation(const PrimExpr& e, } // Detect clip condition as min max value -bool DetectClipBound( - const PrimExpr& cond, - std::unordered_map* bmap) { +bool DetectClipBound(const PrimExpr& cond, + std::unordered_map* bmap) { int flag = 0; Var var; auto fvisit = [&bmap, &flag, &var](const ObjectRef& n) { @@ -207,13 +206,14 @@ bool DetectClipBound( return false; } LinearEqEntry ret; + Analyzer analyzer; if (!LinearEqDetector(var).Detect(canonical, &ret)) return false; - ret.coeff = Simplify(ret.coeff); + ret.coeff = analyzer.Simplify(ret.coeff); IntervalEntry& p = (*bmap)[var.get()]; if (is_const_int(ret.coeff, 1)) { // var + shift >=0 -> var >= -shift if (p.min_value.defined()) { - p.min_value = tir::MaxNode::make(p.min_value, -ret.base); + p.min_value = max(p.min_value, -ret.base); } else { p.min_value = -ret.base; } @@ -222,7 +222,7 @@ bool DetectClipBound( if (is_const_int(ret.coeff, -1)) { // -var + shift >=0 -> var <= shift if (p.max_value.defined()) { - p.max_value = tir::MinNode::make(p.max_value, ret.base); + p.max_value = min(p.max_value, ret.base); } else { p.max_value = ret.base; } @@ -231,8 +231,7 @@ bool DetectClipBound( return false; } - -template +template void SplitCommExpr(const PrimExpr& e, std::vector* ret) { if (const OP* op = e.as()) { SplitCommExpr(op->a, ret); @@ -254,14 +253,15 @@ Array DetectClipBound(const PrimExpr& e, const Array& vars) { for (PrimExpr cond : splits) { if (!DetectClipBound(cond, &rmap)) return Array(); } + Analyzer analyzer; Array ret; for (Var v : vars) { IntervalEntry e = rmap[v.get()]; if (e.min_value.defined()) { - e.min_value = Simplify(e.min_value); + e.min_value = analyzer.Simplify(e.min_value); } if (e.max_value.defined()) { - e.max_value = Simplify(e.max_value); + e.max_value = analyzer.Simplify(e.max_value); } ret.push_back(e.min_value); ret.push_back(e.max_value); @@ -269,12 +269,11 @@ Array DetectClipBound(const PrimExpr& e, const Array& vars) { return ret; } -TVM_REGISTER_GLOBAL("arith.DetectLinearEquation") -.set_body_typed(DetectLinearEquation); +TVM_REGISTER_GLOBAL("arith.DetectLinearEquation").set_body_typed(DetectLinearEquation); TVM_REGISTER_GLOBAL("arith.DetectClipBound") -.set_body_typed([](const PrimExpr& e, const Array& vars) { - return DetectClipBound(e, vars); -}); + .set_body_typed([](const PrimExpr& e, const Array& vars) { + return DetectClipBound(e, vars); + }); } // namespace arith } // namespace tvm diff --git a/src/arith/domain_touched.cc b/src/arith/domain_touched.cc index bda70fb67cba..b44d9f7ff1f5 100644 --- a/src/arith/domain_touched.cc +++ b/src/arith/domain_touched.cc @@ -21,14 +21,13 @@ * \file bound_deducer.cc * \brief Utility to deduce bound of expression */ +#include +#include #include -#include #include -#include -#include -#include #include +#include namespace tvm { namespace arith { @@ -36,14 +35,14 @@ namespace arith { using namespace tir; // Find Read region of the tensor in the stmt. -class FuncTouchedDomain final : public StmtExprVisitor { +class BufferTouchedDomain final : public StmtExprVisitor { public: - FuncTouchedDomain(const te::Tensor &tensor, bool consider_calls, bool consider_provides) - : tensor_(tensor), consider_calls_(consider_calls), consider_provides_(consider_provides) {} + BufferTouchedDomain(const Buffer& buffer, bool consider_loads, bool consider_stores) + : buffer_(buffer), consider_loads_(consider_loads), consider_stores_(consider_stores) {} - Domain Find(const Stmt& stmt) { + Region Find(const Stmt& stmt) { operator()(stmt); - Domain ret; + Region ret; Range none; for (size_t i = 0; i < bounds_.size(); ++i) { ret.push_back(arith::Union(bounds_[i]).cover_range(none)); @@ -51,17 +50,15 @@ class FuncTouchedDomain final : public StmtExprVisitor { return ret; } - void VisitStmt_(const ForNode *op) final { + void VisitStmt_(const ForNode* op) final { const VarNode* var = op->loop_var.get(); - dom_map_[var] = IntSet::range( - Range::make_by_min_extent(op->min, op->extent)); + dom_map_[var] = IntSet::range(Range::make_by_min_extent(op->min, op->extent)); StmtExprVisitor::VisitStmt_(op); dom_map_.erase(var); } void VisitStmt_(const LetStmtNode* op) final { - dom_map_[op->var.get()] = - arith::EvalSet(op->value, dom_map_); + dom_map_[op->var.get()] = arith::EvalSet(op->value, dom_map_); StmtExprVisitor::VisitStmt_(op); dom_map_.erase(op->var.get()); } @@ -80,18 +77,16 @@ class FuncTouchedDomain final : public StmtExprVisitor { } } - void VisitExpr_(const CallNode* op) final { - if (consider_calls_ && tensor_->op.same_as(op->func) - && tensor_->value_index == op->value_index) { - Touch(op->args); + void VisitExpr_(const BufferLoadNode* op) final { + if (consider_loads_ && buffer_.same_as(op->buffer)) { + Touch(op->indices); } StmtExprVisitor::VisitExpr_(op); } - void VisitStmt_(const ProvideNode* op) final { - if (consider_provides_ && tensor_->op.same_as(op->func) - && tensor_->value_index == op->value_index) { - Touch(op->args); + void VisitStmt_(const BufferStoreNode* op) final { + if (consider_stores_ && buffer_.same_as(op->buffer)) { + Touch(op->indices); } StmtExprVisitor::VisitStmt_(op); } @@ -106,21 +101,18 @@ class FuncTouchedDomain final : public StmtExprVisitor { } } - const te::Tensor &tensor_; - bool consider_calls_, consider_provides_; + const Buffer& buffer_; + bool consider_loads_, consider_stores_; std::vector > bounds_; std::unordered_map dom_map_; }; -Domain DomainTouched(Stmt stmt, - const te::Tensor &tensor, - bool consider_calls, - bool consider_provides) { - return FuncTouchedDomain(tensor, consider_calls, consider_provides).Find(stmt); +Region DomainTouched(const Stmt& stmt, const Buffer& buffer, bool consider_loads, + bool consider_stores) { + return BufferTouchedDomain(buffer, consider_loads, consider_stores).Find(stmt); } -TVM_REGISTER_GLOBAL("arith.DomainTouched") -.set_body_typed(DomainTouched); +TVM_REGISTER_GLOBAL("arith.DomainTouched").set_body_typed(DomainTouched); } // namespace arith } // namespace tvm diff --git a/src/arith/int_constraints.cc b/src/arith/int_constraints.cc index 34efa986e985..62858d2dc9e2 100644 --- a/src/arith/int_constraints.cc +++ b/src/arith/int_constraints.cc @@ -22,19 +22,18 @@ * \brief The integer constraints data structures. */ #include +#include #include #include -#include -#include #include #include +#include namespace tvm { namespace arith { -IntConstraints::IntConstraints(Array variables, - Map ranges, +IntConstraints::IntConstraints(Array variables, Map ranges, Array relations) { ObjectPtr node = make_object(); if (!variables.defined()) { @@ -46,7 +45,7 @@ IntConstraints::IntConstraints(Array variables, CHECK(relations.defined()); for (const auto& var : variables) { CHECK(var.dtype().is_int() || var.dtype().is_uint()) - << "Variables in IntConstraints must be integers"; + << "Variables in IntConstraints must be integers"; } node->variables = std::move(variables); node->ranges = std::move(ranges); @@ -57,18 +56,13 @@ IntConstraints::IntConstraints(Array variables, TVM_REGISTER_NODE_TYPE(IntConstraintsNode); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "IntConstraints(" - << op->variables - << ", " << op->ranges - << ", " << op->relations - << ")"; - }); - + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "IntConstraints(" << op->variables << ", " << op->ranges << ", " << op->relations + << ")"; + }); -IntConstraintsTransform::IntConstraintsTransform(IntConstraints src, - IntConstraints dst, +IntConstraintsTransform::IntConstraintsTransform(IntConstraints src, IntConstraints dst, Map src_to_dst, Map dst_to_src) { ObjectPtr node = make_object(); @@ -82,15 +76,12 @@ IntConstraintsTransform::IntConstraintsTransform(IntConstraints src, TVM_REGISTER_NODE_TYPE(IntConstraintsTransformNode); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "IntConstraintsTransform(" - << "\n\t" << op->src - << "\n\t" << op->dst - << "\n\t" << op->src_to_dst - << "\n\t" << op->dst_to_src - << "\n)"; - }); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "IntConstraintsTransform(" + << "\n\t" << op->src << "\n\t" << op->dst << "\n\t" << op->src_to_dst << "\n\t" + << op->dst_to_src << "\n)"; + }); } // namespace arith } // namespace tvm diff --git a/src/arith/int_operator.h b/src/arith/int_operator.h index 3be34b638777..b69ce4fe5858 100644 --- a/src/arith/int_operator.h +++ b/src/arith/int_operator.h @@ -25,6 +25,7 @@ #define TVM_ARITH_INT_OPERATOR_H_ #include +#include namespace tvm { namespace arith { @@ -38,56 +39,41 @@ namespace arith { * \return Whether overflow can happen. * \tparam Op The integer operator. */ -template -inline bool WillOverflow(int64_t x, - int64_t y, - int64_t min_value, - int64_t max_value) { +template +inline bool WillOverflow(int64_t x, int64_t y, int64_t min_value, int64_t max_value) { return false; } -template<> -inline bool WillOverflow(int64_t x, - int64_t y, - int64_t min_value, - int64_t max_value) { +template <> +inline bool WillOverflow(int64_t x, int64_t y, int64_t min_value, int64_t max_value) { if ((y > 0) && (x > max_value - y)) return true; if ((y < 0) && (x < min_value - y)) return true; return false; } -template<> -inline bool WillOverflow(int64_t x, - int64_t y, - int64_t min_value, - int64_t max_value) { +template <> +inline bool WillOverflow(int64_t x, int64_t y, int64_t min_value, int64_t max_value) { if ((y > 0) && (x < min_value + y)) return true; if ((y < 0) && (x > max_value + y)) return true; return false; } -template<> -inline bool WillOverflow(int64_t x, - int64_t y, - int64_t min_value, - int64_t max_value) { +template <> +inline bool WillOverflow(int64_t x, int64_t y, int64_t min_value, int64_t max_value) { if (y == 0) return false; if (y > 0) { - if (x < min_value / y) return true; - if (x > max_value / y) return true; + if (x < min_value / y) return true; + if (x > max_value / y) return true; } else { if (y == -1 && x == std::numeric_limits::min()) return true; - if (x > min_value / y) return true; - if (x < max_value / y) return true; + if (x > min_value / y) return true; + if (x < max_value / y) return true; } return false; } -template<> -inline bool WillOverflow(int64_t x, - int64_t y, - int64_t min_value, - int64_t max_value) { +template <> +inline bool WillOverflow(int64_t x, int64_t y, int64_t min_value, int64_t max_value) { return y == 0; } @@ -97,9 +83,7 @@ inline bool WillOverflow(int64_t x, * \param y The right operand. * \return the result. */ -inline int64_t truncdiv(int64_t x, int64_t y) { - return x / y; -} +inline int64_t truncdiv(int64_t x, int64_t y) { return x / y; } /*! * \brief Compute the truncdiv remainder of two integers. @@ -107,9 +91,7 @@ inline int64_t truncdiv(int64_t x, int64_t y) { * \param y The right operand. * \return the result. */ -inline int64_t truncmod(int64_t x, int64_t y) { - return x % y; -} +inline int64_t truncmod(int64_t x, int64_t y) { return x % y; } /*! * \brief Peform floor division of two integers. @@ -120,13 +102,10 @@ inline int64_t truncmod(int64_t x, int64_t y) { inline int64_t floordiv(int64_t x, int64_t y) { int64_t rdiv = x / y; int64_t rmod = x % y; - bool is_floor_div = - (y >= 0 && rmod >= 0) || - (y < 0 && rmod <= 0); + bool is_floor_div = (y >= 0 && rmod >= 0) || (y < 0 && rmod <= 0); return is_floor_div ? rdiv : (rdiv - 1); } - /*! * \brief Compute the floordiv remainder of two integers. * \param x The left operand. @@ -135,12 +114,74 @@ inline int64_t floordiv(int64_t x, int64_t y) { */ inline int64_t floormod(int64_t x, int64_t y) { int64_t rmod = x % y; - bool is_floor_div = - (y >= 0 && rmod >= 0) || - (y < 0 && rmod <= 0); + bool is_floor_div = (y >= 0 && rmod >= 0) || (y < 0 && rmod <= 0); return is_floor_div ? rmod : rmod + y; } +/*! + * \brief Use Extended Euclidean algorithm to solve ax + by = gcd(a, b) + * \param a The first coefficient. + * \param b The second coefficient. + * \param x The solution of x. + * \param y The solution of y. + * \return The GCD of a and b. + */ +inline int64_t ExtendedEuclidean(int64_t a, int64_t b, int64_t* x, int64_t* y) { + // Extended Euclidean algorithm + // if a < 0, the problem can be convert into + // |a|* (-x) + b * y = gcd(|a|, b) + // + // initial condition: + // a * 0 + b * 1 = b + // a * 1 + b * 0 = a + int64_t s = 0, old_s = 1; + int64_t r = b, old_r = a >= 0 ? a : -a; + // Iteration (r2 < r1): + // a * x1 + b * y1 = r1 + // a * x2 + b * y2 = r2 + // The above two eqs can derive the following eq (q = r1 / r2) + // a * (x1 - x2 * q) + b * (y1 - y2 * q) = r1 - r2 * q = r3 + // Because r3 < r2, the iteration can eventually terminate + while (r != 0) { + int64_t q = old_r / r; + int64_t tmp = old_r; + old_r = r; + r = tmp - q * r; + tmp = old_s; + old_s = s; + s = tmp - q * s; + } + + *x = a >= 0 ? old_s : -old_s; + if (b != 0) { + *y = (old_r - (*x) * a) / b; + } else { + *y = 1; + } + + return old_r; +} + +/*! + * \brief Take GCD of a and b. + * \param a The first operand. + * \param b The second operand. + * \return The result. + */ +inline int64_t ZeroAwareGCD(int64_t a, int64_t b) { + if (a < 0) a = -a; + if (b < 0) b = -b; + if (a < b) std::swap(a, b); + if (b == 0) return a; + // perform GCD (greatest common divisor) + // ax + by = gcd(a, b) z if a != 0, b != 0 + while (a % b != 0) { + a = a % b; + std::swap(a, b); + } + return b; +} + } // namespace arith } // namespace tvm #endif // TVM_ARITH_INT_OPERATOR_H_ diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index 8c5afb1be8b5..b043b355b507 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -22,23 +22,24 @@ * \brief The integer set functions */ #include +#include #include #include -#include -#include #include #include +#include + #include "interval_set.h" #include "pattern_match.h" namespace tvm { namespace arith { +using tir::is_one; +using tir::is_zero; using tir::make_const; using tir::make_zero; -using tir::is_zero; -using tir::is_one; PrimExpr SymbolicLimits::pos_inf_ = Var("pos_inf", DataType::Handle()); PrimExpr SymbolicLimits::neg_inf_ = Var("neg_inf", DataType::Handle()); @@ -54,9 +55,7 @@ IntervalSet MakeIntervalSet(PrimExpr min_value, PrimExpr max_value) { return IntervalSet(min_value, max_value); } -TVM_REGISTER_GLOBAL("arith.IntervalSet") -.set_body_typed(MakeIntervalSet); - +TVM_REGISTER_GLOBAL("arith.IntervalSet").set_body_typed(MakeIntervalSet); IntervalSet Intersect(Analyzer* analyzer, IntervalSet a, IntervalSet b) { PrimExpr max_value = min(a->max_value, b->max_value); @@ -77,43 +76,40 @@ IntervalSet Union(Analyzer* analyzer, IntervalSet a, IntervalSet b) { } // type traits -template +template struct is_logical_op { static const bool value = false; }; -#define TVM_DECLARE_LOGICAL_OP(OP) \ - template<> \ - struct is_logical_op { \ - static const bool value = true; \ +#define TVM_DECLARE_LOGICAL_OP(OP) \ + template <> \ + struct is_logical_op { \ + static const bool value = true; \ }; -TVM_DECLARE_LOGICAL_OP(AndNode); -TVM_DECLARE_LOGICAL_OP(OrNode); -TVM_DECLARE_LOGICAL_OP(EQNode); -TVM_DECLARE_LOGICAL_OP(NENode); -TVM_DECLARE_LOGICAL_OP(GENode); -TVM_DECLARE_LOGICAL_OP(GTNode); -TVM_DECLARE_LOGICAL_OP(LENode); -TVM_DECLARE_LOGICAL_OP(LTNode); -TVM_DECLARE_LOGICAL_OP(NotNode); +TVM_DECLARE_LOGICAL_OP(And); +TVM_DECLARE_LOGICAL_OP(Or); +TVM_DECLARE_LOGICAL_OP(EQ); +TVM_DECLARE_LOGICAL_OP(NE); +TVM_DECLARE_LOGICAL_OP(GE); +TVM_DECLARE_LOGICAL_OP(GT); +TVM_DECLARE_LOGICAL_OP(LE); +TVM_DECLARE_LOGICAL_OP(LT); +TVM_DECLARE_LOGICAL_OP(Not); /*! * \brief Combine two interval set under arithmetic operations. * \note this can possibly relax the set. */ -template -inline IntervalSet Combine(Analyzer* analyzer, - IntervalSet a, - IntervalSet b) { +template +inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b) { if (a->IsSinglePoint() && b->IsSinglePoint()) { PrimExpr res = TryConstFold(a->min_value, b->min_value); - if (!res.defined()) res = Op::make(a->min_value, b->min_value); + if (!res.defined()) res = Op(a->min_value, b->min_value); return IntervalSet::SinglePoint(res); } if (is_logical_op::value) { - return IntervalSet(make_const(a->min_value.dtype(), 0), - make_const(a->min_value.dtype(), 1)); + return IntervalSet(make_const(a->min_value.dtype(), 0), make_const(a->min_value.dtype(), 1)); } if (a->IsEmpty()) return a; if (b->IsEmpty()) return b; @@ -122,47 +118,36 @@ inline IntervalSet Combine(Analyzer* analyzer, return IntervalSet::Everything(); } -template<> -inline IntervalSet Combine(Analyzer* analyer, - IntervalSet a, - IntervalSet b) { +template <> +inline IntervalSet Combine(Analyzer* analyer, IntervalSet a, IntervalSet b) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(a->min_value + b->min_value); } if (a->IsEmpty()) return a; if (b->IsEmpty()) return b; PrimExpr min_value = - a->HasLowerBound() && b->HasLowerBound() ? - a->min_value + b->min_value : neg_inf(); + a->HasLowerBound() && b->HasLowerBound() ? a->min_value + b->min_value : neg_inf(); PrimExpr max_value = - a->HasUpperBound() && b->HasUpperBound() ? - a->max_value + b->max_value : pos_inf(); + a->HasUpperBound() && b->HasUpperBound() ? a->max_value + b->max_value : pos_inf(); return IntervalSet(min_value, max_value); } -template<> -inline IntervalSet Combine(Analyzer* analyer, - IntervalSet a, - IntervalSet b) { +template <> +inline IntervalSet Combine(Analyzer* analyer, IntervalSet a, IntervalSet b) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(a->min_value - b->min_value); } if (a->IsEmpty()) return a; if (b->IsEmpty()) return b; PrimExpr min_value = - a->HasLowerBound() && b->HasUpperBound() ? - a->min_value - b->max_value : neg_inf(); + a->HasLowerBound() && b->HasUpperBound() ? a->min_value - b->max_value : neg_inf(); PrimExpr max_value = - a->HasUpperBound() && b->HasLowerBound() ? - a->max_value - b->min_value : pos_inf(); + a->HasUpperBound() && b->HasLowerBound() ? a->max_value - b->min_value : pos_inf(); return IntervalSet(min_value, max_value); } - -template<> -inline IntervalSet Combine(Analyzer* analyzer, - IntervalSet a, - IntervalSet b) { +template <> +inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(a->min_value * b->min_value); } @@ -183,21 +168,19 @@ inline IntervalSet Combine(Analyzer* analyzer, PrimExpr max_value = a->HasLowerBound() ? a->min_value * b->min_value : pos_inf(); return IntervalSet(min_value, max_value); } else if (a->HasUpperBound() && a->HasLowerBound()) { - using tir::SelectNode; + using tir::Select; PrimExpr sign = b->min_value >= make_zero(b->min_value.dtype().element_of()); PrimExpr e1 = a->min_value * b->min_value; PrimExpr e2 = a->max_value * b->min_value; - return IntervalSet(SelectNode::make(sign, e1, e2), SelectNode::make(sign, e2, e1)); + return IntervalSet(Select(sign, e1, e2), Select(sign, e2, e1)); } } DLOG(WARNING) << "Return Everything in CombineInterval Mul"; return IntervalSet::Everything(); } -template<> -inline IntervalSet Combine(Analyzer* analyzer, - IntervalSet a, - IntervalSet b) { +template <> +inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(a->min_value / b->min_value); } @@ -218,21 +201,19 @@ inline IntervalSet Combine(Analyzer* analyzer, PrimExpr max_value = a->HasLowerBound() ? a->min_value / b->min_value : pos_inf(); return IntervalSet(min_value, max_value); } else if (a->HasUpperBound() && a->HasLowerBound()) { - using tir::SelectNode; + using tir::Select; PrimExpr sign = b->min_value >= make_zero(b->min_value.dtype().element_of()); PrimExpr e1 = a->min_value / b->min_value; PrimExpr e2 = a->max_value / b->min_value; - return IntervalSet(SelectNode::make(sign, e1, e2), SelectNode::make(sign, e2, e1)); + return IntervalSet(Select(sign, e1, e2), Select(sign, e2, e1)); } } DLOG(WARNING) << "Return Everything in CombineInterval Div"; return IntervalSet::Everything(); } -template<> -inline IntervalSet Combine(Analyzer* analyzer, - IntervalSet a, - IntervalSet b) { +template <> +inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(truncmod(a->min_value, b->min_value)); } @@ -259,11 +240,8 @@ inline IntervalSet Combine(Analyzer* analyzer, return IntervalSet::Everything(); } - -template<> -inline IntervalSet Combine(Analyzer* analyzer, - IntervalSet a, - IntervalSet b) { +template <> +inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(floordiv(a->min_value, b->min_value)); } @@ -284,21 +262,19 @@ inline IntervalSet Combine(Analyzer* analyzer, PrimExpr max_value = a->HasLowerBound() ? floordiv(a->min_value, b->min_value) : pos_inf(); return IntervalSet(min_value, max_value); } else if (a->HasUpperBound() && a->HasLowerBound()) { - using tir::SelectNode; + using tir::Select; PrimExpr sign = b->min_value >= make_zero(b->min_value.dtype().element_of()); PrimExpr e1 = floordiv(a->min_value, b->min_value); PrimExpr e2 = floordiv(a->max_value, b->min_value); - return IntervalSet(SelectNode::make(sign, e1, e2), SelectNode::make(sign, e2, e1)); + return IntervalSet(Select(sign, e1, e2), Select(sign, e2, e1)); } } DLOG(WARNING) << "Return Everything in CombineInterval Div"; return IntervalSet::Everything(); } -template<> -inline IntervalSet Combine(Analyzer* analyzer, - IntervalSet a, - IntervalSet b) { +template <> +inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(floormod(a->min_value, b->min_value)); } @@ -311,6 +287,16 @@ inline IntervalSet Combine(Analyzer* analyzer, LOG(FATAL) << "Modular by zero in CombineInterval Mod"; } if (analyzer->CanProveGreaterEqual(divisor, 0)) { + if (divisor.as()) { + // a mod b = a - (a / b) * b if a_max / b == a_min / b + auto qmax = floordiv(a->max_value, divisor); + auto qmin = floordiv(a->min_value, divisor); + if (analyzer->CanProve(qmax == qmin)) { + auto tmax = a->max_value - divisor * qmin; + auto tmin = a->min_value - divisor * qmin; + return IntervalSet(tmin, tmax); + } + } return IntervalSet(make_zero(divisor.dtype()), divisor - 1); } else { PrimExpr bound = abs(divisor) - 1; @@ -321,30 +307,24 @@ inline IntervalSet Combine(Analyzer* analyzer, return IntervalSet::Everything(); } -template<> -inline IntervalSet Combine(Analyzer* analzyer, - IntervalSet a, - IntervalSet b) { +template <> +inline IntervalSet Combine(Analyzer* analzyer, IntervalSet a, IntervalSet b) { if (a->IsSinglePoint() && b->IsSinglePoint()) { - return IntervalSet::SinglePoint(max(a->min_value, b->min_value)); + return IntervalSet::SinglePoint(max(a->min_value, b->min_value)); } if (a->IsEmpty()) return a; if (b->IsEmpty()) return b; - return IntervalSet(max(a->min_value, b->min_value), - max(a->max_value, b->max_value)); + return IntervalSet(max(a->min_value, b->min_value), max(a->max_value, b->max_value)); } -template<> -inline IntervalSet Combine(Analyzer* analzyer, - IntervalSet a, - IntervalSet b) { +template <> +inline IntervalSet Combine(Analyzer* analzyer, IntervalSet a, IntervalSet b) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(min(a->min_value, b->min_value)); } if (a->IsEmpty()) return a; if (b->IsEmpty()) return b; - return IntervalSet(min(a->min_value, b->min_value), - min(a->max_value, b->max_value)); + return IntervalSet(min(a->min_value, b->min_value), min(a->max_value, b->max_value)); } // internal helper function to get an interval set @@ -360,20 +340,12 @@ using namespace tir; // Simplified version of int set evaluator that operates on IntervalSet // We might use better set analysis in the future to replace the intervalset. -class IntervalSetEvaluator : - public ExprFunctor { +class IntervalSetEvaluator : public ExprFunctor { public: - IntervalSetEvaluator(Analyzer* analyzer, - const Map& dom_map, - bool eval_vec = false) - : analyzer_(analyzer), - dom_map_(dom_map), - eval_vec_(eval_vec) { - } + IntervalSetEvaluator(Analyzer* analyzer, const Map& dom_map, bool eval_vec = false) + : analyzer_(analyzer), dom_map_(dom_map), eval_vec_(eval_vec) {} - IntervalSet Eval(const PrimExpr& val) { - return this->VisitExpr(val); - } + IntervalSet Eval(const PrimExpr& val) { return this->VisitExpr(val); } // evaluate and relax the set IntervalSet Eval(IntervalSet val) { // avoid recursive indefinite recursive expansion. @@ -394,8 +366,7 @@ class IntervalSetEvaluator : auto it = dom_map_.find(var); if (it != dom_map_.end()) { IntervalSet res = ToIntervalSet((*it).second); - if (res->min_value.same_as(var) && - res->max_value.same_as(var)) { + if (res->min_value.same_as(var) && res->max_value.same_as(var)) { return res; } // recursively evaluate mapped result @@ -406,74 +377,39 @@ class IntervalSetEvaluator : } } + IntervalSet VisitExpr_(const AddNode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const AddNode* op) final { - return VisitBinaryExpr_(op); - } - - IntervalSet VisitExpr_(const SubNode* op) final { - return VisitBinaryExpr_(op); - } + IntervalSet VisitExpr_(const SubNode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const MulNode* op) final { - return VisitBinaryExpr_(op); - } + IntervalSet VisitExpr_(const MulNode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const DivNode* op) final { - return VisitBinaryExpr_(op); - } + IntervalSet VisitExpr_(const DivNode* op) final { return VisitBinaryExpr_
(op); } - IntervalSet VisitExpr_(const ModNode* op) final { - return VisitBinaryExpr_(op); - } + IntervalSet VisitExpr_(const ModNode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const FloorDivNode* op) final { - return VisitBinaryExpr_(op); - } + IntervalSet VisitExpr_(const FloorDivNode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const FloorModNode* op) final { - return VisitBinaryExpr_(op); - } + IntervalSet VisitExpr_(const FloorModNode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const MinNode* op) final { - return VisitBinaryExpr_(op); - } + IntervalSet VisitExpr_(const MinNode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const MaxNode* op) final { - return VisitBinaryExpr_(op); - } + IntervalSet VisitExpr_(const MaxNode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const EQNode* op) final { - return VisitBinaryExpr_(op); - } + IntervalSet VisitExpr_(const EQNode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const NENode* op) final { - return VisitBinaryExpr_(op); - } + IntervalSet VisitExpr_(const NENode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const LTNode* op) final { - return VisitBinaryExpr_(op); - } + IntervalSet VisitExpr_(const LTNode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const LENode* op) final { - return VisitBinaryExpr_(op); - } + IntervalSet VisitExpr_(const LENode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const GTNode* op) final { - return VisitBinaryExpr_(op); - } + IntervalSet VisitExpr_(const GTNode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const GENode* op) final { - return VisitBinaryExpr_(op); - } + IntervalSet VisitExpr_(const GENode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const AndNode* op) final { - return VisitBinaryExpr_(op); - } + IntervalSet VisitExpr_(const AndNode* op) final { return VisitBinaryExpr_(op); } - IntervalSet VisitExpr_(const OrNode* op) final { - return VisitBinaryExpr_(op); - } + IntervalSet VisitExpr_(const OrNode* op) final { return VisitBinaryExpr_(op); } IntervalSet VisitExpr_(const RampNode* op) final { CHECK(eval_vec_); @@ -482,16 +418,12 @@ class IntervalSetEvaluator : if (stride.Match(op->stride)) { DataType t = op->base.dtype(); int64_t vstride = stride.Eval()->value; - if (vstride> 0) { - return Combine( - analyzer_, - base, - IntervalSet(make_zero(t), make_const(t, vstride * op->lanes - 1))); + if (vstride > 0) { + return Combine(analyzer_, base, + IntervalSet(make_zero(t), make_const(t, vstride * op->lanes - 1))); } else { - return Combine( - analyzer_, - base, - IntervalSet(make_const(t, vstride * op->lanes + 1), make_zero(t))); + return Combine(analyzer_, base, + IntervalSet(make_const(t, vstride * op->lanes + 1), make_zero(t))); } } DLOG(WARNING) << "cannot evaluate set on expression " << GetRef(op); @@ -516,19 +448,19 @@ class IntervalSetEvaluator : private: // whether set is exactly single point that equals value. - bool MatchPoint(const IntervalSet& set, - const PrimExpr& value) const { + bool MatchPoint(const IntervalSet& set, const PrimExpr& value) const { return set->min_value.same_as(value) && set->max_value.same_as(value); } - template + template inline IntervalSet VisitBinaryExpr_(const T* op) { + static_assert(std::is_same::value, "constraint"); IntervalSet a = this->Eval(op->a); IntervalSet b = this->Eval(op->b); if (MatchPoint(a, op->a) && MatchPoint(b, op->b)) { return IntervalSet::SinglePoint(GetRef(op)); } - return Combine(analyzer_, a, b); + return Combine(analyzer_, a, b); } // recursive depth @@ -541,9 +473,7 @@ class IntervalSetEvaluator : class IntSetAnalyzer::Impl { public: - explicit Impl(Analyzer* analyzer) - : analyzer_(analyzer) { - } + explicit Impl(Analyzer* analyzer) : analyzer_(analyzer) {} IntSet Eval(const PrimExpr& expr, const Map& dom_map) const { return IntervalSetEvaluator(analyzer_, dom_map).Eval(expr); @@ -553,16 +483,11 @@ class IntSetAnalyzer::Impl { Analyzer* analyzer_; }; -IntSetAnalyzer::IntSetAnalyzer(Analyzer* parent) - : impl_(new Impl(parent)) { -} +IntSetAnalyzer::IntSetAnalyzer(Analyzer* parent) : impl_(new Impl(parent)) {} -IntSetAnalyzer::~IntSetAnalyzer() { - delete impl_; -} +IntSetAnalyzer::~IntSetAnalyzer() { delete impl_; } -IntSet IntSetAnalyzer::operator()(const PrimExpr& expr, - const Map& dom_map) { +IntSet IntSetAnalyzer::operator()(const PrimExpr& expr, const Map& dom_map) { return impl_->Eval(expr, dom_map); } @@ -570,11 +495,12 @@ IntSet IntSetAnalyzer::operator()(const PrimExpr& expr, // TODO(tqchen): revisit IntSet interface as well. Range IntSet::cover_range(Range max_range) const { IntSet temp; + Analyzer analyzer; const IntervalSetNode* s_int = (*this).as(); CHECK(s_int != nullptr); if (s_int->HasUpperBound() && s_int->HasLowerBound()) { - return Range::make_by_min_extent( - s_int->min_value, Simplify(s_int->max_value + 1 - s_int->min_value)); + return Range::make_by_min_extent(s_int->min_value, + analyzer.Simplify(s_int->max_value + 1 - s_int->min_value)); } return max_range; } @@ -607,26 +533,30 @@ bool IntSet::is_single_point() const { } bool IntSet::can_prove_positive() const { + Analyzer analyzer; const IntervalSetNode* s_int = (*this).as(); - return (s_int && is_positive_const(tir::Simplify(s_int->min_value))); + return (s_int && is_positive_const(analyzer.Simplify(s_int->min_value))); } bool IntSet::can_prove_negative() const { + Analyzer analyzer; const IntervalSetNode* s_int = (*this).as(); - return (s_int && is_negative_const(tir::Simplify(s_int->max_value))); + return (s_int && is_negative_const(analyzer.Simplify(s_int->max_value))); } bool IntSet::can_prove_non_positive() const { + Analyzer analyzer; if (const auto* s_int = (*this).as()) { - auto max = tir::Simplify(s_int->max_value); + auto max = analyzer.Simplify(s_int->max_value); return is_zero(max) || is_negative_const(max); } return false; } bool IntSet::can_prove_non_negative() const { + Analyzer analyzer; if (const IntervalSetNode* s_int = (*this).as()) { - auto min = tir::Simplify(s_int->min_value); + auto min = analyzer.Simplify(s_int->min_value); return is_zero(min) || is_positive_const(min); } return false; @@ -649,17 +579,11 @@ PrimExpr IntSet::point_value() const { return s_int->min_value; } -IntSet IntSet::nothing() { - return IntervalSet::Empty(); -} +IntSet IntSet::nothing() { return IntervalSet::Empty(); } -IntSet IntSet::everything() { - return IntervalSet::Everything(); -} +IntSet IntSet::everything() { return IntervalSet::Everything(); } -IntSet IntSet::single_point(PrimExpr x) { - return IntervalSet::SinglePoint(x); -} +IntSet IntSet::single_point(PrimExpr x) { return IntervalSet::SinglePoint(x); } IntSet IntSet::interval(PrimExpr min, PrimExpr max) { if (min.same_as(max)) { @@ -669,8 +593,8 @@ IntSet IntSet::interval(PrimExpr min, PrimExpr max) { } // Range related code -inline bool ProveEqual(PrimExpr lhs, PrimExpr rhs) { - return is_zero(tir::Simplify(lhs - rhs)); +inline bool ProveEqual(Analyzer* analyzer, PrimExpr lhs, PrimExpr rhs) { + return is_zero(analyzer->Simplify(lhs - rhs)); } IntSet IntSet::range(Range r) { @@ -685,8 +609,9 @@ bool IntSet::match_range(const Range& b) const { const IntSet& a = *this; const IntervalSetNode* a_int = a.as(); if (!a_int) return false; - return ProveEqual(a_int->min_value, b->min) && - ProveEqual(a_int->max_value, b->extent + b->min - 1); + Analyzer ana; + return ProveEqual(&ana, a_int->min_value, b->min) && + ProveEqual(&ana, a_int->max_value, b->extent + b->min - 1); } IntSet Union(const Array& sets) { @@ -697,8 +622,7 @@ IntSet Union(const Array& sets) { for (size_t i = 1; i < sets.size(); ++i) { x = Union(&ana, x, ToIntervalSet(sets[i])); } - return IntervalSet(tir::Simplify(x->min_value), - tir::Simplify(x->max_value)); + return IntervalSet(ana.Simplify(x->min_value), ana.Simplify(x->max_value)); } IntSet Intersect(const Array& sets) { @@ -709,8 +633,7 @@ IntSet Intersect(const Array& sets) { for (size_t i = 1; i < sets.size(); ++i) { x = Intersect(&ana, x, ToIntervalSet(sets[i])); } - return IntervalSet(tir::Simplify(x->min_value), - tir::Simplify(x->max_value)); + return IntervalSet(ana.Simplify(x->min_value), ana.Simplify(x->max_value)); } Map ConvertDomMap(const Map& dom_map) { @@ -721,8 +644,7 @@ Map ConvertDomMap(const Map& dom_map) { return dmap; } -Map ConvertDomMap( - const std::unordered_map& dom_map) { +Map ConvertDomMap(const std::unordered_map& dom_map) { Map dmap; for (auto kv : dom_map) { dmap.Set(GetRef(kv.first), kv.second); @@ -730,8 +652,7 @@ Map ConvertDomMap( return dmap; } -IntSet EvalSet(PrimExpr e, - const Map& dom_map) { +IntSet EvalSet(PrimExpr e, const Map& dom_map) { Analyzer ana; return IntervalSetEvaluator(&ana, dom_map, false).Eval(e); } @@ -742,49 +663,40 @@ IntSet IntSet::vector(PrimExpr x) { return IntervalSetEvaluator(&ana, dmap, true).Eval(x); } -IntSet EvalSet(PrimExpr e, - const Map& dom_map) { +IntSet EvalSet(PrimExpr e, const Map& dom_map) { return EvalSet(e, ConvertDomMap(dom_map)); } -IntSet EvalSet(PrimExpr e, - const std::unordered_map& dom_map) { +IntSet EvalSet(PrimExpr e, const std::unordered_map& dom_map) { return EvalSet(e, ConvertDomMap(dom_map)); } -IntSet EvalSet(Range r, - const Map& dom_map) { +IntSet EvalSet(Range r, const Map& dom_map) { Analyzer ana; IntervalSetEvaluator m(&ana, dom_map); // Simplifying first can give tighter bounds if r->min and r->extent share variables PrimExpr sum = r->min + r->extent - 1; - auto res = m.Eval(IntervalSet(r->min, Simplify(sum))); + auto res = m.Eval(IntervalSet(r->min, ana.Simplify(sum))); return std::move(res); } -IntSet EvalSet(Range r, - const std::unordered_map& dom_map) { +IntSet EvalSet(Range r, const std::unordered_map& dom_map) { return EvalSet(r, ConvertDomMap(dom_map)); } -IntSet EvalSet(IntSet s, - const std::unordered_map& dom_map) { +IntSet EvalSet(IntSet s, const std::unordered_map& dom_map) { Analyzer ana; auto dmap = ConvertDomMap(dom_map); IntervalSetEvaluator m(&ana, dmap); const IntervalSetNode* s_int = s.as(); - PrimExpr vmax = s_int->HasUpperBound() ? - m.Eval(s_int->max_value).max() : s_int->max_value; - PrimExpr vmin = s_int->HasLowerBound() ? - m.Eval(s_int->min_value).min() : s_int->min_value; + PrimExpr vmax = s_int->HasUpperBound() ? m.Eval(s_int->max_value).max() : s_int->max_value; + PrimExpr vmin = s_int->HasLowerBound() ? m.Eval(s_int->min_value).min() : s_int->min_value; return IntervalSet(vmin, vmax); } class SubExprIntervalSetEvaluator : public IntervalSetEvaluator { public: - explicit SubExprIntervalSetEvaluator( - Analyzer* analyzer, - const Map& dom_map) + explicit SubExprIntervalSetEvaluator(Analyzer* analyzer, const Map& dom_map) : IntervalSetEvaluator(analyzer, dom_map) {} IntervalSet VisitExpr(const PrimExpr& n) final { @@ -796,9 +708,8 @@ class SubExprIntervalSetEvaluator : public IntervalSetEvaluator { ExprIntSetMap expr_map; }; -ExprIntSetMap EvalSetForEachSubExpr( - PrimExpr e, - const std::unordered_map& dom_map) { +ExprIntSetMap EvalSetForEachSubExpr(PrimExpr e, + const std::unordered_map& dom_map) { Analyzer ana; auto dmap = ConvertDomMap(dom_map); SubExprIntervalSetEvaluator m(&ana, dmap); @@ -806,42 +717,32 @@ ExprIntSetMap EvalSetForEachSubExpr( return m.expr_map; } -IntSet EvalSet(Range r, - const Map& dom_map) { +IntSet EvalSet(Range r, const Map& dom_map) { return EvalSet(r, ConvertDomMap(dom_map)); } TVM_REGISTER_NODE_TYPE(IntervalSetNode); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "IntervalSet" - << "[" << op->min_value << ", " - << op->max_value << ']'; - }); - + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "IntervalSet" + << "[" << op->min_value << ", " << op->max_value << ']'; + }); -TVM_REGISTER_GLOBAL("arith.intset_single_point") -.set_body_typed(IntSet::single_point); +TVM_REGISTER_GLOBAL("arith.intset_single_point").set_body_typed(IntSet::single_point); -TVM_REGISTER_GLOBAL("arith.intset_vector") -.set_body_typed(IntSet::vector); +TVM_REGISTER_GLOBAL("arith.intset_vector").set_body_typed(IntSet::vector); -TVM_REGISTER_GLOBAL("arith.intset_interval") -.set_body_typed(IntSet::interval); +TVM_REGISTER_GLOBAL("arith.intset_interval").set_body_typed(IntSet::interval); -TVM_REGISTER_GLOBAL("arith.IntervalSetGetMin") -.set_body_method(&IntSet::min); +TVM_REGISTER_GLOBAL("arith.IntervalSetGetMin").set_body_method(&IntSet::min); -TVM_REGISTER_GLOBAL("arith.IntervalSetGetMax") -.set_body_method(&IntSet::max); +TVM_REGISTER_GLOBAL("arith.IntervalSetGetMax").set_body_method(&IntSet::max); -TVM_REGISTER_GLOBAL("arith.IntSetIsNothing") -.set_body_method(&IntSet::is_nothing); +TVM_REGISTER_GLOBAL("arith.IntSetIsNothing").set_body_method(&IntSet::is_nothing); -TVM_REGISTER_GLOBAL("arith.IntSetIsEverything") -.set_body_method(&IntSet::is_everything); +TVM_REGISTER_GLOBAL("arith.IntSetIsEverything").set_body_method(&IntSet::is_everything); } // namespace arith } // namespace tvm diff --git a/src/arith/interval_set.h b/src/arith/interval_set.h index 51b500adb412..eb308dd385a4 100644 --- a/src/arith/interval_set.h +++ b/src/arith/interval_set.h @@ -26,7 +26,9 @@ #include #include + #include + #include "const_fold.h" namespace tvm { @@ -53,26 +55,18 @@ class IntervalSetNode : public IntSetNode { } /*! \return Whether the interval has upper bound. */ - bool HasUpperBound() const { - return !is_pos_inf(max_value) && !IsEmpty(); - } + bool HasUpperBound() const { return !is_pos_inf(max_value) && !IsEmpty(); } /*! \return Whether the interval has lower bound. */ - bool HasLowerBound() const { - return !is_neg_inf(min_value) && !IsEmpty(); - } + bool HasLowerBound() const { return !is_neg_inf(min_value) && !IsEmpty(); } /*! \return Whether the interval is a single point. */ - bool IsSinglePoint() const { - return min_value.same_as(max_value); - } + bool IsSinglePoint() const { return min_value.same_as(max_value); } /*! \return whether interval represent nothing */ bool IsEmpty() const { // during computations, either extreme could occur. return is_pos_inf(min_value) || is_neg_inf(max_value); } /*! \return whether interval represent everything */ - bool IsEverything() const { - return is_neg_inf(min_value) && is_pos_inf(max_value); - } + bool IsEverything() const { return is_neg_inf(min_value) && is_pos_inf(max_value); } static constexpr const char* _type_key = "arith.IntervalSet"; TVM_DECLARE_FINAL_OBJECT_INFO(IntervalSetNode, IntSetNode); @@ -97,24 +91,18 @@ class IntervalSet : public IntSet { * \param value The value to be represented. * \return The result set. */ - static IntervalSet SinglePoint(PrimExpr value) { - return IntervalSet(value, value); - } + static IntervalSet SinglePoint(PrimExpr value) { return IntervalSet(value, value); } /*! * \brief Create an IntervalSet that represents everything. * \param value The value to be represented. * \return The result set. */ - static IntervalSet Everything() { - return IntervalSet(neg_inf(), pos_inf()); - } + static IntervalSet Everything() { return IntervalSet(neg_inf(), pos_inf()); } /*! * \brief Create an empty eet. * \return The result set. */ - static IntervalSet Empty() { - return IntervalSet(pos_inf(), neg_inf()); - } + static IntervalSet Empty() { return IntervalSet(pos_inf(), neg_inf()); } TVM_DEFINE_OBJECT_REF_COW_METHOD(IntervalSetNode); TVM_DEFINE_OBJECT_REF_METHODS(IntervalSet, IntSet, IntervalSetNode); @@ -136,7 +124,7 @@ TVM_DLL IntervalSet Union(Analyzer* analyzer, IntervalSet a, IntervalSet b); * \param b The second set. * \return The result set. */ -TVM_DLL IntervalSet Intersect(Analyzer *analzyer, IntervalSet a, IntervalSet b); +TVM_DLL IntervalSet Intersect(Analyzer* analzyer, IntervalSet a, IntervalSet b); } // namespace arith } // namespace tvm diff --git a/src/arith/ir_mutator_with_analyzer.cc b/src/arith/ir_mutator_with_analyzer.cc index 6e653cec3c3b..84e2093dcf98 100644 --- a/src/arith/ir_mutator_with_analyzer.cc +++ b/src/arith/ir_mutator_with_analyzer.cc @@ -20,24 +20,22 @@ /*! * \file tvm/arith/ir_mutator_with_analyzer.cc */ -#include -#include #include "ir_mutator_with_analyzer.h" +#include +#include + namespace tvm { namespace arith { using namespace tir; -Stmt IRMutatorWithAnalyzer:: -VisitStmt_(const ForNode* op) { - analyzer_->Bind(op->loop_var, - Range::make_by_min_extent(op->min, op->extent)); +Stmt IRMutatorWithAnalyzer::VisitStmt_(const ForNode* op) { + analyzer_->Bind(op->loop_var, Range::make_by_min_extent(op->min, op->extent)); return StmtExprMutator::VisitStmt_(op); } -Stmt IRMutatorWithAnalyzer:: -VisitStmt_(const LetStmtNode* op) { +Stmt IRMutatorWithAnalyzer::VisitStmt_(const LetStmtNode* op) { PrimExpr value = this->VisitExpr(op->value); if (!tir::HasSideEffect(value)) { analyzer_->Bind(op->var, value); @@ -45,8 +43,7 @@ VisitStmt_(const LetStmtNode* op) { // We keep the let-binding here // as sub-class may or maynot choose to replace it. Stmt body = this->VisitStmt(op->body); - if (value.same_as(op->value) && - body.same_as(op->body)) { + if (value.same_as(op->value) && body.same_as(op->body)) { return GetRef(op); } else { auto n = this->CopyOnWrite(op); @@ -56,29 +53,33 @@ VisitStmt_(const LetStmtNode* op) { } } -Stmt IRMutatorWithAnalyzer:: -VisitStmt_(const IfThenElseNode* op) { +Stmt IRMutatorWithAnalyzer::VisitStmt_(const IfThenElseNode* op) { PrimExpr condition = this->VisitExpr(op->condition); + PrimExpr real_condition = condition; + if (auto call = condition.as()) { + if (call->is_intrinsic(CallNode::likely)) { + real_condition = call->args[0]; + } + } + Stmt then_case, else_case; { - With ctx(analyzer_, condition); + With ctx(analyzer_, real_condition); then_case = this->VisitStmt(op->then_case); } if (op->else_case.defined()) { - With ctx(analyzer_, - analyzer_->rewrite_simplify(NotNode::make(condition))); - else_case = this->VisitStmt(op->else_case); + With ctx(analyzer_, analyzer_->rewrite_simplify(Not(real_condition))); + else_case = this->VisitStmt(op->else_case); } - if (is_one(condition)) return then_case; - if (is_zero(condition)) { + if (is_one(real_condition)) return then_case; + if (is_zero(real_condition)) { if (else_case.defined()) { return else_case; } - return EvaluateNode::make(0); + return Evaluate(0); } - if (condition.same_as(op->condition) && - then_case.same_as(op->then_case) && + if (condition.same_as(op->condition) && then_case.same_as(op->then_case) && else_case.same_as(op->else_case)) { return GetRef(op); } else { @@ -90,14 +91,11 @@ VisitStmt_(const IfThenElseNode* op) { } } -Stmt IRMutatorWithAnalyzer:: -VisitStmt_(const AttrStmtNode* op) { - if (op->attr_key == tir::attr::thread_extent || - op->attr_key == tir::attr::virtual_thread) { +Stmt IRMutatorWithAnalyzer::VisitStmt_(const AttrStmtNode* op) { + if (op->attr_key == tir::attr::thread_extent || op->attr_key == tir::attr::virtual_thread) { IterVar iv = Downcast(op->node); CHECK_NE(iv->thread_tag.length(), 0U); - analyzer_->Bind(iv->var, - Range::make_by_min_extent(0, op->value)); + analyzer_->Bind(iv->var, Range::make_by_min_extent(0, op->value)); Stmt stmt = StmtExprMutator::VisitStmt_(op); return stmt; } else { @@ -105,16 +103,13 @@ VisitStmt_(const AttrStmtNode* op) { } } -Stmt IRMutatorWithAnalyzer:: -VisitStmt_(const AssertStmtNode* op) { +Stmt IRMutatorWithAnalyzer::VisitStmt_(const AssertStmtNode* op) { PrimExpr condition = this->VisitExpr(op->condition); PrimExpr message = this->VisitExpr(op->message); With ctx(analyzer_, condition); Stmt body = this->VisitStmt(op->body); - if (condition.same_as(op->condition) && - message.same_as(op->message) && - body.same_as(op->body)) { + if (condition.same_as(op->condition) && message.same_as(op->message) && body.same_as(op->body)) { return GetRef(op); } else { auto n = this->CopyOnWrite(op); @@ -125,8 +120,7 @@ VisitStmt_(const AssertStmtNode* op) { } } -PrimExpr IRMutatorWithAnalyzer:: -VisitExpr_(const CallNode* op) { +PrimExpr IRMutatorWithAnalyzer::VisitExpr_(const CallNode* op) { // add condition context to if_then_else if (op->is_intrinsic(tir::intrinsic::tvm_if_then_else)) { PrimExpr cond = this->VisitExpr(op->args[0]); @@ -136,8 +130,7 @@ VisitExpr_(const CallNode* op) { true_value = this->VisitExpr(op->args[1]); } { - With constraint(analyzer_, - analyzer_->rewrite_simplify(NotNode::make(cond))); + With constraint(analyzer_, analyzer_->rewrite_simplify(Not(cond))); false_value = this->VisitExpr(op->args[2]); } if (is_zero(cond)) { @@ -146,21 +139,17 @@ VisitExpr_(const CallNode* op) { if (is_one(cond)) { return true_value; } - if (cond.same_as(op->args[0]) && - true_value.same_as(op->args[1]) && + if (cond.same_as(op->args[0]) && true_value.same_as(op->args[1]) && false_value.same_as(op->args[2])) { return GetRef(op); } else { - return CallNode::make(op->dtype, op->name, - {cond, true_value, false_value}, - op->call_type); + return Call(op->dtype, op->name, {cond, true_value, false_value}, op->call_type); } } return StmtExprMutator::VisitExpr_(op); } -PrimExpr IRMutatorWithAnalyzer:: -VisitExpr_(const LetNode* op) { +PrimExpr IRMutatorWithAnalyzer::VisitExpr_(const LetNode* op) { PrimExpr value = this->VisitExpr(op->value); if (!tir::HasSideEffect(value)) { analyzer_->Bind(op->var, value); @@ -168,16 +157,14 @@ VisitExpr_(const LetNode* op) { // We keep the let-binding here // as sub-class may or maynot choose to replace it. PrimExpr body = this->VisitExpr(op->body); - if (value.same_as(op->value) && - body.same_as(op->body)) { + if (value.same_as(op->value) && body.same_as(op->body)) { return GetRef(op); } else { - return LetNode::make(op->var, value, body); + return Let(op->var, value, body); } } -PrimExpr IRMutatorWithAnalyzer:: -VisitExpr_(const SelectNode* op) { +PrimExpr IRMutatorWithAnalyzer::VisitExpr_(const SelectNode* op) { PrimExpr cond = this->VisitExpr(op->condition); PrimExpr true_value, false_value; { @@ -185,8 +172,7 @@ VisitExpr_(const SelectNode* op) { true_value = VisitExpr(op->true_value); } { - With constraint(analyzer_, - analyzer_->rewrite_simplify(NotNode::make(cond))); + With constraint(analyzer_, analyzer_->rewrite_simplify(Not(cond))); false_value = VisitExpr(op->false_value); } if (is_zero(cond)) { @@ -196,17 +182,15 @@ VisitExpr_(const SelectNode* op) { return true_value; } // normal path - if (cond.same_as(op->condition) && - true_value.same_as(op->true_value) && + if (cond.same_as(op->condition) && true_value.same_as(op->true_value) && false_value.same_as(op->false_value)) { return GetRef(op); } else { - return SelectNode::make(cond, true_value, false_value); + return Select(cond, true_value, false_value); } } -PrimExpr IRMutatorWithAnalyzer:: -VisitExpr_(const ReduceNode* op) { +PrimExpr IRMutatorWithAnalyzer::VisitExpr_(const ReduceNode* op) { // Setup the domain information before simplification. for (const IterVar& iv : op->axis) { analyzer_->Bind(iv->var, iv->dom); diff --git a/src/arith/ir_mutator_with_analyzer.h b/src/arith/ir_mutator_with_analyzer.h index 394e5db9c93e..004265bbe50a 100644 --- a/src/arith/ir_mutator_with_analyzer.h +++ b/src/arith/ir_mutator_with_analyzer.h @@ -24,8 +24,9 @@ #ifndef TVM_ARITH_IR_MUTATOR_WITH_ANALYZER_H_ #define TVM_ARITH_IR_MUTATOR_WITH_ANALYZER_H_ -#include #include +#include + #include namespace tvm { @@ -42,18 +43,17 @@ namespace arith { */ class IRMutatorWithAnalyzer : public tir::StmtExprMutator { public: - explicit IRMutatorWithAnalyzer(Analyzer* analyzer) - : analyzer_(analyzer) {} + explicit IRMutatorWithAnalyzer(Analyzer* analyzer) : analyzer_(analyzer) {} - using StmtExprMutator::VisitStmt_; using StmtExprMutator::VisitExpr_; + using StmtExprMutator::VisitStmt_; // override functions that need to populate the context information. - Stmt VisitStmt_(const tir::ForNode* op) override; - Stmt VisitStmt_(const tir::LetStmtNode* op) override; - Stmt VisitStmt_(const tir::IfThenElseNode* op) override; - Stmt VisitStmt_(const tir::AttrStmtNode* op) override; - Stmt VisitStmt_(const tir::AssertStmtNode* op) override; + tir::Stmt VisitStmt_(const tir::ForNode* op) override; + tir::Stmt VisitStmt_(const tir::LetStmtNode* op) override; + tir::Stmt VisitStmt_(const tir::IfThenElseNode* op) override; + tir::Stmt VisitStmt_(const tir::AttrStmtNode* op) override; + tir::Stmt VisitStmt_(const tir::AssertStmtNode* op) override; PrimExpr VisitExpr_(const tir::LetNode* op) override; PrimExpr VisitExpr_(const tir::SelectNode* op) override; PrimExpr VisitExpr_(const tir::CallNode* op) override; diff --git a/src/arith/ir_visitor_with_analyzer.h b/src/arith/ir_visitor_with_analyzer.h index b2dbe9d10c08..810949b56e1f 100644 --- a/src/arith/ir_visitor_with_analyzer.h +++ b/src/arith/ir_visitor_with_analyzer.h @@ -34,23 +34,18 @@ namespace tir { class IRVisitorWithAnalyzer final : public StmtExprVisitor { public: - PrimExpr Simplify(const PrimExpr& expr) { - return analyzer_.Simplify(expr); - } + PrimExpr Simplify(const PrimExpr& expr) { return analyzer_.Simplify(expr); } void VisitStmt_(const ForNode* op) { - analyzer_.Bind(op->loop_var, - Range::make_by_min_extent(op->min, op->extent)); + analyzer_.Bind(op->loop_var, Range::make_by_min_extent(op->min, op->extent)); return StmtExprVisitor::VisitStmt_(op); } void VisitStmt_(const AttrStmtNode* op) { - if (op->attr_key == attr::thread_extent || - op->attr_key == attr::virtual_thread) { + if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) { IterVar iv = Downcast(op->node); CHECK_NE(iv->thread_tag.length(), 0U); - analyzer_.Bind(iv->var, - Range::make_by_min_extent(0, op->value)); + analyzer_.Bind(iv->var, Range::make_by_min_extent(0, op->value)); StmtExprVisitor::VisitStmt_(op); } else { StmtExprVisitor::VisitStmt_(op); diff --git a/src/arith/modular_set.cc b/src/arith/modular_set.cc index 40cd7f8793ee..3457674d4ed3 100644 --- a/src/arith/modular_set.cc +++ b/src/arith/modular_set.cc @@ -21,13 +21,15 @@ * \file modular_set.cc * \brief Modular set analysis */ -#include #include -#include +#include #include +#include + #include -#include #include +#include + #include "pattern_match.h" namespace tvm { @@ -46,19 +48,15 @@ ModularSet::ModularSet(int64_t coeff, int64_t base) { } TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "ModularSet(" - << "coeff=" << op->coeff << ", base=" - << op->base << ')'; - }); - -ModularSet MakeModularSet(int64_t coeff, int64_t base) { - return ModularSet(coeff, base); -} + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "ModularSet(" + << "coeff=" << op->coeff << ", base=" << op->base << ')'; + }); + +ModularSet MakeModularSet(int64_t coeff, int64_t base) { return ModularSet(coeff, base); } -TVM_REGISTER_GLOBAL("arith.ModularSet") -.set_body_typed(MakeModularSet); +TVM_REGISTER_GLOBAL("arith.ModularSet").set_body_typed(MakeModularSet); // internal entry for const int bound struct ModularSetAnalyzer::Entry { @@ -77,37 +75,27 @@ struct ModularSetAnalyzer::Entry { this->base = base; } - bool is_const() const { - return coeff == 0; - } + bool is_const() const { return coeff == 0; } - bool operator==(const Entry& other) const { - return coeff == other.coeff && base == other.base; - } + bool operator==(const Entry& other) const { return coeff == other.coeff && base == other.base; } bool operator==(const ModularSet& other) const { - return other.defined() && - coeff == other->coeff && base == other->base; + return other.defined() && coeff == other->coeff && base == other->base; } }; -class ModularSetAnalyzer::Impl : - public ExprFunctor { +class ModularSetAnalyzer::Impl : public ExprFunctor { public: - explicit Impl(Analyzer* parent) - : parent_(parent) {} + explicit Impl(Analyzer* parent) : parent_(parent) {} - void Update(const Var& var, - const ModularSet& info, - bool override) { + void Update(const Var& var, const ModularSet& info, bool override) { if (!override) { auto it = var_map_.find(var); if (it != var_map_.end()) { - CHECK(it->second == info) - << "Trying to update var \'" << var << "\'" - << " with a different const bound: " - << "original=" << ModularSet(it->second.coeff, it->second.base) - << ", new=" << info; + CHECK(it->second == info) << "Trying to update var \'" << var << "\'" + << " with a different const bound: " + << "original=" << ModularSet(it->second.coeff, it->second.base) + << ", new=" << info; } } var_map_[var] = Entry(info->coeff, info->base); @@ -127,17 +115,11 @@ class ModularSetAnalyzer::Impl : } // Override visitor behaviors - Entry VisitExprDefault_(const Object* op) final { - return Everything(); - } + Entry VisitExprDefault_(const Object* op) final { return Everything(); } - Entry VisitExpr_(const CastNode* op) final { - return VisitExpr(op->value); - } + Entry VisitExpr_(const CastNode* op) final { return VisitExpr(op->value); } - Entry VisitExpr_(const IntImmNode* op) final { - return Entry(0, op->value); - } + Entry VisitExpr_(const IntImmNode* op) final { return Entry(0, op->value); } Entry VisitExpr_(const AddNode* op) final { Entry a = VisitExpr(op->a); @@ -167,9 +149,7 @@ class ModularSetAnalyzer::Impl : return Entry(coeff, a.base * b.base); } - Entry DivByConst(const PrimExpr& lhs, - int64_t val, - bool round_down) { + Entry DivByConst(const PrimExpr& lhs, int64_t val, bool round_down) { Entry a = VisitExpr(lhs); CHECK_NE(val, 0); if (a.coeff % val == 0) { @@ -179,8 +159,7 @@ class ModularSetAnalyzer::Impl : } // positive division have a clear rounding mode. // Only handle case where we clearly know we need to round down. - if (a.base > 0 && val > 0 && - (round_down || parent_->CanProveGreaterEqual(lhs, 0))) { + if (a.base > 0 && val > 0 && (round_down || parent_->CanProveGreaterEqual(lhs, 0))) { return Entry(a.coeff / val, a.base / val); } } @@ -254,7 +233,7 @@ class ModularSetAnalyzer::Impl : /*! \brief pointer to parent. */ Analyzer* parent_{nullptr}; // internal variable map - std::unordered_map var_map_; + std::unordered_map var_map_; /*! * \brief Update var by intersecting entry with var's current set. * \param var The variable. @@ -269,9 +248,7 @@ class ModularSetAnalyzer::Impl : } var_map_[var] = Intersect(old, entry); // reover function. - return [this, old, var]() { - var_map_[var] = old; - }; + return [this, old, var]() { var_map_[var] = old; }; } /*! * \brief Create union of two sets. @@ -293,49 +270,7 @@ class ModularSetAnalyzer::Impl : return Entry(ZeroAwareGCD(ZeroAwareGCD(base0, base1), coeff), base0); } } - /*! - * \brief Use Extended Euclidean algorithm to solve ax + by = gcd(a, b) - * \param a The first coefficient. - * \param b The second coefficient. - * \param x The solution of x. - * \param y The solution of y. - * \return The GCD of a and b. - */ - static int64_t ExtendedEuclidean(int64_t a, int64_t b, int64_t* x, int64_t* y) { - // Extended Euclidean algorithm - // if a < 0, the problem can be convert into - // |a|* (-x) + b * y = gcd(|a|, b) - // - // initial condition: - // a * 0 + b * 1 = b - // a * 1 + b * 0 = a - int64_t s = 0, old_s = 1; - int64_t r = b, old_r = a >= 0 ? a : -a; - // Iteration (r2 < r1): - // a * x1 + b * y1 = r1 - // a * x2 + b * y2 = r2 - // The above two eqs can derive the following eq (q = r1 / r2) - // a * (x1 - x2 * q) + b * (y1 - y2 * q) = r1 - r2 * q = r3 - // Because r3 < r2, the iteration can eventually terminate - while (r != 0) { - int64_t q = old_r / r; - int64_t tmp = old_r; - old_r = r; - r = tmp - q * r; - tmp = old_s; - old_s = s; - s = tmp - q * s; - } - - *x = a >= 0 ? old_s : -old_s; - if (b != 0) { - *y = (old_r - (*x) * a) / b; - } else { - *y = 1; - } - return old_r; - } /*! * \brief Create interect of two sets. * \param a The left operand. @@ -362,39 +297,16 @@ class ModularSetAnalyzer::Impl : return Nothing(); } } - /*! - * \brief Take GCD of a and b. - * \param a The first operand. - * \param b The second operand. - * \return The result. - */ - static int64_t ZeroAwareGCD(int64_t a, int64_t b) { - if (a < 0) a = -a; - if (b < 0) b = -b; - if (a < b) std::swap(a, b); - if (b == 0) return a; - // perform GCD (greatest common divisor) - // ax + by = gcd(a, b) z if a != 0, b != 0 - while (a % b != 0) { - a = a % b; - std::swap(a, b); - } - return b; - } /*! * \brief return everything dtype can represent. * \return Bound that represent everything dtype can represent. */ - static Entry Everything() { - return Entry(1, 0); - } + static Entry Everything() { return Entry(1, 0); } /*! * \brief return an empty set * \return Bound that represent everything dtype can represent. */ - static Entry Nothing() { - return Entry(0, 1); - } + static Entry Nothing() { return Entry(0, 1); } }; ModularSet ModularSetAnalyzer::operator()(const PrimExpr& expr) { @@ -402,9 +314,7 @@ ModularSet ModularSetAnalyzer::operator()(const PrimExpr& expr) { return ModularSet(ret.coeff, ret.base); } -void ModularSetAnalyzer::Update(const Var& var, - const ModularSet& info, - bool override) { +void ModularSetAnalyzer::Update(const Var& var, const ModularSet& info, bool override) { impl_->Update(var, info, override); } @@ -412,13 +322,9 @@ std::function ModularSetAnalyzer::EnterConstraint(const PrimExpr& constr return impl_->EnterConstraint(constraint); } -ModularSetAnalyzer::ModularSetAnalyzer(Analyzer* parent) - : impl_(new Impl(parent)) { -} +ModularSetAnalyzer::ModularSetAnalyzer(Analyzer* parent) : impl_(new Impl(parent)) {} -ModularSetAnalyzer::~ModularSetAnalyzer() { - delete impl_; -} +ModularSetAnalyzer::~ModularSetAnalyzer() { delete impl_; } } // namespace arith } // namespace tvm diff --git a/src/arith/pattern_match.h b/src/arith/pattern_match.h index e81b0881f927..ff01941e4acf 100644 --- a/src/arith/pattern_match.h +++ b/src/arith/pattern_match.h @@ -65,9 +65,11 @@ #ifndef TVM_ARITH_PATTERN_MATCH_H_ #define TVM_ARITH_PATTERN_MATCH_H_ -#include #include +#include + #include + #include "const_fold.h" namespace tvm { @@ -84,7 +86,7 @@ namespace arith { * * \tparam Derived The type of the derived class. */ -template +template class Pattern { public: /*! @@ -108,30 +110,26 @@ class Pattern { * * \return whether value matches the pattern. */ - template + template bool Match(const NodeType& value) const { derived().InitMatch_(); return derived().Match_(value); } /*! \return Derived instance of current class. */ - const Derived& derived() const { - return *static_cast(this); - } + const Derived& derived() const { return *static_cast(this); } }; /*! * \brief Default deep equality checker * \tparam T the comparison point. */ -template +template class PEqualChecker { public: - bool operator()(const T& lhs, const T& rhs) const { - return lhs == rhs; - } + bool operator()(const T& lhs, const T& rhs) const { return lhs == rhs; } }; -template<> +template <> class PEqualChecker { public: bool operator()(const PrimExpr& lhs, const PrimExpr& rhs) const { @@ -140,20 +138,16 @@ class PEqualChecker { } }; -template<> +template <> class PEqualChecker { public: - bool operator()(const IntImm& lhs, const IntImm& rhs) const { - return lhs->value == rhs->value; - } + bool operator()(const IntImm& lhs, const IntImm& rhs) const { return lhs->value == rhs->value; } }; -template<> -class PEqualChecker { +template <> +class PEqualChecker { public: - bool operator()(const Var& lhs, const Var& rhs) const { - return lhs.same_as(rhs); - } + bool operator()(const tir::Var& lhs, const tir::Var& rhs) const { return lhs.same_as(rhs); } }; /*! @@ -166,15 +160,13 @@ class PEqualChecker { * \note PVar is not thread safe. * Do not use the same PVar in multiple threads. */ -template -class PVar : public Pattern > { +template +class PVar : public Pattern> { public: // Store PVars by reference in the expression. using Nested = const PVar&; - void InitMatch_() const { - filled_ = false; - } + void InitMatch_() const { filled_ = false; } bool Match_(const T& value) const { if (!filled_) { @@ -186,9 +178,8 @@ class PVar : public Pattern > { } } - template::value>::type> + template ::value>::type> bool Match_(const NodeRefType& value) const { if (const auto* ptr = value.template as()) { return Match_(GetRef(ptr)); @@ -214,21 +205,17 @@ class PVar : public Pattern > { * * \tparam T the type of the hole. */ -template -class PConst : public Pattern > { +template +class PConst : public Pattern> { public: PConst(T value) // NOLINT(*) : value_(value) {} void InitMatch_() const {} - bool Match_(const T& value) const { - return PEqualChecker()(value_, value); - } + bool Match_(const T& value) const { return PEqualChecker()(value_, value); } - T Eval() const { - return value_; - } + T Eval() const { return value_; } private: const T value_; @@ -236,13 +223,12 @@ class PConst : public Pattern > { /*! * \brief Pattern binary expression. - * \tparam NodeType The AST node type. + * \tparam OpType The AST noderef type. * \tparam TA The pattern type of the first operand. * \tparam TB The pattern type of the second operand. */ -template -class PBinaryExpr : - public Pattern > { +template +class PBinaryExpr : public Pattern> { public: PBinaryExpr(const TA& a, const TB& b) : a_(a), b_(b) {} @@ -252,6 +238,7 @@ class PBinaryExpr : } bool Match_(const ObjectRef& node) const { + using NodeType = typename OpType::ContainerType; if (const NodeType* ptr = node.as()) { if (!a_.Match_(ptr->a)) return false; if (!b_.Match_(ptr->b)) return false; @@ -264,9 +251,9 @@ class PBinaryExpr : PrimExpr Eval() const { PrimExpr lhs = a_.Eval(); PrimExpr rhs = b_.Eval(); - PrimExpr ret = TryConstFold(lhs, rhs); + PrimExpr ret = TryConstFold(lhs, rhs); if (ret.defined()) return ret; - return NodeType::make(lhs, rhs); + return OpType(lhs, rhs); } private: @@ -274,12 +261,10 @@ class PBinaryExpr : typename TB::Nested b_; }; -template -class PConstWithTypeLike : - public Pattern > { +template +class PConstWithTypeLike : public Pattern> { public: - PConstWithTypeLike(const TA& ref, int64_t value) - : ref_(ref), value_(value) {} + PConstWithTypeLike(const TA& ref, int64_t value) : ref_(ref), value_(value) {} void InitMatch_() const {} @@ -291,79 +276,70 @@ class PConstWithTypeLike : } } - PrimExpr Eval() const { - return tir::make_const(ref_.Eval().dtype(), value_); - } + PrimExpr Eval() const { return tir::make_const(ref_.Eval().dtype(), value_); } private: typename TA::Nested ref_; int64_t value_; }; - -#define TVM_PATTERN_BINARY_OP_EX(FuncName, NodeName, CheckStep) \ - template \ - inline PBinaryExpr \ - FuncName(const Pattern& a, const Pattern& b) { \ - CheckStep; \ - return PBinaryExpr(a.derived(), b.derived()); \ - } \ - template \ - inline PBinaryExpr > \ - FuncName(const Pattern& a, int64_t b) { \ - CheckStep; \ - return FuncName(a, PConstWithTypeLike(a.derived(), b)); \ - } \ - template \ - inline PBinaryExpr, TA> \ - FuncName(int64_t b, const Pattern& a) { \ - CheckStep; \ - return FuncName(PConstWithTypeLike(a.derived(), b), a); \ - } - -#define TVM_PATTERN_BINARY_OP(FuncName, NodeName) \ - TVM_PATTERN_BINARY_OP_EX(FuncName, NodeName, ) - +#define TVM_PATTERN_BINARY_OP_EX(FuncName, NodeName, CheckStep) \ + template \ + inline PBinaryExpr FuncName(const Pattern& a, const Pattern& b) { \ + CheckStep; \ + return PBinaryExpr(a.derived(), b.derived()); \ + } \ + template \ + inline PBinaryExpr> FuncName(const Pattern& a, \ + int64_t b) { \ + CheckStep; \ + return FuncName(a, PConstWithTypeLike(a.derived(), b)); \ + } \ + template \ + inline PBinaryExpr, TA> FuncName(int64_t b, \ + const Pattern& a) { \ + CheckStep; \ + return FuncName(PConstWithTypeLike(a.derived(), b), a); \ + } + +#define TVM_PATTERN_BINARY_OP(FuncName, NodeName) TVM_PATTERN_BINARY_OP_EX(FuncName, NodeName, ) // raise ambiguity error for operator overload of / and % -TVM_PATTERN_BINARY_OP_EX(operator/, tir::DivNode, DivAmbiguityError(a)); -TVM_PATTERN_BINARY_OP_EX(operator%, tir::ModNode, DivAmbiguityError(a)); +TVM_PATTERN_BINARY_OP_EX(operator/, tir::Div, DivAmbiguityError(a)); +TVM_PATTERN_BINARY_OP_EX(operator%, tir::Mod, DivAmbiguityError(a)); // arithmetic expressions -TVM_PATTERN_BINARY_OP(operator+, tir::AddNode); -TVM_PATTERN_BINARY_OP(operator-, tir::SubNode); -TVM_PATTERN_BINARY_OP(operator*, tir::MulNode); -TVM_PATTERN_BINARY_OP(min, tir::MinNode); -TVM_PATTERN_BINARY_OP(max, tir::MaxNode); -TVM_PATTERN_BINARY_OP(div, tir::DivNode); -TVM_PATTERN_BINARY_OP(truncdiv, tir::DivNode); -TVM_PATTERN_BINARY_OP(truncmod, tir::ModNode); -TVM_PATTERN_BINARY_OP(floordiv, tir::FloorDivNode); -TVM_PATTERN_BINARY_OP(floormod, tir::FloorModNode); +TVM_PATTERN_BINARY_OP(operator+, tir::Add); +TVM_PATTERN_BINARY_OP(operator-, tir::Sub); +TVM_PATTERN_BINARY_OP(operator*, tir::Mul); +TVM_PATTERN_BINARY_OP(min, tir::Min); +TVM_PATTERN_BINARY_OP(max, tir::Max); +TVM_PATTERN_BINARY_OP(div, tir::Div); +TVM_PATTERN_BINARY_OP(truncdiv, tir::Div); +TVM_PATTERN_BINARY_OP(truncmod, tir::Mod); +TVM_PATTERN_BINARY_OP(floordiv, tir::FloorDiv); +TVM_PATTERN_BINARY_OP(floormod, tir::FloorMod); // logical expressions -TVM_PATTERN_BINARY_OP(operator>, tir::GTNode); -TVM_PATTERN_BINARY_OP(operator>=, tir::GENode); -TVM_PATTERN_BINARY_OP(operator<, tir::LTNode); -TVM_PATTERN_BINARY_OP(operator<=, tir::LENode); -TVM_PATTERN_BINARY_OP(operator==, tir::EQNode); -TVM_PATTERN_BINARY_OP(operator!=, tir::NENode); -TVM_PATTERN_BINARY_OP(operator&&, tir::AndNode); -TVM_PATTERN_BINARY_OP(operator||, tir::OrNode); +TVM_PATTERN_BINARY_OP(operator>, tir::GT); +TVM_PATTERN_BINARY_OP(operator>=, tir::GE); +TVM_PATTERN_BINARY_OP(operator<, tir::LT); +TVM_PATTERN_BINARY_OP(operator<=, tir::LE); +TVM_PATTERN_BINARY_OP(operator==, tir::EQ); +TVM_PATTERN_BINARY_OP(operator!=, tir::NE); +TVM_PATTERN_BINARY_OP(operator&&, tir::And); +TVM_PATTERN_BINARY_OP(operator||, tir::Or); /*! * \brief Pattern not expression. * \tparam TA The pattern type of the true operand. */ -template -class PNotExpr : public Pattern > { +template +class PNotExpr : public Pattern> { public: - explicit PNotExpr(const TA& value) - : value_(value) {} + explicit PNotExpr(const TA& value) : value_(value) {} - void InitMatch_() const { - value_.InitMatch_(); - } + void InitMatch_() const { value_.InitMatch_(); } bool Match_(const ObjectRef& node) const { if (const tir::NotNode* ptr = node.as()) { @@ -374,15 +350,13 @@ class PNotExpr : public Pattern > { } } - PrimExpr Eval() const { - return tir::NotNode::make(value_.Eval()); - } + PrimExpr Eval() const { return tir::Not(value_.Eval()); } private: typename TA::Nested value_; }; -template +template inline PNotExpr operator!(const Pattern& value) { return PNotExpr(value.derived()); } @@ -394,16 +368,11 @@ inline PNotExpr operator!(const Pattern& value) { * \tparam TA The pattern type of the true operand. * \tparam TB The pattern type of the false operand. */ -template -class PSelectExpr : - public Pattern > { +template +class PSelectExpr : public Pattern> { public: - PSelectExpr(const TCond& condition, - const TA& true_value, - const TB& false_value) - : condition_(condition), - true_value_(true_value), - false_value_(false_value) {} + PSelectExpr(const TCond& condition, const TA& true_value, const TB& false_value) + : condition_(condition), true_value_(true_value), false_value_(false_value) {} void InitMatch_() const { condition_.InitMatch_(); @@ -423,8 +392,7 @@ class PSelectExpr : } PrimExpr Eval() const { - return tir::SelectNode::make( - condition_.Eval(), true_value_.Eval(), false_value_.Eval()); + return tir::Select(condition_.Eval(), true_value_.Eval(), false_value_.Eval()); } private: @@ -446,13 +414,12 @@ class PSelectExpr : * \tparam TA The pattern type of the true operand. * \tparam TB The pattern type of the false operand. */ -template -inline PSelectExpr -select(const Pattern& condition, - const Pattern& true_value, - const Pattern& false_value) { - return PSelectExpr( - condition.derived(), true_value.derived(), false_value.derived()); +template +inline PSelectExpr select(const Pattern& condition, + const Pattern& true_value, + const Pattern& false_value) { + return PSelectExpr(condition.derived(), true_value.derived(), + false_value.derived()); } /*! @@ -460,13 +427,10 @@ select(const Pattern& condition, * \tparam DType The Pattern type of dtype. * \tparam TA The pattern type of the first operand. */ -template -class PCastExpr : - public Pattern > { +template +class PCastExpr : public Pattern> { public: - PCastExpr(const DType& dtype, const TA& value) - : dtype_(dtype), value_(value) { - } + PCastExpr(const DType& dtype, const TA& value) : dtype_(dtype), value_(value) {} void InitMatch_() const { dtype_.InitMatch_(); @@ -483,9 +447,7 @@ class PCastExpr : } } - PrimExpr Eval() const { - return tir::CastNode::make(dtype_.Eval(), value_.Eval()); - } + PrimExpr Eval() const { return tir::Cast(dtype_.Eval(), value_.Eval()); } private: typename DType::Nested dtype_; @@ -503,9 +465,8 @@ class PCastExpr : * \tparam DType The pattern type of type. * \tparam TA The pattern type of value. */ -template -inline PCastExpr -cast(const Pattern& dtype, const Pattern& value) { +template +inline PCastExpr cast(const Pattern& dtype, const Pattern& value) { return PCastExpr(dtype.derived(), value.derived()); } @@ -515,15 +476,11 @@ cast(const Pattern& dtype, const Pattern& value) { * \tparam TStride The pattern type of the stride. * \tparam TLanes The pattern type of the lanes. */ -template -class PRampExpr : - public Pattern > { +template +class PRampExpr : public Pattern> { public: - PRampExpr(const TBase& base, - const TStride& stride, - const TLanes& lanes) - : base_(base), stride_(stride), lanes_(lanes) { - } + PRampExpr(const TBase& base, const TStride& stride, const TLanes& lanes) + : base_(base), stride_(stride), lanes_(lanes) {} void InitMatch_() const { base_.InitMatch_(); @@ -542,9 +499,7 @@ class PRampExpr : } } - PrimExpr Eval() const { - return tir::RampNode::make(base_.Eval(), stride_.Eval(), lanes_.Eval()); - } + PrimExpr Eval() const { return tir::Ramp(base_.Eval(), stride_.Eval(), lanes_.Eval()); } private: typename TBase::Nested base_; @@ -565,13 +520,18 @@ class PRampExpr : * \tparam TStride The pattern type of the stride. * \tparam TLanes The pattern type of the lanes. */ -template -inline PRampExpr -ramp(const Pattern& base, - const Pattern& stride, - const Pattern& lanes) { - return PRampExpr( - base.derived(), stride.derived(), lanes.derived()); +template +inline PRampExpr ramp(const Pattern& base, + const Pattern& stride, + const Pattern& lanes) { + return PRampExpr(base.derived(), stride.derived(), lanes.derived()); +} + +template +inline PRampExpr, PConst> ramp(const Pattern& base, + int stride, int lanes) { + return PRampExpr, PConst>( + base.derived(), PConstWithTypeLike(base.derived(), stride), PConst(lanes)); } /*! @@ -579,14 +539,10 @@ ramp(const Pattern& base, * \tparam TA The pattern type of the value. * \tparam TLanes The pattern type of the lanes. */ -template -class PBroadcastExpr : - public Pattern > { +template +class PBroadcastExpr : public Pattern> { public: - PBroadcastExpr(const TA& value, - const TLanes& lanes) - : value_(value), lanes_(lanes) { - } + PBroadcastExpr(const TA& value, const TLanes& lanes) : value_(value), lanes_(lanes) {} void InitMatch_() const { value_.InitMatch_(); @@ -603,9 +559,7 @@ class PBroadcastExpr : } } - PrimExpr Eval() const { - return tir::BroadcastNode::make(value_.Eval(), lanes_.Eval()); - } + PrimExpr Eval() const { return tir::Broadcast(value_.Eval(), lanes_.Eval()); } private: typename TA::Nested value_; @@ -623,40 +577,37 @@ class PBroadcastExpr : * \tparam TA The pattern type of the value. * \tparam TLanes The pattern type of the lanes. */ -template -inline PBroadcastExpr -broadcast(const Pattern& value, const Pattern& lanes) { +template +inline PBroadcastExpr broadcast(const Pattern& value, + const Pattern& lanes) { return PBroadcastExpr(value.derived(), lanes.derived()); } // internal namespace namespace detail { // implementation details for CallExpr -template +template struct tuple_for_each_dispatcher { - template - static void run(F& f, const TTuple& tuple) { // NOLINT(*) + template + static void run(F& f, const TTuple& tuple) { // NOLINT(*) f(I, std::get(tuple)); - tuple_for_each_dispatcher< - (I + 1) == std::tuple_size::value, (I + 1), F> - ::run(f, tuple); + tuple_for_each_dispatcher<(I + 1) == std::tuple_size::value, (I + 1), F>::run(f, tuple); } }; -template +template struct tuple_for_each_dispatcher { - template - static void run(F& f, const TTuple& tuple) {} // NOLINT(*) + template + static void run(F& f, const TTuple& tuple) {} // NOLINT(*) }; -template +template inline void tuple_for_each(F& f, const TTuple& tuple) { // NOLINT(*) - tuple_for_each_dispatcher::value == 0, 0, F> - ::run(f, tuple); + tuple_for_each_dispatcher::value == 0, 0, F>::run(f, tuple); } struct PCallExprInitMatchFunctor { - template + template void operator()(size_t i, const T& pattern) const { pattern.InitMatch_(); } @@ -666,10 +617,9 @@ struct PCallExprMatchFunctor { const tir::CallNode* call_; bool matched_{true}; - explicit PCallExprMatchFunctor(const tir::CallNode* call) - : call_(call) {} + explicit PCallExprMatchFunctor(const tir::CallNode* call) : call_(call) {} - template + template void operator()(size_t i, const T& pattern) { matched_ = matched_ && pattern.Match_(call_->args[i]); } @@ -678,7 +628,7 @@ struct PCallExprMatchFunctor { struct PCallExprEvalArgsFunctor { Array args_; - template + template void operator()(size_t i, const T& pattern) { args_.push_back(pattern.Eval()); } @@ -692,13 +642,10 @@ struct PCallExprEvalArgsFunctor { * \note Op functor contains the name of the function and * the implementation of Eval. */ -template -class PCallExpr : - public Pattern > { +template +class PCallExpr : public Pattern> { public: - explicit PCallExpr(const TArgs&... args) - : args_(args...) { - } + explicit PCallExpr(const TArgs&... args) : args_(args...) {} void InitMatch_() const { detail::PCallExprInitMatchFunctor finit; @@ -728,18 +675,16 @@ class PCallExpr : }; // arithemetic intrinsics -#define TVM_PATTERN_BINARY_INTRIN(FuncName, OpName, IntrinStr) \ - struct OpName { \ - static PrimExpr Eval(Array args) { \ - return tir::CallNode::make(args[0].dtype(), kName, args, \ - tir::CallNode::PureIntrinsic); \ - } \ - static constexpr const char* kName = IntrinStr; \ - }; \ - template \ - inline PCallExpr \ - FuncName(const Pattern& a, const Pattern& b) { \ - return PCallExpr(a.derived(), b.derived()); \ +#define TVM_PATTERN_BINARY_INTRIN(FuncName, OpName, IntrinStr) \ + struct OpName { \ + static PrimExpr Eval(Array args) { \ + return tir::Call(args[0].dtype(), kName, args, tir::CallNode::PureIntrinsic); \ + } \ + static constexpr const char* kName = IntrinStr; \ + }; \ + template \ + inline PCallExpr FuncName(const Pattern& a, const Pattern& b) { \ + return PCallExpr(a.derived(), b.derived()); \ } TVM_PATTERN_BINARY_INTRIN(operator<<, PLeftShiftOp, "shift_left"); @@ -749,18 +694,16 @@ TVM_PATTERN_BINARY_INTRIN(operator|, PBitwiseOrOp, "bitwise_or"); TVM_PATTERN_BINARY_INTRIN(operator^, PBitwiseXorOp, "bitwise_xor"); // unary intrinsics -#define TVM_PATTERN_UNARY_INTRIN(FuncName, OpName, IntrinStr) \ - struct OpName { \ - static PrimExpr Eval(Array args) { \ - return tir::CallNode::make(args[0].dtype(), kName, args, \ - tir::CallNode::PureIntrinsic); \ - } \ - static constexpr const char* kName = IntrinStr; \ - }; \ - template \ - inline PCallExpr \ - FuncName(const Pattern& a) { \ - return PCallExpr(a.derived()); \ +#define TVM_PATTERN_UNARY_INTRIN(FuncName, OpName, IntrinStr) \ + struct OpName { \ + static PrimExpr Eval(Array args) { \ + return tir::Call(args[0].dtype(), kName, args, tir::CallNode::PureIntrinsic); \ + } \ + static constexpr const char* kName = IntrinStr; \ + }; \ + template \ + inline PCallExpr FuncName(const Pattern& a) { \ + return PCallExpr(a.derived()); \ } TVM_PATTERN_UNARY_INTRIN(operator~, PBitwiseNotOp, "bitwise_not"); @@ -768,9 +711,7 @@ TVM_PATTERN_UNARY_INTRIN(operator~, PBitwiseNotOp, "bitwise_not"); // if_then_else struct PIfThenElseOp { static PrimExpr Eval(Array args) { - return tir::CallNode::make( - args[1].dtype(), kName, args, - tir::CallNode::PureIntrinsic); + return tir::Call(args[1].dtype(), kName, args, tir::CallNode::PureIntrinsic); } static constexpr const char* kName = "tvm_if_then_else"; }; @@ -788,13 +729,12 @@ struct PIfThenElseOp { * \tparam TA The pattern type of the true operand. * \tparam TB The pattern type of the false operand. */ -template -inline PCallExpr -if_then_else(const Pattern& cond, - const Pattern& true_value, - const Pattern& false_value) { - return PCallExpr( - cond.derived(), true_value.derived(), false_value.derived()); +template +inline PCallExpr if_then_else(const Pattern& cond, + const Pattern& true_value, + const Pattern& false_value) { + return PCallExpr(cond.derived(), true_value.derived(), + false_value.derived()); } } // namespace arith diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index 126310813cc4..ce3f2a6223f2 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -22,12 +22,15 @@ * \brief Rewrite-rule based simplification. */ // Acknowledgement: Most rewrite-rules are from Halide. +#include "rewrite_simplify.h" + #include #include + #include + #include "const_fold.h" #include "pattern_match.h" -#include "rewrite_simplify.h" namespace tvm { namespace arith { @@ -35,9 +38,9 @@ namespace arith { using namespace tir; // macro for doing simple rewrite -#define TVM_TRY_REWRITE(SrcExpr, ResExpr) \ - if ((SrcExpr).Match(ret)) { \ - return (ResExpr).Eval(); \ +#define TVM_TRY_REWRITE(SrcExpr, ResExpr) \ + if ((SrcExpr).Match(ret)) { \ + return (ResExpr).Eval(); \ } // macro for rewrite + recursively rewrite ResExpr @@ -47,15 +50,15 @@ using namespace tir; } // macro rewrite only if CondExor is true after match. -#define TVM_TRY_REWRITE_IF(SrcExpr, ResExpr, CondExpr) \ - if ((SrcExpr).Match(ret) && (CondExpr)) { \ - return (ResExpr).Eval(); \ +#define TVM_TRY_REWRITE_IF(SrcExpr, ResExpr, CondExpr) \ + if ((SrcExpr).Match(ret) && (CondExpr)) { \ + return (ResExpr).Eval(); \ } // macro rewrite + recursive_rewrite only if CondExor is true after match. -#define TVM_TRY_RECURSIVE_REWRITE_IF(SrcExpr, ResExpr, CondExpr) \ - if ((SrcExpr).Match(ret) && (CondExpr)) { \ - return RecursiveRewrite((ResExpr).Eval()); \ +#define TVM_TRY_RECURSIVE_REWRITE_IF(SrcExpr, ResExpr, CondExpr) \ + if ((SrcExpr).Match(ret) && (CondExpr)) { \ + return RecursiveRewrite((ResExpr).Eval()); \ } // NOTE for developers: @@ -66,8 +69,8 @@ using namespace tir; // // try to prove x equals val -RewriteSimplifier::Impl::CompareResult RewriteSimplifier::Impl:: -TryCompare(const PrimExpr& x, int64_t val) { +RewriteSimplifier::Impl::CompareResult RewriteSimplifier::Impl::TryCompare(const PrimExpr& x, + int64_t val) { PrimExpr diff = this->VisitExpr(x); if (const auto* ptr = diff.as()) { if (ptr->value == val) { @@ -100,26 +103,22 @@ TryCompare(const PrimExpr& x, int64_t val) { return kUnknown; } -void RewriteSimplifier::Impl:: -Update(const Var& var, const PrimExpr& info, bool can_override) { +void RewriteSimplifier::Impl::Update(const Var& var, const PrimExpr& info, bool can_override) { if (!can_override) { auto it = var_map_.find(var); if (it != var_map_.end()) { - CHECK(ExprDeepEqual()(it->second, info)) - << "Trying to update var \'" << var << "\'" - << " with a different value: " - << "original=" << it->second - << ", new=" << info; + CHECK(ExprDeepEqual()(it->second, info)) << "Trying to update var \'" << var << "\'" + << " with a different value: " + << "original=" << it->second << ", new=" << info; } } var_map_[var] = info; } -PrimExpr RewriteSimplifier::Impl:: -VisitExpr_(const AddNode* op) { +PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AddNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); - PrimExpr const_res = TryConstFold(op->a, op->b); + PrimExpr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; // Pattern var to match any expression PVar x, y, z, b1, b2, s1, s2; @@ -129,14 +128,10 @@ VisitExpr_(const AddNode* op) { PVar lanes; // Vector rules if (op->dtype.lanes() != 1) { - TVM_TRY_REWRITE(ramp(b1, s1, lanes) + ramp(b2, s2, lanes), - ramp(b1 + b2, s1 + s2, lanes)); - TVM_TRY_REWRITE(ramp(b1, s1, lanes) + broadcast(x, lanes), - ramp(b1 + x, s1, lanes)); - TVM_TRY_REWRITE(broadcast(x, lanes) + ramp(b1, s1, lanes), - ramp(x + b1, s1, lanes)); - TVM_TRY_REWRITE(broadcast(x, lanes) + broadcast(y, lanes), - broadcast(x + y, lanes)); + TVM_TRY_REWRITE(ramp(b1, s1, lanes) + ramp(b2, s2, lanes), ramp(b1 + b2, s1 + s2, lanes)); + TVM_TRY_REWRITE(ramp(b1, s1, lanes) + broadcast(x, lanes), ramp(b1 + x, s1, lanes)); + TVM_TRY_REWRITE(broadcast(x, lanes) + ramp(b1, s1, lanes), ramp(x + b1, s1, lanes)); + TVM_TRY_REWRITE(broadcast(x, lanes) + broadcast(y, lanes), broadcast(x + y, lanes)); } if (IsIndexType(op->dtype)) { @@ -167,14 +162,10 @@ VisitExpr_(const AddNode* op) { TVM_TRY_REWRITE(max(x, y) + min(y, x), x + y); TVM_TRY_REWRITE(min(x, y) + max(y, x), x + y); - TVM_TRY_REWRITE_IF(min(x, y + c1) + c2, min(x + c2, y), - c1.Eval()->value == -c2.Eval()->value); - TVM_TRY_REWRITE_IF(min(x + c1, y) + c2, min(x, y + c2), - c1.Eval()->value == -c2.Eval()->value); - TVM_TRY_REWRITE_IF(max(x, y + c1) + c2, max(x + c2, y), - c1.Eval()->value == -c2.Eval()->value); - TVM_TRY_REWRITE_IF(max(x + c1, y) + c2, max(x, y + c2), - c1.Eval()->value == -c2.Eval()->value); + TVM_TRY_REWRITE_IF(min(x, y + c1) + c2, min(x + c2, y), c1.Eval()->value == -c2.Eval()->value); + TVM_TRY_REWRITE_IF(min(x + c1, y) + c2, min(x, y + c2), c1.Eval()->value == -c2.Eval()->value); + TVM_TRY_REWRITE_IF(max(x, y + c1) + c2, max(x + c2, y), c1.Eval()->value == -c2.Eval()->value); + TVM_TRY_REWRITE_IF(max(x + c1, y) + c2, max(x, y + c2), c1.Eval()->value == -c2.Eval()->value); // constant folding // NOTE: canonicalization might better at this. @@ -213,15 +204,16 @@ VisitExpr_(const AddNode* op) { } // condition rules. - TVM_TRY_REWRITE(select(x, b1, b2) + select(x, s1, s2), - select(x, b1 + s1, b2 + s2)); + TVM_TRY_REWRITE(select(x, b1, b2) + select(x, s1, s2), select(x, b1 + s1, b2 + s2)); // default value return ret; } std::function RewriteSimplifier::Impl::EnterConstraint(const PrimExpr& constraint) { size_t old_literal_size = literal_constraints_.size(); - literal_constraints_.push_back(constraint); + // we will compare the already simplified result with the constraint, + // so simplify the constarint as well + literal_constraints_.push_back(operator()(constraint)); size_t new_literal_size = literal_constraints_.size(); auto frecover = [old_literal_size, new_literal_size, this]() { CHECK_EQ(literal_constraints_.size(), new_literal_size); @@ -230,11 +222,10 @@ std::function RewriteSimplifier::Impl::EnterConstraint(const PrimExpr& c return frecover; } -PrimExpr RewriteSimplifier::Impl:: -VisitExpr_(const SubNode* op) { +PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SubNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); - PrimExpr const_res = TryConstFold(op->a, op->b); + PrimExpr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; // Pattern var to match any expression PVar x, y, z, b1, b2, s1, s2; @@ -244,14 +235,10 @@ VisitExpr_(const SubNode* op) { PVar lanes; // Vector rules if (op->dtype.lanes() != 1) { - TVM_TRY_REWRITE(ramp(b1, s1, lanes) - ramp(b2, s2, lanes), - ramp(b1 - b2, s1 - s2, lanes)); - TVM_TRY_REWRITE(ramp(b1, s1, lanes) - broadcast(x, lanes), - ramp(b1 - x, s1, lanes)); - TVM_TRY_REWRITE(broadcast(x, lanes) - ramp(b1, s1, lanes), - ramp(x - b1, 0 - s1, lanes)); - TVM_TRY_REWRITE(broadcast(x, lanes) - broadcast(y, lanes), - broadcast(x - y, lanes)); + TVM_TRY_REWRITE(ramp(b1, s1, lanes) - ramp(b2, s2, lanes), ramp(b1 - b2, s1 - s2, lanes)); + TVM_TRY_REWRITE(ramp(b1, s1, lanes) - broadcast(x, lanes), ramp(b1 - x, s1, lanes)); + TVM_TRY_REWRITE(broadcast(x, lanes) - ramp(b1, s1, lanes), ramp(x - b1, 0 - s1, lanes)); + TVM_TRY_REWRITE(broadcast(x, lanes) - broadcast(y, lanes), broadcast(x - y, lanes)); } if (IsIndexType(op->dtype)) { @@ -293,20 +280,20 @@ VisitExpr_(const SubNode* op) { TVM_TRY_REWRITE((y + x) - (z + x), y - z); TVM_TRY_REWRITE((y + x) - (x + z), y - z); - TVM_TRY_REWRITE(min(x + y, z) - x, min(y, z - x)); - TVM_TRY_REWRITE(min(y + x, z) - x, min(y, z - x)); - TVM_TRY_REWRITE(min(z, x + y) - x, min(z - x, y)); - TVM_TRY_REWRITE(min(z, y + x) - x, min(z - x, y)); + TVM_TRY_REWRITE(min(x + y, z) - x, min(y, z - x)); + TVM_TRY_REWRITE(min(y + x, z) - x, min(y, z - x)); + TVM_TRY_REWRITE(min(z, x + y) - x, min(z - x, y)); + TVM_TRY_REWRITE(min(z, y + x) - x, min(z - x, y)); - TVM_TRY_REWRITE(max(x + y, z) - x, max(y, z - x)); - TVM_TRY_REWRITE(max(y + x, z) - x, max(y, z - x)); - TVM_TRY_REWRITE(max(z, x + y) - x, max(z - x, y)); - TVM_TRY_REWRITE(max(z, y + x) - x, max(z - x, y)); + TVM_TRY_REWRITE(max(x + y, z) - x, max(y, z - x)); + TVM_TRY_REWRITE(max(y + x, z) - x, max(y, z - x)); + TVM_TRY_REWRITE(max(z, x + y) - x, max(z - x, y)); + TVM_TRY_REWRITE(max(z, y + x) - x, max(z - x, y)); - TVM_TRY_REWRITE(x - min(x + y, z), max(0 - y, x - z)); - TVM_TRY_REWRITE(x - min(y + x, z), max(0 - y, x - z)); - TVM_TRY_REWRITE(x - min(z, x + y), max(x - z, 0 - y)); - TVM_TRY_REWRITE(x - min(z, y + x), max(x - z, 0 - y)); + TVM_TRY_REWRITE(x - min(x + y, z), max(0 - y, x - z)); + TVM_TRY_REWRITE(x - min(y + x, z), max(0 - y, x - z)); + TVM_TRY_REWRITE(x - min(z, x + y), max(x - z, 0 - y)); + TVM_TRY_REWRITE(x - min(z, y + x), max(x - z, 0 - y)); TVM_TRY_REWRITE(min(x, y) - min(y, x), ZeroWithTypeLike(x)); TVM_TRY_REWRITE(max(x, y) - max(y, x), ZeroWithTypeLike(x)); @@ -324,10 +311,8 @@ VisitExpr_(const SubNode* op) { // DivMod rules // trucdiv // NOTE: c*(x/c) + x % c == x is true all division mode. - TVM_TRY_REWRITE_IF(x - truncdiv(x, c1) * c1, truncmod(x, c1), - c1.Eval()->value != 0); - TVM_TRY_REWRITE_IF(truncdiv(x, c1) * c1 - x, 0 - truncmod(x, c1), - c1.Eval()->value != 0); + TVM_TRY_REWRITE_IF(x - truncdiv(x, c1) * c1, truncmod(x, c1), c1.Eval()->value != 0); + TVM_TRY_REWRITE_IF(truncdiv(x, c1) * c1 - x, 0 - truncmod(x, c1), c1.Eval()->value != 0); TVM_TRY_REWRITE_IF(x - (truncdiv(x + y, c1)) * c1, truncmod(x + y, c1) - y, c1.Eval()->value != 0); TVM_TRY_REWRITE_IF((truncdiv(x + y, c1)) * c1 - x, y - truncmod(x + y, c1), @@ -337,45 +322,40 @@ VisitExpr_(const SubNode* op) { TVM_TRY_REWRITE_IF(truncdiv(x - y, c1) * c1 - x, 0 - truncmod(x - y, c1) - y, c1.Eval()->value != 0); - TVM_TRY_REWRITE_IF(x * c2 - truncdiv(x, c1) * c3, truncmod(x, c1) * c2, - c1.Eval()->value != 0 && - c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); - TVM_TRY_REWRITE_IF(truncdiv(x, c1) * c3 - x * c2, 0 - truncmod(x, c1) * c2, - c1.Eval()->value != 0 && - c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); - TVM_TRY_REWRITE_IF(x * c2 - truncdiv(x + y, c1) * c3, (truncmod(x + y, c1) - y) * c2, - c1.Eval()->value != 0 && - c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); - TVM_TRY_REWRITE_IF(truncdiv(x + y, c1) * c3 - x * c2, (y - truncmod(x + y, c1)) * c2, - c1.Eval()->value != 0 && - c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); - TVM_TRY_REWRITE_IF(x * c2 - truncdiv(x - y, c1) * c3, (truncmod(x - y, c1) + y) * c2, - c1.Eval()->value != 0 && - c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); - TVM_TRY_REWRITE_IF(truncdiv(x - y, c1) * c3 - x * c2, (0 - truncmod(x - y, c1) - y) * c2, - c1.Eval()->value != 0 && - c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); + TVM_TRY_REWRITE_IF( + x * c2 - truncdiv(x, c1) * c3, truncmod(x, c1) * c2, + c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); + TVM_TRY_REWRITE_IF( + truncdiv(x, c1) * c3 - x * c2, 0 - truncmod(x, c1) * c2, + c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); + TVM_TRY_REWRITE_IF( + x * c2 - truncdiv(x + y, c1) * c3, (truncmod(x + y, c1) - y) * c2, + c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); + TVM_TRY_REWRITE_IF( + truncdiv(x + y, c1) * c3 - x * c2, (y - truncmod(x + y, c1)) * c2, + c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); + TVM_TRY_REWRITE_IF( + x * c2 - truncdiv(x - y, c1) * c3, (truncmod(x - y, c1) + y) * c2, + c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); + TVM_TRY_REWRITE_IF( + truncdiv(x - y, c1) * c3 - x * c2, (0 - truncmod(x - y, c1) - y) * c2, + c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); // Proof in the case of floordiv, need positive condition. // let x = a * c3 + r // (x + c1) / c3 - x / c3 => (r + c1) / c3 // NOTE: the use of floormod(c2, c3) was intentional to simplify the const. - TVM_TRY_REWRITE_IF(truncdiv(x + c1, c3) - truncdiv(x + c2, c3), + TVM_TRY_REWRITE_IF(truncdiv(x + c1, c3) - truncdiv(x + c2, c3), truncdiv(truncmod(x + floormod(c2, c3), c3) + (c1 - c2), c3), CanProveGreaterEqual(x.Eval(), -c2.Eval()->value) && - c1.Eval()->value >= c2.Eval()->value && - c3.Eval()->value > 0); - TVM_TRY_REWRITE_IF(truncdiv(x + c1, c3) - truncdiv(x, c3), - truncdiv(truncmod(x, c3) + c1, c3), - CanProveGreaterEqual(x.Eval(), 0) && - c1.Eval()->value >= 0 && - c3.Eval()->value > 0); + c1.Eval()->value >= c2.Eval()->value && c3.Eval()->value > 0); + TVM_TRY_REWRITE_IF( + truncdiv(x + c1, c3) - truncdiv(x, c3), truncdiv(truncmod(x, c3) + c1, c3), + CanProveGreaterEqual(x.Eval(), 0) && c1.Eval()->value >= 0 && c3.Eval()->value > 0); // floordiv - TVM_TRY_REWRITE_IF(x - floordiv(x, c1) * c1, floormod(x, c1), - c1.Eval()->value != 0); - TVM_TRY_REWRITE_IF(floordiv(x, c1) * c1 - x, 0 - floormod(x, c1), - c1.Eval()->value != 0); + TVM_TRY_REWRITE_IF(x - floordiv(x, c1) * c1, floormod(x, c1), c1.Eval()->value != 0); + TVM_TRY_REWRITE_IF(floordiv(x, c1) * c1 - x, 0 - floormod(x, c1), c1.Eval()->value != 0); TVM_TRY_REWRITE_IF(x - floordiv(x + y, c1) * c1, floormod(x + y, c1) - y, c1.Eval()->value != 0); TVM_TRY_REWRITE_IF(floordiv(x + y, c1) * c1 - x, y - floormod(x + y, c1), @@ -385,30 +365,29 @@ VisitExpr_(const SubNode* op) { TVM_TRY_REWRITE_IF(floordiv(x - y, c1) * c1 - x, 0 - floormod(x - y, c1) - y, c1.Eval()->value != 0); - TVM_TRY_REWRITE_IF(x * c2 - floordiv(x, c1) * c3, floormod(x, c1) * c2, - c1.Eval()->value != 0 && - c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); - TVM_TRY_REWRITE_IF(floordiv(x, c1) * c3 - x * c2, 0 - floormod(x, c1) * c2, - c1.Eval()->value != 0 && - c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); - TVM_TRY_REWRITE_IF(x * c2 - floordiv(x + y, c1) * c3, (floormod(x + y, c1) - y) * c2, - c1.Eval()->value != 0 && - c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); - TVM_TRY_REWRITE_IF(floordiv(x + y, c1) * c3 - x * c2, (y - floormod(x + y, c1)) * c2, - c1.Eval()->value != 0 && - c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); - TVM_TRY_REWRITE_IF(x * c2 - floordiv(x - y, c1) * c3, (floormod(x - y, c1) + y) * c2, - c1.Eval()->value != 0 && - c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); - TVM_TRY_REWRITE_IF(floordiv(x - y, c1) * c3 - x * c2, (0 - floormod(x - y, c1) - y) * c2, - c1.Eval()->value != 0 && - c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); + TVM_TRY_REWRITE_IF( + x * c2 - floordiv(x, c1) * c3, floormod(x, c1) * c2, + c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); + TVM_TRY_REWRITE_IF( + floordiv(x, c1) * c3 - x * c2, 0 - floormod(x, c1) * c2, + c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); + TVM_TRY_REWRITE_IF( + x * c2 - floordiv(x + y, c1) * c3, (floormod(x + y, c1) - y) * c2, + c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); + TVM_TRY_REWRITE_IF( + floordiv(x + y, c1) * c3 - x * c2, (y - floormod(x + y, c1)) * c2, + c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); + TVM_TRY_REWRITE_IF( + x * c2 - floordiv(x - y, c1) * c3, (floormod(x - y, c1) + y) * c2, + c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); + TVM_TRY_REWRITE_IF( + floordiv(x - y, c1) * c3 - x * c2, (0 - floormod(x - y, c1) - y) * c2, + c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); TVM_TRY_REWRITE_IF(floordiv(x + c1, c3) - floordiv(x + c2, c3), floordiv(floormod(x + floormod(c2, c3), c3) + (c1 - c2), c3), c3.Eval()->value > 0); - TVM_TRY_REWRITE_IF(floordiv(x + c1, c3) - floordiv(x, c3), - floordiv(floormod(x, c3) + c1, c3), + TVM_TRY_REWRITE_IF(floordiv(x + c1, c3) - floordiv(x, c3), floordiv(floormod(x, c3) + c1, c3), c3.Eval()->value > 0); // canonicalization rule @@ -420,20 +399,16 @@ VisitExpr_(const SubNode* op) { } // condition rules. - TVM_TRY_REWRITE(select(x, b1, b2) - select(x, s1, s2), - select(x, b1 - s1, b2 - s2)); - TVM_TRY_REWRITE(select(x, y, z) - z, - select(x, y - z, ZeroWithTypeLike(z))); - TVM_TRY_REWRITE(select(x, y, z) - y, - select(x, ZeroWithTypeLike(y), z - y)); + TVM_TRY_REWRITE(select(x, b1, b2) - select(x, s1, s2), select(x, b1 - s1, b2 - s2)); + TVM_TRY_REWRITE(select(x, y, z) - z, select(x, y - z, ZeroWithTypeLike(z))); + TVM_TRY_REWRITE(select(x, y, z) - y, select(x, ZeroWithTypeLike(y), z - y)); return ret; } -PrimExpr RewriteSimplifier::Impl:: -VisitExpr_(const MulNode* op) { +PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MulNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); - PrimExpr const_res = TryConstFold(op->a, op->b); + PrimExpr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; // Pattern var to match any expression PVar x, y, z, b1, b2, s1, s2; @@ -443,12 +418,9 @@ VisitExpr_(const MulNode* op) { PVar lanes; // Vector rules if (op->dtype.lanes() != 1) { - TVM_TRY_REWRITE(broadcast(x, lanes) * broadcast(y, lanes), - broadcast(x * y, lanes)); - TVM_TRY_REWRITE(ramp(b1, s1, lanes) * broadcast(x, lanes), - ramp(b1 * x, s1 * x, lanes)); - TVM_TRY_REWRITE(broadcast(x, lanes) * ramp(b1, s1, lanes), - ramp(b1 * x, s1 * x, lanes)); + TVM_TRY_REWRITE(broadcast(x, lanes) * broadcast(y, lanes), broadcast(x * y, lanes)); + TVM_TRY_REWRITE(ramp(b1, s1, lanes) * broadcast(x, lanes), ramp(b1 * x, s1 * x, lanes)); + TVM_TRY_REWRITE(broadcast(x, lanes) * ramp(b1, s1, lanes), ramp(b1 * x, s1 * x, lanes)); } if (IsIndexType(op->dtype)) { @@ -461,18 +433,15 @@ VisitExpr_(const MulNode* op) { // canonicalization TVM_TRY_RECURSIVE_REWRITE(x * (c1 * y), (x * y) * c1); TVM_TRY_RECURSIVE_REWRITE(c1 * x, x * c1); - TVM_TRY_RECURSIVE_REWRITE_IF( - (x - y) * c1, (y - x) * (0 - c1), - c1.Eval()->value < 0); + TVM_TRY_RECURSIVE_REWRITE_IF((x - y) * c1, (y - x) * (0 - c1), c1.Eval()->value < 0); } return ret; } -PrimExpr RewriteSimplifier::Impl:: -VisitExpr_(const DivNode* op) { +PrimExpr RewriteSimplifier::Impl::VisitExpr_(const DivNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); - PrimExpr const_res = TryConstFold(op->a, op->b); + PrimExpr const_res = TryConstFold
(op->a, op->b); if (const_res.defined()) return const_res; // Pattern var to match any expression PVar x, y, z, b1; @@ -490,8 +459,7 @@ VisitExpr_(const DivNode* op) { // Vector rules if (op->dtype.lanes() != 1) { // NOTE: use div as the pattern also works for float. - TVM_TRY_REWRITE(div(broadcast(x, lanes), broadcast(y, lanes)), - broadcast(div(x, y), lanes)); + TVM_TRY_REWRITE(div(broadcast(x, lanes), broadcast(y, lanes)), broadcast(div(x, y), lanes)); // ramp / bcast if ((div(ramp(b1, c1, lanes), broadcast(c2, lanes))).Match(ret)) { int64_t c1val = c1.Eval()->value; @@ -532,10 +500,8 @@ VisitExpr_(const DivNode* op) { c1.Eval()->value > 0 && c2.Eval()->value > 0); TVM_TRY_REWRITE_IF(truncdiv(truncdiv(x, c1) + c2, c3), truncdiv(x + c1 * c2, c1 * c3), - c1.Eval()->value > 0 && - c2.Eval()->value >= 0 && - c3.Eval()->value > 0 && - CanProveGreaterEqual(x.Eval(), 0)); + c1.Eval()->value > 0 && c2.Eval()->value >= 0 && c3.Eval()->value > 0 && + CanProveGreaterEqual(x.Eval(), 0)); if (truncdiv(x * c1, c2).Match(ret)) { int64_t c1val = c1.Eval()->value; @@ -551,150 +517,105 @@ VisitExpr_(const DivNode* op) { TVM_TRY_REWRITE(truncdiv(c1 * x, x), c1); // Rules involving 2-operands. - TVM_TRY_REWRITE_IF(truncdiv(x * c1 + y, c2), - x * truncdiv(c1, c2) + truncdiv(y, c2), - c1.Eval()->value >= 0 && - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0 && - CanProveGreaterEqual(x.Eval(), 0) && - CanProveGreaterEqual(y.Eval(), 0)); - - TVM_TRY_REWRITE_IF(truncdiv(min(x * c1, y), c2), - min(x * truncdiv(c1, c2), truncdiv(y, c2)), - c1.Eval()->value >= 0 && - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0 && - CanProveGreaterEqual(x.Eval(), 0) && - CanProveGreaterEqual(y.Eval(), 0)); - - TVM_TRY_REWRITE_IF(truncdiv(max(x * c1, y), c2), - max(x * truncdiv(c1, c2), truncdiv(y, c2)), - c1.Eval()->value >= 0 && - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0 && - CanProveGreaterEqual(x.Eval(), 0) && - CanProveGreaterEqual(y.Eval(), 0)); - - TVM_TRY_REWRITE_IF(truncdiv(y + x * c1, c2), - truncdiv(y, c2) + x * truncdiv(c1, c2), - c1.Eval()->value >= 0 && - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0 && - CanProveGreaterEqual(x.Eval(), 0) && - CanProveGreaterEqual(y.Eval(), 0)); - - TVM_TRY_REWRITE_IF(truncdiv(min(y, x * c1), c2), - min(truncdiv(y, c2), x * truncdiv(c1, c2)), - c1.Eval()->value >= 0 && - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0 && - CanProveGreaterEqual(x.Eval(), 0) && - CanProveGreaterEqual(y.Eval(), 0)); - - TVM_TRY_REWRITE_IF(truncdiv(max(y, x * c1), c2), - max(truncdiv(y, c2), x * truncdiv(c1, c2)), - c1.Eval()->value >= 0 && - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0 && - CanProveGreaterEqual(x.Eval(), 0) && - CanProveGreaterEqual(y.Eval(), 0)); + TVM_TRY_REWRITE_IF(truncdiv(x * c1 + y, c2), x * truncdiv(c1, c2) + truncdiv(y, c2), + c1.Eval()->value >= 0 && c2.Eval()->value > 0 && + c1.Eval()->value % c2.Eval()->value == 0 && + CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0)); + + TVM_TRY_REWRITE_IF(truncdiv(min(x * c1, y), c2), min(x * truncdiv(c1, c2), truncdiv(y, c2)), + c1.Eval()->value >= 0 && c2.Eval()->value > 0 && + c1.Eval()->value % c2.Eval()->value == 0 && + CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0)); + + TVM_TRY_REWRITE_IF(truncdiv(max(x * c1, y), c2), max(x * truncdiv(c1, c2), truncdiv(y, c2)), + c1.Eval()->value >= 0 && c2.Eval()->value > 0 && + c1.Eval()->value % c2.Eval()->value == 0 && + CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0)); + + TVM_TRY_REWRITE_IF(truncdiv(y + x * c1, c2), truncdiv(y, c2) + x * truncdiv(c1, c2), + c1.Eval()->value >= 0 && c2.Eval()->value > 0 && + c1.Eval()->value % c2.Eval()->value == 0 && + CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0)); + + TVM_TRY_REWRITE_IF(truncdiv(min(y, x * c1), c2), min(truncdiv(y, c2), x * truncdiv(c1, c2)), + c1.Eval()->value >= 0 && c2.Eval()->value > 0 && + c1.Eval()->value % c2.Eval()->value == 0 && + CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0)); + + TVM_TRY_REWRITE_IF(truncdiv(max(y, x * c1), c2), max(truncdiv(y, c2), x * truncdiv(c1, c2)), + c1.Eval()->value >= 0 && c2.Eval()->value > 0 && + c1.Eval()->value % c2.Eval()->value == 0 && + CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0)); // Rules involving 3-operands. - TVM_TRY_REWRITE_IF(truncdiv(x * c1 + y + z, c2), - x * truncdiv(c1, c2) + truncdiv(y + z, c2), - c1.Eval()->value >= 0 && - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0 && - CanProveGreaterEqual(x.Eval(), 0) && - CanProveGreaterEqual((y + z).Eval(), 0)); - - TVM_TRY_REWRITE_IF(truncdiv(x * c1 - y + z, c2), - x * truncdiv(c1, c2) + truncdiv(z - y, c2), - c1.Eval()->value >= 0 && - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0 && - CanProveGreaterEqual(x.Eval(), 0) && - CanProveGreaterEqual((z - y).Eval(), 0)); - - TVM_TRY_REWRITE_IF(truncdiv(x * c1 + y - z, c2), - x * truncdiv(c1, c2) + truncdiv(y - z, c2), - c1.Eval()->value >= 0 && - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0 && - CanProveGreaterEqual(x.Eval(), 0) && - CanProveGreaterEqual((y - z).Eval(), 0)); - - TVM_TRY_REWRITE_IF(truncdiv(y + x * c1 + z, c2), - x * truncdiv(c1, c2) + truncdiv(y + z, c2), - c1.Eval()->value > 0 && - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0 && - CanProveGreaterEqual(x.Eval(), 0) && - CanProveGreaterEqual((y + z).Eval(), 0)); - - TVM_TRY_REWRITE_IF(truncdiv(x + c1, c2), - truncdiv(x, c2) + truncdiv(c1, c2), - c1.Eval()->value > 0 && - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0 && - CanProveGreaterEqual(x.Eval(), 0)); + TVM_TRY_REWRITE_IF( + truncdiv(x * c1 + y + z, c2), x * truncdiv(c1, c2) + truncdiv(y + z, c2), + c1.Eval()->value >= 0 && c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0 && + CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual((y + z).Eval(), 0)); + + TVM_TRY_REWRITE_IF( + truncdiv(x * c1 - y + z, c2), x * truncdiv(c1, c2) + truncdiv(z - y, c2), + c1.Eval()->value >= 0 && c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0 && + CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual((z - y).Eval(), 0)); + + TVM_TRY_REWRITE_IF( + truncdiv(x * c1 + y - z, c2), x * truncdiv(c1, c2) + truncdiv(y - z, c2), + c1.Eval()->value >= 0 && c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0 && + CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual((y - z).Eval(), 0)); + + TVM_TRY_REWRITE_IF( + truncdiv(y + x * c1 + z, c2), x * truncdiv(c1, c2) + truncdiv(y + z, c2), + c1.Eval()->value > 0 && c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0 && + CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual((y + z).Eval(), 0)); + + TVM_TRY_REWRITE_IF(truncdiv(x + c1, c2), truncdiv(x, c2) + truncdiv(c1, c2), + c1.Eval()->value > 0 && c2.Eval()->value > 0 && + c1.Eval()->value % c2.Eval()->value == 0 && + CanProveGreaterEqual(x.Eval(), 0)); TVM_TRY_REWRITE_IF(truncdiv(x + y, x), truncdiv(y, x) + 1, - CanProveGreaterEqual(x.Eval(), 0) && - CanProveGreaterEqual(y.Eval(), 0)); + CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0)); TVM_TRY_REWRITE_IF(truncdiv(y + x, x), truncdiv(y, x) + 1, - CanProveGreaterEqual(x.Eval(), 0) && - CanProveGreaterEqual(y.Eval(), 0)); - - TVM_TRY_REWRITE_IF(truncdiv((x + y) + z, x), - truncdiv(y + z, x) + 1, - CanProveGreaterEqual(x.Eval(), 0) && - CanProveGreaterEqual((y + z).Eval(), 0)); - TVM_TRY_REWRITE_IF(truncdiv((y + x) + z, x), - truncdiv(y + z, x) + 1, - CanProveGreaterEqual(x.Eval(), 0) && - CanProveGreaterEqual((y + z).Eval(), 0)); - TVM_TRY_REWRITE_IF(truncdiv(y + (z + x), x), - truncdiv(y + z, x) + 1, - CanProveGreaterEqual(x.Eval(), 0) && - CanProveGreaterEqual((y + z).Eval(), 0)); - TVM_TRY_REWRITE_IF(truncdiv(y + (x + z), x), - truncdiv(y + z, x) + 1, - CanProveGreaterEqual(x.Eval(), 0) && - CanProveGreaterEqual((y + z).Eval(), 0)); + CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0)); + + TVM_TRY_REWRITE_IF( + truncdiv((x + y) + z, x), truncdiv(y + z, x) + 1, + CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual((y + z).Eval(), 0)); + TVM_TRY_REWRITE_IF( + truncdiv((y + x) + z, x), truncdiv(y + z, x) + 1, + CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual((y + z).Eval(), 0)); + TVM_TRY_REWRITE_IF( + truncdiv(y + (z + x), x), truncdiv(y + z, x) + 1, + CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual((y + z).Eval(), 0)); + TVM_TRY_REWRITE_IF( + truncdiv(y + (x + z), x), truncdiv(y + z, x) + 1, + CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual((y + z).Eval(), 0)); TVM_TRY_REWRITE_IF(truncdiv(x * y, y), x, - CanProveGreaterEqual(x.Eval(), 0) && - CanProveGreaterEqual(y.Eval(), 0)); + CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0)); TVM_TRY_REWRITE_IF(truncdiv(y * x, y), x, - CanProveGreaterEqual(x.Eval(), 0) && - CanProveGreaterEqual(y.Eval(), 0)); + CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0)); TVM_TRY_REWRITE_IF(truncdiv(x * z + y, z), x + truncdiv(y, z), - CanProveGreaterEqual(x.Eval(), 0) && - CanProveGreaterEqual(y.Eval(), 0) && - CanProveGreaterEqual(z.Eval(), 0)); + CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0) && + CanProveGreaterEqual(z.Eval(), 0)); TVM_TRY_REWRITE_IF(truncdiv(z * x + y, z), x + truncdiv(y, z), - CanProveGreaterEqual(x.Eval(), 0) && - CanProveGreaterEqual(y.Eval(), 0) && - CanProveGreaterEqual(z.Eval(), 0)); + CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0) && + CanProveGreaterEqual(z.Eval(), 0)); TVM_TRY_REWRITE_IF(truncdiv(y + x * z, z), truncdiv(y, z) + x, - CanProveGreaterEqual(x.Eval(), 0) && - CanProveGreaterEqual(y.Eval(), 0) && - CanProveGreaterEqual(z.Eval(), 0)); + CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0) && + CanProveGreaterEqual(z.Eval(), 0)); TVM_TRY_REWRITE_IF(truncdiv(y + z * x, z), truncdiv(y, z) + x, - CanProveGreaterEqual(x.Eval(), 0) && - CanProveGreaterEqual(y.Eval(), 0) && - CanProveGreaterEqual(z.Eval(), 0)); + CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0) && + CanProveGreaterEqual(z.Eval(), 0)); } return ret; } -PrimExpr RewriteSimplifier::Impl:: -VisitExpr_(const ModNode* op) { +PrimExpr RewriteSimplifier::Impl::VisitExpr_(const ModNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); - PrimExpr const_res = TryConstFold(op->a, op->b); + PrimExpr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; // Pattern var to match any expression @@ -725,8 +646,7 @@ VisitExpr_(const ModNode* op) { if (ramp_min == ramp_max) { return ramp(truncmod(bmod->base, c2), c1, lanes).Eval(); } else { - return truncmod(ramp(truncmod(bmod->base, c2), c1, lanes), - broadcast(c2, lanes)).Eval(); + return truncmod(ramp(truncmod(bmod->base, c2), c1, lanes), broadcast(c2, lanes)).Eval(); } } } @@ -738,41 +658,34 @@ VisitExpr_(const ModNode* op) { // We adopt the default C division uses truncation instead of floordiv. // This means most rules need to check non-negativeness of the operands. TVM_TRY_REWRITE_IF(truncmod(x * c1, c2), ZeroWithTypeLike(x), - c2.Eval()->value != 0 && - c1.Eval()->value % c2.Eval()->value == 0); + c2.Eval()->value != 0 && c1.Eval()->value % c2.Eval()->value == 0); TVM_TRY_REWRITE_IF(truncmod(x * c1 + y, c2), truncmod(y, c2), - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0 && - CanProveGreaterEqual((x * c1).Eval(), 0) && - CanProveGreaterEqual(y.Eval(), 0)); + c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0 && + CanProveGreaterEqual((x * c1).Eval(), 0) && + CanProveGreaterEqual(y.Eval(), 0)); TVM_TRY_REWRITE_IF(truncmod(x + c1, c2), truncmod(x, c2), - c2.Eval()->value > 0 && - c1.Eval()->value >= 0 && - c1.Eval()->value % c2.Eval()->value == 0 && - CanProveGreaterEqual(x.Eval(), 0)); + c2.Eval()->value > 0 && c1.Eval()->value >= 0 && + c1.Eval()->value % c2.Eval()->value == 0 && + CanProveGreaterEqual(x.Eval(), 0)); TVM_TRY_REWRITE_IF(truncmod(x + y * c1, c2), truncmod(x, c2), - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0 && - CanProveGreaterEqual(x.Eval(), 0) && - CanProveGreaterEqual((y * c1).Eval(), 0)); + c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0 && + CanProveGreaterEqual(x.Eval(), 0) && + CanProveGreaterEqual((y * c1).Eval(), 0)); // canonicalization: x % c == x % (-c) for truncated division // NOTE: trunc div required TVM_TRY_RECURSIVE_REWRITE_IF( - truncmod(x, c1), - truncmod(x, PConst(make_const(op->dtype, -c1.Eval()->value))), + truncmod(x, c1), truncmod(x, PConst(make_const(op->dtype, -c1.Eval()->value))), c1.Eval()->value < 0); // try modular analysis if (truncmod(x, c1).Match(ret)) { ModularSet mod = analyzer_->modular_set(x.Eval()); int64_t c1val = c1.Eval()->value; - if (mod->coeff % c1val == 0 && - c1val > 0 && - CanProveGreaterEqual(x.Eval(), 0)) { + if (mod->coeff % c1val == 0 && c1val > 0 && CanProveGreaterEqual(x.Eval(), 0)) { return truncmod(mod->base, c1).Eval(); } } @@ -780,11 +693,10 @@ VisitExpr_(const ModNode* op) { return ret; } -PrimExpr RewriteSimplifier::Impl:: -VisitExpr_(const FloorDivNode* op) { +PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); - PrimExpr const_res = TryConstFold(op->a, op->b); + PrimExpr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; // Pattern var to match any expression PVar x, y, z, b1; @@ -836,67 +748,43 @@ VisitExpr_(const FloorDivNode* op) { TVM_TRY_REWRITE(floordiv(c1 * x, x), c1); // Rules involving 2-operands. - TVM_TRY_REWRITE_IF(floordiv(x * c1 + y, c2), - x * floordiv(c1, c2) + floordiv(y, c2), - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0); - - TVM_TRY_REWRITE_IF(floordiv(min(x * c1, y), c2), - min(x * floordiv(c1, c2), floordiv(y, c2)), - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0); - - TVM_TRY_REWRITE_IF(floordiv(max(x * c1, y), c2), - max(x * floordiv(c1, c2), floordiv(y, c2)), - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0); - - TVM_TRY_REWRITE_IF(floordiv(y + x * c1, c2), - floordiv(y, c2) + x * floordiv(c1, c2), - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0); - - TVM_TRY_REWRITE_IF(floordiv(min(y, x * c1), c2), - min(floordiv(y, c2), x * floordiv(c1, c2)), - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0); - - TVM_TRY_REWRITE_IF(floordiv(max(y, x * c1), c2), - max(floordiv(y, c2), x * floordiv(c1, c2)), - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0); + TVM_TRY_REWRITE_IF(floordiv(x * c1 + y, c2), x * floordiv(c1, c2) + floordiv(y, c2), + c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); + + TVM_TRY_REWRITE_IF(floordiv(min(x * c1, y), c2), min(x * floordiv(c1, c2), floordiv(y, c2)), + c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); + + TVM_TRY_REWRITE_IF(floordiv(max(x * c1, y), c2), max(x * floordiv(c1, c2), floordiv(y, c2)), + c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); + + TVM_TRY_REWRITE_IF(floordiv(y + x * c1, c2), floordiv(y, c2) + x * floordiv(c1, c2), + c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); + + TVM_TRY_REWRITE_IF(floordiv(min(y, x * c1), c2), min(floordiv(y, c2), x * floordiv(c1, c2)), + c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); + + TVM_TRY_REWRITE_IF(floordiv(max(y, x * c1), c2), max(floordiv(y, c2), x * floordiv(c1, c2)), + c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); // Rules involving 3-operands. - TVM_TRY_REWRITE_IF(floordiv(x * c1 + y + z, c2), - x * floordiv(c1, c2) + floordiv(y + z, c2), - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0); - - TVM_TRY_REWRITE_IF(floordiv(x * c1 - y + z, c2), - x * floordiv(c1, c2) + floordiv(z - y, c2), - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0); - - TVM_TRY_REWRITE_IF(floordiv(x * c1 + y - z, c2), - x * floordiv(c1, c2) + floordiv(y - z, c2), - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0); - - TVM_TRY_REWRITE_IF(floordiv(y + x * c1 + z, c2), - x * floordiv(c1, c2) + floordiv(y + z, c2), - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0); - - TVM_TRY_REWRITE_IF(floordiv(x + c1, c2), - floordiv(x, c2) + floordiv(c1, c2), - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0); - - TVM_TRY_REWRITE_IF(floordiv(x + y, x), floordiv(y, x) + 1, - CanProveGreaterEqual(x.Eval(), 0)); + TVM_TRY_REWRITE_IF(floordiv(x * c1 + y + z, c2), x * floordiv(c1, c2) + floordiv(y + z, c2), + c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); - TVM_TRY_REWRITE_IF(floordiv(y + x, x), floordiv(y, x) + 1, - CanProveGreaterEqual(x.Eval(), 0)); + TVM_TRY_REWRITE_IF(floordiv(x * c1 - y + z, c2), x * floordiv(c1, c2) + floordiv(z - y, c2), + c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); + + TVM_TRY_REWRITE_IF(floordiv(x * c1 + y - z, c2), x * floordiv(c1, c2) + floordiv(y - z, c2), + c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); + + TVM_TRY_REWRITE_IF(floordiv(y + x * c1 + z, c2), x * floordiv(c1, c2) + floordiv(y + z, c2), + c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); + + TVM_TRY_REWRITE_IF(floordiv(x + c1, c2), floordiv(x, c2) + floordiv(c1, c2), + c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); + + TVM_TRY_REWRITE_IF(floordiv(x + y, x), floordiv(y, x) + 1, CanProveGreaterEqual(x.Eval(), 0)); + + TVM_TRY_REWRITE_IF(floordiv(y + x, x), floordiv(y, x) + 1, CanProveGreaterEqual(x.Eval(), 0)); TVM_TRY_REWRITE_IF(floordiv((x + y) + z, x), floordiv(y + z, x) + 1, CanProveGreaterEqual(x.Eval(), 0)); @@ -907,10 +795,8 @@ VisitExpr_(const FloorDivNode* op) { TVM_TRY_REWRITE_IF(floordiv(y + (x + z), x), floordiv(y + z, x) + 1, CanProveGreaterEqual(x.Eval(), 0)); - TVM_TRY_REWRITE_IF(floordiv(x * y, y), x, - CanProveGreaterEqual(y.Eval(), 0)); - TVM_TRY_REWRITE_IF(floordiv(y * x, y), x, - CanProveGreaterEqual(y.Eval(), 0)); + TVM_TRY_REWRITE_IF(floordiv(x * y, y), x, CanProveGreaterEqual(y.Eval(), 0)); + TVM_TRY_REWRITE_IF(floordiv(y * x, y), x, CanProveGreaterEqual(y.Eval(), 0)); TVM_TRY_REWRITE_IF(floordiv(x * z + y, z), x + floordiv(y, z), CanProveGreaterEqual(z.Eval(), 0)); @@ -924,11 +810,10 @@ VisitExpr_(const FloorDivNode* op) { return ret; } -PrimExpr RewriteSimplifier::Impl:: -VisitExpr_(const FloorModNode* op) { +PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); - PrimExpr const_res = TryConstFold(op->a, op->b); + PrimExpr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; // Pattern var to match any expression @@ -967,20 +852,16 @@ VisitExpr_(const FloorModNode* op) { if (IsIndexType(op->dtype)) { // Be-aware of the division rules: we use floordiv/floormod here TVM_TRY_REWRITE_IF(floormod(x * c1, c2), ZeroWithTypeLike(x), - c2.Eval()->value != 0 && - c1.Eval()->value % c2.Eval()->value == 0); + c2.Eval()->value != 0 && c1.Eval()->value % c2.Eval()->value == 0); TVM_TRY_REWRITE_IF(floormod(x * c1 + y, c2), floormod(y, c2), - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0); + c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); TVM_TRY_REWRITE_IF(floormod(x + c1, c2), floormod(x, c2), - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0); + c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); TVM_TRY_REWRITE_IF(floormod(x + y * c1, c2), floormod(x, c2), - c2.Eval()->value > 0 && - c1.Eval()->value % c2.Eval()->value == 0); + c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); // try modular analysis if (floormod(x, c1).Match(ret)) { @@ -994,11 +875,10 @@ VisitExpr_(const FloorModNode* op) { return ret; } -PrimExpr RewriteSimplifier::Impl:: -VisitExpr_(const MinNode* op) { +PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MinNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); - PrimExpr const_res = TryConstFold(op->a, op->b); + PrimExpr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; // Pattern var to match any expression @@ -1009,8 +889,7 @@ VisitExpr_(const MinNode* op) { // vector rule if (op->dtype.lanes() != 1) { - TVM_TRY_REWRITE(min(broadcast(x, lanes), broadcast(y, lanes)), - broadcast(min(x, y), lanes)); + TVM_TRY_REWRITE(min(broadcast(x, lanes), broadcast(y, lanes)), broadcast(min(x, y), lanes)); TVM_TRY_REWRITE(min(min(x, broadcast(y, lanes)), broadcast(z, lanes)), min(x, broadcast(min(y, z), lanes))); } @@ -1035,8 +914,7 @@ VisitExpr_(const MinNode* op) { return (x + c2).Eval(); } } - if (min(x + c1, x).Match(ret) || - min(x, x + c1).Match(ret)) { + if (min(x + c1, x).Match(ret) || min(x, x + c1).Match(ret)) { if (c1.Eval()->value < 0) { return (x + c1).Eval(); } else { @@ -1055,40 +933,30 @@ VisitExpr_(const MinNode* op) { // Divide up rounding: truc div // NOTE: trucdiv(x, y) >= floordiv(x, y) TVM_TRY_REWRITE_IF(min(truncdiv(x + c1, c2) * c2, x), x, - c2.Eval()->value > 0 && - c1.Eval()->value + 1 == c2.Eval()->value); + c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value); TVM_TRY_REWRITE_IF(min(truncdiv(x + c1, c2) * c2, max(x, c2)), max(x, c2), - c2.Eval()->value > 0 && - c1.Eval()->value + 1 == c2.Eval()->value && - CanProveGreaterEqual(x.Eval(), 0)); + c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value && + CanProveGreaterEqual(x.Eval(), 0)); TVM_TRY_REWRITE_IF(min(x, truncdiv(x + c1, c2) * c2), x, - c2.Eval()->value > 0 && - c1.Eval()->value + 1 == c2.Eval()->value); + c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value); TVM_TRY_REWRITE_IF(min(max(x, c2), truncdiv(x + c1, c2) * c2), max(x, c2), - c2.Eval()->value > 0 && - c1.Eval()->value + 1 == c2.Eval()->value && - CanProveGreaterEqual(x.Eval(), 0)); + c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value && + CanProveGreaterEqual(x.Eval(), 0)); // Divide up rounding: floor div TVM_TRY_REWRITE_IF(min(floordiv(x + c1, c2) * c2, x), x, - c2.Eval()->value > 0 && - c1.Eval()->value + 1 == c2.Eval()->value); + c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value); TVM_TRY_REWRITE_IF(min(floordiv(x + c1, c2) * c2, max(x, c2)), max(x, c2), - c2.Eval()->value > 0 && - c1.Eval()->value + 1 == c2.Eval()->value); + c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value); TVM_TRY_REWRITE_IF(min(x, floordiv(x + c1, c2) * c2), x, - c2.Eval()->value > 0 && - c1.Eval()->value + 1 == c2.Eval()->value); + c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value); TVM_TRY_REWRITE_IF(min(max(x, c2), floordiv(x + c1, c2) * c2), max(x, c2), - c2.Eval()->value > 0 && - c1.Eval()->value + 1 == c2.Eval()->value); + c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value); - TVM_TRY_REWRITE_IF(min(x, floordiv(x, c2) * c2), floordiv(x, c2) * c2, - c2.Eval()->value > 0); - TVM_TRY_REWRITE_IF(min(floordiv(x, c2) * c2, x), floordiv(x, c2) * c2, - c2.Eval()->value > 0); + TVM_TRY_REWRITE_IF(min(x, floordiv(x, c2) * c2), floordiv(x, c2) * c2, c2.Eval()->value > 0); + TVM_TRY_REWRITE_IF(min(floordiv(x, c2) * c2, x), floordiv(x, c2) * c2, c2.Eval()->value > 0); TVM_TRY_REWRITE(min(max(x, y), min(x, y)), min(x, y)); TVM_TRY_REWRITE(min(max(x, y), min(y, x)), min(x, y)); @@ -1157,8 +1025,11 @@ VisitExpr_(const MinNode* op) { if (min(x * c1, c2).Match(ret)) { int64_t c1val = c1.Eval()->value; int64_t c2val = c2.Eval()->value; + if (c1val == 0) { + return c2val < 0 ? c2.Eval() : c1.Eval(); + } if (c2val % c1val == 0) { - if (c2val / c1val >= 0) { + if (c1val > 0) { return (min(x, c2val / c1val) * c1val).Eval(); } else { return (max(x, c2val / c1val) * c1val).Eval(); @@ -1168,22 +1039,18 @@ VisitExpr_(const MinNode* op) { // canonicalization TVM_TRY_RECURSIVE_REWRITE(min(min(x, c1), y), min(min(x, y), c1)); - TVM_TRY_RECURSIVE_REWRITE_IF( - min(c1 - x, c2), c1 - max(x, c1 - c2), - c2.Eval()->value != 0); + TVM_TRY_RECURSIVE_REWRITE_IF(min(c1 - x, c2), c1 - max(x, c1 - c2), c2.Eval()->value != 0); } // condition rules. - TVM_TRY_REWRITE(min(select(x, y, z), select(x, s1, s2)), - select(x, min(y, s1), min(z, s2))); + TVM_TRY_REWRITE(min(select(x, y, z), select(x, s1, s2)), select(x, min(y, s1), min(z, s2))); return ret; } -PrimExpr RewriteSimplifier::Impl:: -VisitExpr_(const MaxNode* op) { +PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MaxNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); - PrimExpr const_res = TryConstFold(op->a, op->b); + PrimExpr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; // Pattern var to match any expression @@ -1194,8 +1061,7 @@ VisitExpr_(const MaxNode* op) { // vector rule if (op->dtype.lanes() != 1) { - TVM_TRY_REWRITE(max(broadcast(x, lanes), broadcast(y, lanes)), - broadcast(max(x, y), lanes)); + TVM_TRY_REWRITE(max(broadcast(x, lanes), broadcast(y, lanes)), broadcast(max(x, y), lanes)); TVM_TRY_REWRITE(max(max(x, broadcast(y, lanes)), broadcast(z, lanes)), max(x, broadcast(max(y, z), lanes))); } @@ -1220,8 +1086,7 @@ VisitExpr_(const MaxNode* op) { return (x + c2).Eval(); } } - if (max(x + c1, x).Match(ret) || - max(x, x + c1).Match(ret)) { + if (max(x + c1, x).Match(ret) || max(x, x + c1).Match(ret)) { if (c1.Eval()->value > 0) { return (x + c1).Eval(); } else { @@ -1239,27 +1104,19 @@ VisitExpr_(const MaxNode* op) { // DivMod rules // Divide up rounding: truc div // NOTE: trucdiv(x, y) >= floordiv(x, y) - TVM_TRY_REWRITE_IF(max(truncdiv(x + c1, c2) * c2, x), - truncdiv(x + c1, c2) * c2, - c2.Eval()->value > 0 && - c1.Eval()->value + 1 == c2.Eval()->value); - TVM_TRY_REWRITE_IF(max(x, truncdiv(x + c1, c2) * c2), - truncdiv(x + c1, c2) * c2, - c2.Eval()->value > 0 && - c1.Eval()->value + 1 == c2.Eval()->value); + TVM_TRY_REWRITE_IF(max(truncdiv(x + c1, c2) * c2, x), truncdiv(x + c1, c2) * c2, + c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value); + TVM_TRY_REWRITE_IF(max(x, truncdiv(x + c1, c2) * c2), truncdiv(x + c1, c2) * c2, + c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value); // Divide up rounding: floor div TVM_TRY_REWRITE_IF(max(floordiv(x + c1, c2) * c2, x), floordiv(x + c1, c2) * c2, - c2.Eval()->value > 0 && - c1.Eval()->value + 1 == c2.Eval()->value); + c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value); TVM_TRY_REWRITE_IF(max(x, floordiv(x + c1, c2) * c2), floordiv(x + c1, c2) * c2, - c2.Eval()->value > 0 && - c1.Eval()->value + 1 == c2.Eval()->value); + c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value); - TVM_TRY_REWRITE_IF(max(floordiv(x, c2) * c2, x), x, - c2.Eval()->value > 0); - TVM_TRY_REWRITE_IF(max(x, floordiv(x, c2) * c2), x, - c2.Eval()->value > 0); + TVM_TRY_REWRITE_IF(max(floordiv(x, c2) * c2, x), x, c2.Eval()->value > 0); + TVM_TRY_REWRITE_IF(max(x, floordiv(x, c2) * c2), x, c2.Eval()->value > 0); TVM_TRY_REWRITE(max(min(x, y), max(x, y)), max(x, y)); TVM_TRY_REWRITE(max(min(x, y), max(y, x)), max(x, y)); @@ -1331,8 +1188,11 @@ VisitExpr_(const MaxNode* op) { if (max(x * c1, c2).Match(ret)) { int64_t c1val = c1.Eval()->value; int64_t c2val = c2.Eval()->value; + if (c1val == 0) { + return c2val > 0 ? c2.Eval() : c1.Eval(); + } if (c2val % c1val == 0) { - if (c2val / c1val >= 0) { + if (c1val > 0) { return (max(x, c2val / c1val) * c1val).Eval(); } else { return (min(x, c2val / c1val) * c1val).Eval(); @@ -1342,21 +1202,18 @@ VisitExpr_(const MaxNode* op) { // canonicalization TVM_TRY_RECURSIVE_REWRITE(max(max(x, c1), y), max(max(x, y), c1)); - TVM_TRY_RECURSIVE_REWRITE_IF( - max(c1 - x, c2), c1 - min(x, c1 - c2), c2.Eval()->value != 0); + TVM_TRY_RECURSIVE_REWRITE_IF(max(c1 - x, c2), c1 - min(x, c1 - c2), c2.Eval()->value != 0); } // condition rules. - TVM_TRY_REWRITE(max(select(x, y, z), select(x, s1, s2)), - select(x, max(y, s1), max(z, s2))); + TVM_TRY_REWRITE(max(select(x, y, z), select(x, s1, s2)), select(x, max(y, s1), max(z, s2))); return ret; } -PrimExpr RewriteSimplifier::Impl:: -VisitExpr_(const EQNode* op) { +PrimExpr RewriteSimplifier::Impl::VisitExpr_(const EQNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); - PrimExpr const_res = TryConstFold(op->a, op->b); + PrimExpr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; // Pattern var to match any expression @@ -1367,8 +1224,7 @@ VisitExpr_(const EQNode* op) { // vector rule if (op->dtype.lanes() != 1) { - TVM_TRY_REWRITE(broadcast(x, lanes) == broadcast(y, lanes), - broadcast(x == y, lanes)); + TVM_TRY_REWRITE(broadcast(x, lanes) == broadcast(y, lanes), broadcast(x == y, lanes)); } if (IsIndexType(op->a.dtype())) { @@ -1386,31 +1242,26 @@ VisitExpr_(const EQNode* op) { return ret; } -PrimExpr RewriteSimplifier::Impl:: -VisitExpr_(const NENode* op) { - return this->VisitExpr(NotNode::make(op->a == op->b)); +PrimExpr RewriteSimplifier::Impl::VisitExpr_(const NENode* op) { + return this->VisitExpr(Not(op->a == op->b)); } -PrimExpr RewriteSimplifier::Impl:: -VisitExpr_(const LENode* op) { - return this->VisitExpr(NotNode::make(op->b < op->a)); +PrimExpr RewriteSimplifier::Impl::VisitExpr_(const LENode* op) { + return this->VisitExpr(Not(op->b < op->a)); } -PrimExpr RewriteSimplifier::Impl:: -VisitExpr_(const GTNode* op) { +PrimExpr RewriteSimplifier::Impl::VisitExpr_(const GTNode* op) { return this->VisitExpr(op->b < op->a); } -PrimExpr RewriteSimplifier::Impl:: -VisitExpr_(const GENode* op) { - return this->VisitExpr(NotNode::make(op->a < op->b)); +PrimExpr RewriteSimplifier::Impl::VisitExpr_(const GENode* op) { + return this->VisitExpr(Not(op->a < op->b)); } -PrimExpr RewriteSimplifier::Impl:: -VisitExpr_(const LTNode* op) { +PrimExpr RewriteSimplifier::Impl::VisitExpr_(const LTNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); - PrimExpr const_res = TryConstFold(op->a, op->b); + PrimExpr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; // Pattern var to match any expression @@ -1421,10 +1272,8 @@ VisitExpr_(const LTNode* op) { // vector rule if (op->dtype.lanes() != 1) { - TVM_TRY_REWRITE(broadcast(x, lanes) < broadcast(y, lanes), - broadcast(x < y, lanes)); - TVM_TRY_REWRITE(ramp(x, s1, lanes) < ramp(y, s1, lanes), - broadcast(x < y, lanes)); + TVM_TRY_REWRITE(broadcast(x, lanes) < broadcast(y, lanes), broadcast(x < y, lanes)); + TVM_TRY_REWRITE(ramp(x, s1, lanes) < ramp(y, s1, lanes), broadcast(x < y, lanes)); } if (IsIndexType(op->a.dtype())) { @@ -1436,6 +1285,7 @@ VisitExpr_(const LTNode* op) { return make_const(op->dtype, false); } + // clang-format off TVM_TRY_REWRITE(x + y < x + z, y < z); TVM_TRY_REWRITE(x + y < z + x, y < z); TVM_TRY_REWRITE(y + x < x + z, y < z); @@ -1449,100 +1299,76 @@ VisitExpr_(const LTNode* op) { TVM_TRY_REWRITE(c1 < x + c2, c1 - c2 < x); TVM_TRY_REWRITE(c1 < c2 - x, x < c2 - c1); - TVM_TRY_REWRITE_IF(x * c1 < y * c1, x < y, - c1.Eval()->value > 0); - TVM_TRY_REWRITE_IF(x * c1 < y * c1, y < x, - c1.Eval()->value < 0); + TVM_TRY_REWRITE_IF(x * c1 < y * c1, x < y, c1.Eval()->value > 0); + TVM_TRY_REWRITE_IF(x * c1 < y * c1, y < x, c1.Eval()->value < 0); // constant cancelation: only need to make use of one mod // truc div - TVM_TRY_REWRITE_IF(x * c2 < c1, x < truncdiv(c1 - 1, c2) + 1, - c1.Eval()->value > 0 && - c2.Eval()->value > 0); + TVM_TRY_REWRITE_IF(x * c2 < c1, + x < truncdiv(c1 - 1, c2) + 1, c1.Eval()->value > 0 && c2.Eval()->value > 0); // NOTE: trunc div required TVM_TRY_REWRITE_IF(x * c2 < c1, x < truncdiv(c1, c2), - c1.Eval()->value <= 0 && - c2.Eval()->value > 0); + c1.Eval()->value <= 0 && c2.Eval()->value > 0); // NOTE: trunc div required (euclidean is ok too, floored is not) - TVM_TRY_REWRITE_IF(x * c2 < c1, truncdiv(c1 - 1, c2) - 1 < x, - c1.Eval()->value > 0 && + TVM_TRY_REWRITE_IF(x * c2 < c1, truncdiv(c1 - 1, c2) - 1 < x, c1.Eval()->value > 0 && c2.Eval()->value < 0); // NOTE: trunc div required (floored is ok too, euclidean is not) TVM_TRY_REWRITE_IF(x * c2 < c1, truncdiv(c1, c2) < x, - c1.Eval()->value <= 0 && - c2.Eval()->value < 0); + c1.Eval()->value <= 0 && c2.Eval()->value < 0); // NOTE: trunc div required TVM_TRY_REWRITE_IF(c1 < x * c2, truncdiv(c1 + 1, c2) - 1 < x, - c1.Eval()->value < 0 && - c2.Eval()->value > 0); + c1.Eval()->value < 0 && c2.Eval()->value > 0); TVM_TRY_REWRITE_IF(c1 < x * c2, truncdiv(c1, c2) < x, - c1.Eval()->value >= 0 && - c2.Eval()->value > 0); + c1.Eval()->value >= 0 && c2.Eval()->value > 0); // NOTE: trunc div required (floored is ok too, euclidean is not) TVM_TRY_REWRITE_IF(c1 < x * c2, x < truncdiv(c1 + 1, c2) + 1, - c1.Eval()->value < 0 && - c2.Eval()->value < 0); + c1.Eval()->value < 0 && c2.Eval()->value < 0); // NOTE: trunc div required (euclidean is ok too, floored is not) TVM_TRY_REWRITE_IF(c1 < x * c2, x < truncdiv(c1, c2), - c1.Eval()->value >= 0 && - c2.Eval()->value < 0); + c1.Eval()->value >= 0 && c2.Eval()->value < 0); // DivMod rules // trucdiv - TVM_TRY_REWRITE_IF(truncdiv(x, c1) < c2, x < c1 * c2, - c1.Eval()->value > 0 && - c2.Eval()->value > 0); + TVM_TRY_REWRITE_IF(truncdiv(x, c1) < c2, + xvalue> 0 && c2.Eval()->value > 0); // NOTE: trunc div required - TVM_TRY_REWRITE_IF(truncdiv(x, c1) < c2, x < c1 * (c2 - 1) + 1, - c1.Eval()->value > 0 && - c2.Eval()->value <= 0); + TVM_TRY_REWRITE_IF(truncdiv(x, c1) < c2, + xvalue> 0 && c2.Eval()->value <= 0); TVM_TRY_REWRITE_IF(c1 < truncdiv(x, c2), (c1 + 1) * c2 - 1 < x, - c1.Eval()->value >= 0 && - c2.Eval()->value > 0); + c1.Eval()->value >= 0 && c2.Eval()->value > 0); // NOTE: trunc div required TVM_TRY_REWRITE_IF(c1 < truncdiv(x, c2), c1 * c2 < x, - c1.Eval()->value < 0 && - c2.Eval()->value > 0); + c1.Eval()->value < 0 && c2.Eval()->value > 0); // invariance for any div mod: x - (x / c1) * c1 == x % c1 - TVM_TRY_REWRITE_IF(truncdiv(x, c1) * c1 < x, 0 < truncmod(x, c1), - c1.Eval()->value > 0); - TVM_TRY_REWRITE_IF(truncdiv(x, c1) * c1 < x + y, 0 < truncmod(x, c1) + y, - c1.Eval()->value > 0); - TVM_TRY_REWRITE_IF(truncdiv(x, c1) * c1 < x - y, y < truncmod(x, c1), - c1.Eval()->value > 0); + TVM_TRY_REWRITE_IF(truncdiv(x, c1) * c1 < x, 0 < truncmod(x, c1), c1.Eval()->value > 0); + TVM_TRY_REWRITE_IF(truncdiv(x, c1) * c1 < x + y, + 0 < truncmod(x, c1) + y, c1.Eval()->value > 0); + TVM_TRY_REWRITE_IF(truncdiv(x, c1) * c1 < x - y, + y < truncmod(x, c1), c1.Eval()->value > 0); TVM_TRY_REWRITE_IF(truncdiv(x + c2, c1) * c1 < x, - c2 < truncmod(x + c2, c1), - c1.Eval()->value > 0); + c2 < truncmod(x + c2, c1), c1.Eval()->value > 0); TVM_TRY_REWRITE_IF(truncdiv(x + c2, c1) * c1 < x + y, - c2 < truncmod(x + c2, c1) + y, - c1.Eval()->value > 0); + c2 < truncmod(x + c2, c1) + y, c1.Eval()->value > 0); TVM_TRY_REWRITE_IF(truncdiv(x + c2, c1) * c1 < x - y, - y < truncmod(x + c2, c1) + (0 - c2), - c1.Eval()->value > 0); + y < truncmod(x + c2, c1) + (0 - c2), c1.Eval()->value > 0); // floordiv - TVM_TRY_REWRITE_IF(floordiv(x, c1) < c2, x < c1 * c2, - c1.Eval()->value > 0); - TVM_TRY_REWRITE_IF(c1 < floordiv(x, c2), (c1 + 1) * c2 - 1 < x, - c2.Eval()->value > 0); - - TVM_TRY_REWRITE_IF(floordiv(x, c1) * c1 < x, 0 < floormod(x, c1), - c1.Eval()->value > 0); - TVM_TRY_REWRITE_IF(floordiv(x, c1) * c1 < x + y, 0 < floormod(x, c1) + y, - c1.Eval()->value > 0); - TVM_TRY_REWRITE_IF(floordiv(x, c1) * c1 < x - y, y < floormod(x, c1), - c1.Eval()->value > 0); + TVM_TRY_REWRITE_IF(floordiv(x, c1) < c2, x < c1 * c2, c1.Eval()->value > 0); + TVM_TRY_REWRITE_IF(c1 < floordiv(x, c2), (c1 + 1) * c2 - 1 < x, c2.Eval()->value > 0); + + TVM_TRY_REWRITE_IF(floordiv(x, c1) * c1 < x, 0 < floormod(x, c1), c1.Eval()->value > 0); + TVM_TRY_REWRITE_IF(floordiv(x, c1) * c1 < x + y, + 0 < floormod(x, c1) + y, c1.Eval()->value > 0); + TVM_TRY_REWRITE_IF(floordiv(x, c1) * c1 < x - y, + y < floormod(x, c1), c1.Eval()->value > 0); TVM_TRY_REWRITE_IF(floordiv(x + c2, c1) * c1 < x, - c2 < floormod(x + c2, c1), - c1.Eval()->value > 0); + c2 < floormod(x + c2, c1), c1.Eval()->value > 0); TVM_TRY_REWRITE_IF(floordiv(x + c2, c1) * c1 < x + y, - c2 < floormod(x + c2, c1) + y, - c1.Eval()->value > 0); + c2 < floormod(x + c2, c1) + y, c1.Eval()->value > 0); TVM_TRY_REWRITE_IF(floordiv(x + c2, c1) * c1 < x - y, - y < floormod(x + c2, c1) + (0 - c2), - c1.Eval()->value > 0); + y < floormod(x + c2, c1) + (0 - c2), c1.Eval()->value > 0); // canonicalization rule TVM_TRY_RECURSIVE_REWRITE(min(x, y) < z, x < z || y < z); @@ -1558,15 +1384,15 @@ VisitExpr_(const LTNode* op) { TVM_TRY_RECURSIVE_REWRITE(x + c1 < c2, x < c2 - c1); TVM_TRY_RECURSIVE_REWRITE(x - c1 < c2, x < c2 + c1); TVM_TRY_REWRITE(x - c1 < 0, x < c1); + // clang-format on } return ret; } -PrimExpr RewriteSimplifier::Impl:: -VisitExpr_(const NotNode* op) { +PrimExpr RewriteSimplifier::Impl::VisitExpr_(const NotNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); - PrimExpr const_res = TryConstFold(op->a); + PrimExpr const_res = TryConstFold(op->a); if (const_res.defined()) return const_res; // Pattern var to match any expression PVar x, y; @@ -1587,11 +1413,10 @@ VisitExpr_(const NotNode* op) { return ret; } -PrimExpr RewriteSimplifier::Impl:: -VisitExpr_(const AndNode* op) { +PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AndNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); - PrimExpr const_res = TryConstFold(op->a, op->b); + PrimExpr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; // Pattern var to match any expression @@ -1601,8 +1426,7 @@ VisitExpr_(const AndNode* op) { PVar lanes; if (op->dtype.lanes() != 1) { - TVM_TRY_REWRITE(broadcast(x, lanes) && broadcast(y, lanes), - broadcast(x && y, lanes)); + TVM_TRY_REWRITE(broadcast(x, lanes) && broadcast(y, lanes), broadcast(x && y, lanes)); } auto cfalse = PConst(make_const(op->dtype, false)); @@ -1612,35 +1436,26 @@ VisitExpr_(const AndNode* op) { TVM_TRY_REWRITE(x <= y && y < x, cfalse); TVM_TRY_REWRITE(y < x && x <= y, cfalse); - TVM_TRY_REWRITE_IF(x < c1 && c2 < x, cfalse, - c2.Eval()->value + 1 >= c1.Eval()->value); - TVM_TRY_REWRITE_IF(c2 < x && x < c1, cfalse, - c2.Eval()->value + 1 >= c1.Eval()->value); - - TVM_TRY_REWRITE_IF(x < c1 && c2 <= x, cfalse, - c2.Eval()->value >= c1.Eval()->value); - TVM_TRY_REWRITE_IF(c2 <= x && x < c1, cfalse, - c2.Eval()->value >= c1.Eval()->value); - TVM_TRY_REWRITE_IF(x <= c1 && c2 < x, cfalse, - c2.Eval()->value >= c1.Eval()->value); - TVM_TRY_REWRITE_IF(c2 < x && x <= c1, cfalse, - c2.Eval()->value >= c1.Eval()->value); - - TVM_TRY_REWRITE_IF(x <= c1 && c2 <= x, cfalse, - c2.Eval()->value > c1.Eval()->value); - TVM_TRY_REWRITE_IF(c2 <= x && x <= c1, cfalse, - c2.Eval()->value > c1.Eval()->value); + TVM_TRY_REWRITE_IF(x < c1 && c2 < x, cfalse, c2.Eval()->value + 1 >= c1.Eval()->value); + TVM_TRY_REWRITE_IF(c2 < x && x < c1, cfalse, c2.Eval()->value + 1 >= c1.Eval()->value); + + TVM_TRY_REWRITE_IF(x < c1 && c2 <= x, cfalse, c2.Eval()->value >= c1.Eval()->value); + TVM_TRY_REWRITE_IF(c2 <= x && x < c1, cfalse, c2.Eval()->value >= c1.Eval()->value); + TVM_TRY_REWRITE_IF(x <= c1 && c2 < x, cfalse, c2.Eval()->value >= c1.Eval()->value); + TVM_TRY_REWRITE_IF(c2 < x && x <= c1, cfalse, c2.Eval()->value >= c1.Eval()->value); + + TVM_TRY_REWRITE_IF(x <= c1 && c2 <= x, cfalse, c2.Eval()->value > c1.Eval()->value); + TVM_TRY_REWRITE_IF(c2 <= x && x <= c1, cfalse, c2.Eval()->value > c1.Eval()->value); TVM_TRY_REWRITE(x == c1 && x != c2, x == c1 && c1 != c2); TVM_TRY_REWRITE(x != c2 && x == c1, x == c1 && c1 != c2); return ret; } -PrimExpr RewriteSimplifier::Impl:: -VisitExpr_(const OrNode* op) { +PrimExpr RewriteSimplifier::Impl::VisitExpr_(const OrNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); - PrimExpr const_res = TryConstFold(op->a, op->b); + PrimExpr const_res = TryConstFold(op->a, op->b); if (const_res.defined()) return const_res; // Pattern var to match any expression @@ -1650,8 +1465,7 @@ VisitExpr_(const OrNode* op) { PVar lanes; if (op->dtype.lanes() != 1) { - TVM_TRY_REWRITE(broadcast(x, lanes) || broadcast(y, lanes), - broadcast(x || y, lanes)); + TVM_TRY_REWRITE(broadcast(x, lanes) || broadcast(y, lanes), broadcast(x || y, lanes)); } auto ctrue = PConst(make_const(op->dtype, true)); @@ -1662,32 +1476,23 @@ VisitExpr_(const OrNode* op) { TVM_TRY_REWRITE(x <= y || y < x, ctrue); TVM_TRY_REWRITE(y < x || x <= y, ctrue); - TVM_TRY_REWRITE_IF(x < c1 || c2 < x, ctrue, - c2.Eval()->value < c1.Eval()->value); - TVM_TRY_REWRITE_IF(c2 < x || x < c1, ctrue, - c2.Eval()->value < c1.Eval()->value); - - TVM_TRY_REWRITE_IF(x <= c1 || c2 < x, ctrue, - c2.Eval()->value <= c1.Eval()->value); - TVM_TRY_REWRITE_IF(c2 < x || x <= c1, ctrue, - c2.Eval()->value <= c1.Eval()->value); - TVM_TRY_REWRITE_IF(x < c1 || c2 <= x, ctrue, - c2.Eval()->value <= c1.Eval()->value); - TVM_TRY_REWRITE_IF(c2 <= x || x < c1, ctrue, - c2.Eval()->value <= c1.Eval()->value); - - TVM_TRY_REWRITE_IF(x <= c1 || c2 <= x, ctrue, - c2.Eval()->value <= c1.Eval()->value + 1); - TVM_TRY_REWRITE_IF(c2 <= x || x <= c1, ctrue, - c2.Eval()->value <= c1.Eval()->value + 1); + TVM_TRY_REWRITE_IF(x < c1 || c2 < x, ctrue, c2.Eval()->value < c1.Eval()->value); + TVM_TRY_REWRITE_IF(c2 < x || x < c1, ctrue, c2.Eval()->value < c1.Eval()->value); + + TVM_TRY_REWRITE_IF(x <= c1 || c2 < x, ctrue, c2.Eval()->value <= c1.Eval()->value); + TVM_TRY_REWRITE_IF(c2 < x || x <= c1, ctrue, c2.Eval()->value <= c1.Eval()->value); + TVM_TRY_REWRITE_IF(x < c1 || c2 <= x, ctrue, c2.Eval()->value <= c1.Eval()->value); + TVM_TRY_REWRITE_IF(c2 <= x || x < c1, ctrue, c2.Eval()->value <= c1.Eval()->value); + + TVM_TRY_REWRITE_IF(x <= c1 || c2 <= x, ctrue, c2.Eval()->value <= c1.Eval()->value + 1); + TVM_TRY_REWRITE_IF(c2 <= x || x <= c1, ctrue, c2.Eval()->value <= c1.Eval()->value + 1); TVM_TRY_REWRITE(x != c1 || x == c2, x != c1 || c1 == c2); TVM_TRY_REWRITE(x == c2 || x != c1, x != c1 || c1 == c2); return ret; } -PrimExpr RewriteSimplifier::Impl:: -VisitExpr_(const SelectNode* op) { +PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SelectNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); if (op == nullptr) return ret; @@ -1697,8 +1502,7 @@ VisitExpr_(const SelectNode* op) { return ret; } -PrimExpr RewriteSimplifier::Impl:: -VisitExpr_(const CallNode* op) { +PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CallNode* op) { // add condition context to if_then_else PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); @@ -1728,8 +1532,7 @@ VisitExpr_(const CallNode* op) { return ret; } -PrimExpr RewriteSimplifier::Impl:: -VisitExpr_(const VarNode* op) { +PrimExpr RewriteSimplifier::Impl::VisitExpr_(const VarNode* op) { Var var = GetRef(op); auto it = var_map_.find(var); if (it != var_map_.end()) { @@ -1738,15 +1541,13 @@ VisitExpr_(const VarNode* op) { return GetRef(op); } -PrimExpr RewriteSimplifier::Impl:: -VisitExpr_(const CastNode* op) { +PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CastNode* op) { PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); op = ret.as(); return cast(op->dtype, op->value); } -PrimExpr RewriteSimplifier::Impl:: -VisitExpr_(const LetNode* op) { +PrimExpr RewriteSimplifier::Impl::VisitExpr_(const LetNode* op) { PrimExpr value = this->VisitExpr(op->value); if (!tir::HasSideEffect(value)) { // it is fine to discard the let binding @@ -1755,11 +1556,10 @@ VisitExpr_(const LetNode* op) { return this->VisitExpr(op->body); } PrimExpr body = this->VisitExpr(op->body); - if (value.same_as(op->value) && - body.same_as(op->body)) { + if (value.same_as(op->value) && body.same_as(op->body)) { return GetRef(op); } else { - return LetNode::make(op->var, value, body); + return Let(op->var, value, body); } } @@ -1775,9 +1575,7 @@ PrimExpr RewriteSimplifier::operator()(const PrimExpr& expr) { return res; } -void RewriteSimplifier::Update(const Var& var, - const PrimExpr& info, - bool override) { +void RewriteSimplifier::Update(const Var& var, const PrimExpr& info, bool override) { impl_->Update(var, info, override); } @@ -1785,13 +1583,9 @@ std::function RewriteSimplifier::EnterConstraint(const PrimExpr& constra return impl_->EnterConstraint(constraint); } -RewriteSimplifier::RewriteSimplifier(Analyzer* parent) - : impl_(new Impl(parent)) { -} +RewriteSimplifier::RewriteSimplifier(Analyzer* parent) : impl_(new Impl(parent)) {} -RewriteSimplifier::~RewriteSimplifier() { - delete impl_; -} +RewriteSimplifier::~RewriteSimplifier() { delete impl_; } } // namespace arith } // namespace tvm diff --git a/src/arith/rewrite_simplify.h b/src/arith/rewrite_simplify.h index 8798df92777d..68c0dd271410 100644 --- a/src/arith/rewrite_simplify.h +++ b/src/arith/rewrite_simplify.h @@ -26,11 +26,13 @@ #include #include + #include #include + #include "const_fold.h" -#include "pattern_match.h" #include "ir_mutator_with_analyzer.h" +#include "pattern_match.h" namespace tvm { namespace arith { @@ -46,8 +48,7 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer { public: using IRMutatorWithAnalyzer::VisitExpr_; - explicit Impl(Analyzer* parent) - : IRMutatorWithAnalyzer(parent) {} + explicit Impl(Analyzer* parent) : IRMutatorWithAnalyzer(parent) {} void Update(const Var& var, const PrimExpr& info, bool override_info); PrimExpr VisitExpr_(const AddNode* op) override; @@ -78,19 +79,11 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer { protected: /*! \brief internal structure for comparison. */ - enum CompareResult { - kUnknown, - kEQ, - kGT, - kGE, - kLT, - kLE, - kNE - }; + enum CompareResult { kUnknown, kEQ, kGT, kGE, kLT, kLE, kNE }; // counter to record recursive rewrite depth. int recur_depth_{0}; // internal variable map - std::unordered_map var_map_; + std::unordered_map var_map_; std::vector literal_constraints_; @@ -127,18 +120,17 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer { return res; } - template + template PConstWithTypeLike ZeroWithTypeLike(const Pattern& pattern) { return PConstWithTypeLike(pattern.derived(), 0); } - template + template PConstWithTypeLike OneWithTypeLike(const Pattern& pattern) { return PConstWithTypeLike(pattern.derived(), 1); } }; - } // namespace arith } // namespace tvm #endif // TVM_ARITH_REWRITE_SIMPLIFY_H_ diff --git a/src/arith/solve_linear_equation.cc b/src/arith/solve_linear_equation.cc index 8142a03155c8..5bf0e0e32984 100644 --- a/src/arith/solve_linear_equation.cc +++ b/src/arith/solve_linear_equation.cc @@ -21,25 +21,24 @@ * \file tvm/arith/solve_linear_equation.cc * \brief Solve linear equations. */ -#include -#include #include #include -#include -#include #include -#include #include +#include +#include +#include +#include + +#include "int_operator.h" namespace tvm { namespace arith { using namespace tvm::runtime; -void SmithNormalFormDiag(std::vector >* S, - std::vector >* V, - std::vector* x, - std::vector* y) { +void SmithNormalFormDiag(std::vector>* S, std::vector>* V, + std::vector* x, std::vector* y) { if (S->empty() || V->empty()) return; size_t m = S->size(); size_t n = (*S)[0].size(); // n is # of variables @@ -98,7 +97,7 @@ void SmithNormalFormDiag(std::vector >* S, int64_t g, a, b; // g = a*matrix[index][index] + b*matrix[i][index] if ((*S)[i][index] % (*S)[index][index] != 0) { - std::tie(g, a, b) = xgcd((*S)[index][index], (*S)[i][index]); + g = ExtendedEuclidean((*S)[index][index], (*S)[i][index], &a, &b); } else { // Explicitly avoid changing the index-th row. This is important to avoid infinite loop. g = (*S)[index][index]; @@ -123,19 +122,19 @@ void SmithNormalFormDiag(std::vector >* S, for (size_t j = index; j < (*S)[i].size(); ++j) { // Multiply index-th row by a and add the i-th row multiplied by b // This will make the index-th diagonal element equal to the gcd - int64_t new_index_j = a*(*S)[index][j] + b*(*S)[i][j]; + int64_t new_index_j = a * (*S)[index][j] + b * (*S)[i][j]; // This transformation performs zeroing of matrix[i][index] - int64_t new_i_j = n_g*(*S)[index][j] - m_g*(*S)[i][j]; + int64_t new_i_j = n_g * (*S)[index][j] - m_g * (*S)[i][j]; (*S)[index][j] = new_index_j; (*S)[i][j] = new_i_j; } // We have to do the same with rhs - PrimExpr ea = te::make_const((*y)[index].dtype(), a); - PrimExpr eb = te::make_const((*y)[i].dtype(), b); - PrimExpr e_m_g = te::make_const((*y)[i].dtype(), m_g); - PrimExpr e_n_g = te::make_const((*y)[index].dtype(), n_g); - PrimExpr new_index_rhs = ea*(*y)[index] + eb*(*y)[i]; - PrimExpr new_i_rhs = e_n_g*(*y)[index] - e_m_g*(*y)[i]; + PrimExpr ea = tir::make_const((*y)[index].dtype(), a); + PrimExpr eb = tir::make_const((*y)[i].dtype(), b); + PrimExpr e_m_g = tir::make_const((*y)[i].dtype(), m_g); + PrimExpr e_n_g = tir::make_const((*y)[index].dtype(), n_g); + PrimExpr new_index_rhs = ea * (*y)[index] + eb * (*y)[i]; + PrimExpr new_i_rhs = e_n_g * (*y)[index] - e_m_g * (*y)[i]; (*y)[index] = new_index_rhs; (*y)[i] = new_i_rhs; } @@ -151,7 +150,7 @@ void SmithNormalFormDiag(std::vector >* S, int64_t g, a, b; // g = a*matrix[index][index] + b*matrix[index][j] if ((*S)[index][j] % (*S)[index][index] != 0) { - std::tie(g, a, b) = xgcd((*S)[index][index], (*S)[index][j]); + g = ExtendedEuclidean((*S)[index][index], (*S)[index][j], &a, &b); // During this phase we may disrupt the zeroness of the index-th column, so we will // have to take some action if this might have happened. changed = true; @@ -177,25 +176,25 @@ void SmithNormalFormDiag(std::vector >* S, int64_t n_g = (*S)[index][j] / g; for (size_t i = index; i < m; ++i) { - int64_t new_i_index = a*(*S)[i][index] + b*(*S)[i][j]; - int64_t new_i_j = n_g*(*S)[i][index] - m_g*(*S)[i][j]; + int64_t new_i_index = a * (*S)[i][index] + b * (*S)[i][j]; + int64_t new_i_j = n_g * (*S)[i][index] - m_g * (*S)[i][j]; (*S)[i][index] = new_i_index; (*S)[i][j] = new_i_j; } // We do exactly the same transformations with V for (size_t i = 0; i < n; ++i) { - int64_t new_i_index = a*(*V)[i][index] + b*(*V)[i][j]; - int64_t new_i_j = n_g*(*V)[i][index] - m_g*(*V)[i][j]; + int64_t new_i_index = a * (*V)[i][index] + b * (*V)[i][j]; + int64_t new_i_j = n_g * (*V)[i][index] - m_g * (*V)[i][j]; (*V)[i][index] = new_i_index; (*V)[i][j] = new_i_j; } // And apply reverse transformations to new_to_old. - PrimExpr ea = te::make_const((*x)[j].dtype(), a); - PrimExpr eb = te::make_const((*x)[index].dtype(), b); - PrimExpr e_m_g = te::make_const((*x)[index].dtype(), m_g); - PrimExpr e_n_g = te::make_const((*x)[j].dtype(), n_g); - PrimExpr new_index = e_m_g*(*x)[index] + e_n_g*(*x)[j]; - PrimExpr new_j = eb*(*x)[index] - ea*(*x)[j]; + PrimExpr ea = tir::make_const((*x)[j].dtype(), a); + PrimExpr eb = tir::make_const((*x)[index].dtype(), b); + PrimExpr e_m_g = tir::make_const((*x)[index].dtype(), m_g); + PrimExpr e_n_g = tir::make_const((*x)[j].dtype(), n_g); + PrimExpr new_index = e_m_g * (*x)[index] + e_n_g * (*x)[j]; + PrimExpr new_j = eb * (*x)[index] - ea * (*x)[j]; (*x)[index] = new_index; (*x)[j] = new_j; } @@ -209,8 +208,7 @@ void SmithNormalFormDiag(std::vector >* S, } } -Map InferRange(const Map& vars_to_infer, - const Array& ori_vars, +Map InferRange(const Map& vars_to_infer, const Array& ori_vars, const Map& ori_ranges) { // The resulting ranges Map new_ranges; @@ -244,8 +242,7 @@ Map InferRange(const Map& vars_to_infer, // pretty print matrix equation void DebugPrint(const std::vector>& S, - const std::vector>& V, - const std::vector& V_inv_x, + const std::vector>& V, const std::vector& V_inv_x, const std::vector& rhs) { std::cout << "S:\n"; for (size_t i = 0; i < S.size(); ++i) { @@ -266,7 +263,7 @@ void DebugPrint(const std::vector>& S, std::cout << "\n" << std::endl; } -IntConstraintsTransform SolveLinearEquations(const IntConstraints &system_to_solve) { +IntConstraintsTransform SolveLinearEquations(const IntConstraints& system_to_solve) { // m: # of equations // n: # of variables // we first construct A_{mxn} x_{nx1} = y_{mx1} @@ -274,10 +271,10 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints &system_to_sol // S_{mxn} = U_{mxm} A_{mxn} V_{nxn} // => U^{-1} S V^{-1} x = y // S V^{-1} x = U y - std::vector Uy; // mx1 + std::vector Uy; // mx1 std::vector> S; // mxn std::vector> V; // nxn - std::vector V_inv_x; // V^{-1} x, nx1 + std::vector V_inv_x; // V^{-1} x, nx1 // Conditions we don't know what to do with std::vector rest; @@ -300,9 +297,8 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints &system_to_sol for (const PrimExpr& equation : system_to_solve->relations) { if (const tir::EQNode* eq = equation.as()) { // a-b = sum_{i=0}^{n-1} variables[i] * coeff[i] + coeff[n] - Array coeffs = arith::DetectLinearEquation( - analyzer_problem.Simplify(eq->a - eq->b), - system_to_solve->variables); + Array coeffs = arith::DetectLinearEquation(analyzer_problem.Simplify(eq->a - eq->b), + system_to_solve->variables); if (!coeffs.empty()) { std::vector row; for (size_t j = 0; j < coeffs.size() - 1; ++j) { @@ -364,13 +360,12 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints &system_to_sol new_relation = analyzer_problem.Simplify(new_relation); if (tir::is_const_int(new_relation, 0)) { // unable to solve the system. - return IntConstraintsTransform( - system_to_solve, - IntConstraints( - /*variables=*/{}, - /*ranges=*/{}, - /*relations=*/{te::make_zero(DataType::Bool())}), - {}, {}); + return IntConstraintsTransform(system_to_solve, + IntConstraints( + /*variables=*/{}, + /*ranges=*/{}, + /*relations=*/{tir::make_zero(DataType::Bool())}), + {}, {}); } else if (!tir::is_const_int(new_relation, 1)) { new_relations.push_back(new_relation); } @@ -403,32 +398,30 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints &system_to_sol // The j-th variable is just a single value, don't create a tvm variable // S^{-1}_{nxm} Uy_{mxn} if (S[j][j] >= 0) { - PrimExpr a = te::make_const(Uy[j].dtype(), S[j][j]); - solution_for_V_inv_x.push_back( - analyzer_problem.Simplify(floordiv(Uy[j], a))); + PrimExpr a = tir::make_const(Uy[j].dtype(), S[j][j]); + solution_for_V_inv_x.push_back(analyzer_problem.Simplify(floordiv(Uy[j], a))); } else { // This is required because some simplifiers // have problems with dividing by negative numbers - PrimExpr a = te::make_const(Uy[j].dtype(), -S[j][j]); - solution_for_V_inv_x.push_back( - analyzer_problem.Simplify(floordiv(-Uy[j], a))); + PrimExpr a = tir::make_const(Uy[j].dtype(), -S[j][j]); + solution_for_V_inv_x.push_back(analyzer_problem.Simplify(floordiv(-Uy[j], a))); } } } // V V^{-1} x = x for (size_t i = 0; i < num_vars; ++i) { - PrimExpr e = te::make_zero(system_to_solve->variables[i].dtype()); + PrimExpr e = tir::make_zero(system_to_solve->variables[i].dtype()); for (size_t j = 0; j < num_vars; ++j) { - e = e + te::make_const(e.dtype(), V[i][j])*solution_for_V_inv_x[j]; + e = e + tir::make_const(e.dtype(), V[i][j]) * solution_for_V_inv_x[j]; } e = analyzer_problem.Simplify(e); old_to_new_map.Set(system_to_solve->variables[i], e); } // The resulting ranges - Map new_ranges = InferRange( - new_to_old_map, system_to_solve->variables, system_to_solve->ranges); + Map new_ranges = + InferRange(new_to_old_map, system_to_solve->variables, system_to_solve->ranges); Analyzer analyzer_solution; analyzer_solution.Bind(new_ranges); @@ -439,10 +432,9 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints &system_to_sol const Range& old_range = p.second; if (old_to_new_map.count(old_var)) { PrimExpr express_by_new_vars = old_to_new_map[old_var]; - PrimExpr lower_cond = analyzer_solution.Simplify( - old_range->min <= express_by_new_vars); - PrimExpr upper_cond = analyzer_solution.Simplify( - express_by_new_vars < old_range->min + old_range->extent); + PrimExpr lower_cond = analyzer_solution.Simplify(old_range->min <= express_by_new_vars); + PrimExpr upper_cond = + analyzer_solution.Simplify(express_by_new_vars < old_range->min + old_range->extent); if (!tir::is_const_int(lower_cond, 1)) { new_relations.push_back(lower_cond); } @@ -458,23 +450,21 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints &system_to_sol } IntConstraints solution(new_vars, new_ranges, new_relations); - IntConstraintsTransform transform( - system_to_solve, solution, old_to_new_map, new_to_old_map); + IntConstraintsTransform transform(system_to_solve, solution, old_to_new_map, new_to_old_map); return transform; } -TVM_REGISTER_GLOBAL("arith.SolveLinearEquations") -.set_body([](TVMArgs args, TVMRetValue *ret) { - if (args.size() == 1) { - *ret = SolveLinearEquations(args[0]); - } else if (args.size() == 3) { - IntConstraints problem(args[0], args[1], args[2]); - *ret = SolveLinearEquations(problem); - } else { - LOG(FATAL) << "arith.SolveLinearEquations expects 1 or 3 arguments, gets " << args.size(); - } - }); +TVM_REGISTER_GLOBAL("arith.SolveLinearEquations").set_body([](TVMArgs args, TVMRetValue* ret) { + if (args.size() == 1) { + *ret = SolveLinearEquations(args[0]); + } else if (args.size() == 3) { + IntConstraints problem(args[0], args[1], args[2]); + *ret = SolveLinearEquations(problem); + } else { + LOG(FATAL) << "arith.SolveLinearEquations expects 1 or 3 arguments, gets " << args.size(); + } +}); } // namespace arith } // namespace tvm diff --git a/src/autotvm/feature_visitor.cc b/src/autotvm/feature_visitor.cc index da044babdd43..54fc2522db66 100644 --- a/src/autotvm/feature_visitor.cc +++ b/src/autotvm/feature_visitor.cc @@ -30,10 +30,9 @@ namespace autotvm { // for loop void FeatureVisitor::VisitStmt_(const ForNode* op) { - const auto *extent = op->extent.as(); + const auto* extent = op->extent.as(); int64_t loop_extent = -1; - if (extent != nullptr) - loop_extent = extent->value; + if (extent != nullptr) loop_extent = extent->value; AnnotationType ann = kSerial; switch (op->for_type) { case ForType ::Parallel: @@ -58,10 +57,9 @@ void FeatureVisitor::VisitStmt_(const ForNode* op) { // parallel axis, virtual thread void FeatureVisitor::VisitStmt_(const AttrStmtNode* op) { - if (op->attr_key == attr::thread_extent || - op->attr_key == attr::virtual_thread) { + if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) { Var var = op->node.as()->var; - const auto *extent = op->value.as(); + const auto* extent = op->value.as(); CHECK(extent); std::string name = var.get()->name_hint; diff --git a/src/autotvm/feature_visitor.h b/src/autotvm/feature_visitor.h index 5391bddfa2f6..8180839b0668 100644 --- a/src/autotvm/feature_visitor.h +++ b/src/autotvm/feature_visitor.h @@ -29,6 +29,7 @@ #include #include #include + #include namespace tvm { @@ -40,8 +41,17 @@ using namespace tvm::tir; * \brief Type of for loop, used as one-hot encoding in features */ enum AnnotationType { - kBlockX, kBlockY, kBlockZ, kThreadX, kThreadY, kThreadZ, - kUnrolled, kVectorized, kParallel, kSerial, kVirtualThread, + kBlockX, + kBlockY, + kBlockZ, + kThreadX, + kThreadY, + kThreadZ, + kUnrolled, + kVectorized, + kParallel, + kSerial, + kVirtualThread, kNum, }; @@ -59,17 +69,17 @@ class FeatureVisitor : public StmtExprVisitor { void VisitExpr_(const LoadNode* op) final; void VisitStmt_(const StoreNode* op) final; - using StmtExprVisitor::VisitStmt_; using StmtExprVisitor::VisitExpr_; + using StmtExprVisitor::VisitStmt_; protected: /*! - * \brief Enter a for loop node - * \param var The expression to be printed. - * \param length The output stream - * \param ann_type The type for the for loop - * \return skip Whether skip this node - */ + * \brief Enter a for loop node + * \param var The expression to be printed. + * \param length The output stream + * \param ann_type The type for the for loop + * \return skip Whether skip this node + */ virtual bool EnterItervar_(tir::Var var, int64_t length, AnnotationType ann_type) = 0; /*! \brief Exit a for loop subtree */ virtual void ExitItervar_() = 0; diff --git a/src/autotvm/touch_extractor.cc b/src/autotvm/touch_extractor.cc index fbd0829c8a60..91e2ee135b16 100644 --- a/src/autotvm/touch_extractor.cc +++ b/src/autotvm/touch_extractor.cc @@ -24,9 +24,9 @@ #include "touch_extractor.h" -#include #include #include +#include #include namespace tvm { @@ -34,9 +34,14 @@ namespace autotvm { int ParallelLevel(AnnotationType ann) { switch (ann) { - case kBlockX: case kBlockY: case kBlockZ: + case kBlockX: + case kBlockY: + case kBlockZ: return 2; - case kThreadX: case kThreadY: case kThreadZ: case kParallel: + case kThreadX: + case kThreadY: + case kThreadZ: + case kParallel: return 1; default: return 0; @@ -44,7 +49,7 @@ int ParallelLevel(AnnotationType ann) { } // get touch pattern from index expression -class IndexParser: public ExprVisitor { +class IndexParser : public ExprVisitor { public: void Parse(PrimExpr expr) { pattern_map.clear(); @@ -95,11 +100,9 @@ bool TouchExtractor::EnterItervar_(Var var, int64_t length, AnnotationType ann_t itervar_map.erase(var); } - itervar_map.insert({var, ItervarFeature(var, length, - static_cast(itervar_stack_.size()), - ann_type, - topdown_product_, - static_cast(itervar_counter_++))}); + itervar_map.insert( + {var, ItervarFeature(var, length, static_cast(itervar_stack_.size()), ann_type, + topdown_product_, static_cast(itervar_counter_++))}); } return true; @@ -120,7 +123,7 @@ void TouchExtractor::ExitItervar_() { CHECK(touch_pattern != itervar_map[stack_var].touch_feature.end()); touch_pattern->second.count *= itervar_map[var].length; } - } else { // multiply reuse ratio + } else { // multiply reuse ratio for (auto stack_var : itervar_stack_) { auto touch_pattern = itervar_map[stack_var].touch_feature.find(kv.first); CHECK(touch_pattern != itervar_map[stack_var].touch_feature.end()); @@ -131,8 +134,7 @@ void TouchExtractor::ExitItervar_() { itervar_stack_.pop_back(); int64_t length = itervar_map[var].length; - if (length != 0) - topdown_product_ /= length; + if (length != 0) topdown_product_ /= length; int64_t bottomup_product = -1; for (auto kv : itervar_map[var].touch_feature) { bottomup_product = std::max(bottomup_product, kv.second.count * kv.second.reuse); @@ -188,8 +190,7 @@ void TouchExtractor::EnterMem_(Var buffer_var, PrimExpr index) { } } -void TouchExtractor::ExitMem_() { -} +void TouchExtractor::ExitMem_() {} /*! * \brief Get axis-based feature for all axes @@ -219,7 +220,7 @@ void TouchExtractor::ExitMem_() { * \note If you want to flatten these features as the input of your model, * You can use the faster one GetItervarFeatureFlatten below. */ -void GetItervarFeature(Stmt stmt, bool take_log, Array > > *ret_feature) { +void GetItervarFeature(Stmt stmt, bool take_log, Array > >* ret_feature) { // extract TouchExtractor touch_analyzer; touch_analyzer.Analyze(stmt); @@ -229,7 +230,7 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array > > for (auto kv : touch_analyzer.itervar_map) { vars.push_back(kv.first); } - std::sort(vars.begin(), vars.end(), [&](const Var &lhs, const Var &rhs) -> bool { + std::sort(vars.begin(), vars.end(), [&](const Var& lhs, const Var& rhs) -> bool { return touch_analyzer.itervar_map[lhs].order < touch_analyzer.itervar_map[rhs].order; }); @@ -237,28 +238,26 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array > > std::function trans; if (take_log) { trans = [](int64_t x) { - if (x < 0) - return -std::log(-x+1) / std::log(2); + if (x < 0) return -std::log(-x + 1) / std::log(2); x = x + 1; return std::log(x) / std::log(2); }; } else { - trans = [](int64_t x) { - return x; - }; + trans = [](int64_t x) { return x; }; } // serialize for front end for (auto var : vars) { Array > feature_row; - ItervarFeature &fea = touch_analyzer.itervar_map[var]; - feature_row.push_back(Array{tvm::tir::StringImmNode::make("_itervar_"), var}); - - Array attr{tvm::tir::StringImmNode::make("_attr_"), - FloatImm(DataType::Float(32), trans(fea.length)), - IntImm(DataType::Int(32), fea.nest_level), - FloatImm(DataType::Float(32), trans(fea.topdown_product)), - FloatImm(DataType::Float(32), trans(fea.bottomup_product)), + ItervarFeature& fea = touch_analyzer.itervar_map[var]; + feature_row.push_back(Array{tvm::tir::StringImm("_itervar_"), var}); + + Array attr{ + tvm::tir::StringImm("_attr_"), + FloatImm(DataType::Float(32), trans(fea.length)), + IntImm(DataType::Int(32), fea.nest_level), + FloatImm(DataType::Float(32), trans(fea.topdown_product)), + FloatImm(DataType::Float(32), trans(fea.bottomup_product)), }; // one hot annotation for (int i = 0; i < kNum; i++) { @@ -267,10 +266,11 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array > > feature_row.push_back(attr); // arithmetic - feature_row.push_back(Array{tvm::tir::StringImmNode::make("_arith_"), - FloatImm(DataType::Float(32), trans(fea.add_ct)), - FloatImm(DataType::Float(32), trans(fea.mul_ct)), - FloatImm(DataType::Float(32), trans(fea.div_ct)), + feature_row.push_back(Array{ + tvm::tir::StringImm("_arith_"), + FloatImm(DataType::Float(32), trans(fea.add_ct)), + FloatImm(DataType::Float(32), trans(fea.mul_ct)), + FloatImm(DataType::Float(32), trans(fea.div_ct)), }); // touch map @@ -280,16 +280,16 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array > > } std::sort(bufs.begin(), bufs.end()); for (auto k : bufs) { - TouchPattern &v = fea.touch_feature[k]; - feature_row.push_back( - Array{tvm::tir::StringImmNode::make(k), - FloatImm(DataType::Float(32), trans(v.stride)), - FloatImm(DataType::Float(32), trans(v.mod)), - FloatImm(DataType::Float(32), trans(v.count)), - FloatImm(DataType::Float(32), trans(v.reuse)), - FloatImm(DataType::Float(32), trans(v.thread_count)), - FloatImm(DataType::Float(32), trans(v.thread_reuse)), - }); + TouchPattern& v = fea.touch_feature[k]; + feature_row.push_back(Array{ + tvm::tir::StringImm(k), + FloatImm(DataType::Float(32), trans(v.stride)), + FloatImm(DataType::Float(32), trans(v.mod)), + FloatImm(DataType::Float(32), trans(v.count)), + FloatImm(DataType::Float(32), trans(v.reuse)), + FloatImm(DataType::Float(32), trans(v.thread_count)), + FloatImm(DataType::Float(32), trans(v.thread_reuse)), + }); } ret_feature->push_back(feature_row); @@ -305,7 +305,7 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array > > * \note See GetItervarFeature for more details about the return value. * This is an optimized version of GetItervarFeature + Flatten. This runs much faster. */ -void GetItervarFeatureFlatten(Stmt stmt, bool take_log, std::vector *ret_feature) { +void GetItervarFeatureFlatten(Stmt stmt, bool take_log, std::vector* ret_feature) { // extract touch feature TouchExtractor touch_analyzer; touch_analyzer.Analyze(stmt); @@ -315,7 +315,7 @@ void GetItervarFeatureFlatten(Stmt stmt, bool take_log, std::vector *ret_ for (auto kv : touch_analyzer.itervar_map) { vars.push_back(kv.first); } - std::sort(vars.begin(), vars.end(), [&](const Var &lhs, const Var &rhs) -> bool { + std::sort(vars.begin(), vars.end(), [&](const Var& lhs, const Var& rhs) -> bool { return touch_analyzer.itervar_map[lhs].order < touch_analyzer.itervar_map[rhs].order; }); @@ -323,20 +323,17 @@ void GetItervarFeatureFlatten(Stmt stmt, bool take_log, std::vector *ret_ std::function trans; if (take_log) { trans = [](int64_t x) { - if (x < 0) - return -std::log(-x+1) / std::log(2); + if (x < 0) return -std::log(-x + 1) / std::log(2); x = x + 1; return std::log(x) / std::log(2); }; } else { - trans = [](int64_t x) { - return x; - }; + trans = [](int64_t x) { return x; }; } // serialize for front end for (auto var : vars) { - ItervarFeature &fea = touch_analyzer.itervar_map[var]; + ItervarFeature& fea = touch_analyzer.itervar_map[var]; ret_feature->push_back(trans(fea.length)); ret_feature->push_back(fea.nest_level); @@ -360,7 +357,7 @@ void GetItervarFeatureFlatten(Stmt stmt, bool take_log, std::vector *ret_ } std::sort(bufs.begin(), bufs.end()); for (auto k : bufs) { - TouchPattern &v = fea.touch_feature[k]; + TouchPattern& v = fea.touch_feature[k]; ret_feature->push_back(trans(v.stride)); ret_feature->push_back(trans(v.mod)); ret_feature->push_back(trans(v.count)); @@ -372,12 +369,12 @@ void GetItervarFeatureFlatten(Stmt stmt, bool take_log, std::vector *ret_ } /*! - * \brief Get curve sample feature (relation feature) and flatten them into a one-dimensional vector. - * \param stmt The statement to be extracted - * \param sample_n The number of points used for sampling a curve (along one dimension) - * \param ret_feature The buffer where the return value is stored + * \brief Get curve sample feature (relation feature) and flatten them into a one-dimensional + * vector. \param stmt The statement to be extracted \param sample_n The number of points used for + * sampling a curve (along one dimension) \param ret_feature The buffer where the return value is + * stored */ -void GetCurveSampleFeatureFlatten(Stmt stmt, int sample_n, std::vector *ret_feature) { +void GetCurveSampleFeatureFlatten(Stmt stmt, int sample_n, std::vector* ret_feature) { // extract touch feature TouchExtractor touch_ext; touch_ext.Analyze(stmt); @@ -387,7 +384,7 @@ void GetCurveSampleFeatureFlatten(Stmt stmt, int sample_n, std::vector *r for (auto kv : touch_ext.itervar_map) { vars.push_back(kv.first); } - std::sort(vars.begin(), vars.end(), [&](const Var &lhs, const Var &rhs) -> bool { + std::sort(vars.begin(), vars.end(), [&](const Var& lhs, const Var& rhs) -> bool { return touch_ext.itervar_map[lhs].order < touch_ext.itervar_map[rhs].order; }); @@ -401,14 +398,14 @@ void GetCurveSampleFeatureFlatten(Stmt stmt, int sample_n, std::vector *r // find maximum depth of loop nest for (auto var : vars) { - ItervarFeature &fea = touch_ext.itervar_map[var]; + ItervarFeature& fea = touch_ext.itervar_map[var]; max_depth = std::max(max_depth, fea.nest_level); } // mark inner most buffer for (auto iter = vars.rbegin(); iter != vars.rend(); iter++) { auto var = *iter; - ItervarFeature &fea = touch_ext.itervar_map[var]; + ItervarFeature& fea = touch_ext.itervar_map[var]; if (fea.nest_level == max_depth) { for (auto kv : fea.touch_feature) { // delete buffer no (e.g. 'A_0' -> 'A', 'A_1' -> 'A') @@ -416,8 +413,7 @@ void GetCurveSampleFeatureFlatten(Stmt stmt, int sample_n, std::vector *r // delete memory scope (e.g. 'A.local' -> 'A', 'A.shared' -> 'A') size_t pos = raw_name.find("."); - if (pos < kv.first.size()) - raw_name = raw_name.substr(0, pos); + if (pos < kv.first.size()) raw_name = raw_name.substr(0, pos); // If there are multiple innermost buffers that are derived from a same raw buffer // We only record the last occurrence (note the `iter` is in reverse order) @@ -441,7 +437,7 @@ void GetCurveSampleFeatureFlatten(Stmt stmt, int sample_n, std::vector *r // extract curves for (auto var : vars) { - ItervarFeature &fea = touch_ext.itervar_map[var]; + ItervarFeature& fea = touch_ext.itervar_map[var]; for (auto kv : fea.touch_feature) { if (innermost_buffers.find(kv.first) != innermost_buffers.end()) { reuse_curve[kv.first].emplace_back(std::log(kv.second.reuse) / std::log(2)); @@ -453,7 +449,7 @@ void GetCurveSampleFeatureFlatten(Stmt stmt, int sample_n, std::vector *r } // sample relation in the curve - auto sample_curve = [&](const std::vector &x, const std::vector &y, + auto sample_curve = [&](const std::vector& x, const std::vector& y, double weight) { for (int i = 0; i < sample_n; i++) { double xx = i * weight; @@ -469,9 +465,9 @@ void GetCurveSampleFeatureFlatten(Stmt stmt, int sample_n, std::vector *r // serialize to frontend for (auto k : innermost_buffers) { - std::vector &count = count_curve[k]; - std::vector &reuse = reuse_curve[k]; - std::vector &top_down = topdown_curve[k]; + std::vector& count = count_curve[k]; + std::vector& reuse = reuse_curve[k]; + std::vector& top_down = topdown_curve[k]; std::sort(count.begin(), count.end()); std::sort(reuse.begin(), reuse.end()); @@ -484,49 +480,45 @@ void GetCurveSampleFeatureFlatten(Stmt stmt, int sample_n, std::vector *r } } - // register API for front end TVM_REGISTER_GLOBAL("autotvm.feature.GetItervarFeature") -.set_body([](TVMArgs args, TVMRetValue *ret) { - Stmt stmt = args[0]; - bool take_log = args[1]; - Array > > ret_feature; + .set_body([](TVMArgs args, TVMRetValue* ret) { + Stmt stmt = args[0]; + bool take_log = args[1]; + Array > > ret_feature; - GetItervarFeature(stmt, take_log, &ret_feature); - - *ret = ret_feature; -}); + GetItervarFeature(stmt, take_log, &ret_feature); + *ret = ret_feature; + }); TVM_REGISTER_GLOBAL("autotvm.feature.GetItervarFeatureFlatten") -.set_body([](TVMArgs args, TVMRetValue *ret) { - Stmt stmt = args[0]; - bool take_log = args[1]; - std::vector ret_feature; + .set_body([](TVMArgs args, TVMRetValue* ret) { + Stmt stmt = args[0]; + bool take_log = args[1]; + std::vector ret_feature; - GetItervarFeatureFlatten(stmt, take_log, &ret_feature); - - TVMByteArray arr; - arr.size = sizeof(float) * ret_feature.size(); - arr.data = reinterpret_cast(ret_feature.data()); - *ret = arr; -}); + GetItervarFeatureFlatten(stmt, take_log, &ret_feature); + TVMByteArray arr; + arr.size = sizeof(float) * ret_feature.size(); + arr.data = reinterpret_cast(ret_feature.data()); + *ret = arr; + }); TVM_REGISTER_GLOBAL("autotvm.feature.GetCurveSampleFeatureFlatten") -.set_body([](TVMArgs args, TVMRetValue *ret) { - Stmt stmt = args[0]; - int sample_n = args[1]; - std::vector ret_feature; + .set_body([](TVMArgs args, TVMRetValue* ret) { + Stmt stmt = args[0]; + int sample_n = args[1]; + std::vector ret_feature; - GetCurveSampleFeatureFlatten(stmt, sample_n, &ret_feature); - - TVMByteArray arr; - arr.size = sizeof(float) * ret_feature.size(); - arr.data = reinterpret_cast(ret_feature.data()); - *ret = arr; -}); + GetCurveSampleFeatureFlatten(stmt, sample_n, &ret_feature); + TVMByteArray arr; + arr.size = sizeof(float) * ret_feature.size(); + arr.data = reinterpret_cast(ret_feature.data()); + *ret = arr; + }); } // namespace autotvm } // namespace tvm diff --git a/src/autotvm/touch_extractor.h b/src/autotvm/touch_extractor.h index 23fbc54d843e..313e4d78d6e1 100644 --- a/src/autotvm/touch_extractor.h +++ b/src/autotvm/touch_extractor.h @@ -25,16 +25,17 @@ #ifndef TVM_AUTOTVM_TOUCH_EXTRACTOR_H_ #define TVM_AUTOTVM_TOUCH_EXTRACTOR_H_ +#include #include #include -#include -#include -#include +#include #include +#include #include -#include #include +#include + #include "feature_visitor.h" namespace tvm { @@ -55,11 +56,7 @@ struct TouchPattern { // all the feature of an iter var struct ItervarFeature { - ItervarFeature(Var var, - int64_t extent, - int nest, - AnnotationType ann_type, - int64_t topdown, + ItervarFeature(Var var, int64_t extent, int nest, AnnotationType ann_type, int64_t topdown, int counter) : length(extent), nest_level(nest), ann(ann_type), topdown_product(topdown), order(counter) {} ItervarFeature() {} @@ -67,9 +64,9 @@ struct ItervarFeature { // Axis Attributes int64_t length; int nest_level; - AnnotationType ann; // one-hot axis type - int64_t topdown_product; // accumulative product of axis length, in top-down order - int64_t bottomup_product; // accumulative product of axis length, in bottom-up order + AnnotationType ann; // one-hot axis type + int64_t topdown_product; // accumulative product of axis length, in top-down order + int64_t bottomup_product; // accumulative product of axis length, in bottom-up order // bottomup_product = reuse * count for any touched buffer int order; // used for soring axis @@ -86,42 +83,35 @@ struct ItervarFeature { // extract iter vars and their touch pattern from ir class TouchExtractor : public FeatureVisitor { public: - void Analyze(const Stmt& stmt) { - operator()(stmt); - } + void Analyze(const Stmt& stmt) { operator()(stmt); } // arithmetic stats void VisitExpr_(const AddNode* op) final { - if (op->dtype.is_float()) - itervar_map[itervar_stack_.back()].add_ct++; + if (op->dtype.is_float()) itervar_map[itervar_stack_.back()].add_ct++; FeatureVisitor::VisitExpr_(op); } void VisitExpr_(const SubNode* op) final { - if (op->dtype.is_float()) - itervar_map[itervar_stack_.back()].add_ct++; + if (op->dtype.is_float()) itervar_map[itervar_stack_.back()].add_ct++; FeatureVisitor::VisitExpr_(op); } void VisitExpr_(const MulNode* op) final { - if (op->dtype.is_float()) - itervar_map[itervar_stack_.back()].mul_ct++; + if (op->dtype.is_float()) itervar_map[itervar_stack_.back()].mul_ct++; FeatureVisitor::VisitExpr_(op); } void VisitExpr_(const DivNode* op) final { - if (op->dtype.is_float()) - itervar_map[itervar_stack_.back()].div_ct++; + if (op->dtype.is_float()) itervar_map[itervar_stack_.back()].div_ct++; FeatureVisitor::VisitExpr_(op); } void VisitExpr_(const ModNode* op) final { - if (op->dtype.is_float()) - itervar_map[itervar_stack_.back()].div_ct++; + if (op->dtype.is_float()) itervar_map[itervar_stack_.back()].div_ct++; FeatureVisitor::VisitExpr_(op); } - std::unordered_map itervar_map; + std::unordered_map itervar_map; private: bool EnterItervar_(Var var, int64_t length, AnnotationType ann_type); diff --git a/src/contrib/hybrid/codegen_hybrid.cc b/src/contrib/hybrid/codegen_hybrid.cc index bb97900833dd..e08f39f8135d 100644 --- a/src/contrib/hybrid/codegen_hybrid.cc +++ b/src/contrib/hybrid/codegen_hybrid.cc @@ -20,10 +20,12 @@ /*! * \file codegen_hybrid.cc */ +#include "codegen_hybrid.h" + #include -#include + #include -#include "codegen_hybrid.h" +#include namespace tvm { namespace contrib { @@ -34,7 +36,7 @@ using runtime::TVMRetValue; using namespace tir; std::string dot_to_underscore(std::string s) { - for (auto &ch : s) + for (auto& ch : s) if (ch == '.') ch = '_'; return s; } @@ -57,11 +59,9 @@ std::string CodeGenHybrid::GetUniqueName(std::string prefix) { return prefix; } -std::string CodeGenHybrid::Finish() { - return stream.str(); -} +std::string CodeGenHybrid::Finish() { return stream.str(); } -void CodeGenHybrid::PrintType(DataType t, std::ostream &os) { +void CodeGenHybrid::PrintType(DataType t, std::ostream& os) { if (t.is_float()) { os << "float"; CHECK(t.bits() == 16 || t.bits() == 32 || t.bits() == 64); @@ -80,20 +80,19 @@ void CodeGenHybrid::VisitExpr_(const IntImmNode* op, std::ostream& os) { // NOL os << op->value; } -void CodeGenHybrid::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT(*) +void CodeGenHybrid::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT(*) PrintType(op->dtype, os); os << "(" << std::setprecision(20) << op->value << ")"; } -void CodeGenHybrid::VisitExpr_(const StringImmNode* op, std::ostream& os) { // NOLINT(*) +void CodeGenHybrid::VisitExpr_(const StringImmNode* op, std::ostream& os) { // NOLINT(*) os << "'" << op->value << "'"; } -template -inline void PrintBinaryExpr(const T* op, - const char* opstr, +template +inline void PrintBinaryExpr(const T* op, const char* opstr, std::ostream& os, // NOLINT(*) CodeGenHybrid* p) { - CHECK(op->dtype.lanes() == 1) << "vec bin op not implemented"; + CHECK(op->dtype.lanes() == 1) << "vec bin op not implemented"; if (isalpha(opstr[0])) { os << opstr << '('; p->PrintExpr(op->a, os); @@ -111,11 +110,10 @@ inline void PrintBinaryExpr(const T* op, } } -inline void PrintBinaryIntrinsitc(const CallNode* op, - const char* opstr, +inline void PrintBinaryIntrinsitc(const CallNode* op, const char* opstr, std::ostream& os, // NOLINT(*) CodeGenHybrid* p) { - CHECK(op->dtype.lanes() == 1) << "vec bin intrin not implemented"; + CHECK(op->dtype.lanes() == 1) << "vec bin intrin not implemented"; CHECK_EQ(op->args.size(), 2U); os << '('; p->PrintExpr(op->args[0], os); @@ -204,18 +202,21 @@ void CodeGenHybrid::VisitExpr_(const NotNode* op, std::ostream& os) { // NOLINT PrintExpr(op->a, os); } +void CodeGenHybrid::VisitExpr_(const ProducerLoadNode* op, std::ostream& os) { // NOLINT(*) + auto tensor = Downcast(op->producer); + + os << GetTensorID(tensor); + os << "["; + for (size_t i = 0; i < op->indices.size(); ++i) { + if (i) os << ", "; + std::stringstream idx; + PrintExpr(op->indices[i], idx); + os << idx.str(); + } + os << "]"; +} void CodeGenHybrid::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) - if (op->call_type == CallNode::Halide) { - os << GetTensorID(op->func, op->value_index); - os << "["; - for (size_t i = 0; i < op->args.size(); ++i) { - if (i) os << ", "; - std::stringstream idx; - PrintExpr(op->args[i], idx); - os << idx.str(); - } - os << "]"; - } else if (op->is_intrinsic(CallNode::bitwise_and)) { + if (op->is_intrinsic(CallNode::bitwise_and)) { PrintBinaryIntrinsitc(op, "&", os, this); } else if (op->is_intrinsic(CallNode::bitwise_xor)) { PrintBinaryIntrinsitc(op, "^", os, this); @@ -252,9 +253,7 @@ void CodeGenHybrid::VisitExpr_(const LoadNode* op, std::ostream& os) { // NOLIN LOG(FATAL) << "Phase 0 has no Load(s)!"; } -void CodeGenHybrid::VisitStmt_(const StoreNode* op) { - LOG(FATAL) << "Phase 0 has no Store(s)!"; -} +void CodeGenHybrid::VisitStmt_(const StoreNode* op) { LOG(FATAL) << "Phase 0 has no Store(s)!"; } void CodeGenHybrid::VisitExpr_(const LetNode* op, std::ostream& os) { // NOLINT(*) LOG(FATAL) << "Phase 0 has no Let(s)!"; @@ -268,7 +267,7 @@ void CodeGenHybrid::VisitExpr_(const RampNode* op, std::ostream& os) { // NOLIN LOG(FATAL) << "Ramp to be supported yet"; } -void CodeGenHybrid::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*) +void CodeGenHybrid::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*) LOG(FATAL) << "Broadcast: not supported "; } @@ -293,15 +292,15 @@ void CodeGenHybrid::VisitStmt_(const AttrStmtNode* op) { CHECK(iter_var); binds_[iter_var->var.get()] = dot_to_underscore(iter_var->var->name_hint); PrintIndent(); - stream << "for " << binds_[iter_var->var.get()] << " in bind('" - << iter_var->var->name_hint << "', "; + stream << "for " << binds_[iter_var->var.get()] << " in bind('" << iter_var->var->name_hint + << "', "; PrintExpr(op->value, stream); stream << "):\n"; indent_ += tab_; PrintStmt(op->body); indent_ -= tab_; } else if (op->attr_key == tir::attr::realize_scope) { - auto v = Downcast(op->node); + auto v = Downcast(op->node); alloc_storage_scope_[v] = op->value.as()->value; PrintStmt(op->body); } else { @@ -310,20 +309,21 @@ void CodeGenHybrid::VisitStmt_(const AttrStmtNode* op) { } } -void CodeGenHybrid::VisitStmt_(const RealizeNode* op) { - CHECK(alloc_storage_scope_.count(op->func)); - if (!alloc_storage_scope_[op->func].empty()) { +void CodeGenHybrid::VisitStmt_(const ProducerRealizeNode* op) { + auto tensor = Downcast(op->producer); + CHECK(alloc_storage_scope_.count(tensor->op)); + if (!alloc_storage_scope_[tensor->op].empty()) { PrintIndent(); - stream << GetTensorID(op->func, op->value_index) << " = allocate(("; + stream << GetTensorID(tensor) << " = allocate(("; for (size_t i = 0; i < op->bounds.size(); ++i) { if (i) stream << ", "; stream << PrintExpr(op->bounds[i]->extent); } if (op->bounds.size() == 1) stream << ", "; stream << "), '"; - PrintType(op->dtype, stream); + PrintType(tensor->dtype, stream); stream << "', '"; - stream << alloc_storage_scope_[op->func] << "')\n"; + stream << alloc_storage_scope_[tensor->op] << "')\n"; } PrintStmt(op->body); } @@ -338,13 +338,14 @@ void CodeGenHybrid::VisitStmt_(const AssertStmtNode* op) { PrintStmt(op->body); } -void CodeGenHybrid::VisitStmt_(const ProvideNode* op) { +void CodeGenHybrid::VisitStmt_(const ProducerStoreNode* op) { + auto tensor = Downcast(op->producer); PrintIndent(); - stream << GetTensorID(op->func, op->value_index); + stream << GetTensorID(tensor); stream << "["; - for (size_t i = 0; i < op->args.size(); ++i) { + for (size_t i = 0; i < op->indices.size(); ++i) { if (i) stream << ", "; - PrintExpr(op->args[i], stream); + PrintExpr(op->indices[i], stream); } stream << "] = "; PrintExpr(op->value, stream); @@ -355,17 +356,16 @@ void CodeGenHybrid::VisitStmt_(const ForNode* op) { std::string extent = PrintExpr(op->extent); PrintIndent(); std::string vid = GetVarID(op->loop_var.get()); - stream << "for " << vid << " in " << "range(" << extent << "):\n"; + stream << "for " << vid << " in " + << "range(" << extent << "):\n"; indent_ += tab_; PrintStmt(op->body); indent_ -= tab_; } -bool is_noop(const Stmt &stmt) { - if (!stmt.defined()) - return true; - if (auto eval = stmt.as()) - return is_const(eval->value); +bool is_noop(const Stmt& stmt) { + if (!stmt.defined()) return true; + if (auto eval = stmt.as()) return is_const(eval->value); return false; } @@ -395,17 +395,13 @@ void CodeGenHybrid::VisitStmt_(const SeqStmtNode* op) { void CodeGenHybrid::VisitStmt_(const EvaluateNode* op) { if (is_const(op->value)) return; std::string str = PrintExpr(op->value); - if (!str.empty()) - stream << str << "\n"; + if (!str.empty()) stream << str << "\n"; } -void CodeGenHybrid::PrintIndent() { - stream << std::string(indent_, ' '); -} +void CodeGenHybrid::PrintIndent() { stream << std::string(indent_, ' '); } -std::string CodeGenHybrid::GetVarID(const VarNode *v) { - if (binds_.count(v)) - return binds_[v]; +std::string CodeGenHybrid::GetVarID(const VarNode* v) { + if (binds_.count(v)) return binds_[v]; auto key = std::make_pair(static_cast(v), 0); if (id_map_.count(key)) { return id_map_[key]; @@ -413,14 +409,14 @@ std::string CodeGenHybrid::GetVarID(const VarNode *v) { return id_map_[key] = GetUniqueName(v->name_hint); } -std::string CodeGenHybrid::GetTensorID(const FunctionRef &func, int value_index) { - auto key = std::make_pair(func.get(), value_index); +std::string CodeGenHybrid::GetTensorID(const Tensor& tensor) { + auto key = std::make_pair(tensor->op.get(), tensor->value_index); if (id_map_.count(key)) { return id_map_[key]; } - std::string name_hint = func->func_name(); - if (func->num_outputs() > 1) { - name_hint += "_v" + std::to_string(value_index); + std::string name_hint = tensor->op->name; + if (tensor->op->num_outputs() > 1) { + name_hint += "_v" + std::to_string(tensor->value_index); } return id_map_[key] = GetUniqueName(name_hint); } @@ -469,10 +465,8 @@ void CodeGenHybrid::ReserveKeywords() { GetUniqueName("max_num_threads"); } -void CodeGenHybrid::DumpStmt(const Stmt &stmt, - const Array &inputs, - const Array &outputs, - const std::string &name) { +void CodeGenHybrid::DumpStmt(const Stmt& stmt, const Array& inputs, + const Array& outputs, const std::string& name) { ReserveKeywords(); GetUniqueName(name); @@ -480,7 +474,7 @@ void CodeGenHybrid::DumpStmt(const Stmt &stmt, for (size_t i = 0; i < inputs.size(); ++i) { if (i) stream << ", "; if (auto tensor = inputs[i].as()) { - stream << GetTensorID(tensor->op, tensor->value_index); + stream << GetTensorID(GetRef(tensor)); } else { auto var = inputs[i].as(); CHECK(var) << "Input should either be a tensor or a variable!"; @@ -491,14 +485,12 @@ void CodeGenHybrid::DumpStmt(const Stmt &stmt, indent_ += tab_; for (size_t i = 0; i < outputs.size(); ++i) { PrintIndent(); - stream << GetTensorID(outputs[i]->op, outputs[i]->value_index) - << " = output_tensor(("; + stream << GetTensorID(outputs[i]) << " = output_tensor(("; for (size_t j = 0; j < outputs[i]->shape.size(); ++j) { if (j) stream << ", "; PrintExpr(outputs[i]->shape[j], stream); } - if (outputs[i]->shape.size() == 1) - stream << ", "; + if (outputs[i]->shape.size() == 1) stream << ", "; stream << "), '" << outputs[i]->dtype << "')\n"; } PrintStmt(stmt); @@ -506,19 +498,18 @@ void CodeGenHybrid::DumpStmt(const Stmt &stmt, stream << "return "; for (size_t i = 0; i < outputs.size(); ++i) { if (i) stream << ", "; - stream << GetTensorID(outputs[i]->op, outputs[i]->value_index); + stream << GetTensorID(outputs[i]); } stream << "\n"; } -TVM_REGISTER_GLOBAL("hybrid._Dump") -.set_body([](TVMArgs args, TVMRetValue* rv) { - CodeGenHybrid codegen; - if (args.size() == 4) - codegen.DumpStmt(args[0], args[1], args[2], args[3]); - else - codegen.DumpStmt(args[0], args[1], args[2]); - *rv = codegen.Finish(); - }); +TVM_REGISTER_GLOBAL("hybrid._Dump").set_body([](TVMArgs args, TVMRetValue* rv) { + CodeGenHybrid codegen; + if (args.size() == 4) + codegen.DumpStmt(args[0], args[1], args[2], args[3]); + else + codegen.DumpStmt(args[0], args[1], args[2]); + *rv = codegen.Finish(); +}); } // namespace contrib } // namespace tvm diff --git a/src/contrib/hybrid/codegen_hybrid.h b/src/contrib/hybrid/codegen_hybrid.h index d282edbb1926..b01ca2763e28 100644 --- a/src/contrib/hybrid/codegen_hybrid.h +++ b/src/contrib/hybrid/codegen_hybrid.h @@ -24,10 +24,12 @@ #ifndef TVM_CONTRIB_HYBRID_CODEGEN_HYBRID_H_ #define TVM_CONTRIB_HYBRID_CODEGEN_HYBRID_H_ -#include -#include #include +#include #include +#include +#include + #include #include #include @@ -45,9 +47,8 @@ using namespace tir; * **NOTE** CodeGenHybrid does not aim at generating Python scripts consumed by Python2/3. * For runtime support, please refer the decorator in ``tvm/python/hybrid/api.py``. */ -class CodeGenHybrid : - public ExprFunctor, - public StmtFunctor { +class CodeGenHybrid : public ExprFunctor, + public StmtFunctor { public: /*! * \brief Dump the given function body to hybrid script. @@ -56,8 +57,8 @@ class CodeGenHybrid : * \param outputs Output tensors of this schedule. * \param name The name of the function. */ - void DumpStmt(const Stmt &stmt, const Array &inputs, const Array &outputs, - const std::string &name = "hybrid_func"); + void DumpStmt(const Stmt& stmt, const Array& inputs, const Array& outputs, + const std::string& name = "hybrid_func"); /*! * \brief Finalize the compilation and return the code. * \return The code. @@ -69,64 +70,61 @@ class CodeGenHybrid : * \brief Print the Stmt n to CodeGenHybrid->stream * \param n The statement to be printed. */ - void PrintStmt(const Stmt &n) { - this->VisitStmt(n); - } + void PrintStmt(const Stmt& n) { this->VisitStmt(n); } /*! * \brief Print the expression n(or its ssa id if in ssa mode) into os * \param n The expression to be printed. * \param os The output stream */ - void PrintExpr(const PrimExpr &n, std::ostream &os) { - this->VisitExpr(n, os); - } + void PrintExpr(const PrimExpr& n, std::ostream& os) { this->VisitExpr(n, os); } /*! * \brief Same as PrintExpr, but simply returns result string * \param n The expression to be printed. */ - std::string PrintExpr(const PrimExpr &n) { + std::string PrintExpr(const PrimExpr& n) { std::ostringstream os; PrintExpr(n, os); return os.str(); } // expression - void VisitExpr_(const VarNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const LoadNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const LetNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const CallNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const AddNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const SubNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const MulNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const DivNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const ModNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const FloorDivNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const FloorModNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const MinNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const MaxNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const EQNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const NENode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const LTNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const LENode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const GTNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const GENode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const AndNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const OrNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const CastNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const NotNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const SelectNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const RampNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const BroadcastNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const IntImmNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const FloatImmNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const StringImmNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const VarNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const LoadNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const LetNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const CallNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const ProducerLoadNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const AddNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const SubNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const MulNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const DivNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const ModNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const FloorDivNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const FloorModNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const MinNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const MaxNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const EQNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const NENode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const LTNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const LENode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const GTNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const GENode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const AndNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const OrNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const CastNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const NotNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const SelectNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const RampNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const BroadcastNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const IntImmNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const FloatImmNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const StringImmNode* op, std::ostream& os) override; // NOLINT(*) // statment void VisitStmt_(const LetStmtNode* op) override; void VisitStmt_(const StoreNode* op) override; - void VisitStmt_(const ProvideNode* op) override; + void VisitStmt_(const ProducerStoreNode* op) override; void VisitStmt_(const ForNode* op) override; void VisitStmt_(const IfThenElseNode* op) override; void VisitStmt_(const AllocateNode* op) override; - void VisitStmt_(const RealizeNode* op) override; + void VisitStmt_(const ProducerRealizeNode* op) override; void VisitStmt_(const AttrStmtNode* op) override; void VisitStmt_(const AssertStmtNode* op) override; void VisitStmt_(const EvaluateNode* op) override; @@ -136,7 +134,7 @@ class CodeGenHybrid : * \param t The type representation. * \param os The stream to print the ctype into */ - virtual void PrintType(DataType t, std::ostream& os); // NOLINT(*) + virtual void PrintType(DataType t, std::ostream& os); // NOLINT(*) private: /*! \brief The current indent of the code dump. */ @@ -150,9 +148,9 @@ class CodeGenHybrid : /*! * \brief Keys are either (tensors, value_index) or (variables, 0). * Values are the corresponding IDs.*/ - std::map, std::string> id_map_; + std::map, std::string> id_map_; /*! \brief Variables (keys) binded to the threads (values). */ - std::map binds_; + std::map binds_; /*! * \brief Find an unallocated name for the given prefix. * \param prefix The given prefix. @@ -164,15 +162,14 @@ class CodeGenHybrid : * \brief Get or allocate the ID for the given variable. * \param v The given variable. */ - std::string GetVarID(const VarNode *v); + std::string GetVarID(const VarNode* v); /*! * \brief Get or allocate the ID for the given tensor. - * \param func The tensor to allocate a name. - * \param value_index The value index of the given tensor. + * \param tensor The tensor to allocate a name. */ - std::string GetTensorID(const FunctionRef &func, int value_index); + std::string GetTensorID(const Tensor& tensor); /*! \brief the storage scope of allocation */ - std::map alloc_storage_scope_; + std::map alloc_storage_scope_; }; } // namespace contrib diff --git a/src/contrib/tf_op/tvm_dso_op_kernels.cc b/src/contrib/tf_op/tvm_dso_op_kernels.cc index d74d8fb917e5..705a3347b68c 100644 --- a/src/contrib/tf_op/tvm_dso_op_kernels.cc +++ b/src/contrib/tf_op/tvm_dso_op_kernels.cc @@ -97,12 +97,29 @@ class TensorAsBuf { tensorflow::Status GetDLPackDtype(const tensorflow::Tensor& tf_tensor, DLDataType* res) { auto dtype = tf_tensor.dtype(); - if (dtype == tensorflow::DT_FLOAT) { + + if (dtype == tensorflow::DT_HALF) { + *res = {kDLFloat, 16, 1}; + } else if (dtype == tensorflow::DT_FLOAT) { *res = {kDLFloat, 32, 1}; - } else if (dtype == tensorflow::DT_INT64) { - *res = {kDLInt, 64, 1}; + } else if (dtype == tensorflow::DT_DOUBLE) { + *res = {kDLFloat, 64, 1}; + } else if (dtype == tensorflow::DT_INT8) { + *res = {kDLInt, 8, 1}; + } else if (dtype == tensorflow::DT_INT16) { + *res = {kDLInt, 16, 1}; } else if (dtype == tensorflow::DT_INT32) { *res = {kDLInt, 32, 1}; + } else if (dtype == tensorflow::DT_INT64) { + *res = {kDLInt, 64, 1}; + } else if (dtype == tensorflow::DT_UINT8) { + *res = {kDLUInt, 8, 1}; + } else if (dtype == tensorflow::DT_UINT16) { + *res = {kDLUInt, 16, 1}; + } else if (dtype == tensorflow::DT_UINT32) { + *res = {kDLUInt, 32, 1}; + } else if (dtype == tensorflow::DT_UINT64) { + *res = {kDLUInt, 64, 1}; } else { return tensorflow::Status(tensorflow::error::INTERNAL, "Fail to get dlpack datatype"); } diff --git a/src/contrib/tf_op/tvm_dso_ops.cc b/src/contrib/tf_op/tvm_dso_ops.cc index 1183b2ef34b5..794494298d71 100644 --- a/src/contrib/tf_op/tvm_dso_ops.cc +++ b/src/contrib/tf_op/tvm_dso_ops.cc @@ -21,11 +21,15 @@ REGISTER_OP("TvmDsoOp") .Input("input_args: ListT") - .Attr("ListT: list({int8, int32, int64, float16, float32})") + .Attr( + "ListT: list({float16, float32, float64, int8, int16, int32, int64, uint8, uint16," + "uint32, uint64})") .Input("dynamic_output_shape: int64") .Output("output: output_dtype") .Attr("lib_path: string") .Attr("func_name: string") - .Attr("output_dtype: {int8, int32, int64, float16, float32} = DT_FLOAT") + .Attr( + "output_dtype: {float16, float32, float64, int8, int16, int32, int64, uint8, uint16," + "uint32, uint64} = DT_FLOAT") .Attr("static_output_shape: list(int) >= 0 = []") .Attr("has_static_output_shape: bool"); diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index f576c842b25c..9d2a11c265dd 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -23,14 +23,13 @@ */ #include #include -#include - -#include -#include -#include -#include +#include #include #include +#include +#include +#include +#include #include #include @@ -38,9 +37,17 @@ namespace tvm { +// Register build pipeline related options +TVM_REGISTER_PASS_CONFIG_OPTION("tir.noalias", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.detect_global_barrier", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.instrument_bound_checkers", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_assert", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_vectorize", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.add_lower_pass", Array>); + +using runtime::PackedFunc; using runtime::TVMArgs; using runtime::TVMRetValue; -using runtime::PackedFunc; bool LLVMEnabled() { const runtime::PackedFunc* pf = runtime::Registry::Get("target.build.llvm"); @@ -60,12 +67,8 @@ Target DefaultTargetHost(Target target) { } } -tir::Buffer BufferWithOffsetAlignment(Array shape, - DataType dtype, - std::string name, - int data_alignment, - int offset_factor, - bool compact) { +tir::Buffer BufferWithOffsetAlignment(Array shape, DataType dtype, std::string name, + int data_alignment, int offset_factor, bool compact) { auto data = tir::Var(name, DataType::Handle()); bool has_any = false; if (!compact) { @@ -85,22 +88,18 @@ tir::Buffer BufferWithOffsetAlignment(Array shape, elem_offset = PrimExpr(); } - return tir::BufferNode::make(data, dtype, shape, Array(), elem_offset, name, "", - data_alignment, offset_factor, buffer_type); + return tir::Buffer(data, dtype, shape, Array(), elem_offset, name, "", data_alignment, + offset_factor, buffer_type); } -void GetBinds(const Array& args, - bool compact, +void GetBinds(const Array& args, bool compact, const std::unordered_map& binds, - Map* out_binds, - Array* out_arg_list, - const BuildConfig& config) { + Map* out_binds, Array* out_arg_list) { *out_binds = binds; - for (const auto &x : args) { + for (const auto& x : args) { if (out_binds->find(x) == out_binds->end()) { - auto buf = BufferWithOffsetAlignment(x->shape, x->dtype, x->op->name, - config->data_alignment, config->offset_factor, compact); + auto buf = BufferWithOffsetAlignment(x->shape, x->dtype, x->op->name, -1, 0, compact); out_binds->Set(x, buf); out_arg_list->push_back(buf); } else { @@ -109,64 +108,6 @@ void GetBinds(const Array& args, } } -/*! -* \brief Build a Stmt given a schedule, args and binds. This function runs the IR passes. -* \param sch The schedule to build. -* \param args The arguments for the schedule. -* \param binds Buffer assignments. -* \param loop_partition True if the LoopPartition pass should be included. -* \param out_arg_list Returns the arguments for the Stmt. -* \param config The build configuration. -* \return The built Stmt. -*/ -tir::Stmt BuildStmt(te::Schedule sch, - const Array& args, - const std::unordered_map& binds, - bool loop_partition, - Array *out_arg_list, - const BuildConfig& config) { - sch = sch.normalize(); - - // Phase 0 - auto bounds = te::InferBound(sch); - auto stmt = te::ScheduleOps(sch, bounds, false); - stmt = tir::InjectPrefetch(stmt); - - bool compact = tir::VerifyCompactBuffer(stmt); - Map out_binds; - GetBinds(args, compact, binds, &out_binds, out_arg_list, config); - - // Phase 1 - stmt = tir::StorageFlatten(stmt, out_binds, 64, - config->instrument_bound_checkers); - stmt = tir::CanonicalSimplify(stmt); - if (loop_partition) { - stmt = tir::LoopPartition(stmt, config->partition_const_loop); - } - if (config->disable_vectorize) { - stmt = tir::SkipVectorize(stmt); - } else { - stmt = tir::VectorizeLoop(stmt); - } - stmt = tir::InjectVirtualThread(stmt); - stmt = tir::InjectDoubleBuffer(stmt, config->double_buffer_split_loop); - stmt = tir::StorageRewrite(stmt); - stmt = tir::UnrollLoop(stmt, config->auto_unroll_max_step, config->auto_unroll_max_depth, - config->auto_unroll_max_extent, config->unroll_explicit); - - // Phase 2 - stmt = tir::Simplify(stmt); - stmt = tir::RemoveNoOp(stmt); - - if (!(config->disable_select_rewriting)) - stmt = tir::RewriteUnsafeSelect(stmt); - - if (config->instrument_bound_checkers) - stmt = tir::InstrumentBoundCheckers(stmt); - - return stmt; -} - transform::Pass BindTarget(Target target) { auto fpass = [target](tir::PrimFunc f, IRModule m, transform::PassContext ctx) { return WithAttr(std::move(f), tvm::attr::kTarget, target); @@ -174,9 +115,8 @@ transform::Pass BindTarget(Target target) { return tir::transform::CreatePrimFuncPass(fpass, 0, "BindTarget", {}); } - -template -transform::Pass FilterBy(FCond fcond) { +template +transform::Pass Filter(FCond fcond) { auto fpass = [fcond](tir::PrimFunc f, IRModule m, transform::PassContext ctx) { if (fcond(f)) { return f; @@ -184,52 +124,72 @@ transform::Pass FilterBy(FCond fcond) { return tir::PrimFunc(nullptr); } }; - return tir::transform::CreatePrimFuncPass(fpass, 0, "FilterBy", {}); + return tir::transform::CreatePrimFuncPass(fpass, 0, "Filter", {}); } - -IRModule lower(te::Schedule sch, - const Array& args, - const std::string& name, - const std::unordered_map& binds, - const BuildConfig& config) { +IRModule lower(te::Schedule sch, const Array& args, const std::string& name, + const std::unordered_map& binds) { Array out_arg_list; - auto stmt = BuildStmt(sch, args, binds, true, &out_arg_list, config); + auto pass_ctx = transform::PassContext::Current(); - Array params; - Map buffer_map; + sch = sch.normalize(); - for (auto var : out_arg_list) { - if (auto* n = var.as()) { - params.push_back(GetRef(n)); - } else { - tir::Buffer buffer = Downcast(var); - tir::Var bptr(buffer->name, DataType::Handle()); - params.push_back(bptr); - buffer_map.Set(bptr, buffer); - } - } + // Before TIR transformation. + auto bounds = te::InferBound(sch); + auto stmt = te::ScheduleOps(sch, bounds, false); + bool compact = te::VerifyCompactBuffer(stmt); - auto f = tir::PrimFunc(params, stmt, VoidType(), buffer_map); + Map out_binds; + GetBinds(args, compact, binds, &out_binds, &out_arg_list); + + // build the function + tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds); f = WithAttr(std::move(f), "global_symbol", runtime::String(name)); - if (config->restricted_func) { - f = WithAttr(std::move(f), "tir.noalias", Integer(1)); + bool noalias = pass_ctx->GetConfig("tir.noalias", Bool(true)).value(); + bool disable_vectorize = pass_ctx->GetConfig("tir.disable_vectorize", Bool(false)).value(); + bool instrument_bound_checkers = + pass_ctx->GetConfig("tir.instrument_bound_checkers", Bool(false)).value(); + + if (noalias) { + f = WithAttr(std::move(f), "tir.noalias", Bool(true)); } - return IRModule(Map({{GlobalVar(name), f}})); -} + auto mod = IRModule(Map({{GlobalVar(name), f}})); + auto pass_list = Array(); -std::pair -split_dev_host_funcs(IRModule mod_mixed, - const Target& target, - const Target& target_host, - const BuildConfig& config) { - mod_mixed = BindTarget(target)(std::move(mod_mixed)); - tir::VerifyMemory(mod_mixed); + // Phase 0 + pass_list.push_back(tir::transform::InjectPrefetch()); + pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers)); + // Phase 1 + pass_list.push_back(tir::transform::NarrowDataType(32)); + pass_list.push_back(tir::transform::Simplify()); + pass_list.push_back(tir::transform::LoopPartition()); + pass_list.push_back(tir::transform::VectorizeLoop(!disable_vectorize)); + pass_list.push_back(tir::transform::InjectVirtualThread()); + pass_list.push_back(tir::transform::InjectDoubleBuffer()); + pass_list.push_back(tir::transform::StorageRewrite()); + pass_list.push_back(tir::transform::UnrollLoop()); + // Phase 2 + pass_list.push_back(tir::transform::Simplify()); + pass_list.push_back(tir::transform::RemoveNoOp()); + pass_list.push_back(tir::transform::RewriteUnsafeSelect()); + if (instrument_bound_checkers) { + pass_list.push_back(tir::transform::InstrumentBoundCheckers()); + } + // run + auto optimize = transform::Sequential(pass_list); + mod = optimize(std::move(mod)); + return mod; +} + +std::pair SplitDevHostFuncs(IRModule mod_mixed, const Target& target, + const Target& target_host, + const transform::PassContext& pass_ctx) { + Array mixed_pass_list = {BindTarget(target), + tir::transform::VerifyMemory()}; - Array mixed_pass_list = {BindTarget(target)}; - if (config->detect_global_barrier) { + if (pass_ctx->GetConfig("tir.detect_global_barrier", Bool(false)).value()) { mixed_pass_list.push_back(tir::transform::ThreadSync("global")); } mixed_pass_list.push_back(tir::transform::ThreadSync("shared")); @@ -242,31 +202,30 @@ split_dev_host_funcs(IRModule mod_mixed, mod_mixed = opt_mixed(std::move(mod_mixed)); auto host_pass_list = { - FilterBy([](const tir::PrimFunc& f) { - return f->GetAttr( - tvm::attr::kCallingConv, - Integer(CallingConv::kDefault)) != CallingConv::kDeviceKernelLaunch; - }), - BindTarget(target_host), - tir::transform::LowerTVMBuiltin(), - tir::transform::LowerIntrin(), - tir::transform::LowerDeviceStorageAccessInfo(), - tir::transform::CombineContextCall(), + Filter([](const tir::PrimFunc& f) { + return f->GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) != + CallingConv::kDeviceKernelLaunch; + }), + BindTarget(target_host), + tir::transform::LowerTVMBuiltin(), + tir::transform::LowerIntrin(), + tir::transform::LowerDeviceStorageAccessInfo(), + tir::transform::CombineContextCall(), }; auto opt_host = transform::Sequential(host_pass_list); auto mhost = opt_host(mod_mixed); // device pipeline auto device_pass_list = { - FilterBy([](const tir::PrimFunc& f) { - return f->GetAttr( - tvm::attr::kCallingConv, - Integer(CallingConv::kDefault)) == CallingConv::kDeviceKernelLaunch; - }), - BindTarget(target), - tir::transform::LowerWarpMemory(), - tir::transform::LowerIntrin(), - tir::transform::LowerDeviceStorageAccessInfo(), + Filter([](const tir::PrimFunc& f) { + return f->GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) == + CallingConv::kDeviceKernelLaunch; + }), + BindTarget(target), + tir::transform::LowerWarpMemory(), + tir::transform::Simplify(), + tir::transform::LowerIntrin(), + tir::transform::LowerDeviceStorageAccessInfo(), }; auto opt_device = transform::Sequential(device_pass_list); auto mdevice = opt_device(mod_mixed); @@ -275,33 +234,28 @@ split_dev_host_funcs(IRModule mod_mixed, auto keys = target->keys(); bool target_is_gpu = std::find(keys.begin(), keys.end(), "gpu") != keys.end(); if (target_is_gpu && mdevice->functions.size() == 0) { - LOG(WARNING) << "Specified target " - << target->str() + LOG(WARNING) << "Specified target " << target->str() << " but cannot find device code. Did you forget to bind?"; } - if (target->device_type == target::llvm()->device_type && - target_host == target) { - CHECK(mdevice->functions.empty()) - << "No device code should be generated when target " - << "and host_target are both llvm target." - << "\n"; + if (target->device_type == target::llvm()->device_type && target_host == target) { + CHECK(mdevice->functions.empty()) << "No device code should be generated when target " + << "and host_target are both llvm target." + << "\n"; } return {mhost, mdevice}; } - // Build for heterogeneous execution. -runtime::Module build(const Map& inputs, - const Target& target_host, - const BuildConfig& config) { - std::vector device_modules; +runtime::Module build(const Map& inputs, const Target& target_host) { + auto pass_ctx = transform::PassContext::Current(); + std::vector device_modules; Target target_host_val = target_host; if (!target_host.defined()) { for (const auto& it : inputs) { - if (it.first->device_type == kDLCPU) { + if (it.first->device_type == kDLCPU || it.first->device_type == kDLMicroDev) { target_host_val = it.first; break; } @@ -315,8 +269,7 @@ runtime::Module build(const Map& inputs, IRModule mhost_all = IRModule(Map()); for (const auto& it : inputs) { - auto pair = - split_dev_host_funcs(it.second, it.first, target_host_val, config); + auto pair = SplitDevHostFuncs(it.second, it.first, target_host_val, pass_ctx); auto& mhost = pair.first; auto& mdevice = pair.second; @@ -337,9 +290,7 @@ runtime::Module build(const Map& inputs, } // Build for heterogeneous execution when target is a string. -runtime::Module build(const Map& inputs, - const Target& target_host, - const BuildConfig& config) { +runtime::Module build(const Map& inputs, const Target& target_host) { Map updated_input; for (const auto& it : inputs) { auto target = Target::Create(it.first); @@ -348,16 +299,13 @@ runtime::Module build(const Map& inputs, } updated_input.Set(target, it.second); } - return build(updated_input, target_host, config); + return build(updated_input, target_host); } // Build for homogeneous execution. -runtime::Module build(const IRModule& funcs, - const Target& target, - const Target& target_host, - const BuildConfig& config) { +runtime::Module build(const IRModule& funcs, const Target& target, const Target& target_host) { Map inputs = {{target, funcs}}; - return build(inputs, target_host, config); + return build(inputs, target_host); } } // namespace tvm diff --git a/src/ir/adt.cc b/src/ir/adt.cc index 4650a3bed4a7..f0ce859f3f87 100644 --- a/src/ir/adt.cc +++ b/src/ir/adt.cc @@ -21,14 +21,12 @@ * \file src/ir/adt.cc * \brief ADT type definitions. */ -#include #include +#include namespace tvm { -Constructor::Constructor(std::string name_hint, - tvm::Array inputs, - GlobalTypeVar belong_to) { +Constructor::Constructor(String name_hint, tvm::Array inputs, GlobalTypeVar belong_to) { ObjectPtr n = make_object(); n->name_hint = std::move(name_hint); n->inputs = std::move(inputs); @@ -39,21 +37,18 @@ Constructor::Constructor(std::string name_hint, TVM_REGISTER_NODE_TYPE(ConstructorNode); TVM_REGISTER_GLOBAL("ir.Constructor") -.set_body_typed([](std::string name_hint, - tvm::Array inputs, - GlobalTypeVar belong_to) { - return Constructor(name_hint, inputs, belong_to); -}); + .set_body_typed([](String name_hint, tvm::Array inputs, GlobalTypeVar belong_to) { + return Constructor(name_hint, inputs, belong_to); + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "ConstructorNode(" << node->name_hint << ", " - << node->inputs << ", " << node->belong_to << ")"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "ConstructorNode(" << node->name_hint << ", " << node->inputs << ", " + << node->belong_to << ")"; + }); -TypeData::TypeData(GlobalTypeVar header, - tvm::Array type_vars, +TypeData::TypeData(GlobalTypeVar header, tvm::Array type_vars, tvm::Array constructors) { ObjectPtr n = make_object(); n->header = std::move(header); @@ -65,17 +60,16 @@ TypeData::TypeData(GlobalTypeVar header, TVM_REGISTER_NODE_TYPE(TypeDataNode); TVM_REGISTER_GLOBAL("ir.TypeData") -.set_body_typed([](GlobalTypeVar header, - tvm::Array type_vars, - tvm::Array constructors) { - return TypeData(header, type_vars, constructors); -}); + .set_body_typed([](GlobalTypeVar header, tvm::Array type_vars, + tvm::Array constructors) { + return TypeData(header, type_vars, constructors); + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "TypeDataNode(" << node->header << ", " << node->type_vars << ", " - << node->constructors << ")"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "TypeDataNode(" << node->header << ", " << node->type_vars << ", " + << node->constructors << ")"; + }); } // namespace tvm diff --git a/src/ir/attr_functor.h b/src/ir/attr_functor.h index dbd5a4fab23b..12b4f6f65b11 100644 --- a/src/ir/attr_functor.h +++ b/src/ir/attr_functor.h @@ -32,6 +32,7 @@ #include #include + #include namespace tvm { @@ -39,16 +40,13 @@ namespace tvm { template class AttrFunctor; -#define ATTR_FUNCTOR_DEFAULT \ +#define ATTR_FUNCTOR_DEFAULT \ { return VisitAttrDefault_(op, std::forward(args)...); } - -#define ATTR_FUNCTOR_DISPATCH(OP) \ - vtable.template set_dispatch( \ - [](const ObjectRef& n, TSelf* self, Args... args) { \ - return self->VisitAttr_(static_cast(n.get()), \ - std::forward(args)...); \ - }); \ +#define ATTR_FUNCTOR_DISPATCH(OP) \ + vtable.template set_dispatch([](const ObjectRef& n, TSelf* self, Args... args) { \ + return self->VisitAttr_(static_cast(n.get()), std::forward(args)...); \ + }); // A functor for common attribute information. template @@ -78,7 +76,6 @@ class AttrFunctor { } virtual R VisitAttrDefault_(const Object* node, Args... args) = 0; virtual R VisitAttr_(const ArrayNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; - virtual R VisitAttr_(const StrMapNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const tir::IntImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const tir::FloatImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const tir::StringImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; @@ -115,7 +112,6 @@ class AttrFunctor { using namespace tir; FType vtable; // Set dispatch - ATTR_FUNCTOR_DISPATCH(StrMapNode); ATTR_FUNCTOR_DISPATCH(ArrayNode); ATTR_FUNCTOR_DISPATCH(IntImmNode); ATTR_FUNCTOR_DISPATCH(FloatImmNode); diff --git a/src/ir/attrs.cc b/src/ir/attrs.cc index bee103d7ed20..af46439cff7c 100644 --- a/src/ir/attrs.cc +++ b/src/ir/attrs.cc @@ -22,26 +22,22 @@ */ #include #include + #include "attr_functor.h" namespace tvm { -void DictAttrsNode::VisitAttrs(AttrVisitor* v) { - v->Visit("__dict__", &dict); -} +void DictAttrsNode::VisitAttrs(AttrVisitor* v) { v->Visit("__dict__", &dict); } -void DictAttrsNode::VisitNonDefaultAttrs(AttrVisitor* v) { - v->Visit("__dict__", &dict); -} +void DictAttrsNode::VisitNonDefaultAttrs(AttrVisitor* v) { v->Visit("__dict__", &dict); } -void DictAttrsNode::InitByPackedArgs( - const runtime::TVMArgs& args, bool allow_unknown) { +void DictAttrsNode::InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) { for (int i = 0; i < args.size(); i += 2) { std::string key = args[i]; runtime::TVMArgValue val = args[i + 1]; if (val.IsObjectRef()) { dict.Set(key, val.operator ObjectRef()); - } else if (val.type_code() == kTVMStr) { + } else if (String::CanConvertFrom(val)) { dict.Set(key, val.operator String()); } else { dict.Set(key, val.operator PrimExpr()); @@ -49,33 +45,29 @@ void DictAttrsNode::InitByPackedArgs( } } -Array DictAttrsNode::ListFieldInfo() const { - return {}; -} +Array DictAttrsNode::ListFieldInfo() const { return {}; } -DictAttrs::DictAttrs(Map dict) { +DictAttrs::DictAttrs(Map dict) { ObjectPtr n = make_object(); n->dict = std::move(dict); data_ = std::move(n); } TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << op->dict; -}); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << op->dict; + }); TVM_REGISTER_NODE_TYPE(DictAttrsNode); TVM_REGISTER_NODE_TYPE(AttrFieldInfoNode); -TVM_REGISTER_GLOBAL("ir.DictAttrsGetDict") -.set_body_typed([](DictAttrs attrs) { +TVM_REGISTER_GLOBAL("ir.DictAttrsGetDict").set_body_typed([](DictAttrs attrs) { return attrs->dict; }); -TVM_REGISTER_GLOBAL("ir.AttrsListFieldInfo") -.set_body_typed([](Attrs attrs) { +TVM_REGISTER_GLOBAL("ir.AttrsListFieldInfo").set_body_typed([](Attrs attrs) { return attrs->ListFieldInfo(); }); diff --git a/src/ir/env_func.cc b/src/ir/env_func.cc index 4d3ed30bc032..7b0d6e6f09c2 100644 --- a/src/ir/env_func.cc +++ b/src/ir/env_func.cc @@ -26,16 +26,15 @@ namespace tvm { - using runtime::PackedFunc; using runtime::TVMArgs; using runtime::TVMRetValue; TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "EnvFunc(" << op->name << ")"; -}); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "EnvFunc(" << op->name << ")"; + }); ObjectPtr CreateEnvNode(const std::string& name) { auto* f = runtime::Registry::Get(name); @@ -46,31 +45,24 @@ ObjectPtr CreateEnvNode(const std::string& name) { return n; } -EnvFunc EnvFunc::Get(const std::string& name) { - return EnvFunc(CreateEnvNode(name)); -} +EnvFunc EnvFunc::Get(const String& name) { return EnvFunc(CreateEnvNode(name)); } -TVM_REGISTER_GLOBAL("ir.EnvFuncGet") -.set_body_typed(EnvFunc::Get); +TVM_REGISTER_GLOBAL("ir.EnvFuncGet").set_body_typed(EnvFunc::Get); -TVM_REGISTER_GLOBAL("ir.EnvFuncCall") -.set_body([](TVMArgs args, TVMRetValue* rv) { - EnvFunc env = args[0]; - CHECK_GE(args.size(), 1); - env->func.CallPacked(TVMArgs(args.values + 1, - args.type_codes + 1, - args.size() - 1), rv); - }); +TVM_REGISTER_GLOBAL("ir.EnvFuncCall").set_body([](TVMArgs args, TVMRetValue* rv) { + EnvFunc env = args[0]; + CHECK_GE(args.size(), 1); + env->func.CallPacked(TVMArgs(args.values + 1, args.type_codes + 1, args.size() - 1), rv); +}); -TVM_REGISTER_GLOBAL("ir.EnvFuncGetPackedFunc") -.set_body_typed([](const EnvFunc&n) { - return n->func; - }); +TVM_REGISTER_GLOBAL("ir.EnvFuncGetPackedFunc").set_body_typed([](const EnvFunc& n) { + return n->func; +}); TVM_REGISTER_NODE_TYPE(EnvFuncNode) -.set_creator(CreateEnvNode) -.set_repr_bytes([](const Object* n) -> std::string { - return static_cast(n)->name; - }); + .set_creator(CreateEnvNode) + .set_repr_bytes([](const Object* n) -> std::string { + return static_cast(n)->name; + }); } // namespace tvm diff --git a/src/ir/error.cc b/src/ir/error.cc index 9d498288d2ba..5cd7a247d025 100644 --- a/src/ir/error.cc +++ b/src/ir/error.cc @@ -22,8 +22,8 @@ * \brief Utilities for error tracking and reporting. */ -#include #include +#include // NOTE: reverse dependency on relay. // These dependencies do not happen at the interface-level, // and are only used in minimum cases where they are clearly marked. @@ -31,14 +31,16 @@ // Rationale: use relay's printer for astext. #include +// clang-format off #include #include #include +// clang-format on namespace tvm { -template -using NodeMap = std::unordered_map; +template +using NodeMap = std::unordered_map; void ErrorReporter::RenderErrors(const IRModule& module, bool use_color) { // First we pick an error reporting strategy for each error. @@ -76,9 +78,9 @@ void ErrorReporter::RenderErrors(const IRModule& module, bool use_color) { // Setup error map. auto it = error_maps.find(global); if (it != error_maps.end()) { - it->second.insert({ node, err_msg.str() }); + it->second.insert({node, err_msg.str()}); } else { - error_maps.insert({ global, { { node, err_msg.str() }}}); + error_maps.insert({global, {{node, err_msg.str()}}}); } } @@ -87,10 +89,10 @@ void ErrorReporter::RenderErrors(const IRModule& module, bool use_color) { std::stringstream annotated_prog; // First we output a header for the errors. - annotated_prog << - rang::style::bold << std::endl << - "Error(s) have occurred. The program has been annotated with them:" - << std::endl << std::endl << rang::style::reset; + annotated_prog << rang::style::bold << std::endl + << "Error(s) have occurred. The program has been annotated with them:" << std::endl + << std::endl + << rang::style::reset; // For each global function which contains errors, we will // construct an annotated function. @@ -101,11 +103,8 @@ void ErrorReporter::RenderErrors(const IRModule& module, bool use_color) { // We output the name of the function before displaying // the annotated program. - annotated_prog << - rang::style::bold << - "In `" << global->name_hint << "`: " << - std::endl << - rang::style::reset; + annotated_prog << rang::style::bold << "In `" << global->name_hint << "`: " << std::endl + << rang::style::reset; // We then call into the Relay printer to generate the program. // @@ -140,9 +139,9 @@ void ErrorReporter::ReportAt(const GlobalVar& global, const ObjectRef& node, con if (it != this->node_to_error_.end()) { it->second.push_back(index_to_insert); } else { - this->node_to_error_.insert({ node, { index_to_insert }}); + this->node_to_error_.insert({node, {index_to_insert}}); } - this->node_to_gv_.insert({ node, global }); + this->node_to_gv_.insert({node, global}); } } // namespace tvm diff --git a/src/ir/expr.cc b/src/ir/expr.cc index 7272213ad406..289477e096f3 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -21,9 +21,9 @@ * \file src/ir/expr.cc * \brief The expression AST nodes for the common IR infra. */ -#include #include #include +#include // NOTE: reverse dependency on top/tir. // These dependencies do not happen at the interface-level, // and are only used in minimum cases where they are clearly marked. @@ -34,11 +34,9 @@ namespace tvm { -PrimExpr::PrimExpr(int32_t value) - : PrimExpr(IntImm(DataType::Int(32), value)) {} +PrimExpr::PrimExpr(int32_t value) : PrimExpr(IntImm(DataType::Int(32), value)) {} -PrimExpr::PrimExpr(float value) - : PrimExpr(FloatImm(DataType::Float(32), value)) {} +PrimExpr::PrimExpr(float value) : PrimExpr(FloatImm(DataType::Float(32), value)) {} PrimExpr PrimExpr::FromObject_(ObjectRef ref) { using runtime::ObjectTypeChecker; @@ -49,20 +47,17 @@ PrimExpr PrimExpr::FromObject_(ObjectRef ref) { return GetRef(ptr)(); } if (auto* ptr = ref.as()) { - return tir::StringImmNode::make(GetRef(ptr)); + return tir::StringImm(GetRef(ptr)); } CHECK(ObjectTypeChecker::Check(ref.get())) - << "Expect type " << ObjectTypeChecker::TypeName() - << " but get " << ref->GetTypeKey(); + << "Expect type " << ObjectTypeChecker::TypeName() << " but get " + << ref->GetTypeKey(); return Downcast(ref); } - IntImm::IntImm(DataType dtype, int64_t value) { - CHECK(dtype.is_scalar()) - << "ValueError: IntImm can only take scalar."; - CHECK(dtype.is_int() || dtype.is_uint()) - << "ValueError: IntImm can only take scalar."; + CHECK(dtype.is_scalar()) << "ValueError: IntImm can only take scalar."; + CHECK(dtype.is_int() || dtype.is_uint()) << "ValueError: IntImm can only take scalar."; if (dtype.is_uint()) { CHECK_GE(value, 0U); } @@ -72,88 +67,77 @@ IntImm::IntImm(DataType dtype, int64_t value) { data_ = std::move(node); } -TVM_REGISTER_GLOBAL("ir.IntImm") -.set_body_typed([](DataType dtype, int64_t value) { +TVM_REGISTER_GLOBAL("ir.IntImm").set_body_typed([](DataType dtype, int64_t value) { return IntImm(dtype, value); }); TVM_REGISTER_NODE_TYPE(IntImmNode); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - if (op->dtype == DataType::Int(32)) { - p->stream << op->value; - } else { - p->stream << "(" << op->dtype << ")" << op->value; - } - }); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + if (op->dtype == DataType::Int(32)) { + p->stream << op->value; + } else { + p->stream << "(" << op->dtype << ")" << op->value; + } + }); FloatImm::FloatImm(DataType dtype, double value) { - CHECK_EQ(dtype.lanes(), 1) - << "ValueError: FloatImm can only take scalar."; + CHECK_EQ(dtype.lanes(), 1) << "ValueError: FloatImm can only take scalar."; ObjectPtr node = make_object(); node->dtype = dtype; node->value = value; data_ = std::move(node); } -TVM_REGISTER_GLOBAL("ir.FloatImm") -.set_body_typed([](DataType dtype, double value) { +TVM_REGISTER_GLOBAL("ir.FloatImm").set_body_typed([](DataType dtype, double value) { return FloatImm(dtype, value); }); TVM_REGISTER_NODE_TYPE(FloatImmNode); - TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - auto& stream = p->stream; - switch (op->dtype.bits()) { - case 64: - stream << op->value; - break; - case 32: - stream << op->value << 'f'; - break; - case 16: - stream << op->value << 'h'; - break; - default: - LOG(FATAL) << "Unknown float type bits=" << op->dtype.bits(); - } - }); - + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + auto& stream = p->stream; + switch (op->dtype.bits()) { + case 64: + stream << op->value; + break; + case 32: + stream << op->value << 'f'; + break; + case 16: + stream << op->value << 'h'; + break; + default: + LOG(FATAL) << "Unknown float type bits=" << op->dtype.bits(); + } + }); Range::Range(PrimExpr begin, PrimExpr end) - : Range(make_object( - begin, - tir::is_zero(begin) ? end : (end - begin))) { -} + : Range(make_object(begin, tir::is_zero(begin) ? end : (end - begin))) {} Range Range::make_by_min_extent(PrimExpr min, PrimExpr extent) { return Range(make_object(min, extent)); } -TVM_REGISTER_GLOBAL("ir.range_by_min_extent") -.set_body_typed(Range::make_by_min_extent); +TVM_REGISTER_GLOBAL("ir.range_by_min_extent").set_body_typed(Range::make_by_min_extent); -TVM_REGISTER_GLOBAL("ir.Range") -.set_body([](TVMArgs args, TVMRetValue* ret) { +TVM_REGISTER_GLOBAL("ir.Range").set_body([](TVMArgs args, TVMRetValue* ret) { *ret = Range(args[0], args[1]); - }); +}); TVM_REGISTER_NODE_TYPE(RangeNode); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "range(min=" << op->min << ", ext=" << op->extent << ')'; - }); - + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "range(min=" << op->min << ", ext=" << op->extent << ')'; + }); -GlobalVar::GlobalVar(std::string name_hint) { +GlobalVar::GlobalVar(String name_hint) { ObjectPtr n = make_object(); n->name_hint = std::move(name_hint); data_ = std::move(n); @@ -161,57 +145,51 @@ GlobalVar::GlobalVar(std::string name_hint) { TVM_REGISTER_NODE_TYPE(GlobalVarNode); -TVM_REGISTER_GLOBAL("ir.GlobalVar") -.set_body_typed([](std::string name){ - return GlobalVar(name); -}); +TVM_REGISTER_GLOBAL("ir.GlobalVar").set_body_typed([](String name) { return GlobalVar(name); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "GlobalVar(" << node->name_hint << ")"; - }); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "GlobalVar(" << node->name_hint << ")"; + }); // Container printer TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << '['; - for (size_t i = 0 ; i < op->data.size(); ++i) { - if (i != 0) { - p->stream << ", "; + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << '['; + for (size_t i = 0; i < op->size(); ++i) { + if (i != 0) { + p->stream << ", "; + } + p->Print(op->at(i)); } - p->Print(op->data[i]); - } - p->stream << ']'; -}); + p->stream << ']'; + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << '{'; - for (auto it = op->data.begin(); it != op->data.end(); ++it) { - if (it != op->data.begin()) { - p->stream << ", "; + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << '{'; + for (auto it = op->data.begin(); it != op->data.end(); ++it) { + if (it != op->data.begin()) { + p->stream << ", "; + } + if (it->first->IsInstance()) { + p->stream << '\"' << Downcast(it->first) << "\": "; + } else { + p->Print(it->first); + p->stream << ": "; + } + p->Print(it->second); } - p->Print(it->first); - p->stream << ": "; - p->Print(it->second); - } - p->stream << '}'; - }); + p->stream << '}'; + }); + +TVM_REGISTER_GLOBAL("ir.DebugPrint").set_body_typed([](ObjectRef ref) { + std::stringstream ss; + ss << ref; + return ss.str(); +}); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << '{'; - for (auto it = op->data.begin(); it != op->data.end(); ++it) { - if (it != op->data.begin()) { - p->stream << ", "; - } - p->stream << '\"' << it->first << "\": "; - p->Print(it->second); - } - p->stream << '}'; - }); } // namespace tvm diff --git a/src/ir/function.cc b/src/ir/function.cc index 08cdc93e28b5..c0cda704c424 100644 --- a/src/ir/function.cc +++ b/src/ir/function.cc @@ -21,40 +21,32 @@ * \file src/ir/function.cc * \brief The function data structure. */ -#include #include +#include // NOTE: reverse dependency on relay, tir/ // These dependencies do not happen at the interface-level, // and are only used in minimum cases where they are clearly marked. // // Rationale: We calls into the type specific WithAttr function -#include #include - +#include namespace tvm { -TVM_REGISTER_GLOBAL("ir.BaseFunc_Attrs") -.set_body_typed([](BaseFunc func) { - return func->attrs; -}); +TVM_REGISTER_GLOBAL("ir.BaseFunc_Attrs").set_body_typed([](BaseFunc func) { return func->attrs; }); -TVM_REGISTER_GLOBAL("ir.BaseFuncCopy") -.set_body_typed([](BaseFunc func) { - return func; -}); +TVM_REGISTER_GLOBAL("ir.BaseFuncCopy").set_body_typed([](BaseFunc func) { return func; }); TVM_REGISTER_GLOBAL("ir.BaseFuncWithAttr") -.set_body_typed([](BaseFunc func, std::string key, ObjectRef value) -> BaseFunc { - if (func->IsInstance()) { - return WithAttr(Downcast(std::move(func)), key, value); - } else if (func->IsInstance()) { - return WithAttr(Downcast(std::move(func)), key, value); - } else { - LOG(FATAL) << "Do not support function type " << func->GetTypeKey(); - return func; - } -}); - + .set_body_typed([](BaseFunc func, String key, ObjectRef value) -> BaseFunc { + if (func->IsInstance()) { + return WithAttr(Downcast(std::move(func)), key, value); + } else if (func->IsInstance()) { + return WithAttr(Downcast(std::move(func)), key, value); + } else { + LOG(FATAL) << "Do not support function type " << func->GetTypeKey(); + return func; + } + }); } // namespace tvm diff --git a/src/ir/module.cc b/src/ir/module.cc index 6262150556c7..c7393749dc37 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -21,9 +21,9 @@ * \file module.cc * \brief The global module in Relay. */ -#include #include #include +#include // NOTE: reverse dependency on relay. // These dependencies do not happen at the interface-level, // and are only used in minimum cases where they are clearly marked. @@ -32,15 +32,15 @@ #include #include -#include #include +#include #include namespace tvm { IRModule::IRModule(tvm::Map functions, tvm::Map type_definitions, - std::unordered_set import_set) { + std::unordered_set import_set) { auto n = make_object(); n->functions = std::move(functions); n->type_definitions = std::move(type_definitions); @@ -52,14 +52,14 @@ IRModule::IRModule(tvm::Map functions, for (const auto& kv : n->functions) { // set global var map CHECK(n->global_var_map_.count(kv.first->name_hint) == 0) - << "Duplicate global function name " << kv.first->name_hint; + << "Duplicate global function name " << kv.first->name_hint; n->global_var_map_.Set(kv.first->name_hint, kv.first); } for (const auto& kv : n->type_definitions) { // set global typevar map CHECK(n->global_type_var_map_.count(kv.first->name_hint) == 0) - << "Duplicate global type definition name " << kv.first->name_hint; + << "Duplicate global type definition name " << kv.first->name_hint; n->global_type_var_map_.Set(kv.first->name_hint, kv.first); n->RegisterConstructors(kv.first, kv.second); } @@ -87,9 +87,8 @@ void IRModuleNode::SHashReduce(SHashReducer hash_reduce) const { auto reduce_temp = [&]() { // sort by the hash key of the keys. - std::sort(temp.begin(), temp.end(), [](const KV& lhs, const KV& rhs) { - return lhs.first < rhs.first; - }); + std::sort(temp.begin(), temp.end(), + [](const KV& lhs, const KV& rhs) { return lhs.first < rhs.first; }); hash_reduce(static_cast(temp.size())); // hash the content @@ -111,15 +110,15 @@ void IRModuleNode::SHashReduce(SHashReducer hash_reduce) const { reduce_temp(); } -bool IRModuleNode::ContainGlobalVar(const std::string& name) const { +bool IRModuleNode::ContainGlobalVar(const String& name) const { return global_var_map_.find(name) != global_var_map_.end(); } -bool IRModuleNode::ContainGlobalTypeVar(const std::string& name) const { +bool IRModuleNode::ContainGlobalTypeVar(const String& name) const { return global_type_var_map_.find(name) != global_type_var_map_.end(); } -GlobalVar IRModuleNode::GetGlobalVar(const std::string& name) const { +GlobalVar IRModuleNode::GetGlobalVar(const String& name) const { auto it = global_var_map_.find(name); if (it == global_var_map_.end()) { std::ostringstream msg; @@ -146,15 +145,15 @@ tvm::Array IRModuleNode::GetGlobalVars() const { return tvm::Array(global_vars); } -GlobalTypeVar IRModuleNode::GetGlobalTypeVar(const std::string& name) const { +GlobalTypeVar IRModuleNode::GetGlobalTypeVar(const String& name) const { CHECK(global_type_var_map_.defined()); auto it = global_type_var_map_.find(name); CHECK(it != global_type_var_map_.end()) - << "Cannot find global type var " << name << " in the Module"; + << "Cannot find global type var " << name << " in the Module"; return (*it).second; } -Constructor IRModuleNode::GetConstructor(const std::string& adt, const std::string& cons) const { +Constructor IRModuleNode::GetConstructor(const String& adt, const String& cons) const { TypeData typeDef = this->LookupTypeDef(adt); for (Constructor c : typeDef->constructors) { if (cons.compare(c->name_hint) == 0) { @@ -174,7 +173,7 @@ tvm::Array IRModuleNode::GetGlobalTypeVars() const { return tvm::Array(global_type_vars); } -template +template tvm::Array concat(const tvm::Array& l, const tvm::Array& r) { tvm::Array ret(l); for (const T& t : r) { @@ -184,55 +183,37 @@ tvm::Array concat(const tvm::Array& l, const tvm::Array& r) { } // helper function to run type check -relay::Function RunTypeCheck(const IRModule& mod, - const GlobalVar& var, - relay::Function f) { +relay::Function RunTypeCheck(const IRModule& mod, const GlobalVar& var, relay::Function f) { auto func = Downcast(relay::DeDup(std::move(f))); // Type check the item before we add it to the module. auto fv = relay::FreeVars(func); auto ftv = relay::FreeTypeVars(func, mod); if (fv.size() != 0) { - LOG(WARNING) - << "There are free variables: " - << fv - << " in function: " - << AsText(func, false) - << std::endl; + LOG(WARNING) << "There are free variables: " << fv << " in function: " << AsText(func, false) + << std::endl; } if (ftv.size() != 0) { - LOG(WARNING) - << "There are free type variables: " - << ftv - << " in function: " - << AsText(func, false) - << std::endl; + LOG(WARNING) << "There are free type variables: " << ftv + << " in function: " << AsText(func, false) << std::endl; } - func = relay::Function(concat(func->params, fv), - func->body, - func->ret_type, - concat(func->type_params, ftv), - func->attrs); + func = relay::Function(concat(func->params, fv), func->body, func->ret_type, + concat(func->type_params, ftv), func->attrs); // Type check the item before we add it to the module. relay::Function checked_func = InferType(func, mod, var); return checked_func; } -void IRModuleNode::Add(const GlobalVar& var, - const BaseFunc& f, - bool update) { +void IRModuleNode::Add(const GlobalVar& var, const BaseFunc& f, bool update) { BaseFunc checked_func = f; if (auto* ptr = f.as()) { - checked_func = RunTypeCheck(GetRef(this), - var, - GetRef(ptr)); + checked_func = RunTypeCheck(GetRef(this), var, GetRef(ptr)); } Type type = checked_func->checked_type(); CHECK(type.as() == nullptr); if (functions.find(var) != functions.end()) { - CHECK(update) - << "Already have definition for " << var->name_hint; + CHECK(update) << "Already have definition for " << var->name_hint; auto old_type = functions[var]->checked_type(); CHECK(tvm::StructuralEqual()(type, old_type)) << "Module#update changes type, not possible in this mode."; @@ -241,8 +222,7 @@ void IRModuleNode::Add(const GlobalVar& var, AddUnchecked(var, checked_func); } -void IRModuleNode::AddUnchecked(const GlobalVar& var, - const BaseFunc& func) { +void IRModuleNode::AddUnchecked(const GlobalVar& var, const BaseFunc& func) { this->functions.Set(var, func); auto it = global_var_map_.find(var->name_hint); @@ -268,36 +248,31 @@ void IRModuleNode::RegisterConstructors(const GlobalTypeVar& var, const TypeData } } -void IRModuleNode::AddTypeDef(const GlobalTypeVar& var, - const TypeData& type, - bool update) { +void IRModuleNode::AddTypeDef(const GlobalTypeVar& var, const TypeData& type, bool update) { AddTypeDefUnchecked(var, type, update); // need to kind check at the end because the check can look up // a definition potentially CHECK(relay::KindCheck(type, GetRef(this)) == TypeKind::kTypeData) - << "Invalid or malformed typedata given to module: " << type; + << "Invalid or malformed typedata given to module: " << type; } -void IRModuleNode::AddTypeDefUnchecked(const GlobalTypeVar& var, - const TypeData& type, +void IRModuleNode::AddTypeDefUnchecked(const GlobalTypeVar& var, const TypeData& type, bool update) { this->type_definitions.Set(var, type); if (!update) { // set global type var map CHECK(global_type_var_map_.count(var->name_hint) == 0) - << "Duplicate global type definition name " << var->name_hint; + << "Duplicate global type definition name " << var->name_hint; } global_type_var_map_.Set(var->name_hint, var); RegisterConstructors(var, type); } -void IRModuleNode::Update(const GlobalVar& var, - const BaseFunc& func) { +void IRModuleNode::Update(const GlobalVar& var, const BaseFunc& func) { this->Add(var, func, true); } -void IRModuleNode::UpdateTypeDef(const GlobalTypeVar& var, - const TypeData& type) { +void IRModuleNode::UpdateTypeDef(const GlobalTypeVar& var, const TypeData& type) { this->AddTypeDef(var, type, true); } @@ -310,32 +285,29 @@ void IRModuleNode::Remove(const GlobalVar& var) { BaseFunc IRModuleNode::Lookup(const GlobalVar& var) const { auto it = functions.find(var); - CHECK(it != functions.end()) - << "There is no definition of " << var->name_hint; + CHECK(it != functions.end()) << "There is no definition of " << var->name_hint; return (*it).second; } -BaseFunc IRModuleNode::Lookup(const std::string& name) const { +BaseFunc IRModuleNode::Lookup(const String& name) const { GlobalVar id = this->GetGlobalVar(name); return this->Lookup(id); } TypeData IRModuleNode::LookupTypeDef(const GlobalTypeVar& var) const { auto it = type_definitions.find(var); - CHECK(it != type_definitions.end()) - << "There is no definition of " << var->name_hint; + CHECK(it != type_definitions.end()) << "There is no definition of " << var->name_hint; return (*it).second; } -TypeData IRModuleNode::LookupTypeDef(const std::string& name) const { +TypeData IRModuleNode::LookupTypeDef(const String& name) const { GlobalTypeVar id = this->GetGlobalTypeVar(name); return this->LookupTypeDef(id); } Constructor IRModuleNode::LookupTag(const int32_t tag) { auto it = constructor_tag_map_.find(tag); - CHECK(it != constructor_tag_map_.end()) - << "There is no constructor with the tag " << tag; + CHECK(it != constructor_tag_map_.end()) << "There is no constructor with the tag " << tag; return (*it).second; } @@ -356,10 +328,9 @@ void IRModuleNode::Update(const IRModule& mod) { } } -IRModule IRModule::FromExpr( - const RelayExpr& expr, - const tvm::Map& global_funcs, - const tvm::Map& type_definitions) { +IRModule IRModule::FromExpr(const RelayExpr& expr, + const tvm::Map& global_funcs, + const tvm::Map& type_definitions) { auto mod = IRModule(global_funcs, type_definitions); BaseFunc func; std::string gv_name = "main"; @@ -371,39 +342,35 @@ IRModule IRModule::FromExpr( } } else { - func = relay::Function(relay::FreeVars(expr), expr, Type(), - relay::FreeTypeVars(expr, mod), {}); + func = relay::Function(relay::FreeVars(expr), expr, Type(), relay::FreeTypeVars(expr, mod), {}); } auto main_gv = GlobalVar(gv_name); mod->Add(main_gv, func); return mod; } -void IRModuleNode::Import(const std::string& path) { +void IRModuleNode::Import(const String& path) { if (this->import_set_.count(path) == 0) { this->import_set_.insert(path); DLOG(INFO) << "Importing: " << path; std::fstream src_file(path, std::fstream::in); - std::string file_contents { - std::istreambuf_iterator(src_file), - std::istreambuf_iterator() }; + std::string file_contents{std::istreambuf_iterator(src_file), + std::istreambuf_iterator()}; auto mod_to_import = IRModule::FromText(file_contents, path); Update(mod_to_import); } } -void IRModuleNode::ImportFromStd(const std::string& path) { +void IRModuleNode::ImportFromStd(const String& path) { auto* f = tvm::runtime::Registry::Get("tvm.relay.std_path"); CHECK(f != nullptr) << "The Relay std_path is not set, please register tvm.relay.std_path."; std::string std_path = (*f)(); - return this->Import(std_path + "/" + path); + this->Import(std_path + "/" + path.operator std::string()); } -std::unordered_set IRModuleNode::Imports() const { - return this->import_set_; -} +std::unordered_set IRModuleNode::Imports() const { return this->import_set_; } -IRModule IRModule::FromText(const std::string& text, const std::string& source_path) { +IRModule IRModule::FromText(const String& text, const String& source_path) { auto* f = tvm::runtime::Registry::Get("relay.fromtext"); CHECK(f != nullptr) << "The Relay std_path is not set, please register tvm.relay.std_path."; IRModule mod = (*f)(text, source_path); @@ -413,13 +380,12 @@ IRModule IRModule::FromText(const std::string& text, const std::string& source_p TVM_REGISTER_NODE_TYPE(IRModuleNode); TVM_REGISTER_GLOBAL("ir.IRModule") -.set_body_typed([](tvm::Map funcs, - tvm::Map types) { - return IRModule(funcs, types, {}); -}); + .set_body_typed([](tvm::Map funcs, + tvm::Map types) { + return IRModule(funcs, types, {}); + }); -TVM_REGISTER_GLOBAL("ir.Module_Add") -.set_body([](TVMArgs args, TVMRetValue* ret) { +TVM_REGISTER_GLOBAL("ir.Module_Add").set_body([](TVMArgs args, TVMRetValue* ret) { IRModule mod = args[0]; GlobalVar var = args[1]; ObjectRef val = args[2]; @@ -443,75 +409,65 @@ TVM_REGISTER_GLOBAL("ir.Module_Add") *ret = mod; }); -TVM_REGISTER_GLOBAL("ir.Module_AddDef") -.set_body_method(&IRModuleNode::AddTypeDef); +TVM_REGISTER_GLOBAL("ir.Module_AddDef").set_body_method(&IRModuleNode::AddTypeDef); TVM_REGISTER_GLOBAL("ir.Module_GetGlobalVar") -.set_body_method(&IRModuleNode::GetGlobalVar); + .set_body_method(&IRModuleNode::GetGlobalVar); TVM_REGISTER_GLOBAL("ir.Module_GetGlobalVars") -.set_body_method(&IRModuleNode::GetGlobalVars); + .set_body_method(&IRModuleNode::GetGlobalVars); TVM_REGISTER_GLOBAL("ir.Module_GetGlobalTypeVars") -.set_body_method(&IRModuleNode::GetGlobalTypeVars); + .set_body_method(&IRModuleNode::GetGlobalTypeVars); TVM_REGISTER_GLOBAL("ir.Module_ContainGlobalVar") -.set_body_method(&IRModuleNode::ContainGlobalVar); + .set_body_method(&IRModuleNode::ContainGlobalVar); TVM_REGISTER_GLOBAL("ir.Module_GetGlobalTypeVar") -.set_body_method(&IRModuleNode::GetGlobalTypeVar); + .set_body_method(&IRModuleNode::GetGlobalTypeVar); -TVM_REGISTER_GLOBAL("ir.Module_Lookup") -.set_body_typed([](IRModule mod, GlobalVar var) { +TVM_REGISTER_GLOBAL("ir.Module_Lookup").set_body_typed([](IRModule mod, GlobalVar var) { return mod->Lookup(var); }); -TVM_REGISTER_GLOBAL("ir.Module_Lookup_str") -.set_body_typed([](IRModule mod, std::string var) { +TVM_REGISTER_GLOBAL("ir.Module_Lookup_str").set_body_typed([](IRModule mod, String var) { return mod->Lookup(var); }); -TVM_REGISTER_GLOBAL("ir.Module_LookupDef") -.set_body_typed([](IRModule mod, GlobalTypeVar var) { +TVM_REGISTER_GLOBAL("ir.Module_LookupDef").set_body_typed([](IRModule mod, GlobalTypeVar var) { return mod->LookupTypeDef(var); }); -TVM_REGISTER_GLOBAL("ir.Module_LookupDef_str") -.set_body_typed([](IRModule mod, std::string var) { +TVM_REGISTER_GLOBAL("ir.Module_LookupDef_str").set_body_typed([](IRModule mod, String var) { return mod->LookupTypeDef(var); }); -TVM_REGISTER_GLOBAL("ir.Module_LookupTag") -.set_body_typed([](IRModule mod, int32_t tag) { - return mod->LookupTag(tag); - }); +TVM_REGISTER_GLOBAL("ir.Module_LookupTag").set_body_typed([](IRModule mod, int32_t tag) { + return mod->LookupTag(tag); +}); TVM_REGISTER_GLOBAL("ir.Module_FromExpr") -.set_body_typed([](RelayExpr e, - tvm::Map funcs, - tvm::Map type_defs) { - return IRModule::FromExpr(e, funcs, type_defs); -}); + .set_body_typed([](RelayExpr e, tvm::Map funcs, + tvm::Map type_defs) { + return IRModule::FromExpr(e, funcs, type_defs); + }); -TVM_REGISTER_GLOBAL("ir.Module_Update") -.set_body_typed([](IRModule mod, IRModule from) { +TVM_REGISTER_GLOBAL("ir.Module_Update").set_body_typed([](IRModule mod, IRModule from) { mod->Update(from); }); -TVM_REGISTER_GLOBAL("ir.Module_Import") -.set_body_typed([](IRModule mod, std::string path) { +TVM_REGISTER_GLOBAL("ir.Module_Import").set_body_typed([](IRModule mod, String path) { mod->Import(path); }); -TVM_REGISTER_GLOBAL("ir.Module_ImportFromStd") -.set_body_typed([](IRModule mod, std::string path) { +TVM_REGISTER_GLOBAL("ir.Module_ImportFromStd").set_body_typed([](IRModule mod, String path) { mod->ImportFromStd(path); -});; +}); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "IRModuleNode( " << node->functions << ")"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "IRModuleNode( " << node->functions << ")"; + }); } // namespace tvm diff --git a/src/ir/op.cc b/src/ir/op.cc index b024165c1a4c..63d223050ff5 100644 --- a/src/ir/op.cc +++ b/src/ir/op.cc @@ -23,195 +23,108 @@ */ #include #include -#include #include +#include #include #include -#include -namespace dmlc { -// enable registry -DMLC_REGISTRY_ENABLE(::tvm::OpRegistry); -} // namespace dmlc +#include "../node/attr_registry.h" namespace tvm { -using runtime::TVMRetValue; -using runtime::TVMArgs; using runtime::PackedFunc; +using runtime::TVMArgs; +using runtime::TVMRetValue; -::dmlc::Registry* OpRegistry::Registry() { - return ::dmlc::Registry::Get(); -} - -// single manager of operator information. -struct OpManager { - // mutex to avoid registration from multiple threads. - std::mutex mutex; - // global operator counter - std::atomic op_counter{0}; - // storage of additional attribute table. - std::unordered_map> attr; - // frontend functions - std::vector frontend_funcs; - // get singleton of the op manager - static OpManager* Global() { - static OpManager* inst = new OpManager(); - return inst; - } -}; +using OpRegistry = AttrRegistry; // find operator by name -const Op& Op::Get(const std::string& name) { - const OpRegistry* reg = dmlc::Registry::Find(name); +const Op& Op::Get(const String& name) { + const OpRegEntry* reg = OpRegistry::Global()->Get(name); CHECK(reg != nullptr) << "Operator " << name << " is not registered"; return reg->op(); } -OpRegistry::OpRegistry() { - OpManager* mgr = OpManager::Global(); +OpRegEntry::OpRegEntry(uint32_t reg_index) { ObjectPtr n = make_object(); - n->index_ = mgr->op_counter++; + n->index_ = reg_index; op_ = Op(n); } +OpRegEntry& OpRegEntry::RegisterOrGet(const String& name) { + return OpRegistry::Global()->RegisterOrGet(name); +} + // Get attribute map by key -const GenericOpMap& Op::GetGenericAttr(const std::string& key) { - OpManager* mgr = OpManager::Global(); - std::lock_guard lock(mgr->mutex); - auto it = mgr->attr.find(key); - if (it == mgr->attr.end()) { - LOG(FATAL) << "Operator attribute \'" << key << "\' is not registered"; - } - return *it->second.get(); +const AttrRegistryMapContainerMap& Op::GetAttrMapContainer(const String& attr_name) { + return OpRegistry::Global()->GetAttrMap(attr_name); } // Check if a key is present in the registry. -bool Op::HasGenericAttr(const std::string& key) { - OpManager* mgr = OpManager::Global(); - std::lock_guard lock(mgr->mutex); - auto it = mgr->attr.find(key); - if (it == mgr->attr.end()) { - return false; - } - return true; -} +bool Op::HasAttrMap(const String& attr_name) { return OpRegistry::Global()->HasAttrMap(attr_name); } -// Resets attr of the OpMap. -void OpRegistry::reset_attr(const std::string& key) { - OpManager* mgr = OpManager::Global(); - std::lock_guard lock(mgr->mutex); - std::unique_ptr& op_map = mgr->attr[key]; - if (op_map == nullptr) { - return; - } - uint32_t index = op_->index_; - if (op_map->data_.size() > index) { - op_map->data_[index] = std::make_pair(TVMRetValue(), 0); - } +// Resets attr of the OpAttrMap. +void OpRegEntry::reset_attr(const std::string& attr_name) { + OpRegistry::Global()->ResetAttr(attr_name, op_); } -void OpRegistry::UpdateAttr(const std::string& key, - TVMRetValue value, - int plevel) { - OpManager* mgr = OpManager::Global(); - std::lock_guard lock(mgr->mutex); - std::unique_ptr& op_map = mgr->attr[key]; - if (op_map == nullptr) { - op_map.reset(new GenericOpMap()); - op_map->attr_name_ = key; - } - uint32_t index = op_->index_; - if (op_map->data_.size() <= index) { - op_map->data_.resize(index + 1, std::make_pair(TVMRetValue(), 0)); - } - std::pair& p = op_map->data_[index]; - CHECK(p.second != plevel) - << "Attribute " << key << " of operator " << this->name - << " is already registered with same plevel=" << plevel; - CHECK(value.type_code() != kTVMNullptr) - << "Registered packed_func is Null for " << key - << " of operator " << this->name; - if (p.second < plevel && value.type_code() != kTVMNullptr) { - op_map->data_[index] = std::make_pair(value, plevel); - } +void OpRegEntry::UpdateAttr(const String& key, TVMRetValue value, int plevel) { + OpRegistry::Global()->UpdateAttr(key, op_, value, plevel); } // Frontend APIs -TVM_REGISTER_GLOBAL("relay.op._ListOpNames") -.set_body_typed([]() { - Array ret; - for (const std::string& name : dmlc::Registry::ListAllNames()) { - ret.push_back(name); - } - return ret; - }); - -TVM_REGISTER_GLOBAL("relay.op._GetOp").set_body_typed(Op::Get); - -TVM_REGISTER_GLOBAL("relay.op._OpGetAttr") -.set_body([](TVMArgs args, TVMRetValue* rv) { - Op op = args[0]; - std::string attr_name = args[1]; - auto op_map = Op::GetAttr(attr_name); - if (op_map.count(op)) { - *rv = op_map[op]; - } - }); - -TVM_REGISTER_GLOBAL("relay.op._OpSetAttr") -.set_body([](TVMArgs args, TVMRetValue* rv) { - Op op = args[0]; - std::string attr_name = args[1]; - runtime::TVMArgValue value = args[2]; - int plevel = args[3]; - auto& reg = - OpRegistry::Registry()->__REGISTER_OR_GET__(op->name).set_name(); - reg.set_attr(attr_name, value, plevel); - }); - -TVM_REGISTER_GLOBAL("relay.op._OpResetAttr") -.set_body([](TVMArgs args, TVMRetValue* rv) { - Op op = args[0]; - std::string attr_name = args[1]; - auto& reg = - OpRegistry::Registry()->__REGISTER_OR_GET__(op->name); - reg.reset_attr(attr_name); - }); - -TVM_REGISTER_GLOBAL("relay.op._Register") -.set_body([](TVMArgs args, TVMRetValue* rv) { - std::string op_name = args[0]; - std::string attr_key = args[1]; - runtime::TVMArgValue value = args[2]; - int plevel = args[3]; - auto& reg = - OpRegistry::Registry()->__REGISTER_OR_GET__(op_name).set_name(); - // enable resgiteration and override of certain properties - if (attr_key == "num_inputs" && plevel > 128) { - reg.set_num_inputs(value); - } else if (attr_key == "attrs_type_key" && plevel > 128) { - LOG(FATAL) << "attrs type key no longer supported"; - } else { - // normal attr table override. - if (args[2].type_code() == kTVMPackedFuncHandle) { - // do an eager copy of the PackedFunc - PackedFunc f = args[2]; - // If we get a function from frontend, avoid deleting it. - OpManager::Global()->frontend_funcs.push_back(new PackedFunc(f)); - reg.set_attr(attr_key, f, plevel); +TVM_REGISTER_GLOBAL("ir.ListOpNames").set_body_typed([]() { + return OpRegistry::Global()->ListAllNames(); +}); + +TVM_REGISTER_GLOBAL("ir.GetOp").set_body_typed([](String name) -> Op { return Op::Get(name); }); + +TVM_REGISTER_GLOBAL("ir.OpGetAttr").set_body_typed([](Op op, String attr_name) -> TVMRetValue { + auto op_map = Op::GetAttrMap(attr_name); + TVMRetValue rv; + if (op_map.count(op)) { + rv = op_map[op]; + } + return rv; +}); + +TVM_REGISTER_GLOBAL("ir.OpSetAttr") + .set_body_typed([](Op op, String attr_name, runtime::TVMArgValue value, int plevel) { + auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name(); + reg.set_attr(attr_name, value, plevel); + }); + +TVM_REGISTER_GLOBAL("ir.OpResetAttr").set_body_typed([](Op op, String attr_name) { + auto& reg = OpRegistry::Global()->RegisterOrGet(op->name); + reg.reset_attr(attr_name); +}); + +TVM_REGISTER_GLOBAL("ir.RegisterOpAttr") + .set_body_typed([](String op_name, String attr_key, runtime::TVMArgValue value, int plevel) { + auto& reg = OpRegistry::Global()->RegisterOrGet(op_name).set_name(); + // enable resgiteration and override of certain properties + if (attr_key == "num_inputs" && plevel > 128) { + reg.set_num_inputs(value); + } else if (attr_key == "attrs_type_key" && plevel > 128) { + LOG(FATAL) << "attrs type key no longer supported"; } else { - reg.set_attr(attr_key, args[2], plevel); + // normal attr table override. + if (value.type_code() == kTVMPackedFuncHandle) { + // do an eager copy of the PackedFunc + PackedFunc f = value; + // If we get a function from frontend, avoid deleting it. + auto* fcopy = new PackedFunc(f); + reg.set_attr(attr_key, *fcopy, plevel); + } else { + reg.set_attr(attr_key, value, plevel); + } } - } - }); + }); // helper to get internal dev function in objectref. struct Op2ObjectPtr : public ObjectRef { - static ObjectPtr Get(const Op& op) { - return GetDataPtr(op); - } + static ObjectPtr Get(const Op& op) { return GetDataPtr(op); } }; ObjectPtr CreateOp(const std::string& name) { @@ -221,16 +134,13 @@ ObjectPtr CreateOp(const std::string& name) { return Op2ObjectPtr::Get(op); } -TVM_REGISTER_NODE_TYPE(OpNode) -.set_creator(CreateOp) -.set_repr_bytes([](const Object* n) { - return static_cast(n)->name; - }); +TVM_REGISTER_NODE_TYPE(OpNode).set_creator(CreateOp).set_repr_bytes( + [](const Object* n) -> std::string { return static_cast(n)->name; }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "Op(" << node->name << ")"; - }); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "Op(" << node->name << ")"; + }); } // namespace tvm diff --git a/src/ir/span.cc b/src/ir/span.cc index f84353de2a8b..565439f2ad74 100644 --- a/src/ir/span.cc +++ b/src/ir/span.cc @@ -25,10 +25,10 @@ namespace tvm { -ObjectPtr GetSourceNameNode(const std::string& name) { +ObjectPtr GetSourceNameNode(const String& name) { // always return pointer as the reference can change as map re-allocate. // or use another level of indirection by creating a unique_ptr - static std::unordered_map > source_map; + static std::unordered_map > source_map; auto sn = source_map.find(name); if (sn == source_map.end()) { @@ -41,42 +41,44 @@ ObjectPtr GetSourceNameNode(const std::string& name) { } } -SourceName SourceName::Get(const std::string& name) { - return SourceName(GetSourceNameNode(name)); +ObjectPtr GetSourceNameNodeByStr(const std::string& name) { + return GetSourceNameNode(name); } -TVM_REGISTER_GLOBAL("ir.SourceName") -.set_body_typed(SourceName::Get); +SourceName SourceName::Get(const String& name) { return SourceName(GetSourceNameNode(name)); } + +TVM_REGISTER_GLOBAL("ir.SourceName").set_body_typed(SourceName::Get); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "SourceName(" << node->name << ", " << node << ")"; - }); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "SourceName(" << node->name << ", " << node << ")"; + }); TVM_REGISTER_NODE_TYPE(SourceNameNode) -.set_creator(GetSourceNameNode) -.set_repr_bytes([](const Object* n) { - return static_cast(n)->name; - }); + .set_creator(GetSourceNameNodeByStr) + .set_repr_bytes([](const Object* n) -> std::string { + return static_cast(n)->name; + }); -Span SpanNode::make(SourceName source, int lineno, int col_offset) { +Span::Span(SourceName source, int lineno, int col_offset) { auto n = make_object(); n->source = std::move(source); n->lineno = lineno; n->col_offset = col_offset; - return Span(n); + data_ = std::move(n); } TVM_REGISTER_NODE_TYPE(SpanNode); -TVM_REGISTER_GLOBAL("ir.Span") -.set_body_typed(SpanNode::make); +TVM_REGISTER_GLOBAL("ir.Span").set_body_typed([](SourceName source, int lineno, int col_offset) { + return Span(source, lineno, col_offset); +}); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "Span(" << node->source << ", " << node->lineno << ", " - << node->col_offset << ")"; - }); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "Span(" << node->source << ", " << node->lineno << ", " << node->col_offset + << ")"; + }); } // namespace tvm diff --git a/src/ir/tensor_type.cc b/src/ir/tensor_type.cc index 92f0ea229421..0fab0acb8964 100644 --- a/src/ir/tensor_type.cc +++ b/src/ir/tensor_type.cc @@ -21,8 +21,8 @@ * \file src/ir/tensor_type.cc * \brief The type system AST nodes of Relay. */ -#include #include +#include #include namespace tvm { @@ -37,9 +37,7 @@ TensorType::TensorType(Array shape, DataType dtype) { data_ = std::move(n); } -TensorType TensorType::Scalar(DataType dtype) { - return TensorType({}, dtype); -} +TensorType TensorType::Scalar(DataType dtype) { return TensorType({}, dtype); } PrimExpr TensorTypeNode::Size() const { if (shape.size() == 0) { @@ -55,15 +53,14 @@ PrimExpr TensorTypeNode::Size() const { TVM_REGISTER_NODE_TYPE(TensorTypeNode); -TVM_REGISTER_GLOBAL("ir.TensorType") -.set_body_typed([](Array shape, DataType dtype) { +TVM_REGISTER_GLOBAL("ir.TensorType").set_body_typed([](Array shape, DataType dtype) { return TensorType(shape, dtype); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "TensorType(" << node->shape << ", " << node->dtype << ")"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "TensorType(" << node->shape << ", " << node->dtype << ")"; + }); } // namespace tvm diff --git a/src/ir/transform.cc b/src/ir/transform.cc index c1547d5205a4..d74b95abebdb 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -22,24 +22,23 @@ * \brief Infrastructure for transformation passes. */ #include -#include +#include +#include #include #include -#include -#include - -// TODO(tqchen): Update to use String container after it is merged. -#include +#include #include #include +#include "../runtime/object_internal.h" + namespace tvm { namespace transform { +using tvm::ReprPrinter; using tvm::runtime::TVMArgs; using tvm::runtime::TVMRetValue; -using tvm::ReprPrinter; struct PassContextThreadLocalEntry { /*! \brief The default pass context. */ @@ -48,32 +47,26 @@ struct PassContextThreadLocalEntry { /*! \brief The current pass context. */ std::stack context_stack; - PassContextThreadLocalEntry() { - default_context = PassContext(make_object()); - } + PassContextThreadLocalEntry() { default_context = PassContext(make_object()); } }; /*! \brief Thread local store to hold the pass context. */ -typedef dmlc::ThreadLocalStore - RelayPassContextThreadLocalStore; +typedef dmlc::ThreadLocalStore RelayPassContextThreadLocalStore; void PassContext::EnterWithScope() { - PassContextThreadLocalEntry* entry = - RelayPassContextThreadLocalStore::Get(); + PassContextThreadLocalEntry* entry = RelayPassContextThreadLocalStore::Get(); entry->context_stack.push(*this); } void PassContext::ExitWithScope() { - PassContextThreadLocalEntry* entry = - RelayPassContextThreadLocalStore::Get(); + PassContextThreadLocalEntry* entry = RelayPassContextThreadLocalStore::Get(); CHECK(!entry->context_stack.empty()); CHECK(entry->context_stack.top().same_as(*this)); entry->context_stack.pop(); } PassContext PassContext::Current() { - PassContextThreadLocalEntry* entry = - RelayPassContextThreadLocalStore::Get(); + PassContextThreadLocalEntry* entry = RelayPassContextThreadLocalStore::Get(); if (!entry->context_stack.empty()) { return entry->context_stack.top(); } else { @@ -81,15 +74,77 @@ PassContext PassContext::Current() { } } -PassContext PassContext::Create() { - return PassContext(make_object()); +class PassConfigManager { + public: + void Register(std::string key, uint32_t value_type_index) { + CHECK_EQ(key2vtype_.count(key), 0U); + ValueTypeInfo info; + info.type_index = value_type_index; + info.type_key = runtime::Object::TypeIndex2Key(value_type_index); + key2vtype_[key] = info; + } + + // Trying to validate and legalize a config. + void Legalize(Map* config) { + std::vector> update; + auto* reflection = ReflectionVTable::Global(); + + for (auto kv : *config) { + auto it = key2vtype_.find(kv.first); + if (it == key2vtype_.end()) { + std::ostringstream os; + os << "AttributeError: Invalid config option \'" << kv.first << "\' candidates are:"; + int counter = 0; + for (const auto& kv : key2vtype_) { + os << ' '; + if (counter++ != 0) os << ','; + os << kv.first; + } + LOG(FATAL) << os.str(); + } + const auto& info = it->second; + CHECK(kv.second.defined()) << "AttributeError: " << kv.first << " is None"; + if (kv.second->IsInstance::ContainerType>()) { + ObjectRef converted = + reflection->CreateObject(info.type_key, Downcast>(kv.second)); + update.emplace_back(kv.first, converted); + } else { + if (!runtime::ObjectInternal::DerivedFrom(kv.second.get(), info.type_index)) { + LOG(FATAL) << "AttributeError: expect config " << kv.first << " to have type " + << info.type_key << " but get " << kv.second->GetTypeKey(); + } + } + } + for (auto&& kv : update) { + config->Set(kv.first, kv.second); + } + } + + static PassConfigManager* Global() { + static auto* inst = new PassConfigManager(); + return inst; + } + + private: + struct ValueTypeInfo { + std::string type_key; + uint32_t type_index; + }; + + std::unordered_map key2vtype_; +}; + +void PassContext::RegisterConfigOption(const char* key, uint32_t value_type_index) { + PassConfigManager::Global()->Register(key, value_type_index); } +PassContext PassContext::Create() { return PassContext(make_object()); } + void PassContext::Trace(const IRModule& module, const PassInfo& info, bool is_before) const { - auto pass_ctx_node = this->operator->(); - if (pass_ctx_node->trace_func != nullptr) { - pass_ctx_node->trace_func(module, info, is_before); - } + auto pass_ctx_node = this->operator->(); + if (pass_ctx_node->trace_func != nullptr) { + pass_ctx_node->trace_func(module, info, is_before); + } } class ModulePass; @@ -114,9 +169,7 @@ class ModulePassNode : public PassNode { ModulePassNode() = default; - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("pass_info", &pass_info); - } + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("pass_info", &pass_info); } /*! * \brief Run a module pass on given pass context. @@ -211,9 +264,7 @@ class SequentialNode : public PassNode { TVM_DECLARE_FINAL_OBJECT_INFO(SequentialNode, PassNode); }; -PassInfo::PassInfo(int opt_level, - std::string name, - tvm::Array required) { +PassInfo::PassInfo(int opt_level, String name, tvm::Array required) { auto pass_info = make_object(); pass_info->opt_level = opt_level; pass_info->name = std::move(name); @@ -221,9 +272,8 @@ PassInfo::PassInfo(int opt_level, data_ = std::move(pass_info); } -ModulePass::ModulePass( - runtime::TypedPackedFunc pass_func, - PassInfo pass_info) { +ModulePass::ModulePass(runtime::TypedPackedFunc pass_func, + PassInfo pass_info) { auto n = make_object(); n->pass_func = std::move(pass_func); n->pass_info = std::move(pass_info); @@ -231,13 +281,10 @@ ModulePass::ModulePass( } // Module -> Module optimizations. -IRModule ModulePassNode::operator()(IRModule mod, - const PassContext& pass_ctx) const { +IRModule ModulePassNode::operator()(IRModule mod, const PassContext& pass_ctx) const { const PassInfo& pass_info = Info(); - DLOG(INFO) << "Executing module pass : " - << pass_info->name - << " with opt level: " - << pass_info->opt_level; + DLOG(INFO) << "Executing module pass : " << pass_info->name + << " with opt level: " << pass_info->opt_level; CHECK(mod.defined()); pass_ctx.Trace(mod, pass_info, true); @@ -254,7 +301,7 @@ Sequential::Sequential(tvm::Array passes, PassInfo pass_info) { data_ = std::move(n); } -Sequential::Sequential(tvm::Array passes, std::string name) { +Sequential::Sequential(tvm::Array passes, String name) { auto n = make_object(); n->passes = std::move(passes); PassInfo pass_info = PassInfo(2, std::move(name), {}); @@ -298,29 +345,27 @@ bool SequentialNode::PassEnabled(const PassInfo& info) const { return ctx->opt_level >= info->opt_level; } -Pass GetPass(const std::string& pass_name) { +Pass GetPass(const String& pass_name) { using tvm::runtime::Registry; const runtime::PackedFunc* f = nullptr; - if (pass_name.find("transform.") != std::string::npos) { + if (pass_name.operator std::string().find("transform.") != std::string::npos) { f = Registry::Get(pass_name); } else if ((f = Registry::Get("transform." + pass_name))) { // pass } else if ((f = Registry::Get("relay._transform." + pass_name))) { } - CHECK(f != nullptr) << "Cannot use " << pass_name - << "to create the pass"; + CHECK(f != nullptr) << "Cannot use " << pass_name << "to create the pass"; return (*f)(); } // TODO(zhiics): we currenlty only sequentially execute each pass in // a Sequential without the consideration of their orders. The phase // ordering problem needs to be handled in the future. -IRModule SequentialNode::operator()(IRModule mod, - const PassContext& pass_ctx) const { +IRModule SequentialNode::operator()(IRModule mod, const PassContext& pass_ctx) const { for (const Pass& pass : passes) { CHECK(pass.defined()) << "Found undefined pass for optimization."; const PassInfo& pass_info = pass->Info(); - if (!PassEnabled(pass_info)) continue; + if (!PassEnabled(pass_info)) continue; // resolve dependencies for (const auto& it : pass_info->required) { mod = GetPass(it)(std::move(mod), pass_ctx); @@ -330,11 +375,8 @@ IRModule SequentialNode::operator()(IRModule mod, return mod; } -Pass CreateModulePass( - const runtime::TypedPackedFunc& pass_func, - int opt_level, - const std::string& name, - const tvm::Array& required) { +Pass CreateModulePass(const runtime::TypedPackedFunc& pass_func, + int opt_level, String name, tvm::Array required) { PassInfo pass_info = PassInfo(opt_level, name, required); return ModulePass(pass_func, pass_info); } @@ -342,55 +384,50 @@ Pass CreateModulePass( TVM_REGISTER_NODE_TYPE(PassInfoNode); TVM_REGISTER_GLOBAL("transform.PassInfo") -.set_body_typed([](int opt_level, std::string name, tvm::Array required) { - return PassInfo(opt_level, name, required); -}); + .set_body_typed([](int opt_level, String name, tvm::Array required) { + return PassInfo(opt_level, name, required); + }); -TVM_REGISTER_GLOBAL("transform.Info") -.set_body([](TVMArgs args, TVMRetValue* ret) { +TVM_REGISTER_GLOBAL("transform.Info").set_body([](TVMArgs args, TVMRetValue* ret) { Pass pass = args[0]; *ret = pass->Info(); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, tvm::ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "The meta data of the pass: "; - p->stream << "pass name: " << node->name; - p->stream << "opt_level: " << node->opt_level; - p->stream << "required passes: [" << "\n"; - for (const auto& it : node->required) { - p->stream << it << ", "; - } - p->stream << "]\n"; -}); + .set_dispatch([](const ObjectRef& ref, tvm::ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "The meta data of the pass: "; + p->stream << "pass name: " << node->name; + p->stream << "opt_level: " << node->opt_level; + p->stream << "required passes: [" + << "\n"; + for (const auto& it : node->required) { + p->stream << it << ", "; + } + p->stream << "]\n"; + }); TVM_REGISTER_NODE_TYPE(ModulePassNode); TVM_REGISTER_GLOBAL("transform.MakeModulePass") -.set_body_typed( - [](runtime::TypedPackedFunc pass_func, - PassInfo pass_info) { - return ModulePass(pass_func, pass_info); -}); + .set_body_typed([](runtime::TypedPackedFunc pass_func, + PassInfo pass_info) { return ModulePass(pass_func, pass_info); }); -TVM_REGISTER_GLOBAL("transform.RunPass") -.set_body_typed([](Pass pass, IRModule mod) { +TVM_REGISTER_GLOBAL("transform.RunPass").set_body_typed([](Pass pass, IRModule mod) { return pass(std::move(mod)); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - const PassInfo info = node->Info(); - p->stream << "Run Module pass: " << info->name - << " at the optimization level " << info->opt_level; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + const PassInfo info = node->Info(); + p->stream << "Run Module pass: " << info->name << " at the optimization level " + << info->opt_level; + }); TVM_REGISTER_NODE_TYPE(SequentialNode); -TVM_REGISTER_GLOBAL("transform.Sequential") -.set_body([](TVMArgs args, TVMRetValue* ret) { +TVM_REGISTER_GLOBAL("transform.Sequential").set_body([](TVMArgs args, TVMRetValue* ret) { tvm::Array passes = args[0]; int opt_level = args[1]; std::string name = args[2]; @@ -400,91 +437,80 @@ TVM_REGISTER_GLOBAL("transform.Sequential") }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - const PassInfo info = node->Info(); - p->stream << "Run Sequential pass: " << info->name - << " at the optimization level " << info->opt_level << ". "; - p->stream << "The passes will be executed are: ["; - for (const auto& it : node->passes) { - const PassInfo pass_info = it->Info(); - p->stream << pass_info->name << " "; - } - p->stream << "]"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + const PassInfo info = node->Info(); + p->stream << "Run Sequential pass: " << info->name << " at the optimization level " + << info->opt_level << ". "; + p->stream << "The passes will be executed are: ["; + for (const auto& it : node->passes) { + const PassInfo pass_info = it->Info(); + p->stream << pass_info->name << " "; + } + p->stream << "]"; + }); TVM_REGISTER_NODE_TYPE(PassContextNode); TVM_REGISTER_GLOBAL("transform.PassContext") -.set_body([](TVMArgs args, TVMRetValue* ret) { - auto pctx = PassContext::Create(); - int opt_level = args[0]; - int fallback_device = args[1]; - tvm::Array required = args[2]; - tvm::Array disabled = args[3]; - TraceFunc trace_func = args[4]; - pctx->opt_level = opt_level; - pctx->fallback_device = fallback_device; - pctx->required_pass = std::move(required); - pctx->disabled_pass = std::move(disabled); - pctx->trace_func = std::move(trace_func); - *ret = pctx; -}); + .set_body_typed([](int opt_level, Array required, Array disabled, + TraceFunc trace_func, Optional> config) { + auto pctx = PassContext::Create(); + pctx->opt_level = opt_level; + + pctx->required_pass = std::move(required); + pctx->disabled_pass = std::move(disabled); + pctx->trace_func = std::move(trace_func); + if (config.defined()) { + pctx->config = config.value(); + } + PassConfigManager::Global()->Legalize(&(pctx->config)); + return pctx; + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "Pass context information: " << "\n"; - p->stream << "\topt_level: " << node->opt_level << "\n"; - p->stream << "\tfallback device: " - << runtime::DeviceName(node->fallback_device) - << "\n"; - - p->stream << "\trequired passes: [" << node->opt_level; - for (const auto& it : node->required_pass) { - p->stream << it << " "; - } - p->stream << "]\n"; - - p->stream << "\tdisabled passes: [" << node->opt_level; - for (const auto& it : node->disabled_pass) { - p->stream << it << " "; - } - p->stream << "]"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "Pass context information: " + << "\n"; + p->stream << "\topt_level: " << node->opt_level << "\n"; + + p->stream << "\trequired passes: ["; + for (const auto& it : node->required_pass) { + p->stream << it << " "; + } + p->stream << "]\n"; + + p->stream << "\tdisabled passes: ["; + for (const auto& it : node->disabled_pass) { + p->stream << it << " "; + } + p->stream << "]\n"; + p->stream << "\tconfig: " << node->config; + }); class PassContext::Internal { public: - static void EnterScope(PassContext pass_ctx) { - pass_ctx.EnterWithScope(); - } + static void EnterScope(PassContext pass_ctx) { pass_ctx.EnterWithScope(); } - static void ExitScope(PassContext pass_ctx) { - pass_ctx.ExitWithScope(); - } + static void ExitScope(PassContext pass_ctx) { pass_ctx.ExitWithScope(); } }; -TVM_REGISTER_GLOBAL("transform.GetCurrentPassContext") -.set_body_typed(PassContext::Current); - -TVM_REGISTER_GLOBAL("transform.EnterPassContext") -.set_body_typed(PassContext::Internal::EnterScope); +TVM_REGISTER_GLOBAL("transform.GetCurrentPassContext").set_body_typed(PassContext::Current); -TVM_REGISTER_GLOBAL("transform.ExitPassContext") -.set_body_typed(PassContext::Internal::ExitScope); +TVM_REGISTER_GLOBAL("transform.EnterPassContext").set_body_typed(PassContext::Internal::EnterScope); +TVM_REGISTER_GLOBAL("transform.ExitPassContext").set_body_typed(PassContext::Internal::ExitScope); -Pass PrintIR(std::string header, bool show_meta_data) { - auto pass_func =[header, show_meta_data](IRModule mod, const PassContext& ctx) { - LOG(INFO) << "PrintIR(" << header << "):\n" - << AsText(mod, show_meta_data); +Pass PrintIR(String header, bool show_meta_data) { + auto pass_func = [header, show_meta_data](IRModule mod, const PassContext& ctx) { + LOG(INFO) << "PrintIR(" << header << "):\n" << AsText(mod, show_meta_data); return mod; }; return CreateModulePass(pass_func, 0, "PrintIR", {}); } -TVM_REGISTER_GLOBAL("transform.PrintIR") -.set_body_typed(PrintIR); +TVM_REGISTER_GLOBAL("transform.PrintIR").set_body_typed(PrintIR); } // namespace transform } // namespace tvm diff --git a/src/ir/type.cc b/src/ir/type.cc index 5b038218c127..38a6ec3e6805 100644 --- a/src/ir/type.cc +++ b/src/ir/type.cc @@ -33,17 +33,15 @@ PrimType::PrimType(runtime::DataType dtype) { TVM_REGISTER_NODE_TYPE(PrimTypeNode); -TVM_REGISTER_GLOBAL("ir.PrimType") -.set_body_typed([](runtime::DataType dtype) { +TVM_REGISTER_GLOBAL("ir.PrimType").set_body_typed([](runtime::DataType dtype) { return PrimType(dtype); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << node->dtype; -}); - + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << node->dtype; + }); PointerType::PointerType(Type element_type) { ObjectPtr n = make_object(); @@ -53,20 +51,18 @@ PointerType::PointerType(Type element_type) { TVM_REGISTER_NODE_TYPE(PointerTypeNode); -TVM_REGISTER_GLOBAL("ir.PointerType") -.set_body_typed([](Type element_type) { +TVM_REGISTER_GLOBAL("ir.PointerType").set_body_typed([](Type element_type) { return PointerType(element_type); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->Print(node->element_type); - p->stream << '*'; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->Print(node->element_type); + p->stream << '*'; + }); - -TypeVar::TypeVar(std::string name, TypeKind kind) { +TypeVar::TypeVar(String name, TypeKind kind) { ObjectPtr n = make_object(); n->name_hint = std::move(name); n->kind = std::move(kind); @@ -75,20 +71,17 @@ TypeVar::TypeVar(std::string name, TypeKind kind) { TVM_REGISTER_NODE_TYPE(TypeVarNode); -TVM_REGISTER_GLOBAL("ir.TypeVar") -.set_body_typed([](std::string name, int kind) { +TVM_REGISTER_GLOBAL("ir.TypeVar").set_body_typed([](String name, int kind) { return TypeVar(name, static_cast(kind)); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "TypeVar(" << node->name_hint << ", " - << node->kind << ")"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "TypeVar(" << node->name_hint << ", " << node->kind << ")"; + }); - -GlobalTypeVar::GlobalTypeVar(std::string name, TypeKind kind) { +GlobalTypeVar::GlobalTypeVar(String name, TypeKind kind) { ObjectPtr n = make_object(); n->name_hint = std::move(name); n->kind = std::move(kind); @@ -97,21 +90,17 @@ GlobalTypeVar::GlobalTypeVar(std::string name, TypeKind kind) { TVM_REGISTER_NODE_TYPE(GlobalTypeVarNode); -TVM_REGISTER_GLOBAL("ir.GlobalTypeVar") -.set_body_typed([](std::string name, int kind) { +TVM_REGISTER_GLOBAL("ir.GlobalTypeVar").set_body_typed([](String name, int kind) { return GlobalTypeVar(name, static_cast(kind)); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "GlobalTypeVar(" << node->name_hint << ", " - << node->kind << ")"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "GlobalTypeVar(" << node->name_hint << ", " << node->kind << ")"; + }); -FuncType::FuncType(tvm::Array arg_types, - Type ret_type, - tvm::Array type_params, +FuncType::FuncType(tvm::Array arg_types, Type ret_type, tvm::Array type_params, tvm::Array type_constraints) { ObjectPtr n = make_object(); n->arg_types = std::move(arg_types); @@ -124,21 +113,17 @@ FuncType::FuncType(tvm::Array arg_types, TVM_REGISTER_NODE_TYPE(FuncTypeNode); TVM_REGISTER_GLOBAL("ir.FuncType") -.set_body_typed([](tvm::Array arg_types, - Type ret_type, - tvm::Array type_params, - tvm::Array type_constraints) { - return FuncType(arg_types, ret_type, type_params, type_constraints); -}); + .set_body_typed([](tvm::Array arg_types, Type ret_type, tvm::Array type_params, + tvm::Array type_constraints) { + return FuncType(arg_types, ret_type, type_params, type_constraints); + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "FuncType(" << node->type_params << ", " - << node->arg_types << ", " << node->ret_type << ", " - << node->type_constraints << ")"; -}); - + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "FuncType(" << node->type_params << ", " << node->arg_types << ", " + << node->ret_type << ", " << node->type_constraints << ")"; + }); TupleType::TupleType(Array fields) { ObjectPtr n = make_object(); @@ -146,23 +131,19 @@ TupleType::TupleType(Array fields) { data_ = std::move(n); } -TupleType TupleType::Empty() { - return TupleType(Array()); -} +TupleType TupleType::Empty() { return TupleType(Array()); } TVM_REGISTER_NODE_TYPE(TupleTypeNode); -TVM_REGISTER_GLOBAL("ir.TupleType") -.set_body_typed([](Array fields) { +TVM_REGISTER_GLOBAL("ir.TupleType").set_body_typed([](Array fields) { return TupleType(fields); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "TupleTypeNode(" << node->fields << ")"; -}); - + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "TupleTypeNode(" << node->fields << ")"; + }); IncompleteType::IncompleteType(TypeKind kind) { auto n = make_object(); @@ -172,17 +153,15 @@ IncompleteType::IncompleteType(TypeKind kind) { TVM_REGISTER_NODE_TYPE(IncompleteTypeNode); -TVM_REGISTER_GLOBAL("ir.IncompleteType") -.set_body_typed([](int kind) { - return IncompleteType(static_cast(kind)); - }); +TVM_REGISTER_GLOBAL("ir.IncompleteType").set_body_typed([](int kind) { + return IncompleteType(static_cast(kind)); +}); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "IncompleteTypeNode(" << node->kind << ", " << node << ")"; - }); - + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "IncompleteTypeNode(" << node->kind << ", " << node << ")"; + }); RelayRefType::RelayRefType(Type value) { ObjectPtr n = make_object(); @@ -190,17 +169,16 @@ RelayRefType::RelayRefType(Type value) { data_ = std::move(n); } -TVM_REGISTER_GLOBAL("ir.RelayRefType") -.set_body_typed([](Type value) { +TVM_REGISTER_GLOBAL("ir.RelayRefType").set_body_typed([](Type value) { return RelayRefType(value); }); TVM_REGISTER_NODE_TYPE(RelayRefTypeNode); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "RelayRefTypeNode(" << node->value << ")"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "RelayRefTypeNode(" << node->value << ")"; + }); } // namespace tvm diff --git a/src/ir/type_functor.cc b/src/ir/type_functor.cc index 9d9167fa1c0f..21ce3d09d2ae 100644 --- a/src/ir/type_functor.cc +++ b/src/ir/type_functor.cc @@ -22,18 +22,16 @@ * \brief Implementations of type functors. */ #include + #include namespace tvm { -void TypeVisitor::VisitType_(const TypeVarNode* op) { -} +void TypeVisitor::VisitType_(const TypeVarNode* op) {} -void TypeVisitor::VisitType_(const TensorTypeNode* op) { -} +void TypeVisitor::VisitType_(const TensorTypeNode* op) {} -void TypeVisitor::VisitType_(const IncompleteTypeNode* op) { -} +void TypeVisitor::VisitType_(const IncompleteTypeNode* op) {} void TypeVisitor::VisitType_(const FuncTypeNode* op) { for (auto type_param : op->type_params) { @@ -56,9 +54,7 @@ void TypeVisitor::VisitType_(const TupleTypeNode* op) { } } -void TypeVisitor::VisitType_(const RelayRefTypeNode* op) { - this->VisitType(op->value); -} +void TypeVisitor::VisitType_(const RelayRefTypeNode* op) { this->VisitType(op->value); } void TypeVisitor::VisitType_(const TypeRelationNode* op) { for (const Type& t : op->args) { @@ -66,8 +62,7 @@ void TypeVisitor::VisitType_(const TypeRelationNode* op) { } } -void TypeVisitor::VisitType_(const GlobalTypeVarNode* op) { -} +void TypeVisitor::VisitType_(const GlobalTypeVarNode* op) {} void TypeVisitor::VisitType_(const TypeCallNode* op) { this->VisitType(op->func); @@ -90,12 +85,9 @@ void TypeVisitor::VisitType_(const TypeDataNode* op) { } } -void TypeVisitor::VisitType_(const PrimTypeNode* op) { -} +void TypeVisitor::VisitType_(const PrimTypeNode* op) {} -void TypeVisitor::VisitType_(const PointerTypeNode* op) { - this->VisitType(op->element_type); -} +void TypeVisitor::VisitType_(const PointerTypeNode* op) { this->VisitType(op->element_type); } Type TypeMutator::VisitType(const Type& t) { return t.defined() ? TypeFunctor::VisitType(t) : t; @@ -115,18 +107,14 @@ Array TypeMutator::MutateArray(Array arr) { return arr; } -Type TypeMutator::VisitType_(const TypeVarNode* op) { - return GetRef(op); -} +Type TypeMutator::VisitType_(const TypeVarNode* op) { return GetRef(op); } Type TypeMutator::VisitType_(const TensorTypeNode* op) { // TODO(tvm-team) recursively visit to replace Var return GetRef(op); } -Type TypeMutator::VisitType_(const IncompleteTypeNode* op) { - return GetRef(op); -} +Type TypeMutator::VisitType_(const IncompleteTypeNode* op) { return GetRef(op); } Type TypeMutator::VisitType_(const FuncTypeNode* op) { bool changed = false; @@ -145,8 +133,7 @@ Type TypeMutator::VisitType_(const FuncTypeNode* op) { for (auto type_cs : op->type_constraints) { auto new_type_cs = VisitType(type_cs); changed = changed || !new_type_cs.same_as(type_cs); - if (const TypeConstraintNode* tin = - new_type_cs.as()) { + if (const TypeConstraintNode* tin = new_type_cs.as()) { type_constraints.push_back(GetRef(tin)); } else { LOG(FATAL) << new_type_cs; @@ -160,10 +147,7 @@ Type TypeMutator::VisitType_(const FuncTypeNode* op) { changed = changed || !new_ret_type.same_as(op->ret_type); if (!changed) return GetRef(op); - return FuncType(new_args, - new_ret_type, - type_params, - type_constraints); + return FuncType(new_args, new_ret_type, type_params, type_constraints); } Type TypeMutator::VisitType_(const TupleTypeNode* op) { @@ -184,16 +168,11 @@ Type TypeMutator::VisitType_(const TypeRelationNode* type_rel) { if (new_args.same_as(type_rel->args)) { return GetRef(type_rel); } else { - return TypeRelation(type_rel->func, - new_args, - type_rel->num_inputs, - type_rel->attrs); + return TypeRelation(type_rel->func, new_args, type_rel->num_inputs, type_rel->attrs); } } -Type TypeMutator::VisitType_(const GlobalTypeVarNode* op) { - return GetRef(op); -} +Type TypeMutator::VisitType_(const GlobalTypeVarNode* op) { return GetRef(op); } Type TypeMutator::VisitType_(const TypeCallNode* op) { Type new_func = VisitType(op->func); @@ -205,13 +184,9 @@ Type TypeMutator::VisitType_(const TypeCallNode* op) { } } -Type TypeMutator::VisitType_(const TypeDataNode* op) { - return GetRef(op); -} +Type TypeMutator::VisitType_(const TypeDataNode* op) { return GetRef(op); } -Type TypeMutator::VisitType_(const PrimTypeNode* op) { - return GetRef(op); -} +Type TypeMutator::VisitType_(const PrimTypeNode* op) { return GetRef(op); } Type TypeMutator::VisitType_(const PointerTypeNode* op) { Type element_type = VisitType(op->element_type); @@ -226,8 +201,7 @@ Type TypeMutator::VisitType_(const PointerTypeNode* op) { // Implements bind. class TypeBinder : public TypeMutator { public: - explicit TypeBinder(const tvm::Map& args_map) - : args_map_(args_map) {} + explicit TypeBinder(const tvm::Map& args_map) : args_map_(args_map) {} Type VisitType_(const TypeVarNode* op) override { auto id = GetRef(op); diff --git a/src/ir/type_relation.cc b/src/ir/type_relation.cc index ab479e782b56..f038a6678b42 100644 --- a/src/ir/type_relation.cc +++ b/src/ir/type_relation.cc @@ -35,22 +35,17 @@ TypeCall::TypeCall(Type func, tvm::Array args) { TVM_REGISTER_NODE_TYPE(TypeCallNode); -TVM_REGISTER_GLOBAL("ir.TypeCall") -.set_body_typed([](Type func, Array type) { +TVM_REGISTER_GLOBAL("ir.TypeCall").set_body_typed([](Type func, Array type) { return TypeCall(func, type); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "TypeCallNode(" << node->func << ", " - << node->args << ")"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "TypeCallNode(" << node->func << ", " << node->args << ")"; + }); -TypeRelation::TypeRelation(TypeRelationFn func, - Array args, - int num_inputs, - Attrs attrs) { +TypeRelation::TypeRelation(TypeRelationFn func, Array args, int num_inputs, Attrs attrs) { ObjectPtr n = make_object(); n->func = std::move(func); n->args = std::move(args); @@ -62,18 +57,13 @@ TypeRelation::TypeRelation(TypeRelationFn func, TVM_REGISTER_NODE_TYPE(TypeRelationNode); TVM_REGISTER_GLOBAL("ir.TypeRelation") -.set_body_typed([](TypeRelationFn func, - Array args, - int num_inputs, - Attrs attrs) { - return TypeRelation(func, args, num_inputs, attrs); -}); + .set_body_typed([](TypeRelationFn func, Array args, int num_inputs, Attrs attrs) { + return TypeRelation(func, args, num_inputs, attrs); + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "TypeRelationNode(" - << node->func->name - << ", " << node->args << ")"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "TypeRelationNode(" << node->func->name << ", " << node->args << ")"; + }); } // namespace tvm diff --git a/src/node/attr_registry.h b/src/node/attr_registry.h new file mode 100644 index 000000000000..9cc5b4d410a7 --- /dev/null +++ b/src/node/attr_registry.h @@ -0,0 +1,181 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/node/attr_registry.h + * \brief Common global registry for objects that also have additional attrs. + */ +#ifndef TVM_NODE_ATTR_REGISTRY_H_ +#define TVM_NODE_ATTR_REGISTRY_H_ + +#include +#include + +#include +#include +#include +#include +#include + +namespace tvm { + +/*! + * \breif Implementation of registry with attributes. + * + * \tparam EntryType Tye type of the registry entry. + * \tparam KeyType The actual key that is used to lookup the attributes. + * each entry has a corresponding key by default. + */ +template +class AttrRegistry { + public: + using TSelf = AttrRegistry; + /*! + * \brief Get an entry from the registry. + * \param name The name of the item. + * \return The corresponding entry. + */ + const EntryType* Get(const String& name) const { + auto it = entry_map_.find(name); + if (it != entry_map_.end()) return it->second; + return nullptr; + } + + /*! + * \brief Get an entry or register a new one. + * \param name The name of the item. + * \return The corresponding entry. + */ + EntryType& RegisterOrGet(const String& name) { + auto it = entry_map_.find(name); + if (it != entry_map_.end()) return *it->second; + uint32_t registry_index = static_cast(entries_.size()); + auto entry = std::unique_ptr(new EntryType(registry_index)); + auto* eptr = entry.get(); + eptr->name = name; + entry_map_[name] = eptr; + entries_.emplace_back(std::move(entry)); + return *eptr; + } + + /*! + * \brief List all the entry names in the registry. + * \return The entry names. + */ + Array ListAllNames() const { + Array names; + for (const auto& kv : entry_map_) { + names.push_back(kv.first); + } + return names; + } + + /*! + * \brief Update the attribute stable. + * \param attr_name The name of the attribute. + * \param key The key to the attribute table. + * \param value The value to be set. + * \param plevel The support level. + */ + void UpdateAttr(const String& attr_name, const KeyType& key, runtime::TVMRetValue value, + int plevel) { + using runtime::TVMRetValue; + std::lock_guard lock(mutex_); + auto& op_map = attrs_[attr_name]; + if (op_map == nullptr) { + op_map.reset(new AttrRegistryMapContainerMap()); + op_map->attr_name_ = attr_name; + } + + uint32_t index = key->AttrRegistryIndex(); + if (op_map->data_.size() <= index) { + op_map->data_.resize(index + 1, std::make_pair(TVMRetValue(), 0)); + } + std::pair& p = op_map->data_[index]; + CHECK(p.second != plevel) << "Attribute " << attr_name << " of " << key->AttrRegistryName() + << " is already registered with same plevel=" << plevel; + CHECK(value.type_code() != kTVMNullptr) << "Registered packed_func is Null for " << attr_name + << " of operator " << key->AttrRegistryName(); + if (p.second < plevel && value.type_code() != kTVMNullptr) { + op_map->data_[index] = std::make_pair(value, plevel); + } + } + + /*! + * \brief Reset an attribute table entry. + * \param attr_name The name of the attribute. + * \param key The key to the attribute table. + */ + void ResetAttr(const String& attr_name, const KeyType& key) { + std::lock_guard lock(mutex_); + auto& op_map = attrs_[attr_name]; + if (op_map == nullptr) { + return; + } + uint32_t index = key->AttrRegistryIndex(); + if (op_map->data_.size() > index) { + op_map->data_[index] = std::make_pair(TVMRetValue(), 0); + } + } + + /*! + * \brief Get an internal attribute map. + * \param attr_name The name of the attribute. + * \return The result attribute map. + */ + const AttrRegistryMapContainerMap& GetAttrMap(const String& attr_name) { + std::lock_guard lock(mutex_); + auto it = attrs_.find(attr_name); + if (it == attrs_.end()) { + LOG(FATAL) << "Attribute \'" << attr_name << "\' is not registered"; + } + return *it->second.get(); + } + + /*! + * \brief Check of attribute has been registered. + * \param attr_name The name of the attribute. + * \return The check result. + */ + bool HasAttrMap(const String& attr_name) { + std::lock_guard lock(mutex_); + return attrs_.count(attr_name); + } + + /*! + * \return a global singleton of the registry. + */ + static TSelf* Global() { + static TSelf* inst = new TSelf(); + return inst; + } + + private: + // mutex to avoid registration from multiple threads. + std::mutex mutex_; + // entries in the registry + std::vector> entries_; + // map from name to entries. + std::unordered_map entry_map_; + // storage of additional attribute table. + std::unordered_map>> attrs_; +}; + +} // namespace tvm +#endif // TVM_NODE_ATTR_REGISTRY_H_ diff --git a/src/node/container.cc b/src/node/container.cc index 52e4bf19718c..bdebb7fea778 100644 --- a/src/node/container.cc +++ b/src/node/container.cc @@ -20,10 +20,11 @@ * Expose container API to frontend. * \file src/node/container.cc */ -#include -#include #include +#include +#include #include + #include "../support/str_escape.h" namespace tvm { @@ -32,14 +33,11 @@ namespace tvm { struct StringObjTrait { static constexpr const std::nullptr_t VisitAttrs = nullptr; - static void SHashReduce(const runtime::StringObj* key, - SHashReducer hash_reduce) { - hash_reduce->SHashReduceHashedValue( - runtime::String::HashBytes(key->data, key->size)); + static void SHashReduce(const runtime::StringObj* key, SHashReducer hash_reduce) { + hash_reduce->SHashReduceHashedValue(runtime::String::HashBytes(key->data, key->size)); } - static bool SEqualReduce(const runtime::StringObj* lhs, - const runtime::StringObj* rhs, + static bool SEqualReduce(const runtime::StringObj* lhs, const runtime::StringObj* rhs, SEqualReducer equal) { if (lhs == rhs) return true; if (lhs->size != rhs->size) return false; @@ -49,32 +47,29 @@ struct StringObjTrait { }; struct RefToObjectPtr : public ObjectRef { - static ObjectPtr Get(const ObjectRef& ref) { - return GetDataPtr(ref); - } + static ObjectPtr Get(const ObjectRef& ref) { return GetDataPtr(ref); } }; TVM_REGISTER_REFLECTION_VTABLE(runtime::StringObj, StringObjTrait) -.set_creator([](const std::string& bytes) { - return RefToObjectPtr::Get(runtime::String(bytes)); -}) -.set_repr_bytes([](const Object* n) -> std::string { - return GetRef( - static_cast(n)).operator std::string(); -}); + .set_creator([](const std::string& bytes) { + return RefToObjectPtr::Get(runtime::String(bytes)); + }) + .set_repr_bytes([](const Object* n) -> std::string { + return GetRef(static_cast(n)) + . + operator std::string(); + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << '"' << support::StrEscape(op->data, op->size) << '"'; -}); - + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << '"' << support::StrEscape(op->data, op->size) << '"'; + }); struct ADTObjTrait { static constexpr const std::nullptr_t VisitAttrs = nullptr; - static void SHashReduce(const runtime::ADTObj* key, - SHashReducer hash_reduce) { + static void SHashReduce(const runtime::ADTObj* key, SHashReducer hash_reduce) { hash_reduce(key->tag); hash_reduce(static_cast(key->size)); for (uint32_t i = 0; i < key->size; ++i) { @@ -82,8 +77,7 @@ struct ADTObjTrait { } } - static bool SEqualReduce(const runtime::ADTObj* lhs, - const runtime::ADTObj* rhs, + static bool SEqualReduce(const runtime::ADTObj* lhs, const runtime::ADTObj* rhs, SEqualReducer equal) { if (lhs == rhs) return true; if (lhs->tag != rhs->tag) return false; @@ -98,39 +92,31 @@ struct ADTObjTrait { TVM_REGISTER_REFLECTION_VTABLE(runtime::ADTObj, ADTObjTrait); - struct NDArrayContainerTrait { static constexpr const std::nullptr_t VisitAttrs = nullptr; - static void SHashReduce(const runtime::NDArray::Container* key, - SHashReducer hash_reduce) { + static void SHashReduce(const runtime::NDArray::Container* key, SHashReducer hash_reduce) { CHECK_EQ(key->dl_tensor.ctx.device_type, kDLCPU) << "can only compare CPU tensor"; - CHECK(runtime::IsContiguous(key->dl_tensor)) - << "Can only hash contiguous tensor"; + CHECK(runtime::IsContiguous(key->dl_tensor)) << "Can only hash contiguous tensor"; hash_reduce(runtime::DataType(key->dl_tensor.dtype)); hash_reduce(key->dl_tensor.ndim); for (int i = 0; i < key->dl_tensor.ndim; ++i) { hash_reduce(key->dl_tensor.shape[i]); } - hash_reduce->SHashReduceHashedValue( - runtime::String::HashBytes( - static_cast(key->dl_tensor.data), - runtime::GetDataSize(key->dl_tensor))); + hash_reduce->SHashReduceHashedValue(runtime::String::HashBytes( + static_cast(key->dl_tensor.data), runtime::GetDataSize(key->dl_tensor))); } static bool SEqualReduce(const runtime::NDArray::Container* lhs, - const runtime::NDArray::Container* rhs, - SEqualReducer equal) { + const runtime::NDArray::Container* rhs, SEqualReducer equal) { if (lhs == rhs) return true; auto ldt = lhs->dl_tensor.dtype; auto rdt = rhs->dl_tensor.dtype; CHECK_EQ(lhs->dl_tensor.ctx.device_type, kDLCPU) << "can only compare CPU tensor"; CHECK_EQ(rhs->dl_tensor.ctx.device_type, kDLCPU) << "can only compare CPU tensor"; - CHECK(runtime::IsContiguous(lhs->dl_tensor)) - << "Can only compare contiguous tensor"; - CHECK(runtime::IsContiguous(rhs->dl_tensor)) - << "Can only compare contiguous tensor"; + CHECK(runtime::IsContiguous(lhs->dl_tensor)) << "Can only compare contiguous tensor"; + CHECK(runtime::IsContiguous(rhs->dl_tensor)) << "Can only compare contiguous tensor"; if (lhs->dl_tensor.ndim != rhs->dl_tensor.ndim) return false; for (int i = 0; i < lhs->dl_tensor.ndim; ++i) { @@ -147,24 +133,20 @@ struct NDArrayContainerTrait { TVM_REGISTER_REFLECTION_VTABLE(runtime::NDArray::Container, NDArrayContainerTrait); - struct ArrayNodeTrait { static constexpr const std::nullptr_t VisitAttrs = nullptr; - static void SHashReduce(const ArrayNode* key, - SHashReducer hash_reduce) { - hash_reduce(static_cast(key->data.size())); - for (size_t i = 0; i < key->data.size(); ++i) { - hash_reduce(key->data[i]); + static void SHashReduce(const ArrayNode* key, SHashReducer hash_reduce) { + hash_reduce(static_cast(key->size())); + for (size_t i = 0; i < key->size(); ++i) { + hash_reduce(key->at(i)); } } - static bool SEqualReduce(const ArrayNode* lhs, - const ArrayNode* rhs, - SEqualReducer equal) { - if (lhs->data.size() != rhs->data.size()) return false; - for (size_t i = 0; i < lhs->data.size(); ++i) { - if (!equal(lhs->data[i], rhs->data[i])) return false; + static bool SEqualReduce(const ArrayNode* lhs, const ArrayNode* rhs, SEqualReducer equal) { + if (lhs->size() != rhs->size()) return false; + for (size_t i = 0; i < lhs->size(); ++i) { + if (!equal(lhs->at(i), rhs->at(i))) return false; } return true; } @@ -172,53 +154,43 @@ struct ArrayNodeTrait { TVM_REGISTER_OBJECT_TYPE(ArrayNode); TVM_REGISTER_REFLECTION_VTABLE(ArrayNode, ArrayNodeTrait) -.set_creator([](const std::string&) -> ObjectPtr { - return ::tvm::runtime::make_object(); - }); - - -TVM_REGISTER_GLOBAL("node.Array") -.set_body([](TVMArgs args, TVMRetValue* ret) { - std::vector data; - for (int i = 0; i < args.size(); ++i) { - if (args[i].type_code() != kTVMNullptr) { - data.push_back(args[i].operator ObjectRef()); - } else { - data.push_back(ObjectRef(nullptr)); - } + .set_creator([](const std::string&) -> ObjectPtr { + return ::tvm::runtime::make_object(); + }); + +TVM_REGISTER_GLOBAL("node.Array").set_body([](TVMArgs args, TVMRetValue* ret) { + std::vector data; + for (int i = 0; i < args.size(); ++i) { + if (args[i].type_code() != kTVMNullptr) { + data.push_back(args[i].operator ObjectRef()); + } else { + data.push_back(ObjectRef(nullptr)); } - auto node = make_object(); - node->data = std::move(data); - *ret = Array(node); - }); - -TVM_REGISTER_GLOBAL("node.ArrayGetItem") -.set_body([](TVMArgs args, TVMRetValue* ret) { - int64_t i = args[1]; - CHECK_EQ(args[0].type_code(), kTVMObjectHandle); - Object* ptr = static_cast(args[0].value().v_handle); - CHECK(ptr->IsInstance()); - auto* n = static_cast(ptr); - CHECK_LT(static_cast(i), n->data.size()) - << "out of bound of array"; - *ret = n->data[static_cast(i)]; - }); - -TVM_REGISTER_GLOBAL("node.ArraySize") -.set_body([](TVMArgs args, TVMRetValue* ret) { - CHECK_EQ(args[0].type_code(), kTVMObjectHandle); - Object* ptr = static_cast(args[0].value().v_handle); - CHECK(ptr->IsInstance()); - *ret = static_cast( - static_cast(ptr)->data.size()); - }); + } + *ret = Array(data); +}); + +TVM_REGISTER_GLOBAL("node.ArrayGetItem").set_body([](TVMArgs args, TVMRetValue* ret) { + int64_t i = args[1]; + CHECK_EQ(args[0].type_code(), kTVMObjectHandle); + Object* ptr = static_cast(args[0].value().v_handle); + CHECK(ptr->IsInstance()); + auto* n = static_cast(ptr); + CHECK_LT(static_cast(i), n->size()) << "out of bound of array"; + *ret = n->at(i); +}); +TVM_REGISTER_GLOBAL("node.ArraySize").set_body([](TVMArgs args, TVMRetValue* ret) { + CHECK_EQ(args[0].type_code(), kTVMObjectHandle); + Object* ptr = static_cast(args[0].value().v_handle); + CHECK(ptr->IsInstance()); + *ret = static_cast(static_cast(ptr)->size()); +}); struct MapNodeTrait { static constexpr const std::nullptr_t VisitAttrs = nullptr; - static void SHashReduce(const MapNode* key, - SHashReducer hash_reduce) { + static void SHashReduceForOMap(const MapNode* key, SHashReducer hash_reduce) { // SHash's var handling depends on the determinism of traversal. // NOTE: only book-keep the mapped hash keys. // This resolves common use cases where we want to store @@ -233,15 +205,15 @@ struct MapNodeTrait { } } // sort by the hash key of the keys. - std::sort(temp.begin(), temp.end(), [](const KV& lhs, const KV& rhs) { - return lhs.first < rhs.first; - }); + std::sort(temp.begin(), temp.end(), + [](const KV& lhs, const KV& rhs) { return lhs.first < rhs.first; }); // add size to the hash hash_reduce(static_cast(key->data.size())); // hash the content for (size_t i = 0; i < temp.size();) { size_t k = i + 1; - for (; k < temp.size() && temp[k].first == temp[i].first; ++k) {} + for (; k < temp.size() && temp[k].first == temp[i].first; ++k) { + } // ties are rare, but we need to skip them to make the hash determinsitic if (k == i + 1) { hash_reduce->SHashReduceHashedValue(temp[i].first); @@ -251,47 +223,19 @@ struct MapNodeTrait { } } - static bool SEqualReduce(const MapNode* lhs, - const MapNode* rhs, - SEqualReducer equal) { - if (rhs->data.size() != lhs->data.size()) return false; - for (const auto& kv : lhs->data) { - // Only allow equal checking if the keys are already mapped - // This resolves common use cases where we want to store - // Map where Var is defined in the function - // parameters. - ObjectRef rhs_key = equal->MapLhsToRhs(kv.first); - if (!rhs_key.defined()) return false; - auto it = rhs->data.find(rhs_key); - if (it == rhs->data.end()) return false; - if (!equal(kv.second, it->second)) return false; - } - return true; - } -}; - -TVM_REGISTER_OBJECT_TYPE(MapNode); -TVM_REGISTER_REFLECTION_VTABLE(MapNode, MapNodeTrait) -.set_creator([](const std::string&) -> ObjectPtr { - return ::tvm::runtime::make_object(); - }); - - -struct StrMapNodeTrait { - static constexpr const std::nullptr_t VisitAttrs = nullptr; - - static void SHashReduce(const StrMapNode* key, - SHashReducer hash_reduce) { + static void SHashReduceForSMap(const MapNode* key, SHashReducer hash_reduce) { // NOTE: only book-keep the mapped hash keys. // This resolves common use cases where we want to store // Map where Var is defined in the function // parameters. - using KV = std::pair; - std::vector temp(key->data.begin(), key->data.end()); + using KV = std::pair; + std::vector temp; + for (const auto& kv : key->data) { + temp.push_back(std::make_pair(Downcast(kv.first), kv.second)); + } // sort by the hash key of the keys. - std::sort(temp.begin(), temp.end(), [](const KV& lhs, const KV& rhs) { - return lhs.first < rhs.first; - }); + std::sort(temp.begin(), temp.end(), + [](const KV& lhs, const KV& rhs) { return lhs.first < rhs.first; }); // NOTE: we won't have ties // add size to the hash after sorting. hash_reduce(static_cast(key->data.size())); @@ -302,10 +246,33 @@ struct StrMapNodeTrait { } } - static bool SEqualReduce(const StrMapNode* lhs, - const StrMapNode* rhs, - SEqualReducer equal) { - if (rhs->data.size() != lhs->data.size()) return false; + static void SHashReduce(const MapNode* key, SHashReducer hash_reduce) { + bool is_str_map = std::all_of(key->data.begin(), key->data.end(), [](const auto& v) { + return v.first->template IsInstance(); + }); + if (is_str_map) { + SHashReduceForSMap(key, hash_reduce); + } else { + SHashReduceForOMap(key, hash_reduce); + } + } + + static bool SEqualReduceForOMap(const MapNode* lhs, const MapNode* rhs, SEqualReducer equal) { + for (const auto& kv : lhs->data) { + // Only allow equal checking if the keys are already mapped + // This resolves common use cases where we want to store + // Map where Var is defined in the function + // parameters. + ObjectRef rhs_key = equal->MapLhsToRhs(kv.first); + if (!rhs_key.defined()) return false; + auto it = rhs->data.find(rhs_key); + if (it == rhs->data.end()) return false; + if (!equal(kv.second, it->second)) return false; + } + return true; + } + + static bool SEqualReduceForSMap(const MapNode* lhs, const MapNode* rhs, SEqualReducer equal) { for (const auto& kv : lhs->data) { auto it = rhs->data.find(kv.first); if (it == rhs->data.end()) return false; @@ -313,124 +280,84 @@ struct StrMapNodeTrait { } return true; } -}; -TVM_REGISTER_OBJECT_TYPE(StrMapNode); -TVM_REGISTER_REFLECTION_VTABLE(StrMapNode, StrMapNodeTrait) -.set_creator([](const std::string&) -> ObjectPtr { - return ::tvm::runtime::make_object(); - }); - - -TVM_REGISTER_GLOBAL("node.Map") -.set_body([](TVMArgs args, TVMRetValue* ret) { - CHECK_EQ(args.size() % 2, 0); - if (args.size() != 0 && args[0].type_code() == kTVMStr) { - // StrMap - StrMapNode::ContainerType data; - for (int i = 0; i < args.num_args; i += 2) { - CHECK(args[i].type_code() == kTVMStr) - << "key of str map need to be str"; - CHECK(args[i + 1].IsObjectRef()) - << "value of the map to be NodeRef"; - data.emplace(std::make_pair(args[i].operator std::string(), - args[i + 1].operator ObjectRef())); - } - auto node = make_object(); - node->data = std::move(data); - *ret = Map(node); - } else { - // Container node. - MapNode::ContainerType data; - for (int i = 0; i < args.num_args; i += 2) { - CHECK(args[i].IsObjectRef()) - << "key of str map need to be object"; - CHECK(args[i + 1].IsObjectRef()) - << "value of map to be NodeRef"; - data.emplace(std::make_pair(args[i].operator ObjectRef(), - args[i + 1].operator ObjectRef())); - } - auto node = make_object(); - node->data = std::move(data); - *ret = Map(node); + static bool SEqualReduce(const MapNode* lhs, const MapNode* rhs, SEqualReducer equal) { + if (rhs->data.size() != lhs->data.size()) return false; + if (rhs->data.size() == 0) return true; + bool ls = std::all_of(lhs->data.begin(), lhs->data.end(), + [](const auto& v) { return v.first->template IsInstance(); }); + bool rs = std::all_of(rhs->data.begin(), rhs->data.end(), + [](const auto& v) { return v.first->template IsInstance(); }); + if (ls != rs) { + return false; } - }); + return (ls && rs) ? SEqualReduceForSMap(lhs, rhs, equal) : SEqualReduceForOMap(lhs, rhs, equal); + } +}; +TVM_REGISTER_OBJECT_TYPE(MapNode); +TVM_REGISTER_REFLECTION_VTABLE(MapNode, MapNodeTrait) + .set_creator([](const std::string&) -> ObjectPtr { + return ::tvm::runtime::make_object(); + }); -TVM_REGISTER_GLOBAL("node.MapSize") -.set_body([](TVMArgs args, TVMRetValue* ret) { - CHECK_EQ(args[0].type_code(), kTVMObjectHandle); - Object* ptr = static_cast(args[0].value().v_handle); - if (ptr->IsInstance()) { - auto* n = static_cast(ptr); - *ret = static_cast(n->data.size()); - } else { - CHECK(ptr->IsInstance()); - auto* n = static_cast(ptr); - *ret = static_cast(n->data.size()); - } - }); - -TVM_REGISTER_GLOBAL("node.MapGetItem") -.set_body([](TVMArgs args, TVMRetValue* ret) { - CHECK_EQ(args[0].type_code(), kTVMObjectHandle); - Object* ptr = static_cast(args[0].value().v_handle); - - if (ptr->IsInstance()) { - auto* n = static_cast(ptr); - auto it = n->data.find(args[1].operator ObjectRef()); - CHECK(it != n->data.end()) - << "cannot find the corresponding key in the Map"; - *ret = (*it).second; - } else { - CHECK(ptr->IsInstance()); - auto* n = static_cast(ptr); - auto it = n->data.find(args[1].operator std::string()); - CHECK(it != n->data.end()) - << "cannot find the corresponding key in the Map"; - *ret = (*it).second; - } - }); - -TVM_REGISTER_GLOBAL("node.MapCount") -.set_body([](TVMArgs args, TVMRetValue* ret) { - CHECK_EQ(args[0].type_code(), kTVMObjectHandle); - Object* ptr = static_cast(args[0].value().v_handle); - - if (ptr->IsInstance()) { - auto* n = static_cast(ptr); - CHECK_EQ(args[0].type_code(), kTVMObjectHandle); - *ret = static_cast( - n->data.count(args[1].operator ObjectRef())); - } else { - CHECK(ptr->IsInstance()); - auto* n = static_cast(ptr); - *ret = static_cast( - n->data.count(args[1].operator std::string())); - } - }); - -TVM_REGISTER_GLOBAL("node.MapItems") -.set_body([](TVMArgs args, TVMRetValue* ret) { - CHECK_EQ(args[0].type_code(), kTVMObjectHandle); - Object* ptr = static_cast(args[0].value().v_handle); - - if (ptr->IsInstance()) { - auto* n = static_cast(ptr); - auto rkvs = make_object(); - for (const auto& kv : n->data) { - rkvs->data.push_back(kv.first); - rkvs->data.push_back(kv.second); - } - *ret = Array(rkvs); +TVM_REGISTER_GLOBAL("node.Map").set_body([](TVMArgs args, TVMRetValue* ret) { + CHECK_EQ(args.size() % 2, 0); + MapNode::ContainerType data; + for (int i = 0; i < args.num_args; i += 2) { + ObjectRef k = + String::CanConvertFrom(args[i]) ? args[i].operator String() : args[i].operator ObjectRef(); + ObjectRef v = args[i + 1]; + data.emplace(std::move(k), std::move(v)); + } + auto node = make_object(); + node->data = std::move(data); + *ret = Map(node); +}); + +TVM_REGISTER_GLOBAL("node.MapSize").set_body([](TVMArgs args, TVMRetValue* ret) { + CHECK_EQ(args[0].type_code(), kTVMObjectHandle); + Object* ptr = static_cast(args[0].value().v_handle); + CHECK(ptr->IsInstance()); + auto* n = static_cast(ptr); + *ret = static_cast(n->data.size()); +}); + +TVM_REGISTER_GLOBAL("node.MapGetItem").set_body([](TVMArgs args, TVMRetValue* ret) { + CHECK_EQ(args[0].type_code(), kTVMObjectHandle); + Object* ptr = static_cast(args[0].value().v_handle); + CHECK(ptr->IsInstance()); + + auto* n = static_cast(ptr); + auto it = n->data.find(String::CanConvertFrom(args[1]) ? args[1].operator String() + : args[1].operator ObjectRef()); + CHECK(it != n->data.end()) << "cannot find the corresponding key in the Map"; + *ret = (*it).second; +}); + +TVM_REGISTER_GLOBAL("node.MapCount").set_body([](TVMArgs args, TVMRetValue* ret) { + CHECK_EQ(args[0].type_code(), kTVMObjectHandle); + Object* ptr = static_cast(args[0].value().v_handle); + CHECK(ptr->IsInstance()); + const MapNode* n = static_cast(ptr); + int64_t cnt = n->data.count(String::CanConvertFrom(args[1]) ? args[1].operator String() + : args[1].operator ObjectRef()); + *ret = cnt; +}); + +TVM_REGISTER_GLOBAL("node.MapItems").set_body([](TVMArgs args, TVMRetValue* ret) { + CHECK_EQ(args[0].type_code(), kTVMObjectHandle); + Object* ptr = static_cast(args[0].value().v_handle); + auto* n = static_cast(ptr); + Array rkvs; + for (const auto& kv : n->data) { + if (kv.first->IsInstance()) { + rkvs.push_back(Downcast(kv.first)); } else { - auto* n = static_cast(ptr); - auto rkvs = make_object(); - for (const auto& kv : n->data) { - rkvs->data.push_back(tir::StringImmNode::make(kv.first)); - rkvs->data.push_back(kv.second); - } - *ret = Array(rkvs); + rkvs.push_back(kv.first); } - }); + rkvs.push_back(kv.second); + } + *ret = std::move(rkvs); +}); } // namespace tvm diff --git a/src/node/reflection.cc b/src/node/reflection.cc index 08a914ff38f9..8de21da9a645 100644 --- a/src/node/reflection.cc +++ b/src/node/reflection.cc @@ -21,27 +21,25 @@ * Reflection utilities. * \file node/reflection.cc */ -#include -#include +#include #include +#include #include -#include +#include namespace tvm { -using runtime::TVMRetValue; -using runtime::TVMArgs; using runtime::PackedFunc; +using runtime::TVMArgs; +using runtime::TVMRetValue; // Attr getter. class AttrGetter : public AttrVisitor { public: - const std::string& skey; + const String& skey; TVMRetValue* ret; - AttrGetter(const std::string &skey, - TVMRetValue* ret) - : skey(skey), ret(ret) {} + AttrGetter(const String& skey, TVMRetValue* ret) : skey(skey), ret(ret) {} bool found_ref_object{false}; @@ -86,8 +84,7 @@ class AttrGetter : public AttrVisitor { } }; -runtime::TVMRetValue ReflectionVTable::GetAttr( - Object* self, const std::string& field_name) const { +runtime::TVMRetValue ReflectionVTable::GetAttr(Object* self, const String& field_name) const { runtime::TVMRetValue ret; AttrGetter getter(field_name, &ret); @@ -110,8 +107,8 @@ runtime::TVMRetValue ReflectionVTable::GetAttr( } } if (!success) { - LOG(FATAL) << "AttributeError: " << self->GetTypeKey() - << " object has no attributed " << getter.skey; + LOG(FATAL) << "AttributeError: " << self->GetTypeKey() << " object has no attributed " + << getter.skey; } return ret; } @@ -121,40 +118,19 @@ class AttrDir : public AttrVisitor { public: std::vector* names; - void Visit(const char* key, double* value) final { - names->push_back(key); - } - void Visit(const char* key, int64_t* value) final { - names->push_back(key); - } - void Visit(const char* key, uint64_t* value) final { - names->push_back(key); - } - void Visit(const char* key, bool* value) final { - names->push_back(key); - } - void Visit(const char* key, int* value) final { - names->push_back(key); - } - void Visit(const char* key, void** value) final { - names->push_back(key); - } - void Visit(const char* key, DataType* value) final { - names->push_back(key); - } - void Visit(const char* key, std::string* value) final { - names->push_back(key); - } - void Visit(const char* key, runtime::NDArray* value) final { - names->push_back(key); - } - void Visit(const char* key, runtime::ObjectRef* value) final { - names->push_back(key); - } + void Visit(const char* key, double* value) final { names->push_back(key); } + void Visit(const char* key, int64_t* value) final { names->push_back(key); } + void Visit(const char* key, uint64_t* value) final { names->push_back(key); } + void Visit(const char* key, bool* value) final { names->push_back(key); } + void Visit(const char* key, int* value) final { names->push_back(key); } + void Visit(const char* key, void** value) final { names->push_back(key); } + void Visit(const char* key, DataType* value) final { names->push_back(key); } + void Visit(const char* key, std::string* value) final { names->push_back(key); } + void Visit(const char* key, runtime::NDArray* value) final { names->push_back(key); } + void Visit(const char* key, runtime::ObjectRef* value) final { names->push_back(key); } }; -std::vector -ReflectionVTable::ListAttrNames(Object* self) const { +std::vector ReflectionVTable::ListAttrNames(Object* self) const { std::vector names; AttrDir dir; dir.names = &names; @@ -176,13 +152,11 @@ ReflectionVTable* ReflectionVTable::Global() { return &inst; } -ObjectPtr -ReflectionVTable::CreateInitObject(const std::string& type_key, - const std::string& repr_bytes) const { +ObjectPtr ReflectionVTable::CreateInitObject(const std::string& type_key, + const std::string& repr_bytes) const { uint32_t tindex = Object::TypeKey2Index(type_key); if (tindex >= fcreate_.size() || fcreate_[tindex] == nullptr) { - LOG(FATAL) << "TypeError: " << type_key - << " is not registered via TVM_REGISTER_NODE_TYPE"; + LOG(FATAL) << "TypeError: " << type_key << " is not registered via TVM_REGISTER_NODE_TYPE"; } return fcreate_[tindex](repr_bytes); } @@ -192,30 +166,16 @@ class NodeAttrSetter : public AttrVisitor { std::string type_key; std::unordered_map attrs; - void Visit(const char* key, double* value) final { - *value = GetAttr(key).operator double(); - } - void Visit(const char* key, int64_t* value) final { - *value = GetAttr(key).operator int64_t(); - } - void Visit(const char* key, uint64_t* value) final { - *value = GetAttr(key).operator uint64_t(); - } - void Visit(const char* key, int* value) final { - *value = GetAttr(key).operator int(); - } - void Visit(const char* key, bool* value) final { - *value = GetAttr(key).operator bool(); - } + void Visit(const char* key, double* value) final { *value = GetAttr(key).operator double(); } + void Visit(const char* key, int64_t* value) final { *value = GetAttr(key).operator int64_t(); } + void Visit(const char* key, uint64_t* value) final { *value = GetAttr(key).operator uint64_t(); } + void Visit(const char* key, int* value) final { *value = GetAttr(key).operator int(); } + void Visit(const char* key, bool* value) final { *value = GetAttr(key).operator bool(); } void Visit(const char* key, std::string* value) final { *value = GetAttr(key).operator std::string(); } - void Visit(const char* key, void** value) final { - *value = GetAttr(key).operator void*(); - } - void Visit(const char* key, DataType* value) final { - *value = GetAttr(key).operator DataType(); - } + void Visit(const char* key, void** value) final { *value = GetAttr(key).operator void*(); } + void Visit(const char* key, DataType* value) final { *value = GetAttr(key).operator DataType(); } void Visit(const char* key, runtime::NDArray* value) final { *value = GetAttr(key).operator runtime::NDArray(); } @@ -235,27 +195,54 @@ class NodeAttrSetter : public AttrVisitor { } }; -void InitNodeByPackedArgs(Object* n, const TVMArgs& args) { +void InitNodeByPackedArgs(ReflectionVTable* reflection, Object* n, const TVMArgs& args) { NodeAttrSetter setter; setter.type_key = n->GetTypeKey(); CHECK_EQ(args.size() % 2, 0); for (int i = 0; i < args.size(); i += 2) { - setter.attrs.emplace(args[i].operator std::string(), - args[i + 1]); + setter.attrs.emplace(args[i].operator std::string(), args[i + 1]); } - auto* reflection = ReflectionVTable::Global(); reflection->VisitAttrs(n, &setter); if (setter.attrs.size() != 0) { std::ostringstream os; os << setter.type_key << " does not contain field "; - for (const auto &kv : setter.attrs) { + for (const auto& kv : setter.attrs) { os << " " << kv.first; } LOG(FATAL) << os.str(); } } +ObjectRef ReflectionVTable::CreateObject(const std::string& type_key, const TVMArgs& kwargs) { + ObjectPtr n = this->CreateInitObject(type_key); + if (n->IsInstance()) { + static_cast(n.get())->InitByPackedArgs(kwargs); + } else { + InitNodeByPackedArgs(this, n.get(), kwargs); + } + return ObjectRef(n); +} + +ObjectRef ReflectionVTable::CreateObject(const std::string& type_key, + const Map& kwargs) { + // Redirect to the TVMArgs version + // It is not the most efficient way, but CreateObject is not meant to be used + // in a fast code-path and is mainly reserved as a flexible API for frontends. + std::vector values(kwargs.size() * 2); + std::vector tcodes(kwargs.size() * 2); + runtime::TVMArgsSetter setter(values.data(), tcodes.data()); + int index = 0; + + for (auto& kv : static_cast(kwargs.get())->data) { + setter(index, Downcast(kv.first).c_str()); + setter(index + 1, kv.second); + index += 2; + } + + return CreateObject(type_key, runtime::TVMArgs(values.data(), tcodes.data(), kwargs.size() * 2)); +} + // Expose to FFI APIs. void NodeGetAttr(TVMArgs args, TVMRetValue* ret) { CHECK_EQ(args[0].type_code(), kTVMObjectHandle); @@ -267,17 +254,17 @@ void NodeListAttrNames(TVMArgs args, TVMRetValue* ret) { CHECK_EQ(args[0].type_code(), kTVMObjectHandle); Object* self = static_cast(args[0].value().v_handle); - auto names = std::make_shared >( - ReflectionVTable::Global()->ListAttrNames(self)); - - *ret = PackedFunc([names](TVMArgs args, TVMRetValue *rv) { - int64_t i = args[0]; - if (i == -1) { - *rv = static_cast(names->size()); - } else { - *rv = (*names)[i]; - } - }); + auto names = + std::make_shared >(ReflectionVTable::Global()->ListAttrNames(self)); + + *ret = PackedFunc([names](TVMArgs args, TVMRetValue* rv) { + int64_t i = args[0]; + if (i == -1) { + *rv = static_cast(names->size()); + } else { + *rv = (*names)[i]; + } + }); } // API function to make node. @@ -287,23 +274,12 @@ void MakeNode(const TVMArgs& args, TVMRetValue* rv) { std::string type_key = args[0]; std::string empty_str; TVMArgs kwargs(args.values + 1, args.type_codes + 1, args.size() - 1); - auto* reflection = ReflectionVTable::Global(); - ObjectPtr n = reflection->CreateInitObject(type_key); - if (n->IsInstance()) { - static_cast(n.get())->InitByPackedArgs(kwargs); - } else { - InitNodeByPackedArgs(n.get(), kwargs); - } - *rv = ObjectRef(n); + *rv = ReflectionVTable::Global()->CreateObject(type_key, kwargs); } +TVM_REGISTER_GLOBAL("node.NodeGetAttr").set_body(NodeGetAttr); -TVM_REGISTER_GLOBAL("node.NodeGetAttr") -.set_body(NodeGetAttr); - -TVM_REGISTER_GLOBAL("node.NodeListAttrNames") -.set_body(NodeListAttrNames); +TVM_REGISTER_GLOBAL("node.NodeListAttrNames").set_body(NodeListAttrNames); -TVM_REGISTER_GLOBAL("node.MakeNode") -.set_body(MakeNode); +TVM_REGISTER_GLOBAL("node.MakeNode").set_body(MakeNode); } // namespace tvm diff --git a/src/node/repr_printer.cc b/src/node/repr_printer.cc index bf41c82f5a76..ea263439023f 100644 --- a/src/node/repr_printer.cc +++ b/src/node/repr_printer.cc @@ -21,8 +21,8 @@ * Printer utilities * \file node/repr_printer.cc */ -#include #include +#include namespace tvm { @@ -51,16 +51,11 @@ ReprPrinter::FType& ReprPrinter::vtable() { return inst; } -void Dump(const runtime::ObjectRef& n) { - std::cerr << n << "\n"; -} +void Dump(const runtime::ObjectRef& n) { std::cerr << n << "\n"; } -void Dump(const runtime::Object* n) { - Dump(runtime::GetRef(n)); -} +void Dump(const runtime::Object* n) { Dump(runtime::GetRef(n)); } -TVM_REGISTER_GLOBAL("node.AsRepr") -.set_body_typed([](runtime::ObjectRef obj) { +TVM_REGISTER_GLOBAL("node.AsRepr").set_body_typed([](runtime::ObjectRef obj) { std::ostringstream os; os << obj; return os.str(); diff --git a/src/node/serialization.cc b/src/node/serialization.cc index ee6072d77c1c..386653349904 100644 --- a/src/node/serialization.cc +++ b/src/node/serialization.cc @@ -23,29 +23,25 @@ */ #include #include -#include -#include -#include +#include #include #include #include -#include +#include +#include +#include -#include #include #include +#include #include "../support/base64.h" namespace tvm { -inline std::string Type2String(const DataType& t) { - return runtime::DLDataType2String(t); -} +inline std::string Type2String(const DataType& t) { return runtime::DLDataType2String(t); } -inline DataType String2Type(std::string s) { - return DataType(runtime::String2DLDataType(s)); -} +inline DataType String2Type(std::string s) { return DataType(runtime::String2DLDataType(s)); } inline std::string Base64Decode(std::string s) { dmlc::MemoryStringStream mstrm(&s); @@ -109,19 +105,23 @@ class NodeIndexer : public AttrVisitor { if (node->IsInstance()) { ArrayNode* n = static_cast(node); - for (const auto& sp : n->data) { + for (const auto& sp : *n) { MakeIndex(const_cast(sp.get())); } } else if (node->IsInstance()) { MapNode* n = static_cast(node); - for (const auto& kv : n->data) { - MakeIndex(const_cast(kv.first.get())); - MakeIndex(const_cast(kv.second.get())); - } - } else if (node->IsInstance()) { - StrMapNode* n = static_cast(node); - for (const auto& kv : n->data) { - MakeIndex(const_cast(kv.second.get())); + bool is_str_map = std::all_of(n->data.begin(), n->data.end(), [](const auto& v) { + return v.first->template IsInstance(); + }); + if (is_str_map) { + for (const auto& kv : n->data) { + MakeIndex(const_cast(kv.second.get())); + } + } else { + for (const auto& kv : n->data) { + MakeIndex(const_cast(kv.first.get())); + MakeIndex(const_cast(kv.second.get())); + } } } else { // if the node already have repr bytes, no need to visit Attrs. @@ -148,7 +148,7 @@ struct JSONNode { /*! \brief values of a map or array. */ std::vector data; - void Save(dmlc::JSONWriter *writer) const { + void Save(dmlc::JSONWriter* writer) const { writer->BeginObject(); writer->WriteObjectKeyValue("type_key", type_key); if (repr_bytes.size() != 0) { @@ -173,7 +173,7 @@ struct JSONNode { writer->EndObject(); } - void Load(dmlc::JSONReader *reader) { + void Load(dmlc::JSONReader* reader) { attrs.clear(); data.clear(); repr_bytes.clear(); @@ -213,36 +213,23 @@ class JSONAttrGetter : public AttrVisitor { s << (*value); node_->attrs[key] = s.str(); } - void Visit(const char* key, int64_t* value) final { - node_->attrs[key] = std::to_string(*value); - } - void Visit(const char* key, uint64_t* value) final { - node_->attrs[key] = std::to_string(*value); - } - void Visit(const char* key, int* value) final { - node_->attrs[key] = std::to_string(*value); - } - void Visit(const char* key, bool* value) final { - node_->attrs[key] = std::to_string(*value); - } - void Visit(const char* key, std::string* value) final { - node_->attrs[key] = *value; - } + void Visit(const char* key, int64_t* value) final { node_->attrs[key] = std::to_string(*value); } + void Visit(const char* key, uint64_t* value) final { node_->attrs[key] = std::to_string(*value); } + void Visit(const char* key, int* value) final { node_->attrs[key] = std::to_string(*value); } + void Visit(const char* key, bool* value) final { node_->attrs[key] = std::to_string(*value); } + void Visit(const char* key, std::string* value) final { node_->attrs[key] = *value; } void Visit(const char* key, void** value) final { LOG(FATAL) << "not allowed to serialize a pointer"; } - void Visit(const char* key, DataType* value) final { - node_->attrs[key] = Type2String(*value); - } + void Visit(const char* key, DataType* value) final { node_->attrs[key] = Type2String(*value); } void Visit(const char* key, runtime::NDArray* value) final { - node_->attrs[key] = std::to_string( - tensor_index_->at(const_cast((*value).operator->()))); + node_->attrs[key] = + std::to_string(tensor_index_->at(const_cast((*value).operator->()))); } void Visit(const char* key, ObjectRef* value) final { - node_->attrs[key] = std::to_string( - node_index_->at(const_cast(value->get()))); + node_->attrs[key] = std::to_string(node_index_->at(const_cast(value->get()))); } // Get the node @@ -261,24 +248,24 @@ class JSONAttrGetter : public AttrVisitor { if (node->IsInstance()) { ArrayNode* n = static_cast(node); - for (size_t i = 0; i < n->data.size(); ++i) { - node_->data.push_back( - node_index_->at(const_cast(n->data[i].get()))); + for (size_t i = 0; i < n->size(); ++i) { + node_->data.push_back(node_index_->at(const_cast(n->at(i).get()))); } } else if (node->IsInstance()) { MapNode* n = static_cast(node); - for (const auto& kv : n->data) { - node_->data.push_back( - node_index_->at(const_cast(kv.first.get()))); - node_->data.push_back( - node_index_->at(const_cast(kv.second.get()))); - } - } else if (node->IsInstance()) { - StrMapNode* n = static_cast(node); - for (const auto& kv : n->data) { - node_->keys.push_back(kv.first); - node_->data.push_back( - node_index_->at(const_cast(kv.second.get()))); + bool is_str_map = std::all_of(n->data.begin(), n->data.end(), [](const auto& v) { + return v.first->template IsInstance(); + }); + if (is_str_map) { + for (const auto& kv : n->data) { + node_->keys.push_back(Downcast(kv.first)); + node_->data.push_back(node_index_->at(const_cast(kv.second.get()))); + } + } else { + for (const auto& kv : n->data) { + node_->data.push_back(node_index_->at(const_cast(kv.first.get()))); + node_->data.push_back(node_index_->at(const_cast(kv.second.get()))); + } } } else { // recursively index normal object. @@ -291,7 +278,7 @@ class JSONAttrGetter : public AttrVisitor { // from given json node. class JSONAttrSetter : public AttrVisitor { public: - const std::vector >* node_list_; + const std::vector>* node_list_; const std::vector* tensor_list_; JSONNode* node_; @@ -304,7 +291,7 @@ class JSONAttrSetter : public AttrVisitor { } return it->second; } - template + template void ParseValue(const char* key, T* value) const { std::istringstream is(GetValue(key)); is >> *value; @@ -312,24 +299,12 @@ class JSONAttrSetter : public AttrVisitor { LOG(FATAL) << "Wrong value format for field " << key; } } - void Visit(const char* key, double* value) final { - ParseValue(key, value); - } - void Visit(const char* key, int64_t* value) final { - ParseValue(key, value); - } - void Visit(const char* key, uint64_t* value) final { - ParseValue(key, value); - } - void Visit(const char* key, int* value) final { - ParseValue(key, value); - } - void Visit(const char* key, bool* value) final { - ParseValue(key, value); - } - void Visit(const char* key, std::string* value) final { - *value = GetValue(key); - } + void Visit(const char* key, double* value) final { ParseValue(key, value); } + void Visit(const char* key, int64_t* value) final { ParseValue(key, value); } + void Visit(const char* key, uint64_t* value) final { ParseValue(key, value); } + void Visit(const char* key, int* value) final { ParseValue(key, value); } + void Visit(const char* key, bool* value) final { ParseValue(key, value); } + void Visit(const char* key, std::string* value) final { *value = GetValue(key); } void Visit(const char* key, void** value) final { LOG(FATAL) << "not allowed to deserialize a pointer"; } @@ -355,23 +330,24 @@ class JSONAttrSetter : public AttrVisitor { if (node->IsInstance()) { ArrayNode* n = static_cast(node); - n->data.clear(); + CHECK_EQ(n->size(), node_->data.size()); + int64_t i = 0; for (size_t index : node_->data) { - n->data.push_back(ObjectRef(node_list_->at(index))); + n->SetItem(i++, ObjectRef(node_list_->at(index))); } } else if (node->IsInstance()) { MapNode* n = static_cast(node); - CHECK_EQ(node_->data.size() % 2, 0U); - for (size_t i = 0; i < node_->data.size(); i += 2) { - n->data[ObjectRef(node_list_->at(node_->data[i]))] - = ObjectRef(node_list_->at(node_->data[i + 1])); - } - } else if (node->IsInstance()) { - StrMapNode* n = static_cast(node); - CHECK_EQ(node_->data.size(), node_->keys.size()); - for (size_t i = 0; i < node_->data.size(); ++i) { - n->data[node_->keys[i]] - = ObjectRef(node_list_->at(node_->data[i])); + if (node_->keys.empty()) { + CHECK_EQ(node_->data.size() % 2, 0U); + for (size_t i = 0; i < node_->data.size(); i += 2) { + n->data[ObjectRef(node_list_->at(node_->data[i]))] = + ObjectRef(node_list_->at(node_->data[i + 1])); + } + } else { + CHECK_EQ(node_->data.size(), node_->keys.size()); + for (size_t i = 0; i < node_->data.size(); ++i) { + n->data[String(node_->keys[i])] = ObjectRef(node_list_->at(node_->data[i])); + } } } else { reflection_->VisitAttrs(node, this); @@ -390,7 +366,7 @@ struct JSONGraph { // global attributes AttrMap attrs; - void Save(dmlc::JSONWriter *writer) const { + void Save(dmlc::JSONWriter* writer) const { writer->BeginObject(); writer->WriteObjectKeyValue("root", root); writer->WriteObjectKeyValue("nodes", nodes); @@ -401,7 +377,7 @@ struct JSONGraph { writer->EndObject(); } - void Load(dmlc::JSONReader *reader) { + void Load(dmlc::JSONReader* reader) { attrs.clear(); dmlc::JSONObjectReadHelper helper; helper.DeclareField("root", &root); @@ -448,21 +424,23 @@ std::string SaveJSON(const ObjectRef& n) { } ObjectRef LoadJSON(std::string json_str) { - std::istringstream is(json_str); - dmlc::JSONReader reader(&is); JSONGraph jgraph; - // load in json graph. - jgraph.Load(&reader); - std::vector > nodes; + std::vector> nodes; std::vector tensors; - // load in tensors - for (const std::string& blob : jgraph.b64ndarrays) { - dmlc::MemoryStringStream mstrm(const_cast(&blob)); - support::Base64InStream b64strm(&mstrm); - b64strm.InitPosition(); - runtime::NDArray temp; - CHECK(temp.Load(&b64strm)); - tensors.emplace_back(temp); + { + // load in json graph. + std::istringstream is(json_str); + dmlc::JSONReader reader(&is); + jgraph.Load(&reader); + // load in tensors + for (const std::string& blob : jgraph.b64ndarrays) { + dmlc::MemoryStringStream mstrm(const_cast(&blob)); + support::Base64InStream b64strm(&mstrm); + b64strm.InitPosition(); + runtime::NDArray temp; + CHECK(temp.Load(&b64strm)); + tensors.emplace_back(temp); + } } ReflectionVTable* reflection = ReflectionVTable::Global(); @@ -470,10 +448,12 @@ ObjectRef LoadJSON(std::string json_str) { nodes.reserve(jgraph.nodes.size()); for (const JSONNode& jnode : jgraph.nodes) { - if (jnode.type_key.length() != 0) { - ObjectPtr node = - reflection->CreateInitObject(jnode.type_key, jnode.repr_bytes); - nodes.emplace_back(node); + if (jnode.type_key == ArrayNode::_type_key) { + CHECK(jnode.repr_bytes.empty()); + nodes.emplace_back(ArrayNode::CreateRepeated(jnode.data.size(), ObjectRef(nullptr))); + } else if (jnode.type_key.length() != 0) { + ObjectPtr node = reflection->CreateInitObject(jnode.type_key, jnode.repr_bytes); + nodes.emplace_back(std::move(node)); } else { nodes.emplace_back(ObjectPtr()); } @@ -488,8 +468,7 @@ ObjectRef LoadJSON(std::string json_str) { // Skip the nodes that has an repr bytes representation. // NOTE: the second condition is used to guard the case // where the repr bytes itself is an empty string "". - if (setter.node_->repr_bytes.length() == 0 && - nodes[i] != nullptr && + if (setter.node_->repr_bytes.length() == 0 && nodes[i] != nullptr && !reflection->GetReprBytes(nodes[i].get(), nullptr)) { setter.Set(nodes[i].get()); } @@ -497,9 +476,7 @@ ObjectRef LoadJSON(std::string json_str) { return ObjectRef(nodes.at(jgraph.root)); } -TVM_REGISTER_GLOBAL("node.SaveJSON") -.set_body_typed(SaveJSON); +TVM_REGISTER_GLOBAL("node.SaveJSON").set_body_typed(SaveJSON); -TVM_REGISTER_GLOBAL("node.LoadJSON") -.set_body_typed(LoadJSON); +TVM_REGISTER_GLOBAL("node.LoadJSON").set_body_typed(LoadJSON); } // namespace tvm diff --git a/src/node/structural_equal.cc b/src/node/structural_equal.cc index 03cdf9c1e429..9fcf510b70a8 100644 --- a/src/node/structural_equal.cc +++ b/src/node/structural_equal.cc @@ -19,10 +19,10 @@ /*! * \file src/node/structural_equal.cc */ -#include -#include #include #include +#include +#include #include #include @@ -30,13 +30,13 @@ namespace tvm { // Define the dispatch functio here since primary user is in this file. -bool ReflectionVTable:: -SEqualReduce(const Object* self, const Object* other, SEqualReducer equal) const { +bool ReflectionVTable::SEqualReduce(const Object* self, const Object* other, + SEqualReducer equal) const { uint32_t tindex = self->type_index(); if (tindex >= fsequal_reduce_.size() || fsequal_reduce_[tindex] == nullptr) { LOG(FATAL) << "TypeError: SEqualReduce of " << self->GetTypeKey() - << " is not registered via TVM_REGISTER_NODE_TYPE." - << " Did you forget to set _type_has_method_sequal_reduce=true?"; + << " is not registered via TVM_REGISTER_NODE_TYPE." + << " Did you forget to set _type_has_method_sequal_reduce=true?"; } return fsequal_reduce_[tindex](self, other, equal); } @@ -50,11 +50,9 @@ SEqualReduce(const Object* self, const Object* other, SEqualReducer equal) const * The order of SEqual being called is the same as the order as if we * eagerly do recursive calls in SEqualReduce. */ -class RemapVarSEqualHandler : - public SEqualReducer::Handler { +class RemapVarSEqualHandler : public SEqualReducer::Handler { public: - explicit RemapVarSEqualHandler(bool assert_mode) - : assert_mode_(assert_mode) {} + explicit RemapVarSEqualHandler(bool assert_mode) : assert_mode_(assert_mode) {} bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars) final { // We cannot use check lhs.same_as(rhs) to check equality. @@ -121,9 +119,8 @@ class RemapVarSEqualHandler : // Check the result. bool CheckResult(bool result, const ObjectRef& lhs, const ObjectRef& rhs) { if (assert_mode_ && !result) { - LOG(FATAL) - << "ValueError: StructuralEqual check failed, caused by\n" - << "lhs = " << lhs << "\nrhs = " << rhs; + LOG(FATAL) << "ValueError: StructuralEqual check failed, caused by\n" + << "lhs = " << lhs << "\nrhs = " << rhs; } return result; } @@ -177,9 +174,7 @@ class RemapVarSEqualHandler : // The default equal as registered in the structural equal vtable. bool DispatchSEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars) { auto compute = [=]() { - CHECK(lhs.defined() && - rhs.defined() && - lhs->type_index() == rhs->type_index()); + CHECK(lhs.defined() && rhs.defined() && lhs->type_index() == rhs->type_index()); // skip entries that already have equality maps. auto it = equal_map_lhs_.find(lhs); if (it != equal_map_lhs_.end()) { @@ -221,21 +216,18 @@ class RemapVarSEqualHandler : // reflection vtable ReflectionVTable* vtable_ = ReflectionVTable::Global(); // map from lhs to rhs - std::unordered_map equal_map_lhs_; + std::unordered_map equal_map_lhs_; // map from rhs to lhs - std::unordered_map equal_map_rhs_; + std::unordered_map equal_map_rhs_; }; TVM_REGISTER_GLOBAL("node.StructuralEqual") -.set_body_typed([](const ObjectRef& lhs, - const ObjectRef& rhs, - bool assert_mode, - bool map_free_vars) { - return RemapVarSEqualHandler(assert_mode).Equal(lhs, rhs, map_free_vars); -}); - -bool StructuralEqual::operator()(const ObjectRef& lhs, - const ObjectRef& rhs) const { + .set_body_typed([](const ObjectRef& lhs, const ObjectRef& rhs, bool assert_mode, + bool map_free_vars) { + return RemapVarSEqualHandler(assert_mode).Equal(lhs, rhs, map_free_vars); + }); + +bool StructuralEqual::operator()(const ObjectRef& lhs, const ObjectRef& rhs) const { return RemapVarSEqualHandler(false).Equal(lhs, rhs, false); } diff --git a/src/node/structural_hash.cc b/src/node/structural_hash.cc index a29340c931a4..7c32f31fad89 100644 --- a/src/node/structural_hash.cc +++ b/src/node/structural_hash.cc @@ -19,25 +19,23 @@ /*! * \file src/node/structural_hash.cc */ -#include -#include #include #include +#include +#include #include -#include #include - +#include namespace tvm { // Define the dispatch functio here since primary user is in this file. -void ReflectionVTable:: -SHashReduce(const Object* self, SHashReducer reducer) const { +void ReflectionVTable::SHashReduce(const Object* self, SHashReducer reducer) const { uint32_t tindex = self->type_index(); if (tindex >= fshash_reduce_.size() || fshash_reduce_[tindex] == nullptr) { LOG(FATAL) << "TypeError: SHashReduce of " << self->GetTypeKey() - << " is not registered via TVM_REGISTER_NODE_TYPE"; + << " is not registered via TVM_REGISTER_NODE_TYPE"; } fshash_reduce_[tindex](self, reducer); } @@ -49,8 +47,7 @@ SHashReduce(const Object* self, SHashReducer reducer) const { // In particular, when we traverse unordered_map, we should first sort // the entries by keys(or hash of keys) before traversing. -class VarCountingSHashHandler : - public SHashReducer::Handler { +class VarCountingSHashHandler : public SHashReducer::Handler { public: /*! \brief Pending reduce tasks. */ struct Task { @@ -76,7 +73,6 @@ class VarCountingSHashHandler : : object(object), reduced_hash(reduced_hash), map_free_vars(map_free_vars) {} }; - VarCountingSHashHandler() {} void MarkGraphNode() final { @@ -95,8 +91,7 @@ class VarCountingSHashHandler : } void SHashReduceHashedValue(size_t hashed_value) final { - pending_tasks_.emplace_back( - Task(ObjectRef(nullptr), hashed_value, false)); + pending_tasks_.emplace_back(Task(ObjectRef(nullptr), hashed_value, false)); } void SHashReduceFreeVar(const runtime::Object* var, bool map_free_vars) final { @@ -104,13 +99,11 @@ class VarCountingSHashHandler : if (map_free_vars) { // use counter value. size_t value = std::hash()(free_var_counter_++); - pending_tasks_.emplace_back( - Task(ObjectRef(nullptr), value, false)); + pending_tasks_.emplace_back(Task(ObjectRef(nullptr), value, false)); } else { // use pointer hash size_t value = std::hash()(var); - pending_tasks_.emplace_back( - Task(ObjectRef(nullptr), value, false)); + pending_tasks_.emplace_back(Task(ObjectRef(nullptr), value, false)); } } @@ -124,12 +117,10 @@ class VarCountingSHashHandler : } auto it = hash_memo_.find(object); if (it != hash_memo_.end()) { - pending_tasks_.emplace_back( - Task(ObjectRef(nullptr), it->second, false)); + pending_tasks_.emplace_back(Task(ObjectRef(nullptr), it->second, false)); } else { // Push a pending task with initial value. - pending_tasks_.emplace_back( - Task(object, object->GetTypeKeyHash(), map_free_vars)); + pending_tasks_.emplace_back(Task(object, object->GetTypeKeyHash(), map_free_vars)); } } @@ -195,9 +186,8 @@ class VarCountingSHashHandler : // Append the graph node counter to the hash // so that we can distinguish DAG from trees. if (entry.graph_node_hash) { - entry.reduced_hash = HashCombine( - entry.reduced_hash, - std::hash()(graph_node_counter_++)); + entry.reduced_hash = + HashCombine(entry.reduced_hash, std::hash()(graph_node_counter_++)); } hash_memo_[entry.object] = entry.reduced_hash; } @@ -265,16 +255,14 @@ class VarCountingSHashHandler : // reflection vtable ReflectionVTable* vtable_ = ReflectionVTable::Global(); // map from lhs to rhs - std::unordered_map hash_memo_; + std::unordered_map hash_memo_; }; - TVM_REGISTER_GLOBAL("node.StructuralHash") -.set_body_typed([](const ObjectRef& object, bool map_free_vars) -> int64_t { - size_t hashed_value = - VarCountingSHashHandler().Hash(object, map_free_vars); - return static_cast(hashed_value); -}); + .set_body_typed([](const ObjectRef& object, bool map_free_vars) -> int64_t { + size_t hashed_value = VarCountingSHashHandler().Hash(object, map_free_vars); + return static_cast(hashed_value); + }); size_t StructuralHash::operator()(const ObjectRef& object) const { return VarCountingSHashHandler().Hash(object, false); diff --git a/src/printer/doc.cc b/src/printer/doc.cc index ee260f41df55..d487e3e7aa3e 100644 --- a/src/printer/doc.cc +++ b/src/printer/doc.cc @@ -23,10 +23,12 @@ * * Reference: Philip Wadler. A Prettier Printer. Journal of Functional Programming'98 */ +#include "doc.h" + #include -#include + #include -#include "doc.h" +#include namespace tvm { @@ -38,9 +40,7 @@ class DocTextNode : public DocAtomNode { /*! \brief The str content in the text. */ std::string str; - explicit DocTextNode(std::string str_val) - : str(str_val) { - } + explicit DocTextNode(std::string str_val) : str(str_val) {} static constexpr const char* _type_key = "printer.DocText"; TVM_DECLARE_FINAL_OBJECT_INFO(DocTextNode, DocAtomNode); @@ -68,8 +68,7 @@ class DocLineNode : public DocAtomNode { /*! \brief The amount of indent in newline. */ int indent; - explicit DocLineNode(int indent) - : indent(indent) {} + explicit DocLineNode(int indent) : indent(indent) {} static constexpr const char* _type_key = "printer.DocLine"; TVM_DECLARE_FINAL_OBJECT_INFO(DocLineNode, DocAtomNode); @@ -79,9 +78,7 @@ TVM_REGISTER_OBJECT_TYPE(DocLineNode); class DocLine : public DocAtom { public: - explicit DocLine(int indent) { - data_ = runtime::make_object(indent); - } + explicit DocLine(int indent) { data_ = runtime::make_object(indent); } TVM_DEFINE_OBJECT_REF_METHODS(DocLine, DocAtom, DocLineNode); }; @@ -89,14 +86,11 @@ class DocLine : public DocAtom { // DSL function implementations Doc& Doc::operator<<(const Doc& right) { CHECK(this != &right); - this->stream_.insert( - this->stream_.end(), right.stream_.begin(), right.stream_.end()); + this->stream_.insert(this->stream_.end(), right.stream_.begin(), right.stream_.end()); return *this; } -Doc& Doc::operator<<(std::string right) { - return *this << DocText(right); -} +Doc& Doc::operator<<(std::string right) { return *this << DocText(right); } Doc& Doc::operator<<(const DocAtom& right) { this->stream_.push_back(right); @@ -117,13 +111,9 @@ std::string Doc::str() { return os.str(); } -Doc Doc::NewLine(int indent) { - return Doc() << DocLine(indent); -} +Doc Doc::NewLine(int indent) { return Doc() << DocLine(indent); } -Doc Doc::Text(std::string text) { - return Doc() << DocText(text); -} +Doc Doc::Text(std::string text) { return Doc() << DocText(text); } Doc Doc::RawText(std::string text) { return Doc() << DocAtom(runtime::make_object(text)); @@ -152,10 +142,7 @@ Doc Doc::PyBoolLiteral(bool value) { } } -Doc Doc::Brace(std::string open, - const Doc& body, - std::string close, - int indent) { +Doc Doc::Brace(std::string open, const Doc& body, std::string close, int indent) { Doc doc; doc << open; doc << Indent(indent, NewLine() << body) << NewLine(); diff --git a/src/printer/doc.h b/src/printer/doc.h index 7d8d72e00b4c..dc6ba8952f3e 100644 --- a/src/printer/doc.h +++ b/src/printer/doc.h @@ -26,12 +26,13 @@ #ifndef TVM_PRINTER_DOC_H_ #define TVM_PRINTER_DOC_H_ +#include #include #include -#include + #include -#include #include +#include namespace tvm { @@ -48,7 +49,7 @@ class DocAtomNode : public Object { /*! * \brief Managed reference to DocAtomNode. * \sa DocAtomNode. -*/ + */ class DocAtom : public ObjectRef { public: TVM_DEFINE_OBJECT_REF_METHODS(DocAtom, ObjectRef, DocAtomNode); @@ -93,8 +94,7 @@ class Doc { * \tparam T the type of the value. * \return reference to self. */ - template::value>::type> + template ::value>::type> Doc& operator<<(const T& value) { std::ostringstream os; os << value; @@ -149,10 +149,7 @@ class Doc { * \param indent amount of indentation. * \return The created doc. */ - static Doc Brace(std::string open, - const Doc& body, - std::string close, - int indent = 2); + static Doc Brace(std::string open, const Doc& body, std::string close, int indent = 2); /*! * \brief Create a doc by concatenating together with separator. * \param vec The docs to be concatenated. diff --git a/src/printer/meta_data.h b/src/printer/meta_data.h index d3906926363c..df27d92170c6 100644 --- a/src/printer/meta_data.h +++ b/src/printer/meta_data.h @@ -24,10 +24,12 @@ #ifndef TVM_PRINTER_META_DATA_H_ #define TVM_PRINTER_META_DATA_H_ -#include #include +#include + #include #include + #include "doc.h" namespace tvm { @@ -98,8 +100,7 @@ class TextMetaDataContext { } std::string type_key = node->GetTypeKey(); CHECK(!type_key.empty()); - Array& mvector = - meta_data_[type_key]; + Array& mvector = meta_data_[type_key]; int64_t index = static_cast(mvector.size()); mvector.push_back(node); Doc doc; @@ -108,6 +109,13 @@ class TextMetaDataContext { return meta_repr_[node]; } + /*! + * \brief Test whether a node has been put in meta + * \param node The query node + * \return whether the node has been put in meta + */ + bool InMeta(const ObjectRef& node) { return meta_repr_.find(node) != meta_repr_.end(); } + /*! * \brief Print a key value pair */ @@ -121,20 +129,17 @@ class TextMetaDataContext { */ Doc GetMetaSection() const { if (meta_data_.size() == 0) return Doc(); - return Doc::RawText( - SaveJSON(Map(meta_data_.begin(), meta_data_.end()))); + return Doc::RawText(SaveJSON(Map(meta_data_.begin(), meta_data_.end()))); } /*! \return whether the meta data context is empty. */ - bool empty() const { - return meta_data_.empty(); - } + bool empty() const { return meta_data_.empty(); } private: /*! \brief additional metadata stored in TVM json format */ - std::unordered_map > meta_data_; + std::unordered_map > meta_data_; /*! \brief map from meta data into its string representation */ - std::unordered_map meta_repr_; + std::unordered_map meta_repr_; }; } // namespace tvm #endif // TVM_PRINTER_META_DATA_H_ diff --git a/src/printer/relay_text_printer.cc b/src/printer/relay_text_printer.cc index bda997a59d4d..a09e24b12429 100644 --- a/src/printer/relay_text_printer.cc +++ b/src/printer/relay_text_printer.cc @@ -18,7 +18,7 @@ */ /*! - * \file text_format_printer.cc + * \file relay_text_printer.cc * \brief Printer to print out the IR text format * that can be parsed by a parser. * @@ -32,154 +32,130 @@ * - Var * - Otherwise, inline if the node is at the end of a scope and is used at most once. */ -#include #include -#include +#include #include #include +#include + +#include "../ir/attr_functor.h" +#include "../relay/analysis/dependency_graph.h" #include "doc.h" #include "meta_data.h" -#include "../relay/analysis/dependency_graph.h" -#include "../ir/attr_functor.h" +#include "text_printer.h" namespace tvm { namespace relay { -class RelayTextPrinter : - public ExprFunctor, - public PatternFunctor, - public TypeFunctor, - public AttrFunctor { - public: - explicit RelayTextPrinter(bool show_meta_data, - runtime::TypedPackedFunc annotate) - : show_meta_data_(show_meta_data), - annotate_(annotate) {} - - /*! - * \brief Print additional info about expr in comment. - * \param expr The expression. - */ - Doc PrintOptionalInfo(const Expr& expr) { - Doc doc; - // default annotations - if (annotate_ == nullptr) { - if ((expr.as() || expr.as()) && expr->checked_type_.defined()) { - doc << " /* ty=" << Print(expr->checked_type()) << " */"; - } - } else { - std::string annotated_expr = annotate_(expr); - if (annotated_expr != "") { - doc << annotated_expr; - } +/*! + * \brief Print additional info about expr in comment. + * \param expr The expression. + */ +Doc RelayTextPrinter::PrintOptionalInfo(const Expr& expr) { + Doc doc; + // default annotations + if (annotate_ == nullptr) { + if ((expr.as() || expr.as()) && expr->checked_type_.defined()) { + doc << " /* ty=" << Print(expr->checked_type()) << " */"; + } + } else { + std::string annotated_expr = annotate_(expr); + if (annotated_expr != "") { + doc << annotated_expr; } - - return doc; } - // indent a new body - Doc PrintBody(const ObjectRef& node, int indent = 2) { - Doc doc; - Doc body; - doc << "{"; - doc << Doc::Indent(indent, body << Doc::NewLine() << PrintScope(node)) << Doc::NewLine(); - doc << "}"; - return doc; - } + return doc; +} - // create a new scope by creating a new printer object. This allows temp var - // numbers to be reused and prevents hoisted vars from escaping too far - Doc PrintScope(const ObjectRef& node) { - // print in a new scope - doc_stack_.push_back(Doc()); - // must print first so doc_stack_.back() reference doesn't become stale - Doc doc = Print(node, false, true); - doc = doc_stack_.back() << doc; - doc_stack_.pop_back(); - return doc; - } +// indent a new body +Doc RelayTextPrinter::PrintBody(const ObjectRef& node, int indent) { + Doc doc; + Doc body; + doc << "{"; + doc << Doc::Indent(indent, body << Doc::NewLine() << PrintScope(node)) << Doc::NewLine(); + doc << "}"; + return doc; +} - Doc PrintFinal(const ObjectRef& node) { - if (node->IsInstance() && - !node->IsInstance()) { - // Temporarily skip non-relay functions. - // TODO(tvm-team) enhance the code to work for all functions - } else if (node.as()) { - Expr expr = Downcast(node); - dg_ = DependencyGraph::Create(&arena_, expr); - } +// create a new scope by creating a new printer object. This allows temp var +// numbers to be reused and prevents hoisted vars from escaping too far +Doc RelayTextPrinter::PrintScope(const ObjectRef& node) { + // print in a new scope + doc_stack_.push_back(Doc()); + // must print first so doc_stack_.back() reference doesn't become stale + Doc doc = Print(node, false, true); + doc = doc_stack_.back() << doc; + doc_stack_.pop_back(); + return doc; +} - Doc doc; - doc << PrintScope(node); - if (!meta_.empty()) { - doc << Doc::NewLine(); - if (show_meta_data_) { - // append meta data in the end. - doc << "METADATA:" << Doc::NewLine() << meta_.GetMetaSection(); - } else { - doc << "// meta data omitted. you can use show_meta_data=True to include meta data"; - } - } - return doc; +Doc RelayTextPrinter::PrintFinal(const ObjectRef& node) { + if (node.defined() && node->IsInstance() && + !node->IsInstance()) { + // Temporarily skip non-relay functions. + // TODO(tvm-team) enhance the code to work for all functions + } else if (node.as()) { + Expr expr = Downcast(node); + dg_ = DependencyGraph::Create(&arena_, expr); } - std::vector PrintCallAttrs(const Attrs& attrs, const Expr& op); - std::vector PrintFuncAttrs(const Attrs& attrs); - - Doc Print(const ObjectRef& node, bool meta = false, bool try_inline = false) { - bool is_non_relay_func = - node->IsInstance() && - !node->IsInstance(); - if (node.as() && !is_non_relay_func) { - return PrintExpr(Downcast(node), meta, try_inline); - } else if (node.as()) { - return PrintType(Downcast(node), meta); - } else if (node.as()) { - return PrintPattern(Downcast(node), meta); - } else if (node.as()) { - return PrintMod(Downcast(node)); - } else { - // default module. - std::ostringstream os; - os << node; - return Doc::RawText(os.str()); - } + Doc doc; + doc << PrintScope(node); + return doc; +} + +Doc RelayTextPrinter::Print(const ObjectRef& node, bool meta, bool try_inline) { + bool is_non_relay_func = node.defined() && node->IsInstance() && + !node->IsInstance(); + if (node.as() && !is_non_relay_func) { + return PrintExpr(Downcast(node), meta, try_inline); + } else if (node.as()) { + return PrintType(Downcast(node), meta); + } else if (node.as()) { + return PrintPattern(Downcast(node), meta); + } else if (node.as()) { + return PrintMod(Downcast(node)); + } else { + // default module. + std::ostringstream os; + os << node; + return Doc::RawText(os.str()); } +} - Doc TempVar(int n) { - Doc doc; - return doc << "%" << n; - } - - Doc AllocTemp() { - return TempVar(temp_var_counter_++); - } - - /*! - * \brief get a unique name with the corresponding prefix - * \param prefix The prefix of the name - * \return The returned name. - */ - Doc GetUniqueName(const std::string& prefix) { - std::string unique_prefix = prefix; - auto it = name_alloc_map_.find(prefix); - if (it != name_alloc_map_.end()) { - while (true) { - std::ostringstream os; - os << prefix << (++it->second); - std::string name = os.str(); - if (name_alloc_map_.count(name) == 0) { - unique_prefix = name; - break; - } +Doc RelayTextPrinter::TempVar(int n) { + Doc doc; + return doc << "%" << n; +} + +Doc RelayTextPrinter::AllocTemp() { return TempVar(temp_var_counter_++); } + +/*! + * \brief get a unique name with the corresponding prefix + * \param prefix The prefix of the name + * \return The returned name. + */ +Doc RelayTextPrinter::GetUniqueName(const std::string& prefix) { + std::string unique_prefix = prefix; + auto it = name_alloc_map_.find(prefix); + if (it != name_alloc_map_.end()) { + while (true) { + std::ostringstream os; + os << prefix << (++it->second); + std::string name = os.str(); + if (name_alloc_map_.count(name) == 0) { + unique_prefix = name; + break; } } - name_alloc_map_[unique_prefix] = 0; - return Doc::Text(unique_prefix); } + name_alloc_map_[unique_prefix] = 0; + return Doc::Text(unique_prefix); +} - Doc Print(Kind k) { - switch (k) { +Doc RelayTextPrinter::Print(Kind k) { + switch (k) { case kType: return Doc::Text("Type"); case kShapeVar: @@ -195,642 +171,605 @@ class RelayTextPrinter : default: LOG(ERROR) << "Unknown Kind"; throw; - } } - /*! - * \brief Allocate name to a type variable. - * \param var The input type variable. - * \return The corresponding name. - */ - Doc AllocTypeVar(const TypeVar& var) { - if (memo_type_.count(var)) { - Doc val = memo_type_[var]; - val << "-malformed-ir"; - return val; - } - std::string name = var->name_hint; - if (name.length() == 0 || !std::isalpha(name[0])) { - name = "t" + name; - } - Doc val = GetUniqueName(name); - memo_type_[var] = val; - if (var->kind != kType) { - val << ": " << Print(var->kind); - } +} +/*! + * \brief Allocate name to a type variable. + * \param var The input type variable. + * \return The corresponding name. + */ +Doc RelayTextPrinter::AllocTypeVar(const TypeVar& var) { + if (memo_type_.count(var)) { + Doc val = memo_type_[var]; + val << "-malformed-ir"; return val; } + std::string name = var->name_hint; + if (name.length() == 0 || !std::isalpha(name[0])) { + name = "t" + name; + } + Doc val = GetUniqueName(name); + memo_type_[var] = val; + if (var->kind != kType) { + val << ": " << Print(var->kind); + } + return val; +} - /*! - * \brief Allocate name to a variable. - * \param var The input variable. - * \return The corresponding name. - */ - Doc AllocVar(const Var& var) { - // still print if ir is malformed, but show the error. - if (memo_.count(var)) { - Doc val = memo_[var]; - val << "-malformed-ir"; - return val; - } - std::string name = var->name_hint(); - // always make sure first name is alpha - if (name.length() == 0 || !std::isalpha(name[0])) { - name = "v" + name; - } - Doc val = GetUniqueName("%" + name); - memo_[var] = val; - if (var->type_annotation.defined()) { - val << ": " << Print(var->type_annotation); - } +/*! + * \brief Allocate name to a variable. + * \param var The input variable. + * \return The corresponding name. + */ +Doc RelayTextPrinter::AllocVar(const Var& var) { + // still print if ir is malformed, but show the error. + if (memo_.count(var)) { + Doc val = memo_[var]; + val << "-malformed-ir"; return val; } - - bool IsUnique(const Expr& expr) { - auto it = dg_.expr_node.find(expr); - if (it == dg_.expr_node.end()) { - return true; - } else { - return !(it->second->parents.head && it->second->parents.head->next); - } + std::string name = var->name_hint(); + // always make sure first name is alpha + if (name.length() == 0 || !std::isalpha(name[0])) { + name = "v" + name; } + Doc val = GetUniqueName("%" + name); + memo_[var] = val; + if (var->type_annotation.defined()) { + val << ": " << Print(var->type_annotation); + } + return val; +} - bool AlwaysInline(const Expr& expr) { - return expr.as() || expr.as() || expr.as() || - expr.as() || expr.as(); +bool RelayTextPrinter::IsUnique(const Expr& expr) { + auto it = dg_.expr_node.find(expr); + if (it == dg_.expr_node.end()) { + return true; + } else { + return !(it->second->parents.head && it->second->parents.head->next); } +} - //------------------------------------ - // Overload of Expr printing functions - //------------------------------------ - Doc PrintExpr(const Expr& expr, bool meta, bool try_inline) { - // Exploit memoization to print GNF. - // The first time we visit an expression, we need to allocate a temp var - // for it. Every subsequent time we can just use its assigned variable. - // This works since hashing uses pointer equality. +bool RelayTextPrinter::AlwaysInline(const Expr& expr) { + return expr.as() || expr.as() || expr.as() || + expr.as() || expr.as(); +} - // determine whether to inline - bool inline_expr = AlwaysInline(expr); - if (try_inline) { - inline_expr |= IsUnique(expr); - } +//------------------------------------ +// Overload of Expr printing functions +//------------------------------------ +Doc RelayTextPrinter::PrintExpr(const Expr& expr, bool meta, bool try_inline) { + // Exploit memoization to print GNF. + // The first time we visit an expression, we need to allocate a temp var + // for it. Every subsequent time we can just use its assigned variable. + // This works since hashing uses pointer equality. + + // determine whether to inline + bool inline_expr = AlwaysInline(expr); + if (try_inline) { + inline_expr |= IsUnique(expr); + } + + auto it = memo_.find(expr); + if (it != memo_.end()) return it->second; + + Doc printed_expr; + if (meta) { + printed_expr = meta_->GetMetaNode(GetRef(expr.get())); + } else if (!inline_expr && expr.as()) { + // wrap GNFed let in brackets + Doc body; + printed_expr << "("; + printed_expr << Doc::Indent(2, body << Doc::NewLine() << VisitExpr(expr)) << Doc::NewLine(); + printed_expr << ")"; + } else { + printed_expr = VisitExpr(expr); + } - auto it = memo_.find(expr); - if (it != memo_.end()) return it->second; - - Doc printed_expr; - if (meta) { - printed_expr = meta_.GetMetaNode(GetRef(expr.get())); - } else if (!inline_expr && expr.as()) { - // wrap GNFed let in brackets - Doc body; - printed_expr << "("; - printed_expr << Doc::Indent(2, body << Doc::NewLine() << VisitExpr(expr)) << Doc::NewLine(); - printed_expr << ")"; - } else { - printed_expr = VisitExpr(expr); - } + printed_expr << PrintOptionalInfo(expr); - printed_expr << PrintOptionalInfo(expr); - - // add expr to doc - if (expr.as()) { - // This is our first time visiting the var and we hit the VarNode case - // in the visitor. Thus the variable is free. - doc_stack_.back() << "free_var " << printed_expr << Doc::NewLine(); - // Memoization is done in AllocVar. - return memo_[expr]; - } else if (inline_expr) { - memo_[expr] = printed_expr; - return printed_expr; - } else { - Doc temp_var = AllocTemp(); - memo_[expr] = temp_var; - doc_stack_.back() << temp_var << " = " << printed_expr << ";" << Doc::NewLine(); - return temp_var; - } + // add expr to doc + if (expr.as()) { + // This is our first time visiting the var and we hit the VarNode case + // in the visitor. Thus the variable is free. + doc_stack_.back() << "free_var " << printed_expr << Doc::NewLine(); + // Memoization is done in AllocVar. + return memo_[expr]; + } else if (inline_expr) { + memo_[expr] = printed_expr; + return printed_expr; + } else { + Doc temp_var = AllocTemp(); + memo_[expr] = temp_var; + doc_stack_.back() << temp_var << " = " << printed_expr << ";" << Doc::NewLine(); + return temp_var; } +} - // Should only be triggered when op is a free variable being visited for the - // first time. - Doc VisitExpr_(const VarNode* op) final { - return AllocVar(GetRef(op)); +// Should only be triggered when op is a free variable being visited for the +// first time. +Doc RelayTextPrinter::VisitExpr_(const VarNode* op) { return AllocVar(GetRef(op)); } + +/*! + * \brief special method to print out const scalar + * \param dtype The data type + * \param value The value to be printed. + */ +template +Doc RelayTextPrinter::ScalarLiteral(DataType dtype, const T& value) { + std::ostringstream os; + if (dtype == DataType::Int(32)) { + os << value; + } else if (dtype == DataType::Float(32)) { + os << value << 'f'; + } else if (dtype == DataType::Float(64)) { + os << value; + } else if (dtype == DataType::Bool()) { + return Doc::PyBoolLiteral(value != 0); + } else { + os << value; } + return Doc::Text(os.str()); +} - /*! - * \brief special method to print out const scalar - * \param dtype The data type - * \param value The value to be printed. - */ - template - static Doc ScalarLiteral(DataType dtype, const T& value) { +Doc RelayTextPrinter::VisitExpr_(const ConstantNode* op) { + // Print out simple scalars directly. + if (op->is_scalar()) { std::ostringstream os; + DataType dtype = DataType(op->data->dtype); + CHECK_EQ(op->data->ctx.device_type, kDLCPU); if (dtype == DataType::Int(32)) { - os << value; + return ScalarLiteral(dtype, static_cast(op->data->data)[0]); + } else if (dtype == DataType::Int(64)) { + return ScalarLiteral(dtype, static_cast(op->data->data)[0]); } else if (dtype == DataType::Float(32)) { - os << value << 'f'; + return ScalarLiteral(dtype, static_cast(op->data->data)[0]); } else if (dtype == DataType::Float(64)) { - os << value; + return ScalarLiteral(dtype, static_cast(op->data->data)[0]); } else if (dtype == DataType::Bool()) { - return Doc::PyBoolLiteral(value != 0); - } else { - os << value; + return ScalarLiteral(dtype, static_cast(op->data->data)[0]); } - return Doc::Text(os.str()); } + // default fall-back, record it as meta node. + Doc doc; + return doc << Print(GetRef(op), true); +} - Doc VisitExpr_(const ConstantNode* op) final { - // Print out simple scalars directly. - if (op->is_scalar()) { - std::ostringstream os; - DataType dtype = DataType(op->data->dtype); - CHECK_EQ(op->data->ctx.device_type, kDLCPU); - if (dtype == DataType::Int(32)) { - return ScalarLiteral(dtype, static_cast(op->data->data)[0]); - } else if (dtype == DataType::Int(64)) { - return ScalarLiteral(dtype, static_cast(op->data->data)[0]); - } else if (dtype == DataType::Float(32)) { - return ScalarLiteral(dtype, static_cast(op->data->data)[0]); - } else if (dtype == DataType::Float(64)) { - return ScalarLiteral(dtype, static_cast(op->data->data)[0]); - } else if (dtype == DataType::Bool()) { - return ScalarLiteral(dtype, static_cast(op->data->data)[0]); - } - } - // default fall-back, record it as meta node. - Doc doc; - return doc << Print(GetRef(op), true); +Doc RelayTextPrinter::VisitExpr_(const TupleNode* op) { + std::vector fields; + for (Expr field : op->fields) { + fields.push_back(Print(field)); } - - Doc VisitExpr_(const TupleNode* op) final { - std::vector fields; - for (Expr field : op->fields) { - fields.push_back(Print(field)); - } - Doc doc; - doc << "(" << Doc::Concat(fields); - // conform to python tuple format (1,) - if (op->fields.size() == 1) { - doc << ","; - } - return doc << ")"; + Doc doc; + doc << "(" << Doc::Concat(fields); + // conform to python tuple format (1,) + if (op->fields.size() == 1) { + doc << ","; } + return doc << ")"; +} - Doc VisitExpr_(const TupleGetItemNode* op) final { - Doc doc; - return doc << Print(op->tuple) << "." << op->index; - } +Doc RelayTextPrinter::VisitExpr_(const TupleGetItemNode* op) { + Doc doc; + return doc << Print(op->tuple) << "." << op->index; +} - Doc VisitExpr_(const IfNode* op) final { +Doc RelayTextPrinter::VisitExpr_(const IfNode* op) { + Doc doc; + doc << "if (" << Print(op->cond) << ") "; + doc << PrintBody(op->true_branch); + doc << " else "; + doc << PrintBody(op->false_branch); + return doc; +} + +Doc RelayTextPrinter::VisitExpr_(const LetNode* op) { + int n = 0; + Expr let = GetRef(op); + while (auto let_node = let.as()) { Doc doc; - doc << "if (" << Print(op->cond) << ") "; - doc << PrintBody(op->true_branch); - doc << " else "; - doc << PrintBody(op->false_branch); - return doc; + doc << "let " << AllocVar(let_node->var) << " = " << Print(let_node->value, false, true) << ";" + << Doc::NewLine(); + doc_stack_.push_back(doc); + let = let_node->body; + ++n; + } + Doc doc = PrintScope(let); + for (int i = 0; i < n; ++i) { + doc = doc_stack_.back() << doc; + doc_stack_.pop_back(); } + return doc; +} - Doc VisitExpr_(const LetNode* op) final { - Doc doc; - doc - << "let " - << AllocVar(op->var) - << " = " - << Print(op->value, false, true) - << ";" - << Doc::NewLine(); - // we use a scope here so GNF hoisting doesn't escape too far - // and nested, unique lets are not hoisted - doc << PrintScope(op->body); - return doc; +Doc RelayTextPrinter::PrintFunc(const Doc& prefix, const relay::Function& fn) { + Doc doc; + doc << prefix; + if (fn->type_params.size() > 0) { + doc << "["; + std::vector type_params; + for (const TypeVar& tv : fn->type_params) { + type_params.push_back(Doc::Text(tv->name_hint)); + } + doc << Doc::Concat(type_params); + doc << "]"; + } + doc << "("; + std::vector params; + for (Var param : fn->params) { + params.push_back(AllocVar(param)); + } + for (const Doc& d : PrintFuncAttrs(fn->attrs)) { + params.push_back(d); + } + doc << Doc::Concat(params) << ") "; + if (fn->ret_type.defined()) { + doc << "-> " << Print(fn->ret_type) << " "; } + doc << PrintBody(fn->body); + return doc; +} - Doc PrintFunc(const Doc& prefix, const relay::Function& fn) { +Doc RelayTextPrinter::PrintFunc(const Doc& prefix, const BaseFunc& base_func) { + if (auto* n = base_func.as()) { + return PrintFunc(prefix, GetRef(n)); + } else if (auto* n = base_func.as()) { + std::ostringstream os; + os << GetRef(n); + return Doc::RawText(os.str()); + } else { + // def @xyz = meta['ExternalFunc'][id] Doc doc; - doc << prefix; - if (fn->type_params.size() > 0) { - doc << "["; - std::vector type_params; - for (const TypeVar& tv : fn->type_params) { - type_params.push_back(Doc::Text(tv->name_hint)); - } - doc << Doc::Concat(type_params); - doc << "]"; - } - doc << "("; - std::vector params; - for (Var param : fn->params) { - params.push_back(AllocVar(param)); - } - for (const Doc& d : PrintFuncAttrs(fn->attrs)) { - params.push_back(d); - } - doc << Doc::Concat(params) << ") "; - if (fn->ret_type.defined()) { - doc << "-> " << Print(fn->ret_type) << " "; - } - doc << PrintBody(fn->body); + doc << prefix << " = " << meta_->GetMetaNode(base_func); return doc; } +} - Doc PrintFunc(const Doc& prefix, const BaseFunc& base_func) { - if (auto* n = base_func.as()) { - return PrintFunc(prefix, GetRef(n)); - } else if (auto* n = base_func.as()) { - std::ostringstream os; - os << GetRef(n); - return Doc::RawText(os.str()); - } else { - // def @xyz = meta['ExternalFunc'][id] - Doc doc; - doc << prefix << " = " << meta_.GetMetaNode(base_func); - return doc; +Doc RelayTextPrinter::PrintMod(const IRModule& mod) { + Doc doc; + int counter = 0; + // type definitions + for (const auto& kv : mod->type_definitions) { + if (counter++ != 0) { + doc << Doc::NewLine(); } + doc << Print(kv.second); + doc << Doc::NewLine(); } - - Doc PrintMod(const IRModule& mod) { - Doc doc; - int counter = 0; - // type definitions - for (const auto& kv : mod->type_definitions) { - if (counter++ != 0) { - doc << Doc::NewLine(); - } - doc << Print(kv.second); - doc << Doc::NewLine(); + // functions + for (const auto& kv : mod->functions) { + if (kv.second.as()) { + dg_ = DependencyGraph::Create(&arena_, kv.second); } - // functions - for (const auto& kv : mod->functions) { - if (kv.second.as()) { - dg_ = DependencyGraph::Create(&arena_, kv.second); - } - if (counter++ != 0) { - doc << Doc::NewLine(); - } - std::ostringstream os; - os << "def @" << kv.first->name_hint; - doc << PrintFunc(Doc::Text(os.str()), kv.second); + if (counter++ != 0) { doc << Doc::NewLine(); } - return doc; + std::ostringstream os; + os << "def @" << kv.first->name_hint; + doc << PrintFunc(Doc::Text(os.str()), kv.second); + doc << Doc::NewLine(); } + return doc; +} - Doc VisitExpr_(const FunctionNode* op) final { - return PrintFunc(Doc::Text("fn "), GetRef(op)); - } +Doc RelayTextPrinter::VisitExpr_(const FunctionNode* op) { + return PrintFunc(Doc::Text("fn "), GetRef(op)); +} - Doc VisitExpr_(const GlobalVarNode* op) final { - return Doc::Text('@' + op->name_hint); - } +Doc RelayTextPrinter::VisitExpr_(const GlobalVarNode* op) { + return Doc::Text('@' + op->name_hint.operator std::string()); +} - Doc VisitExpr_(const OpNode* op) final { - return Doc::Text(op->name); - } +Doc RelayTextPrinter::VisitExpr_(const OpNode* op) { return Doc::Text(op->name); } - Doc VisitExpr_(const CallNode* op) final { - Doc doc; - // visit args first so they are lifted before the op - // this places op closer to its call site - std::vector args; - for (const Expr& arg : op->args) { - args.push_back(Print(arg)); - } - for (const Doc& d : PrintCallAttrs(op->attrs, op->op)) { - args.push_back(d); - } - const auto* cons_node = op->op.as(); - if (cons_node) { - doc << cons_node->name_hint; - } else { - doc << Print(op->op); - } - - if (cons_node && cons_node->inputs.size() == 0) { - // don't print as a call if it's a 0-arity cons - return doc; - } else { - return doc << "(" << Doc::Concat(args) << ")"; - } +Doc RelayTextPrinter::VisitExpr_(const CallNode* op) { + Doc doc; + // visit args first so they are lifted before the op + // this places op closer to its call site + std::vector args; + for (const Expr& arg : op->args) { + args.push_back(Print(arg)); + } + for (const Doc& d : PrintCallAttrs(op->attrs, op->op)) { + args.push_back(d); + } + const auto* cons_node = op->op.as(); + if (cons_node) { + doc << cons_node->name_hint; + } else { + doc << Print(op->op); } - Doc VisitExpr_(const RefCreateNode* op) final { - Doc doc; - return doc << "ref(" << Print(op->value) << ")"; + if (cons_node && cons_node->inputs.size() == 0) { + // don't print as a call if it's a 0-arity cons + return doc; + } else { + return doc << "(" << Doc::Concat(args) << ")"; } +} - Doc VisitExpr_(const RefReadNode* op) final { - Doc doc; - return doc << Print(op->ref) << "^"; - } +Doc RelayTextPrinter::VisitExpr_(const RefCreateNode* op) { + Doc doc; + return doc << "ref(" << Print(op->value) << ")"; +} - Doc VisitExpr_(const RefWriteNode* op) final { - Doc doc; - return doc << "(" << Print(op->ref) << " := " << Print(op->value) << ")"; - } +Doc RelayTextPrinter::VisitExpr_(const RefReadNode* op) { + Doc doc; + return doc << Print(op->ref) << "^"; +} - Doc VisitExpr_(const MatchNode* op) final { - // TODO(jmp): Lots of code duplication here because PrintBody and PrintScope don't accept Docs. - Doc doc; - Doc body; - doc << "match"; - if (!op->complete) { - doc << "?"; - } - doc << " (" << Print(op->data) << ") {"; - std::vector clause_docs; - for (const auto& clause : op->clauses) { - Doc clause_doc; - clause_doc << PrintPattern(clause->lhs, false) << " => "; - Doc rhs_doc = PrintScope(clause->rhs); - if (clause->rhs.as()) { - // only add braces if there are multiple lines on the rhs - rhs_doc = Doc::Brace("{", rhs_doc, "}"); - } - clause_doc << rhs_doc << ","; - clause_docs.push_back(clause_doc); - } - doc << Doc::Indent(2, body << Doc::NewLine() << Doc::Concat(clause_docs, Doc::NewLine())) - << Doc::NewLine() << "}"; - return doc; - } +Doc RelayTextPrinter::VisitExpr_(const RefWriteNode* op) { + Doc doc; + return doc << "(" << Print(op->ref) << " := " << Print(op->value) << ")"; +} - Doc PrintPattern(const Pattern& pattern, bool meta) { - auto it = memo_pattern_.find(pattern); - if (it != memo_pattern_.end()) return it->second; - Doc printed_pattern; - if (meta) { - printed_pattern = meta_.GetMetaNode(GetRef(pattern.get())); - } else { - printed_pattern = VisitPattern(pattern); - } - memo_pattern_[pattern] = printed_pattern; - return printed_pattern; - } +Doc RelayTextPrinter::VisitExpr_(const MatchNode* op) { + // TODO(jmp): Lots of code duplication here because PrintBody and PrintScope don't accept Docs. + Doc doc; + Doc body; + doc << "match"; + if (!op->complete) { + doc << "?"; + } + doc << " (" << Print(op->data) << ") {"; + std::vector clause_docs; + for (const auto& clause : op->clauses) { + Doc clause_doc; + clause_doc << PrintPattern(clause->lhs, false) << " => "; + Doc rhs_doc = PrintScope(clause->rhs); + if (clause->rhs.as()) { + // only add braces if there are multiple lines on the rhs + rhs_doc = Doc::Brace("{", rhs_doc, "}"); + } + clause_doc << rhs_doc << ","; + clause_docs.push_back(clause_doc); + } + doc << Doc::Indent(2, body << Doc::NewLine() << Doc::Concat(clause_docs, Doc::NewLine())) + << Doc::NewLine() << "}"; + return doc; +} - Doc VisitPattern_(const PatternConstructorNode* p) final { - Doc doc; - doc << p->constructor->name_hint; - if (!p->patterns.empty()) { - doc << "("; - std::vector pats; - for (const auto& pat : p->patterns) { - pats.push_back(Print(pat)); - } - doc << Doc::Concat(pats) << ")"; - } - return doc; +Doc RelayTextPrinter::PrintPattern(const Pattern& pattern, bool meta) { + auto it = memo_pattern_.find(pattern); + if (it != memo_pattern_.end()) return it->second; + Doc printed_pattern; + if (meta) { + printed_pattern = meta_->GetMetaNode(GetRef(pattern.get())); + } else { + printed_pattern = VisitPattern(pattern); } + memo_pattern_[pattern] = printed_pattern; + return printed_pattern; +} - Doc VisitPattern_(const PatternTupleNode* pt) final { - Doc doc; +Doc RelayTextPrinter::VisitPattern_(const PatternConstructorNode* p) { + Doc doc; + doc << p->constructor->name_hint; + if (!p->patterns.empty()) { doc << "("; std::vector pats; - for (const auto& pat : pt->patterns) { + for (const auto& pat : p->patterns) { pats.push_back(Print(pat)); } doc << Doc::Concat(pats) << ")"; - return doc; } + return doc; +} - Doc VisitPattern_(const PatternWildcardNode* pw) final { - return Doc::Text("_"); +Doc RelayTextPrinter::VisitPattern_(const PatternTupleNode* pt) { + Doc doc; + doc << "("; + std::vector pats; + for (const auto& pat : pt->patterns) { + pats.push_back(Print(pat)); } + doc << Doc::Concat(pats) << ")"; + return doc; +} - Doc VisitPattern_(const PatternVarNode* pv) final { - return AllocVar(pv->var); - } +Doc RelayTextPrinter::VisitPattern_(const PatternWildcardNode* pw) { return Doc::Text("_"); } - Doc VisitExpr_(const ConstructorNode* n) final { - Doc doc; - doc << n->name_hint; - if (in_adt_def_ && n->inputs.size() != 0) { - doc << "("; - std::vector inputs; - for (Type input : n->inputs) { - inputs.push_back(Print(input)); - } - doc << Doc::Concat(inputs) << ")"; - } - return doc; - } +Doc RelayTextPrinter::VisitPattern_(const PatternVarNode* pv) { return AllocVar(pv->var); } - //------------------------------------ - // Overload of Type printing functions - //------------------------------------ - Doc PrintType(const Type& type, bool meta) { - auto it = memo_type_.find(type); - if (it != memo_type_.end()) return it->second; - Doc printed_type; - if (meta) { - printed_type = meta_.GetMetaNode(GetRef(type.get())); - } else { - printed_type = VisitType(type); +Doc RelayTextPrinter::VisitExpr_(const ConstructorNode* n) { + Doc doc; + doc << n->name_hint; + if (in_adt_def_ && n->inputs.size() != 0) { + doc << "("; + std::vector inputs; + for (Type input : n->inputs) { + inputs.push_back(Print(input)); } - memo_type_[type] = printed_type; - return printed_type; + doc << Doc::Concat(inputs) << ")"; } + return doc; +} - Doc VisitTypeDefault_(const Object* node) final { - // by default always print as meta data - return Print(GetRef(node), true); +//------------------------------------ +// Overload of Type printing functions +//------------------------------------ +Doc RelayTextPrinter::PrintType(const Type& type, bool meta) { + auto it = memo_type_.find(type); + if (it != memo_type_.end()) return it->second; + Doc printed_type; + if (meta) { + printed_type = meta_->GetMetaNode(GetRef(type.get())); + } else { + printed_type = VisitType(type); } + memo_type_[type] = printed_type; + return printed_type; +} - Doc VisitType_(const TypeVarNode* node) final { - return Doc::Text(node->name_hint); - } +Doc RelayTextPrinter::VisitTypeDefault_(const Object* node) { + // by default always print as meta data + return Print(GetRef(node), true); +} - Doc VisitType_(const GlobalTypeVarNode* node) final { - return Doc::Text(node->name_hint); - } +Doc RelayTextPrinter::VisitType_(const TypeVarNode* node) { return Doc::Text(node->name_hint); } - Doc VisitType_(const TypeCallNode* node) final { - Doc doc = PrintType(node->func, false); - std::vector args; - for (const Type& t : node->args) { - args.push_back(PrintType(t, false)); - } - doc << "["; - doc << Doc::Concat(args); - doc << "]"; - return doc; - } +Doc RelayTextPrinter::VisitType_(const GlobalTypeVarNode* node) { + return Doc::Text(node->name_hint); +} - Doc PrintDType(DataType dtype) { - return Doc::Text(runtime::DLDataType2String(dtype)); +Doc RelayTextPrinter::VisitType_(const TypeCallNode* node) { + Doc doc = PrintType(node->func, false); + std::vector args; + for (const Type& t : node->args) { + args.push_back(PrintType(t, false)); } + doc << "["; + doc << Doc::Concat(args); + doc << "]"; + return doc; +} - Doc VisitType_(const TensorTypeNode* node) final { - // scalar type - if (node->shape.size() == 0) { - return PrintDType(node->dtype); - } - Doc doc; - doc << "Tensor[("; - std::vector shapes; - for (ObjectRef shape : node->shape) { - shapes.push_back(PrintAttr(shape)); - } - doc << Doc::Concat(shapes); - return doc << "), " << PrintDType(node->dtype) << "]"; - } +Doc RelayTextPrinter::PrintDType(DataType dtype) { + return Doc::Text(runtime::DLDataType2String(dtype)); +} - Doc VisitType_(const TupleTypeNode* node) final { - std::vector fields; - for (Type field : node->fields) { - fields.push_back(Print(field)); - } - Doc doc; - doc << "(" << Doc::Concat(fields); - // conform to python tuple format (1,) - if (node->fields.size() == 1) { - doc << ","; - } - return doc << ")"; +Doc RelayTextPrinter::VisitType_(const TensorTypeNode* node) { + // scalar type + if (node->shape.size() == 0) { + return PrintDType(node->dtype); } - - Doc VisitType_(const FuncTypeNode* node) final { - Doc doc; - doc << "fn "; - if (node->type_params.size() != 0) { - doc << "["; - std::vector type_params; - for (Type type_param : node->type_params) { - type_params.push_back(Print(type_param)); - } - doc << Doc::Concat(type_params); - doc << "]"; - } - std::vector arg_types; - for (Type arg_type : node->arg_types) { - arg_types.push_back(Print(arg_type)); - } - return doc << "(" << Doc::Concat(arg_types) << ") -> " << Print(node->ret_type); + Doc doc; + doc << "Tensor[("; + std::vector shapes; + for (ObjectRef shape : node->shape) { + shapes.push_back(PrintAttr(shape)); } + doc << Doc::Concat(shapes); + return doc << "), " << PrintDType(node->dtype) << "]"; +} - Doc VisitType_(const RelayRefTypeNode* node) final { - Doc doc; - return doc << "ref(" << Print(node->value) << ")"; +Doc RelayTextPrinter::VisitType_(const TupleTypeNode* node) { + std::vector fields; + for (Type field : node->fields) { + fields.push_back(Print(field)); } + Doc doc; + doc << "(" << Doc::Concat(fields); + // conform to python tuple format (1,) + if (node->fields.size() == 1) { + doc << ","; + } + return doc << ")"; +} - Doc VisitType_(const TypeDataNode* node) final { - in_adt_def_ = true; - Doc doc; - doc << "type " << Print(node->header); - - // type vars - if (node->type_vars.size() != 0) { - doc << "["; - std::vector type_vars; - for (Type type_var : node->type_vars) { - type_vars.push_back(Print(type_var)); - } - doc << Doc::Concat(type_vars) << "]"; - } - doc << " "; - - std::vector constructor_docs; - for (Constructor constructor : node->constructors) { - constructor_docs.push_back(Print(constructor, /* meta */ false, /* try_inline */ true)); - } - Doc separator; - separator << "," << Doc::NewLine(); - Doc adt_body; - adt_body << Doc::Concat(constructor_docs, separator); - // add trailing comma if there are any constructors - if (!constructor_docs.empty()) { - adt_body << ","; +Doc RelayTextPrinter::VisitType_(const FuncTypeNode* node) { + Doc doc; + doc << "fn "; + if (node->type_params.size() != 0) { + doc << "["; + std::vector type_params; + for (Type type_param : node->type_params) { + type_params.push_back(Print(type_param)); } - doc << Doc::Brace("{", adt_body, "}"); - in_adt_def_ = false; - return doc; + doc << Doc::Concat(type_params); + doc << "]"; + } + std::vector arg_types; + for (Type arg_type : node->arg_types) { + arg_types.push_back(Print(arg_type)); } + return doc << "(" << Doc::Concat(arg_types) << ") -> " << Print(node->ret_type); +} - //------------------------------------ - // Overload of Attr printing functions - //------------------------------------ +Doc RelayTextPrinter::VisitType_(const RelayRefTypeNode* node) { + Doc doc; + return doc << "ref(" << Print(node->value) << ")"; +} - Doc PrintAttr(const ObjectRef& value, bool meta = false) { - if (value.defined()) { - Doc printed_attr; - if (value.as()) { - printed_attr << "?"; - } else if (meta) { - printed_attr = meta_.GetMetaNode(Downcast(value)); - } else { - printed_attr = VisitAttr(value); - } - return printed_attr; - } else { - return Doc::Text("None"); +Doc RelayTextPrinter::VisitType_(const TypeDataNode* node) { + in_adt_def_ = true; + Doc doc; + doc << "type " << Print(node->header); + + // type vars + if (node->type_vars.size() != 0) { + doc << "["; + std::vector type_vars; + for (Type type_var : node->type_vars) { + type_vars.push_back(Print(type_var)); } + doc << Doc::Concat(type_vars) << "]"; } + doc << " "; - Doc VisitAttrDefault_(const Object* op) final { - return PrintAttr(GetRef(op), true); + std::vector constructor_docs; + for (Constructor constructor : node->constructors) { + constructor_docs.push_back(Print(constructor, /* meta */ false, /* try_inline */ true)); + } + Doc separator; + separator << "," << Doc::NewLine(); + Doc adt_body; + adt_body << Doc::Concat(constructor_docs, separator); + // add trailing comma if there are any constructors + if (!constructor_docs.empty()) { + adt_body << ","; } + doc << Doc::Brace("{", adt_body, "}"); + in_adt_def_ = false; + return doc; +} - Doc VisitAttr_(const ArrayNode* op) final { - Doc doc; - doc << "["; - std::vector arr_vals; - for (auto val : op->data) { - arr_vals.push_back(PrintAttr(val)); +//------------------------------------ +// Overload of Attr printing functions +//------------------------------------ + +Doc RelayTextPrinter::PrintAttr(const ObjectRef& value, bool meta) { + if (value.defined()) { + Doc printed_attr; + if (value.as()) { + printed_attr << "?"; + } else if (meta) { + printed_attr = meta_->GetMetaNode(Downcast(value)); + } else { + printed_attr = VisitAttr(value); } - doc << Doc::Concat(arr_vals); - doc << "]"; - return doc; + return printed_attr; + } else { + return Doc::Text("None"); } +} - Doc VisitAttr_(const tir::IntImmNode* op) final { - return ScalarLiteral(op->dtype, op->value); - } +Doc RelayTextPrinter::VisitAttrDefault_(const Object* op) { + return PrintAttr(GetRef(op), true); +} - Doc VisitAttr_(const tir::FloatImmNode* op) final { - return ScalarLiteral(op->dtype, op->value); - } +Doc RelayTextPrinter::VisitAttr_(const ArrayNode* op) { + Doc doc; + doc << "["; + std::vector arr_vals; + for (auto val : *op) { + arr_vals.push_back(PrintAttr(val)); + } + doc << Doc::Concat(arr_vals); + doc << "]"; + return doc; +} - Doc VisitAttr_(const tir::StringImmNode* op) final { - return Doc::StrLiteral(op->value); - } +Doc RelayTextPrinter::VisitAttr_(const tir::IntImmNode* op) { + return ScalarLiteral(op->dtype, op->value); +} - private: - /*! \brief Whether to print meta data. */ - bool show_meta_data_; - /*! \brief additional comment function */ - runtime::TypedPackedFunc annotate_; - /*! \brief Stack of docs to implement scoped GNFing. */ - std::vector doc_stack_{}; - /*! \brief Map from Expr to Doc */ - std::unordered_map memo_; - /*! \brief Map from Type to Doc */ - std::unordered_map memo_type_; - /*! \brief Map from Type to Doc */ - std::unordered_map memo_pattern_; - /*! \brief name allocation map */ - std::unordered_map name_alloc_map_; - /*! \brief meta data context */ - TextMetaDataContext meta_; - /*! \brief counter of temporary variable */ - size_t temp_var_counter_{0}; - /*! \brief whether the printer is currently in an ADT definition */ - bool in_adt_def_; - /*! \brief arena for dependency graph */ - support::Arena arena_; - /*! \brief dependency graph of the expr */ - DependencyGraph dg_; - class AttrPrinter; - friend class AttrPrinter; -}; +Doc RelayTextPrinter::VisitAttr_(const tir::FloatImmNode* op) { + return ScalarLiteral(op->dtype, op->value); +} + +Doc RelayTextPrinter::VisitAttr_(const tir::StringImmNode* op) { + return Doc::StrLiteral(op->value); +} /*! * \brief Attribute printer which prints the attributes in the call. */ -class RelayTextPrinter::AttrPrinter : - public AttrVisitor { +class RelayTextPrinter::AttrPrinter : public AttrVisitor { public: - AttrPrinter(std::vector* doc, RelayTextPrinter* parent) - : docs(doc), parent_(parent) {} + AttrPrinter(std::vector* doc, RelayTextPrinter* parent) : docs(doc), parent_(parent) {} - template + template void PrintKV(const char* key, const T& value) { Doc doc; doc << key << "=" << value; @@ -842,24 +781,12 @@ class RelayTextPrinter::AttrPrinter : doc << key << "=" << *value << "f"; docs->push_back(doc); } - void Visit(const char* key, int64_t* value) final { - PrintKV(key, *value); - } - void Visit(const char* key, uint64_t* value) final { - PrintKV(key, *value); - } - void Visit(const char* key, int* value) final { - PrintKV(key, *value); - } - void Visit(const char* key, bool* value) final { - PrintKV(key, Doc::PyBoolLiteral(*value)); - } - void Visit(const char* key, std::string* value) final { - PrintKV(key, Doc::StrLiteral(*value)); - } - void Visit(const char* key, void** value) final { - LOG(FATAL) << "do not allow void as argument"; - } + void Visit(const char* key, int64_t* value) final { PrintKV(key, *value); } + void Visit(const char* key, uint64_t* value) final { PrintKV(key, *value); } + void Visit(const char* key, int* value) final { PrintKV(key, *value); } + void Visit(const char* key, bool* value) final { PrintKV(key, Doc::PyBoolLiteral(*value)); } + void Visit(const char* key, std::string* value) final { PrintKV(key, Doc::StrLiteral(*value)); } + void Visit(const char* key, void** value) final { LOG(FATAL) << "do not allow void as argument"; } void Visit(const char* key, DataType* value) final { PrintKV(key, Doc::StrLiteral(runtime::DLDataType2String(*value))); } @@ -875,15 +802,14 @@ class RelayTextPrinter::AttrPrinter : RelayTextPrinter* parent_; }; -std::vector RelayTextPrinter::PrintCallAttrs( - const Attrs& attrs, const Expr& op) { +std::vector RelayTextPrinter::PrintCallAttrs(const Attrs& attrs, const Expr& op) { std::vector docs; if (!attrs.defined()) return docs; const auto* op_node = op.as(); if (op_node && (attrs->type_index() != op_node->attrs_type_index)) { // fallback Doc doc; - doc << meta_.GetMetaNode(attrs); + doc << meta_->GetMetaNode(attrs); docs.push_back(doc); return docs; } else { @@ -905,38 +831,13 @@ std::vector RelayTextPrinter::PrintFuncAttrs(const Attrs& attrs) { } return docs; } -} // namespace relay -static const char* kSemVer = "v0.0.4"; - -// TODO(tvm-team): split into files, related: arith/analyzer.h -// -// - text_printer.h (common header) -// - text_printer.cc (prints modules dispatch into relay and tir files) -// - type_text_printer.cc(specific printing logics for types, -// can also consider put under type_text_printer) -// - Implements AsText -// - relay_text_printer.cc (specific printing logics for relay) -// - tir_text_printer.cc (specific printing logics for TIR) -std::string PrettyPrint(const ObjectRef& node) { - Doc doc; - doc << relay::RelayTextPrinter(false, nullptr).PrintFinal(node); - return doc.str(); -} +TVM_REGISTER_GLOBAL("ir.TextPrinter").set_body_typed([](ObjectRef node) { + std::cout << "The program: " << node << std::endl; + auto text = AsText(node, false, nullptr); + std::cout << "The text " << text; + return text; +}); -std::string AsText(const ObjectRef& node, - bool show_meta_data, - runtime::TypedPackedFunc annotate) { - Doc doc; - doc << kSemVer << Doc::NewLine(); - doc << relay::RelayTextPrinter(show_meta_data, annotate).PrintFinal(node); - return doc.str(); -} - - -TVM_REGISTER_GLOBAL("ir.PrettyPrint") -.set_body_typed(PrettyPrint); - -TVM_REGISTER_GLOBAL("ir.AsText") -.set_body_typed(AsText); +} // namespace relay } // namespace tvm diff --git a/src/printer/text_printer.cc b/src/printer/text_printer.cc new file mode 100644 index 000000000000..2993d38234ea --- /dev/null +++ b/src/printer/text_printer.cc @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file text_printer.cc + * \brief Printer to print out the unified IR text format + * that can be parsed by a parser. + */ + +#include "text_printer.h" + +#include + +#include + +namespace tvm { + +static const char* kSemVer = "v0.0.4"; + +// TODO(tvm-team): split into files, related: arith/analyzer.h +// +// - text_printer.h (common header) +// - text_printer.cc (prints modules dispatch into relay and tir files) +// - type_text_printer.cc(specific printing logics for types, +// can also consider put under type_text_printer) +// - Implements AsText +// - relay_text_printer.cc (specific printing logics for relay) +// - tir_text_printer.cc (specific printing logics for TIR) + +Doc TextPrinter::PrintMod(const IRModule& mod) { + Doc doc; + int counter = 0; + // type definitions + for (const auto& kv : mod->type_definitions) { + if (counter++ != 0) { + doc << Doc::NewLine(); + } + doc << relay_text_printer_.Print(kv.second); + doc << Doc::NewLine(); + } + // functions + for (const auto& kv : mod->functions) { + if (kv.second.as()) { + relay_text_printer_.dg_ = + relay::DependencyGraph::Create(&relay_text_printer_.arena_, kv.second); + } + if (counter++ != 0) { + doc << Doc::NewLine(); + } + if (kv.second.as()) { + std::ostringstream os; + os << "def @" << kv.first->name_hint; + doc << relay_text_printer_.PrintFunc(Doc::Text(os.str()), kv.second); + } else if (kv.second.as()) { + doc << tir_text_printer_.PrintPrimFunc(Downcast(kv.second)); + } + doc << Doc::NewLine(); + } + return doc; +} + +String PrettyPrint(const ObjectRef& node) { + Doc doc; + doc << TextPrinter(false, nullptr).PrintFinal(node); + return doc.str(); +} + +String AsText(const ObjectRef& node, bool show_meta_data, + runtime::TypedPackedFunc annotate) { + Doc doc; + doc << kSemVer << Doc::NewLine(); + runtime::TypedPackedFunc ftyped = nullptr; + if (annotate != nullptr) { + ftyped = runtime::TypedPackedFunc( + [&annotate](const ObjectRef& expr) -> std::string { return annotate(expr); }); + } + doc << TextPrinter(show_meta_data, ftyped).PrintFinal(node); + return doc.str(); +} + +TVM_REGISTER_GLOBAL("ir.PrettyPrint").set_body_typed(PrettyPrint); + +TVM_REGISTER_GLOBAL("ir.AsText").set_body_typed(AsText); + +} // namespace tvm diff --git a/src/printer/text_printer.h b/src/printer/text_printer.h new file mode 100644 index 000000000000..c7b2b31019ae --- /dev/null +++ b/src/printer/text_printer.h @@ -0,0 +1,401 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file text_printer.h + * \brief Printer to print out the unified IR text format + * that can be parsed by a parser. + */ + +#ifndef TVM_PRINTER_TEXT_PRINTER_H_ +#define TVM_PRINTER_TEXT_PRINTER_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "../ir/attr_functor.h" +#include "../relay/analysis/dependency_graph.h" +#include "doc.h" +#include "meta_data.h" +#include "text_printer.h" + +namespace tvm { +class TextPrinter; +} // namespace tvm + +namespace tvm { +namespace relay { + +class RelayTextPrinter : public ExprFunctor, + public PatternFunctor, + public TypeFunctor, + public AttrFunctor { + public: + explicit RelayTextPrinter(bool show_meta_data, TextMetaDataContext* meta, + runtime::TypedPackedFunc annotate) + : show_meta_data_(show_meta_data), annotate_(annotate), meta_(meta) {} + + /*! + * \brief Print additional info about expr in comment. + * \param expr The expression. + */ + Doc PrintOptionalInfo(const Expr& expr); + // indent a new body + Doc PrintBody(const ObjectRef& node, int indent = 2); + // create a new scope by creating a new printer object. This allows temp var + // numbers to be reused and prevents hoisted vars from escaping too far + Doc PrintScope(const ObjectRef& node); + Doc PrintFinal(const ObjectRef& node); + std::vector PrintCallAttrs(const Attrs& attrs, const Expr& op); + std::vector PrintFuncAttrs(const Attrs& attrs); + + Doc Print(const ObjectRef& node, bool meta = false, bool try_inline = false); + + Doc TempVar(int n); + Doc AllocTemp(); + /*! + * \brief get a unique name with the corresponding prefix + * \param prefix The prefix of the name + * \return The returned name. + */ + Doc GetUniqueName(const std::string& prefix); + Doc Print(Kind k); + /*! + * \brief Allocate name to a type variable. + * \param var The input type variable. + * \return The corresponding name. + */ + Doc AllocTypeVar(const TypeVar& var); + /*! + * \brief Allocate name to a variable. + * \param var The input variable. + * \return The corresponding name. + */ + Doc AllocVar(const Var& var); + bool IsUnique(const Expr& expr); + bool AlwaysInline(const Expr& expr); + + Doc PrintFunc(const Doc& prefix, const relay::Function& fn); + Doc PrintFunc(const Doc& prefix, const BaseFunc& base_func); + Doc PrintMod(const IRModule& mod); + + //------------------------------------ + // Overload of Expr printing functions + //------------------------------------ + Doc PrintExpr(const Expr& expr, bool meta, bool try_inline); + // Should only be triggered when op is a free variable being visited for the + // first time. + Doc VisitExpr_(const VarNode* op) final; + /*! + * \brief special method to print out const scalar + * \param dtype The data type + * \param value The value to be printed. + */ + template + static Doc ScalarLiteral(DataType dtype, const T& value); + Doc VisitExpr_(const ConstantNode* op) final; + Doc VisitExpr_(const TupleNode* op) final; + Doc VisitExpr_(const TupleGetItemNode* op) final; + Doc VisitExpr_(const IfNode* op) final; + Doc VisitExpr_(const LetNode* op) final; + Doc VisitExpr_(const FunctionNode* op) final; + Doc VisitExpr_(const GlobalVarNode* op) final; + Doc VisitExpr_(const OpNode* op) final; + Doc VisitExpr_(const CallNode* op) final; + Doc VisitExpr_(const RefCreateNode* op) final; + Doc VisitExpr_(const RefReadNode* op) final; + Doc VisitExpr_(const RefWriteNode* op) final; + Doc VisitExpr_(const MatchNode* op) final; + Doc PrintPattern(const Pattern& pattern, bool meta); + Doc VisitPattern_(const PatternConstructorNode* p) final; + Doc VisitPattern_(const PatternTupleNode* pt) final; + Doc VisitPattern_(const PatternWildcardNode* pw) final; + Doc VisitPattern_(const PatternVarNode* pv) final; + Doc VisitExpr_(const ConstructorNode* n) final; + //------------------------------------ + // Overload of Type printing functions + //------------------------------------ + Doc PrintType(const Type& type, bool meta); + Doc VisitTypeDefault_(const Object* node) final; + Doc VisitType_(const TypeVarNode* node) final; + Doc VisitType_(const GlobalTypeVarNode* node); + Doc VisitType_(const TypeCallNode* node) final; + Doc PrintDType(DataType dtype); + Doc VisitType_(const TensorTypeNode* node) final; + Doc VisitType_(const TupleTypeNode* node) final; + Doc VisitType_(const FuncTypeNode* node) final; + Doc VisitType_(const RelayRefTypeNode* node) final; + Doc VisitType_(const TypeDataNode* node) final; + //------------------------------------ + // Overload of Attr printing functions + //------------------------------------ + Doc PrintAttr(const ObjectRef& value, bool meta = false); + Doc VisitAttrDefault_(const Object* op) final; + Doc VisitAttr_(const ArrayNode* op) final; + Doc VisitAttr_(const tir::IntImmNode* op) final; + Doc VisitAttr_(const tir::FloatImmNode* op) final; + Doc VisitAttr_(const tir::StringImmNode* op) final; + + private: + /*! \brief Whether to print meta data. */ + bool show_meta_data_; + /*! \brief additional comment function */ + runtime::TypedPackedFunc annotate_; + /*! \brief Stack of docs to implement scoped GNFing. */ + std::vector doc_stack_{}; + /*! \brief Map from Expr to Doc */ + std::unordered_map memo_; + /*! \brief Map from Type to Doc */ + std::unordered_map memo_type_; + /*! \brief Map from Type to Doc */ + std::unordered_map memo_pattern_; + /*! \brief name allocation map */ + std::unordered_map name_alloc_map_; + /*! \brief meta data context */ + TextMetaDataContext* meta_; + /*! \brief counter of temporary variable */ + size_t temp_var_counter_{0}; + /*! \brief whether the printer is currently in an ADT definition */ + bool in_adt_def_; + /*! \brief arena for dependency graph */ + support::Arena arena_; + /*! \brief dependency graph of the expr */ + DependencyGraph dg_; + class AttrPrinter; + friend class AttrPrinter; + friend class tvm::TextPrinter; +}; + +} // namespace relay +} // namespace tvm + +namespace tvm { +namespace tir { + +/*! + * \brief Meta node collector + * If we decide to put some node into meta, then all the sub-nodes inside + * it need to be put in meta as well, since when parsing we need to know + * whether two refs are the same + */ +class MetaCollector : public StmtExprVisitor { + public: + explicit MetaCollector(TextMetaDataContext* meta) : meta_(meta) {} + + void Collect(const ObjectRef& n) { + // these nodes can be print directly(StringLiteral or use identifier to identify) + if (!n.defined() || n.as() || n.as() || n.as() || + n.as() || n.as() || n.as()) { + return; + } + if (n->IsInstance()) { + VisitStmt(Downcast(n)); + } else if (n->IsInstance()) { + VisitExpr(Downcast(n)); + } + } + + void VisitStmt(const Stmt& n) override { + meta_->GetMetaNode(n); + StmtVisitor::VisitStmt(n); + } + + void VisitExpr(const PrimExpr& n) override { + meta_->GetMetaNode(n); + ExprVisitor::VisitExpr(n); + } + + private: + TextMetaDataContext* meta_; +}; + +class TIRTextPrinter : public StmtFunctor, + public ExprFunctor, + public TypeFunctor { + public: + explicit TIRTextPrinter(bool show_meta, TextMetaDataContext* meta) + : show_meta_(show_meta), meta_(meta), meta_collector_(meta) {} + + /*! \brief Print the node */ + Doc Print(const ObjectRef& node); + + private: + /*! \brief whether show meta data */ + bool show_meta_; + /*! \brief meta data context */ + TextMetaDataContext* meta_; + /*! \brief meta collector */ + MetaCollector meta_collector_; + /*! \brief Map from Var to Doc */ + std::unordered_map memo_var_; + /*! \brief Map from Buffer to Doc */ + std::unordered_map memo_buf_; + /*! \brief name allocation map */ + std::unordered_map name_alloc_map_; + + friend class tvm::TextPrinter; + + Doc VisitExpr_(const IntImmNode* op) override; + Doc VisitExpr_(const FloatImmNode* op) override; + Doc VisitExpr_(const StringImmNode* op) override; + Doc VisitExpr_(const CastNode* op) override; + Doc VisitExpr_(const VarNode* op) override; + Doc VisitExpr_(const AddNode* op) override; + Doc VisitExpr_(const SubNode* op) override; + Doc VisitExpr_(const MulNode* op) override; + Doc VisitExpr_(const DivNode* op) override; + Doc VisitExpr_(const ModNode* op) override; + Doc VisitExpr_(const FloorDivNode* op) override; + Doc VisitExpr_(const FloorModNode* op) override; + Doc VisitExpr_(const MinNode* op) override; + Doc VisitExpr_(const MaxNode* op) override; + Doc VisitExpr_(const EQNode* op) override; + Doc VisitExpr_(const NENode* op) override; + Doc VisitExpr_(const LTNode* op) override; + Doc VisitExpr_(const LENode* op) override; + Doc VisitExpr_(const GTNode* op) override; + Doc VisitExpr_(const GENode* op) override; + Doc VisitExpr_(const AndNode* op) override; + Doc VisitExpr_(const OrNode* op) override; + Doc VisitExpr_(const NotNode* op) override; + Doc VisitExpr_(const SelectNode* op) override; + Doc VisitExpr_(const BufferLoadNode* op) override; + Doc VisitExpr_(const ProducerLoadNode* op) override; + Doc VisitExpr_(const LoadNode* op) override; + Doc VisitExpr_(const RampNode* op) override; + Doc VisitExpr_(const BroadcastNode* op) override; + Doc VisitExpr_(const LetNode* op) override; + Doc VisitExpr_(const CallNode* op) override; + Doc VisitExpr_(const ShuffleNode* op) override; + Doc VisitExpr_(const ReduceNode* op) override; + Doc VisitExprDefault_(const Object* op) override; + + Doc VisitStmt_(const LetStmtNode* op) override; + Doc VisitStmt_(const AttrStmtNode* op) override; + Doc VisitStmt_(const AssertStmtNode* op) override; + Doc VisitStmt_(const StoreNode* op) override; + Doc VisitStmt_(const BufferStoreNode* op) override; + Doc VisitStmt_(const BufferRealizeNode* op) override; + Doc VisitStmt_(const AllocateNode* op) override; + Doc VisitStmt_(const FreeNode* op) override; + Doc VisitStmt_(const IfThenElseNode* op) override; + Doc VisitStmt_(const SeqStmtNode* op) override; + Doc VisitStmt_(const EvaluateNode* op) override; + Doc VisitStmt_(const ForNode* op) override; + Doc VisitStmt_(const PrefetchNode* op) override; + Doc VisitStmtDefault_(const Object* op) override; + + Doc VisitType_(const PrimTypeNode* node) override; + Doc VisitType_(const PointerTypeNode* node) override; + Doc VisitType_(const TupleTypeNode* node) override; + + Doc PrintIRModule(const IRModule& module); + Doc PrintPrimFunc(const PrimFunc& primFunc); + Doc PrintArray(const ArrayNode* op); + Doc PrintIterVar(const IterVarNode* op); + Doc PrintRange(const RangeNode* op); + Doc PrintBuffer(const BufferNode* op); + Doc PrintString(const StringObj* op) { return Doc::StrLiteral(op->data); } + + /*! + * \brief special method to print out data type + * \param dtype The data type + */ + static Doc PrintDType(DataType dtype); + /*! + * \brief special method to print out const scalar + * \param dtype The data type + * \param data The pointer to hold the data. + */ + template + static Doc PrintConstScalar(DataType dtype, const T& data); + Doc GetUniqueName(std::string prefix); + Doc AllocVar(const Var& var); + Doc AllocBuf(const Buffer& buffer); + /*! + * \brief special method to render vectors of docs with a separator + * \param vec vector of docs + * \param sep separator + */ + static Doc PrintSep(const std::vector& vec, const Doc& sep); + Doc PrintBody(const Stmt& body, bool indent = true); +}; + +} // namespace tir +} // namespace tvm + +namespace tvm { + +class TextPrinter { + public: + explicit TextPrinter(bool show_meta_data, + const runtime::TypedPackedFunc& annotate) + : show_meta_data_(show_meta_data), + annotate_(annotate), + relay_text_printer_(show_meta_data, &meta_, annotate), + tir_text_printer_(show_meta_data, &meta_) {} + + /*! \brief whether show meta data */ + bool show_meta_data_; + /*! \brief meta data context */ + TextMetaDataContext meta_; + /*! \brief additional comment function */ + runtime::TypedPackedFunc annotate_; + /*! \brief Relay Text Printer */ + relay::RelayTextPrinter relay_text_printer_; + /*! \brief TIR Text Printer */ + tir::TIRTextPrinter tir_text_printer_; + + Doc PrintFinal(const ObjectRef& node) { + Doc doc; + if (node->IsInstance()) { + doc << PrintMod(Downcast(node)); + } else if (node->IsInstance() || node->IsInstance() || + node->IsInstance()) { + doc << tir_text_printer_.Print(node); + } else { + doc << relay_text_printer_.PrintFinal(node); + } + if (!meta_.empty()) { + doc << Doc::NewLine(); + if (show_meta_data_) { + // append meta data in the end. + doc << "METADATA:" << Doc::NewLine() << meta_.GetMetaSection(); + } else { + doc << "// meta data omitted. you can use show_meta_data=True to include meta data"; + } + } + return doc; + } + + Doc PrintMod(const IRModule& mod); +}; +} // namespace tvm + +#endif // TVM_PRINTER_TEXT_PRINTER_H_ diff --git a/src/printer/tir_text_printer.cc b/src/printer/tir_text_printer.cc new file mode 100644 index 000000000000..29927379f17d --- /dev/null +++ b/src/printer/tir_text_printer.cc @@ -0,0 +1,616 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tir_text_printer.cc + * \brief Printer to print out the IR text format + * that can be parsed by a parser. + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "doc.h" +#include "meta_data.h" +#include "text_printer.h" + +namespace tvm { +namespace tir { + +Doc TIRTextPrinter::Print(const ObjectRef& node) { + if (!node.defined()) return Doc::Text("(nullptr)"); + if (node->IsInstance()) { + return VisitStmt(Downcast(node)); + } else if (node->IsInstance()) { + return Doc::Text("?"); + } else if (node->IsInstance()) { + return VisitExpr(Downcast(node)); + } else if (node->IsInstance()) { + return VisitType(Downcast(node)); + } else if (node->IsInstance()) { + return PrintPrimFunc(Downcast(node)); + } else if (node->IsInstance()) { + return PrintIRModule(Downcast(node)); + } else if (node->IsInstance()) { + return PrintArray(node.as()); + } else if (node->IsInstance()) { + return PrintIterVar(node.as()); + } else if (node->IsInstance()) { + return PrintRange(node.as()); + } else if (node->IsInstance()) { + return PrintBuffer(node.as()); + } else if (node->IsInstance()) { + return PrintString(node.as()); + } else { + return this->meta_->GetMetaNode(node); + } +} + +Doc TIRTextPrinter::PrintPrimFunc(const PrimFunc& primFunc) { + const auto* op = primFunc.operator->(); + const auto& signature = op->func_type_annotation(); + // collect Meta in DictAttr + for (const auto& it : primFunc->attrs->dict) { + meta_collector_.Collect(it.second); + } + // collect buffers in buffer_map + memo_var_.clear(); + memo_buf_.clear(); + for (const auto& it : op->buffer_map) { + memo_buf_[it.second] = AllocBuf(it.second); + } + // print PrimFunc + Doc doc; + doc << "primfn" + << "("; + // print params and its type annotation + std::vector params; + for (const auto& param : op->params) { + params.push_back(Print(param)); + } + Doc sep; + doc << PrintSep(params, Doc::Indent(9, Doc::Text(", "))) << ")"; + // print return type + doc << " -> " << Print(signature->ret_type); + // print attr + Doc attr_doc; + std::vector attr_docs; + for (const auto& it : op->attrs->dict) { + attr_docs.push_back(Doc::StrLiteral(it.first) << ": " << Print(it.second)); + } + attr_doc << Doc::NewLine() << "attr = {" << PrintSep(attr_docs, Doc::Text(", ")) << "}"; + doc << Doc::Indent(2, attr_doc); + // print all the buffers in the tree + Doc buffer_doc; + std::vector buffer_docs; + for (const auto& it : memo_buf_) { + const auto& buf = it.first; + buffer_docs.push_back(Print(buf) << Doc::Text(": Buffer(") << Print(buf->data) << ", " + << PrintDType(buf->dtype) << ", " << Print(buf->shape) << ", " + << Print(buf->strides)); + if (!is_zero(buf->elem_offset)) { + buffer_docs.back() << ", elem_offset=" << Print(buf->elem_offset); + } + if (buf->scope != "global") { + buffer_docs.back() << ", scope=" << Doc::StrLiteral(buf->scope); + } + if (buf->data_alignment != 128) { + buffer_docs.back() << ", align=" << buf->data_alignment; + } + if (buf->offset_factor != 1) { + buffer_docs.back() << ", offset_factor=" << buf->offset_factor; + } + if (buf->buffer_type != 1) { + buffer_docs.back() << ", type=" << Doc::StrLiteral("auto"); + } + buffer_docs.back() << ")"; + } + buffer_doc << Doc::NewLine() << "buffers = {"; + buffer_doc << PrintSep(buffer_docs, Doc::Indent(11, Doc::Text(",") << Doc::NewLine())); + doc << Doc::Indent(2, buffer_doc) << "}"; + // print buffer_map + std::vector buffer_map_doc; + for (const auto& it : op->buffer_map) { + buffer_map_doc.push_back(Print(it.first) << ": " << Print(it.second)); + } + doc << Doc::Indent( + 2, Doc::NewLine() << "buffer_map = {" << PrintSep(buffer_map_doc, Doc::Text(", ")) << "}"); + doc << PrintBody(op->body); + return doc; +} + +Doc TIRTextPrinter::PrintIRModule(const IRModule& module) { + const auto* op = module.operator->(); + Doc doc; + + Doc body; + body << Doc::NewLine(); + std::vector functions; + for (auto it = op->functions.begin(); it != op->functions.end(); ++it) { + if ((*it).second.as()) { + functions.push_back(Print((*it).second)); + } + } + body << TIRTextPrinter::PrintSep(functions, Doc::NewLine() << Doc::NewLine()); + doc << Doc::Indent(0, body); + return doc; +} + +Doc TIRTextPrinter::PrintArray(const ArrayNode* op) { + Doc doc; + doc << '['; + for (size_t i = 0; i < op->size(); ++i) { + if (i != 0) { + doc << ", "; + } + doc << Print(op->at(i)); + } + doc << ']'; + return doc; +} + +Doc TIRTextPrinter::PrintIterVar(const IterVarNode* op) { + Doc doc; + doc << "IterVar(" << Print(op->var); + if (op->dom.defined()) { + doc << ", [" << Print(op->dom) << "], "; + } else { + doc << ", " << Print(op->dom) << ", "; + } + doc << Doc::StrLiteral(IterVarType2String(op->iter_type)) << ", "; + doc << Doc::StrLiteral(op->thread_tag) << ")"; + return doc; +} + +Doc TIRTextPrinter::PrintRange(const RangeNode* op) { + return Print(op->min) << ":" << Print(op->min + op->extent); +} + +Doc TIRTextPrinter::PrintBuffer(const BufferNode* op) { + const Buffer& buffer = GetRef(op); + CHECK_GT(memo_buf_.count(buffer), 0); + return meta_->InMeta(buffer) ? meta_->GetMetaNode(buffer) : memo_buf_[buffer]; +} + +Doc TIRTextPrinter::VisitExprDefault_(const Object* op) { + return this->meta_->GetMetaNode(GetRef(op)); +} + +Doc TIRTextPrinter::VisitStmtDefault_(const Object* op) { + return this->meta_->GetMetaNode(GetRef(op)); +} + +Doc TIRTextPrinter::VisitExpr_(const IntImmNode* op) { + return PrintConstScalar(op->dtype, op->value); +} + +Doc TIRTextPrinter::VisitExpr_(const FloatImmNode* op) { + return PrintConstScalar(op->dtype, op->value); +} + +Doc TIRTextPrinter::VisitExpr_(const StringImmNode* op) { return Doc::StrLiteral(op->value); } + +Doc TIRTextPrinter::VisitExpr_(const CastNode* op) { + Doc doc; + doc << "cast(" << PrintDType(op->dtype) << ", " << Print(op->value) << ")"; + return doc; +} + +Doc TIRTextPrinter::VisitExpr_(const VarNode* op) { + const Var& var = GetRef(op); + return meta_->InMeta(var) ? meta_->GetMetaNode(var) : AllocVar(GetRef(op)); +} + +#define TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(OpName, OpString) \ + Doc TIRTextPrinter::VisitExpr_(const OpName* op) { \ + Doc doc; \ + doc << "(" << Print(op->a) << OpString; \ + doc << Print(op->b) << ")"; \ + return doc; \ + } + +TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(AddNode, " + ") +TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(SubNode, " - ") +TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(MulNode, "*") +TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(DivNode, " / ") +TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(ModNode, " % ") +TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(EQNode, " == ") +TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(NENode, " != ") +TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(LTNode, " < ") +TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(LENode, " <= ") +TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(GTNode, " > ") +TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(GENode, " >= ") +TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(AndNode, " && ") +TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(OrNode, " || ") + +Doc TIRTextPrinter::VisitExpr_(const FloorDivNode* op) { + Doc doc; + doc << "floordiv(" << Print(op->a) << ", " << Print(op->b) << ")"; + return doc; +} + +Doc TIRTextPrinter::VisitExpr_(const FloorModNode* op) { + Doc doc; + doc << "floormod(" << Print(op->a) << ", " << Print(op->b) << ")"; + return doc; +} + +Doc TIRTextPrinter::VisitExpr_(const MinNode* op) { + Doc doc; + doc << "min(" << Print(op->a) << ", " << Print(op->b) << ")"; + return doc; +} + +Doc TIRTextPrinter::VisitExpr_(const MaxNode* op) { + Doc doc; + doc << "max(" << Print(op->a) << ", " << Print(op->b) << ")"; + return doc; +} + +Doc TIRTextPrinter::VisitExpr_(const NotNode* op) { + Doc doc; + doc << "!" << Print(op->a); + return doc; +} + +Doc TIRTextPrinter::VisitExpr_(const SelectNode* op) { + Doc doc; + doc << "select(" << Print(op->condition) << ", " << Print(op->true_value) << ", " + << Print(op->false_value); + return doc; +} + +Doc TIRTextPrinter::VisitExpr_(const BufferLoadNode* op) { + Doc doc; + doc << Print(op->buffer) << Print(op->indices); + return doc; +} + +Doc TIRTextPrinter::VisitExpr_(const ProducerLoadNode* op) { + // TODO(tvm-team): consider make a better text format for producer. + Doc doc; + doc << op->producer->GetNameHint() << Print(op->indices); + return doc; +} + +Doc TIRTextPrinter::VisitExpr_(const LoadNode* op) { + Doc doc; + doc << "(" << PrintDType(op->dtype) << "*)" << Print(op->buffer_var) << "[" << Print(op->index) + << "])"; + if (!is_one(op->predicate)) { + doc << " if " << Print(op->predicate); + } + return doc; +} + +Doc TIRTextPrinter::VisitExpr_(const RampNode* op) { + Doc doc; + doc << "ramp(" << Print(op->base) << ", " << Print(op->stride) << ", " << op->lanes << ")"; + return doc; +} + +Doc TIRTextPrinter::VisitExpr_(const BroadcastNode* op) { + Doc doc; + doc << "broadcast(" << Print(op->value) << ", " << op->lanes << ")"; + return doc; +} + +Doc TIRTextPrinter::VisitExpr_(const LetNode* op) { + Doc doc; + doc << "let " << Print(op->var) << " = " << Print(op->value) << " in " << Print(op->body); + return doc; +} + +inline const char* CallType2String(CallNode::CallType t) { + switch (t) { + case CallNode::Extern: + return "extern"; + case CallNode::ExternCPlusPlus: + return "extern_cpp"; + case CallNode::PureExtern: + return "pure_extern"; + case CallNode::Intrinsic: + return "intrin"; + case CallNode::PureIntrinsic: + return "pure_intrin"; + } + LOG(FATAL) << "Unknown CallType"; + return "Unknown"; +} + +Doc TIRTextPrinter::VisitExpr_(const CallNode* op) { + Doc doc; + doc << "@" << Doc::Text(op->name) << "("; + std::vector args; + for (const auto& arg : op->args) { + args.push_back(Print(arg)); + } + doc << PrintSep(args, Doc::Text(", ")) << ", dtype=" << PrintDType(op->dtype) + << ", type=" << Doc::StrLiteral(CallType2String(op->call_type)) << ")"; + return doc; +} + +Doc TIRTextPrinter::VisitExpr_(const ShuffleNode* op) { + Doc doc; + doc << "shuffle(" << Print(op->vectors) << ", " << Print(op->indices) << ")"; + return doc; +} + +Doc TIRTextPrinter::VisitExpr_(const ReduceNode* op) { + Doc doc; + doc << "reduce(" << Print(op->combiner) << ", " << Print(op->source) << ", " << Print(op->axis) + << ", " << op->value_index << ")"; + return doc; +} + +Doc TIRTextPrinter::VisitStmt_(const LetStmtNode* op) { + Doc doc; + doc << "let " << Print(op->var) << " = " << Print(op->value) << PrintBody(op->body); + return doc; +} + +Doc TIRTextPrinter::VisitStmt_(const AttrStmtNode* op) { + Doc doc; + meta_collector_.Collect(op->node); + doc << "attr [" << Print(op->node) << "] " << Doc::StrLiteral(op->attr_key) << " = " + << Print(op->value); + if (op->body->IsInstance()) { + doc << PrintBody(op->body); + } else { + doc << ";" << Doc::NewLine() << Print(op->body); + } + return doc; +} + +Doc TIRTextPrinter::VisitStmt_(const AssertStmtNode* op) { + Doc doc; + doc << "assert(" << Print(op->condition) << ", " << Print(op->message) << ")" + << PrintBody(op->body); + return doc; +} + +Doc TIRTextPrinter::VisitStmt_(const StoreNode* op) { + Doc doc; + doc << Print(op->buffer_var) << "[" << Print(op->index) << "] = " << Print(op->value); + if (!is_one(op->predicate)) { + doc << " if " << Print(op->predicate); + } + return doc; +} + +Doc TIRTextPrinter::VisitStmt_(const BufferStoreNode* op) { + Doc doc; + doc << Print(op->buffer) << Print(op->indices) << " = " << Print(op->value); + return doc; +} + +Doc TIRTextPrinter::VisitStmt_(const BufferRealizeNode* op) { + Doc doc; + doc << "realize(" << Print(op->buffer) << ", " << Print(op->bounds) << ", " + << Print(op->condition) << PrintBody(op->body) << ")"; + return doc; +} + +Doc TIRTextPrinter::VisitStmt_(const AllocateNode* op) { + Doc doc; + doc << "allocate(" << Print(op->buffer_var) << ", " << PrintDType(op->dtype) << ", " + << Print(op->extents) << ")"; + if (!is_one(op->condition)) { + doc << " if " << Print(op->condition); + } + if (op->body->IsInstance()) { + doc << PrintBody(op->body); + } else { + doc << ";" << Doc::NewLine() << Print(op->body); + } + return doc; +} + +Doc TIRTextPrinter::VisitStmt_(const FreeNode* op) { + Doc doc; + doc << "free(" << Print(op->buffer_var) << ")"; + return doc; +} + +Doc TIRTextPrinter::VisitStmt_(const IfThenElseNode* op) { + Doc doc; + doc << "if " << Print(op->condition) << PrintBody(op->then_case); + if (!is_one(op->condition) && op->else_case.defined()) { + doc << " else" << PrintBody(op->else_case); + } + return doc; +} + +Doc TIRTextPrinter::VisitStmt_(const SeqStmtNode* op) { + std::vector stmts; + Doc seq_doc, doc; + for (Stmt stmt : op->seq) { + seq_doc << Doc::NewLine() << Print(stmt); + } + doc << " {" << Doc::Indent(2, seq_doc) << Doc::NewLine() << "}"; + return doc; +} + +Doc TIRTextPrinter::VisitStmt_(const EvaluateNode* op) { + Doc doc; + doc << Print(op->value); + return doc; +} + +inline const char* ForType2String(ForType t) { + switch (t) { + case ForType::Serial: + return "serial"; + case ForType::Parallel: + return "parallel"; + case ForType::Vectorized: + return "vectorized"; + case ForType::Unrolled: + return "unroll"; + } + LOG(FATAL) << "Unknown ForType"; + return "Unknown"; +} + +Doc TIRTextPrinter::VisitStmt_(const ForNode* op) { + Doc doc; + doc << "for (" << Print(op->loop_var) << ", " << Print(op->min) << ", " + << Print(op->min + op->extent) << ")"; + if (op->for_type != ForType::Serial) { + doc << " " << Doc::StrLiteral(ForType2String(op->for_type)); + } + doc << PrintBody(op->body); + return doc; +} + +Doc TIRTextPrinter::VisitStmt_(const PrefetchNode* op) { + Doc doc; + doc << "prefetch(" << Print(op->buffer) << ", " << Print(op->bounds) << ")"; + return doc; +} + +Doc TIRTextPrinter::VisitType_(const PrimTypeNode* node) { + Doc doc; + doc << PrintDType(node->dtype); + return doc; +} + +Doc TIRTextPrinter::VisitType_(const PointerTypeNode* node) { + Doc doc; + doc << "Pointer(" << Print(node->element_type) << ")"; + return doc; +} + +Doc TIRTextPrinter::VisitType_(const TupleTypeNode* node) { + std::vector fields; + for (Type field : node->fields) { + fields.push_back(Print(field)); + } + Doc doc; + doc << "(" << Doc::Concat(fields); + // conform to python tuple format (1,) + if (node->fields.size() == 1) { + doc << ","; + } + return doc << ")"; +} + +Doc TIRTextPrinter::PrintDType(DataType dtype) { + return Doc::Text(runtime::DLDataType2String(dtype)); +} + +template +Doc TIRTextPrinter::PrintConstScalar(DataType dtype, const T& data) { + Doc doc; + std::ostringstream os; + os << data; + if (dtype == DataType::Int(32)) { + doc << Doc::Text(os.str()); + } else { + if (dtype.bits() == 1 && dtype.lanes() == 1 && dtype.code() == kDLUInt) { + doc << ((data == 1) ? "True" : "False"); + return doc; + } + doc << Doc::Text(os.str()); + switch (dtype.code()) { + case kDLInt: + doc << "i"; + break; + case kDLUInt: + doc << "u"; + break; + case kDLFloat: + doc << "f"; + break; + } + doc << Doc::Text(std::to_string(dtype.bits())); + if (dtype.lanes() != 1) doc << "x" << Doc::Text(std::to_string(dtype.lanes())); + } + return doc; +} + +Doc TIRTextPrinter::GetUniqueName(std::string prefix) { + // std::replace(prefix.begin(), prefix.end(), '.', '_'); + std::string unique_prefix = prefix; + auto it = name_alloc_map_.find(prefix); + if (it != name_alloc_map_.end()) { + while (name_alloc_map_.count(unique_prefix = prefix + "_" + std::to_string(++it->second)) > 0) { + } + } + name_alloc_map_[unique_prefix] = 0; + return Doc::Text(unique_prefix); +} + +Doc TIRTextPrinter::AllocVar(const Var& var) { + const auto& it = memo_var_.find(var); + if (it != memo_var_.end()) { + return it->second; + } + std::string name = var->name_hint.operator std::string(); + if (name.length() == 0 || !std::isalpha(name[0])) { + name = "v" + name; + } + Doc val = GetUniqueName(name); + memo_var_[var] = val; + return val << ": " << Print(GetType(var)); +} + +Doc TIRTextPrinter::AllocBuf(const Buffer& buffer) { + const auto& it = memo_buf_.find(buffer); + if (it != memo_buf_.end()) { + return it->second; + } + std::string name = buffer->name; + if (name.length() == 0 || !std::isalpha(name[0])) { + name = "buf_" + name; + } + Doc val = GetUniqueName(name); + memo_buf_[buffer] = val; + return val; +} + +Doc TIRTextPrinter::PrintSep(const std::vector& vec, const Doc& sep) { + Doc seq; + if (vec.size() != 0) { + seq = vec[0]; + for (size_t i = 1; i < vec.size(); i++) { + seq << sep << vec[i]; + } + } + return seq; +} + +Doc TIRTextPrinter::PrintBody(const Stmt& body, bool indent) { + Doc doc; + if (body->IsInstance()) return Print(body); + doc << " {" << Doc::Indent(2, Doc::NewLine() << Print(body)) << Doc::NewLine() << "}"; + return doc; +} + +} // namespace tir +} // namespace tvm diff --git a/src/relay/analysis/annotated_region_set.cc b/src/relay/analysis/annotated_region_set.cc index 94c7621e60af..587add36706f 100644 --- a/src/relay/analysis/annotated_region_set.cc +++ b/src/relay/analysis/annotated_region_set.cc @@ -19,14 +19,13 @@ #include "annotated_region_set.h" -#include #include +#include #include #include #include - namespace tvm { namespace relay { @@ -39,8 +38,7 @@ AnnotatedRegion AnnotatedRegionSetNode::GetRegion(const Expr& expr) const { return AnnotatedRegion(nullptr); } -void AnnotatedRegionSetNode::MergeRegions(AnnotatedRegion src, - AnnotatedRegion dest) { +void AnnotatedRegionSetNode::MergeRegions(AnnotatedRegion src, AnnotatedRegion dest) { if (dest == src) { return; } @@ -86,32 +84,69 @@ AnnotatedRegion AnnotatedRegionSetNode::MakeRegion(const std::string& target) { return *ret.first; } -class AnnotatedRegionSet::Creator : public ExprVisitor { +class AnnotatedRegionSet::Creator : protected MixedModeVisitor { public: Creator(const Op& region_begin_op, const Op& region_end_op) : begin_op_(region_begin_op), end_op_(region_end_op) {} + AnnotatedRegionSet Create(const Expr& expr) { + VisitExpr(expr); + return std::move(region_set_); + } + + void AddToArgRegion(Expr expr, Array args) { + // Merge argument regions and add itself to the region. + + // Find the first open region. + AnnotatedRegion region; + for (auto arg : args) { + const CallNode* end = arg.as(); + if (end && end->op == end_op_) { // Ignore closed regions. + continue; + } + + region = region_set_->GetRegion(arg); + if (region.defined()) { + break; + } + } + + // Try to merge open regions. + for (auto arg : args) { + const CallNode* end = arg.as(); + if (end && end->op == end_op_) { // Ignore closed regions. + continue; + } + + auto arg_region = region_set_->GetRegion(arg); + CHECK_EQ(region.defined(), arg_region.defined()) + << "Arg regions are inconsistent: " << AsText(expr); + if (region.defined() && region != arg_region) { + region_set_->MergeRegions(arg_region, region); + } + } + if (region.defined()) { + region_set_->AddToRegion(region, expr); + } + } + void VisitExpr_(const CallNode* call) { auto op_node = call->op.as(); if (op_node == nullptr || call->attrs.as() == nullptr) { - // Propagate region to arguments - auto region = region_set_->GetRegion(GetRef(call)); - if (region.defined()) { - for (auto arg : call->args) { - region_set_->AddToRegion(region, arg); - } - } + AddToArgRegion(GetRef(call), call->args); } else if (call->op == begin_op_) { // The annotation node is inserted on edge so it must have only one argument. CHECK_EQ(call->args.size(), 1U); + std::string target = call->attrs.as()->compiler; + // Check if the argument already belongs to a region auto region = region_set_->GetRegion(GetRef(call)); - if (!region.defined()) { - throw Error(ErrorBuilder() - << "Cannot find the corresponding region for start annotation:\n" - << AsText(GetRef(call), false)); - } + CHECK(!region.defined()); + + // Create a new region. + region = region_set_->MakeRegion(target); + region->nodes_.insert(GetRef(call)); region->ins_.push_back(GetRef(call)); } else { CHECK_EQ(call->op, end_op_); @@ -122,9 +157,8 @@ class AnnotatedRegionSet::Creator : public ExprVisitor { // Check if the argument already belongs to a region auto region = region_set_->GetRegion(call->args[0]); if (!region.defined()) { - // Create a new region if the argument is not belonged to any regions yet. - region = region_set_->MakeRegion(target); - region->nodes_.insert(call->args[0]); + throw Error(ErrorBuilder() << "Cannot find the corresponding region for end annotation:\n" + << AsText(GetRef(call), false)); } else { // If the argument is belonged to a region, it must have the same target. // Otherwise we should see a region_begin op. @@ -133,83 +167,42 @@ class AnnotatedRegionSet::Creator : public ExprVisitor { region->nodes_.insert(GetRef(call)); region->outs_.push_back(GetRef(call)); } - ExprVisitor::VisitExpr_(call); } - AnnotatedRegionSet Create(const Expr& expr) { - VisitExpr(expr); - return std::move(region_set_); - } - - void VisitExpr_(const TupleNode* op) { - auto region = region_set_->GetRegion(GetRef(op)); - if (region.defined()) { - for (auto field : op->fields) { - region_set_->AddToRegion(region, field); - } - } - ExprVisitor::VisitExpr_(op); - } + void VisitExpr_(const TupleNode* op) { AddToArgRegion(GetRef(op), op->fields); } void VisitExpr_(const TupleGetItemNode* g) { - auto region = region_set_->GetRegion(GetRef(g)); - if (region.defined()) { - region_set_->AddToRegion(region, g->tuple); - } - ExprVisitor::VisitExpr_(g); - } - - void VisitExpr_(const FunctionNode* op) { - auto region = region_set_->GetRegion(GetRef(op)); - if (region.defined()) { - for (auto param : op->params) { - region_set_->AddToRegion(region, param); - } - } - ExprVisitor::VisitExpr_(op); + Array args = {g->tuple}; + AddToArgRegion(GetRef(g), args); } void VisitExpr_(const LetNode* op) { - auto region = region_set_->GetRegion(GetRef(op)); - if (region.defined()) { - region_set_->AddToRegion(region, op->var); - region_set_->AddToRegion(region, op->value); - region_set_->AddToRegion(region, op->body); - } + Array args = {op->var, op->value, op->body}; + AddToArgRegion(GetRef(op), args); ExprVisitor::VisitExpr_(op); } void VisitExpr_(const IfNode* op) { - auto region = region_set_->GetRegion(GetRef(op)); - if (region.defined()) { - region_set_->AddToRegion(region, op->cond); - region_set_->AddToRegion(region, op->true_branch); - region_set_->AddToRegion(region, op->false_branch); - } + Array args = {op->cond, op->true_branch, op->false_branch}; + AddToArgRegion(GetRef(op), args); ExprVisitor::VisitExpr_(op); } void VisitExpr_(const RefCreateNode* op) { - auto region = region_set_->GetRegion(GetRef(op)); - if (region.defined()) { - region_set_->AddToRegion(region, op->value); - } + Array args = {op->value}; + AddToArgRegion(GetRef(op), args); ExprVisitor::VisitExpr_(op); } void VisitExpr_(const RefReadNode* op) { - auto region = region_set_->GetRegion(GetRef(op)); - if (region.defined()) { - region_set_->AddToRegion(region, op->ref); - } + Array args = {op->ref}; + AddToArgRegion(GetRef(op), args); ExprVisitor::VisitExpr_(op); } void VisitExpr_(const RefWriteNode* op) { - auto region = region_set_->GetRegion(GetRef(op)); - if (region.defined()) { - region_set_->AddToRegion(region, op->ref); - } + Array args = {op->ref}; + AddToArgRegion(GetRef(op), args); ExprVisitor::VisitExpr_(op); } @@ -230,15 +223,14 @@ TVM_REGISTER_NODE_TYPE(AnnotatedRegionNode); TVM_REGISTER_NODE_TYPE(AnnotatedRegionSetNode); TVM_REGISTER_GLOBAL("relay.analysis.AnnotatedRegionSet") -.set_body_typed([](Expr expr, Op begin, Op end) { - return AnnotatedRegionSet::Create(expr, begin, end); -}); + .set_body_typed([](Expr expr, Op begin, Op end) { + return AnnotatedRegionSet::Create(expr, begin, end); + }); TVM_REGISTER_GLOBAL("relay.analysis.GetRegion") -.set_body_typed([](AnnotatedRegionSet region_set, Expr expr) { - return region_set->GetRegion(expr); -}); - + .set_body_typed([](AnnotatedRegionSet region_set, Expr expr) { + return region_set->GetRegion(expr); + }); } // namespace relay } // namespace tvm diff --git a/src/relay/analysis/annotated_region_set.h b/src/relay/analysis/annotated_region_set.h index 3bd569387d46..cbcf155350df 100644 --- a/src/relay/analysis/annotated_region_set.h +++ b/src/relay/analysis/annotated_region_set.h @@ -27,19 +27,19 @@ #ifndef TVM_RELAY_ANALYSIS_ANNOTATED_REGION_SET_H_ #define TVM_RELAY_ANALYSIS_ANNOTATED_REGION_SET_H_ +#include #include #include #include -#include #include -#include #include +#include +#include #include #include #include #include -#include namespace tvm { namespace relay { @@ -61,29 +61,19 @@ class AnnotatedRegionNode : public Object { } /*! \brief Get the region ID. */ - int GetID() const { - return id_; - } + int GetID() const { return id_; } /*! \brief Get the region target. */ - std::string GetTarget() const { - return target_; - } + std::string GetTarget() const { return target_; } /*! \brief Get the region's inputs. */ - std::list GetInputs() const { - return ins_; - } + std::list GetInputs() const { return ins_; } /*! \brief Get the region's outputs. */ - std::list GetOutputs() const { - return outs_; - } + std::list GetOutputs() const { return outs_; } /*! \brief Get the region's nodes. */ - std::unordered_set GetNodes() const { - return nodes_; - } + std::unordered_set GetNodes() const { return nodes_; } static constexpr const char* _type_key = "relay.AnnotatedRegion"; TVM_DECLARE_FINAL_OBJECT_INFO(AnnotatedRegionNode, Object); @@ -98,7 +88,7 @@ class AnnotatedRegionNode : public Object { /*! \brief The outputs of this region */ std::list outs_; /*! \brief Nodes in this region. */ - std::unordered_set nodes_; + std::unordered_set nodes_; friend class AnnotatedRegionSet; friend class AnnotatedRegionSetNode; @@ -107,7 +97,7 @@ class AnnotatedRegionNode : public Object { /*! * \brief An object to hold the properties of a region as used by the * AnnotatedRegionSet class. This should be considered read-only. -*/ + */ class AnnotatedRegion : public ObjectRef { public: AnnotatedRegion() { @@ -116,9 +106,9 @@ class AnnotatedRegion : public ObjectRef { } /*! - * \brief Construct from an object pointer. - * \param n The object pointer. - */ + * \brief Construct from an object pointer. + * \param n The object pointer. + */ explicit AnnotatedRegion(ObjectPtr n) : ObjectRef(n) {} /*! \return Mutable pointers to the node. */ @@ -130,8 +120,7 @@ class AnnotatedRegion : public ObjectRef { }; class AnnotatedRegionSetNode : public Object { - using UnorderedRegionSet = - std::unordered_set; + using UnorderedRegionSet = std::unordered_set; // Create iterator alias for a RegionSet object. using iterator = UnorderedRegionSet::iterator; using const_iterator = UnorderedRegionSet::const_iterator; @@ -141,21 +130,13 @@ class AnnotatedRegionSetNode : public Object { AnnotatedRegionSetNode() = default; /*! \return The begin iterator */ - iterator begin() { - return regions_.begin(); - } + iterator begin() { return regions_.begin(); } /*! \return The end iterator */ - iterator end() { - return regions_.end(); - } + iterator end() { return regions_.end(); } /*! \return The const begin iterator */ - const_iterator begin() const { - return regions_.begin(); - } + const_iterator begin() const { return regions_.begin(); } /*! \return The const end iterator */ - const_iterator end() const { - return regions_.end(); - } + const_iterator end() const { return regions_.end(); } /*! * \brief Get the region that an expression belongs to. @@ -168,11 +149,11 @@ class AnnotatedRegionSetNode : public Object { AnnotatedRegion GetRegion(const Expr& expr) const; /*! - * \brief Merge src region into dest region. - * - * \param src The region to merge - will be erased. - * \param dest The region into which src will be merged. - */ + * \brief Merge src region into dest region. + * + * \param src The region to merge - will be erased. + * \param dest The region into which src will be merged. + */ void MergeRegions(AnnotatedRegion src, AnnotatedRegion dest); void VisitAttrs(AttrVisitor* v) { @@ -199,7 +180,7 @@ class AnnotatedRegionSetNode : public Object { */ AnnotatedRegion MakeRegion(const std::string& target); - std::unordered_set regions_; + std::unordered_set regions_; /*! \brief The next region ID to assign. */ int region_id_{0}; @@ -214,8 +195,7 @@ class AnnotatedRegionSetNode : public Object { * to update and query regions. */ class AnnotatedRegionSet : public ObjectRef { - using UnorderedRegionSet = - std::unordered_set; + using UnorderedRegionSet = std::unordered_set; // Create iterator alias for a RegionSet object. using iterator = UnorderedRegionSet::iterator; using const_iterator = UnorderedRegionSet::const_iterator; @@ -227,10 +207,10 @@ class AnnotatedRegionSet : public ObjectRef { } /*! - * \brief Construct from an object pointer. - * - * \param n The object pointer. - */ + * \brief Construct from an object pointer. + * + * \param n The object pointer. + */ explicit AnnotatedRegionSet(ObjectPtr n) : ObjectRef(n) {} /*! \return The begin iterator. */ @@ -253,7 +233,7 @@ class AnnotatedRegionSet : public ObjectRef { } /*! \return The end iterator. */ const_iterator end() const { - const auto *n = operator->(); + const auto* n = operator->(); CHECK(n); return n->end(); } @@ -267,7 +247,7 @@ class AnnotatedRegionSet : public ObjectRef { /*! \return The region an expression belongs to. */ AnnotatedRegion operator[](const Expr& expr) { - const auto *n = operator->(); + const auto* n = operator->(); CHECK(n); return n->GetRegion(expr); } @@ -280,9 +260,7 @@ class AnnotatedRegionSet : public ObjectRef { * * \return The created RegionSet for the expression. */ - static AnnotatedRegionSet Create(const Expr& expr, - const Op& begin, - const Op& end); + static AnnotatedRegionSet Create(const Expr& expr, const Op& begin, const Op& end); private: /*! \brief Helper class to construct a RegionSet from an expr.*/ diff --git a/src/relay/analysis/call_graph.cc b/src/relay/analysis/call_graph.cc index a12d23d88a30..0d3fedcde0f7 100644 --- a/src/relay/analysis/call_graph.cc +++ b/src/relay/analysis/call_graph.cc @@ -26,6 +26,7 @@ #include #include + #include #include #include @@ -72,22 +73,21 @@ void CallGraphNode::AddToCallGraph(const GlobalVar& gv, const Function& func) { const CallGraphEntry* CallGraphNode::operator[](const GlobalVar& gv) const { const_iterator cit = call_graph_.find(gv); - CHECK(cit != call_graph_.end()) - << "GlobalVar " << gv->name_hint << " not found in the call graph!"; + CHECK(cit != call_graph_.end()) << "GlobalVar " << gv->name_hint + << " not found in the call graph!"; return cit->second.get(); } CallGraphEntry* CallGraphNode::operator[](const GlobalVar& gv) { const_iterator cit = call_graph_.find(gv); - CHECK(cit != call_graph_.end()) - << "GlobalVar " << gv->name_hint << " not found in the call graph!"; + CHECK(cit != call_graph_.end()) << "GlobalVar " << gv->name_hint + << " not found in the call graph!"; return cit->second.get(); } BaseFunc CallGraphNode::GetGlobalFunction(const GlobalVar& var) const { CHECK(module->ContainGlobalVar(var->name_hint)) - << "GlobalVar " << var->name_hint - << " not found in the current ir module"; + << "GlobalVar " << var->name_hint << " not found in the current ir module"; return module->Lookup(var); } @@ -120,8 +120,8 @@ GlobalVar CallGraphNode::RemoveGlobalVarFromModule(CallGraphEntry* cg_node, bool update_call_graph) { CHECK(cg_node->empty() || (cg_node->IsRecursive() && cg_node->size() == 1)) << "Cannot remove global var " << cg_node->GetNameHint() - << " from call graph, because it still calls " - << cg_node->size() << " other global functions"; + << " from call graph, because it still calls " << cg_node->size() + << " other global functions"; if (update_call_graph) { // Update the call graph by removing all edges that point to the node @@ -172,8 +172,7 @@ std::vector CallGraphNode::TopologicalOrder() const { << " with # refs = " << (*this)[it.first]->GetRefCount(); } } - LOG(FATAL) << "Expected " << module->functions.size() - << " globals, but received " + LOG(FATAL) << "Expected " << module->functions.size() << " globals, but received " << ret.size(); } @@ -184,8 +183,7 @@ std::vector CallGraphNode::TopologicalOrder() const { // that are visited by previous CallGraphEntry entries can be memoized. This // helps us to make sure no entry will be visited multiple times when collecting // the nodes for an entire call graph. -std::vector CallGraphEntry::TopologicalOrder( - CallGraphEntrySet* visited) const { +std::vector CallGraphEntry::TopologicalOrder(CallGraphEntrySet* visited) const { std::vector ret; std::vector current_nodes; if (visited->find(this) == visited->end()) { @@ -234,8 +232,7 @@ inline void CallGraphEntry::AddCalledGlobal(CallGraphEntry* cg_node) { // Remove an edge from the current global function to the callee. void CallGraphEntry::RemoveCallTo(const GlobalVar& callee) { for (auto it = begin();; ++it) { - CHECK(it != end()) << "Cannot find global function " - << callee->name_hint << " to remove!"; + CHECK(it != end()) << "Cannot find global function " << callee->name_hint << " to remove!"; if (it->second->GetGlobalVar() == callee) { // Only remove one occurrence of the call site. it->second->DecRef(); @@ -260,8 +257,7 @@ void CallGraphEntry::RemoveAllCallTo(CallGraphEntry* callee) { } // Make sure all references to the callee are removed. CHECK_EQ(callee->GetRefCount(), 0U) - << "All references to " << callee->GetNameHint() - << " should have been removed"; + << "All references to " << callee->GetNameHint() << " should have been removed"; } void CallGraphEntry::Print(std::ostream& os) const { @@ -293,54 +289,51 @@ std::ostream& operator<<(std::ostream& os, const CallGraphEntry& cgn) { TVM_REGISTER_NODE_TYPE(CallGraphNode); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - CHECK(node); - p->stream << "CallGraph: \n" << GetRef(node); -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + CHECK(node); + p->stream << "CallGraph: \n" << GetRef(node); + }); -TVM_REGISTER_GLOBAL("relay.analysis.CallGraph") -.set_body_typed([](IRModule module) { +TVM_REGISTER_GLOBAL("relay.analysis.CallGraph").set_body_typed([](IRModule module) { return CallGraph(module); }); -TVM_REGISTER_GLOBAL("relay.analysis.PrintCallGraph") -.set_body_typed([](CallGraph call_graph) { +TVM_REGISTER_GLOBAL("relay.analysis.PrintCallGraph").set_body_typed([](CallGraph call_graph) { std::stringstream ss; ss << call_graph; return ss.str(); }); -TVM_REGISTER_GLOBAL("relay.analysis.GetModule") -.set_body_typed([](CallGraph call_graph) { +TVM_REGISTER_GLOBAL("relay.analysis.GetModule").set_body_typed([](CallGraph call_graph) { return call_graph->module; }); TVM_REGISTER_GLOBAL("relay.analysis.PrintCallGraphGlobalVar") -.set_body_typed([](CallGraph call_graph, GlobalVar var) { - const auto* entry_node = call_graph[var]; - std::stringstream ss; - ss << *entry_node; - return ss.str(); -}); + .set_body_typed([](CallGraph call_graph, GlobalVar var) { + const auto* entry_node = call_graph[var]; + std::stringstream ss; + ss << *entry_node; + return ss.str(); + }); TVM_REGISTER_GLOBAL("relay.analysis.GetRefCountGlobalVar") -.set_body_typed([](CallGraph call_graph, GlobalVar var) { - const auto* entry_node = call_graph[var]; - return static_cast(entry_node->GetRefCount()); -}); + .set_body_typed([](CallGraph call_graph, GlobalVar var) { + const auto* entry_node = call_graph[var]; + return static_cast(entry_node->GetRefCount()); + }); TVM_REGISTER_GLOBAL("relay.analysis.GetGlobalVarCallCount") -.set_body_typed([](CallGraph call_graph, GlobalVar var) { - const auto* entry_node = call_graph[var]; - return static_cast(entry_node->size()); -}); + .set_body_typed([](CallGraph call_graph, GlobalVar var) { + const auto* entry_node = call_graph[var]; + return static_cast(entry_node->size()); + }); TVM_REGISTER_GLOBAL("relay.analysis.IsRecursive") -.set_body_typed([](CallGraph call_graph, GlobalVar var) { - const auto* entry_node = call_graph[var]; - return entry_node->IsRecursive(); -}); + .set_body_typed([](CallGraph call_graph, GlobalVar var) { + const auto* entry_node = call_graph[var]; + return entry_node->IsRecursive(); + }); } // namespace relay } // namespace tvm diff --git a/src/relay/analysis/call_graph.h b/src/relay/analysis/call_graph.h index 86bc6469c316..07b25278b1d6 100644 --- a/src/relay/analysis/call_graph.h +++ b/src/relay/analysis/call_graph.h @@ -32,6 +32,7 @@ #include #include #include + #include #include #include @@ -47,8 +48,7 @@ class CallGraph; class CallGraphNode : public Object { using CallGraphMap = - std::unordered_map, ObjectHash, - ObjectEqual>; + std::unordered_map, ObjectPtrHash, ObjectPtrEqual>; // Create iterator alias for a CallGraphNode object. using iterator = CallGraphMap::iterator; using const_iterator = CallGraphMap::const_iterator; @@ -60,9 +60,7 @@ class CallGraphNode : public Object { /*! \brief Default constructor. */ CallGraphNode() {} - void VisitAttrs(AttrVisitor* v) { - v->Visit("module", &module); - } + void VisitAttrs(AttrVisitor* v) { v->Visit("module", &module); } /*! * \brief Print the call graph. @@ -72,21 +70,13 @@ class CallGraphNode : public Object { void Print(std::ostream& os) const; /*! \return The begin iterator. */ - iterator begin() { - return call_graph_.begin(); - } + iterator begin() { return call_graph_.begin(); } /*! \return The end iterator. */ - iterator end() { - return call_graph_.end(); - } + iterator end() { return call_graph_.end(); } /*! \return The begin iterator. */ - const_iterator begin() const { - return call_graph_.begin(); - } + const_iterator begin() const { return call_graph_.begin(); } /*! \return The end iterator. */ - const_iterator end() const { - return call_graph_.end(); - } + const_iterator end() const { return call_graph_.end(); } /*! * \brief Get an element from the CallGraphNode using a GlobalVar. @@ -157,8 +147,7 @@ class CallGraphNode : public Object { * * \return The GlobalVar removed from the current module. */ - GlobalVar RemoveGlobalVarFromModule(CallGraphEntry* cg_node, - bool update_call_graph = false); + GlobalVar RemoveGlobalVarFromModule(CallGraphEntry* cg_node, bool update_call_graph = false); /*! * \brief Lookup a GlobalVar for the CallGraphNode. It creates an entry for @@ -207,8 +196,7 @@ class CallGraphNode : public Object { */ class CallGraph : public ObjectRef { using CallGraphMap = - std::unordered_map, ObjectHash, - ObjectEqual>; + std::unordered_map, ObjectPtrHash, ObjectPtrEqual>; // Create iterator alias for a CallGraph object. using iterator = CallGraphMap::iterator; using const_iterator = CallGraphMap::const_iterator; @@ -340,30 +328,20 @@ class CallGraphEntry { CallGraphEntry& operator=(const CallGraphEntry&) = delete; /*! \return The begin iterator */ - iterator begin() { - return called_globals_.begin(); - } + iterator begin() { return called_globals_.begin(); } /*! \return The end iterator */ - iterator end() { - return called_globals_.end(); - } + iterator end() { return called_globals_.end(); } /*! \return The const begin iterator */ - const_iterator begin() const { - return called_globals_.begin(); - } + const_iterator begin() const { return called_globals_.begin(); } /*! \return The const end iterator */ - const_iterator end() const { - return called_globals_.end(); - } + const_iterator end() const { return called_globals_.end(); } /*! * \brief Return if the list of called nodes is empty. * * \return true if the list is empty. Otherwise, false. */ - bool empty() const { - return called_globals_.empty(); - } + bool empty() const { return called_globals_.empty(); } /*! * \brief Return the size of the list that represents the nodes are called by @@ -371,9 +349,7 @@ class CallGraphEntry { * * \return The number of called nodes. */ - uint32_t size() const { - return static_cast(called_globals_.size()); - } + uint32_t size() const { return static_cast(called_globals_.size()); } /*! * \brief Fetch the i-th CallGraphEntry from the list of nodes that are called @@ -400,27 +376,21 @@ class CallGraphEntry { * * \return The count. */ - uint32_t GetRefCount() const { - return ref_cnt_; - } + uint32_t GetRefCount() const { return ref_cnt_; } /*! * \brief Return the GlobalVar stored in the current CallGraphEntry. * * \return The GlobalVar. */ - GlobalVar GetGlobalVar() const { - return global_; - } + GlobalVar GetGlobalVar() const { return global_; } /*! * \brief Return the name hint of the GlobalVar stored in the CallGraphEntry. * * \return The name hint of the global function. */ - std::string GetNameHint() const { - return global_->name_hint; - } + std::string GetNameHint() const { return global_->name_hint; } /*! * \brief Return if the global function corresponding to the current @@ -428,9 +398,7 @@ class CallGraphEntry { * * \return true if it is recursive. Otherwise, false. */ - bool IsRecursive() const { - return is_recursive_; - } + bool IsRecursive() const { return is_recursive_; } /*! * \brief Return if the global function corresponding to the current @@ -439,9 +407,7 @@ class CallGraphEntry { * * \return true if it is both a recursive function and an entry. Otherwise, false. */ - bool IsRecursiveEntry() const { - return GetRefCount() == 1 && IsRecursive(); - } + bool IsRecursiveEntry() const { return GetRefCount() == 1 && IsRecursive(); } /*! * \brief Return the topological order of the CallGraphEntry. diff --git a/src/relay/analysis/dependency_graph.cc b/src/relay/analysis/dependency_graph.cc index 7e48d12d0cf3..5db833866a3e 100644 --- a/src/relay/analysis/dependency_graph.cc +++ b/src/relay/analysis/dependency_graph.cc @@ -22,7 +22,9 @@ * \brief Implementation of dependency graph APIs. */ #include "dependency_graph.h" + #include + #include #include @@ -32,8 +34,7 @@ namespace relay { // Creator of DependencyGraph class DependencyGraph::Creator : private ExprFunctor { public: - explicit Creator(support::Arena* arena) - : arena_(arena) {} + explicit Creator(support::Arena* arena) : arena_(arena) {} DependencyGraph Create(const Expr& body) { this->VisitExpr(body); @@ -64,7 +65,7 @@ class DependencyGraph::Creator : private ExprFunctor { parent->children.Push(child_link); } - std::unordered_set visited_; + std::unordered_set visited_; DependencyGraph::Node* NewNode(bool new_scope) { auto* ret = arena_->make(); @@ -164,15 +165,15 @@ class DependencyGraph::Creator : private ExprFunctor { } } - void VisitExpr_(const VarNode* v) final { } + void VisitExpr_(const VarNode* v) final {} - void VisitExpr_(const GlobalVarNode* v) final { } + void VisitExpr_(const GlobalVarNode* v) final {} - void VisitExpr_(const ConstantNode* c) final { } + void VisitExpr_(const ConstantNode* c) final {} - void VisitExpr_(const OpNode* o) final { } + void VisitExpr_(const OpNode* o) final {} - void VisitExpr_(const ConstructorNode* c) final { } + void VisitExpr_(const ConstructorNode* c) final {} }; DependencyGraph DependencyGraph::Create(support::Arena* arena, const Expr& body) { diff --git a/src/relay/analysis/dependency_graph.h b/src/relay/analysis/dependency_graph.h index 5e2dc0c899d1..1de125770c7c 100644 --- a/src/relay/analysis/dependency_graph.h +++ b/src/relay/analysis/dependency_graph.h @@ -25,16 +25,18 @@ #define TVM_RELAY_ANALYSIS_DEPENDENCY_GRAPH_H_ #include + #include #include -#include "../transforms/let_list.h" + #include "../../support/arena.h" +#include "../transforms/let_list.h" namespace tvm { namespace relay { -using support::LinkNode; using support::LinkedList; +using support::LinkNode; /* DependencyGraph track input and output of an Expr. * Additionally, dummy scope is created to model scope. @@ -54,7 +56,7 @@ class DependencyGraph { }; /*! \brief Maps a Relay Expr to its node in the dependency graph. */ - std::unordered_map expr_node; + std::unordered_map expr_node; /*! \brief The dependency graph in post DFS order. */ std::vector post_dfs_order; diff --git a/src/relay/analysis/extract_fused_functions.cc b/src/relay/analysis/extract_fused_functions.cc index ff3756cd318d..e76b54e2d0b7 100644 --- a/src/relay/analysis/extract_fused_functions.cc +++ b/src/relay/analysis/extract_fused_functions.cc @@ -50,7 +50,7 @@ class FusedFunctionExtractorWrapper : private ExprVisitor { const IRModule mod_; // This is not simply Map because GlobalVar doesn't // have the desired equals property - Map functions; + Map functions; void VisitExpr_(const FunctionNode* n) final { if (n->HasNonzeroAttr(attr::kPrimitive)) { diff --git a/src/relay/analysis/feature.cc b/src/relay/analysis/feature.cc index 95c2f731ff72..a145b28d55e8 100644 --- a/src/relay/analysis/feature.cc +++ b/src/relay/analysis/feature.cc @@ -21,11 +21,12 @@ * \file feature.cc * \brief Detect features used in Expr/Module */ -#include +#include #include #include #include -#include +#include + #include "../transforms/pass_util.h" namespace tvm { @@ -36,7 +37,7 @@ FeatureSet DetectFeature(const Expr& expr) { return FeatureSet::No(); } struct FeatureDetector : ExprVisitor { - std::unordered_set visited_; + std::unordered_set visited_; FeatureSet fs = FeatureSet::No(); void VisitExpr(const Expr& expr) final { @@ -49,34 +50,30 @@ FeatureSet DetectFeature(const Expr& expr) { } } } -#define DETECT_CONSTRUCT(CONSTRUCT_NAME, STMT) \ - void VisitExpr_(const CONSTRUCT_NAME##Node* op) final { \ - STMT \ - fs += f##CONSTRUCT_NAME; \ - } -#define DETECT_DEFAULT_CONSTRUCT(CONSTRUCT_NAME) DETECT_CONSTRUCT(CONSTRUCT_NAME, { \ - ExprVisitor::VisitExpr_(op); \ - }) +#define DETECT_CONSTRUCT(CONSTRUCT_NAME, STMT) \ + void VisitExpr_(const CONSTRUCT_NAME##Node* op) final { STMT fs += f##CONSTRUCT_NAME; } +#define DETECT_DEFAULT_CONSTRUCT(CONSTRUCT_NAME) \ + DETECT_CONSTRUCT(CONSTRUCT_NAME, { ExprVisitor::VisitExpr_(op); }) DETECT_DEFAULT_CONSTRUCT(Var) DETECT_DEFAULT_CONSTRUCT(GlobalVar) DETECT_DEFAULT_CONSTRUCT(Constant) DETECT_DEFAULT_CONSTRUCT(Tuple) DETECT_DEFAULT_CONSTRUCT(TupleGetItem) DETECT_CONSTRUCT(Function, { - if (!op->HasNonzeroAttr(attr::kPrimitive)) { - ExprVisitor::VisitExpr_(op); - } - }) + if (!op->HasNonzeroAttr(attr::kPrimitive)) { + ExprVisitor::VisitExpr_(op); + } + }) DETECT_DEFAULT_CONSTRUCT(Op) DETECT_DEFAULT_CONSTRUCT(Call) DETECT_CONSTRUCT(Let, { - for (const Var& v : FreeVars(op->value)) { - if (op->var == v) { - fs += fLetRec; - } + for (const Var& v : FreeVars(op->value)) { + if (op->var == v) { + fs += fLetRec; } - ExprVisitor::VisitExpr_(op); - }) + } + ExprVisitor::VisitExpr_(op); + }) DETECT_DEFAULT_CONSTRUCT(If) DETECT_DEFAULT_CONSTRUCT(RefCreate) DETECT_DEFAULT_CONSTRUCT(RefRead) @@ -99,13 +96,15 @@ FeatureSet DetectFeature(const IRModule& mod) { return fs; } -Array PyDetectFeature(const Expr& expr, const IRModule& mod) { - FeatureSet fs = DetectFeature(expr) + DetectFeature(mod); +Array PyDetectFeature(const Expr& expr, const Optional& mod) { + FeatureSet fs = DetectFeature(expr); + if (mod.defined()) { + fs = fs + DetectFeature(mod.value()); + } return static_cast>(fs); } -TVM_REGISTER_GLOBAL("relay.analysis.detect_feature") -.set_body_typed(PyDetectFeature); +TVM_REGISTER_GLOBAL("relay.analysis.detect_feature").set_body_typed(PyDetectFeature); } // namespace relay } // namespace tvm diff --git a/src/relay/analysis/kind_check.cc b/src/relay/analysis/kind_check.cc index b4835ccb7a3c..ac0abc065557 100644 --- a/src/relay/analysis/kind_check.cc +++ b/src/relay/analysis/kind_check.cc @@ -31,9 +31,9 @@ * We check this by ensuring the `dtype` field of a Tensor always * contains a data type such as `int`, `float`, `uint`. */ +#include #include #include -#include namespace tvm { namespace relay { @@ -51,40 +51,28 @@ struct KindChecker : TypeFunctor { this->err_reporter.RenderErrors(mod); } - void CheckKindMatches(const Type& t, const Type& outer, - Kind expected, const std::string& description) { + void CheckKindMatches(const Type& t, const Type& outer, Kind expected, + const std::string& description) { Kind k = this->VisitType(t); if (k != expected) { ReportFatalError(ErrorBuilder() - << "Incorrect kind for a " << description - << ". Type " << t << " inside " << outer - << " is of kind " << k - << " but was expected to be " - << expected); + << "Incorrect kind for a " << description << ". Type " << t << " inside " + << outer << " is of kind " << k << " but was expected to be " << expected); } } - Kind VisitType_(const IncompleteTypeNode* op) override { - return op->kind; - } + Kind VisitType_(const IncompleteTypeNode* op) override { return op->kind; } - Kind VisitType_(const TypeVarNode* op) override { - return op->kind; - } + Kind VisitType_(const TypeVarNode* op) override { return op->kind; } - Kind VisitType_(const GlobalTypeVarNode* op) override { - return op->kind; - } + Kind VisitType_(const GlobalTypeVarNode* op) override { return op->kind; } - Kind VisitType_(const TensorTypeNode* op) override { - return Kind::kType; - } + Kind VisitType_(const TensorTypeNode* op) override { return Kind::kType; } Kind VisitType_(const TupleTypeNode* op) override { // tuples should only contain normal types for (const Type& t : op->fields) { - CheckKindMatches(t, GetRef(op), Kind::kType, - "tuple member"); + CheckKindMatches(t, GetRef(op), Kind::kType, "tuple member"); } return Kind::kType; } @@ -117,8 +105,7 @@ struct KindChecker : TypeFunctor { Kind VisitType_(const TypeRelationNode* op) override { // arguments to type relation should be normal types for (const Type& t : op->args) { - CheckKindMatches(t, GetRef(op), Kind::kType, - "argument to type relation"); + CheckKindMatches(t, GetRef(op), Kind::kType, "argument to type relation"); } return Kind::kConstraint; } @@ -128,9 +115,8 @@ struct KindChecker : TypeFunctor { TypeCall tc = GetRef(op); const auto* gtv = op->func.as(); if (gtv == nullptr) { - ReportFatalError( - ErrorBuilder() <<"The callee in " << tc - << " is not a global type var, but is " << op->func); + ReportFatalError(ErrorBuilder() << "The callee in " << tc + << " is not a global type var, but is " << op->func); } CheckKindMatches(op->func, tc, Kind::kAdtHandle, "type call function"); @@ -143,9 +129,8 @@ struct KindChecker : TypeFunctor { auto var = GetRef(gtv); auto data = mod->LookupTypeDef(var); if (data->type_vars.size() != op->args.size()) { - ReportFatalError(ErrorBuilder() - << "Expected " << data->type_vars.size() << "arguments for " << tc - << "; got " << op->args.size()); + ReportFatalError(ErrorBuilder() << "Expected " << data->type_vars.size() << "arguments for " + << tc << "; got " << op->args.size()); } return Kind::kType; } @@ -164,9 +149,8 @@ struct KindChecker : TypeFunctor { for (const auto& con : op->constructors) { if (!con->belong_to.same_as(op->header)) { - ReportFatalError(ErrorBuilder() - <belong_to - << " but " << op << " has header " << op->header); + ReportFatalError(ErrorBuilder() << con << " has header " << con->belong_to << " but " << op + << " has header " << op->header); } for (const Type& t : con->inputs) { @@ -176,9 +160,7 @@ struct KindChecker : TypeFunctor { return Kind::kTypeData; } - Kind Check(const Type& t) { - return this->VisitType(t); - } + Kind Check(const Type& t) { return this->VisitType(t); } }; Kind KindCheck(const Type& t, const IRModule& mod) { @@ -186,14 +168,13 @@ Kind KindCheck(const Type& t, const IRModule& mod) { return kc.Check(t); } -TVM_REGISTER_GLOBAL("relay.analysis.check_kind") -.set_body([](TVMArgs args, TVMRetValue* ret) { - if (args.size() == 1) { - *ret = KindCheck(args[0], IRModule({}, {})); - } else { - *ret = KindCheck(args[0], args[1]); - } - }); +TVM_REGISTER_GLOBAL("relay.analysis.check_kind").set_body([](TVMArgs args, TVMRetValue* ret) { + if (args.size() == 1) { + *ret = KindCheck(args[0], IRModule({}, {})); + } else { + *ret = KindCheck(args[0], args[1]); + } +}); } // namespace relay } // namespace tvm diff --git a/src/relay/analysis/mac_count.cc b/src/relay/analysis/mac_count.cc index fecde3c75669..d2e62b705d99 100644 --- a/src/relay/analysis/mac_count.cc +++ b/src/relay/analysis/mac_count.cc @@ -26,11 +26,12 @@ * otherwise the count is 0. */ -#include +#include #include #include -#include +#include #include + #include "../transforms/pattern_util.h" namespace tvm { @@ -52,8 +53,7 @@ inline int64_t GetCartesianProd(Array arr) { * \param call_node The call node. * \return The number of MACs. */ -using FMacCount = runtime::TypedPackedFunc< - int64_t(const Call& call_node)>; +using FMacCount = runtime::TypedPackedFunc; //---------------------------------------------- // Per operator defs for MAC count @@ -65,30 +65,26 @@ int64_t ConvMacCount(const Call& call_node) { return 0; } Array args = call_node->args; - CHECK_EQ(args.size(), 2) - << "The number of input arguments of a CONV 2D node should be 2."; + CHECK_EQ(args.size(), 2) << "The number of input arguments of a CONV 2D node should be 2."; const auto* conv_2d_attr = call_node->attrs.as(); const auto* data_type = args[0]->checked_type().as(); Array data_shape = data_type->shape; std::string data_layout = conv_2d_attr->data_layout; int32_t C_ind = Layout(data_layout).IndexOf(LayoutAxis::Get('C')); int32_t c_ind = Layout(data_layout).IndexOf(LayoutAxis::Get('c')); - CHECK_NE(C_ind, -1) - << "There is no input channel dimension."; + CHECK_NE(C_ind, -1) << "There is no input channel dimension."; int64_t input_channel = static_cast(data_shape[C_ind].as()->value); - if (c_ind != -1) - input_channel *= static_cast(data_shape[c_ind].as()->value); + if (c_ind != -1) input_channel *= static_cast(data_shape[c_ind].as()->value); Array kernel_size = conv_2d_attr->kernel_size; - CHECK_EQ(kernel_size.size(), 2) - << "The dimension of the kernel in Conv 2D should be 2."; + CHECK_EQ(kernel_size.size(), 2) << "The dimension of the kernel in Conv 2D should be 2."; const auto* expr = call_node->checked_type().as(); Array output_tensor = expr->shape; CHECK(output_tensor.size() == 4 || output_tensor.size() == 5) - << "The dimension of the output tensor in Conv 2D should be 4 or 5."; + << "The dimension of the output tensor in Conv 2D should be 4 or 5."; int64_t count = GetCartesianProd(output_tensor) * GetCartesianProd(kernel_size); CHECK_EQ(input_channel % conv_2d_attr->groups, 0) - << "The number of input channels is not divisble by groups."; - count *= input_channel/conv_2d_attr->groups; + << "The number of input channels is not divisble by groups."; + count *= input_channel / conv_2d_attr->groups; return count; } @@ -99,29 +95,27 @@ int64_t Conv2dTransposeMacCount(const Call& call_node) { } Array args = call_node->args; CHECK_EQ(args.size(), 2) - << "The number of input arguments of a CONV 2D Transpose node should be 2."; + << "The number of input arguments of a CONV 2D Transpose node should be 2."; const auto* conv_2d_transpose_attr = call_node->attrs.as(); const auto* data_type = args[0]->checked_type().as(); Array data_shape = data_type->shape; std::string data_layout = conv_2d_transpose_attr->data_layout; int32_t C_ind = Layout(data_layout).IndexOf(LayoutAxis::Get('C')); int32_t c_ind = Layout(data_layout).IndexOf(LayoutAxis::Get('c')); - CHECK_NE(C_ind, -1) - << "There is no input channel dimension."; + CHECK_NE(C_ind, -1) << "There is no input channel dimension."; int64_t input_channel = static_cast(data_shape[C_ind].as()->value); - if (c_ind != -1) - input_channel *= static_cast(data_shape[c_ind].as()->value); + if (c_ind != -1) input_channel *= static_cast(data_shape[c_ind].as()->value); Array kernel_size = conv_2d_transpose_attr->kernel_size; CHECK_EQ(kernel_size.size(), 2) - << "The dimension of the kernel in Conv 2D Transpose should be 2."; + << "The dimension of the kernel in Conv 2D Transpose should be 2."; const auto* expr = call_node->checked_type().as(); Array output_tensor = expr->shape; CHECK(output_tensor.size() == 4 || output_tensor.size() == 5) - << "The dimension of the output tensor in Conv 2D Transpose should be 4 or 5."; + << "The dimension of the output tensor in Conv 2D Transpose should be 4 or 5."; int64_t count = GetCartesianProd(output_tensor) * GetCartesianProd(kernel_size); CHECK_EQ(input_channel % conv_2d_transpose_attr->groups, 0) - << "The number of input channels is not divisble by groups."; - count *= input_channel/conv_2d_transpose_attr->groups; + << "The number of input channels is not divisble by groups."; + count *= input_channel / conv_2d_transpose_attr->groups; return count; } @@ -131,20 +125,18 @@ int64_t DenseMacCount(const Call& call_node) { return 0; } Array args = call_node->args; - CHECK_EQ(args.size(), 2) - << "The number of input arguments of a Dense node should be 2."; + CHECK_EQ(args.size(), 2) << "The number of input arguments of a Dense node should be 2."; const auto* data_type = args[0]->checked_type().as(); const auto* weight_type = args[1]->checked_type().as(); Array data_shape = data_type->shape; Array weight_shape = weight_type->shape; CHECK(data_shape.size() == 2 && weight_shape.size() == 2) - << "The dimension of an input tensor to Dense node should be 2."; + << "The dimension of an input tensor to Dense node should be 2."; int64_t d1 = static_cast(data_shape[0].as()->value); int64_t d2 = static_cast(data_shape[1].as()->value); int64_t d3 = static_cast(weight_shape[0].as()->value); int64_t d4 = static_cast(weight_shape[1].as()->value); - CHECK_EQ(d2, d4) - << "The dimensions of input arguments do not match."; + CHECK_EQ(d2, d4) << "The dimensions of input arguments do not match."; int64_t count = d1 * d2 * d3; return count; } @@ -165,23 +157,17 @@ int64_t BatchMatmulMacCount(const Call& call_node) { return batch * m * k * n; } -RELAY_REGISTER_OP("nn.conv2d") -.set_attr("FMacCount", ConvMacCount); +RELAY_REGISTER_OP("nn.conv2d").set_attr("FMacCount", ConvMacCount); -RELAY_REGISTER_OP("nn.conv2d_transpose") -.set_attr("FMacCount", Conv2dTransposeMacCount); +RELAY_REGISTER_OP("nn.conv2d_transpose").set_attr("FMacCount", Conv2dTransposeMacCount); -RELAY_REGISTER_OP("nn.dense") -.set_attr("FMacCount", DenseMacCount); +RELAY_REGISTER_OP("nn.dense").set_attr("FMacCount", DenseMacCount); -RELAY_REGISTER_OP("nn.batch_matmul") -.set_attr("FMacCount", BatchMatmulMacCount); +RELAY_REGISTER_OP("nn.batch_matmul").set_attr("FMacCount", BatchMatmulMacCount); class MacCounter : private ExprVisitor { public: - MacCounter() { - count_ = 0; - } + MacCounter() { count_ = 0; } static int64_t GetTotalMacNumber(const Expr& expr) { LOG(INFO) << "This pass only counts MACs in direct conv2d, " << "conv2d_transpose, dense, and batch_matmul ops"; @@ -192,8 +178,7 @@ class MacCounter : private ExprVisitor { private: void VisitExpr_(const CallNode* call_node) final { - static const auto& fprep = - Op::GetAttr("FMacCount"); + static const auto& fprep = Op::GetAttrMap("FMacCount"); auto f = fprep.get(call_node->op, nullptr); if (f != nullptr) count_ += f(GetRef(call_node)); ExprVisitor::VisitExpr_(call_node); @@ -202,12 +187,9 @@ class MacCounter : private ExprVisitor { int64_t count_; }; -int64_t GetTotalMacNumber(const Expr& expr) { - return MacCounter::GetTotalMacNumber(expr); -} +int64_t GetTotalMacNumber(const Expr& expr) { return MacCounter::GetTotalMacNumber(expr); } -TVM_REGISTER_GLOBAL("relay.analysis.GetTotalMacNumber") -.set_body_typed(GetTotalMacNumber); +TVM_REGISTER_GLOBAL("relay.analysis.GetTotalMacNumber").set_body_typed(GetTotalMacNumber); } // namespace mac_count } // namespace relay diff --git a/src/relay/analysis/match_exhaustion.cc b/src/relay/analysis/match_exhaustion.cc index eeb7fce18c52..e852c40dfeba 100644 --- a/src/relay/analysis/match_exhaustion.cc +++ b/src/relay/analysis/match_exhaustion.cc @@ -27,10 +27,11 @@ * code correctness, since hitting an unmatched case results in a * dynamic error unless exhaustiveness is checked in advance. */ -#include #include +#include #include #include + #include namespace tvm { @@ -154,17 +155,14 @@ Array> CartesianProduct(Array> fields) { } Array ExpandWildcardsConstructor(const PatternConstructor& clause_ctor, - const Pattern& cand, - const IRModule& mod); + const Pattern& cand, const IRModule& mod); -Array ExpandWildcardsTuple(const PatternTuple& clause_tuple, - const Pattern& cand, +Array ExpandWildcardsTuple(const PatternTuple& clause_tuple, const Pattern& cand, const IRModule& mod); // Expands all wildcards in the candidate pattern once // Returns a list of all possible expansions. -Array ExpandWildcards(const Pattern& clause_pat, - const Pattern& cand, +Array ExpandWildcards(const Pattern& clause_pat, const Pattern& cand, const IRModule& mod) { if (auto clause_ctor = clause_pat.as()) { return ExpandWildcardsConstructor(GetRef(clause_ctor), cand, mod); @@ -179,8 +177,7 @@ Array ExpandWildcards(const Pattern& clause_pat, // Use the pattern to decide which constructors to insert. // Returns a list of all possible expansions. Array ExpandWildcardsConstructor(const PatternConstructor& clause_ctor, - const Pattern& cand, - const IRModule& mod) { + const Pattern& cand, const IRModule& mod) { auto gtv = Downcast(clause_ctor->constructor->belong_to); // for a wildcard node, create constructor nodes with wildcards for all args. @@ -203,9 +200,8 @@ Array ExpandWildcardsConstructor(const PatternConstructor& clause_ctor, // for constructors, we will expand the wildcards in any field that is an ADT. Array> values_by_field; for (size_t i = 0; i < ctor_cand->constructor->inputs.size(); i++) { - values_by_field.push_back(ExpandWildcards(clause_ctor->patterns[i], - ctor_cand->patterns[i], - mod)); + values_by_field.push_back( + ExpandWildcards(clause_ctor->patterns[i], ctor_cand->patterns[i], mod)); } // generate new candidates using a cartesian product. @@ -219,8 +215,7 @@ Array ExpandWildcardsConstructor(const PatternConstructor& clause_ctor, // Expands all wildcards in the candidate pattern once. // Returns a list of all possible expansions. -Array ExpandWildcardsTuple(const PatternTuple& clause_tuple, - const Pattern& cand, +Array ExpandWildcardsTuple(const PatternTuple& clause_tuple, const Pattern& cand, const IRModule& mod) { // for a wildcard node, create constructor nodes with wildcards for all args. if (cand.as()) { @@ -236,9 +231,8 @@ Array ExpandWildcardsTuple(const PatternTuple& clause_tuple, // for constructors, we will expand the wildcards in any field that is an ADT. Array> values_by_field; for (size_t i = 0; i < tuple_cand->patterns.size(); i++) { - values_by_field.push_back(ExpandWildcards(clause_tuple->patterns[i], - tuple_cand->patterns[i], - mod)); + values_by_field.push_back( + ExpandWildcards(clause_tuple->patterns[i], tuple_cand->patterns[i], mod)); } // generate new candidates using a cartesian product @@ -311,14 +305,10 @@ Array UnmatchedCases(const Match& match, const IRModule& mod) { // expose for testing only TVM_REGISTER_GLOBAL("relay.analysis.unmatched_cases") -.set_body_typed( - [](const Match& match, const IRModule& mod_ref) { - IRModule call_mod = mod_ref; - if (!call_mod.defined()) { - call_mod = IRModule({}, {}); - } - return UnmatchedCases(match, call_mod); - }); + .set_body_typed([](const Match& match, const Optional& mod_ref) { + IRModule call_mod = mod_ref.defined() ? mod_ref.value() : IRModule({}, {}); + return UnmatchedCases(match, call_mod); + }); } // namespace relay } // namespace tvm diff --git a/src/relay/analysis/type_solver.cc b/src/relay/analysis/type_solver.cc index 650403ca5267..a192002825e6 100644 --- a/src/relay/analysis/type_solver.cc +++ b/src/relay/analysis/type_solver.cc @@ -21,26 +21,25 @@ * \file type_solver.cc * \brief Type solver implementations. */ -#include +#include "type_solver.h" + #include +#include #include -#include + #include +#include #include #include -#include "type_solver.h" namespace tvm { namespace relay { class TypeSolver::Reporter : public TypeReporterNode { public: - explicit Reporter(TypeSolver* solver) - : solver_(solver) {} + explicit Reporter(TypeSolver* solver) : solver_(solver) {} - void Assign(const Type& dst, const Type& src) final { - solver_->Unify(dst, src, location); - } + void Assign(const Type& dst, const Type& src) final { solver_->Unify(dst, src, location); } bool Assert(const IndexExpr& cond) final { if (const int64_t* pdiff = tir::as_const_int(cond)) { @@ -58,13 +57,9 @@ class TypeSolver::Reporter : public TypeReporterNode { return true; } - TVM_DLL void SetLocation(const ObjectRef& ref) final { - location = ref; - } + TVM_DLL void SetLocation(const ObjectRef& ref) final { location = ref; } - TVM_DLL IRModule GetModule() final { - return this->solver_->module_; - } + TVM_DLL IRModule GetModule() final { return this->solver_->module_; } private: /*! \brief The location to report unification errors at. */ @@ -76,7 +71,7 @@ class TypeSolver::Reporter : public TypeReporterNode { class TypeSolver::OccursChecker : public TypeVisitor { public: explicit OccursChecker(TypeSolver* solver, TypeNode* var) - : solver_(solver), var_(var), found_(false) {} + : solver_(solver), var_(var), found_(false) {} bool Check(const Type& t) { VisitType(t); @@ -112,25 +107,24 @@ class TypeSolver::Unifier : public TypeFunctor { if (lhs->resolved_type.as()) { CHECK(!OccursCheck(lhs, rhs->resolved_type)) - << "Incomplete type " << lhs->resolved_type << " occurs in " - << rhs->resolved_type << ", cannot unify"; + << "Incomplete type " << lhs->resolved_type << " occurs in " << rhs->resolved_type + << ", cannot unify"; solver_->MergeFromTo(lhs, rhs); return rhs->resolved_type; } else if (rhs->resolved_type.as()) { CHECK(!OccursCheck(rhs, lhs->resolved_type)) - << "Incomplete type " << rhs->resolved_type << " occurs in " - << lhs->resolved_type << ", cannot unify"; + << "Incomplete type " << rhs->resolved_type << " occurs in " << lhs->resolved_type + << ", cannot unify"; solver_->MergeFromTo(rhs, lhs); return lhs->resolved_type; } else { Type resolved = this->VisitType(lhs->resolved_type, rhs->resolved_type); if (!resolved.defined()) { - solver_->ReportError( - ErrorBuilder() << "unable to unify: " - << "`" << PrettyPrint(lhs->resolved_type) << "` and `" - << PrettyPrint(rhs->resolved_type) << "`", - this->loc); + solver_->ReportError(ErrorBuilder() << "unable to unify: " + << "`" << PrettyPrint(lhs->resolved_type) << "` and `" + << PrettyPrint(rhs->resolved_type) << "`", + this->loc); return lhs->resolved_type; } else { TypeNode* top = solver_->GetTypeNode(resolved); @@ -181,8 +175,8 @@ class TypeSolver::Unifier : public TypeFunctor { if (ulhs.same_as(urhs)) { return ulhs; } - if (ulhs.as() || urhs.as()) { - return Any::make(); + if (ulhs.as() || urhs.as()) { + return Any(); } auto left_index0 = ulhs.as(); @@ -227,14 +221,11 @@ class TypeSolver::Unifier : public TypeFunctor { tvm::Array shape; if (tt1->shape.size() != tt2->shape.size()) { - this->solver_->ReportError( - ErrorBuilder() << - "tensor type `" << PrettyPrint(tt1) << - "` has " << tt1->shape.size() << - " dimensions, while `" << - PrettyPrint(tt2) << - "` has " << tt2->shape.size() << - " dimensions", this->loc); + this->solver_->ReportError(ErrorBuilder() << "tensor type `" << PrettyPrint(tt1) << "` has " + << tt1->shape.size() << " dimensions, while `" + << PrettyPrint(tt2) << "` has " << tt2->shape.size() + << " dimensions", + this->loc); return Type(nullptr); } @@ -259,12 +250,8 @@ class TypeSolver::Unifier : public TypeFunctor { ErrorBuilder err; err << "in particular "; for (auto mismatch : mismatches) { - err << "dimension " - << std::get<0>(mismatch) - << " conflicts " - << std::get<1>(mismatch) - << " does not match " - << std::get<2>(mismatch); + err << "dimension " << std::get<0>(mismatch) << " conflicts " << std::get<1>(mismatch) + << " does not match " << std::get<2>(mismatch); } Error error(err); this->solver_->ReportError(error, this->loc); @@ -293,9 +280,8 @@ class TypeSolver::Unifier : public TypeFunctor { Type VisitType_(const FuncTypeNode* op, const Type& tn) final { const auto* ftn = tn.as(); - if (!ftn - || op->arg_types.size() != ftn->arg_types.size() - || op->type_constraints.size() != ftn->type_constraints.size()) { + if (!ftn || op->arg_types.size() != ftn->arg_types.size() || + op->type_constraints.size() != ftn->type_constraints.size()) { return Type(nullptr); } @@ -316,10 +302,7 @@ class TypeSolver::Unifier : public TypeFunctor { subst_map.Set(op->type_params[i], IncompleteType(kType)); } - FuncType ft = FuncType(op->arg_types, - op->ret_type, - ft_type_params, - op->type_constraints); + FuncType ft = FuncType(op->arg_types, op->ret_type, ft_type_params, op->type_constraints); auto ft1 = Downcast(Bind(ft, subst_map)); auto ft2 = GetRef(ftn); @@ -333,8 +316,7 @@ class TypeSolver::Unifier : public TypeFunctor { std::vector type_constraints; for (size_t i = 0; i < ft1->type_constraints.size(); ++i) { - Type unified_constraint = Unify(ft1->type_constraints[i], - ft2->type_constraints[i]); + Type unified_constraint = Unify(ft1->type_constraints[i], ft2->type_constraints[i]); const auto* tcn = unified_constraint.as(); CHECK(tcn) << "Two type constraints unified into a non-constraint?" << ft1->type_constraints[i] << " and " << ft2->type_constraints[i]; @@ -397,12 +379,10 @@ class TypeSolver::Resolver : public TypeMutator { class TypeSolver::Propagator : public TypeFunctor { public: explicit Propagator(TypeSolver* solver, const std::unordered_set* rels) - : solver_(solver), rels_(rels) {} + : solver_(solver), rels_(rels) {} // adds the relation node to t and all child types of t - void Propagate(const Type& t) { - VisitType(t); - } + void Propagate(const Type& t) { VisitType(t); } void UpdateRelSet(const Type& t) { TypeNode* tnode = solver_->GetTypeNode(t); @@ -532,10 +512,8 @@ class TypeSolver::Merger : public TypeFunctor { }; // constructor -TypeSolver::TypeSolver( - const GlobalVar& current_func, - const IRModule& module, - ErrorReporter* err_reporter) +TypeSolver::TypeSolver(const GlobalVar& current_func, const IRModule& module, + ErrorReporter* err_reporter) : reporter_(make_object(this)), current_func(current_func), err_reporter_(err_reporter), @@ -566,7 +544,7 @@ Type TypeSolver::Unify(const Type& dst, const Type& src, const ObjectRef& loc) { return unifier.Unify(dst, src); } -void TypeSolver::ReportError(const Error& err, const ObjectRef& location) { +void TypeSolver::ReportError(const Error& err, const ObjectRef& location) { CHECK(location.defined()); CHECK(current_func.defined()); err_reporter_->ReportAt(current_func, location, err); @@ -583,20 +561,19 @@ void TypeSolver::AddConstraint(const TypeConstraint& constraint, const ObjectRef // populate the type information. for (size_t i = 0; i < op->args.size(); ++i) { // insert link to the type list - LinkNode* tlink = arena_.make >(); + LinkNode* tlink = arena_.make>(); TypeNode* tnode = GetTypeNode(op->args[i]); tlink->value = tnode; rnode->type_list.Push(tlink); // insert type->relation node - std::unordered_set singleton { rnode }; + std::unordered_set singleton{rnode}; Propagator prop(this, &singleton); prop.Propagate(tnode->resolved_type); } // add the relation to the working queue. this->AddToQueue(rnode); } else { - LOG(FATAL) << "Do not know how to handle constraint type" - << constraint->GetTypeKey(); + LOG(FATAL) << "Do not know how to handle constraint type" << constraint->GetTypeKey(); } } @@ -642,11 +619,9 @@ bool TypeSolver::Solve() { rnode->resolved = false; } catch (const dmlc::Error& err) { rnode->resolved = false; - this->ReportError( - ErrorBuilder() << "an internal invariant was violated while " - << "typechecking your program " - << err.what(), - rnode->location); + this->ReportError(ErrorBuilder() << "an internal invariant was violated while " + << "typechecking your program " << err.what(), + rnode->location); } // Mark inqueue as false after the function call @@ -661,45 +636,40 @@ bool TypeSolver::Solve() { // Expose type solver only for debugging purposes. TVM_REGISTER_GLOBAL("relay.analysis._test_type_solver") -.set_body([](runtime::TVMArgs args, runtime::TVMRetValue* ret) { - using runtime::PackedFunc; - using runtime::TypedPackedFunc; - ErrorReporter *err_reporter = new ErrorReporter(); - auto module = IRModule({}, {}); - auto dummy_fn_name = GlobalVar("test"); - module->Add(dummy_fn_name, Function({}, Tuple(tvm::Array({})), Type(), {}, {})); - auto solver = std::make_shared(dummy_fn_name, module, err_reporter); - - auto mod = [module, solver, err_reporter](std::string name) -> PackedFunc { - if (name == "Solve") { - return TypedPackedFunc([solver]() { - return solver->Solve(); - }); - } else if (name == "Unify") { - return TypedPackedFunc( - [module, solver, err_reporter](Type lhs, Type rhs) { - auto res = solver->Unify(lhs, rhs, lhs); - if (err_reporter->AnyErrors()) { - err_reporter->RenderErrors(module, true); - } - return res; - }); - } else if (name == "Resolve") { - return TypedPackedFunc([solver](Type t) { - return solver->Resolve(t); - }); - } else if (name == "AddConstraint") { - return TypedPackedFunc([solver](TypeConstraint c) { - Expr e = Var("dummy_var", - IncompleteType(Kind::kType)); + .set_body([](runtime::TVMArgs args, runtime::TVMRetValue* ret) { + using runtime::PackedFunc; + using runtime::TypedPackedFunc; + ErrorReporter* err_reporter = new ErrorReporter(); + auto module = IRModule({}, {}); + auto dummy_fn_name = GlobalVar("test"); + module->Add(dummy_fn_name, Function({}, Tuple(tvm::Array({})), Type(), {}, {})); + auto solver = std::make_shared(dummy_fn_name, module, err_reporter); + + auto mod = [module, solver, err_reporter](std::string name) -> PackedFunc { + if (name == "Solve") { + return TypedPackedFunc([solver]() { return solver->Solve(); }); + } else if (name == "Unify") { + return TypedPackedFunc( + [module, solver, err_reporter](Type lhs, Type rhs) { + auto res = solver->Unify(lhs, rhs, lhs); + if (err_reporter->AnyErrors()) { + err_reporter->RenderErrors(module, true); + } + return res; + }); + } else if (name == "Resolve") { + return TypedPackedFunc([solver](Type t) { return solver->Resolve(t); }); + } else if (name == "AddConstraint") { + return TypedPackedFunc([solver](TypeConstraint c) { + Expr e = Var("dummy_var", IncompleteType(Kind::kType)); return solver->AddConstraint(c, e); }); - } else { - return PackedFunc(); - } - }; - *ret = runtime::TypedPackedFunc(mod); - }); + } else { + return PackedFunc(); + } + }; + *ret = runtime::TypedPackedFunc(mod); + }); } // namespace relay } // namespace tvm diff --git a/src/relay/analysis/type_solver.h b/src/relay/analysis/type_solver.h index 8ccc2c7244b0..dcd8de075854 100644 --- a/src/relay/analysis/type_solver.h +++ b/src/relay/analysis/type_solver.h @@ -24,21 +24,23 @@ #ifndef TVM_RELAY_ANALYSIS_TYPE_SOLVER_H_ #define TVM_RELAY_ANALYSIS_TYPE_SOLVER_H_ +#include +#include #include #include -#include -#include -#include + #include #include #include +#include + #include "../../support/arena.h" namespace tvm { namespace relay { -using support::LinkNode; using support::LinkedList; +using support::LinkNode; /*! * \brief Interface of type solver used in type inference. @@ -166,7 +168,7 @@ class TypeSolver { /*! \brief Number of resolved relations */ size_t num_resolved_rels_{0}; /*! \brief map from types to type nodes. */ - std::unordered_map tmap_; + std::unordered_map tmap_; /*! \brief Internal queue to update the relation */ std::queue update_queue_; /*! \brief allocator of all the internal node obhect*/ diff --git a/src/relay/analysis/util.cc b/src/relay/analysis/util.cc index a86faeb50531..b681b90d58e8 100644 --- a/src/relay/analysis/util.cc +++ b/src/relay/analysis/util.cc @@ -25,17 +25,20 @@ */ #include #include +#include #include #include +#include #include + #include "../transforms/pass_util.h" namespace tvm { namespace relay { -template +template struct InsertionSet { - std::unordered_set set; + std::unordered_set set; std::vector data; void Insert(const T& t) { if (set.count(t) == 0) { @@ -47,10 +50,8 @@ struct InsertionSet { class TypeVarTVisitor : public TypeVisitor { public: - TypeVarTVisitor( - InsertionSet* type_vars, - InsertionSet* bound_type_vars) - : type_vars_(type_vars), bound_type_vars_(bound_type_vars) { } + TypeVarTVisitor(InsertionSet* type_vars, InsertionSet* bound_type_vars) + : type_vars_(type_vars), bound_type_vars_(bound_type_vars) {} void VisitType_(const TypeVarNode* tp) final { TypeVar var = GetRef(tp); @@ -149,8 +150,7 @@ class TypeVarEVisitor : private ExprVisitor { } void VisitType(const Type& t) final { - TypeVarTVisitor(&type_vars_, &bound_type_vars_) - .VisitType(t); + TypeVarTVisitor(&type_vars_, &bound_type_vars_).VisitType(t); } private: @@ -204,9 +204,7 @@ class VarVisitor : protected ExprVisitor, protected PatternVisitor { vars_.Insert(v); } - void VisitExpr_(const VarNode* var) final { - vars_.Insert(GetRef(var)); - } + void VisitExpr_(const VarNode* var) final { vars_.Insert(GetRef(var)); } void VisitExpr_(const FunctionNode* op) final { for (const auto& param : op->params) { @@ -221,13 +219,9 @@ class VarVisitor : protected ExprVisitor, protected PatternVisitor { VisitExpr(op->body); } - void VisitPattern(const Pattern& p) final { - PatternVisitor::VisitPattern(p); - } + void VisitPattern(const Pattern& p) final { PatternVisitor::VisitPattern(p); } - void VisitPattern_(const PatternVarNode* op) final { - MarkBounded(op->var); - } + void VisitPattern_(const PatternVarNode* op) final { MarkBounded(op->var); } private: InsertionSet vars_; @@ -258,82 +252,66 @@ tvm::Array AllTypeVars(const Type& type, const IRModule& mod) { return TypeVarEVisitor(mod).All(type); } -tvm::Array FreeVars(const Expr& expr) { - return VarVisitor().Free(expr); -} +tvm::Array FreeVars(const Expr& expr) { return VarVisitor().Free(expr); } -tvm::Array BoundVars(const Expr& expr) { - return VarVisitor().Bound(expr); -} +tvm::Array BoundVars(const Expr& expr) { return VarVisitor().Bound(expr); } -tvm::Array BoundVars(const Pattern& pat) { - return VarVisitor().Bound(pat); -} +tvm::Array BoundVars(const Pattern& pat) { return VarVisitor().Bound(pat); } -tvm::Array AllVars(const Expr& expr) { - return VarVisitor().All(expr); -} +tvm::Array AllVars(const Expr& expr) { return VarVisitor().All(expr); } -TVM_REGISTER_GLOBAL("relay.analysis.free_vars") -.set_body_typed(FreeVars); +TVM_REGISTER_GLOBAL("relay.analysis.free_vars").set_body_typed(FreeVars); -TVM_REGISTER_GLOBAL("relay.analysis.bound_vars") - .set_body([](TVMArgs args, TVMRetValue* ret) { - ObjectRef x = args[0]; - if (x.as()) { - *ret = BoundVars(Downcast(x)); - } else { - *ret = BoundVars(Downcast(x)); - } - }); +TVM_REGISTER_GLOBAL("relay.analysis.bound_vars").set_body([](TVMArgs args, TVMRetValue* ret) { + ObjectRef x = args[0]; + if (x.as()) { + *ret = BoundVars(Downcast(x)); + } else { + *ret = BoundVars(Downcast(x)); + } +}); -TVM_REGISTER_GLOBAL("relay.analysis.all_vars") -.set_body_typed(AllVars); +TVM_REGISTER_GLOBAL("relay.analysis.all_vars").set_body_typed(AllVars); -TVM_REGISTER_GLOBAL("relay.analysis.free_type_vars") -.set_body([](TVMArgs args, TVMRetValue* ret) { - ObjectRef x = args[0]; - IRModule mod = args[1]; - if (x.as()) { - *ret = FreeTypeVars(Downcast(x), mod); - } else { - *ret = FreeTypeVars(Downcast(x), mod); - } - }); - -TVM_REGISTER_GLOBAL("relay.analysis.bound_type_vars") - .set_body([](TVMArgs args, TVMRetValue* ret) { - ObjectRef x = args[0]; - IRModule mod = args[1]; - if (x.as()) { - *ret = BoundTypeVars(Downcast(x), mod); - } else { - *ret = BoundTypeVars(Downcast(x), mod); - } - }); - -TVM_REGISTER_GLOBAL("relay.analysis.all_type_vars") - .set_body([](TVMArgs args, TVMRetValue* ret) { - ObjectRef x = args[0]; - IRModule mod = args[1]; - if (x.as()) { - *ret = AllTypeVars(Downcast(x), mod); - } else { - *ret = AllTypeVars(Downcast(x), mod); - } - }); +TVM_REGISTER_GLOBAL("relay.analysis.free_type_vars").set_body([](TVMArgs args, TVMRetValue* ret) { + ObjectRef x = args[0]; + IRModule mod = args[1]; + if (x.as()) { + *ret = FreeTypeVars(Downcast(x), mod); + } else { + *ret = FreeTypeVars(Downcast(x), mod); + } +}); + +TVM_REGISTER_GLOBAL("relay.analysis.bound_type_vars").set_body([](TVMArgs args, TVMRetValue* ret) { + ObjectRef x = args[0]; + IRModule mod = args[1]; + if (x.as()) { + *ret = BoundTypeVars(Downcast(x), mod); + } else { + *ret = BoundTypeVars(Downcast(x), mod); + } +}); + +TVM_REGISTER_GLOBAL("relay.analysis.all_type_vars").set_body([](TVMArgs args, TVMRetValue* ret) { + ObjectRef x = args[0]; + IRModule mod = args[1]; + if (x.as()) { + *ret = AllTypeVars(Downcast(x), mod); + } else { + *ret = AllTypeVars(Downcast(x), mod); + } +}); /*! * \brief Get reference counter of each internal ExprNode in body. * \param body The body expression. * \return The reference count mapping. */ -std::unordered_map -GetExprRefCount(const Expr& body) { +std::unordered_map GetExprRefCount(const Expr& body) { class ExprRefCounter : private MixedModeVisitor { public: - std::unordered_map - Get(const Expr& body) { + std::unordered_map Get(const Expr& body) { this->VisitExpr(body); return std::move(this->visit_counter_); } @@ -361,13 +339,13 @@ bool IsNDArrayAllGreaterEqual(const runtime::NDArray& tensor, T value) { return true; } -// Cache the operators that are checked recursively to reduce lookup overhead. -static const auto& expand_dims_op = Op::Get("expand_dims"); -static const auto& reshape_op = Op::Get("reshape"); -static const auto& transpose_op = Op::Get("transpose"); -static const auto& squeeze_op = Op::Get("squeeze"); - bool IsAllPositiveConstant(const Expr& expr) { + // Cache the operators that are checked recursively to reduce lookup overhead. + static const auto& expand_dims_op = Op::Get("expand_dims"); + static const auto& reshape_op = Op::Get("reshape"); + static const auto& transpose_op = Op::Get("transpose"); + static const auto& squeeze_op = Op::Get("squeeze"); + // peel through a few common transform ops. if (const auto* constant = expr.as()) { const auto& tensor = constant->data; @@ -391,9 +369,7 @@ bool IsAllPositiveConstant(const Expr& expr) { } } else if (const auto* op = expr.as()) { // tail recursion. - if (op->op == expand_dims_op || - op->op == reshape_op || - op->op == transpose_op || + if (op->op == expand_dims_op || op->op == reshape_op || op->op == transpose_op || op->op == squeeze_op) { return IsAllPositiveConstant(op->args[0]); } else { @@ -419,17 +395,11 @@ Type TypeSubst(const Type& type, const tvm::Map& subst_map) { Expr TypeSubst(const Expr& expr, const tvm::Map& subst_map) { class TypeSubstMutator : public ExprMutator, public PatternMutator { public: - explicit TypeSubstMutator(const tvm::Map& subst_map) : subst_map_(subst_map) { } - Type VisitType(const Type& t) final { - return TypeSubst(t, subst_map_); - } - Var VisitVar(const Var& v) final { - return Downcast(VisitExpr(v)); - } + explicit TypeSubstMutator(const tvm::Map& subst_map) : subst_map_(subst_map) {} + Type VisitType(const Type& t) final { return TypeSubst(t, subst_map_); } + Var VisitVar(const Var& v) final { return Downcast(VisitExpr(v)); } - Pattern VisitPattern(const Pattern& p) final { - return PatternMutator::VisitPattern(p); - } + Pattern VisitPattern(const Pattern& p) final { return PatternMutator::VisitPattern(p); } Clause VisitClause(const Clause& c) final { Pattern pat = VisitPattern(c->lhs); @@ -446,5 +416,58 @@ Expr TypeSubst(const Expr& expr, const tvm::Map& subst_map) { return ret; } +struct IsDynamicVisitor : public TypeVisitor { + bool is_dyn{false}; + void VisitType_(const TensorTypeNode* tt) { + for (auto dim : tt->shape) { + if (dim.as()) { + is_dyn = true; + break; + } + } + } +}; + +bool IsDynamic(const Type& ty) { + IsDynamicVisitor v; + v.VisitType(ty); + return v.is_dyn; +} + +TVM_REGISTER_GLOBAL("relay.ir.IsDynamic").set_body_typed(IsDynamic); + +bool IsDataDependant(const CallNode* call) { + static auto tshape_data_dependant = Op::GetAttrMap("TShapeDataDependant"); + Op op = Downcast(call->op); + + if (!tshape_data_dependant.count(op)) { + return false; + } + + if (op->name == "reshape") { + if (const auto* attrs = call->attrs.as()) { + if (attrs->newshape) { + // If newshape attribute exists, it isn't data dependant. + return false; + } + } + } else if (op->name == "topk") { + if (const auto* attrs = call->attrs.as()) { + if (attrs->k) { + // If k attribute exists, it isn't data dependant. + return false; + } + } + } else if (op->name == "strided_slice") { + if (const auto* attrs = call->attrs.as()) { + if (attrs->begin && attrs->end && attrs->strides) { + // not data dependant if begin, end and strides exist + return false; + } + } + } + + return tshape_data_dependant[op]; +} } // namespace relay } // namespace tvm diff --git a/src/relay/analysis/well_formed.cc b/src/relay/analysis/well_formed.cc index f3a2cadb363f..16f8285c83b3 100644 --- a/src/relay/analysis/well_formed.cc +++ b/src/relay/analysis/well_formed.cc @@ -24,26 +24,24 @@ #include #include #include + #include namespace tvm { namespace relay { - //! brief make sure each Var is bound at most once in a scope. class WellFormedChecker : private ExprVisitor, PatternVisitor { bool well_formed = true; - std::vector> scope; - std::unordered_set current_bound; - std::unordered_set total_bound; - std::unordered_set free; + std::vector> scope; + std::unordered_set current_bound; + std::unordered_set total_bound; + std::unordered_set free; struct Scope { WellFormedChecker* wfc; - explicit Scope(WellFormedChecker* wfc) : wfc(wfc) { - wfc->scope.push_back({{}}); - } + explicit Scope(WellFormedChecker* wfc) : wfc(wfc) { wfc->scope.push_back({{}}); } ~Scope() { CHECK_GE(wfc->scope.size(), 0); for (const Var& v : wfc->scope.back()) { @@ -98,13 +96,9 @@ class WellFormedChecker : private ExprVisitor, PatternVisitor { VisitExpr(c->rhs); } - void VisitPattern(const Pattern& p) final { - PatternVisitor::VisitPattern(p); - } + void VisitPattern(const Pattern& p) final { PatternVisitor::VisitPattern(p); } - void VisitVar(const Var& v) final { - Bound(v); - } + void VisitVar(const Var& v) final { Bound(v); } void VisitExpr(const Expr& e) final { if (auto v = e.as()) { @@ -121,12 +115,9 @@ class WellFormedChecker : private ExprVisitor, PatternVisitor { } }; -bool WellFormed(const Expr& e) { - return WellFormedChecker().CheckWellFormed(e); -} +bool WellFormed(const Expr& e) { return WellFormedChecker().CheckWellFormed(e); } -TVM_REGISTER_GLOBAL("relay.analysis.well_formed") -.set_body_typed(WellFormed); +TVM_REGISTER_GLOBAL("relay.analysis.well_formed").set_body_typed(WellFormed); } // namespace relay } // namespace tvm diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index e2d5e93fa5c1..f9ce24d410b7 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -21,14 +21,14 @@ * \file relay/backend/build_module.cc * \brief Code generation for TVM's graph runtime. */ -#include #include -#include -#include +#include #include -#include #include -#include +#include +#include +#include + #include #include "../../target/source/codegen_source_base.h" @@ -38,7 +38,6 @@ namespace tvm { namespace relay { namespace backend { - using TargetsMap = Map; using namespace tvm::relay::transform; @@ -64,24 +63,18 @@ struct GraphCodegen { } ~GraphCodegen() {} - void Init(runtime::Module* m, TargetsMap targets) { - CallFunc("init", m, targets); - } + void Init(runtime::Module* m, TargetsMap targets) { CallFunc("init", m, targets); } - void Codegen(const Function& func) { - CallFunc("codegen", func); - } + void Codegen(const Function& func) { CallFunc("codegen", func); } - std::string GetJSON() { - return CallFunc("get_graph_json", nullptr); - } + std::string GetJSON() { return CallFunc("get_graph_json", nullptr); } Array GetExternalModules() { return CallFunc>("get_external_modules", nullptr); } - Map GetIRModule() { - return CallFunc>("get_irmodule", nullptr); + Map GetIRModule() { + return CallFunc>("get_irmodule", nullptr); } std::unordered_map GetParams() { @@ -97,13 +90,13 @@ struct GraphCodegen { protected: tvm::runtime::Module mod; - template - R CallFunc(const std::string &name, Args... args) { + template + R CallFunc(const std::string& name, Args... args) { auto pf = mod.GetFunction(name, false); return pf(std::forward(args)...); } - template - void CallFunc(const std::string &name, Args... args) { + template + void CallFunc(const std::string& name, Args... args) { auto pf = mod.GetFunction(name, false); pf(std::forward(args)...); return; @@ -122,43 +115,38 @@ class RelayBuildModule : public runtime::ModuleNode { * \param sptr_to_self The pointer to the module node. * \return The corresponding member function. */ - PackedFunc GetFunction(const std::string& name, - const ObjectPtr& sptr_to_self) final { + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { if (name == "get_graph_json") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->GetGraphJSON(); - }); + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetGraphJSON(); }); } else if (name == "get_module") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->GetModule(); - }); + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetModule(); }); } else if (name == "build") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { CHECK_EQ(args.num_args, 3); this->Build(args[0], args[1], args[2]); }); } else if (name == "list_params") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->ListParamNames(); - }); + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->ListParamNames(); }); } else if (name == "get_params") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->GetParams(); - }); + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetParams(); }); } else if (name == "set_params") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - Map params = args[0]; + Map params = args[0]; for (const auto& kv : params) { this->SetParam(kv.first, kv.second->data); } }); } else if (name == "get_irmodule") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->graph_codegen_->GetIRModule(); + *rv = this->graph_codegen_->GetIRModule(); }); } else if (name == "get_external_modules") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->graph_codegen_->GetExternalModules(); + *rv = this->graph_codegen_->GetExternalModules(); }); } else if (name == "optimize") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { @@ -176,18 +164,14 @@ class RelayBuildModule : public runtime::ModuleNode { * * \return const std::string graph_json */ - const std::string& GetGraphJSON() { - return ret_.graph_json; - } + const std::string& GetGraphJSON() { return ret_.graph_json; } /*! * \brief Get the Module object * * \return runtime::Module */ - runtime::Module GetModule() { - return ret_.mod; - } + runtime::Module GetModule() { return ret_.mod; } /*! * \brief List all paramter names @@ -205,10 +189,10 @@ class RelayBuildModule : public runtime::ModuleNode { /*! * \brief Get params dictionary * - * \return Map params dictionary + * \return Map params dictionary */ - Map GetParams() { - Map ret; + Map GetParams() { + Map ret; for (const auto& kv : ret_.params) { ret.Set(kv.first, Constant(kv.second)); } @@ -221,18 +205,14 @@ class RelayBuildModule : public runtime::ModuleNode { * \param name name of parameter * \param data_in input DLTensor */ - void SetParam(const std::string& name, runtime::NDArray data_in) { - params_[name] = data_in; - } + void SetParam(const std::string& name, runtime::NDArray data_in) { params_[name] = data_in; } /*! * \brief type key * * \return const char* */ - const char* type_key() const final { - return "RelayBuildModule"; - } + const char* type_key() const final { return "RelayBuildModule"; } /*! * \brief Build relay IRModule for graph runtime @@ -241,9 +221,7 @@ class RelayBuildModule : public runtime::ModuleNode { * \param target Target device * \param target_host Host target device */ - void Build(IRModule mod, - const TargetsMap& targets, - const tvm::Target& target_host) { + void Build(IRModule mod, const TargetsMap& targets, const tvm::Target& target_host) { targets_ = targets; target_host_ = target_host; BuildRelay(mod, params_); @@ -259,13 +237,10 @@ class RelayBuildModule : public runtime::ModuleNode { * * \return relay::IRModule The updated Relay IR module after optimization. */ - IRModule Optimize( - IRModule relay_module, - const TargetsMap& targets, - const std::unordered_map& params) { + IRModule Optimize(IRModule relay_module, const TargetsMap& targets, + const std::unordered_map& params) { if (params.size()) { - CHECK(relay_module->ContainGlobalVar("main")) - << "Missing the main entry function"; + CHECK(relay_module->ContainGlobalVar("main")) << "Missing the main entry function"; GlobalVar main_glb_var = relay_module->GetGlobalVar("main"); Function main_func = Downcast(relay_module->Lookup(main_glb_var)); auto new_main = BindParamsByName(main_func, params); @@ -329,8 +304,11 @@ class RelayBuildModule : public runtime::ModuleNode { // Handle heterogeneous compilation. transform::PassContext pass_ctx = PassContext::Current(); if (targets_.size() > 1) { - relay_module = - RunDeviceAnnotationPass(relay_module, pass_ctx->fallback_device); + Optional opt_fallback_dev = + pass_ctx->GetConfig("relay.fallback_device_type", Integer(static_cast(kDLCPU))); + auto fallback_dev = opt_fallback_dev.value(); + CHECK_GT(fallback_dev->value, 0U); + relay_module = RunDeviceAnnotationPass(relay_module, fallback_dev->value); } // Fuse the operations if it is needed. @@ -387,8 +365,7 @@ class RelayBuildModule : public runtime::ModuleNode { * * \return updated_module The updated module after device annotation. */ - IRModule RunDeviceAnnotationPass(const IRModule& relay_module, - int fallback_device) { + IRModule RunDeviceAnnotationPass(const IRModule& relay_module, int fallback_device) { UpdateHeterogeneousInputs(fallback_device); auto rewrite = transform::RewriteAnnotatedOps(fallback_device); auto updated_module = rewrite(relay_module); @@ -417,12 +394,11 @@ class RelayBuildModule : public runtime::ModuleNode { break; } for (auto kv : annotation_map) { - CHECK_EQ(kv.second->value, dev_type) - << "Expressions in the function are " - << "annotated with various device types," - << "but not device copy operators " - << "found. Please check the " - << "RewriteAnnotation pass."; + CHECK_EQ(kv.second->value, dev_type) << "Expressions in the function are " + << "annotated with various device types," + << "but not device copy operators " + << "found. Please check the " + << "RewriteAnnotation pass."; } targets_.Set(0, CreateDefaultTarget(dev_type)); } @@ -436,9 +412,8 @@ class RelayBuildModule : public runtime::ModuleNode { * \param relay_module The Relay IR module. * \param params The parameters. */ - void BuildRelay( - IRModule relay_module, - const std::unordered_map& params) { + void BuildRelay(IRModule relay_module, + const std::unordered_map& params) { // Relay IRModule -> IRModule optimizations. relay_module = Optimize(relay_module, targets_, params); // Get the updated function. @@ -474,23 +449,19 @@ class RelayBuildModule : public runtime::ModuleNode { ret_.mod = tvm::codegen::CSourceModuleCreate(";", ""); } } else { - ret_.mod = tvm::build( - lowered_funcs, - target_host_, - BuildConfig::Current()); + ret_.mod = tvm::build(lowered_funcs, target_host_); } Array ext_mods = graph_codegen_->GetExternalModules(); // Import all external runtime modules. - for (const auto& it : ext_mods) - ret_.mod.Import(it); + for (const auto& it : ext_mods) ret_.mod.Import(it); } private: Target GetTargetHost() { Target target_host = target_host_; if (!target_host_.defined()) { - for (const auto &it : targets_) { + for (const auto& it : targets_) { if (it.second->device_type == kDLCPU) { target_host = it.second; break; @@ -517,20 +488,19 @@ runtime::Module RelayBuildCreate() { return runtime::Module(exec); } -TVM_REGISTER_GLOBAL("relay.build_module._BuildModule") -.set_body([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("relay.build_module._BuildModule").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = RelayBuildCreate(); }); TVM_REGISTER_GLOBAL("relay.build_module.BindParamsByName") -.set_body([](TVMArgs args, TVMRetValue* rv) { - Map params = args[1]; - std::unordered_map params_; - for (const auto& kv : params) { - params_[kv.first] = kv.second->data; - } - *rv = relay::backend::BindParamsByName(args[0], params_); -}); + .set_body([](TVMArgs args, TVMRetValue* rv) { + Map params = args[1]; + std::unordered_map params_; + for (const auto& kv : params) { + params_[kv.first] = kv.second->data; + } + *rv = relay::backend::BindParamsByName(args[0], params_); + }); } // namespace backend } // namespace relay diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index ce0a314f265b..3687b75c8ce8 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -45,6 +45,7 @@ #include #include +#include "../transforms/pass_util.h" #include "utils.h" namespace tvm { @@ -70,28 +71,6 @@ CCacheKey::CCacheKey(Function source_func, Target target) { data_ = std::move(n); } -struct IsDynamicVisitor : public TypeVisitor { - bool is_dyn{false}; - void VisitType_(const TensorTypeNode* tt) { - for (auto dim : tt->shape) { - if (dim.as()) { - is_dyn = true; - break; - } - } - } -}; - -bool IsDynamic(const Type& ty) { - IsDynamicVisitor v; - v.VisitType(ty); - return v.is_dyn; -} - -// TODO(@jroesch): MOVE ME -TVM_REGISTER_GLOBAL("relay.ir.IsDynamic") -.set_body_typed(IsDynamic); - Array GetShape(const Array& shape) { // for now, we always use int32 shape when possible // even if the result of shape inference becomes int64. @@ -124,8 +103,7 @@ class ScheduleGetter : public backend::MemoizedExprTranslator> for (Var param : prim_func->params) { Array inputs; if (const auto* ttype = param->checked_type().as()) { - tvm::te::Tensor tensor = tvm::te::placeholder( - GetShape(ttype->shape), ttype->dtype); + tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype); cache_node->inputs.push_back(tensor); inputs.push_back(tensor); } else { @@ -135,8 +113,7 @@ class ScheduleGetter : public backend::MemoizedExprTranslator> const auto* ttype = field.as(); // TODO(@icemelon): Allow recursive tuple CHECK(ttype != nullptr); - tvm::te::Tensor tensor = tvm::te::placeholder( - GetShape(ttype->shape), ttype->dtype); + tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype); cache_node->inputs.push_back(tensor); inputs.push_back(tensor); } @@ -149,7 +126,7 @@ class ScheduleGetter : public backend::MemoizedExprTranslator> constexpr static size_t kMaxFuncNameLength = 80; if (candidate_name.size() > kMaxFuncNameLength) { std::stringstream truncated_name; - truncated_name << candidate_name.substr(0, kMaxFuncNameLength); + truncated_name << candidate_name.substr(0, kMaxFuncNameLength); truncated_name << "_" << std::hash{}(candidate_name) << "_"; candidate_name = truncated_name.str(); } @@ -190,29 +167,31 @@ class ScheduleGetter : public backend::MemoizedExprTranslator> CHECK(op->is_scalar()); void* data = op->data->data; DataType dtype = DataType(op->data->dtype); - auto value = te::compute({}, [&](const Array&) { - if (dtype == DataType::Int(32)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Int(64)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Float(32)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Float(64)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Bool()) { - return make_const(dtype, static_cast(data)[0]); - } else { - LOG(FATAL) << "not handled"; - return tvm::PrimExpr(); - } - }, "compile_engine_const", topi::kBroadcast); + auto value = te::compute( + {}, + [&](const Array&) { + if (dtype == DataType::Int(32)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Int(64)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Float(32)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Float(64)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Bool()) { + return make_const(dtype, static_cast(data)[0]); + } else { + LOG(FATAL) << "not handled"; + return tvm::PrimExpr(); + } + }, + "compile_engine_const", topi::kBroadcast); scalars_.push_back(value->op); return {value}; } Array VisitExpr_(const CallNode* call_node) final { - static auto fpattern = - Op::GetAttr("TOpPattern"); + static auto fpattern = Op::GetAttrMap("TOpPattern"); static auto flower_call = tvm::runtime::Registry::Get("relay.backend.lower_call"); CHECK(flower_call) << "relay.backend.lower_call is not registered."; @@ -227,12 +206,10 @@ class ScheduleGetter : public backend::MemoizedExprTranslator> } } if (count_tuple) { - CHECK_EQ(call_node->args.size(), 1U) - << "Only allow function with a single tuple input"; + CHECK_EQ(call_node->args.size(), 1U) << "Only allow function with a single tuple input"; } - CHECK(call_node->op.as()) - << "Primitive function only allows call into primitive ops"; + CHECK(call_node->op.as()) << "Primitive function only allows call into primitive ops"; Op op = Downcast(call_node->op); Array outputs; @@ -240,8 +217,7 @@ class ScheduleGetter : public backend::MemoizedExprTranslator> // Skip fcompute for device copy operators as it is not registered. if (op == device_copy_op_) { const auto* copy_input = inputs[0].operator->(); - outputs.push_back(te::TensorNode::make(copy_input->shape, copy_input->dtype, - te::Operation(), 0)); + outputs.push_back(te::Tensor(copy_input->shape, copy_input->dtype, te::Operation(), 0)); } else { LoweredOutput lowered_out = (*flower_call)(GetRef(call_node), inputs, target_); outputs = lowered_out->outputs; @@ -251,8 +227,8 @@ class ScheduleGetter : public backend::MemoizedExprTranslator> int op_pattern = fpattern[op]; if (op_pattern >= kCommReduce) { CHECK(!master_op_.defined() || master_op_pattern_ < kCommReduce) - << "Two complicated op in a primitive function " - << " master=" << master_op_ << " current=" << op; + << "Two complicated op in a primitive function " + << " master=" << master_op_ << " current=" << op; } if (op_pattern >= master_op_pattern_) { master_op_ = op; @@ -261,8 +237,7 @@ class ScheduleGetter : public backend::MemoizedExprTranslator> master_implementation_ = impl; } if (outputs.size() != 1) { - const auto* tuple_type = - call_node->checked_type().as(); + const auto* tuple_type = call_node->checked_type().as(); CHECK(tuple_type) << "Expect output to be a tuple type"; CHECK_EQ(tuple_type->fields.size(), outputs.size()); } @@ -292,8 +267,7 @@ class ScheduleGetter : public backend::MemoizedExprTranslator> Array VisitExpr_(const TupleNode* op) final { Array fields; for (Expr field : op->fields) { - CHECK(field->checked_type().as()) - << "Only allow Tuple of Tensor"; + CHECK(field->checked_type().as()) << "Only allow Tuple of Tensor"; Array res = VisitExpr(field); CHECK_EQ(res.size(), 1); fields.push_back(res[0]); @@ -349,15 +323,15 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator> shape_inputs.push_back(shape_tensor); }; - if (const auto *ttype = param->checked_type().as()) { + if (const auto* ttype = param->checked_type().as()) { add_placeholder(ttype); } else { // flatten tuple of tensor type. - const auto *tuple_type = param->type_as(); + const auto* tuple_type = param->type_as(); // TODO(@icemelon): Support recursive tuple CHECK(tuple_type); for (Type field : tuple_type->fields) { - const auto *ttype = field.as(); + const auto* ttype = field.as(); CHECK(ttype); add_placeholder(ttype); } @@ -372,7 +346,7 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator> constexpr static size_t kMaxFuncNameLength = 80; if (candidate_name.size() > kMaxFuncNameLength) { std::stringstream truncated_name; - truncated_name << candidate_name.substr(0, kMaxFuncNameLength); + truncated_name << candidate_name.substr(0, kMaxFuncNameLength); truncated_name << "_" << std::hash{}(candidate_name) << "_"; candidate_name = truncated_name.str(); } @@ -448,49 +422,49 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator> if (data_dependant) { void* data = op->data->data; DataType dtype = DataType(op->data->dtype); - auto value = tvm::te::compute({}, [&](const Array&) { - if (dtype == DataType::Int(32)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Int(64)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Float(32)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Float(64)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Bool()) { - return make_const(dtype, static_cast(data)[0]); - } else { - LOG(FATAL) << "not handled"; - return tvm::PrimExpr(); - } - }, "data_const", topi::kBroadcast); + auto value = tvm::te::compute( + {}, + [&](const Array&) { + if (dtype == DataType::Int(32)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Int(64)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Float(32)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Float(64)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Bool()) { + return make_const(dtype, static_cast(data)[0]); + } else { + LOG(FATAL) << "not handled"; + return tvm::PrimExpr(); + } + }, + "data_const", topi::kBroadcast); scalars_.push_back(value); return {value}; } else { - auto value = tvm::te::compute({}, [&](const Array&) { - return tir::make_const(DataType::Int(64), 0); - }, "shape_const", topi::kBroadcast); + auto value = tvm::te::compute( + {}, [&](const Array&) { return tir::make_const(DataType::Int(64), 0); }, + "shape_const", topi::kBroadcast); scalars_.push_back(value); return {value}; } } Array VisitExpr_(const CallNode* call_node) final { - static auto fshape_func = Op::GetAttr("FShapeFunc"); - static auto tshape_data_dependant = Op::GetAttr( - "TShapeDataDependant"); - CHECK(call_node->op.as()) - << "Primitive function only allows call into primitive ops"; + static auto fshape_func = Op::GetAttrMap("FShapeFunc"); + static auto tshape_data_dependant = Op::GetAttrMap("TShapeDataDependant"); + CHECK(call_node->op.as()) << "Primitive function only allows call into primitive ops"; Op op = Downcast(call_node->op); CHECK(data_dependants_.empty() || !data_dependants_.back()) - << "Error in op fusion: output of the shape func is fed to a " - << "data-dependant shape func"; - CHECK_GT(fshape_func.count(op), 0) - << "Internal error, cannot find ShapeFunc for " << op->name; + << "Error in op fusion: output of the shape func is fed to a " + << "data-dependant shape func"; + CHECK_GT(fshape_func.count(op), 0) << "Internal error, cannot find ShapeFunc for " << op->name; CHECK_GT(tshape_data_dependant.count(op), 0) - << "Internal error, cannot find TShapeDataDependant for " << op->name; + << "Internal error, cannot find TShapeDataDependant for " << op->name; - data_dependants_.push_back(tshape_data_dependant[op]); + data_dependants_.push_back(IsDataDependant(call_node)); // Visit all inputs Array inputs; int count_tuple = 0; @@ -503,8 +477,7 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator> } } if (count_tuple) { - CHECK_EQ(call_node->args.size(), 1U) - << "Only allow function with a single tuple input"; + CHECK_EQ(call_node->args.size(), 1U) << "Only allow function with a single tuple input"; } // Get output ndims auto ret_type = call_node->checked_type(); @@ -543,8 +516,7 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator> Array VisitExpr_(const TupleNode* op) final { Array fields; for (Expr field : op->fields) { - CHECK(field->checked_type().as()) - << "Only allow Tuple of Tensor"; + CHECK(field->checked_type().as()) << "Only allow Tuple of Tensor"; Array res = VisitExpr(field); CHECK_EQ(res.size(), 1); fields.push_back(res[0]); @@ -552,15 +524,22 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator> return fields; } + Array VisitExpr_(const TupleGetItemNode* op) final { + Array input_shapes = VisitExpr(op->tuple); + Array out; + out.push_back(input_shapes[op->index]); + return out; + } + private: /*! \brief String stream for function name */ std::ostringstream readable_name_stream_; /*! \brief Map from parameter to its shape function usage state */ - std::unordered_map param_states_; + std::unordered_map param_states_; /*! \brief Map from parameter to list of data placeholder */ - std::unordered_map, ObjectHash, ObjectEqual> param_data_; + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> param_data_; /*! \brief Map from parameter to list of shape placeholder */ - std::unordered_map, ObjectHash, ObjectEqual> param_shapes_; + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> param_shapes_; /*! \brief Stack of data dependencies for shape function */ std::vector data_dependants_; /*! \brief Scalars used in the shape function */ @@ -570,9 +549,7 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator> class CompileEngineImpl : public CompileEngineNode { public: // Lower the function. - CachedFunc Lower(const CCacheKey& key) { - return LowerInternal(key)->cached_func; - } + CachedFunc Lower(const CCacheKey& key) { return LowerInternal(key)->cached_func; } // For now, build one module per function. PackedFunc JIT(const CCacheKey& key) final { @@ -583,7 +560,7 @@ class CompileEngineImpl : public CompileEngineNode { if (const auto* f = runtime::Registry::Get("relay.backend.build")) { m = (*f)(value->cached_func->funcs, key->target); } else { - m = build(value->cached_func->funcs, key->target, Target(nullptr), BuildConfig::Current()); + m = build(value->cached_func->funcs, key->target, Target(nullptr)); } value->packed_func = m.GetFunction(value->cached_func->func_name); return value->packed_func; @@ -609,7 +586,10 @@ class CompileEngineImpl : public CompileEngineNode { auto symbol_name = src_func->GetAttr(tvm::attr::kGlobalSymbol); CHECK(symbol_name.defined()) << "No external symbol is set for:\n" << AsText(src_func, false); - auto gv = GlobalVar(std::string(symbol_name.value())); + auto gv = GlobalVar(symbol_name.value()); + // No need to keep compiler attribute at this point, functions have been + // extracted for specific codegen. + src_func = WithAttr(std::move(src_func), attr::kCompiler, NullValue()); ext_mods[code_gen_name]->Add(gv, src_func); cached_ext_funcs.push_back(it.first); } @@ -633,9 +613,7 @@ class CompileEngineImpl : public CompileEngineNode { return ret; } - void Clear() final { - cache_.clear(); - } + void Clear() final { cache_.clear(); } // List all items in the cache. Array ListItems() { std::lock_guard lock(mutex_); @@ -659,7 +637,7 @@ class CompileEngineImpl : public CompileEngineNode { private: // implement lowered func - CCacheValue LowerInternal(const CCacheKey& key) { + CCacheValue LowerInternal(const CCacheKey& key) { std::lock_guard lock(mutex_); CCacheValue value; auto it = cache_.find(key); @@ -676,10 +654,8 @@ class CompileEngineImpl : public CompileEngineNode { // codegen tool once and lower all functions together. if (key->source_func->GetAttr(attr::kCompiler).defined()) { auto cache_node = make_object(); - const auto name_node = - key->source_func->GetAttr(tvm::attr::kGlobalSymbol); - CHECK(name_node.defined()) - << "External function has not been attached a name yet."; + const auto name_node = key->source_func->GetAttr(tvm::attr::kGlobalSymbol); + CHECK(name_node.defined()) << "External function has not been attached a name yet."; cache_node->func_name = std::string(name_node.value()); cache_node->target = tvm::target::ext_dev(); value->cached_func = CachedFunc(cache_node); @@ -690,8 +666,7 @@ class CompileEngineImpl : public CompileEngineNode { CHECK(!value->cached_func.defined()); auto cfunc = CreateSchedule(key->source_func, key->target); - auto cache_node = make_object( - *(cfunc.operator->())); + auto cache_node = make_object(*(cfunc.operator->())); // Skip lowering for device copy node. const Expr body = (key->source_func)->body; @@ -710,13 +685,13 @@ class CompileEngineImpl : public CompileEngineNode { } // lower the function if (const auto* f = runtime::Registry::Get("relay.backend.lower")) { - cache_node->funcs = (*f)( - cfunc->schedule, all_args, cache_node->func_name, key->source_func); + cache_node->funcs = (*f)(cfunc->schedule, all_args, cache_node->func_name, key->source_func); } else { - tvm::BuildConfig bcfg = BuildConfig::Create(); + using tvm::transform::PassContext; + With fresh_pass_ctx_scope(PassContext::Create()); + std::unordered_map binds; - cache_node->funcs = tvm::lower(cfunc->schedule, all_args, cache_node->func_name, - binds, bcfg); + cache_node->funcs = tvm::lower(cfunc->schedule, all_args, cache_node->func_name, binds); } value->cached_func = CachedFunc(cache_node); return value; @@ -740,8 +715,7 @@ class CompileEngineImpl : public CompileEngineNode { CHECK(!value->cached_func.defined()); auto spair = MakeShapeFunc().Create(key->source_func); - auto cache_node = make_object( - *(spair.second.operator->())); + auto cache_node = make_object(*(spair.second.operator->())); cache_node->func_name = GetUniqueName(cache_node->func_name); cache_node->target = key->target; @@ -749,9 +723,12 @@ class CompileEngineImpl : public CompileEngineNode { for (te::Tensor arg : cache_node->outputs) { all_args.push_back(arg); } - tvm::BuildConfig bcfg = BuildConfig::Create(); + + using tvm::transform::PassContext; + With fresh_pass_ctx_scope(PassContext::Create()); + std::unordered_map binds; - cache_node->funcs = tvm::lower(spair.first, all_args, cache_node->func_name, binds, bcfg); + cache_node->funcs = tvm::lower(spair.first, all_args, cache_node->func_name, binds); value->cached_func = CachedFunc(cache_node); return value; } @@ -792,57 +769,41 @@ class CompileEngineImpl : public CompileEngineNode { const CompileEngine& CompileEngine::Global() { // intentionally allocate raw pointer to avoid // free during destructuion. - static CompileEngine* inst = new CompileEngine( - make_object()); + static CompileEngine* inst = new CompileEngine(make_object()); return *inst; } TVM_REGISTER_GLOBAL("relay.backend._make_LoweredOutput") -.set_body_typed([](tvm::Array outputs, OpImplementation impl) { - return LoweredOutput(outputs, impl); -}); + .set_body_typed([](tvm::Array outputs, OpImplementation impl) { + return LoweredOutput(outputs, impl); + }); TVM_REGISTER_GLOBAL("relay.backend._make_CCacheKey") -.set_body_typed([](Function source_func, Target target) { - return CCacheKey(source_func, target); -}); + .set_body_typed([](Function source_func, Target target) { + return CCacheKey(source_func, target); + }); -TVM_REGISTER_GLOBAL("relay.backend._CompileEngineGlobal") -.set_body_typed([]() { +TVM_REGISTER_GLOBAL("relay.backend._CompileEngineGlobal").set_body_typed([]() { return CompileEngine::Global(); }); -TVM_REGISTER_GLOBAL("relay.backend._CompileEngineClear") -.set_body_typed([](CompileEngine self) { +TVM_REGISTER_GLOBAL("relay.backend._CompileEngineClear").set_body_typed([](CompileEngine self) { self->Clear(); }); TVM_REGISTER_GLOBAL("relay.backend._CompileEngineLower") -.set_body_typed( - [](CompileEngine self, CCacheKey key) { - return self->Lower(key); -}); + .set_body_typed([](CompileEngine self, CCacheKey key) { return self->Lower(key); }); TVM_REGISTER_GLOBAL("relay.backend._CompileEngineLowerShapeFunc") -.set_body_typed( - [](CompileEngine self, CCacheKey key) { - return self->LowerShapeFunc(key); -}); + .set_body_typed([](CompileEngine self, CCacheKey key) { return self->LowerShapeFunc(key); }); TVM_REGISTER_GLOBAL("relay.backend._CompileLowerExternalFunctions") -.set_body_typed([](CompileEngine self) { - return self->LowerExternalFunctions(); -}); + .set_body_typed([](CompileEngine self) { return self->LowerExternalFunctions(); }); TVM_REGISTER_GLOBAL("relay.backend._CompileEngineJIT") -.set_body_typed( - [](CompileEngine self, CCacheKey key) { - return self->JIT(key); -}); + .set_body_typed([](CompileEngine self, CCacheKey key) { return self->JIT(key); }); -TVM_REGISTER_GLOBAL("relay.backend._CompileEngineListItems") -.set_body_typed( - [](CompileEngine self){ +TVM_REGISTER_GLOBAL("relay.backend._CompileEngineListItems").set_body_typed([](CompileEngine self) { return static_cast(self.operator->())->ListItems(); }); } // namespace relay diff --git a/src/relay/backend/compile_engine.h b/src/relay/backend/compile_engine.h index 4a3a04d02dcd..a5f3f6359f89 100644 --- a/src/relay/backend/compile_engine.h +++ b/src/relay/backend/compile_engine.h @@ -27,13 +27,14 @@ #include #include -#include #include #include -#include #include -#include +#include +#include + #include +#include namespace tvm { namespace relay { @@ -81,7 +82,7 @@ struct CachedFuncNode : public Object { /*! \brief The schedule to the function */ te::Schedule schedule; /*! \brief The lowered functions to support the function. */ - IRModule funcs = IRModule::Empty(); + IRModule funcs = IRModule(); /*! \brief Parameter usage states in the shape function. */ tvm::Array shape_func_param_states; @@ -150,9 +151,7 @@ class CCacheKey : public ObjectRef { */ TVM_DLL CCacheKey(Function source_func, Target target); - const CCacheKeyNode* operator->() const { - return static_cast(get()); - } + const CCacheKeyNode* operator->() const { return static_cast(get()); } // comparator inline bool operator==(const CCacheKey& other) const { CHECK(defined() && other.defined()); @@ -184,12 +183,8 @@ class CCacheValue : public ObjectRef { public: CCacheValue() {} explicit CCacheValue(ObjectPtr n) : ObjectRef(n) {} - CCacheValueNode* operator->() { - return static_cast(get_mutable()); - } - const CCacheValueNode* operator->() const { - return static_cast(get()); - } + CCacheValueNode* operator->() { return static_cast(get_mutable()); } + const CCacheValueNode* operator->() const { return static_cast(get()); } using ContainerType = CCacheValueNode; }; @@ -240,9 +235,7 @@ class CompileEngine : public ObjectRef { public: CompileEngine() {} explicit CompileEngine(ObjectPtr n) : ObjectRef(n) {} - CompileEngineNode* operator->() { - return static_cast(get_mutable()); - } + CompileEngineNode* operator->() { return static_cast(get_mutable()); } using ContainerType = CompileEngineNode; /*! \brief The global compile engine. */ TVM_DLL static const CompileEngine& Global(); @@ -260,17 +253,15 @@ inline size_t CCacheKeyNode::Hash() const { if (hash_ != 0) return hash_; // do structral hash, avoid 0. hash_ = tvm::StructuralHash()(this->source_func); - hash_ = dmlc::HashCombine( - hash_, std::hash()(target->str())); + hash_ = dmlc::HashCombine(hash_, std::hash()(target->str())); if (hash_ == 0) hash_ = 1; return hash_; } -inline bool CCacheKeyNode::Equal( - const CCacheKeyNode* other) const { +inline bool CCacheKeyNode::Equal(const CCacheKeyNode* other) const { if (Hash() != other->Hash()) return false; return this->target->str() == other->target->str() && - tvm::StructuralEqual()(this->source_func, other->source_func); + tvm::StructuralEqual()(this->source_func, other->source_func); } } // namespace relay @@ -278,7 +269,7 @@ inline bool CCacheKeyNode::Equal( namespace std { // overload hash -template<> +template <> struct hash<::tvm::relay::CCacheKey> { size_t operator()(const ::tvm::relay::CCacheKey& key) const { CHECK(key.defined()); diff --git a/src/relay/backend/contrib/codegen_c/codegen.cc b/src/relay/backend/contrib/codegen_c/codegen.cc index 0b3510c85779..2968966e8039 100644 --- a/src/relay/backend/contrib/codegen_c/codegen.cc +++ b/src/relay/backend/contrib/codegen_c/codegen.cc @@ -56,6 +56,25 @@ class CodegenC : public MemoizedExprTranslator>, public Code return {output}; } + std::vector VisitExpr_(const TupleNode* node) final { + std::vector outs; + for (auto field : node->fields) { + auto res = VisitExpr(field); + CHECK_EQ(res.size(), 1U) << "Do not support tuple nest"; + outs.push_back(res[0]); + } + return outs; + } + + std::vector VisitExpr_(const TupleGetItemNode* op) final { + auto res = VisitExpr(op->tuple); + CHECK_GT(res.size(), static_cast(op->index)); + + // Only keep the item we want for the child node. + // FIXME(@comaniac): The other items should still be requried for the primary outputs. + return {res[op->index]}; + } + std::vector VisitExpr_(const ConstantNode* cn) final { // Note this is for demonstration purpose. ConstantNode doesn't necessarily // belong to calls. We need to revisit this when tuples come into play. @@ -68,7 +87,6 @@ class CodegenC : public MemoizedExprTranslator>, public Code runtime::NDArray array = cn->data; const auto& shape = array.Shape(); - const DLTensor& dl_tensor = array.ToDLPack()->dl_tensor; // Get the number of elements. int64_t num_elems = 1; @@ -83,11 +101,11 @@ class CodegenC : public MemoizedExprTranslator>, public Code // to avoid possible stack overflow. buf_stream << dtype << " " << output.name << "[" << num_elems << "] = {"; if (dtype == "float") { - float* p_flt = static_cast(dl_tensor.data); + float* p_flt = static_cast(array->data); for (int64_t i = 0; i < num_elems - 1; i++) buf_stream << p_flt[i] << ", "; if (num_elems) buf_stream << p_flt[num_elems - 1]; } else if (dtype == "int") { - int* p_flt = static_cast(dl_tensor.data); + int* p_flt = static_cast(array->data); for (int64_t i = 0; i < num_elems - 1; i++) buf_stream << p_flt[i] << ", "; if (num_elems) buf_stream << p_flt[num_elems - 1]; } else { @@ -152,8 +170,8 @@ class CodegenC : public MemoizedExprTranslator>, public Code for (size_t i = 0; i < out_shape.size(); ++i) { out_size *= out_shape[i]; } - buf_stream << dtype << "* " << out << - " = (" << dtype << "*)std::malloc(4 * " << out_size << ");"; + buf_stream << dtype << "* " << out << " = (" << dtype << "*)std::malloc(4 * " << out_size + << ");"; buf_decl_.push_back(buf_stream.str()); decl_stream << ", " << out << ");"; diff --git a/src/relay/backend/contrib/codegen_c/codegen_c.h b/src/relay/backend/contrib/codegen_c/codegen_c.h index 92263861d359..3a3c486bb035 100644 --- a/src/relay/backend/contrib/codegen_c/codegen_c.h +++ b/src/relay/backend/contrib/codegen_c/codegen_c.h @@ -25,9 +25,10 @@ #define TVM_RELAY_BACKEND_CONTRIB_CODEGEN_C_CODEGEN_C_H_ #include -#include #include +#include #include + #include #include #include @@ -69,8 +70,7 @@ class CSourceModuleCodegenBase { * \return An external symbol. */ std::string GetExtSymbol(const Function& func) const { - const auto name_node = - func->GetAttr(tvm::attr::kGlobalSymbol); + const auto name_node = func->GetAttr(tvm::attr::kGlobalSymbol); CHECK(name_node.defined()) << "Fail to retrieve external symbol."; return std::string(name_node.value()); } @@ -124,9 +124,8 @@ class CodegenCBase { * * \endcode */ - void GenerateBackendCFunc(const std::string& func_name, - const Array& args, - const Output& out) { + void GenerateBackendCFunc(const std::string& func_name, const Array& args, + const std::vector& outs) { // Print signature code_stream_ << "\n"; code_stream_ << "extern \"C\" int " << func_name << "_wrapper_("; @@ -134,9 +133,11 @@ class CodegenCBase { code_stream_ << "DLTensor* arg" << i << ",\n"; code_stream_ << "\t"; } - if (args.size() > 0) { - code_stream_ << "DLTensor* arg" << args.size() << ") {\n"; + for (size_t i = 0; i < outs.size() - 1; i++) { + code_stream_ << "DLTensor* out" << i << ",\n"; + code_stream_ << "\t"; } + code_stream_ << "DLTensor* out" << outs.size() - 1 << ") {\n"; EnterScope(); @@ -148,18 +149,20 @@ class CodegenCBase { code_stream_ << "static_cast<" << dtype_str << "*>(arg" << i << "->data),\n"; PrintIndents(); } - if (args.size() > 0) { - code_stream_ << "static_cast<" << out.dtype << "*>(arg" << args.size() << "->data)"; + for (size_t i = 0; i < outs.size() - 1; i++) { + code_stream_ << "static_cast<" << outs[i].dtype << "*>(out" << i << "->data),\n"; + PrintIndents(); } - code_stream_ << ");\n"; + code_stream_ << "static_cast<" << outs.back().dtype << "*>(out" << outs.size() - 1 + << "->data));\n"; PrintIndents(); code_stream_ << "return 0;\n"; ExitScope(); code_stream_ << "}\n\n"; // Generate the macro - code_stream_ << "TVM_DLL_EXPORT_TYPED_FUNC(" << func_name << ", " - << func_name << "_wrapper_);\n\n"; + code_stream_ << "TVM_DLL_EXPORT_TYPED_FUNC(" << func_name << ", " << func_name + << "_wrapper_);\n\n"; } /*! @@ -187,19 +190,19 @@ class CodegenCBase { */ std::string JitImpl(const std::string& ext_func_id, const Array& args, const std::vector& buf_decl, - const std::vector& body, - const std::vector& out) { + const std::vector& body, const std::vector& outs) { // Create the signature. For example, it could be: - // extern "C" void dnnl_0_(float* input0, float* input1, float* out, int M, int N) {} + // extern "C" void dnnl_0_(float* in0, float* in1, float* out0, float* out1) {} code_stream_ << "extern \"C\" void " << ext_func_id << "_("; - CHECK_EQ(out.size(), 1U) << "Internal error: only single output is support."; - for (const auto& arg : args) { const auto& dtype_str = GetDtypeString(arg); code_stream_ << dtype_str << "* " << arg->name_hint() << ", "; } - code_stream_ << out[0].dtype << "* out) {\n"; + for (size_t i = 0; i < outs.size() - 1; ++i) { + code_stream_ << outs[i].dtype << "* out" << i << ", "; + } + code_stream_ << outs.back().dtype << "* out" << outs.size() - 1 << ") {\n"; this->EnterScope(); // Function body @@ -214,22 +217,26 @@ class CodegenCBase { } // Copy output - if (out[0].need_copy) { + for (size_t i = 0; i < outs.size(); ++i) { + if (!outs[i].need_copy) { + continue; + } this->PrintIndents(); - code_stream_ << "std::memcpy(out, " << out[0].name << ", 4 * " << out[0].size << ");\n"; + code_stream_ << "std::memcpy(out" << i << ", " << outs[i].name << ", 4 * " << outs[i].size + << ");\n"; + } - // Free buffers - for (size_t i = 0; i < buf_decl.size(); i++) { - this->PrintIndents(); - code_stream_ << "std::free(buf_" << i << ");\n"; - } + // Free buffers + for (size_t i = 0; i < buf_decl.size(); i++) { + this->PrintIndents(); + code_stream_ << "std::free(buf_" << i << ");\n"; } this->ExitScope(); code_stream_ << "}\n"; // Create the wrapper to call the ext_func - this->GenerateBackendCFunc(ext_func_id, args, out[0]); + this->GenerateBackendCFunc(ext_func_id, args, outs); return code_stream_.str(); } diff --git a/src/relay/backend/contrib/dnnl/codegen.cc b/src/relay/backend/contrib/dnnl/codegen.cc index 26bc8786902c..3f9ad7cdc69f 100644 --- a/src/relay/backend/contrib/dnnl/codegen.cc +++ b/src/relay/backend/contrib/dnnl/codegen.cc @@ -144,6 +144,16 @@ class CodegenDNNL : public MemoizedExprTranslator>, public C return {output}; } + std::vector VisitExpr_(const TupleNode* node) final { + std::vector outs; + for (auto field : node->fields) { + auto res = VisitExpr(field); + CHECK_EQ(res.size(), 1U) << "Do not support tuple nest"; + outs.push_back(res[0]); + } + return outs; + } + std::vector VisitExpr_(const TupleGetItemNode* op) final { auto res = VisitExpr(op->tuple); CHECK_GT(res.size(), static_cast(op->index)); @@ -169,12 +179,12 @@ class CodegenDNNL : public MemoizedExprTranslator>, public C CHECK_EQ(GetDtypeString(type_node), "float") << "Only float is supported for now."; std::ostringstream buf_stream; - const float* ptr = static_cast(array.ToDLPack()->dl_tensor.data); + const float* ptr = static_cast(array->data); // Allocate large arrays on the static section to avoid stakc overflow. // Note that this would probably increase compilation time as the source // file could be really large. - buf_stream << "static float " << output.name << "[" << num_elems <<"] = {"; + buf_stream << "static float " << output.name << "[" << num_elems << "] = {"; for (int64_t i = 0; i < num_elems - 1; i++) { buf_stream << ptr[i] << ","; } @@ -347,8 +357,6 @@ class DNNLModuleCodegen : public CSourceModuleCodegenBase { // Create a corresponding DNNL function for the given relay Function. void GenDNNLFunc(const Function& func) { CHECK(func.defined()) << "Input error: expect a Relay function."; - const auto* call = func->body.as(); - CHECK(call) << "DNNL expects a single convolution or dense op"; // Record the external symbol for runtime lookup. auto sid = GetExtSymbol(func); diff --git a/src/relay/backend/graph_plan_memory.cc b/src/relay/backend/graph_plan_memory.cc index 736509d2d97f..820e17f8a498 100644 --- a/src/relay/backend/graph_plan_memory.cc +++ b/src/relay/backend/graph_plan_memory.cc @@ -22,10 +22,11 @@ * \brief Memory index assignment pass for executing * the program in the graph runtime. */ -#include +#include #include #include -#include +#include + #include "../../support/arena.h" namespace tvm { @@ -60,9 +61,7 @@ class StorageAllocaBaseVisitor : public ExprVisitor { } } - void VisitExpr_(const ConstantNode* op) final { - this->CreateToken(op, false); - } + void VisitExpr_(const ConstantNode* op) final { this->CreateToken(op, false); } void VisitExpr_(const VarNode* op) final { // Do nothing. @@ -96,9 +95,7 @@ class StorageAllocaBaseVisitor : public ExprVisitor { token_map_[op] = {tok[op->index]}; } - void VisitExpr_(const IfNode* op) final { - LOG(FATAL) << "if is not supported."; - } + void VisitExpr_(const IfNode* op) final { LOG(FATAL) << "if is not supported."; } void VisitExpr_(const LetNode* op) final { auto token = GetToken(op->value); @@ -131,12 +128,11 @@ class StorageAllocaBaseVisitor : public ExprVisitor { class StorageAllocaInit : protected StorageAllocaBaseVisitor { public: - explicit StorageAllocaInit(support::Arena* arena) - : arena_(arena) {} + explicit StorageAllocaInit(support::Arena* arena) : arena_(arena) {} /*! \return The internal token map */ - std::unordered_map > - GetInitTokenMap(const Function& func) { + std::unordered_map > GetInitTokenMap( + const Function& func) { node_device_map_ = CollectDeviceInfo(func); this->Run(func); return std::move(token_map_); @@ -145,12 +141,11 @@ class StorageAllocaInit : protected StorageAllocaBaseVisitor { protected: using StorageAllocaBaseVisitor::VisitExpr_; - void CreateToken(const ExprNode* op, bool can_realloc) final { + void CreateToken(const ExprNode* op, bool can_realloc) final { CHECK(!token_map_.count(op)); std::vector tokens; - int device_type = node_device_map_.count(GetRef(op)) - ? node_device_map_[GetRef(op)]->value - : 0; + int device_type = + node_device_map_.count(GetRef(op)) ? node_device_map_[GetRef(op)]->value : 0; if (const auto* tuple_type = op->checked_type().as()) { for (Type t : tuple_type->fields) { const auto* ttype = t.as(); @@ -227,10 +222,9 @@ class StorageAllocator : public StorageAllocaBaseVisitor { } // Either all or none of the nodes should be annotated. if (num_annotated_nodes != 0 && num_annotated_nodes != num_nodes) { - LOG(FATAL) - << num_annotated_nodes << " out of " << num_nodes - << "expressions are assigned with virtual device types. Either all " - "or none of the expressions are expected to be annotated."; + LOG(FATAL) << num_annotated_nodes << " out of " << num_nodes + << "expressions are assigned with virtual device types. Either all " + "or none of the expressions are expected to be annotated."; } return smap; } @@ -296,12 +290,8 @@ class StorageAllocator : public StorageAllocaBaseVisitor { size_t size = 1; for (IndexExpr dim : ttype->shape) { const int64_t* pval = tir::as_const_int(dim); - CHECK(pval != nullptr) - << "Cannot allocate memory symbolic tensor shape " - << ttype->shape; - CHECK_GE(*pval, 0) - << "Cannot allocate memory for tensor with negative shape" - << *pval; + CHECK(pval != nullptr) << "Cannot allocate memory symbolic tensor shape " << ttype->shape; + CHECK_GE(*pval, 0) << "Cannot allocate memory for tensor with negative shape" << *pval; size *= static_cast(pval[0]); } size *= DivRoundUp(ttype->dtype.bits() * ttype->dtype.lanes(), 8); @@ -324,7 +314,7 @@ class StorageAllocator : public StorageAllocaBaseVisitor { auto end = free_.upper_bound(size * match_range_); // search for memory blocks larger than requested for (auto it = mid; it != end; ++it) { - StorageToken *tok = it->second; + StorageToken* tok = it->second; if (tok->device_type != prototype->device_type) continue; CHECK_EQ(tok->ref_counter, 0); // Use exect matching strategy @@ -337,7 +327,7 @@ class StorageAllocator : public StorageAllocaBaseVisitor { // then search for memory blocks smaller than requested space for (auto it = mid; it != begin;) { --it; - StorageToken *tok = it->second; + StorageToken* tok = it->second; if (tok->device_type != prototype->device_type) continue; CHECK_EQ(tok->ref_counter, 0); // Use exect matching strategy @@ -390,8 +380,7 @@ Map > GraphPlanMemory(const Function& func) { return StorageAllocator().Plan(func); } -TVM_REGISTER_GLOBAL("relay.backend.GraphPlanMemory") -.set_body_typed(GraphPlanMemory); +TVM_REGISTER_GLOBAL("relay.backend.GraphPlanMemory").set_body_typed(GraphPlanMemory); } // namespace relay } // namespace tvm diff --git a/src/relay/backend/graph_runtime_codegen.cc b/src/relay/backend/graph_runtime_codegen.cc index 7b686c76e3e7..4226cc872589 100644 --- a/src/relay/backend/graph_runtime_codegen.cc +++ b/src/relay/backend/graph_runtime_codegen.cc @@ -44,7 +44,7 @@ class GraphInputNode; class GraphOpNode; using IntegerArray = Array; -using ShapeVector = std::vector >; +using ShapeVector = std::vector>; using GraphAttrs = std::unordered_map; using GraphObjectPtr = std::shared_ptr; using GraphInputObjectPtr = std::shared_ptr; @@ -54,7 +54,7 @@ using TargetsMap = std::unordered_map; /*! \brief Lowered outputs */ struct LoweredOutput { std::string graph_json; - Map lowered_funcs; + Map lowered_funcs; Array external_mods; std::unordered_map params; }; @@ -70,8 +70,7 @@ class GraphNodeRef { public: GraphNodeRef() {} GraphNodeRef(int ident, int index, int version = 0) - : ident_(ident), index_(index), version_(version) {} - + : ident_(ident), index_(index), version_(version) {} inline void Save(dmlc::JSONWriter* writer) const { writer->BeginArray(); @@ -81,9 +80,7 @@ class GraphNodeRef { writer->EndArray(); } - inline void Load(dmlc::JSONReader* reader) { - LOG(FATAL) << "Not implemented."; - } + inline void Load(dmlc::JSONReader* reader) { LOG(FATAL) << "Not implemented."; } protected: int ident_; @@ -136,11 +133,8 @@ class GraphInputNode : public GraphNode { class GraphOpNode : public GraphNode { public: GraphOpNode() {} - GraphOpNode(const std::string& name, - const GraphAttrs& nd_attrs, - const std::string& op_name, - const std::vector& inputs, - const GraphAttrs& attrs, + GraphOpNode(const std::string& name, const GraphAttrs& nd_attrs, const std::string& op_name, + const std::vector& inputs, const GraphAttrs& attrs, size_t num_outputs = 1) { name_ = name; attrs_ = nd_attrs; @@ -173,8 +167,7 @@ class GraphOpNode : public GraphNode { const GraphAttrs& nd_attrs, const std::string& op_name, const std::vector& inputs, - const GraphAttrs& attrs, - size_t num_outputs = 1) { + const GraphAttrs& attrs, size_t num_outputs = 1) { auto ptr = std::make_shared(name, nd_attrs, op_name, inputs, attrs, num_outputs); return std::dynamic_pointer_cast(ptr); } @@ -214,7 +207,7 @@ class GraphRuntimeCodegen : public backend::MemoizedExprTranslatorUpdate(kv.second); @@ -335,8 +328,7 @@ class GraphRuntimeCodegen : public backend::MemoizedExprTranslator GraphAddCallNode(const CallNode* op, - const std::string& op_name, + std::vector GraphAddCallNode(const CallNode* op, const std::string& op_name, const std::string& func_name) { std::vector inputs; for (auto arg : op->args) { @@ -345,11 +337,7 @@ class GraphRuntimeCodegen : public backend::MemoizedExprTranslator(op)); } @@ -384,11 +372,11 @@ class GraphRuntimeCodegen : public backend::MemoizedExprTranslatorvalue; // Normal Relay Function if (targets_.size() == 1) { - // homogeneous execution. + // homogeneous execution. const auto& it = targets_.begin(); target = (*it).second; } else { @@ -400,20 +388,17 @@ class GraphRuntimeCodegen : public backend::MemoizedExprTranslatorstr())) { - lowered_funcs_[target->str()] = IRModule::Empty(); + lowered_funcs_[target->str()] = IRModule(); } lowered_funcs_[target->str()]->Update(lowered_func->funcs); - return GraphAddCallNode(op, - _GetUniqueName(lowered_func->func_name), - lowered_func->func_name); + return GraphAddCallNode(op, _GetUniqueName(lowered_func->func_name), lowered_func->func_name); } std::vector VisitExpr_(const LetNode* op) override { @@ -560,37 +545,34 @@ class GraphRuntimeCodegen : public backend::MemoizedExprTranslator& sptr_to_self) { - if (name == "init") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - CHECK_EQ(args.num_args, 2) - << "The expected of arguments are: " - << "runtime::Module mod and Map targets"; - void* mod = args[0]; - Map tmp = args[1]; - TargetsMap targets; - for (const auto& it : tmp) { - auto dev_type = it.first.as(); - CHECK(dev_type); - targets[dev_type->value] = it.second; - } - codegen_ = std::make_shared( - reinterpret_cast(mod), targets); - }); + virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) { + if (name == "init") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + CHECK_EQ(args.num_args, 2) << "The expected of arguments are: " + << "runtime::Module mod and Map targets"; + void* mod = args[0]; + Map tmp = args[1]; + TargetsMap targets; + for (const auto& it : tmp) { + auto dev_type = it.first.as(); + CHECK(dev_type); + targets[dev_type->value] = it.second; + } + codegen_ = + std::make_shared(reinterpret_cast(mod), targets); + }); } else if (name == "codegen") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { Function func = args[0]; this->output_ = this->codegen_->Codegen(func); }); } else if (name == "get_graph_json") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->output_.graph_json; - }); + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->output_.graph_json; }); } else if (name == "list_params_name") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { Array ret; - for (const auto &kv : this->output_.params) { + for (const auto& kv : this->output_.params) { ret.push_back(kv.first); } *rv = ret; @@ -614,9 +596,7 @@ class GraphRuntimeCodegenModule : public runtime::ModuleNode { } } - const char* type_key() const final { - return "RelayGraphRuntimeCodegenModule"; - } + const char* type_key() const final { return "RelayGraphRuntimeCodegenModule"; } private: std::shared_ptr codegen_; @@ -629,9 +609,7 @@ runtime::Module CreateGraphCodegenMod() { } TVM_REGISTER_GLOBAL("relay.build_module._GraphRuntimeCodegen") -.set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = CreateGraphCodegenMod(); -}); + .set_body([](TVMArgs args, TVMRetValue* rv) { *rv = CreateGraphCodegenMod(); }); } // namespace backend } // namespace relay diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index 465f788449e2..9a75c0ab76ee 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -21,16 +21,16 @@ * \file src/relay/interpreter.cc * \brief An interpreter for the Relay IR. */ -#include -#include -#include -#include -#include -#include +#include #include #include +#include #include -#include +#include +#include +#include +#include +#include #include "compile_engine.h" @@ -39,8 +39,7 @@ namespace relay { using namespace runtime; -InterpreterClosure::InterpreterClosure(tvm::Map env, - Function func) { +InterpreterClosure::InterpreterClosure(tvm::Map env, Function func) { ObjectPtr n = make_object(); n->env = std::move(env); n->func = std::move(func); @@ -48,10 +47,10 @@ InterpreterClosure::InterpreterClosure(tvm::Map env, } TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "InterpreterClosureNode(" << node->func << ", " << node->env << ")"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "InterpreterClosureNode(" << node->func << ", " << node->env << ")"; + }); inline const PackedFunc& GetPackedFunc(const std::string& name) { const PackedFunc* pf = tvm::runtime::Registry::Get(name); @@ -69,10 +68,10 @@ RecClosure::RecClosure(InterpreterClosure clos, Var bind) { } TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "RecClosureObj(" << node->clos << ")"; - }); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "RecClosureObj(" << node->clos << ")"; + }); RefValue::RefValue(ObjectRef value) { ObjectPtr n = make_object(); @@ -80,21 +79,19 @@ RefValue::RefValue(ObjectRef value) { data_ = std::move(n); } -TVM_REGISTER_GLOBAL("relay._make.RefValue") -.set_body_typed([](ObjectRef value){ +TVM_REGISTER_GLOBAL("relay._make.RefValue").set_body_typed([](ObjectRef value) { return RefValue(value); }); TVM_REGISTER_NODE_TYPE(RefValueObj); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "RefValueObj(" << node->value << ")"; - }); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "RefValueObj(" << node->value << ")"; + }); -ConstructorValue::ConstructorValue(int32_t tag, - tvm::Array fields, +ConstructorValue::ConstructorValue(int32_t tag, tvm::Array fields, Constructor constructor) { ObjectPtr n = make_object(); n->tag = tag; @@ -104,19 +101,17 @@ ConstructorValue::ConstructorValue(int32_t tag, } TVM_REGISTER_GLOBAL("relay._make.ConstructorValue") -.set_body_typed([](int32_t tag, tvm::Array fields, - Constructor constructor) { - return ConstructorValue(tag, fields, constructor); -}); + .set_body_typed([](int32_t tag, tvm::Array fields, Constructor constructor) { + return ConstructorValue(tag, fields, constructor); + }); TVM_REGISTER_NODE_TYPE(ConstructorValueObj); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "ConstructorValueObj(" << node->tag << "," - << node->fields << ")"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "ConstructorValueObj(" << node->tag << "," << node->fields << ")"; + }); /*! * \brief A stack frame in the Relay interpreter. @@ -161,9 +156,7 @@ struct Stack { */ struct LocalFrame { Stack& st; - explicit LocalFrame(Stack& st, const Frame& fr) : st(st) { - st.frames.push_back(fr); - } + explicit LocalFrame(Stack& st, const Frame& fr) : st(st) { st.frames.push_back(fr); } ~LocalFrame() { st.frames.pop_back(); } }; }; @@ -188,22 +181,25 @@ class InterpreterStateObj : public Object { v->Visit("stack", &stack); } - static InterpreterState make(Expr current_expr, Stack stack); - static constexpr const char* _type_key = "relay.InterpreterState"; TVM_DECLARE_FINAL_OBJECT_INFO(InterpreterStateObj, Object); }; class InterpreterState : public ObjectRef { public: + using Frame = tvm::Map; + using Stack = tvm::Array; + + InterpreterState(Expr current_expr, Stack stack); + TVM_DEFINE_OBJECT_REF_METHODS(InterpreterState, ObjectRef, InterpreterStateObj); }; -InterpreterState InterpreterStateObj::make(Expr current_expr, Stack stack) { +InterpreterState::InterpreterState(Expr current_expr, InterpreterState::Stack stack) { ObjectPtr n = make_object(); n->current_expr = std::move(current_expr); n->stack = std::move(stack); - return InterpreterState(n); + data_ = std::move(n); } // NOTE: the current interpreter assumes A-normal form. @@ -213,9 +209,8 @@ InterpreterState InterpreterStateObj::make(Expr current_expr, Stack stack) { // contains DAG in dataflow-form. // // Conversion to ANF is recommended before running the interpretation. -class Interpreter : - public ExprFunctor, - PatternFunctor { +class Interpreter : public ExprFunctor, + PatternFunctor { public: Interpreter(IRModule mod, DLContext context, Target target) : mod_(mod), @@ -232,21 +227,13 @@ class Interpreter : return f(); } - void extend(const Var& id, ObjectRef v) { - stack_.current_frame().locals.Set(id, v); - } + void extend(const Var& id, ObjectRef v) { stack_.current_frame().locals.Set(id, v); } - ObjectRef Lookup(const Var& local) { - return stack_.Lookup(local); - } + ObjectRef Lookup(const Var& local) { return stack_.Lookup(local); } - ObjectRef Eval(const Expr& expr) { - return VisitExpr(expr); - } + ObjectRef Eval(const Expr& expr) { return VisitExpr(expr); } - ObjectRef VisitExpr_(const VarNode* var_node) final { - return Lookup(GetRef(var_node)); - } + ObjectRef VisitExpr_(const VarNode* var_node) final { return Lookup(GetRef(var_node)); } ObjectRef VisitExpr_(const GlobalVarNode* op) final { return Eval(mod_->Lookup(GetRef(op))); @@ -260,9 +247,7 @@ class Interpreter : return ObjectRef(); } - ObjectRef VisitExpr_(const ConstantNode* op) final { - return op->data.CopyTo(context_); - } + ObjectRef VisitExpr_(const ConstantNode* op) final { return op->data.CopyTo(context_); } ObjectRef VisitExpr_(const TupleNode* op) final { std::vector values; @@ -302,8 +287,7 @@ class Interpreter : return MakeClosure(func); } - Array ComputeDynamicShape(const Function& func, - const Array& args) { + Array ComputeDynamicShape(const Function& func, const Array& args) { CCacheKey key(func, Target::Create("llvm")); auto cfunc = engine_->LowerShapeFunc(key); size_t arity = cfunc->inputs.size() + cfunc->outputs.size(); @@ -319,26 +303,26 @@ class Interpreter : cpu_ctx.device_id = 0; auto fset_input = [&](size_t i, ObjectRef val, bool need_shape) { - auto nd_array = Downcast(val); - if (need_shape) { - int64_t ndim = nd_array.Shape().size(); - NDArray shape_arr; - if (ndim == 0) { - shape_arr = NDArray::Empty({}, DataType::Int(64), cpu_ctx); - } else { - shape_arr = NDArray::Empty({ndim}, DataType::Int(64), cpu_ctx); - int64_t* data = reinterpret_cast(shape_arr->data); - for (auto j = 0; j < ndim; ++j) { - data[j] = nd_array.Shape()[j]; - } - } - inputs[i] = shape_arr; - setter(i, shape_arr); + auto nd_array = Downcast(val); + if (need_shape) { + int64_t ndim = nd_array.Shape().size(); + NDArray shape_arr; + if (ndim == 0) { + shape_arr = NDArray::Empty({}, DataType::Int(64), cpu_ctx); } else { - auto arr = nd_array.CopyTo(cpu_ctx); - inputs[i] = arr; - setter(i, arr); + shape_arr = NDArray::Empty({ndim}, DataType::Int(64), cpu_ctx); + int64_t* data = reinterpret_cast(shape_arr->data); + for (auto j = 0; j < ndim; ++j) { + data[j] = nd_array.Shape()[j]; + } } + inputs[i] = shape_arr; + setter(i, shape_arr); + } else { + auto arr = nd_array.CopyTo(cpu_ctx); + inputs[i] = arr; + setter(i, arr); + } }; size_t arg_counter = 0; @@ -367,17 +351,16 @@ class Interpreter : } } } - CHECK_EQ(arg_counter, cfunc->inputs.size()) - << "Shape function input sizes mismatch"; + CHECK_EQ(arg_counter, cfunc->inputs.size()) << "Shape function input sizes mismatch"; auto fset_shape_output = [&](size_t i, Type val_type) { - // TODO(@icemelon): allow recursive tuple - const TensorTypeNode* rtype = val_type.as(); - CHECK(rtype != nullptr); - int64_t ndim = rtype->shape.size(); - auto arr = NDArray::Empty({ndim}, DataType::Int(64), cpu_ctx); - outputs[i] = arr; - setter(arg_counter + i, arr); + // TODO(@icemelon): allow recursive tuple + const TensorTypeNode* rtype = val_type.as(); + CHECK(rtype != nullptr); + int64_t ndim = rtype->shape.size(); + auto arr = NDArray::Empty({ndim}, DataType::Int(64), cpu_ctx); + outputs[i] = arr; + setter(arg_counter + i, arr); }; auto ret_type = func->body->checked_type(); @@ -392,8 +375,7 @@ class Interpreter : auto tt = Downcast(ret_type); fset_shape_output(0, tt); } - CHECK_EQ(cfunc->outputs.size(), out_cnt) - << "Shape function output sizes mismatch"; + CHECK_EQ(cfunc->outputs.size(), out_cnt) << "Shape function output sizes mismatch"; PackedFunc shape_func; Module m; @@ -401,7 +383,7 @@ class Interpreter : if (const auto* f = runtime::Registry::Get("relay.backend.build")) { m = (*f)(cfunc->funcs, cfunc->target); } else { - m = build(cfunc->funcs, cfunc->target, Target(nullptr), BuildConfig::Current()); + m = build(cfunc->funcs, cfunc->target, Target(nullptr)); } shape_func = m.GetFunction(cfunc->func_name); shape_func.CallPacked(TVMArgs(values.data(), codes.data(), arity), &rv); @@ -419,8 +401,7 @@ class Interpreter : return out_shapes; } - ObjectRef InvokePrimitiveOp(const Function& func, - const Array& args) { + ObjectRef InvokePrimitiveOp(const Function& func, const Array& args) { const auto* call_node = func->body.as(); if (call_node && call_node->op == debug_op_) { @@ -451,8 +432,7 @@ class Interpreter : if (const auto* tuple_type = func->body->checked_type().as()) { arg_len += tuple_type->fields.size(); } else { - CHECK(func->body->checked_type().as()) - << func->body->checked_type(); + CHECK(func->body->checked_type().as()) << func->body->checked_type(); arg_len += 1; } std::vector values(arg_len); @@ -463,16 +443,14 @@ class Interpreter : const auto nd_array = Downcast(val); setter(i, nd_array); DLContext arg_ctx = nd_array->ctx; - CHECK(arg_ctx.device_type == context_.device_type && - arg_ctx.device_id == context_.device_id) - << "Interpreter expect context to be " - << context_ << ", but get " << arg_ctx; + CHECK(arg_ctx.device_type == context_.device_type && arg_ctx.device_id == context_.device_id) + << "Interpreter expect context to be " << context_ << ", but get " << arg_ctx; }; int arg_counter = 0; for (ObjectRef arg : args) { if (arg->IsInstance()) { - fset_input(arg_counter++, arg); + fset_input(arg_counter++, arg); } else { auto adt = Downcast(arg); for (size_t i = 0; i < adt.size(); ++i) { @@ -547,8 +525,7 @@ class Interpreter : } // Invoke the closure - ObjectRef Invoke(const InterpreterClosure& closure, - const tvm::Array& args, + ObjectRef Invoke(const InterpreterClosure& closure, const tvm::Array& args, const Var& bind = Var()) { // Get a reference to the function inside the closure. if (closure->func->HasNonzeroAttr(attr::kPrimitive)) { @@ -625,11 +602,9 @@ class Interpreter : ObjectRef VisitExpr_(const TupleGetItemNode* op) final { ObjectRef val = Eval(op->tuple); const auto* adt_obj = val.as(); - CHECK(adt_obj) - << "interal error: when evaluating TupleGetItem expected an ADT value"; + CHECK(adt_obj) << "interal error: when evaluating TupleGetItem expected an ADT value"; auto adt = GetRef(adt_obj); - CHECK_LT(static_cast(op->index), adt.size()) - << "internal error: index out of bounds"; + CHECK_LT(static_cast(op->index), adt.size()) << "internal error: index out of bounds"; return adt[op->index]; } @@ -665,9 +640,7 @@ class Interpreter : } } - ObjectRef VisitExpr_(const RefCreateNode* op) final { - return RefValue(Eval(op->value)); - } + ObjectRef VisitExpr_(const RefCreateNode* op) final { return RefValue(Eval(op->value)); } ObjectRef VisitExpr_(const RefReadNode* op) final { ObjectRef r = Eval(op->ref); @@ -718,9 +691,7 @@ class Interpreter : return true; } - bool VisitPattern_(const PatternWildcardNode* op, const ObjectRef& v) final { - return true; - } + bool VisitPattern_(const PatternWildcardNode* op, const ObjectRef& v) final { return true; } bool VisitPattern_(const PatternVarNode* op, const ObjectRef& v) final { extend(op->var, v); @@ -733,7 +704,7 @@ class Interpreter : InterpreterStateObj::Frame frame = fr.locals; stack.push_back(frame); } - auto state = InterpreterStateObj::make(e, stack); + auto state = InterpreterState(e, stack); return state; } @@ -754,17 +725,11 @@ class Interpreter : const Op& shape_of_op_; }; - -TypedPackedFunc -CreateInterpreter( - IRModule mod, - DLContext context, - Target target) { +TypedPackedFunc CreateInterpreter(IRModule mod, DLContext context, Target target) { if (mod.defined()) { // eta expand to support constructors in argument position - transform::Sequential seq({ - transform::EtaExpand( - /* expand_constructor */ true, /* expand_global_var */ false)}); + transform::Sequential seq({transform::EtaExpand( + /* expand_constructor */ true, /* expand_global_var */ false)}); transform::PassContext pass_ctx = transform::PassContext::Current(); tvm::With ctx(pass_ctx); mod = seq(mod); @@ -779,8 +744,7 @@ CreateInterpreter( return TypedPackedFunc(packed); } -TVM_REGISTER_GLOBAL("relay.backend.CreateInterpreter") -.set_body_typed(CreateInterpreter); +TVM_REGISTER_GLOBAL("relay.backend.CreateInterpreter").set_body_typed(CreateInterpreter); } // namespace relay } // namespace tvm diff --git a/src/relay/backend/param_dict.cc b/src/relay/backend/param_dict.cc index e517fee3a4af..ef4b6589bdba 100644 --- a/src/relay/backend/param_dict.cc +++ b/src/relay/backend/param_dict.cc @@ -22,86 +22,77 @@ * \brief Implementation and registration of parameter dictionary * serializing/deserializing functions. */ -#include +#include "param_dict.h" + #include +#include #include -#include #include - -#include "param_dict.h" - - +#include namespace tvm { namespace relay { using namespace runtime; -TVM_REGISTER_GLOBAL("tvm.relay._save_param_dict") -.set_body([](TVMArgs args, TVMRetValue *rv) { - CHECK_EQ(args.size() % 2, 0u); - // `args` is in the form "key, value, key, value, ..." - size_t num_params = args.size() / 2; - std::vector names; - names.reserve(num_params); - std::vector arrays; - arrays.reserve(num_params); - for (size_t i = 0; i < num_params * 2; i += 2) { - names.emplace_back(args[i].operator std::string()); - arrays.emplace_back(args[i + 1].operator DLTensor*()); - } - std::string bytes; - dmlc::MemoryStringStream strm(&bytes); - dmlc::Stream* fo = &strm; - uint64_t header = kTVMNDArrayListMagic, reserved = 0; - fo->Write(header); - fo->Write(reserved); - fo->Write(names); - { - uint64_t sz = static_cast(arrays.size()); - fo->Write(sz); - for (size_t i = 0; i < sz; ++i) { - tvm::runtime::SaveDLTensor(fo, arrays[i]); - } +TVM_REGISTER_GLOBAL("tvm.relay._save_param_dict").set_body([](TVMArgs args, TVMRetValue* rv) { + CHECK_EQ(args.size() % 2, 0u); + // `args` is in the form "key, value, key, value, ..." + size_t num_params = args.size() / 2; + std::vector names; + names.reserve(num_params); + std::vector arrays; + arrays.reserve(num_params); + for (size_t i = 0; i < num_params * 2; i += 2) { + names.emplace_back(args[i].operator String()); + arrays.emplace_back(args[i + 1].operator DLTensor*()); + } + std::string bytes; + dmlc::MemoryStringStream strm(&bytes); + dmlc::Stream* fo = &strm; + uint64_t header = kTVMNDArrayListMagic, reserved = 0; + fo->Write(header); + fo->Write(reserved); + fo->Write(names); + { + uint64_t sz = static_cast(arrays.size()); + fo->Write(sz); + for (size_t i = 0; i < sz; ++i) { + tvm::runtime::SaveDLTensor(fo, arrays[i]); } - TVMByteArray arr; - arr.data = bytes.c_str(); - arr.size = bytes.length(); - *rv = arr; - }); + } + TVMByteArray arr; + arr.data = bytes.c_str(); + arr.size = bytes.length(); + *rv = arr; +}); -TVM_REGISTER_GLOBAL("tvm.relay._load_param_dict") -.set_body([](TVMArgs args, TVMRetValue *rv) { - std::string bytes = args[0]; - std::vector names; - dmlc::MemoryStringStream memstrm(&bytes); - dmlc::Stream* strm = &memstrm; - uint64_t header, reserved; - CHECK(strm->Read(&header)) - << "Invalid parameters file format"; - CHECK(header == kTVMNDArrayListMagic) - << "Invalid parameters file format"; - CHECK(strm->Read(&reserved)) - << "Invalid parameters file format"; - CHECK(strm->Read(&names)) - << "Invalid parameters file format"; - uint64_t sz; - strm->Read(&sz, sizeof(sz)); - size_t size = static_cast(sz); - CHECK(size == names.size()) - << "Invalid parameters file format"; - tvm::Array ret; - for (size_t i = 0; i < size; ++i) { - tvm::runtime::NDArray temp; - temp.Load(strm); - auto n = tvm::make_object(); - n->name = std::move(names[i]); - n->array = temp; - ret.push_back(NamedNDArray(n)); - } - *rv = ret; - }); +TVM_REGISTER_GLOBAL("tvm.relay._load_param_dict").set_body([](TVMArgs args, TVMRetValue* rv) { + std::string bytes = args[0]; + std::vector names; + dmlc::MemoryStringStream memstrm(&bytes); + dmlc::Stream* strm = &memstrm; + uint64_t header, reserved; + CHECK(strm->Read(&header)) << "Invalid parameters file format"; + CHECK(header == kTVMNDArrayListMagic) << "Invalid parameters file format"; + CHECK(strm->Read(&reserved)) << "Invalid parameters file format"; + CHECK(strm->Read(&names)) << "Invalid parameters file format"; + uint64_t sz; + strm->Read(&sz, sizeof(sz)); + size_t size = static_cast(sz); + CHECK(size == names.size()) << "Invalid parameters file format"; + tvm::Array ret; + for (size_t i = 0; i < size; ++i) { + tvm::runtime::NDArray temp; + temp.Load(strm); + auto n = tvm::make_object(); + n->name = std::move(names[i]); + n->array = temp; + ret.push_back(NamedNDArray(n)); + } + *rv = ret; +}); TVM_REGISTER_NODE_TYPE(NamedNDArrayNode); diff --git a/src/relay/backend/param_dict.h b/src/relay/backend/param_dict.h index c829e546b90b..384201f94648 100644 --- a/src/relay/backend/param_dict.h +++ b/src/relay/backend/param_dict.h @@ -25,9 +25,9 @@ #define TVM_RELAY_BACKEND_PARAM_DICT_H_ #include -#include #include #include +#include #include diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index 65e6ae9e79c6..4475d43f2898 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -32,7 +32,6 @@ #include #include #include -#include #include #include @@ -74,7 +73,7 @@ class MemoizedExprTranslator : public ::tvm::relay::ExprFunctor memo_; + std::unordered_map memo_; }; /*! @@ -129,7 +128,7 @@ inline std::string DType2String(const tvm::DataType dtype) { inline relay::Function BindParamsByName( relay::Function func, const std::unordered_map& params) { std::unordered_map name_dict; - std::unordered_set repeat_var; + std::unordered_set repeat_var; for (auto arg : func->params) { const auto& name = arg->name_hint(); if (name_dict.count(name)) { @@ -139,7 +138,7 @@ inline relay::Function BindParamsByName( } } - std::unordered_map bind_dict; + std::unordered_map bind_dict; for (auto& kv : params) { if (name_dict.count(kv.first) == 0) { continue; @@ -216,5 +215,4 @@ inline const CallNode* GetRootCall(const CallNode* current_call, int depth, } // namespace relay } // namespace tvm - #endif // TVM_RELAY_BACKEND_UTILS_H_ diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 7e2d43e7b35d..81db34125bd7 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -22,26 +22,29 @@ * \brief A compiler from relay::Module to the VM byte code. */ -#include +#include "compiler.h" + +#include #include +#include #include #include #include -#include #include #include -#include -#include +#include +#include #include #include #include #include #include -#include "../utils.h" + #include "../../backend/compile_engine.h" -#include "../../transforms/pass_util.h" #include "../../op/op_common.h" +#include "../../transforms/pass_util.h" +#include "../utils.h" #include "compiler.h" namespace tvm { @@ -54,10 +57,22 @@ Pass InlinePrimitives(); Pass ManifestAlloc(Target target_host) { auto f = tvm::runtime::Registry::Get("relay.transform.ManifestAlloc"); - CHECK(f != nullptr) << "could not load memory allocation pass"; + CHECK(f != nullptr) << "unable to load allocation manifestation pass"; return (*f)(target_host); } +Pass MemoryPlan() { + auto f = tvm::runtime::Registry::Get("relay.transform.MemoryPlan"); + CHECK(f != nullptr) << "unable to load the memory planning pass"; + return (*f)(); +} + +Pass LiftConstants() { + auto f = tvm::runtime::Registry::Get("relay.transform.LiftConstants"); + CHECK(f != nullptr) << "unable to load the constant lifting pass"; + return (*f)(); +} + } // namespace transform namespace vm { @@ -93,8 +108,7 @@ struct AccessField : MatchValue { // Runtime register num after compiling the access field path RegName reg{-1}; - AccessField(MatchValuePtr parent, size_t index) - : parent(parent), index(index) {} + AccessField(MatchValuePtr parent, size_t index) : parent(parent), index(index) {} ~AccessField() {} }; @@ -115,8 +129,7 @@ struct VarBinding : ConditionNode { Var var; MatchValuePtr val; - VarBinding(Var var, MatchValuePtr val) - : var(var), val(val) {} + VarBinding(Var var, MatchValuePtr val) : var(var), val(val) {} ~VarBinding() {} }; @@ -131,9 +144,7 @@ struct TagCompare : ConditionNode { /*! \brief The expected tag */ int target_tag; - TagCompare(MatchValuePtr obj, size_t target) - : obj(obj), target_tag(target) { - } + TagCompare(MatchValuePtr obj, size_t target) : obj(obj), target_tag(target) {} ~TagCompare() {} }; @@ -143,10 +154,8 @@ using TreeLeafNode = relay::TreeLeafNode; using TreeLeafFatalNode = relay::TreeLeafFatalNode; using TreeBranchNode = relay::TreeBranchNode; -TreeObjectPtr BuildDecisionTreeFromPattern(MatchValuePtr data, - Pattern pattern, - TreeObjectPtr then_branch, - TreeObjectPtr else_branch) { +TreeObjectPtr BuildDecisionTreeFromPattern(MatchValuePtr data, Pattern pattern, + TreeObjectPtr then_branch, TreeObjectPtr else_branch) { if (pattern.as()) { // We ignore wildcard binding since it's not producing new vars return then_branch; @@ -176,11 +185,10 @@ TreeObjectPtr BuildDecisionTreeFromPattern(MatchValuePtr data, } } -TreeObjectPtr BuildDecisionTreeFromClause(MatchValuePtr data, - Clause clause, - TreeObjectPtr else_branch) { - return BuildDecisionTreeFromPattern(data, clause->lhs, - TreeLeafNode::Make(clause->rhs), else_branch); +TreeObjectPtr BuildDecisionTreeFromClause(MatchValuePtr data, Clause clause, + TreeObjectPtr else_branch) { + return BuildDecisionTreeFromPattern(data, clause->lhs, TreeLeafNode::Make(clause->rhs), + else_branch); } TreeObjectPtr BuildDecisionTreeFromClauses(MatchValuePtr data, tvm::Array clauses) { @@ -193,40 +201,29 @@ TreeObjectPtr BuildDecisionTreeFromClauses(MatchValuePtr data, tvm::Array ToAllocTensorShape64(NDArray shape) { +std::vector ToAllocTensorShape(NDArray shape) { std::vector raw_shape; - DLTensor tensor = shape.ToDLPack()->dl_tensor; - CHECK_EQ(tensor.ndim, 1u); - CHECK_EQ(tensor.dtype.code, 0U) << "found " << tensor.dtype.code; - - // TODO(@jroesch): we really need to standaridize the bit width of - // all of the shape manipulating code. - CHECK_EQ(tensor.dtype.bits, 64) << "found " << tensor.dtype.bits; - int64_t* int_ptr = reinterpret_cast(tensor.data); - for (auto i = 0; i < tensor.shape[0]; i++) { - raw_shape.push_back(int_ptr[i]); - } - return raw_shape; -} - - -std::vector ToAllocTensorShape32(NDArray shape) { - std::vector raw_shape; - DLTensor tensor = shape.ToDLPack()->dl_tensor; - CHECK_EQ(tensor.ndim, 1u); - CHECK_EQ(tensor.dtype.code, 0U) << "found " << tensor.dtype.code; - - // TODO(@jroesch): we really need to standaridize the bit width of - // all of the shape manipulating code. - CHECK_LE(tensor.dtype.bits, 32) << "found " << tensor.dtype.bits; - int32_t* int_ptr = reinterpret_cast(tensor.data); - for (auto i = 0; i < tensor.shape[0]; i++) { - raw_shape.push_back(static_cast(int_ptr[i])); + CHECK_EQ(shape->ndim, 1u); + CHECK_EQ(shape->dtype.code, 0U) << "The dtype of constant shape must be int32 or int64, but got " + << DLDataType2String(shape->dtype); + CHECK(shape->dtype.bits == 64 || shape->dtype.bits == 32) + << "The dtype of constant shape must be int32 or int64, but got" + << DLDataType2String(shape->dtype); + + if (shape->dtype.bits == 64) { + int64_t* int_ptr = reinterpret_cast(shape->data); + for (auto i = 0; i < shape->shape[0]; i++) { + raw_shape.push_back(int_ptr[i]); + } + } else { // int32 + int32_t* int_ptr = reinterpret_cast(shape->data); + for (auto i = 0; i < shape->shape[0]; i++) { + raw_shape.push_back(static_cast(int_ptr[i])); + } } return raw_shape; } - class VMFunctionCompiler : ExprFunctor { public: VMFunctionCompiler(VMCompilerContext* context, TargetsMap targets, Target target_host) @@ -297,6 +294,15 @@ class VMFunctionCompiler : ExprFunctor { } void VisitExpr_(const ConstantNode* const_node) { + // Check the shape is valid + NDArray data = const_node->data; + const DLTensor* tensor = data.operator->(); + if (tensor->ndim > 0) { + int64_t* shapes = reinterpret_cast(tensor->shape); + for (auto i = 0; i < tensor->ndim; i++) { + CHECK_GT(shapes[i], 0U); + } + } size_t konst_idx = context_->constants.size(); context_->constants.push_back(const_node->data); Emit(Instruction::LoadConst(konst_idx, NewRegister())); @@ -319,11 +325,7 @@ class VMFunctionCompiler : ExprFunctor { } // TODO(@jroesch): use correct tag - Emit(Instruction::AllocADT( - 0, - tuple->fields.size(), - fields_registers, - NewRegister())); + Emit(Instruction::AllocADT(0, tuple->fields.size(), fields_registers, NewRegister())); } void VisitExpr_(const MatchNode* match_node) { @@ -424,52 +426,46 @@ class VMFunctionCompiler : ExprFunctor { for (auto input : inputs) { auto reg = var_register_map_.find(Downcast(input)); CHECK(reg != var_register_map_.end()) - << "internal error: all variables should be in the register mapping"; + << "internal error: all variables should be in the register mapping"; argument_registers.push_back(reg->second); } for (auto output : outputs) { auto reg = var_register_map_.find(Downcast(output)); CHECK(reg != var_register_map_.end()) - << "internal error: all variables should be in the register mapping"; + << "internal error: all variables should be in the register mapping"; argument_registers.push_back(reg->second); } - Emit(Instruction::InvokePacked(op_index, - argument_registers.size(), - outputs.size(), - argument_registers)); + Emit(Instruction::InvokePacked(op_index, argument_registers.size(), outputs.size(), + argument_registers)); } - void EmitInvokeTVMOp(const Function& func, - const Expr& inputs, - const Expr& outputs) { + void EmitInvokeTVMOp(const Function& func, const Expr& inputs, const Expr& outputs) { std::vector argument_registers; CHECK(func->GetAttr(attr::kPrimitive, 0) != 0) - << "internal error: invoke_tvm_op requires the first argument to be a relay::Function"; + << "internal error: invoke_tvm_op requires the first argument to be a relay::Function"; auto input_tuple = inputs.as(); - CHECK(input_tuple) - << "internal error: invoke_tvm_op inputs must be a tuple," - << "please file a bug in the memory manifestation pass"; + CHECK(input_tuple) << "internal error: invoke_tvm_op inputs must be a tuple," + << "please file a bug in the memory manifestation pass"; auto output_tuple = outputs.as(); - CHECK(output_tuple) - << "internal error: invoke_tvm_op outputs must be a tuple," - << "please file a bug in the memory manifestation pass"; + CHECK(output_tuple) << "internal error: invoke_tvm_op outputs must be a tuple," + << "please file a bug in the memory manifestation pass"; for (auto input : input_tuple->fields) { auto reg = var_register_map_.find(Downcast(input)); CHECK(reg != var_register_map_.end()) - << "internal error: all variables should be in the register mapping"; + << "internal error: all variables should be in the register mapping"; argument_registers.push_back(reg->second); } for (auto output : output_tuple->fields) { auto reg = var_register_map_.find(Downcast(output)); CHECK(reg != var_register_map_.end()) - << "internal error: all variables should be in the register mapping"; + << "internal error: all variables should be in the register mapping"; argument_registers.push_back(reg->second); } @@ -509,10 +505,8 @@ class VMFunctionCompiler : ExprFunctor { } } - Emit(Instruction::InvokePacked(op_index, - argument_registers.size(), - output_tuple->fields.size(), - argument_registers)); + Emit(Instruction::InvokePacked(op_index, argument_registers.size(), output_tuple->fields.size(), + argument_registers)); } void VisitExpr_(const CallNode* call_node) { @@ -523,79 +517,77 @@ class VMFunctionCompiler : ExprFunctor { // allocation operations. if (op.as()) { OpMatch matcher; - matcher.Match("memory.invoke_tvm_op", - [this](const Array& args, const Attrs& attrs, const Array& type_arg) { - CHECK_EQ(args.size(), 3); - EmitInvokeTVMOp(Downcast(args[0]), args[1], args[2]); - }).Match("memory.alloc_tensor", - [this](const Array& args, const Attrs& attrs, const Array& type_arg) { - CHECK_EQ(args.size(), 2); - - // Get the attributes. - auto alloc_attrs = attrs.as(); - CHECK(alloc_attrs != nullptr) - << "must be the alloc tensor attrs"; - auto dtype = alloc_attrs->dtype; - - // The storage will be passed dynamically. - this->VisitExpr(args[0]); - auto storage_register = last_register_; - - // If the shape is constant then we will emit a static tensor allocation instruction. - auto const_shape = args[1].as(); - - if (const_shape) { - NDArray shape = const_shape->data; - std::vector raw_shape; - DLTensor tensor = shape.ToDLPack()->dl_tensor; - // TODO(@jroesch): we need to get an RFC done to standarize this - if (tensor.dtype.bits == 64) { - raw_shape = ToAllocTensorShape64(shape); - } else if (tensor.dtype.bits == 32) { - raw_shape = ToAllocTensorShape32(shape); - } else { - LOG(FATAL) << "unsupported bitwidth: " << tensor.dtype.bits; - } - - // Add context field. - Emit(Instruction::AllocTensor(storage_register, raw_shape, dtype, NewRegister())); - } else { - this->VisitExpr(args[1]); - auto shape_register = last_register_; - Emit(Instruction::AllocTensorReg( - storage_register, - shape_register, - dtype, - NewRegister())); - } - }).Match("memory.alloc_storage", - [this](const Array& args, const Attrs& attrs, const Array& type_arg) { - CHECK_EQ(args.size(), 2); - // Compute the size of the allocation. - this->VisitExpr(args[0]); - auto size_register = last_register_; - - this->VisitExpr(args[1]); - auto alignment_register = last_register_; - - // Get the dtype hint from the attributes. - auto alloc_attrs = attrs.as(); - CHECK(alloc_attrs != nullptr) - << "must be the alloc tensor attrs"; - auto dtype = alloc_attrs->dtype; - - Emit(Instruction::AllocStorage(size_register, alignment_register, dtype, NewRegister())); - }).Match("memory.shape_func", - [this](const Array& args, const Attrs& attrs, const Array& type_arg) { - CHECK_EQ(args.size(), 3); - auto shape_func = Downcast(args[0]); - auto inputs = Downcast(args[1]); - auto outputs = Downcast(args[2]); - EmitShapeFunc(shape_func, inputs->fields, outputs->fields); - }).Match("memory.kill", - [](const Array& args, const Attrs& attrs, const Array& type_arg) { - LOG(FATAL) << "memory.kill is not yet supported"; - }); + matcher + .Match("memory.invoke_tvm_op", + [this](const Array& args, const Attrs& attrs, const Array& type_arg) { + CHECK_EQ(args.size(), 3); + EmitInvokeTVMOp(Downcast(args[0]), args[1], args[2]); + }) + .Match("memory.alloc_tensor", + [this](const Array& args, const Attrs& attrs, const Array& type_arg) { + CHECK_EQ(args.size(), 3); + + // Get the attributes. + auto alloc_attrs = attrs.as(); + CHECK(alloc_attrs != nullptr) << "must be the alloc tensor attrs"; + auto dtype = alloc_attrs->dtype; + + // The storage will be passed dynamically. + this->VisitExpr(args[0]); + auto storage_register = last_register_; + + // The storage will be passed dynamically. + this->VisitExpr(args[1]); + auto offset_register = last_register_; + + // If the shape is constant then we will emit a static tensor allocation + // instruction. + auto const_shape = args[2].as(); + + if (const_shape) { + NDArray shape = const_shape->data; + // TODO(@jroesch): we need to get an RFC done to standarize shape dtype + std::vector raw_shape = ToAllocTensorShape(shape); + // Add context field. + Emit(Instruction::AllocTensor(storage_register, offset_register, raw_shape, + dtype, NewRegister())); + } else { + this->VisitExpr(args[2]); + auto shape_register = last_register_; + Emit(Instruction::AllocTensorReg(storage_register, offset_register, + shape_register, dtype, NewRegister())); + } + }) + .Match("memory.alloc_storage", + [this](const Array& args, const Attrs& attrs, const Array& type_arg) { + CHECK_EQ(args.size(), 2); + // Compute the size of the allocation. + this->VisitExpr(args[0]); + auto size_register = last_register_; + + this->VisitExpr(args[1]); + auto alignment_register = last_register_; + + // Get the dtype hint from the attributes. + auto alloc_attrs = attrs.as(); + CHECK(alloc_attrs != nullptr) << "must be the alloc tensor attrs"; + auto dtype = alloc_attrs->dtype; + + Emit(Instruction::AllocStorage(size_register, alignment_register, dtype, + NewRegister())); + }) + .Match("memory.shape_func", + [this](const Array& args, const Attrs& attrs, const Array& type_arg) { + CHECK_EQ(args.size(), 3); + auto shape_func = Downcast(args[0]); + auto inputs = Downcast(args[1]); + auto outputs = Downcast(args[2]); + EmitShapeFunc(shape_func, inputs->fields, outputs->fields); + }) + .Match("memory.kill", + [](const Array& args, const Attrs& attrs, const Array& type_arg) { + LOG(FATAL) << "memory.kill is not yet supported"; + }); matcher(GetRef(call_node)); return; } @@ -618,14 +610,13 @@ class VMFunctionCompiler : ExprFunctor { auto it = context_->global_map.find(global); CHECK(it != context_->global_map.end()); DLOG(INFO) << "VisitExpr_: generating invoke for " << global->name_hint - << " with func_index=" << it->second; + << " with func_index=" << it->second; // TODO(tvm-team): // Think about mixed call into global that is not a relay::Function // perhaps establish as an invariance(all functions in mod must be relay::Function) auto func = Downcast(context_->module->Lookup(global)); - if (IsClosure(func)) { auto arity = func->params.size(); Emit(Instruction::AllocClosure(it->second, arity, args_registers, NewRegister())); @@ -735,13 +726,13 @@ class VMFunctionCompiler : ExprFunctor { protected: /*! \brief Store the expression a variable points to. */ - std::unordered_map expr_map_; + std::unordered_map expr_map_; /*! \brief Instructions in the VMFunction. */ std::vector instructions_; /*! \brief Parameter names of the function. */ std::vector params_; /*! \brief Map from var to register number. */ - std::unordered_map var_register_map_; + std::unordered_map var_register_map_; /*! \brief Last used register number. */ size_t last_register_; /*! \brief Total number of virtual registers allocated. */ @@ -756,9 +747,7 @@ class VMFunctionCompiler : ExprFunctor { Target target_host_; }; - -PackedFunc VMCompiler::GetFunction(const std::string& name, - const ObjectPtr& sptr_to_self) { +PackedFunc VMCompiler::GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) { if (name == "lower") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { CHECK_EQ(args.num_args, 3); @@ -771,19 +760,18 @@ PackedFunc VMCompiler::GetFunction(const std::string& name, this->Codegen(); }); } else if (name == "get_executable") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = runtime::Module(exec_); - }); + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = runtime::Module(exec_); }); } else if (name == "set_params") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - Map params = args[0]; + Map params = args[0]; for (const auto& kv : params) { this->SetParam(kv.first, kv.second->data); } }); } else if (name == "get_params") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - Map ret; + Map ret; for (const auto& kv : params_) { ret.Set(kv.first, Constant(kv.second)); } @@ -804,11 +792,8 @@ void VMCompiler::SetParam(const std::string& name, runtime::NDArray data_in) { params_[name] = data_in; } -void VMCompiler::Lower(IRModule mod, - const TargetsMap& targets, - const tvm::Target& target_host) { - CHECK_EQ(targets.size(), 1) - << "Currently VM compiler doesn't support heterogeneous compilation"; +void VMCompiler::Lower(IRModule mod, const TargetsMap& targets, const tvm::Target& target_host) { + CHECK_EQ(targets.size(), 1) << "Currently VM compiler doesn't support heterogeneous compilation"; if (params_.size()) { BaseFunc base_func = mod->Lookup("main"); CHECK(base_func->IsInstance()) @@ -871,6 +856,44 @@ void VMCompiler::Lower(IRModule mod, } } +transform::Sequential MemoryOpt(tvm::Target host_target) { + Array pass_seqs; + // Manifest the allocations. + pass_seqs.push_back(transform::ManifestAlloc(host_target)); + + // Compute away possibly introduced constant computation. + pass_seqs.push_back(transform::FoldConstant()); + + // Fuse the shape functions. + pass_seqs.push_back(transform::FuseOps()); + + // Manifest the allocations needed for the shape functions. + pass_seqs.push_back(transform::ManifestAlloc(host_target)); + + // Fuse the shape functions. + pass_seqs.push_back(transform::FuseOps()); + + // Perform memory planning in order to coalesce/reduce allocations. + pass_seqs.push_back(transform::MemoryPlan()); + + // Compute away constant computation introduced by coalescing allocations. + pass_seqs.push_back(transform::FoldConstant()); + + // Fuse the shape functions. + pass_seqs.push_back(transform::FuseOps()); + + // Create allocations for math introduced by dynamic region math. + pass_seqs.push_back(transform::ManifestAlloc(host_target)); + + // Compute away possibly introduced constant computation. + pass_seqs.push_back(transform::FoldConstant()); + + // Lift constants to the top-level of the block to simplify VM code generation. + pass_seqs.push_back(transform::LiftConstants()); + + return transform::Sequential(pass_seqs); +} + IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targets) { Array pass_seqs; Array entry_functions{"main"}; @@ -885,7 +908,7 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe // eta expand to support constructors in argument position pass_seqs.push_back(transform::EtaExpand( - /* expand_constructor */ true, /* expand_global_var */ false)); + /* expand_constructor */ true, /* expand_global_var */ false)); pass_seqs.push_back(transform::SimplifyInference()); PackedFunc fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) { @@ -924,13 +947,6 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe pass_seqs.push_back(transform::LambdaLift()); pass_seqs.push_back(transform::InlinePrimitives()); - // Manifest the allocations. - pass_seqs.push_back(transform::ManifestAlloc(this->target_host_)); - // Compute away possibly introduced constant computation. - pass_seqs.push_back(transform::FoldConstant()); - // Fuse the shape functions. - pass_seqs.push_back(transform::FuseOps()); - // Inline the functions that are lifted to the module scope. We perform this // pass after all other optimization passes but before the memory allocation // pass. This is because memory allocation pass will insert `invoke_tvm_op` @@ -938,8 +954,7 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe // external codegen. pass_seqs.push_back(transform::Inline()); - // Manifest the allocations needed for the shape functions. - pass_seqs.push_back(transform::ManifestAlloc(this->target_host_)); + pass_seqs.push_back(MemoryOpt(this->target_host_)); transform::Sequential seq(pass_seqs); transform::PassContext pass_ctx = PassContext::Current(); @@ -967,7 +982,7 @@ void VMCompiler::Codegen() { LOG(WARNING) << "Did you forget to call VMCompiler::Lower?"; return; } - auto const &cached_funcs = context_.cached_funcs; + auto const& cached_funcs = context_.cached_funcs; if (cached_funcs.size() == 0) { return; } @@ -993,7 +1008,11 @@ void VMCompiler::Codegen() { auto ext_mods = compile_engine->LowerExternalFunctions(); runtime::Module mod; if (funcs.size() > 0) { - mod = tvm::build(funcs, target_host_, tvm::BuildConfig::Current()); + Map build_funcs; + for (const auto& i : funcs) { + build_funcs.Set(i.first, i.second); + } + mod = tvm::build(build_funcs, target_host_); CHECK(mod.operator->()); } else { CHECK_EQ(ext_mods.size(), 1U) @@ -1017,8 +1036,7 @@ runtime::Module CreateVMCompiler() { return runtime::Module(exec); } -TVM_REGISTER_GLOBAL("relay._vm._VMCompiler") -.set_body([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("relay._vm._VMCompiler").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = CreateVMCompiler(); }); diff --git a/src/relay/backend/vm/compiler.h b/src/relay/backend/vm/compiler.h index c1040f1ed18e..8b1df7f5122d 100644 --- a/src/relay/backend/vm/compiler.h +++ b/src/relay/backend/vm/compiler.h @@ -28,9 +28,11 @@ #include #include #include -#include #include #include +#include +#include + #include #include #include @@ -38,8 +40,9 @@ #include #include #include -#include "../../../runtime/vm/profiler/vm.h" + #include "../../../runtime/vm/naive_allocator.h" +#include "../../../runtime/vm/profiler/vm.h" #include "../../backend/compile_engine.h" #include "../../transforms/pass_util.h" @@ -52,7 +55,7 @@ using namespace tvm::runtime::vm; using namespace relay::transform; template -using NodeMap = std::unordered_map; +using NodeMap = std::unordered_map; using TagMap = NodeMap; using TagNameMap = std::unordered_map; using GlobalMap = NodeMap; @@ -76,20 +79,16 @@ struct VMCompilerContext { // List of cached functions std::vector cached_funcs; // The functions that have been lowered. - std::unordered_map seen_funcs; + std::unordered_map seen_funcs; }; - class VMCompiler : public runtime::ModuleNode { public: virtual ~VMCompiler() {} - virtual PackedFunc GetFunction(const std::string& name, - const ObjectPtr& sptr_to_self); + virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self); - const char* type_key() const { - return "VMCompiler"; - } + const char* type_key() const { return "VMCompiler"; } /*! * \brief Set the parameters @@ -107,9 +106,7 @@ class VMCompiler : public runtime::ModuleNode { to target mapping. For homogeneous compilation, it is a build target. * \param target_host Host compilation target, if target is device. */ - void Lower(IRModule mod, - const TargetsMap& targets, - const tvm::Target& target_host); + void Lower(IRModule mod, const TargetsMap& targets, const tvm::Target& target_host); /*! \brief Generate the machine code for lowered functions. */ void Codegen(); diff --git a/src/relay/backend/vm/inline_primitives.cc b/src/relay/backend/vm/inline_primitives.cc index 12113b0683f2..cf4f533a0bee 100644 --- a/src/relay/backend/vm/inline_primitives.cc +++ b/src/relay/backend/vm/inline_primitives.cc @@ -24,9 +24,10 @@ #include #include -#include #include #include +#include + #include #include @@ -53,7 +54,7 @@ namespace vm { */ struct PrimitiveInliner : ExprMutator { IRModule module_; - std::unordered_map var_map; + std::unordered_map var_map; explicit PrimitiveInliner(const IRModule& module) : module_(module) {} @@ -125,18 +126,13 @@ struct PrimitiveInliner : ExprMutator { if (n->GetAttr(attr::kCompiler).defined()) continue; auto func = GetRef(n); - DLOG(INFO) << "Before inlining primitives: " << global - << std::endl << AsText(func, false); + DLOG(INFO) << "Before inlining primitives: " << global << std::endl << AsText(func, false); - func = Function(func->params, - VisitExpr(func->body), - func->ret_type, - func->type_params, + func = Function(func->params, VisitExpr(func->body), func->ret_type, func->type_params, func->attrs); module_->Add(global, func, true); - DLOG(INFO) << "After inlining primitives: " << global - << std::endl << AsText(func, false); + DLOG(INFO) << "After inlining primitives: " << global << std::endl << AsText(func, false); } } return module_; @@ -149,16 +145,13 @@ namespace transform { Pass InlinePrimitives() { runtime::TypedPackedFunc pass_func = - [=](IRModule m, PassContext pc) { - return relay::vm::PrimitiveInliner(m).Inline(); - }; + [=](IRModule m, PassContext pc) { return relay::vm::PrimitiveInliner(m).Inline(); }; auto inline_pass = CreateModulePass(pass_func, 1, "Inline", {}); // Eliminate dead code for each function after inlining. return Sequential({inline_pass, DeadCodeElimination()}, "InlinePrimitives"); } -TVM_REGISTER_GLOBAL("relay._transform.InlinePrimitives") -.set_body_typed(InlinePrimitives); +TVM_REGISTER_GLOBAL("relay._transform.InlinePrimitives").set_body_typed(InlinePrimitives); } // namespace transform diff --git a/src/relay/backend/vm/lambda_lift.cc b/src/relay/backend/vm/lambda_lift.cc index bfbefd57a310..011c7d2f9a6b 100644 --- a/src/relay/backend/vm/lambda_lift.cc +++ b/src/relay/backend/vm/lambda_lift.cc @@ -24,12 +24,13 @@ #include #include +#include #include #include -#include -#include #include #include +#include + #include #include @@ -44,9 +45,7 @@ inline std::string GenerateName(const Function& func) { return std::string("lifted_name") + std::to_string(hash); } -bool IsClosure(const Function& func) { - return func->GetAttr(attr::kClosure, 0) != 0; -} +bool IsClosure(const Function& func) { return func->GetAttr(attr::kClosure, 0) != 0; } Function MarkClosure(Function func) { return WithAttr(std::move(func), attr::kClosure, tvm::Integer(1)); @@ -85,8 +84,7 @@ class LambdaLifter : public ExprMutator { if (!letrec_.empty() && var == letrec_.back()) { auto it = lambda_map_.find(var); CHECK(it != lambda_map_.end()); - return Call(it->second, call->args, call_node->attrs, - call_node->type_args); + return Call(it->second, call->args, call_node->attrs, call_node->type_args); } } return std::move(call); @@ -153,18 +151,15 @@ class LambdaLifter : public ExprMutator { if (captured_vars.size() == 0 && free_type_vars.size() == 0) { lifted_func = Function(body->params, body->body, body->ret_type, body->type_params); } else { - lifted_func = - Function(captured_vars, body, func->func_type_annotation(), free_type_vars); + lifted_func = Function(captured_vars, body, func->func_type_annotation(), free_type_vars); lifted_func = MarkClosure(lifted_func); } CHECK(lifted_func.defined()); - if (module_->ContainGlobalVar(name)) { const auto existing_func = module_->Lookup(name); - CHECK(tvm::StructuralEqual()(lifted_func, existing_func)) - << "lifted function hash collision"; + CHECK(tvm::StructuralEqual()(lifted_func, existing_func)) << "lifted function hash collision"; // If an identical function already exists, use its global var. global = module_->GetGlobalVar(name); } else { @@ -192,10 +187,7 @@ class LambdaLifter : public ExprMutator { if (auto* n = pair.second.as()) { if (n->GetAttr(attr::kCompiler).defined()) continue; auto func = GetRef(n); - func = Function(func->params, - VisitExpr(func->body), - func->ret_type, - func->type_params, + func = Function(func->params, VisitExpr(func->body), func->ret_type, func->type_params, func->attrs); module_->Add(pair.first, func, true); } @@ -204,7 +196,7 @@ class LambdaLifter : public ExprMutator { } private: - std::unordered_map lambda_map_; + std::unordered_map lambda_map_; std::vector letrec_; IRModule module_; }; @@ -215,14 +207,11 @@ namespace transform { Pass LambdaLift() { runtime::TypedPackedFunc pass_func = - [=](IRModule m, PassContext pc) { - return relay::vm::LambdaLifter(m).Lift(); - }; + [=](IRModule m, PassContext pc) { return relay::vm::LambdaLifter(m).Lift(); }; return CreateModulePass(pass_func, 1, "LambdaLift", {}); } -TVM_REGISTER_GLOBAL("relay._transform.LambdaLift") -.set_body_typed(LambdaLift); +TVM_REGISTER_GLOBAL("relay._transform.LambdaLift").set_body_typed(LambdaLift); } // namespace transform diff --git a/src/relay/backend/vm/removed_unused_funcs.cc b/src/relay/backend/vm/removed_unused_funcs.cc index c2fe37f15453..4e8713b4900d 100644 --- a/src/relay/backend/vm/removed_unused_funcs.cc +++ b/src/relay/backend/vm/removed_unused_funcs.cc @@ -22,12 +22,13 @@ * \brief Remove unused global relay functions in a relay module. */ +#include #include #include -#include -#include #include #include +#include + #include #include #include @@ -46,12 +47,9 @@ struct CallTracer : ExprVisitor { std::unordered_set called_funcs_; // Record the expressions that are being visited - std::unordered_set visiting_; + std::unordered_set visiting_; - explicit CallTracer(const IRModule& module) - : module_{module}, - called_funcs_{}, - visiting_{} {} + explicit CallTracer(const IRModule& module) : module_{module}, called_funcs_{}, visiting_{} {} void VisitExpr_(const GlobalVarNode* op) final { called_funcs_.insert(op->name_hint); @@ -86,8 +84,7 @@ struct CallTracer : ExprVisitor { * * \return The module with dead functions removed. */ -IRModule RemoveUnusedFunctions(const IRModule& module, - Array entry_funcs) { +IRModule RemoveUnusedFunctions(const IRModule& module, Array entry_funcs) { std::unordered_set called_funcs{}; for (auto entry : entry_funcs) { auto funcs = CallTracer(module).Trace(entry); @@ -108,15 +105,14 @@ IRModule RemoveUnusedFunctions(const IRModule& module, namespace transform { Pass RemoveUnusedFunctions(Array entry_functions) { - runtime::TypedPackedFunc pass_func = - [=](IRModule m, PassContext pc) { + runtime::TypedPackedFunc pass_func = [=](IRModule m, + PassContext pc) { return relay::vm::RemoveUnusedFunctions(m, entry_functions); }; return CreateModulePass(pass_func, 1, "RemoveUnusedFunctions", {}); } -TVM_REGISTER_GLOBAL("relay._transform.RemoveUnusedFunctions") -.set_body_typed(RemoveUnusedFunctions); +TVM_REGISTER_GLOBAL("relay._transform.RemoveUnusedFunctions").set_body_typed(RemoveUnusedFunctions); } // namespace transform diff --git a/src/relay/ir/adt.cc b/src/relay/ir/adt.cc index 11c2cbb772fc..d808351e841c 100644 --- a/src/relay/ir/adt.cc +++ b/src/relay/ir/adt.cc @@ -21,8 +21,8 @@ * \file src/ir/adt.cc * \brief AST nodes for Relay algebraic data types (ADTs). */ -#include #include +#include namespace tvm { namespace relay { @@ -34,15 +34,12 @@ PatternWildcard::PatternWildcard() { TVM_REGISTER_NODE_TYPE(PatternWildcardNode); -TVM_REGISTER_GLOBAL("relay.ir.PatternWildcard") -.set_body_typed([]() { - return PatternWildcard(); -}); +TVM_REGISTER_GLOBAL("relay.ir.PatternWildcard").set_body_typed([]() { return PatternWildcard(); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - p->stream << "PatternWildcardNode()"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + p->stream << "PatternWildcardNode()"; + }); PatternVar::PatternVar(tvm::relay::Var var) { ObjectPtr n = make_object(); @@ -52,19 +49,17 @@ PatternVar::PatternVar(tvm::relay::Var var) { TVM_REGISTER_NODE_TYPE(PatternVarNode); -TVM_REGISTER_GLOBAL("relay.ir.PatternVar") -.set_body_typed([](tvm::relay::Var var) { +TVM_REGISTER_GLOBAL("relay.ir.PatternVar").set_body_typed([](tvm::relay::Var var) { return PatternVar(var); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "PatternVarNode(" << node->var << ")"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "PatternVarNode(" << node->var << ")"; + }); -PatternConstructor::PatternConstructor(Constructor constructor, - tvm::Array patterns) { +PatternConstructor::PatternConstructor(Constructor constructor, tvm::Array patterns) { ObjectPtr n = make_object(); n->constructor = std::move(constructor); n->patterns = std::move(patterns); @@ -74,16 +69,15 @@ PatternConstructor::PatternConstructor(Constructor constructor, TVM_REGISTER_NODE_TYPE(PatternConstructorNode); TVM_REGISTER_GLOBAL("relay.ir.PatternConstructor") -.set_body_typed([](Constructor constructor, tvm::Array patterns) { - return PatternConstructor(constructor, patterns); -}); + .set_body_typed([](Constructor constructor, tvm::Array patterns) { + return PatternConstructor(constructor, patterns); + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "PatternConstructorNode(" << node->constructor - << ", " << node->patterns << ")"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "PatternConstructorNode(" << node->constructor << ", " << node->patterns << ")"; + }); PatternTuple::PatternTuple(tvm::Array patterns) { ObjectPtr n = make_object(); @@ -93,16 +87,15 @@ PatternTuple::PatternTuple(tvm::Array patterns) { TVM_REGISTER_NODE_TYPE(PatternTupleNode); -TVM_REGISTER_GLOBAL("relay.ir.PatternTuple") -.set_body_typed([](tvm::Array patterns) { +TVM_REGISTER_GLOBAL("relay.ir.PatternTuple").set_body_typed([](tvm::Array patterns) { return PatternTuple(patterns); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "PatternTupleNode(" << node->patterns << ")"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "PatternTupleNode(" << node->patterns << ")"; + }); Clause::Clause(Pattern lhs, Expr rhs) { ObjectPtr n = make_object(); @@ -113,17 +106,15 @@ Clause::Clause(Pattern lhs, Expr rhs) { TVM_REGISTER_NODE_TYPE(ClauseNode); -TVM_REGISTER_GLOBAL("relay.ir.Clause") -.set_body_typed([](Pattern lhs, Expr rhs) { +TVM_REGISTER_GLOBAL("relay.ir.Clause").set_body_typed([](Pattern lhs, Expr rhs) { return Clause(lhs, rhs); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "ClauseNode(" << node->lhs << ", " - << node->rhs << ")"; - }); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "ClauseNode(" << node->lhs << ", " << node->rhs << ")"; + }); Match::Match(Expr data, tvm::Array clauses, bool complete) { ObjectPtr n = make_object(); @@ -136,16 +127,16 @@ Match::Match(Expr data, tvm::Array clauses, bool complete) { TVM_REGISTER_NODE_TYPE(MatchNode); TVM_REGISTER_GLOBAL("relay.ir.Match") -.set_body_typed([](Expr data, tvm::Array clauses, bool complete) { - return Match(data, clauses, complete); -}); + .set_body_typed([](Expr data, tvm::Array clauses, bool complete) { + return Match(data, clauses, complete); + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "MatchNode(" << node->data << ", " - << node->clauses << ", " << node->complete << ")"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "MatchNode(" << node->data << ", " << node->clauses << ", " << node->complete + << ")"; + }); } // namespace relay } // namespace tvm diff --git a/src/relay/ir/base.cc b/src/relay/ir/base.cc index 76a3f9d4446e..5f7b8747a751 100644 --- a/src/relay/ir/base.cc +++ b/src/relay/ir/base.cc @@ -23,8 +23,8 @@ */ #include -#include #include +#include namespace tvm { namespace relay { @@ -33,14 +33,13 @@ using namespace tvm::runtime; TVM_REGISTER_NODE_TYPE(IdNode); -Id::Id(std::string name_hint) { +Id::Id(String name_hint) { ObjectPtr n = make_object(); n->name_hint = std::move(name_hint); data_ = std::move(n); } -TVM_REGISTER_GLOBAL("ir.NodeSetSpan") -.set_body_typed([](ObjectRef node_ref, Span sp) { +TVM_REGISTER_GLOBAL("ir.NodeSetSpan").set_body_typed([](ObjectRef node_ref, Span sp) { if (auto* rn = node_ref.as()) { rn->span = sp; } else if (auto* rn = node_ref.as()) { diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc new file mode 100644 index 000000000000..c9bf11e884ab --- /dev/null +++ b/src/relay/ir/dataflow_matcher.cc @@ -0,0 +1,823 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/tvm/relay/dataflow_matcher.cc + * \brief The dataflow pattern matcher for Relay. + */ + +#include +#include +#include +#include + +#include + +#include "indexed_graph.h" + +namespace tvm { +namespace relay { + +// Pattern Matcher + +class DominatorMatcher; + +class DFPatternMatcher : public DFPatternFunctor { + public: + explicit DFPatternMatcher(const Expr& root_expr) : expr_graph_(CreateIndexedGraph(root_expr)) {} + bool Match(const DFPattern& pattern, const Expr& expr); + Map> GetMemo() { return Map>(memo_); } + const IndexedGraph expr_graph_; + + protected: + bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override; + bool VisitDFPattern_(const AltPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const AttrPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const CallPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const ConstantPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const DataTypePatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const ShapePatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const TypePatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const VarPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) override; + + void ClearMap(size_t watermark); + bool MatchesPath(const DominatorPatternNode* op, const Expr& expr); + bool DominatesParent(const DominatorPatternNode* op, const Expr& expr); + + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> memo_; + std::vector matched_nodes_; + bool memoize_ = true; +}; + +bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) { + memo_.clear(); + matched_nodes_.clear(); + return VisitDFPattern(pattern, expr); +} + +void DFPatternMatcher::ClearMap(size_t watermark) { + for (size_t i = watermark; i < matched_nodes_.size(); ++i) { + memo_.erase(matched_nodes_[i]); + } + matched_nodes_.erase(matched_nodes_.begin() + watermark, matched_nodes_.end()); +} + +bool DFPatternMatcher::VisitDFPattern(const DFPattern& pattern, const Expr& expr) { + if (memoize_ && memo_.count(pattern)) { + CHECK_EQ(memo_[pattern].size(), 1); + return expr.same_as(memo_[pattern][0]); + } else { + auto watermark = matched_nodes_.size(); + auto out = DFPatternFunctor::VisitDFPattern(pattern, expr); + if (out) { + memo_[pattern].push_back(expr); + matched_nodes_.push_back(pattern); + } else { + ClearMap(watermark); + } + return out; + } +} + +bool DFPatternMatcher::VisitDFPattern_(const AltPatternNode* op, const Expr& expr) { + return VisitDFPattern(op->left, expr) || VisitDFPattern(op->right, expr); +} + +bool MatchRetValue(const ObjectRef& lhs, const TVMRetValue& rhs) { + switch (rhs.type_code()) { + case kDLInt: + if (auto* val = lhs.as()) { + return val->value == rhs.operator int64_t(); + } + break; + case kDLFloat: + if (auto* val = lhs.as()) { + return val->value == rhs.operator double(); + } + break; + case kTVMStr: + if (auto* val = lhs.as()) { + return val->value == rhs.operator std::string(); + } else if (auto* val = lhs.as()) { + return val->data == rhs.operator std::string(); + } + break; + case kTVMObjectHandle: + if (rhs.IsObjectRef()) { + if (auto* val = lhs.as()) { + return rhs.operator String() == val->value; + } else if (auto* val = lhs.as()) { + return rhs.operator String() == val->data; + } + } + break; + default: + CHECK(false) << "Unsupported type code in Pattern Node " << rhs.type_code(); + } + return false; +} + +bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, const Expr& expr) { + bool matches = false; + auto attributes = attr_pattern->attrs.as()->dict; + if (const auto* op_node = expr.as()) { + Op op = GetRef(op_node); + for (auto kv : attributes) { + auto attr_name = kv.first; + auto attr_value = kv.second; + auto op_map = Op::GetAttrMap(attr_name); + if (op_map.count(op)) { + matches = MatchRetValue(attr_value, op_map[op]); + } + } + } else if (auto* op = expr.as()) { + matches = true; + // TODO(mbrookhart): When OpNode Attrs move from TVMRetValue to the Object system, remove this + // and replace the whole thing with a Visitor-based approach + ReflectionVTable* reflection = ReflectionVTable::Global(); + auto attrs_node = const_cast(op->attrs.get()); + auto attr_names = reflection->ListAttrNames(attrs_node); + for (auto kv : attributes) { + std::string attr = kv.first; + if (matches && std::find(attr_names.begin(), attr_names.end(), attr) != attr_names.end()) { + matches &= MatchRetValue(kv.second, reflection->GetAttr(attrs_node, attr)); + } else { + matches = false; + break; + } + } + } else if (auto* op = expr.as()) { + matches = true; + for (auto kv : attributes) { + if (matches && op->attrs.defined() && op->attrs->dict.count(kv.first)) { + matches &= StructuralEqual()(kv.second, op->attrs->dict[kv.first]); + } else { + matches = false; + break; + } + } + } + return matches && VisitDFPattern(attr_pattern->pattern, expr); +} + +Array reverse(const Array& args) { + Array new_args; + for (auto it = args.rbegin(); it != args.rend(); ++it) { + new_args.push_back(*it); + } + return new_args; +} + +bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& expr) { + // utilities + auto get_op_node = [](const CallPatternNode* op) -> const tvm::OpNode* { + if (op) { + if (auto* expr_pattern = op->op.as()) { + return expr_pattern->expr.as(); + } + } + return nullptr; + }; + auto is_pattern_op = [&get_op_node](const CallPatternNode* op, std::string op_type) { + if (const auto* op_node = get_op_node(op)) { + if (op_node->name == op_type) { + return true; + } + } + return false; + }; + auto is_expr_op = [](const Expr& expr, std::string op_type) { + if (const auto* call_node = expr.as()) { + if (const auto* op_node = call_node->op.as()) { + if (op_node->name == op_type) { + return true; + } + } + } + return false; + }; + // logic + auto watermark = matched_nodes_.size(); + if (const auto* call_node = expr.as()) { + auto matches_op = VisitDFPattern(op->op, call_node->op); + if (matches_op) { + auto watermark2 = matched_nodes_.size(); + + auto match_args = [this, &watermark2](const Array pattern_args, + const Array expr_args) { + bool matches = true; + size_t i = 0; + if (pattern_args.size() == expr_args.size()) { + while (matches && i < pattern_args.size()) { + matches &= VisitDFPattern(pattern_args[i], expr_args[i]); + ++i; + } + } else { + matches = false; + } + if (!matches) { + ClearMap(watermark2); + } + return matches; + }; + + // Standard case + if (match_args(op->args, call_node->args)) { + return true; + } + // Commutative Matching + if (const OpNode* op_node = get_op_node(op)) { + if ((op_node->name == "add") || (op_node->name == "multiply")) { + if (match_args(reverse(op->args), call_node->args)) { + return true; + } + } + } + } else { + ClearMap(watermark); + // associate divide/multiply + if (is_pattern_op(op, "divide")) { + if (const auto* arg_node = op->args[0].as()) { + if (is_pattern_op(arg_node, "multiply") && is_expr_op(expr, "multiply") && + (is_expr_op(call_node->args[0], "divide") || + is_expr_op(call_node->args[1], "divide"))) { + bool out = false; + for (size_t arg_id = 0; arg_id < 2; ++arg_id) { + auto div = CallPattern(op->op, {arg_node->args[arg_id], op->args[1]}, op->attrs, + op->type_args); + auto mul = CallPattern(arg_node->op, {arg_node->args[(arg_id + 1) % 2], div}, + arg_node->attrs, arg_node->type_args); + out = VisitDFPattern(mul, expr); + if (out) { + return true; + } else { + ClearMap(watermark); + } + } + return out; + } + } + } + if (is_pattern_op(op, "multiply")) { + // associate multiply/divide + for (size_t arg_id = 0; arg_id < 2; ++arg_id) { + if (auto* arg_node = op->args[arg_id].as()) { + if (is_pattern_op(arg_node, "divide") && is_expr_op(expr, "divide") && + (is_expr_op(call_node->args[0], "multiply") || + is_expr_op(call_node->args[1], "multiply"))) { + auto mul = CallPattern(op->op, {arg_node->args[0], op->args[(arg_id + 1) % 2]}, + op->attrs, op->type_args); + auto div = CallPattern(arg_node->op, {mul, arg_node->args[1]}, arg_node->attrs, + arg_node->type_args); + return VisitDFPattern(div, expr); + } + } + } + } + } + } + return false; +} + +// Recursively find the Dominator parent along all inputs paths. +bool DFPatternMatcher::MatchesPath(const DominatorPatternNode* op, const Expr& expr) { + auto call_node = expr.as(); + for (auto node : expr_graph_.node_map_.at(expr)->inputs_) { + if (!(call_node && node->ref_ == call_node->op)) { + memoize_ = true; + if (VisitDFPattern(op->parent, node->ref_)) { + return true; + } else { + memoize_ = false; + if (!VisitDFPattern(op->path, node->ref_) || !MatchesPath(op, node->ref_)) { + return false; + } + } + } + } + return true; +} + +// Iteratively ensure that the parent is dominated somewhere by the child or the path +bool DFPatternMatcher::DominatesParent(const DominatorPatternNode* op, const Expr& expr) { + std::stack stack; + std::unordered_set visited; + stack.push(expr); + while (!stack.empty()) { + Expr current = stack.top(); + stack.pop(); + for (auto node : expr_graph_.node_map_.at(current)->dominator_children_) { + if (visited.count(node->ref_) == 0) { + if (VisitDFPattern(op->parent, node->ref_)) { + return true; + } else { + stack.push(node->ref_); + } + visited.insert(node->ref_); + } + } + } + return false; +} + +bool DFPatternMatcher::VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) { + if (VisitDFPattern(op->child, expr)) { + bool matches_path = MatchesPath(op, expr); + memoize_ = true; + if (matches_path) { + return DominatesParent(op, expr); + } + } + return false; +} + +bool DFPatternMatcher::VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) { + return StructuralEqual()(op->expr, expr); +} + +bool DFPatternMatcher::VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) { + bool matches = false; + if (const auto* tuple_get_item_node = expr.as()) { + matches = (op->index == tuple_get_item_node->index) && + VisitDFPattern(op->tuple, tuple_get_item_node->tuple); + } + return matches; +} + +bool DFPatternMatcher::VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) { + bool matches = false; + if (const auto* tuple_node = expr.as()) { + if (op->fields.size() == tuple_node->fields.size()) { + matches = true; + size_t i = 0; + while (matches && i < op->fields.size()) { + matches &= VisitDFPattern(op->fields[i], tuple_node->fields[i]); + ++i; + } + } + } + return matches; +} + +Expr InferType(const Expr& expr) { + auto mod = IRModule::FromExpr(expr); + mod = transform::InferType()(mod); + if (expr.as()) { + return mod->Lookup("main"); + } else { + return mod->Lookup("main").as()->body; + } +} + +bool DFPatternMatcher::VisitDFPattern_(const TypePatternNode* op, const Expr& expr) { + auto expr_type = InferType(expr).as()->checked_type(); + return (StructuralEqual()(op->type, expr_type)) && VisitDFPattern(op->pattern, expr); +} + +bool DFPatternMatcher::VisitDFPattern_(const ShapePatternNode* op, const Expr& expr) { + auto expr_type = InferType(expr).as()->checked_type(); + if (const TensorTypeNode* tensor_type = expr_type.as()) { + return (StructuralEqual()(op->shape, tensor_type->shape)) && VisitDFPattern(op->pattern, expr); + } + return false; +} + +bool DFPatternMatcher::VisitDFPattern_(const DataTypePatternNode* op, const Expr& expr) { + auto expr_type = InferType(expr).as()->checked_type(); + if (const TensorTypeNode* tensor_type = expr_type.as()) { + return (StructuralEqual()(op->dtype, tensor_type->dtype)) && VisitDFPattern(op->pattern, expr); + } + return false; +} + +bool DFPatternMatcher::VisitDFPattern_(const VarPatternNode* op, const Expr& expr) { + bool matches = false; + if (const auto* var_node = expr.as()) { + matches = true; + if (op->name_hint() != "") { + matches &= op->name_hint() == var_node->name_hint(); + } + } + return matches; +} + +bool DFPatternMatcher::VisitDFPattern_(const ConstantPatternNode* op, const Expr& expr) { + return expr.as() != nullptr; +} + +bool DFPatternMatcher::VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) { + return true; +} + +bool MatchPattern(DFPattern pattern, Expr expr) { + return DFPatternMatcher(expr).Match(pattern, expr); +} + +TVM_REGISTER_GLOBAL("relay.dataflow_pattern.match").set_body_typed(MatchPattern); + +/* \brief PatternGrouper does pre-rewriting pattern matching and analysis + * + * This class creates a number of groups of matched expressions, ensures they don't overlap, and + * returns them to the caller for post-analysis rewriting. + * + * This is primarily needed to support the post-dominator analysis required for dominator pattern + * matching. + */ +class PatternGrouper { + public: + /* \brief Internal Group class for storing analysis */ + struct Group { + Expr root_node; + int gid; + Map> matched_nodes; + std::string name; + Function function; + Array args; + }; + + /* \brief Return the group assignments of expressions */ + const std::unordered_map& GetGIDAssignments() { + return gid_assignments_; + } + /* \brief Group expressions that match the pattern */ + const std::vector& GroupMatches(const DFPattern& pattern, const Expr& pre) { + groups_ = {Group()}; + gid_assignments_.clear(); + + pattern_ = pattern; + pattern_graph_ = CreateIndexedGraph(pattern_); + auto matcher = DFPatternMatcher(pre); + matcher_ = &matcher; + this->VisitExprs(); + return this->groups_; + } + + protected: + /* \brief Iteratively traverse the Expression in pre-order to find subgraphs + * + * If we traverse the graph in post-order, we can run into situtations where a small subgraph will + * match the pattern. Due to options like AltPattern, a larger subgraph with more nodes later in + * the graph may also match the pattern. With post-order traversal, we mark the smaller subgraph + * as matched and fail to catch the larger subgraph. This problem is fixed by using pre-order + * traversal. + */ + void VisitExprs() { + std::unordered_set pre_partitioned; + for (size_t i = matcher_->expr_graph_.topological_order_.size(); i != 0; --i) { + size_t index = i - 1; + Expr current = matcher_->expr_graph_.topological_order_.at(index)->ref_; + if (auto op = current.as()) { + if (op->attrs.defined() && op->attrs->dict.count(attr::kPartitionedFromPattern) != 0) { + pre_partitioned.insert(current); + PostOrderVisit(op->body, + [&pre_partitioned](const Expr& expr) { pre_partitioned.insert(expr); }); + } + } + if (pre_partitioned.count(current) == 0 && matcher_->Match(pattern_, current)) { + CreateGroup(current); + } + } + } + /* \brief Creates a new set of nodes based on Group inputs, used to create functions and perform + * group overlap analysis */ + class MatchExtractor : public ExprMutator { + public: + explicit MatchExtractor( + const std::unordered_map& inputs) + : inputs_(inputs) {} + const std::unordered_map& GetMemo() { + return this->memo_; + } + const std::string& GetName() { return name_; } + + protected: + Expr VisitExpr(const Expr& pre) override { + if (inputs_.count(pre)) { + return inputs_.at(pre); + } + return ExprMutator::VisitExpr(pre); + } + Expr VisitExpr_(const TupleNode* op) override { + auto out = ExprMutator::VisitExpr_(op); + name_ += "Tuple_"; + return out; + }; + Expr VisitExpr_(const FunctionNode* op) override { + auto out = ExprMutator::VisitExpr_(op); + name_ += "Function"; + return out; + }; + Expr VisitExpr_(const CallNode* call_node) override { + auto out = ExprMutator::VisitExpr_(call_node); + if (auto operation = call_node->op.as()) { + name_ += operation->name + "_"; + } else { + name_ += "Call_"; + } + return out; + }; + Expr VisitExpr_(const LetNode* op) override { + auto out = ExprMutator::VisitExpr_(op); + name_ += "Let_"; + return out; + }; + Expr VisitExpr_(const IfNode* op) override { + auto out = ExprMutator::VisitExpr_(op); + name_ += "If_"; + return out; + }; + Expr VisitExpr_(const TupleGetItemNode* op) override { + auto out = ExprMutator::VisitExpr_(op); + name_ += "TupleGetItem" + std::to_string(op->index) + "_"; + return out; + }; + Expr VisitExpr_(const MatchNode* op) override { + auto out = ExprMutator::VisitExpr_(op); + name_ += "Match_"; + return out; + }; + std::string name_; + const std::unordered_map inputs_; + }; + + /* \brief Create a group based on a matched expression */ + void CreateGroup(const Expr& expr) { + int var_number = 0; + + auto node_map = matcher_->GetMemo(); + + // Get fuzzy patterns + std::unordered_set fuzzy_matches; + for (auto node : pattern_graph_.topological_order_) { + if (auto op = node->ref_.as()) { + for (auto fuzzy_op : {op->parent, op->path}) { + for (auto match : node_map[fuzzy_op]) { + fuzzy_matches.insert(match); + } + } + } + } + + // Create input variables + Group group; + group.root_node = expr; + group.matched_nodes = node_map; + + std::unordered_map inputs; + Array params; + for (auto node : pattern_graph_.topological_order_) { + if (node->inputs_.size() == 0) { + if (node_map.count(node->ref_)) { + auto matches = node_map[node->ref_]; + for (auto match : matches) { + if (fuzzy_matches.count(match) == 0 && match.as() == nullptr && + match.as() == nullptr && !EmbedConst(match, node->ref_)) { + inputs[match] = Var( + "FunctionVar_" + std::to_string(graph_number_) + "_" + std::to_string(var_number), + NullValue()); + group.args.push_back(match); + params.push_back(inputs[match]); + var_number++; + } + } + } + } + } + + graph_number_++; + + // Extract a Function. Used in Partition directly, + // used to determine Group overlap in other passes + auto extractor = MatchExtractor(inputs); + auto body = extractor.Mutate(expr); + + // Verify the pattern still holds + CHECK(DFPatternMatcher(body).Match(pattern_, body)); + group.function = Function(params, body, NullValue(), Array()); + group.name = extractor.GetName(); + // Check to make sure we aren't overlapping with another group + // The MatchExtractor will create a new graph by replacing nodes that match the inputs of the + // pattern with the input FunctionVar* Variables. The resulting memoization map will only + // contain nodes in the expression that matched the pattern. If a non-input node of the pattern + // (i.e., some piece of computation) overlaps with the nodes in a previous group, we'll have a + // situation where we try to rewrite the same node twice in the second rewriting or parition + // pass. This isn't valid, so we check for it here. We ignore Ops, functions, and constants + // because they exist more globally outside of the fusion. + for (auto kv : extractor.GetMemo()) { + if (gid_assignments_.count(kv.first) != 0 && inputs.count(kv.first) == 0 && + kv.first.as() == nullptr && kv.first.as() == nullptr && + kv.first.as() == nullptr) { + // Exit due to overlapping partitions + return; + } + } + // Assign Group Ids + group.gid = ++gid_; + for (auto kv : extractor.GetMemo()) { + gid_assignments_[kv.first] = gid_; + } + + // Save Group + groups_.emplace_back(std::move(group)); + CHECK_EQ(groups_[gid_].gid, gid_); + } + + /* \brief EmbedConst implements rules for embedding constants into partitioned functions or + * lifting them into the function arguments. + * + * The rules depend on what pattern the ConstantNode matched. + * + * The basic rules are: + * If the constant matches ExprPattern(relay.const(*)) or a ConstantPattern(), embed the constant + * in the partitioned function. If the constant matched an AltPattern, recursively check the + * matched side of the pattern. For any other matching pattern (i.e, wildcard, VarPattern, etc), + * lift the constant into the arguments of the partitioned function. + */ + bool EmbedConst(const Expr& expr, const DFPattern pattern) { + bool embed = false; + if (expr.as()) { + if (pattern.as() != nullptr) { + embed = true; + } else if (auto expr_pat = pattern.as()) { + if (expr_pat->expr.as()) { + embed = true; + } + } else if (auto alt_pat = pattern.as()) { + if (matcher_->Match(alt_pat->left, expr)) { + embed = EmbedConst(expr, alt_pat->left); + } else { + embed = EmbedConst(expr, alt_pat->right); + } + } + } + return embed; + } + // Internal State + DFPattern pattern_; + std::vector groups_; + std::unordered_map gid_assignments_; + DFPatternMatcher* matcher_ = nullptr; + IndexedGraph pattern_graph_; + int gid_ = 0; + int graph_number_ = 0; +}; + +// Rewrite + +DFPatternCallback::DFPatternCallback(DFPattern pattern, PackedFunc function) { + ObjectPtr n = make_object(); + n->pattern_ = std::move(pattern); + n->function_ = std::move(function); + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(DFPatternCallbackNode); + +TVM_REGISTER_GLOBAL("relay.dataflow_pattern.DFPatternCallback") + .set_body_typed([](DFPattern pattern, PackedFunc function) { + return DFPatternCallback(pattern, function); + }); + +/* \brief PatternRewriter rewrites the expression by finding matches and allowing user callback + * function to rewrite those matches + * + * The class uses PatternGrouper to support the dominator pattern. + */ +class PatternRewriter : protected MixedModeMutator { + public: + PatternRewriter() {} + /*! \brief Rewrite can take a number of callbacks and will repeatedly rewrite the graph with the + * callbacks until it stops changing */ + Expr Rewrite(const Array& callbacks, const Expr& pre) { + auto post = pre; + auto last = post; + // rewrite the graph until it stops changing to make sure all rewrites are complete + int count = 0; + do { + last = post; + for (auto callback : callbacks) { + callback_ = callback; + auto grouper = PatternGrouper(); + groups_ = grouper.GroupMatches(callback_->pattern_, post); + gid_assignments_ = grouper.GetGIDAssignments(); + memo_.clear(); + post = this->VisitExpr(post); + count++; + } + } while (last != post || count >= 100); + if (count >= 100) { + throw("Observed 100 rewrite passes, possible conflicting passes?"); + } + return post; + } + + protected: + Expr DispatchVisitExpr(const Expr& pre) override { + auto post = MixedModeMutator::DispatchVisitExpr(pre); + if (gid_assignments_.count(pre) && pre == groups_[gid_assignments_[pre]].root_node) { + // Convert the pre-rewrite node map to a post-rewrite node map + auto group = groups_[gid_assignments_[pre]]; + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> node_map; + for (auto kv : group.matched_nodes) { + Array tmp; + for (size_t i = 0; i < kv.second.size(); ++i) { + tmp.push_back(this->memo_[kv.second[i]]); + } + node_map.insert({kv.first, tmp}); + } + // run the user callback function + return callback_->function_(pre, post, Map>(node_map)); + } + return post; + } + + DFPatternCallback callback_; + std::vector groups_; + std::unordered_map gid_assignments_; +}; + +Expr RewritePatterns(Array callbacks, Expr expr) { + return PatternRewriter().Rewrite(callbacks, expr); +} + +TVM_REGISTER_GLOBAL("relay.dataflow_pattern.rewrite").set_body_typed(RewritePatterns); + +/* \brief PatternPartitioner replaces expressions that match a pattern with function call that + * perform the same computation but allow for further analysis and lowering. + * + * The class uses PatternGrouper to support the dominator pattern. + */ +class PatternPartitioner : protected MixedModeMutator { + public: + Expr Partition(const DFPattern& pattern, const Expr& pre, const Map& attrs, + PackedFunc check) { + auto grouper = PatternGrouper(); + groups_ = grouper.GroupMatches(pattern, pre); + gid_assignments_ = grouper.GetGIDAssignments(); + attrs_ = attrs; + check_ = check; + return this->VisitExpr(pre); + } + + protected: + Expr RewritePartition(const PatternGrouper::Group& group) { + Array args; + for (size_t i = 0; i < group.args.size(); ++i) { + args.push_back(memo_[group.args[i]]); + } + Function func = WithAttr(group.function, attr::kPartitionedFromPattern, String(group.name)); + if (!attrs_.empty()) { + for (auto kv : attrs_) { + func = WithAttr(std::move(func), kv.first, kv.second); + } + } + return Call(func, args); + } + + Expr DispatchVisitExpr(const Expr& pre) override { + auto post = MixedModeMutator::DispatchVisitExpr(pre); + if (gid_assignments_.count(pre) && pre == groups_[gid_assignments_[pre]].root_node && + static_cast(check_(pre))) { + post = RewritePartition(groups_[gid_assignments_[pre]]); + } + return post; + } + + Map attrs_; + std::vector groups_; + std::unordered_map gid_assignments_; + PackedFunc check_; +}; + +Expr PartitionPattern(DFPattern pattern, Expr expr, Map attrs, + PackedFunc check) { + return PatternPartitioner().Partition(pattern, expr, attrs, check); +} + +TVM_REGISTER_GLOBAL("relay.dataflow_pattern.partition") + .set_body_typed([](DFPattern pattern, Expr expr, Map attrs, + PackedFunc check) { return PartitionPattern(pattern, expr, attrs, check); }); + +} // namespace relay +} // namespace tvm diff --git a/src/relay/ir/dataflow_pattern.cc b/src/relay/ir/dataflow_pattern.cc new file mode 100644 index 000000000000..4664e5fc8168 --- /dev/null +++ b/src/relay/ir/dataflow_pattern.cc @@ -0,0 +1,271 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/tvm/relay/dataflow_pattern.cc + * \brief The dataflow pattern language for Relay. + */ +#include + +namespace tvm { +namespace relay { + +ExprPattern::ExprPattern(Expr expr) { + ObjectPtr n = make_object(); + n->expr = std::move(expr); + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(ExprPatternNode); + +TVM_REGISTER_GLOBAL("relay.dataflow_pattern.ExprPattern").set_body_typed([](Expr e) { + return ExprPattern(e); +}); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->Print(node->expr); + }); + +VarPattern::VarPattern(String name_hint, Type type_annotation) { + ObjectPtr n = make_object(); + n->name = std::move(name_hint); + n->type_annotation = std::move(type_annotation); + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(VarPatternNode); + +TVM_REGISTER_GLOBAL("relay.dataflow_pattern.VarPattern") + .set_body_typed([](String name_hint, Type type_annotation) { + return VarPattern(name_hint, type_annotation); + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "VarPattern(" << node->name_hint(); + if (node->type_annotation.defined()) { + p->stream << ", ty="; + p->Print(node->type_annotation); + } + p->stream << ")"; + }); + +TVM_REGISTER_NODE_TYPE(ConstantPatternNode); + +TVM_REGISTER_GLOBAL("relay.dataflow_pattern.ConstantPattern").set_body_typed([]() { + auto c = ConstantPattern(make_object()); + return c; +}); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + p->stream << "ConstantPattern()"; + }); + +CallPattern::CallPattern(DFPattern op, Array args, Attrs attrs, Array type_args) { + ObjectPtr n = make_object(); + n->op = std::move(op); + n->args = std::move(args); + n->attrs = std::move(attrs); + n->type_args = std::move(type_args); + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(CallPatternNode); + +TVM_REGISTER_GLOBAL("relay.dataflow_pattern.CallPattern") + .set_body_typed([](DFPattern op, Array args, Attrs attrs, Array type_args) { + return CallPattern(op, args, attrs, type_args); + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "CallPatternNode(" << node->op << ", " << node->args << ", " << node->attrs + << ", " << node->type_args << ")"; + }); + +TuplePattern::TuplePattern(tvm::Array fields) { + ObjectPtr n = make_object(); + n->fields = std::move(fields); + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(TuplePatternNode); + +TVM_REGISTER_GLOBAL("relay.dataflow_pattern.TuplePattern") + .set_body_typed([](tvm::Array fields) { return TuplePattern(fields); }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "TuplePattern(" << node->fields << ")"; + }); + +TupleGetItemPattern::TupleGetItemPattern(DFPattern tuple, int index) { + ObjectPtr n = make_object(); + n->tuple = std::move(tuple); + n->index = index; + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(TupleGetItemPatternNode); + +TVM_REGISTER_GLOBAL("relay.dataflow_pattern.TupleGetItemPattern") + .set_body_typed([](DFPattern tuple, int index) { return TupleGetItemPattern(tuple, index); }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "TupleGetItemPatternNode(" << node->tuple << ", " << node->index << ")"; + }); + +AltPattern::AltPattern(DFPattern left, DFPattern right) { + ObjectPtr n = make_object(); + n->left = std::move(left); + n->right = std::move(right); + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(AltPatternNode); + +TVM_REGISTER_GLOBAL("relay.dataflow_pattern.AltPattern") + .set_body_typed([](DFPattern left, DFPattern right) { return AltPattern(left, right); }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "AltPattern(" << node->left << " | " << node->right << ")"; + }); + +TVM_REGISTER_NODE_TYPE(WildcardPatternNode); + +TVM_REGISTER_GLOBAL("relay.dataflow_pattern.WildcardPattern").set_body_typed([]() { + auto w = WildcardPattern(make_object()); + return w; +}); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + p->stream << "*"; + }); + +TypePattern::TypePattern(DFPattern pattern, Type type) { + ObjectPtr n = make_object(); + n->pattern = std::move(pattern); + n->type = std::move(type); + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(TypePatternNode); + +TVM_REGISTER_GLOBAL("relay.dataflow_pattern.TypePattern") + .set_body_typed([](DFPattern pattern, Type type) { return TypePattern(pattern, type); }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "TypePattern(" << node->pattern << " has type " << node->type << ")"; + }); + +ShapePattern::ShapePattern(DFPattern pattern, Array shape) { + ObjectPtr n = make_object(); + n->pattern = std::move(pattern); + n->shape = std::move(shape); + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(ShapePatternNode); + +TVM_REGISTER_GLOBAL("relay.dataflow_pattern.ShapePattern") + .set_body_typed([](DFPattern pattern, Array shape) { + return ShapePattern(pattern, shape); + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "ShapePattern(" << node->pattern << " has shape " << node->shape << ")"; + }); + +DataTypePattern::DataTypePattern(DFPattern pattern, DataType dtype) { + ObjectPtr n = make_object(); + n->pattern = std::move(pattern); + n->dtype = std::move(dtype); + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(DataTypePatternNode); + +TVM_REGISTER_GLOBAL("relay.dataflow_pattern.DataTypePattern") + .set_body_typed([](DFPattern pattern, DataType dtype) { + return DataTypePattern(pattern, dtype); + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "TypePattern(" << node->pattern << " has dtype " << node->dtype << ")"; + }); + +AttrPattern::AttrPattern(DFPattern pattern, Attrs attrs) { + ObjectPtr n = make_object(); + n->pattern = std::move(pattern); + n->attrs = std::move(attrs); + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(AttrPatternNode); + +TVM_REGISTER_GLOBAL("relay.dataflow_pattern.AttrPattern") + .set_body_typed([](DFPattern pattern, Attrs attrs) { return AttrPattern(pattern, attrs); }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "AttrPattern(" << node->pattern << " has attributes " << node->attrs << ")"; + }); + +DominatorPattern::DominatorPattern(DFPattern parent, DFPattern path, DFPattern child) { + ObjectPtr n = make_object(); + n->parent = std::move(parent); + n->path = std::move(path); + n->child = std::move(child); + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(DominatorPatternNode); + +TVM_REGISTER_GLOBAL("relay.dataflow_pattern.DominatorPattern") + .set_body_typed([](DFPattern parent, DFPattern path, DFPattern child) { + return DominatorPattern(parent, path, child); + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "DominatorPattern(" << node->parent << ", " << node->path << ", " << node->child + << ")"; + }); + +} // namespace relay +} // namespace tvm diff --git a/src/relay/ir/dataflow_pattern_functor.cc b/src/relay/ir/dataflow_pattern_functor.cc new file mode 100644 index 000000000000..7e9f828c8aa8 --- /dev/null +++ b/src/relay/ir/dataflow_pattern_functor.cc @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/tvm/relay/dataflow_matcher.cc + * \brief The dataflow pattern matcher for Relay. + */ + +#include + +namespace tvm { +namespace relay { + +// DFPatternVisitor + +void DFPatternVisitor::VisitDFPattern(const DFPattern& pattern) { + if (this->visited_.count(pattern.get()) == 0) { + visited_.insert(pattern.get()); + DFPatternFunctor::VisitDFPattern(pattern); + } +} + +void DFPatternVisitor::VisitDFPattern_(const AltPatternNode* op) { + VisitDFPattern(op->left); + VisitDFPattern(op->right); +} + +void DFPatternVisitor::VisitDFPattern_(const AttrPatternNode* op) { VisitDFPattern(op->pattern); } + +void DFPatternVisitor::VisitDFPattern_(const CallPatternNode* op) { + VisitDFPattern(op->op); + for (auto arg : op->args) { + VisitDFPattern(arg); + } +} + +void DFPatternVisitor::VisitDFPattern_(const DataTypePatternNode* op) { + VisitDFPattern(op->pattern); +} + +void DFPatternVisitor::VisitDFPattern_(const DominatorPatternNode* op) { + VisitDFPattern(op->parent); + VisitDFPattern(op->path); + VisitDFPattern(op->child); +} + +void DFPatternVisitor::VisitDFPattern_(const ExprPatternNode* op) {} + +void DFPatternVisitor::VisitDFPattern_(const ShapePatternNode* op) { VisitDFPattern(op->pattern); } + +void DFPatternVisitor::VisitDFPattern_(const TupleGetItemPatternNode* op) { + VisitDFPattern(op->tuple); +} + +void DFPatternVisitor::VisitDFPattern_(const TuplePatternNode* op) { + for (auto field : op->fields) { + VisitDFPattern(field); + } +} + +void DFPatternVisitor::VisitDFPattern_(const TypePatternNode* op) { VisitDFPattern(op->pattern); } + +void DFPatternVisitor::VisitDFPattern_(const VarPatternNode* op) {} + +void DFPatternVisitor::VisitDFPattern_(const ConstantPatternNode* op) {} + +void DFPatternVisitor::VisitDFPattern_(const WildcardPatternNode* op) {} + +} // namespace relay +} // namespace tvm diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index 169db62eee26..1d9e3cef12b7 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -38,19 +38,18 @@ Constant::Constant(runtime::NDArray data) { TVM_REGISTER_NODE_TYPE(ConstantNode); -TVM_REGISTER_GLOBAL("relay.ir.Constant") -.set_body_typed([](runtime::NDArray data) { +TVM_REGISTER_GLOBAL("relay.ir.Constant").set_body_typed([](runtime::NDArray data) { return Constant(data); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - const PackedFunc* fprint = Registry::Get("relay._constant_repr"); - CHECK(fprint) << "unable to find printing function for constants"; - std::string data = (*fprint)(GetRef(node)); - p->stream << "Constant(" << data << ")"; - }); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + const PackedFunc* fprint = Registry::Get("relay._constant_repr"); + CHECK(fprint) << "unable to find printing function for constants"; + std::string data = (*fprint)(GetRef(node)); + p->stream << "Constant(" << data << ")"; + }); TensorType ConstantNode::tensor_type() const { auto dtype = DataType(data->dtype); @@ -58,8 +57,7 @@ TensorType ConstantNode::tensor_type() const { for (int i = 0; i < data->ndim; i++) { CHECK_LE(data->shape[i], std::numeric_limits::max()); CHECK_GE(data->shape[i], std::numeric_limits::min()); - shape.push_back( - tvm::IntImm(DataType::Int(32), data->shape[i])); + shape.push_back(tvm::IntImm(DataType::Int(32), data->shape[i])); } return TensorType(shape, dtype); @@ -73,17 +71,15 @@ Tuple::Tuple(tvm::Array fields) { TVM_REGISTER_NODE_TYPE(TupleNode); -TVM_REGISTER_GLOBAL("relay.ir.Tuple") -.set_body_typed([](tvm::Array fields) { +TVM_REGISTER_GLOBAL("relay.ir.Tuple").set_body_typed([](tvm::Array fields) { return Tuple(fields); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "Tuple(" << node->fields << ")"; - }); - + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "Tuple(" << node->fields << ")"; + }); Var::Var(Id vid, Type type_annotation) { ObjectPtr n = make_object(); @@ -94,21 +90,20 @@ Var::Var(Id vid, Type type_annotation) { TVM_REGISTER_NODE_TYPE(VarNode); -TVM_REGISTER_GLOBAL("relay.ir.Var") -.set_body_typed([](std::string str, Type type_annotation) { +TVM_REGISTER_GLOBAL("relay.ir.Var").set_body_typed([](String str, Type type_annotation) { return Var(str, type_annotation); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "Var(" << node->name_hint(); - if (node->type_annotation.defined()) { - p->stream << ", ty="; - p->Print(node->type_annotation); - } - p->stream << ")"; - }); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "Var(" << node->name_hint(); + if (node->type_annotation.defined()) { + p->stream << ", ty="; + p->Print(node->type_annotation); + } + p->stream << ")"; + }); Call::Call(Expr op, Array args, Attrs attrs, Array type_args) { ObjectPtr n = make_object(); @@ -122,16 +117,16 @@ Call::Call(Expr op, Array args, Attrs attrs, Array type_args) { TVM_REGISTER_NODE_TYPE(CallNode); TVM_REGISTER_GLOBAL("relay.ir.Call") -.set_body_typed([](Expr op, Array args, Attrs attrs, Array type_args) { - return Call(op, args, attrs, type_args); -}); + .set_body_typed([](Expr op, Array args, Attrs attrs, Array type_args) { + return Call(op, args, attrs, type_args); + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "CallNode(" << node->op << ", " << node->args << ", " - << node->attrs << ", " << node->type_args << ")"; - }); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "CallNode(" << node->op << ", " << node->args << ", " << node->attrs << ", " + << node->type_args << ")"; + }); Let::Let(Var var, Expr value, Expr body) { ObjectPtr n = make_object(); @@ -143,17 +138,15 @@ Let::Let(Var var, Expr value, Expr body) { TVM_REGISTER_NODE_TYPE(LetNode); -TVM_REGISTER_GLOBAL("relay.ir.Let") -.set_body_typed([](Var var, Expr value, Expr body) { +TVM_REGISTER_GLOBAL("relay.ir.Let").set_body_typed([](Var var, Expr value, Expr body) { return Let(var, value, body); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "LetNode(" << node->var << ", " << node->value - << ", " << node->body << ")"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "LetNode(" << node->var << ", " << node->value << ", " << node->body << ")"; + }); If::If(Expr cond, Expr true_branch, Expr false_branch) { ObjectPtr n = make_object(); @@ -166,16 +159,16 @@ If::If(Expr cond, Expr true_branch, Expr false_branch) { TVM_REGISTER_NODE_TYPE(IfNode); TVM_REGISTER_GLOBAL("relay.ir.If") -.set_body_typed([](Expr cond, Expr true_branch, Expr false_branch) { - return If(cond, true_branch, false_branch); -}); + .set_body_typed([](Expr cond, Expr true_branch, Expr false_branch) { + return If(cond, true_branch, false_branch); + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "IfNode(" << node->cond << ", " << node->true_branch - << ", " << node->false_branch << ")"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "IfNode(" << node->cond << ", " << node->true_branch << ", " + << node->false_branch << ")"; + }); TupleGetItem::TupleGetItem(Expr tuple, int index) { ObjectPtr n = make_object(); @@ -186,16 +179,15 @@ TupleGetItem::TupleGetItem(Expr tuple, int index) { TVM_REGISTER_NODE_TYPE(TupleGetItemNode); -TVM_REGISTER_GLOBAL("relay.ir.TupleGetItem") -.set_body_typed([](Expr tuple, int index) { +TVM_REGISTER_GLOBAL("relay.ir.TupleGetItem").set_body_typed([](Expr tuple, int index) { return TupleGetItem(tuple, index); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "TupleGetItemNode(" << node->tuple << ", " << node->index << ")"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "TupleGetItemNode(" << node->tuple << ", " << node->index << ")"; + }); RefCreate::RefCreate(Expr value) { ObjectPtr n = make_object(); @@ -205,16 +197,15 @@ RefCreate::RefCreate(Expr value) { TVM_REGISTER_NODE_TYPE(RefCreateNode); -TVM_REGISTER_GLOBAL("relay.ir.RefCreate") -.set_body_typed([](Expr value) { +TVM_REGISTER_GLOBAL("relay.ir.RefCreate").set_body_typed([](Expr value) { return RefCreate(value); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "RefCreateNode(" << node->value << ")"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "RefCreateNode(" << node->value << ")"; + }); RefRead::RefRead(Expr ref) { ObjectPtr n = make_object(); @@ -224,16 +215,13 @@ RefRead::RefRead(Expr ref) { TVM_REGISTER_NODE_TYPE(RefReadNode); -TVM_REGISTER_GLOBAL("relay.ir.RefRead") -.set_body_typed([](Expr ref) { - return RefRead(ref); -}); +TVM_REGISTER_GLOBAL("relay.ir.RefRead").set_body_typed([](Expr ref) { return RefRead(ref); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "RefReadNode(" << node->ref << ")"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "RefReadNode(" << node->ref << ")"; + }); RefWrite::RefWrite(Expr ref, Expr value) { ObjectPtr n = make_object(); @@ -244,24 +232,21 @@ RefWrite::RefWrite(Expr ref, Expr value) { TVM_REGISTER_NODE_TYPE(RefWriteNode); -TVM_REGISTER_GLOBAL("relay.ir.RefWrite") -.set_body_typed([](Expr ref, Expr value) { +TVM_REGISTER_GLOBAL("relay.ir.RefWrite").set_body_typed([](Expr ref, Expr value) { return RefWrite(ref, value); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "RefWriteNode(" << node->ref << ", " << node->value << ")"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "RefWriteNode(" << node->ref << ", " << node->value << ")"; + }); -TVM_REGISTER_GLOBAL("relay.ir.TempExprRealize") -.set_body_typed([](TempExpr temp) { +TVM_REGISTER_GLOBAL("relay.ir.TempExprRealize").set_body_typed([](TempExpr temp) { return temp->Realize(); }); -TVM_REGISTER_GLOBAL("relay.ir.Any") -.set_body_typed([]() { return Any::make(); }); +TVM_REGISTER_GLOBAL("relay.ir.Any").set_body_typed([]() { return Any(); }); } // namespace relay } // namespace tvm diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc index cb5d06f2932c..5b68ff1a0034 100644 --- a/src/relay/ir/expr_functor.cc +++ b/src/relay/ir/expr_functor.cc @@ -142,7 +142,8 @@ void MixedModeVisitor::VisitExpr_(const TupleGetItemNode* op) {} void MixedModeMutator::VisitLeaf(const Expr& expr) { if (!memo_.count(expr)) { - this->DispatchVisitExpr(expr); + Expr ret = this->DispatchVisitExpr(expr); + memo_[expr] = ret; } } @@ -154,9 +155,7 @@ bool MixedModeMutator::CheckVisited(const Expr& expr) { } } -Expr MixedModeMutator::DispatchVisitExpr(const Expr& expr) { - return ExprMutator::VisitExpr(expr); -} +Expr MixedModeMutator::DispatchVisitExpr(const Expr& expr) { return ExprMutator::VisitExpr(expr); } Expr MixedModeMutator::VisitExpr(const Expr& expr) { auto fcheck_visited = [this](const Expr& expr) { return this->CheckVisited(expr); }; @@ -165,9 +164,7 @@ Expr MixedModeMutator::VisitExpr(const Expr& expr) { return memo_[expr]; } else { ExpandDataflow(expr, fcheck_visited, fvisit_leaf); - Expr ret = this->DispatchVisitExpr(expr); - memo_[expr] = ret; - return ret; + return memo_[expr]; } } @@ -178,6 +175,7 @@ class PostOrderRewriter : public MixedModeMutator { auto post = ExprFunctor::VisitExpr(expr); return rewriter_->Rewrite(expr, post); } + protected: ExprRewriter* rewriter_; }; @@ -208,17 +206,11 @@ Expr ExprMutator::VisitExpr_(const VarNode* op) { return GetRef(op); } -Expr ExprMutator::VisitExpr_(const ConstantNode* op) { - return GetRef(op); -} +Expr ExprMutator::VisitExpr_(const ConstantNode* op) { return GetRef(op); } -Expr ExprMutator::VisitExpr_(const GlobalVarNode* op) { - return GetRef(op); -} +Expr ExprMutator::VisitExpr_(const GlobalVarNode* op) { return GetRef(op); } -Expr ExprMutator::VisitExpr_(const OpNode* op) { - return GetRef(op); -} +Expr ExprMutator::VisitExpr_(const OpNode* op) { return GetRef(op); } Expr ExprMutator::VisitExpr_(const TupleNode* op) { tvm::Array fields; @@ -257,9 +249,7 @@ Expr ExprMutator::VisitExpr_(const FunctionNode* op) { auto ret_type = this->VisitType(op->ret_type); auto body = this->Mutate(op->body); - if (all_ty_params_unchanged && - all_params_unchanged && - ret_type.same_as(op->ret_type) && + if (all_ty_params_unchanged && all_params_unchanged && ret_type.same_as(op->ret_type) && body.same_as(op->body)) { return GetRef(op); } else { @@ -297,9 +287,7 @@ Expr ExprMutator::VisitExpr_(const LetNode* op) { auto value = this->Mutate(op->value); auto body = this->Mutate(op->body); - if (var.same_as(op->var) && - value.same_as(op->value) && - body.same_as(op->body)) { + if (var.same_as(op->var) && value.same_as(op->value) && body.same_as(op->body)) { return GetRef(op); } else { return Let(var, value, body); @@ -310,10 +298,9 @@ Expr ExprMutator::VisitExpr_(const IfNode* op) { auto guard = this->Mutate(op->cond); auto true_b = this->Mutate(op->true_branch); auto false_b = this->Mutate(op->false_branch); - if (op->cond.same_as(guard) && - op->true_branch.same_as(true_b) && + if (op->cond.same_as(guard) && op->true_branch.same_as(true_b) && op->false_branch.same_as(false_b)) { - return GetRef(op);; + return GetRef(op); } else { return If(guard, true_b, false_b); } @@ -356,21 +343,31 @@ Expr ExprMutator::VisitExpr_(const RefWriteNode* op) { } } -Expr ExprMutator::VisitExpr_(const ConstructorNode* c) { - return GetRef(c); -} +Expr ExprMutator::VisitExpr_(const ConstructorNode* c) { return GetRef(c); } Expr ExprMutator::VisitExpr_(const MatchNode* m) { + bool unchanged = true; std::vector clauses; for (const Clause& p : m->clauses) { - clauses.push_back(VisitClause(p)); + Clause c = VisitClause(p); + clauses.push_back(c); + unchanged &= c.same_as(p); } - return Match(Mutate(m->data), clauses, m->complete); + Expr data = Mutate(m->data); + unchanged &= data.same_as(m->data); + if (unchanged) { + return GetRef(m); + } + return Match(data, clauses, m->complete); } Clause ExprMutator::VisitClause(const Clause& c) { Pattern p = VisitPattern(c->lhs); - return Clause(p, Mutate(c->rhs)); + Expr rhs = Mutate(c->rhs); + if (p.same_as(c->lhs) && rhs.same_as(c->rhs)) { + return c; + } + return Clause(p, rhs); } Pattern ExprMutator::VisitPattern(const Pattern& p) { return p; } @@ -394,11 +391,9 @@ void ExprVisitor::ExprVisitor::VisitExpr_(const VarNode* op) { } } -void ExprVisitor::ExprVisitor::VisitExpr_(const GlobalVarNode* op) { -} +void ExprVisitor::ExprVisitor::VisitExpr_(const GlobalVarNode* op) {} -void ExprVisitor::ExprVisitor::VisitExpr_(const ConstantNode* op) { -} +void ExprVisitor::ExprVisitor::VisitExpr_(const ConstantNode* op) {} void ExprVisitor::ExprVisitor::VisitExpr_(const TupleNode* op) { for (auto field : op->fields) { @@ -440,17 +435,11 @@ void ExprVisitor::VisitExpr_(const IfNode* op) { void ExprVisitor::VisitExpr_(const OpNode* op) { return; } -void ExprVisitor::VisitExpr_(const TupleGetItemNode* op) { - this->VisitExpr(op->tuple); -} +void ExprVisitor::VisitExpr_(const TupleGetItemNode* op) { this->VisitExpr(op->tuple); } -void ExprVisitor::ExprVisitor::VisitExpr_(const RefCreateNode* op) { - this->VisitExpr(op->value); -} +void ExprVisitor::ExprVisitor::VisitExpr_(const RefCreateNode* op) { this->VisitExpr(op->value); } -void ExprVisitor::ExprVisitor::VisitExpr_(const RefReadNode* op) { - this->VisitExpr(op->ref); -} +void ExprVisitor::ExprVisitor::VisitExpr_(const RefReadNode* op) { this->VisitExpr(op->ref); } void ExprVisitor::ExprVisitor::VisitExpr_(const RefWriteNode* op) { this->VisitExpr(op->ref); @@ -501,30 +490,23 @@ void PostOrderVisit(const Expr& e, std::function fvisit) { ExprApplyVisit(fvisit).VisitExpr(e); } -TVM_REGISTER_GLOBAL("relay.analysis.post_order_visit") -.set_body_typed([](Expr expr, PackedFunc f) { - PostOrderVisit(expr, [f](const Expr& n) { - f(n); - }); - }); +TVM_REGISTER_GLOBAL("relay.analysis.post_order_visit").set_body_typed([](Expr expr, PackedFunc f) { + PostOrderVisit(expr, [f](const Expr& n) { f(n); }); +}); // Implement bind. class ExprBinder : public ExprMutator, PatternMutator { public: - explicit ExprBinder(const tvm::Map& args_map) - : args_map_(args_map) { - } + explicit ExprBinder(const tvm::Map& args_map) : args_map_(args_map) {} Expr VisitExpr_(const LetNode* op) final { - CHECK(!args_map_.count(op->var)) - << "Cannot bind an internel variable in let"; + CHECK(!args_map_.count(op->var)) << "Cannot bind an internel variable in let"; return ExprMutator::VisitExpr_(op); } Expr VisitExpr_(const FunctionNode* op) final { for (Var param : op->params) { - CHECK(!args_map_.count(param)) - << "Cannnot bind an internal function parameter"; + CHECK(!args_map_.count(param)) << "Cannnot bind an internal function parameter"; } return ExprMutator::VisitExpr_(op); } @@ -539,9 +521,7 @@ class ExprBinder : public ExprMutator, PatternMutator { } } - Pattern VisitPattern(const Pattern& p) final { - return PatternMutator::VisitPattern(p); - } + Pattern VisitPattern(const Pattern& p) final { return PatternMutator::VisitPattern(p); } Clause VisitClause(const Clause& c) final { Pattern pat = VisitPattern(c->lhs); @@ -549,8 +529,7 @@ class ExprBinder : public ExprMutator, PatternMutator { } Var VisitVar(const Var& v) final { - CHECK(!args_map_.count(v)) - << "Cannnot bind an internal pattern variable"; + CHECK(!args_map_.count(v)) << "Cannnot bind an internal pattern variable"; return v; } @@ -567,16 +546,11 @@ Expr Bind(const Expr& expr, const tvm::Map& args_map) { new_params.push_back(param); } } - if (new_body.same_as(func->body) && - new_params.size() == func->params.size()) { + if (new_body.same_as(func->body) && new_params.size() == func->params.size()) { return expr; } - auto ret = Function(new_params, - new_body, - func->ret_type, - func->type_params, - func->attrs); - std::unordered_set set; + auto ret = Function(new_params, new_body, func->ret_type, func->type_params, func->attrs); + std::unordered_set set; for (const auto& v : FreeVars(expr)) { set.insert(v); } @@ -585,11 +559,7 @@ Expr Bind(const Expr& expr, const tvm::Map& args_map) { new_params.push_back(v); } } - ret = Function(new_params, - new_body, - func->ret_type, - func->type_params, - func->attrs); + ret = Function(new_params, new_body, func->ret_type, func->type_params, func->attrs); CHECK_EQ(FreeVars(expr).size(), FreeVars(ret).size()); return std::move(ret); } else { @@ -597,15 +567,14 @@ Expr Bind(const Expr& expr, const tvm::Map& args_map) { } } -TVM_REGISTER_GLOBAL("relay.ir.Bind") -.set_body([](TVMArgs args, TVMRetValue* ret) { - ObjectRef input = args[0]; - if (input->IsInstance()) { - *ret = Bind(Downcast(input), args[1]); - } else { - CHECK(input->IsInstance()); - *ret = Bind(Downcast(input), args[1]); - } - }); +TVM_REGISTER_GLOBAL("relay.ir.Bind").set_body([](TVMArgs args, TVMRetValue* ret) { + ObjectRef input = args[0]; + if (input->IsInstance()) { + *ret = Bind(Downcast(input), args[1]); + } else { + CHECK(input->IsInstance()); + *ret = Bind(Downcast(input), args[1]); + } +}); } // namespace relay } // namespace tvm diff --git a/src/relay/ir/function.cc b/src/relay/ir/function.cc index 12a80c5698af..5312e6d48447 100644 --- a/src/relay/ir/function.cc +++ b/src/relay/ir/function.cc @@ -26,11 +26,8 @@ namespace tvm { namespace relay { -Function::Function(tvm::Array params, - Expr body, - Type ret_type, - tvm::Array type_params, - DictAttrs attrs) { +Function::Function(tvm::Array params, Expr body, Type ret_type, + tvm::Array type_params, DictAttrs attrs) { ObjectPtr n = make_object(); CHECK(params.defined()); CHECK(type_params.defined()); @@ -45,34 +42,29 @@ Function::Function(tvm::Array params, FuncType FunctionNode::func_type_annotation() const { Array param_types; for (auto param : this->params) { - Type param_type = (param->type_annotation.defined()) ? param->type_annotation - : IncompleteType(Kind::kType); + Type param_type = + (param->type_annotation.defined()) ? param->type_annotation : IncompleteType(Kind::kType); param_types.push_back(param_type); } - Type ret_type = (this->ret_type.defined()) ? this->ret_type - : IncompleteType(Kind::kType); + Type ret_type = (this->ret_type.defined()) ? this->ret_type : IncompleteType(Kind::kType); return FuncType(param_types, ret_type, this->type_params, {}); } TVM_REGISTER_NODE_TYPE(FunctionNode); TVM_REGISTER_GLOBAL("relay.ir.Function") -.set_body_typed([](tvm::Array params, - Expr body, - Type ret_type, - tvm::Array ty_params, - tvm::DictAttrs attrs) { - return Function(params, body, ret_type, ty_params, attrs); -}); + .set_body_typed([](tvm::Array params, Expr body, Type ret_type, + tvm::Array ty_params, tvm::DictAttrs attrs) { + return Function(params, body, ret_type, ty_params, attrs); + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "FunctionNode(" << node->params << ", " << node->ret_type - << ", " << node->body << ", " << node->type_params << ", " - << node->attrs << ")"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "FunctionNode(" << node->params << ", " << node->ret_type << ", " << node->body + << ", " << node->type_params << ", " << node->attrs << ")"; + }); } // namespace relay } // namespace tvm diff --git a/src/relay/ir/indexed_graph.cc b/src/relay/ir/indexed_graph.cc new file mode 100644 index 000000000000..456bf02a0611 --- /dev/null +++ b/src/relay/ir/indexed_graph.cc @@ -0,0 +1,290 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/ir/indexed_graph.cc + * \brief Utilties for Creating Indexed Graphs. + */ +#include "indexed_graph.h" + +#include +#include +#include +#include + +namespace tvm { +namespace relay { + +// IndexedGraph + +IndexedGraph CreateIndexedGraph(const Expr& expr) { + using NodePtr = std::shared_ptr::Node>; + /*! \brief Creator Creates an IndexedGraph and determintes Topological order */ + class Creator : public MixedModeVisitor { + public: + IndexedGraph CreateGraph(const Expr& expr) { + VisitExpr(expr); + graph_.node_map_[expr]->is_external_ = true; + return std::move(graph_); + } + + protected: + void VisitLeaf(const Expr& expr) override { + MixedModeVisitor::VisitLeaf(expr); + auto node = std::make_shared::Node>(expr, index_++); + graph_.node_map_[expr] = node; + graph_.topological_order_.push_back(node); + } + IndexedGraph graph_; + size_t index_ = 0; + }; + /*! \brief Annotator takes an IndexedGraph, fills it's forward outputs, and does dominator tree + * analysis. + * + * Annotator use ExprFunctor to visit nodes, but iterates over them in pre-determined + * topological order instead of recursing. + */ + class Annotator : public ExprFunctor { + public: + Annotator(const IndexedGraph& graph) : graph_(graph) {} + IndexedGraph Annotate() { + // Visit all of the nodes in topological order to get forward outputs + for (const auto& node : graph_.topological_order_) { + ExprFunctor::VisitExpr(node->ref_, nullptr); + } + // do the dominator analysis + graph_.PostDom(); + return std::move(graph_); + } + + /*! Default visitation pushes the parent to the child's ouputs and the child to the parent's + * inputs*/ + void VisitExpr(const Expr& expr, NodePtr parent) override { + auto current = graph_.node_map_[expr]; + if (parent) { + current->outputs_.push_back(parent.get()); + parent->inputs_.push_back(current.get()); + } + } + + protected: + IndexedGraph graph_; + void VisitExpr_(const VarNode* op, NodePtr parent) override { + if (op->type_annotation.defined()) { + this->VisitType(op->type_annotation); + } + } + + void VisitExpr_(const GlobalVarNode* op, NodePtr parent) override {} + + void VisitExpr_(const ConstantNode* op, NodePtr parent) override {} + + void VisitExpr_(const TupleNode* op, NodePtr parent) override { + for (auto field : op->fields) { + this->VisitExpr(field, graph_.node_map_[GetRef(op)]); + } + } + + void VisitExpr_(const FunctionNode* op, NodePtr parent) override { + for (auto param : op->params) { + this->VisitExpr(param, graph_.node_map_[GetRef(op)]); + } + + this->VisitExpr(op->body, graph_.node_map_[GetRef(op)]); + } + + void VisitExpr_(const CallNode* op, NodePtr parent) override { + this->VisitExpr(op->op, graph_.node_map_[GetRef(op)]); + + for (auto ty_arg : op->type_args) { + this->VisitType(ty_arg); + } + + for (auto arg : op->args) { + this->VisitExpr(arg, graph_.node_map_[GetRef(op)]); + } + } + + void VisitExpr_(const LetNode* op, NodePtr parent) override { + this->VisitExpr(op->value, graph_.node_map_[GetRef(op)]); + this->VisitExpr(op->var, graph_.node_map_[GetRef(op)]); + this->VisitExpr(op->body, graph_.node_map_[GetRef(op)]); + } + + void VisitExpr_(const IfNode* op, NodePtr parent) override { + this->VisitExpr(op->cond, graph_.node_map_[GetRef(op)]); + this->VisitExpr(op->true_branch, graph_.node_map_[GetRef(op)]); + this->VisitExpr(op->false_branch, graph_.node_map_[GetRef(op)]); + } + + void VisitExpr_(const OpNode* op, NodePtr parent) override { return; } + + void VisitExpr_(const TupleGetItemNode* op, NodePtr parent) override { + this->VisitExpr(op->tuple, graph_.node_map_[GetRef(op)]); + } + + void VisitExpr_(const RefCreateNode* op, NodePtr parent) override { + this->VisitExpr(op->value, graph_.node_map_[GetRef(op)]); + } + + void VisitExpr_(const RefReadNode* op, NodePtr parent) override { + this->VisitExpr(op->ref, graph_.node_map_[GetRef(op)]); + } + + void VisitExpr_(const RefWriteNode* op, NodePtr parent) override { + this->VisitExpr(op->ref, graph_.node_map_[GetRef(op)]); + this->VisitExpr(op->value, graph_.node_map_[GetRef(op)]); + } + + void VisitExpr_(const ConstructorNode* op, NodePtr parent) override { + for (const Type& t : op->inputs) { + this->VisitType(t); + } + this->VisitType(op->belong_to); + } + + void VisitExpr_(const MatchNode* op, NodePtr parent) override { + this->VisitExpr(op->data, graph_.node_map_[GetRef(op)]); + for (const Clause& c : op->clauses) { + this->VisitClause(c, graph_.node_map_[GetRef(op)]); + } + } + + void VisitClause(const Clause& op, NodePtr parent) { + this->VisitPattern(op->lhs); + this->VisitExpr(op->rhs, parent); + } + + void VisitPattern(const Pattern& p) { return; } + + void VisitType(const Type& t) { return; } + }; + return Annotator(Creator().CreateGraph(expr)).Annotate(); +} + +IndexedGraph CreateIndexedGraph(const DFPattern& pattern) { + using NodePtr = std::shared_ptr::Node>; + /*! \brief Creator Creates an IndexedGraph and determintes Toplogical order */ + class Creator : public DFPatternVisitor { + public: + IndexedGraph CreateGraph(const DFPattern& pattern) { + VisitDFPattern(pattern); + graph_.node_map_[pattern]->is_external_ = true; + return std::move(graph_); + } + + protected: + void VisitDFPattern(const DFPattern& pattern) override { + if (this->visited_.count(pattern.get()) == 0) { + DFPatternVisitor::VisitDFPattern(pattern); + auto node = std::make_shared::Node>(pattern, index_++); + graph_.node_map_[pattern] = node; + graph_.topological_order_.push_back(node); + } + } + IndexedGraph graph_; + size_t index_ = 0; + }; + /*! \brief Annotator takes an IndexedGraph, fills it's forward outputs, and does domiantor tree + * analysis. + * + * Annotator use ExprFunctor to visit nodes, but iterates over them in pre-determined + * topological order instead of recursing. + */ + class Annotator : public DFPatternFunctor { + public: + Annotator(const IndexedGraph& graph) : graph_(graph) {} + IndexedGraph Annotate() { + // Visit all of the nodes in topological order to get forward outputs + for (const auto& node : graph_.topological_order_) { + DFPatternFunctor::VisitDFPattern(node->ref_, nullptr); + } + graph_.PostDom(); + // do the dominator analysis + return std::move(graph_); + } + + /*! Default visitation pushes the parent to the child's ouputs */ + void VisitDFPattern(const DFPattern& pattern, NodePtr parent) override { + auto current = graph_.node_map_[pattern]; + if (parent) { + current->outputs_.push_back(parent.get()); + parent->inputs_.push_back(current.get()); + } + } + + protected: + IndexedGraph graph_; + void VisitDFPattern_(const AltPatternNode* op, NodePtr parent) override { + VisitDFPattern(op->left, graph_.node_map_[GetRef(op)]); + VisitDFPattern(op->right, graph_.node_map_[GetRef(op)]); + } + + void VisitDFPattern_(const AttrPatternNode* op, NodePtr parent) override { + VisitDFPattern(op->pattern, graph_.node_map_[GetRef(op)]); + } + + void VisitDFPattern_(const CallPatternNode* op, NodePtr parent) override { + VisitDFPattern(op->op, graph_.node_map_[GetRef(op)]); + for (auto arg : op->args) { + VisitDFPattern(arg, graph_.node_map_[GetRef(op)]); + } + } + + void VisitDFPattern_(const ConstantPatternNode* op, NodePtr parent) override {} + + void VisitDFPattern_(const DataTypePatternNode* op, NodePtr parent) override { + VisitDFPattern(op->pattern, graph_.node_map_[GetRef(op)]); + } + + void VisitDFPattern_(const DominatorPatternNode* op, NodePtr parent) override { + VisitDFPattern(op->parent, graph_.node_map_[GetRef(op)]); + VisitDFPattern(op->path, graph_.node_map_[GetRef(op)]); + VisitDFPattern(op->child, graph_.node_map_[GetRef(op)]); + } + + void VisitDFPattern_(const ExprPatternNode* op, NodePtr parent) override {} + + void VisitDFPattern_(const ShapePatternNode* op, NodePtr parent) override { + VisitDFPattern(op->pattern, graph_.node_map_[GetRef(op)]); + } + + void VisitDFPattern_(const TupleGetItemPatternNode* op, NodePtr parent) override { + VisitDFPattern(op->tuple, graph_.node_map_[GetRef(op)]); + } + + void VisitDFPattern_(const TuplePatternNode* op, NodePtr parent) override { + for (auto field : op->fields) { + VisitDFPattern(field, graph_.node_map_[GetRef(op)]); + } + } + + void VisitDFPattern_(const TypePatternNode* op, NodePtr parent) override { + VisitDFPattern(op->pattern, graph_.node_map_[GetRef(op)]); + } + + void VisitDFPattern_(const VarPatternNode* op, NodePtr parent) override {} + + void VisitDFPattern_(const WildcardPatternNode* op, NodePtr parent) override {} + }; + return Annotator(Creator().CreateGraph(pattern)).Annotate(); +} + +} // namespace relay +} // namespace tvm diff --git a/src/relay/ir/indexed_graph.h b/src/relay/ir/indexed_graph.h new file mode 100644 index 000000000000..70508279af21 --- /dev/null +++ b/src/relay/ir/indexed_graph.h @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/ir/indexed_graph.h + * \brief A pattern matcher for matching dataflow properties. + */ +#ifndef TVM_RELAY_IR_INDEXED_GRAPH_H_ +#define TVM_RELAY_IR_INDEXED_GRAPH_H_ + +#include + +#include +#include +#include +#include +#include + +namespace tvm { +namespace relay { + +/*! + * \brief A Wrapper around a templated graph type + * Holds a forward-backward indexed representation of the graph and a dominator tree representation + * of the graph + * + * This class is templated and the implementaiton is in the header file so we can analyze both + * DFPattern and Expr with the same infrastructure. + * + * IndexedGraph should be instantiated through the CreateIndexedGraph utilities. + */ +template +class IndexedGraph { + public: + /*! \brief A Node that wraps the input type and represents the indexed graph and dominator tree */ + struct Node { + /*! \brief Node Constructor + * \param ref The input graph node + * \param index The index of the node in toplogical order + */ + Node(const T& ref, const size_t index) : ref_(ref), index_(index) {} + + /*! \brief The input node */ + const T ref_; + /*! \brief The topological order index */ + const size_t index_; + + /*! \brief A boolean to determine if this node is external to the graph */ + bool is_external_ = false; + /*! \brief The forward inputs of the node */ + std::vector inputs_; + /*! \brief The forward outputs/users of the node */ + std::vector outputs_; + + /*! \brief The depth of the node in the dominator tree */ + size_t depth_ = 0; + /*! \brief The dominator parent/final user of the outputs of this node */ + Node* dominator_parent_; + /*! \brief The nodes this node dominates */ + std::vector dominator_children_; + }; + /*! \brief Construct the domination tree inside IndexedGraph */ + void PostDom() { + for (size_t i = topological_order_.size(); i != 0; --i) { + size_t index = i - 1; + auto* current = topological_order_[index].get(); + if (current->is_external_) { + current->depth_ = 1; + current->dominator_parent_ = nullptr; + } else { + auto parent = LeastCommonAncestor(current->outputs_); + current->depth_ = parent ? parent->depth_ + 1 : 1; + current->dominator_parent_ = parent; + parent->dominator_children_.push_back(current); + } + } + } + /*! \brief Map of input nodes to IndexedGraph Nodes */ + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> node_map_; + /*! \brief Topological IndexedGraph Nodes */ + std::vector> topological_order_; + + protected: + /*! \brief Find the least common ancestor of all outputs of a node */ + Node* LeastCommonAncestor(const std::vector& outputs) { + if (outputs.size() == 0) { + return nullptr; + } + auto parent = outputs.at(0); + for (size_t i = 1; i < outputs.size(); ++i) { + parent = LeastCommonAncestor(parent, outputs.at(i)); + } + return parent; + } + + /*! \brief Find the least common ancestor of two nodes */ + Node* LeastCommonAncestor(Node* lhs, Node* rhs) { + if (lhs == nullptr || rhs == nullptr) { + return nullptr; + } + while (lhs != rhs) { + CHECK(lhs); + CHECK(rhs); + if (lhs->depth_ < rhs->depth_) { + rhs = rhs->dominator_parent_; + } else if (lhs->depth_ > rhs->depth_) { + lhs = lhs->dominator_parent_; + } else { + rhs = rhs->dominator_parent_; + lhs = lhs->dominator_parent_; + } + } + return lhs; + } +}; + +/*! \brief Create an Indexed Graph based on an Expr */ +IndexedGraph CreateIndexedGraph(const Expr& expr); +/*! \brief Create an Indexed Graph based on an DFPattern */ +IndexedGraph CreateIndexedGraph(const DFPattern& pattern); + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_IR_INDEXED_GRAPH_H_ diff --git a/src/relay/ir/op_strategy.cc b/src/relay/ir/op_strategy.cc index 4e407dbed655..a946b94cac02 100644 --- a/src/relay/ir/op_strategy.cc +++ b/src/relay/ir/op_strategy.cc @@ -31,21 +31,18 @@ TVM_REGISTER_NODE_TYPE(OpImplementationNode); TVM_REGISTER_NODE_TYPE(OpSpecializationNode); TVM_REGISTER_NODE_TYPE(OpStrategyNode); -Array OpImplementation::Compute(const Attrs& attrs, - const Array& inputs, +Array OpImplementation::Compute(const Attrs& attrs, const Array& inputs, const Type& out_type) { return (*this)->fcompute(attrs, inputs, out_type); } -te::Schedule OpImplementation::Schedule(const Attrs& attrs, - const Array &outs, +te::Schedule OpImplementation::Schedule(const Attrs& attrs, const Array& outs, const Target& target) { return (*this)->fschedule(attrs, outs, target); } void OpSpecialization::AddImplementation(tvm::relay::FTVMCompute fcompute, - tvm::relay::FTVMSchedule fschedule, - std::string name, + tvm::relay::FTVMSchedule fschedule, String name, int plevel) { auto n = make_object(); n->fcompute = fcompute; @@ -55,9 +52,7 @@ void OpSpecialization::AddImplementation(tvm::relay::FTVMCompute fcompute, (*this)->implementations.push_back(OpImplementation(n)); } -void OpStrategy::AddImplementation(FTVMCompute fcompute, - FTVMSchedule fschedule, - std::string name, +void OpStrategy::AddImplementation(FTVMCompute fcompute, FTVMSchedule fschedule, String name, int plevel) { auto curr_cond = te::SpecializedCondition::Current(); auto self = this->operator->(); @@ -77,38 +72,37 @@ void OpStrategy::AddImplementation(FTVMCompute fcompute, } TVM_REGISTER_GLOBAL("relay.op._OpImplementationCompute") -.set_body([](TVMArgs args, TVMRetValue* rv) { - OpImplementation imp = args[0]; - Attrs attrs = args[1]; - Array inputs = args[2]; - Type out_type = args[3]; - *rv = imp.Compute(attrs, inputs, out_type); -}); + .set_body([](TVMArgs args, TVMRetValue* rv) { + OpImplementation imp = args[0]; + Attrs attrs = args[1]; + Array inputs = args[2]; + Type out_type = args[3]; + *rv = imp.Compute(attrs, inputs, out_type); + }); TVM_REGISTER_GLOBAL("relay.op._OpImplementationSchedule") -.set_body([](TVMArgs args, TVMRetValue* rv) { - OpImplementation imp = args[0]; - Attrs attrs = args[1]; - Array outs = args[2]; - Target target = args[3]; - *rv = imp.Schedule(attrs, outs, target); -}); + .set_body([](TVMArgs args, TVMRetValue* rv) { + OpImplementation imp = args[0]; + Attrs attrs = args[1]; + Array outs = args[2]; + Target target = args[3]; + *rv = imp.Schedule(attrs, outs, target); + }); -TVM_REGISTER_GLOBAL("relay.op._make.OpStrategy") -.set_body([](TVMArgs args, TVMRetValue* rv) { - ObjectPtr n = make_object(); - *rv = OpStrategy(n); +TVM_REGISTER_GLOBAL("relay.op._make.OpStrategy").set_body([](TVMArgs args, TVMRetValue* rv) { + ObjectPtr n = make_object(); + *rv = OpStrategy(n); }); TVM_REGISTER_GLOBAL("relay.op._OpStrategyAddImplementation") -.set_body([](TVMArgs args, TVMRetValue* rv) { - OpStrategy strategy = args[0]; - FTVMCompute compute = args[1]; - FTVMSchedule schedule = args[2]; - std::string name = args[3]; - int plevel = args[4]; - strategy.AddImplementation(compute, schedule, name, plevel); -}); + .set_body([](TVMArgs args, TVMRetValue* rv) { + OpStrategy strategy = args[0]; + FTVMCompute compute = args[1]; + FTVMSchedule schedule = args[2]; + std::string name = args[3]; + int plevel = args[4]; + strategy.AddImplementation(compute, schedule, name, plevel); + }); } // namespace relay } // namespace tvm diff --git a/src/relay/ir/pattern_functor.cc b/src/relay/ir/pattern_functor.cc index 6795884ef438..8c366bad641a 100644 --- a/src/relay/ir/pattern_functor.cc +++ b/src/relay/ir/pattern_functor.cc @@ -27,13 +27,9 @@ namespace tvm { namespace relay { -Pattern PatternMutator::Mutate(const Pattern& pat) { - return (*this)(pat); -} +Pattern PatternMutator::Mutate(const Pattern& pat) { return (*this)(pat); } -Pattern PatternMutator::VisitPattern_(const PatternWildcardNode* op) { - return GetRef(op); -} +Pattern PatternMutator::VisitPattern_(const PatternWildcardNode* op) { return GetRef(op); } Pattern PatternMutator::VisitPattern_(const PatternVarNode* op) { return PatternVar(VisitVar(op->var)); @@ -55,28 +51,20 @@ Pattern PatternMutator::VisitPattern_(const PatternTupleNode* op) { return PatternTuple(pat); } -Type PatternMutator::VisitType(const Type& t) { - return t; -} +Type PatternMutator::VisitType(const Type& t) { return t; } Var PatternMutator::VisitVar(const Var& v) { if (var_map_.count(v) == 0) { - var_map_.insert(std::pair(v, - Var(v->name_hint(), - VisitType(v->type_annotation)))); + var_map_.insert(std::pair(v, Var(v->name_hint(), VisitType(v->type_annotation)))); } return var_map_.at(v); } -Constructor PatternMutator::VisitConstructor(const Constructor& v) { - return v; -} +Constructor PatternMutator::VisitConstructor(const Constructor& v) { return v; } -void PatternVisitor::VisitPattern_(const PatternWildcardNode* op) { } +void PatternVisitor::VisitPattern_(const PatternWildcardNode* op) {} -void PatternVisitor::VisitPattern_(const PatternVarNode* op) { - VisitVar(op->var); -} +void PatternVisitor::VisitPattern_(const PatternVarNode* op) { VisitVar(op->var); } void PatternVisitor::VisitPattern_(const PatternConstructorNode* op) { VisitConstructor(op->constructor); @@ -91,11 +79,9 @@ void PatternVisitor::VisitPattern_(const PatternTupleNode* op) { } } -void PatternVisitor::VisitType(const Type& t) { } +void PatternVisitor::VisitType(const Type& t) {} -void PatternVisitor::VisitVar(const Var& v) { - VisitType(v->type_annotation); -} +void PatternVisitor::VisitVar(const Var& v) { VisitType(v->type_annotation); } void PatternVisitor::VisitConstructor(const Constructor& c) { for (const auto& inp : c->inputs) { diff --git a/src/relay/ir/transform.cc b/src/relay/ir/transform.cc index 06dd2b16661f..b540dd47bcd9 100644 --- a/src/relay/ir/transform.cc +++ b/src/relay/ir/transform.cc @@ -22,15 +22,16 @@ * \brief Relay specific transformation passes. */ #include -#include #include #include - +#include namespace tvm { namespace relay { namespace transform { +TVM_REGISTER_PASS_CONFIG_OPTION("relay.fallback_device_type", IntImm); + class FunctionPass; /*! @@ -56,9 +57,7 @@ class FunctionPassNode : public PassNode { FunctionPassNode() = default; - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("pass_info", &pass_info); - } + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("pass_info", &pass_info); } /*! * \brief Run a function pass on given pass context. @@ -113,14 +112,11 @@ FunctionPass::FunctionPass( } // Perform Module -> Module optimizations at the Function level. -IRModule FunctionPassNode::operator()(IRModule mod, - const PassContext& pass_ctx) const { +IRModule FunctionPassNode::operator()(IRModule mod, const PassContext& pass_ctx) const { const PassInfo& pass_info = Info(); CHECK(mod.defined()); - DLOG(INFO) << "Executing function pass : " - << pass_info->name - << " with opt level: " - << pass_info->opt_level; + DLOG(INFO) << "Executing function pass : " << pass_info->name + << " with opt level: " << pass_info->opt_level; pass_ctx.Trace(mod, pass_info, true); // Execute the pass function and return a new module. @@ -130,9 +126,7 @@ IRModule FunctionPassNode::operator()(IRModule mod, // only picks up relay::Function if (auto* n = it.second.as()) { Function func = GetRef(n); - auto updated_func = SkipFunction(func) - ? func - : pass_func(func, updated_mod, pass_ctx); + auto updated_func = SkipFunction(func) ? func : pass_func(func, updated_mod, pass_ctx); updates.push_back({it.first, updated_func}); } } @@ -146,14 +140,12 @@ IRModule FunctionPassNode::operator()(IRModule mod, bool FunctionPassNode::SkipFunction(const Function& func) const { return (func->GetAttr(attr::kCompiler).defined()) || - func->GetAttr(attr::kSkipOptimization, 0) != 0; + func->GetAttr(attr::kSkipOptimization, 0) != 0; } Pass CreateFunctionPass( const runtime::TypedPackedFunc& pass_func, - int opt_level, - const std::string& name, - const tvm::Array& required) { + int opt_level, String name, tvm::Array required) { PassInfo pass_info = PassInfo(opt_level, name, required); return FunctionPass(pass_func, pass_info); } @@ -161,18 +153,17 @@ Pass CreateFunctionPass( TVM_REGISTER_NODE_TYPE(FunctionPassNode); TVM_REGISTER_GLOBAL("relay._transform.MakeFunctionPass") -.set_body_typed([](runtime::TypedPackedFunc pass_func, - PassInfo pass_info) { - return FunctionPass(pass_func, pass_info); -}); + .set_body_typed( + [](runtime::TypedPackedFunc pass_func, + PassInfo pass_info) { return FunctionPass(pass_func, pass_info); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - const PassInfo info = node->Info(); - p->stream << "Run Function pass: " << info->name - << " at the optimization level " << info->opt_level; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + const PassInfo info = node->Info(); + p->stream << "Run Function pass: " << info->name << " at the optimization level " + << info->opt_level; + }); } // namespace transform } // namespace relay diff --git a/src/relay/op/algorithm/argsort.cc b/src/relay/op/algorithm/argsort.cc index 5b03ceec6ccf..a24097420873 100644 --- a/src/relay/op/algorithm/argsort.cc +++ b/src/relay/op/algorithm/argsort.cc @@ -21,17 +21,15 @@ * \file argsort.cc * \brief Argsort operators */ -#include #include +#include namespace tvm { namespace relay { TVM_REGISTER_NODE_TYPE(ArgsortAttrs); -bool ArgsortRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool ArgsortRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { // `types` contains: [data, result] const ArgsortAttrs* param = attrs.as(); @@ -39,18 +37,14 @@ bool ArgsortRel(const Array& types, const auto* data = types[0].as(); if (data == nullptr) { CHECK(types[0].as()) - << "Argsort: expect input type to be TensorType but get " - << types[0]; + << "Argsort: expect input type to be TensorType but get " << types[0]; return false; } reporter->Assign(types[1], TensorType(data->shape, param->dtype)); return true; } -Expr MakeArgsort(Expr data, - int axis, - bool is_ascend, - DataType dtype) { +Expr MakeArgsort(Expr data, int axis, bool is_ascend, DataType dtype) { auto attrs = make_object(); attrs->axis = axis; attrs->is_ascend = is_ascend; @@ -59,19 +53,17 @@ Expr MakeArgsort(Expr data, return Call(op, {data}, Attrs(attrs), {}); } - -TVM_REGISTER_GLOBAL("relay.op._make.argsort") -.set_body_typed(MakeArgsort); +TVM_REGISTER_GLOBAL("relay.op._make.argsort").set_body_typed(MakeArgsort); RELAY_REGISTER_OP("argsort") -.describe(R"doc(Returns the indices that would sort an + .describe(R"doc(Returns the indices that would sort an input array along the given axis. )doc" TVM_ADD_FILELINE) -.set_num_inputs(1) -.set_attrs_type() -.add_argument("data", "Tensor", "Input data.") -.set_support_level(6) -.add_type_rel("Argsort", ArgsortRel); + .set_num_inputs(1) + .set_attrs_type() + .add_argument("data", "Tensor", "Input data.") + .set_support_level(6) + .add_type_rel("Argsort", ArgsortRel); } // namespace relay } // namespace tvm diff --git a/src/relay/op/algorithm/topk.cc b/src/relay/op/algorithm/topk.cc index 225575c69b00..b02fe86f6baa 100644 --- a/src/relay/op/algorithm/topk.cc +++ b/src/relay/op/algorithm/topk.cc @@ -21,21 +21,21 @@ * \file topk.cc * \brief TopK operators */ -#include #include +#include +#include namespace tvm { namespace relay { +using tir::make_const; TVM_REGISTER_NODE_TYPE(TopKAttrs); -bool TopKRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool TopKRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { // `types` contains: [data, result] const TopKAttrs* param = attrs.as(); - CHECK_EQ(types.size(), 2); + CHECK_EQ(types.size(), 3); const auto* data = types[0].as(); CHECK(data); int ndim = data->shape.size(); @@ -46,55 +46,57 @@ bool TopKRel(const Array& types, CHECK(axis >= 0 && axis < ndim); Array out_shape; for (int i = 0; i < ndim; ++i) { - if (i != axis || param->k < 1) { + if (i != axis) { out_shape.push_back(data->shape[i]); + } else if (param->k) { + const Integer& ck = param->k.value(); + if (ck->value < 1) { + out_shape.push_back(data->shape[i]); + } else { + out_shape.push_back(ck); + } } else { - out_shape.push_back(param->k); + out_shape.push_back(Any()); } } auto values_ty = TensorType(out_shape, data->dtype); auto indices_ty = TensorType(out_shape, param->dtype); if (param->ret_type == "both") { - reporter->Assign(types[1], TupleType({values_ty, indices_ty})); + reporter->Assign(types[2], TupleType({values_ty, indices_ty})); } else if (param->ret_type == "values") { - reporter->Assign(types[1], values_ty); + reporter->Assign(types[2], values_ty); } else if (param->ret_type == "indices") { - reporter->Assign(types[1], indices_ty); + reporter->Assign(types[2], indices_ty); } else { LOG(FATAL) << "Unsupported ret type: " << param->ret_type; } return true; } -Expr MakeTopK(Expr data, - int k, - int axis, - std::string ret_type, - bool is_ascend, - DataType dtype) { +Expr MakeTopK(Expr data, Expr k, int axis, String ret_type, bool is_ascend, DataType dtype) { auto attrs = make_object(); - attrs->k = k; + if (const auto& ck = k.as()) { + attrs->k = tvm::Integer(reinterpret_cast(ck->data->data)[0]); + } attrs->axis = axis; attrs->ret_type = ret_type; attrs->is_ascend = is_ascend; attrs->dtype = dtype; static const Op& op = Op::Get("topk"); - return Call(op, {data}, Attrs(attrs), {}); + return Call(op, {data, k}, Attrs(attrs), {}); } - -TVM_REGISTER_GLOBAL("relay.op._make.topk") -.set_body_typed(MakeTopK); +TVM_REGISTER_GLOBAL("relay.op._make.topk").set_body_typed(MakeTopK); RELAY_REGISTER_OP("topk") -.describe(R"doc(Get the top k elements in an input tensor along the given axis. + .describe(R"doc(Get the top k elements in an input tensor along the given axis. )doc" TVM_ADD_FILELINE) -.set_num_inputs(1) -.set_attrs_type() -.add_argument("data", "Tensor", "Input data.") -.set_support_level(6) -.add_type_rel("TopK", TopKRel); + .set_num_inputs(2) + .set_attrs_type() + .add_argument("data", "Tensor", "Input data.") + .add_argument("k", "Tensor", "Number of top elements.") + .set_support_level(6) + .add_type_rel("TopK", TopKRel); } // namespace relay } // namespace tvm - diff --git a/src/relay/op/annotation/annotation.cc b/src/relay/op/annotation/annotation.cc index dd1bcdc1b9eb..6be9b0d4a3d5 100644 --- a/src/relay/op/annotation/annotation.cc +++ b/src/relay/op/annotation/annotation.cc @@ -23,12 +23,12 @@ * \brief Registration of annotation operators. */ -#include +#include #include #include #include #include -#include +#include #include "../../transforms/infer_layout_util.h" #include "../type_relations.h" @@ -40,48 +40,46 @@ namespace relay { TVM_REGISTER_NODE_TYPE(OnDeviceAttrs); TVM_REGISTER_GLOBAL("relay.op.annotation._make.on_device") -.set_body_typed([](Expr data, int device_type) { - auto attrs = make_object(); - attrs->device_type = device_type; - static const Op& op = Op::Get("on_device"); - return Call(op, {data}, Attrs(attrs), {}); -}); + .set_body_typed([](Expr data, int device_type) { + auto attrs = make_object(); + attrs->device_type = device_type; + static const Op& op = Op::Get("on_device"); + return Call(op, {data}, Attrs(attrs), {}); + }); RELAY_REGISTER_OP("on_device") -.describe(R"code(Annotate an expression with device type)code" TVM_ADD_FILELINE) -.set_num_inputs(1) -.set_support_level(10) -.add_type_rel("Identity", IdentityRel) -.set_attr("TOpPattern", kOpaque) -.set_attr("TOpIsStateful", false) -.set_attr("FInferCorrectLayout", - ElemwiseArbitraryLayout); + .describe(R"code(Annotate an expression with device type)code" TVM_ADD_FILELINE) + .set_num_inputs(1) + .set_support_level(10) + .add_type_rel("Identity", IdentityRel) + .set_attr("TOpPattern", kOpaque) + .set_attr("TOpIsStateful", false) + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout); Expr StopFusion(Expr data) { static const Op& op = Op::Get("annotation.stop_fusion"); return Call(op, {data}, Attrs{}, {}); } -TVM_REGISTER_GLOBAL("relay.op.annotation._make.stop_fusion") -.set_body_typed([](Expr data) { - return StopFusion(data); +TVM_REGISTER_GLOBAL("relay.op.annotation._make.stop_fusion").set_body_typed([](Expr data) { + return StopFusion(data); }); RELAY_REGISTER_OP("annotation.stop_fusion") -.describe(R"code(Annotate an expression to prevent it being fused with previous expressions.)code" -TVM_ADD_FILELINE) -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input data.") -.add_type_rel("Identity", IdentityRel) -.set_support_level(10) -.set_attr("TOpPattern", kOpaque) -.set_attr("TOpIsStateful", false) -.set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) -.set_attr("FTVMCompute", - [](const Attrs& attrs, const Array& inputs, - const Type& out_dtype) -> Array { - return {topi::identity(inputs[0])}; - }); + .describe( + R"code(Annotate an expression to prevent it being fused with previous expressions.)code" TVM_ADD_FILELINE) + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input data.") + .add_type_rel("Identity", IdentityRel) + .set_support_level(10) + .set_attr("TOpPattern", kOpaque) + .set_attr("TOpIsStateful", false) + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) + .set_attr("FTVMCompute", + [](const Attrs& attrs, const Array& inputs, + const Type& out_dtype) -> Array { + return {topi::identity(inputs[0])}; + }); // relay.annotation.cast_hint TVM_REGISTER_NODE_TYPE(CastHintAttrs); @@ -94,134 +92,127 @@ Expr CastHint(Expr data, DataType dtype) { } RELAY_REGISTER_OP("annotation.cast_hint") -.describe(R"code(Annotate an expression to be cast into specific data type.)code" -TVM_ADD_FILELINE) -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input data.") -.add_type_rel("Identity", IdentityRel) -.set_support_level(10) -.set_attr("TOpPattern", kOpaque) -.set_attr("TOpIsStateful", false) -.set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) -.set_attr("FTVMCompute", - [](const Attrs& attrs, const Array& inputs, - const Type& out_dtype) -> Array { - return {topi::identity(inputs[0])}; - }); - + .describe( + R"code(Annotate an expression to be cast into specific data type.)code" TVM_ADD_FILELINE) + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input data.") + .add_type_rel("Identity", IdentityRel) + .set_support_level(10) + .set_attr("TOpPattern", kOpaque) + .set_attr("TOpIsStateful", false) + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) + .set_attr("FTVMCompute", + [](const Attrs& attrs, const Array& inputs, + const Type& out_dtype) -> Array { + return {topi::identity(inputs[0])}; + }); RELAY_REGISTER_OP("annotation.bitpack_start") -.describe(R"code( + .describe(R"code( Mark the start of bitpacking. )code" TVM_ADD_FILELINE) -.set_num_inputs(1) -.set_support_level(10) -.add_type_rel("Identity", IdentityRel) -.set_attr("TOpPattern", kOpaque) -.set_attr("TOpIsStateful", false) -.set_attr("FInferCorrectLayout", - ElemwiseArbitraryLayout) -.set_attr("FTVMCompute", - [](const Attrs& attrs, const Array& inputs, - const Type& out_dtype) -> Array { - return {topi::identity(inputs[0])}; - }); + .set_num_inputs(1) + .set_support_level(10) + .add_type_rel("Identity", IdentityRel) + .set_attr("TOpPattern", kOpaque) + .set_attr("TOpIsStateful", false) + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) + .set_attr("FTVMCompute", + [](const Attrs& attrs, const Array& inputs, + const Type& out_dtype) -> Array { + return {topi::identity(inputs[0])}; + }); RELAY_REGISTER_OP("annotation.bitpack_end") -.describe(R"code( + .describe(R"code( Mark the end of bitpacking. )code" TVM_ADD_FILELINE) -.set_num_inputs(1) -.set_support_level(10) -.add_type_rel("Identity", IdentityRel) -.set_attr("TOpPattern", kOpaque) -.set_attr("TOpIsStateful", false) -.set_attr("FInferCorrectLayout", - ElemwiseArbitraryLayout) -.set_attr("FTVMCompute", - [](const Attrs& attrs, const Array& inputs, - const Type& out_dtype) -> Array { - return {topi::identity(inputs[0])}; - }); - -TVM_REGISTER_GLOBAL("relay.op.annotation._make.checkpoint") -.set_body_typed([](Expr data) { + .set_num_inputs(1) + .set_support_level(10) + .add_type_rel("Identity", IdentityRel) + .set_attr("TOpPattern", kOpaque) + .set_attr("TOpIsStateful", false) + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) + .set_attr("FTVMCompute", + [](const Attrs& attrs, const Array& inputs, + const Type& out_dtype) -> Array { + return {topi::identity(inputs[0])}; + }); + +TVM_REGISTER_GLOBAL("relay.op.annotation._make.checkpoint").set_body_typed([](Expr data) { static const Op& op = Op::Get("annotation.checkpoint"); return Call(op, {data}, Attrs{}, {}); }); RELAY_REGISTER_OP("annotation.checkpoint") -.describe(R"code( + .describe(R"code( Mark a checkpoint for checkpointing memory optimization. )code" TVM_ADD_FILELINE) -.set_num_inputs(1) -.set_support_level(10) -.add_type_rel("Identity", IdentityRel) -.set_attr("TOpPattern", kOpaque) -.set_attr("TOpIsStateful", false) -.set_attr("FInferCorrectLayout", - ElemwiseArbitraryLayout) -.set_attr("FTVMCompute", - [](const Attrs& attrs, const Array& inputs, - const Type& out_dtype) -> Array { - Array outputs; - for (size_t i = 0; i < inputs.size(); ++i) { - outputs.push_back(topi::identity(inputs[i])); - } - return outputs; - }); + .set_num_inputs(1) + .set_support_level(10) + .add_type_rel("Identity", IdentityRel) + .set_attr("TOpPattern", kOpaque) + .set_attr("TOpIsStateful", false) + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) + .set_attr("FTVMCompute", + [](const Attrs& attrs, const Array& inputs, + const Type& out_dtype) -> Array { + Array outputs; + for (size_t i = 0; i < inputs.size(); ++i) { + outputs.push_back(topi::identity(inputs[i])); + } + return outputs; + }); TVM_REGISTER_NODE_TYPE(CompilerAttrs); RELAY_REGISTER_OP("annotation.compiler_begin") -.describe(R"code( + .describe(R"code( Beginning of a region that is handled by a given compiler. )code" TVM_ADD_FILELINE) -.set_num_inputs(1) -.set_support_level(10) -.add_type_rel("Identity", IdentityRel) -.set_attr("TOpPattern", kOpaque) -.set_attr("TOpIsStateful", false) -.set_attr("FInferCorrectLayout", - ElemwiseArbitraryLayout) -.set_attr("FTVMCompute", - [](const Attrs& attrs, const Array& inputs, - const Type& out_dtype) -> Array { - return {topi::identity(inputs[0])}; - }); + .set_num_inputs(1) + .set_support_level(10) + .add_type_rel("Identity", IdentityRel) + .set_attr("TOpPattern", kOpaque) + .set_attr("TOpIsStateful", false) + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) + .set_attr("FTVMCompute", + [](const Attrs& attrs, const Array& inputs, + const Type& out_dtype) -> Array { + return {topi::identity(inputs[0])}; + }); TVM_REGISTER_GLOBAL("relay.op.annotation._make.compiler_begin") -.set_body_typed([](Expr expr, std::string compiler) { - auto attrs = make_object(); - attrs->compiler = compiler; - static const Op& op = Op::Get("annotation.compiler_begin"); - return Call(op, {expr}, Attrs(attrs), {}); -}); + .set_body_typed([](Expr expr, String compiler) { + auto attrs = make_object(); + attrs->compiler = compiler; + static const Op& op = Op::Get("annotation.compiler_begin"); + return Call(op, {expr}, Attrs(attrs), {}); + }); RELAY_REGISTER_OP("annotation.compiler_end") -.describe(R"code( + .describe(R"code( End of a region that is handled by a given compiler. )code" TVM_ADD_FILELINE) -.set_num_inputs(1) -.set_support_level(10) -.add_type_rel("Identity", IdentityRel) -.set_attr("TOpPattern", kOpaque) -.set_attr("TOpIsStateful", false) -.set_attr("FInferCorrectLayout", - ElemwiseArbitraryLayout) -.set_attr("FTVMCompute", - [](const Attrs& attrs, const Array& inputs, - const Type& out_dtype) -> Array { - return {topi::identity(inputs[0])}; - }); + .set_num_inputs(1) + .set_support_level(10) + .add_type_rel("Identity", IdentityRel) + .set_attr("TOpPattern", kOpaque) + .set_attr("TOpIsStateful", false) + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) + .set_attr("FTVMCompute", + [](const Attrs& attrs, const Array& inputs, + const Type& out_dtype) -> Array { + return {topi::identity(inputs[0])}; + }); TVM_REGISTER_GLOBAL("relay.op.annotation._make.compiler_end") -.set_body_typed([](Expr expr, std::string compiler) { - auto attrs = make_object(); - attrs->compiler = compiler; - static const Op& op = Op::Get("annotation.compiler_end"); - return Call(op, {expr}, Attrs(attrs), {}); -}); + .set_body_typed([](Expr expr, String compiler) { + auto attrs = make_object(); + attrs->compiler = compiler; + static const Op& op = Op::Get("annotation.compiler_end"); + return Call(op, {expr}, Attrs(attrs), {}); + }); } // namespace relay } // namespace tvm diff --git a/src/relay/op/debug.cc b/src/relay/op/debug.cc index 8e8586f9d213..56b7d4405490 100644 --- a/src/relay/op/debug.cc +++ b/src/relay/op/debug.cc @@ -22,38 +22,39 @@ * \brief Property def of nn operators. */ -#include -#include -#include #include +#include +#include +#include + #include -#include "./type_relations.h" + #include "./op_common.h" +#include "./type_relations.h" namespace tvm { namespace relay { TVM_REGISTER_NODE_TYPE(DebugAttrs); -Array DebugCompute(const Attrs& attrs, - const Array& inputs, +Array DebugCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { - return Array{ topi::identity(inputs[0]) }; + return Array{topi::identity(inputs[0])}; } RELAY_REGISTER_OP("debug") -.describe(R"code(Enter the interpreter's debugger. + .describe(R"code(Enter the interpreter's debugger. )code" TVM_ADD_FILELINE) -.set_num_inputs(1) -.add_argument("program", "Tuple", "The program to execute before debugging.") -.set_support_level(1) -.set_attrs_type() -.add_type_rel("Debug", IdentityRel) -.set_attr("TOpPattern", kOpaque) -.set_attr("FTVMCompute", DebugCompute); + .set_num_inputs(1) + .add_argument("program", "Tuple", "The program to execute before debugging.") + .set_support_level(1) + .set_attrs_type() + .add_type_rel("Debug", IdentityRel) + .set_attr("TOpPattern", kOpaque) + .set_attr("FTVMCompute", DebugCompute); -Expr MakeDebug(Expr expr, std::string name) { +Expr MakeDebug(Expr expr, String name) { auto dattrs = make_object(); if (name.size() > 0) { dattrs->debug_func = EnvFunc::Get(name); @@ -64,9 +65,7 @@ Expr MakeDebug(Expr expr, std::string name) { return Call(op, {expr}, Attrs(dattrs), {}); } -TVM_REGISTER_GLOBAL("relay.op._make.debug") -.set_body_typed(MakeDebug); +TVM_REGISTER_GLOBAL("relay.op._make.debug").set_body_typed(MakeDebug); } // namespace relay } // namespace tvm - diff --git a/src/relay/op/device_copy.cc b/src/relay/op/device_copy.cc index 4aae549f217b..923965f98192 100644 --- a/src/relay/op/device_copy.cc +++ b/src/relay/op/device_copy.cc @@ -26,14 +26,14 @@ * used as "barrier" to avoid fusing operators belonging to differen devices. */ -#include #include #include #include #include +#include -#include "type_relations.h" #include "../transforms/infer_layout_util.h" +#include "type_relations.h" namespace tvm { namespace relay { @@ -42,27 +42,25 @@ namespace relay { TVM_REGISTER_NODE_TYPE(DeviceCopyAttrs); TVM_REGISTER_GLOBAL("relay.op._make.device_copy") -.set_body_typed([](Expr data, int src_dev_type, - int dst_dev_type) { - auto attrs = make_object(); - attrs->src_dev_type = src_dev_type; - attrs->dst_dev_type = dst_dev_type; - static const Op& op = Op::Get("device_copy"); - return Call(op, {data}, Attrs(attrs), {}); -}); + .set_body_typed([](Expr data, int src_dev_type, int dst_dev_type) { + auto attrs = make_object(); + attrs->src_dev_type = src_dev_type; + attrs->dst_dev_type = dst_dev_type; + static const Op& op = Op::Get("device_copy"); + return Call(op, {data}, Attrs(attrs), {}); + }); RELAY_REGISTER_OP("device_copy") -.describe(R"code( + .describe(R"code( Copy data from one tensor to another. The source and destination might be on different devices. )code" TVM_ADD_FILELINE) -.set_num_inputs(1) -.set_support_level(10) -.add_type_rel("Identity", IdentityRel) -.set_attr("TOpPattern", kOpaque) -.set_attr("TOpIsStateful", false) -.set_attr("FInferCorrectLayout", - ElemwiseArbitraryLayout); + .set_num_inputs(1) + .set_support_level(10) + .add_type_rel("Identity", IdentityRel) + .set_attr("TOpPattern", kOpaque) + .set_attr("TOpIsStateful", false) + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout); } // namespace relay } // namespace tvm diff --git a/src/relay/op/image/dilation2d.cc b/src/relay/op/image/dilation2d.cc index 7146f3736dd6..462f11f56d0d 100644 --- a/src/relay/op/image/dilation2d.cc +++ b/src/relay/op/image/dilation2d.cc @@ -21,9 +21,10 @@ * \file dilation2d.cc * \brief Morphological dilation operator */ -#include -#include #include +#include +#include + #include "../op_common.h" namespace tvm { @@ -32,27 +33,20 @@ namespace relay { // relay.image.dilation2d TVM_REGISTER_NODE_TYPE(Dilation2DAttrs); -template -Array > Dilation2DInferCorrectLayout( - const Attrs& attrs, - const Array& new_in_layouts, - const Array& old_in_layouts, - const Array &old_in_types) { +template +Array > Dilation2DInferCorrectLayout(const Attrs& attrs, + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array& old_in_types) { const T* params = attrs.as(); - return Array >{{params->data_layout, params->kernel_layout}, - {params->data_layout}}; + return Array >{{params->data_layout, params->kernel_layout}, {params->data_layout}}; } // Positional relay function to create dilation2d operator // used by frontend FFI. -Expr MakeDilation2D(Expr data, - Expr weight, - Array strides, - Array padding, - Array dilations, - std::string data_layout, - std::string kernel_layout, +Expr MakeDilation2D(Expr data, Expr weight, Array strides, Array padding, + Array dilations, String data_layout, String kernel_layout, DataType out_dtype) { auto attrs = make_object(); attrs->strides = std::move(strides); @@ -67,7 +61,7 @@ Expr MakeDilation2D(Expr data, template bool Dilation2DRel(const Array& types, int num_inputs, const Attrs& attrs, - const TypeReporter& reporter) { + const TypeReporter& reporter) { CHECK_EQ(types.size(), 3); const auto* data = types[0].as(); const auto* weight = types[1].as(); @@ -113,15 +107,13 @@ bool Dilation2DRel(const Array& types, int num_inputs, const Attrs& attrs, IndexExpr pad_h, pad_w; GetPaddingHeightWidth(param->padding, &pad_h, &pad_w); if (!dshape_nchw[2].as()) { - oshape.Set(2, indexdiv(dshape_nchw[2] + pad_h - dilated_ksize_y, - param->strides[0]) + 1); + oshape.Set(2, indexdiv(dshape_nchw[2] + pad_h - dilated_ksize_y, param->strides[0]) + 1); } else { oshape.Set(2, dshape_nchw[2]); } if (!dshape_nchw[3].as()) { - oshape.Set(3, indexdiv(dshape_nchw[3] + pad_w - dilated_ksize_x, - param->strides[1]) + 1); + oshape.Set(3, indexdiv(dshape_nchw[3] + pad_w - dilated_ksize_x, param->strides[1]) + 1); } else { oshape.Set(3, dshape_nchw[3]); } @@ -136,26 +128,24 @@ bool Dilation2DRel(const Array& types, int num_inputs, const Attrs& attrs, return true; } -TVM_REGISTER_GLOBAL("relay.op.image._make.dilation2d") -.set_body_typed(MakeDilation2D); - +TVM_REGISTER_GLOBAL("relay.op.image._make.dilation2d").set_body_typed(MakeDilation2D); RELAY_REGISTER_OP("image.dilation2d") -.describe(R"code(Computes grayscale dilation of 4D input and 3D filter. + .describe(R"code(Computes grayscale dilation of 4D input and 3D filter. - **data**: This depends on the `layout` parameter. Input is 4D array of shape (batch_size, in_channels, height, width) if `layout` is `NCHW`. - **weight**: (in_channels, height, width) - **out**: This depends on the `layout` parameter. Output is 4D array of shape (batch_size, channels, out_height, out_width) if `layout` is `NCHW`. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(2) -.add_argument("data", "Tensor", "The input tensor.") -.add_argument("weight", "Tensor", "The weight tensor.") -.set_support_level(2) -.add_type_rel("Dilation2D", Dilation2DRel) -.set_attr("FInferCorrectLayout", - Dilation2DInferCorrectLayout); + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("weight", "Tensor", "The weight tensor.") + .set_support_level(2) + .add_type_rel("Dilation2D", Dilation2DRel) + .set_attr("FInferCorrectLayout", + Dilation2DInferCorrectLayout); } // namespace relay } // namespace tvm diff --git a/src/relay/op/image/grid_sample.cc b/src/relay/op/image/grid_sample.cc new file mode 100644 index 000000000000..bc6989155323 --- /dev/null +++ b/src/relay/op/image/grid_sample.cc @@ -0,0 +1,168 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file grid_sample.cc + * \brief affine_grid and grid_sample operator + */ +#include +#include +#include + +#include "../op_common.h" + +namespace tvm { +namespace relay { + +// relay.image.affine_grid +TVM_REGISTER_NODE_TYPE(AffineGridAttrs); + +bool AffineGridRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 2); + const auto* data = types[0].as(); + if (data == nullptr) return false; + auto batch_size = data->shape[0]; + + const AffineGridAttrs* param = attrs.as(); + CHECK(param != nullptr); + + Array oshape; + + CHECK(data->shape.size() == 3U && reporter->AssertEQ(data->shape[1], 2) && + reporter->AssertEQ(data->shape[2], 3)) + << "data should be an" + "affine matrix with shape [batch_size, 2, 3]"; + CHECK(param->target_shape.defined() && param->target_shape.size() == 2) + << "target_shape should be 2D"; + oshape.push_back(batch_size); + oshape.push_back(2); + oshape.push_back(param->target_shape[0]); + oshape.push_back(param->target_shape[1]); + + // assign output type + reporter->Assign(types[1], TensorType(oshape, data->dtype)); + return true; +} + +// Positional relay function to create affine_grid operator +// used by frontend FFI. +Expr MakeAffineGrid(Expr data, Array target_shape) { + auto attrs = make_object(); + attrs->target_shape = std::move(target_shape); + static const Op& op = Op::Get("image.affine_grid"); + return Call(op, {data}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op.image._make.affine_grid").set_body_typed(MakeAffineGrid); + +RELAY_REGISTER_OP("image.affine_grid") + .describe(R"code(affine_grid operator that generates 2D sampling grid. + +This operation is described in https://arxiv.org/pdf/1506.02025.pdf. It generates a uniform +sampling grid within the target shape and normalizes it to [-1, 1]. The provided affine +transformation is then applied on the sampling grid. + +- **data**: data is 3D array of shape [batch, 2, 3], which defines an affine transformation. + +- **out**: out is 4D array of shape [batch, 2, height, width], where each vector + :math:`out[b, :, h, w]` represents the coordinate :math:`(x, y)` + +)code" TVM_ADD_FILELINE) + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The affine matrix.") + .set_support_level(5) + .add_type_rel("AffineGrid", AffineGridRel) + .set_attr("TOpPattern", kInjective); + +// relay.image.grid_sample +TVM_REGISTER_NODE_TYPE(GridSampleAttrs); + +bool GridSampleRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 3); + const auto* data = types[0].as(); + const auto* grid = types[1].as(); + if (!data || !grid) return false; + const auto* param = attrs.as(); + CHECK(param); + static const Layout kNCHW("NCHW"); + const Layout in_layout(param->layout); + auto layout_converter = tir::BijectiveLayout(in_layout, kNCHW); + auto oshape = layout_converter.ForwardShape(data->shape); + oshape.Set(2, grid->shape[2]); + oshape.Set(3, grid->shape[3]); + // assign output type + reporter->Assign(types[2], TensorType(layout_converter.BackwardShape(oshape), data->dtype)); + return true; +} + +// Positional relay function to create affine_grid operator +// used by frontend FFI. +Expr MakeGridSample(Expr data, Expr grid, String method, String layout) { + auto attrs = make_object(); + attrs->method = std::move(method); + attrs->layout = std::move(layout); + static const Op& op = Op::Get("image.grid_sample"); + return Call(op, {data, grid}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op.image._make.grid_sample").set_body_typed(MakeGridSample); + +RELAY_REGISTER_OP("image.grid_sample") + .describe(R"code(Applies grid sampling to input feature map. + +Given :math:`data` and :math:`grid`, then the output is computed by + +.. math:: + x_{src} = grid[batch, 0, y_{dst}, x_{dst}] \\ + y_{src} = grid[batch, 1, y_{dst}, x_{dst}] \\ + output[batch, channel, y_{dst}, x_{dst}] = G(data[batch, channel, y_{src}, x_{src}) + +:math:`x_{dst}`, :math:`y_{dst}` enumerate all spatial locations in :math:`output`, and +:math:`G()` denotes the interpolation function. +The out-boundary points will be padded with zeros. The shape of the output will be +(data.shape[0], data.shape[1], grid.shape[2], grid.shape[3]). + +The operator assumes that :math:`data` has 'NCHW' layout and :math:`grid` has been normalized to [-1, 1]. + +grid_sample often cooperates with affine_grid which generates sampling grids for grid_sample. + +- **data**: data is 4D array of shape + (batch_size, channels, in_height, in_width) for NCHW + (batch_size, in_height, in_width, channels) for NHWC + +- **grid**: out is 4D array of shape [batch, 2, out_height, out_width], where each vector + :math:`out[b, :, h, w]` represents the coordinate :math:`(x, y)` + +- **out**: out is 4D array of shape + (batch, in_channel, out_height, out_width) for NCHW + (batch_size, in_height, in_width, channels) for NHWC + +)code" TVM_ADD_FILELINE) + .set_num_inputs(2) + .set_attrs_type() + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(5) + .add_type_rel("GridSample", GridSampleRel) + .set_attr("TOpPattern", kInjective); + +} // namespace relay +} // namespace tvm diff --git a/src/relay/op/image/resize.cc b/src/relay/op/image/resize.cc index c8f976256600..b6d2c71d7eda 100644 --- a/src/relay/op/image/resize.cc +++ b/src/relay/op/image/resize.cc @@ -21,9 +21,10 @@ * \file resize.cc * \brief Image resize operators */ -#include -#include #include +#include +#include + #include "../op_common.h" namespace tvm { @@ -31,9 +32,7 @@ namespace relay { TVM_REGISTER_NODE_TYPE(ResizeAttrs); -bool ResizeRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool ResizeRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); @@ -46,8 +45,8 @@ bool ResizeRel(const Array& types, const Layout in_layout(param->layout); auto layout_converter = tir::BijectiveLayout(in_layout, kNCHW); CHECK(layout_converter.defined()) - << "Resize only support input layouts that are convertible from NCHW." - << " But got " << in_layout; + << "Resize only support input layouts that are convertible from NCHW." + << " But got " << in_layout; auto oshape = layout_converter.ForwardShape(data->shape); oshape.Set(2, param->size[0]); @@ -59,20 +58,14 @@ bool ResizeRel(const Array& types, } // assign output type - reporter->Assign(types[1], - TensorType(layout_converter.BackwardShape(oshape), - out_dtype)); + reporter->Assign(types[1], TensorType(layout_converter.BackwardShape(oshape), out_dtype)); return true; } // Positional relay function to create image operator // used by frontend FFI. -Expr MakeResize(Expr data, - Array size, - std::string layout, - std::string method, - std::string coordinate_transformation_mode, - DataType out_dtype) { +Expr MakeResize(Expr data, Array size, String layout, String method, + String coordinate_transformation_mode, DataType out_dtype) { auto attrs = make_object(); attrs->size = std::move(size); attrs->layout = std::move(layout); @@ -83,13 +76,10 @@ Expr MakeResize(Expr data, return Call(op, {data}, Attrs(attrs), {}); } - -TVM_REGISTER_GLOBAL("relay.op.image._make.resize") -.set_body_typed(MakeResize); - +TVM_REGISTER_GLOBAL("relay.op.image._make.resize").set_body_typed(MakeResize); RELAY_REGISTER_OP("image.resize") -.describe(R"code(Perform resize to input array with nearest neighbour or bilinear interpolation. + .describe(R"code(Perform resize to input array with nearest neighbour or bilinear interpolation. - **data**: data is 4D array of shape (batch_size, channels, in_height, in_width) for NCHW @@ -102,26 +92,93 @@ RELAY_REGISTER_OP("image.resize") for layout NHWC (batch_size, size[0], size[1], channels) )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(5) -.add_type_rel("Resize", ResizeRel) -.set_attr("TOpPattern", kInjective); + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(5) + .add_type_rel("Resize", ResizeRel) + .set_attr("TOpPattern", kInjective); + +TVM_REGISTER_NODE_TYPE(Resize3dAttrs); + +bool Resize3dRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 2); + const auto* data = types[0].as(); + if (data == nullptr) return false; + + static const Layout kNCDHW("NCDHW"); + + const Resize3dAttrs* param = attrs.as(); + CHECK(param != nullptr); + const Layout in_layout(param->layout); + auto layout_converter = tir::BijectiveLayout(in_layout, kNCDHW); + CHECK(layout_converter.defined()) + << "Resize3d only support input layouts that are convertible from NCDHW." + << " But got " << in_layout; + + auto oshape = layout_converter.ForwardShape(data->shape); + oshape.Set(2, param->size[0]); + oshape.Set(3, param->size[1]); + oshape.Set(4, param->size[2]); + + DataType out_dtype = param->out_dtype; + if (out_dtype.bits() == 0) { + out_dtype = data->dtype; + } + // assign output type + reporter->Assign(types[1], TensorType(layout_converter.BackwardShape(oshape), out_dtype)); + return true; +} + +// Positional relay function to create image operator +// used by frontend FFI. +Expr MakeResize3d(Expr data, Array size, String layout, String method, + String coordinate_transformation_mode, DataType out_dtype) { + auto attrs = make_object(); + attrs->size = std::move(size); + attrs->layout = std::move(layout); + attrs->method = std::move(method); + attrs->coordinate_transformation_mode = coordinate_transformation_mode; + attrs->out_dtype = out_dtype; + static const Op& op = Op::Get("image.resize3d"); + return Call(op, {data}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op.image._make.resize3d").set_body_typed(MakeResize3d); + +RELAY_REGISTER_OP("image.resize3d") + .describe(R"code( +Perform resize3d to input array with nearest neighbour or bilinear interpolation. + +- **data**: data is 5D array of shape + (batch_size, channels, in_depth, in_height, in_width) for NCDHW + (batch_size, in_depth, in_height, in_width, channels) for NDHWC + +- **out**: Output is 5D array of shape + for layout NCDHW + (batch_size, channels, size[0], size[1], size[2]) + + for layout NDHWC + (batch_size, size[0], size[1], size[2], channels) +)code" TVM_ADD_FILELINE) + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(5) + .add_type_rel("Resize3d", Resize3dRel) + .set_attr("TOpPattern", kInjective); TVM_REGISTER_NODE_TYPE(CropAndResizeAttrs); -bool CropAndResizeRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool CropAndResizeRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 4); const auto* data = types[0].as(); const auto* boxes = types[1].as(); const auto* box_indices = types[2].as(); - if (data == nullptr || boxes == nullptr || - box_indices == nullptr) return false; + if (data == nullptr || boxes == nullptr || box_indices == nullptr) return false; const CropAndResizeAttrs* param = attrs.as(); CHECK(param != nullptr); @@ -137,24 +194,17 @@ bool CropAndResizeRel(const Array& types, const Layout in_layout(param->layout); auto layout_converter = tir::BijectiveLayout(in_layout, kNCHW); auto oshape = layout_converter.ForwardShape(data->shape); - oshape.Set(0, box_indices->shape[0]); + oshape.Set(0, boxes->shape[0]); oshape.Set(2, crop_size[0]); oshape.Set(3, crop_size[1]); auto bshape = layout_converter.BackwardShape(oshape); // assign output type - reporter->Assign(types[3], - TensorType(layout_converter.BackwardShape(oshape), - out_dtype)); + reporter->Assign(types[3], TensorType(bshape, out_dtype)); return true; } -Expr MakeCropAndResize(Expr data, - Expr boxes, - Expr box_indices, - Array crop_size, - std::string layout, - std::string method, - double extrapolation_value, +Expr MakeCropAndResize(Expr data, Expr boxes, Expr box_indices, Array crop_size, + String layout, String method, double extrapolation_value, DataType out_dtype) { auto attrs = make_object(); attrs->crop_size = std::move(crop_size); @@ -166,12 +216,11 @@ Expr MakeCropAndResize(Expr data, return Call(op, {data, boxes, box_indices}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op.image._make.crop_and_resize") -.set_body_typed(MakeCropAndResize); - +TVM_REGISTER_GLOBAL("relay.op.image._make.crop_and_resize").set_body_typed(MakeCropAndResize); RELAY_REGISTER_OP("image.crop_and_resize") - .describe(R"code(Perform crop and resize to input array with nearest neighbour or bilinear interpolation. + .describe( + R"code(Perform crop and resize to input array with nearest neighbour or bilinear interpolation. - **data**: data is 4D array of shape (batch_size, channels, in_height, in_width) for NCHW @@ -184,14 +233,14 @@ RELAY_REGISTER_OP("image.crop_and_resize") for layout NHWC (batch_size, crop_size[0], crop_size[1], channels) )code" TVM_ADD_FILELINE) -.set_num_inputs(3) -.add_argument("data", "Tensor", "The input tensor.") -.add_argument("boxes", "Tensor", "The boxes tensor.") -.add_argument("box_indices", "Tensor", "The box indices tensor.") -.set_attrs_type() -.set_support_level(5) -.add_type_rel("CropAndResize", CropAndResizeRel) -.set_attr("TOpPattern", kInjective); + .set_num_inputs(3) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("boxes", "Tensor", "The boxes tensor.") + .add_argument("box_indices", "Tensor", "The box indices tensor.") + .set_attrs_type() + .set_support_level(5) + .add_type_rel("CropAndResize", CropAndResizeRel) + .set_attr("TOpPattern", kInjective); } // namespace relay } // namespace tvm diff --git a/src/relay/op/memory/memory.cc b/src/relay/op/memory/memory.cc index 0a7142df572f..e5081adbf6a7 100644 --- a/src/relay/op/memory/memory.cc +++ b/src/relay/op/memory/memory.cc @@ -27,6 +27,7 @@ #include #include #include +#include #include "../../transforms/infer_layout_util.h" #include "../op_common.h" @@ -91,7 +92,7 @@ RELAY_REGISTER_OP("memory.alloc_storage") }); TVM_REGISTER_GLOBAL("relay.op.memory._make.alloc_tensor") - .set_body_typed([](Expr storage, tvm::relay::Expr shape, DataType dtype, + .set_body_typed([](Expr storage, Expr offset, tvm::relay::Expr shape, DataType dtype, Array assert_shape) { auto attrs = make_object(); attrs->dtype = dtype; @@ -101,27 +102,27 @@ TVM_REGISTER_GLOBAL("relay.op.memory._make.alloc_tensor") attrs->const_shape = Downcast(shape); } static const Op& op = Op::Get("memory.alloc_tensor"); - return Call(op, {storage, shape}, Attrs(attrs), {}); + return Call(op, {storage, offset, shape}, Attrs(attrs), {}); }); std::vector FromConstShape(Constant konst) { runtime::NDArray shape = konst->data; std::vector raw_shape; - DLTensor tensor = shape.ToDLPack()->dl_tensor; - CHECK_EQ(tensor.ndim, 1u); - CHECK_EQ(tensor.dtype.code, 0U) << "found " << tensor.dtype.code; - - CHECK(tensor.dtype.bits == 64 || tensor.dtype.bits == 32) - << "found " << static_cast(tensor.dtype.bits); - - if (tensor.dtype.bits == 32) { - const int32_t* int_ptr = reinterpret_cast(tensor.data); - for (auto i = 0; i < tensor.shape[0]; i++) { + CHECK_EQ(shape->ndim, 1u); + CHECK_EQ(shape->dtype.code, 0U) << "The dtype of constant shape must be int32 or int64, but got " + << runtime::DLDataType2String(shape->dtype); + CHECK(shape->dtype.bits == 64 || shape->dtype.bits == 32) + << "The dtype of constant shape must be int32 or int64, but got" + << runtime::DLDataType2String(shape->dtype); + + if (shape->dtype.bits == 32) { + const int32_t* int_ptr = reinterpret_cast(shape->data); + for (auto i = 0; i < shape->shape[0]; i++) { raw_shape.push_back(int_ptr[i]); } - } else if (tensor.dtype.bits == 64) { - const int64_t* int_ptr = reinterpret_cast(tensor.data); - for (auto i = 0; i < tensor.shape[0]; i++) { + } else if (shape->dtype.bits == 64) { + const int64_t* int_ptr = reinterpret_cast(shape->data); + for (auto i = 0; i < shape->shape[0]; i++) { raw_shape.push_back(int_ptr[i]); } } @@ -131,7 +132,7 @@ std::vector FromConstShape(Constant konst) { bool AllocTensorRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { - CHECK_EQ(types.size(), 3u); + CHECK_EQ(types.size(), 4u); auto alloc_attrs = attrs.as(); CHECK(alloc_attrs != nullptr) << "must be alloc_tensor attributes"; // First argument should be storage. @@ -140,18 +141,28 @@ bool AllocTensorRel(const Array& types, int num_inputs, const Attrs& attrs auto storage_name = mod->GetGlobalTypeVar("Storage"); auto storage = relay::TypeCall(storage_name, {}); reporter->Assign(types[0], storage); - // Second argument should be shape tensor. - auto tt = types[1].as(); + // Second argument should be the offset. + auto offset_type = types[1].as(); + CHECK(offset_type != nullptr) << "must be a scalar type"; + + // Third argument should be shape tensor. + auto tt = types[2].as(); CHECK(tt != nullptr) << "must be tensor type"; - auto rank = tt->shape[0].as(); - CHECK(rank != nullptr); - auto dims = rank->value; + + // Be careful about having to allocate scalars. + int64_t dims = 0; + if (tt->shape.size() != 0) { + auto rank = tt->shape[0].as(); + CHECK(rank != nullptr); + dims = rank->value; + } // Constant node case. Type alloc_type; if (alloc_attrs->const_shape.defined()) { auto con = alloc_attrs->const_shape; auto sh = FromConstShape(con); + CHECK_EQ(sh.size(), dims); Array out_shape; for (auto i = 0u; i < dims; i++) { out_shape.push_back(tvm::Integer(sh[i])); @@ -164,14 +175,15 @@ bool AllocTensorRel(const Array& types, int num_inputs, const Attrs& attrs return true; } - reporter->Assign(types[2], alloc_type); + reporter->Assign(types[3], alloc_type); return true; } RELAY_REGISTER_OP("memory.alloc_tensor") .describe(R"code(Explicitly allocate storage to be used by tensors.)code" TVM_ADD_FILELINE) - .set_num_inputs(2) + .set_num_inputs(3) .add_argument("storage", "Storage", "The storage to allocate from.") + .add_argument("offset", "Tensor", "The offset into the backing storage.") .add_argument("shape", "Tensor", "The shape of the tensor to allocate.") .add_type_rel("AllocTensor", AllocTensorRel) .set_support_level(10) @@ -329,14 +341,12 @@ Expr ToTupleType(const Type& t, const std::vector& exprs) { } } -TVM_REGISTER_GLOBAL("relay.op.memory._make.FlattenTupleType") -.set_body_typed([](Type type) { +TVM_REGISTER_GLOBAL("relay.op.memory._make.FlattenTupleType").set_body_typed([](Type type) { auto types = FlattenTupleType(type); return Array(types.begin(), types.end()); }); -TVM_REGISTER_GLOBAL("relay.op.memory._make.FromTupleType") -.set_body_typed([](Type type, Expr expr) { +TVM_REGISTER_GLOBAL("relay.op.memory._make.FromTupleType").set_body_typed([](Type type, Expr expr) { auto exprs = FromTupleType(type, expr); return Array(exprs.begin(), exprs.end()); }); @@ -358,12 +368,23 @@ bool ShapeFuncRel(const Array& types, int num_inputs, const Attrs& attrs, auto tuple = TupleType(func_type->arg_types); auto in_types = FlattenTupleType(tuple); auto out_types = FlattenTupleType(func_type->ret_type); + Array is_input; + for (size_t i = 0; i < func_type->arg_types.size(); ++i) { + auto const& aty = func_type->arg_types[i]; + size_t num_types = 1; + if (aty.as()) { + num_types = FlattenTupleType(aty).size(); + } + for (size_t j = 0; j < num_types; ++j) { + is_input.push_back(shape_func_attrs->is_input[i]); + } + } Array shape_func_ins, shape_func_outs; for (size_t i = 0; i < in_types.size(); i++) { auto in_type = in_types[i]; - if (shape_func_attrs->is_input[i]) { + if (is_input[i]) { shape_func_ins.push_back(in_type); } else { auto shape = RankShape(in_type->shape); diff --git a/src/relay/op/nn/bitserial.cc b/src/relay/op/nn/bitserial.cc index d2174579cc31..022ca5cc96d8 100644 --- a/src/relay/op/nn/bitserial.cc +++ b/src/relay/op/nn/bitserial.cc @@ -22,12 +22,12 @@ * \brief Property def of bitserial operators. */ -#include #include #include +#include -#include "../op_common.h" #include "../../transforms/infer_layout_util.h" +#include "../op_common.h" namespace tvm { namespace relay { @@ -86,7 +86,7 @@ bool BitPackRel(const Array& types, int num_inputs, const Attrs& attrs, } Expr MakeBitPack(Expr data, int bits, int pack_axis, int bit_axis, DataType pack_type, - std::string name) { + String name) { auto attrs = make_object(); attrs->bits = bits; attrs->pack_axis = pack_axis; @@ -109,11 +109,11 @@ efficient implementation of bitserial operations. packed must be divisible by number of bits. - **out**: Packed tensor with shape appropriately compressed. )code" TVM_ADD_FILELINE) -.set_num_inputs(1) -.set_attrs_type() -.add_argument("data", "Tensor", "Input data.") -.set_support_level(2) -.add_type_rel("BitPack", BitPackRel); + .set_num_inputs(1) + .set_attrs_type() + .add_argument("data", "Tensor", "Input data.") + .set_support_level(2) + .add_type_rel("BitPack", BitPackRel); // relay.nn.bitserial_conv2d TVM_REGISTER_NODE_TYPE(BinaryConv2DAttrs); @@ -137,10 +137,8 @@ bool BinaryConv2DRel(const Array& types, int num_inputs, const Attrs& attr Array oshape({dshape_nchw[0], param->channels, 0, 0}); IndexExpr pad_h, pad_w; GetPaddingHeightWidth(param->padding, &pad_h, &pad_w); - oshape.Set( - 2, (dshape_nchw[2] + pad_h - param->kernel_size[0]) / param->strides[0] + 1); - oshape.Set( - 3, (dshape_nchw[3] + pad_w - param->kernel_size[1]) / param->strides[1] + 1); + oshape.Set(2, (dshape_nchw[2] + pad_h - param->kernel_size[0]) / param->strides[0] + 1); + oshape.Set(3, (dshape_nchw[3] + pad_w - param->kernel_size[1]) / param->strides[1] + 1); DataType out_dtype = param->out_dtype; oshape = trans_in_layout.BackwardShape(oshape); // assign output type @@ -152,7 +150,7 @@ bool BinaryConv2DRel(const Array& types, int num_inputs, const Attrs& attr // used by frontend FFI. Expr MakeBinaryConv2D(Expr data, Expr weight, Array strides, Array padding, IndexExpr channels, Array kernel_size, int activation_bits, - int weight_bits, std::string data_layout, std::string kernel_layout, + int weight_bits, String data_layout, String kernel_layout, DataType pack_dtype, DataType out_dtype, bool unipolar) { auto attrs = make_object(); attrs->strides = std::move(strides); @@ -187,14 +185,14 @@ on some platforms. - **out**: Output with same layout as input. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(2) -.add_argument("data", "Tensor", "The input tensor.") -.add_argument("weight", "Tensor", "The weight tensor.") -.set_support_level(2) -.add_type_rel("BinaryConv2D", BinaryConv2DRel) -.set_attr("FInferCorrectLayout", - BinaryConv2DInferCorrectLayout); + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("weight", "Tensor", "The weight tensor.") + .set_support_level(2) + .add_type_rel("BinaryConv2D", BinaryConv2DRel) + .set_attr("FInferCorrectLayout", + BinaryConv2DInferCorrectLayout); // relay.nn.bitserial_dense TVM_REGISTER_NODE_TYPE(BinaryDenseAttrs); @@ -248,12 +246,12 @@ RELAY_REGISTER_OP("nn.bitserial_dense") - **out**: `(x1, x2, ..., xn, units)`. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(2) -.add_argument("data", "2D Tensor", "Input data.") -.add_argument("weight", "2D Tensor", "Weight matrix.") -.set_support_level(1) -.add_type_rel("BinaryDense", BinaryDenseRel); + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "2D Tensor", "Input data.") + .add_argument("weight", "2D Tensor", "Weight matrix.") + .set_support_level(1) + .add_type_rel("BinaryDense", BinaryDenseRel); } // namespace relay } // namespace tvm diff --git a/src/relay/op/nn/convolution.cc b/src/relay/op/nn/convolution.cc index 66dab57fd947..6c6eb1ecb8b2 100644 --- a/src/relay/op/nn/convolution.cc +++ b/src/relay/op/nn/convolution.cc @@ -21,33 +21,25 @@ * \file convolution.cc * \brief Convolution operators */ -#include -#include -#include +#include "convolution.h" + #include +#include +#include + #include #include "../../transforms/infer_layout_util.h" #include "../op_common.h" -#include "convolution.h" namespace tvm { namespace relay { template -Expr MakeConv(Expr data, - Expr weight, - Array strides, - Array padding, - Array dilation, - int groups, - IndexExpr channels, - Array kernel_size, - std::string data_layout, - std::string kernel_layout, - std::string out_layout, - DataType out_dtype, - std::string op_name) { +Expr MakeConv(Expr data, Expr weight, Array strides, Array padding, + Array dilation, int groups, IndexExpr channels, + Array kernel_size, std::string data_layout, std::string kernel_layout, + std::string out_layout, DataType out_dtype, std::string op_name) { auto attrs = make_object(); attrs->strides = std::move(strides); attrs->padding = std::move(padding); @@ -64,19 +56,10 @@ Expr MakeConv(Expr data, } template -Expr MakeConvWinograd(Expr data, - Expr weight, - int tile_size, - Array strides, - Array padding, - Array dilation, - int groups, - IndexExpr channels, - Array kernel_size, - std::string data_layout, - std::string kernel_layout, - std::string out_layout, - DataType out_dtype, +Expr MakeConvWinograd(Expr data, Expr weight, int tile_size, Array strides, + Array padding, Array dilation, int groups, + IndexExpr channels, Array kernel_size, std::string data_layout, + std::string kernel_layout, std::string out_layout, DataType out_dtype, std::string op_name) { auto attrs = make_object(); attrs->tile_size = tile_size; @@ -94,9 +77,7 @@ Expr MakeConvWinograd(Expr data, return Call(op, {data, weight}, Attrs(attrs), {}); } -Expr MakeConvWinogradWeightTransform(Expr weight, - int tile_size, - std::string op_name) { +Expr MakeConvWinogradWeightTransform(Expr weight, int tile_size, std::string op_name) { auto attrs = make_object(); attrs->tile_size = tile_size; const Op& op = Op::Get(op_name); @@ -104,20 +85,11 @@ Expr MakeConvWinogradWeightTransform(Expr weight, } template -Expr MakeConvTranspose(Expr data, - Expr weight, - Array strides, - Array padding, - Array dilation, - int groups, - IndexExpr channels, - Array kernel_size, - std::string data_layout, - std::string kernel_layout, - std::string out_layout, - Array output_padding, - DataType out_dtype, - std::string op_name) { +Expr MakeConvTranspose(Expr data, Expr weight, Array strides, Array padding, + Array dilation, int groups, IndexExpr channels, + Array kernel_size, std::string data_layout, + std::string kernel_layout, std::string out_layout, + Array output_padding, DataType out_dtype, std::string op_name) { auto attrs = make_object(); attrs->strides = std::move(strides); attrs->padding = std::move(padding); @@ -135,21 +107,11 @@ Expr MakeConvTranspose(Expr data, } template -Expr MakeDeformableConv(Expr data, - Expr offset, - Expr weight, - Array strides, - Array padding, - Array dilation, - int deformable_groups, - int groups, - int channels, - Array kernel_size, - std::string data_layout, - std::string kernel_layout, - std::string out_layout, - DataType out_dtype, - std::string op_name) { +Expr MakeDeformableConv(Expr data, Expr offset, Expr weight, Array strides, + Array padding, Array dilation, int deformable_groups, + int groups, int channels, Array kernel_size, + std::string data_layout, std::string kernel_layout, std::string out_layout, + DataType out_dtype, std::string op_name) { auto attrs = make_object(); attrs->strides = strides; attrs->padding = padding; @@ -166,32 +128,21 @@ Expr MakeDeformableConv(Expr data, return Call(op, {data, offset, weight}, Attrs{attrs}, {}); } - // relay.nn.conv1d TVM_REGISTER_NODE_TYPE(Conv1DAttrs); TVM_REGISTER_GLOBAL("relay.op.nn._make.conv1d") -.set_body_typed([](Expr data, - Expr weight, - Array strides, - Array padding, - Array dilation, - int groups, - IndexExpr channels, - Array kernel_size, - std::string data_layout, - std::string kernel_layout, - std::string out_layout, - DataType out_dtype) { - return MakeConv( - data, weight, strides, padding, dilation, - groups, channels, kernel_size, data_layout, - kernel_layout, out_layout, out_dtype, "nn.conv1d"); -}); - + .set_body_typed([](Expr data, Expr weight, Array strides, Array padding, + Array dilation, int groups, IndexExpr channels, + Array kernel_size, String data_layout, String kernel_layout, + String out_layout, DataType out_dtype) { + return MakeConv(data, weight, strides, padding, dilation, groups, channels, + kernel_size, data_layout, kernel_layout, out_layout, out_dtype, + "nn.conv1d"); + }); RELAY_REGISTER_OP("nn.conv1d") -.describe(R"code(1D convolution layer (e.g. spatial convolution over sequences). + .describe(R"code(1D convolution layer (e.g. spatial convolution over sequences). This layer creates a convolution kernel that is convolved with the layer input to produce a tensor of outputs. @@ -203,40 +154,29 @@ with the layer input to produce a tensor of outputs. (batch_size, channels, out_width) if `layout` is `NCW`. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(2) -.add_argument("data", "Tensor", "The input tensor.") -.add_argument("weight", "Tensor", "The weight tensor.") -.set_support_level(2) -.add_type_rel("Conv1D", Conv1DRel) -.set_attr("FInferCorrectLayout", ConvInferCorrectLayout); - + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("weight", "Tensor", "The weight tensor.") + .set_support_level(2) + .add_type_rel("Conv1D", Conv1DRel) + .set_attr("FInferCorrectLayout", ConvInferCorrectLayout); // relay.nn.conv2d TVM_REGISTER_NODE_TYPE(Conv2DAttrs); TVM_REGISTER_GLOBAL("relay.op.nn._make.conv2d") -.set_body_typed([](Expr data, - Expr weight, - Array strides, - Array padding, - Array dilation, - int groups, - IndexExpr channels, - Array kernel_size, - std::string data_layout, - std::string kernel_layout, - std::string out_layout, - DataType out_dtype) { - return MakeConv( - data, weight, strides, padding, dilation, - groups, channels, kernel_size, data_layout, - kernel_layout, out_layout, out_dtype, "nn.conv2d"); -}); - + .set_body_typed([](Expr data, Expr weight, Array strides, Array padding, + Array dilation, int groups, IndexExpr channels, + Array kernel_size, String data_layout, String kernel_layout, + String out_layout, DataType out_dtype) { + return MakeConv(data, weight, strides, padding, dilation, groups, channels, + kernel_size, data_layout, kernel_layout, out_layout, out_dtype, + "nn.conv2d"); + }); RELAY_REGISTER_OP("nn.conv2d") -.describe(R"code(2D convolution layer (e.g. spatial convolution over images). + .describe(R"code(2D convolution layer (e.g. spatial convolution over images). This layer creates a convolution kernel that is convolved with the layer input to produce a tensor of outputs. @@ -248,40 +188,29 @@ with the layer input to produce a tensor of outputs. (batch_size, channels, out_height, out_width) if `layout` is `NCHW`. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(2) -.add_argument("data", "Tensor", "The input tensor.") -.add_argument("weight", "Tensor", "The weight tensor.") -.set_support_level(2) -.add_type_rel("Conv2D", Conv2DRel) -.set_attr("FInferCorrectLayout", ConvInferCorrectLayout); - + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("weight", "Tensor", "The weight tensor.") + .set_support_level(2) + .add_type_rel("Conv2D", Conv2DRel) + .set_attr("FInferCorrectLayout", ConvInferCorrectLayout); // relay.nn.conv3d TVM_REGISTER_NODE_TYPE(Conv3DAttrs); TVM_REGISTER_GLOBAL("relay.op.nn._make.conv3d") -.set_body_typed([](Expr data, - Expr weight, - Array strides, - Array padding, - Array dilation, - int groups, - IndexExpr channels, - Array kernel_size, - std::string data_layout, - std::string kernel_layout, - std::string out_layout, - DataType out_dtype) { - return MakeConv( - data, weight, strides, padding, dilation, - groups, channels, kernel_size, data_layout, - kernel_layout, out_layout, out_dtype, "nn.conv3d"); -}); - + .set_body_typed([](Expr data, Expr weight, Array strides, Array padding, + Array dilation, int groups, IndexExpr channels, + Array kernel_size, String data_layout, String kernel_layout, + String out_layout, DataType out_dtype) { + return MakeConv(data, weight, strides, padding, dilation, groups, channels, + kernel_size, data_layout, kernel_layout, out_layout, out_dtype, + "nn.conv3d"); + }); RELAY_REGISTER_OP("nn.conv3d") -.describe(R"code(3D convolution layer (e.g. convolution over 3D image data, + .describe(R"code(3D convolution layer (e.g. convolution over 3D image data, like Magnetic Resonance Imaging (MRI) data in medicine). This layer creates a convolution kernel that is convolved @@ -294,40 +223,74 @@ with the layer input to produce a tensor of outputs. (batch_size, channels, out_depth, out_height, out_width) if `layout` is `NCDHW`. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(2) -.add_argument("data", "Tensor", "The input tensor.") -.add_argument("weight", "Tensor", "The weight tensor.") -.set_support_level(2) -.add_type_rel("Conv3D", Conv3DRel) -.set_attr("FInferCorrectLayout", ConvInferCorrectLayout); + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("weight", "Tensor", "The weight tensor.") + .set_support_level(2) + .add_type_rel("Conv3D", Conv3DRel) + .set_attr("FInferCorrectLayout", ConvInferCorrectLayout); + +// relay.nn.conv3d_transpose +TVM_REGISTER_NODE_TYPE(Conv3DTransposeAttrs); + +TVM_REGISTER_GLOBAL("relay.op.nn._make.conv3d_transpose") + .set_body_typed([](Expr data, Expr weight, Array strides, Array padding, + Array dilation, int groups, IndexExpr channels, + Array kernel_size, String data_layout, String kernel_layout, + String out_layout, Array output_padding, DataType out_dtype) { + return MakeConvTranspose( + data, weight, strides, padding, dilation, groups, channels, kernel_size, data_layout, + kernel_layout, out_layout, output_padding, out_dtype, "nn.conv3d_transpose"); + }); + +RELAY_REGISTER_OP("nn.conv3d_transpose") + .describe(R"code(Transposed 3D convolution layer (sometimes called Deconvolution 3D). +The need for transposed convolutions generally arises +from the desire to use a transformation going in the opposite direction +of a normal convolution, i.e., from something that has the shape of the +output of some convolution to something that has the shape of its input +while maintaining a connectivity pattern that is compatible with +said convolution. + +- **data**: This depends on the `layout` parameter. Input is 5D array of shape + (batch_size, in_channels, depth, height, width) if `layout` is `NCDHW`. +- **weight**: (in_channels, channels, kernel_size[0], kernel_size[1], kernel_size[2]) +- **bias**: (channels,) +- **out**: This depends on the `layout` parameter. Output is 5D array of shape + (batch_size, channels, out_depth, out_height, out_width) if `layout` is `NCDHW`. + + out_depth and out_height and out_width are calculated as:: + out_depth = (depth-1)*strides[0]-2*padding[0]+kernel_size[0]+output_padding[0] + out_height = (height-1)*strides[1]-2*padding[1]+kernel_size[1]+output_padding[1] + out_width = (width-1)*strides[2]-2*padding[2]+kernel_size[2]+output_padding[2] + +)code" TVM_ADD_FILELINE) + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("weight", "Tensor", "The weight tensor.") + .set_support_level(2) + .set_attr("FInferCorrectLayout", + ConvInferCorrectLayout) + .add_type_rel("Conv3DTranspose", Conv3DTransposeRel); // relay.nn.conv2d_transpose TVM_REGISTER_NODE_TYPE(Conv2DTransposeAttrs); TVM_REGISTER_GLOBAL("relay.op.nn._make.conv2d_transpose") -.set_body_typed([](Expr data, - Expr weight, - Array strides, - Array padding, - Array dilation, - int groups, - IndexExpr channels, - Array kernel_size, - std::string data_layout, - std::string kernel_layout, - std::string out_layout, - Array output_padding, - DataType out_dtype) { - return MakeConvTranspose( - data, weight, strides, padding, dilation, - groups, channels, kernel_size, data_layout, - kernel_layout, out_layout, output_padding, out_dtype, "nn.conv2d_transpose"); -}); + .set_body_typed([](Expr data, Expr weight, Array strides, Array padding, + Array dilation, int groups, IndexExpr channels, + Array kernel_size, String data_layout, String kernel_layout, + String out_layout, Array output_padding, DataType out_dtype) { + return MakeConvTranspose( + data, weight, strides, padding, dilation, groups, channels, kernel_size, data_layout, + kernel_layout, out_layout, output_padding, out_dtype, "nn.conv2d_transpose"); + }); RELAY_REGISTER_OP("nn.conv2d_transpose") -.describe(R"code(Transposed 2D convolution layer (sometimes called Deconvolution). + .describe(R"code(Transposed 2D convolution layer (sometimes called Deconvolution). The need for transposed convolutions generally arises from the desire to use a transformation going in the opposite direction @@ -348,40 +311,30 @@ v (batch_size, channels, out_height, out_width) if `layout` is `NCHW` out_width = (width-1)*strides[1]-2*padding[1]+kernel_size[1]+output_padding[1] )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(2) -.add_argument("data", "Tensor", "The input tensor.") -.add_argument("weight", "Tensor", "The weight tensor.") -.set_support_level(2) -.set_attr("FInferCorrectLayout", - ConvInferCorrectLayout) -.add_type_rel("Conv2DTranspose", Conv2DTransposeRel); + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("weight", "Tensor", "The weight tensor.") + .set_support_level(2) + .set_attr("FInferCorrectLayout", + ConvInferCorrectLayout) + .add_type_rel("Conv2DTranspose", Conv2DTransposeRel); // relay.nn.conv1d_transpose TVM_REGISTER_NODE_TYPE(Conv1DTransposeAttrs); TVM_REGISTER_GLOBAL("relay.op.nn._make.conv1d_transpose") -.set_body_typed([](Expr data, - Expr weight, - Array strides, - Array padding, - Array dilation, - int groups, - IndexExpr channels, - Array kernel_size, - std::string data_layout, - std::string kernel_layout, - std::string out_layout, - Array output_padding, - DataType out_dtype) { - return MakeConvTranspose( - data, weight, strides, padding, dilation, - groups, channels, kernel_size, data_layout, - kernel_layout, out_layout, output_padding, out_dtype, "nn.conv1d_transpose"); -}); + .set_body_typed([](Expr data, Expr weight, Array strides, Array padding, + Array dilation, int groups, IndexExpr channels, + Array kernel_size, String data_layout, String kernel_layout, + String out_layout, Array output_padding, DataType out_dtype) { + return MakeConvTranspose( + data, weight, strides, padding, dilation, groups, channels, kernel_size, data_layout, + kernel_layout, out_layout, output_padding, out_dtype, "nn.conv1d_transpose"); + }); RELAY_REGISTER_OP("nn.conv1d_transpose") -.describe(R"code(Transposed 1D convolution layer (sometimes called Deconvolution). + .describe(R"code(Transposed 1D convolution layer (sometimes called Deconvolution). The need for transposed convolutions generally arises from the desire to use a transformation going in the opposite direction @@ -401,39 +354,29 @@ said convolution. out_width = (width-1)*strides[0]-2*padding[0]+kernel_size[0]+output_padding[0] )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(2) -.add_argument("data", "Tensor", "The input tensor.") -.add_argument("weight", "Tensor", "The weight tensor.") -.set_support_level(2) -.add_type_rel("Conv1DTranspose", Conv1DTransposeRel); + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("weight", "Tensor", "The weight tensor.") + .set_support_level(2) + .add_type_rel("Conv1DTranspose", Conv1DTransposeRel); // relay.nn.contrib_conv2d_winograd_without_weight_transform TVM_REGISTER_NODE_TYPE(Conv2DWinogradAttrs); TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv2d_winograd_without_weight_transform") -.set_body_typed([](Expr data, - Expr weight, - int tile_size, - Array strides, - Array padding, - Array dilation, - int groups, - IndexExpr channels, - Array kernel_size, - std::string data_layout, - std::string kernel_layout, - std::string out_layout, - DataType out_dtype) { - return MakeConvWinograd( - data, weight, tile_size, strides, padding, dilation, - groups, channels, kernel_size, data_layout, - kernel_layout, out_layout, out_dtype, "nn.contrib_conv2d_winograd_without_weight_transform"); -}); - + .set_body_typed([](Expr data, Expr weight, int tile_size, Array strides, + Array padding, Array dilation, int groups, + IndexExpr channels, Array kernel_size, String data_layout, + String kernel_layout, String out_layout, DataType out_dtype) { + return MakeConvWinograd( + data, weight, tile_size, strides, padding, dilation, groups, channels, kernel_size, + data_layout, kernel_layout, out_layout, out_dtype, + "nn.contrib_conv2d_winograd_without_weight_transform"); + }); RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_without_weight_transform") -.describe(R"code(Compute conv2d with winograd algorithm. Only supports NCHW layout. + .describe(R"code(Compute conv2d with winograd algorithm. Only supports NCHW layout. This operator assumes the weight tensor is already pre-transformed by nn.contrib_conv2d_winograd_weight_transform. @@ -444,64 +387,54 @@ RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_without_weight_transform") - **out**: Output is 4D array of shape (batch_size, channels, out_height, out_width) )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(2) -.add_argument("data", "Tensor", "The input tensor.") -.add_argument("weight", "Tensor", "The weight tensor.") -.set_support_level(10) -.add_type_rel("Conv2DWinograd", Conv2DWinogradRel) -.set_attr("FInferCorrectLayout", - ConvInferCorrectLayout); + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("weight", "Tensor", "The weight tensor.") + .set_support_level(10) + .add_type_rel("Conv2DWinograd", Conv2DWinogradRel) + .set_attr("FInferCorrectLayout", + ConvInferCorrectLayout); // relay.nn.contrib_conv2d_winograd_weight_transform TVM_REGISTER_NODE_TYPE(ConvWinogradWeightTransformAttrs); TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv2d_winograd_weight_transform") -.set_body_typed([](Expr weight, - int tile_size) { - return MakeConvWinogradWeightTransform( - weight, tile_size, "nn.contrib_conv2d_winograd_weight_transform"); -}); + .set_body_typed([](Expr weight, int tile_size) { + return MakeConvWinogradWeightTransform(weight, tile_size, + "nn.contrib_conv2d_winograd_weight_transform"); + }); RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_weight_transform") -.describe(R"code(Weight transformation of winograd fast convolution algorithm. + .describe(R"code(Weight transformation of winograd fast convolution algorithm. Separate this into another operator in order to enable Precompute Pass to compute the weight transformation in advance. - **weight**: (channels, in_channels, kernel_size[0], kernel_size[1]) )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("weight", "Tensor", "The weight tensor.") -.set_support_level(10) -.add_type_rel("Conv2DWinogradWeightTransform", Conv2DWinogradWeightTransformRel); + .set_attrs_type() + .set_num_inputs(1) + .add_argument("weight", "Tensor", "The weight tensor.") + .set_support_level(10) + .add_type_rel("Conv2DWinogradWeightTransform", Conv2DWinogradWeightTransformRel); // relay.nn.contrib_conv3d_winograd_without_weight_transform TVM_REGISTER_NODE_TYPE(Conv3DWinogradAttrs); TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv3d_winograd_without_weight_transform") -.set_body_typed([](Expr data, - Expr weight, - int tile_size, - Array strides, - Array padding, - Array dilation, - int groups, - IndexExpr channels, - Array kernel_size, - std::string data_layout, - std::string kernel_layout, - std::string out_layout, - DataType out_dtype) { - return MakeConvWinograd( - data, weight, tile_size, strides, padding, dilation, - groups, channels, kernel_size, data_layout, - kernel_layout, out_layout, out_dtype, "nn.contrib_conv3d_winograd_without_weight_transform"); -}); + .set_body_typed([](Expr data, Expr weight, int tile_size, Array strides, + Array padding, Array dilation, int groups, + IndexExpr channels, Array kernel_size, String data_layout, + String kernel_layout, String out_layout, DataType out_dtype) { + return MakeConvWinograd( + data, weight, tile_size, strides, padding, dilation, groups, channels, kernel_size, + data_layout, kernel_layout, out_layout, out_dtype, + "nn.contrib_conv3d_winograd_without_weight_transform"); + }); RELAY_REGISTER_OP("nn.contrib_conv3d_winograd_without_weight_transform") -.describe(R"code(Compute conv3d with winograd algorithm. Only supports NCDHW layout. + .describe(R"code(Compute conv3d with winograd algorithm. Only supports NCDHW layout. This operator assumes the weight tensor is already pre-transformed by nn.contrib_conv3d_winograd_weight_transform. @@ -512,22 +445,21 @@ RELAY_REGISTER_OP("nn.contrib_conv3d_winograd_without_weight_transform") - **out**: Output is 5D array of shape (batch_size, channels, depth, out_height, out_width) )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(2) -.add_argument("data", "Tensor", "The input tensor.") -.add_argument("weight", "Tensor", "The weight tensor.") -.set_support_level(10) -.add_type_rel("Conv3DWinograd", Conv3DWinogradRel) -.set_attr("FInferCorrectLayout", - ConvInferCorrectLayout); + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("weight", "Tensor", "The weight tensor.") + .set_support_level(10) + .add_type_rel("Conv3DWinograd", Conv3DWinogradRel) + .set_attr("FInferCorrectLayout", + ConvInferCorrectLayout); // relay.nn.contrib_conv3d_winograd_weight_transform TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv3d_winograd_weight_transform") -.set_body_typed([](Expr weight, - int tile_size) { - return MakeConvWinogradWeightTransform( - weight, tile_size, "nn.contrib_conv3d_winograd_weight_transform"); -}); + .set_body_typed([](Expr weight, int tile_size) { + return MakeConvWinogradWeightTransform(weight, tile_size, + "nn.contrib_conv3d_winograd_weight_transform"); + }); RELAY_REGISTER_OP("nn.contrib_conv3d_winograd_weight_transform") .describe(R"code(Weight transformation of winograd fast 3d convolution algorithm. @@ -537,18 +469,16 @@ weight transformation in advance. - **weight**: (channels, in_channels, kernel_size[0], kernel_size[1], kernel_size[2]) )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("weight", "Tensor", "The weight tensor.") -.set_support_level(10) -.add_type_rel("Conv3DWinogradWeightTransform", Conv3DWinogradWeightTransformRel); - + .set_attrs_type() + .set_num_inputs(1) + .add_argument("weight", "Tensor", "The weight tensor.") + .set_support_level(10) + .add_type_rel("Conv3DWinogradWeightTransform", Conv3DWinogradWeightTransformRel); // relay.nn.contrib_conv2d_winograd_nnpack_weight_transform TVM_REGISTER_NODE_TYPE(Conv2DWinogradNNPACKWeightTransformAttrs); -Expr MakeConv2DWinogradNNPACKWeightTransform(Expr weight, - int convolution_algorithm, +Expr MakeConv2DWinogradNNPACKWeightTransform(Expr weight, int convolution_algorithm, DataType out_dtype) { auto attrs = make_object(); attrs->convolution_algorithm = convolution_algorithm; @@ -558,99 +488,75 @@ Expr MakeConv2DWinogradNNPACKWeightTransform(Expr weight, } TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv2d_winograd_nnpack_weight_transform") -.set_body_typed(MakeConv2DWinogradNNPACKWeightTransform); + .set_body_typed(MakeConv2DWinogradNNPACKWeightTransform); RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_nnpack_weight_transform") -.describe(R"code(Weight transformation of winograd fast convolution algorithm with NNPACK. + .describe(R"code(Weight transformation of winograd fast convolution algorithm with NNPACK. Separate this into another symbol in order to enable Precompute Pass to compute the weight transformation in advance. - **weight**: (channels, in_channels, kernel_size[0], kernel_size[1]) )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("weight", "Tensor", "The weight tensor.") -.set_support_level(10) -.add_type_rel("Conv2DWinogradNNPACKWeightTransform", Conv2DWinogradNNPACKWeightTransformRel); - + .set_attrs_type() + .set_num_inputs(1) + .add_argument("weight", "Tensor", "The weight tensor.") + .set_support_level(10) + .add_type_rel("Conv2DWinogradNNPACKWeightTransform", Conv2DWinogradNNPACKWeightTransformRel); // Positional relay function to create conv2d NCHWc operator // used by frontend FFI. TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv2d_NCHWc") -.set_body_typed([](Expr data, - Expr weight, - Array strides, - Array padding, - Array dilation, - int groups, - IndexExpr channels, - Array kernel_size, - std::string data_layout, - std::string kernel_layout, - std::string out_layout, - DataType out_dtype) { - return MakeConv( - data, weight, strides, padding, dilation, - groups, channels, kernel_size, data_layout, - kernel_layout, out_layout, out_dtype, "nn.contrib_conv2d_NCHWc"); -}); + .set_body_typed([](Expr data, Expr weight, Array strides, Array padding, + Array dilation, int groups, IndexExpr channels, + Array kernel_size, String data_layout, String kernel_layout, + String out_layout, DataType out_dtype) { + return MakeConv(data, weight, strides, padding, dilation, groups, channels, + kernel_size, data_layout, kernel_layout, out_layout, out_dtype, + "nn.contrib_conv2d_NCHWc"); + }); RELAY_REGISTER_OP("nn.contrib_conv2d_NCHWc") -.describe(R"code(Compute conv2d with NCHWc data layout. Only supports NCHW layout. + .describe(R"code(Compute conv2d with NCHWc data layout. Only supports NCHW layout. - **data**: Input is 5D packed tensor. - **weight**: 6D packed tensor. - **out**: Output is 5D packed tensor )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(2) -.add_argument("data", "Tensor", "The input tensor.") -.add_argument("weight", "Tensor", "The weight tensor.") -.set_support_level(10) -.add_type_rel("Conv2DNCHWc", Conv2DWinogradRel) -.set_attr("FInferCorrectLayout", - ConvInferCorrectLayout); - + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("weight", "Tensor", "The weight tensor.") + .set_support_level(10) + .add_type_rel("Conv2DNCHWc", Conv2DWinogradRel) + .set_attr("FInferCorrectLayout", ConvInferCorrectLayout); // Positional relay function to create depthwise conv2d NCHWc operator // used by frontend FFI. TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_depthwise_conv2d_NCHWc") -.set_body_typed([](Expr data, - Expr weight, - Array strides, - Array padding, - Array dilation, - int groups, - IndexExpr channels, - Array kernel_size, - std::string data_layout, - std::string kernel_layout, - std::string out_layout, - DataType out_dtype) { - return MakeConv( - data, weight, strides, padding, dilation, - groups, channels, kernel_size, data_layout, - kernel_layout, out_layout, out_dtype, "nn.contrib_depthwise_conv2d_NCHWc"); -}); - + .set_body_typed([](Expr data, Expr weight, Array strides, Array padding, + Array dilation, int groups, IndexExpr channels, + Array kernel_size, String data_layout, String kernel_layout, + String out_layout, DataType out_dtype) { + return MakeConv(data, weight, strides, padding, dilation, groups, channels, + kernel_size, data_layout, kernel_layout, out_layout, out_dtype, + "nn.contrib_depthwise_conv2d_NCHWc"); + }); RELAY_REGISTER_OP("nn.contrib_depthwise_conv2d_NCHWc") -.describe(R"code(Compute conv2d with NCHWc data layout. Only supports NCHW layout. + .describe(R"code(Compute conv2d with NCHWc data layout. Only supports NCHW layout. - **data**: Input is 5D packed tensor. - **weight**: 6D packed tensor. - **out**: Output is 5D packed tensor )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(2) -.add_argument("data", "Tensor", "The input tensor.") -.add_argument("weight", "Tensor", "The weight tensor.") -.set_support_level(10) -.add_type_rel("Conv2D", Conv2DRel) -.set_attr("FInferCorrectLayout", - ConvInferCorrectLayout); - + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("weight", "Tensor", "The weight tensor.") + .set_support_level(10) + .add_type_rel("Conv2D", Conv2DRel) + .set_attr("FInferCorrectLayout", ConvInferCorrectLayout); TVM_REGISTER_NODE_TYPE(DeformableConv2DAttrs); @@ -674,36 +580,25 @@ along the channel axis, and also evenly split `weight` along the first dimension the convolution on the *i*-th part of the data with the *i*-th weight part. The output is obtained by concating all the *g* results. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(3) -.add_argument("data", "Tensor", "The input tensor.") -.add_argument("offset", "Tensor", "The offset tensor.") -.add_argument("weight", "Tensor", "The weight tensor.") -.set_support_level(5) -.add_type_rel("DeformableConv2D", DeformableConv2DRel); + .set_attrs_type() + .set_num_inputs(3) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("offset", "Tensor", "The offset tensor.") + .add_argument("weight", "Tensor", "The weight tensor.") + .set_support_level(5) + .add_type_rel("DeformableConv2D", DeformableConv2DRel); // Positional relay function to create deformable_conv2d operator // used by frontend FFI. TVM_REGISTER_GLOBAL("relay.op.nn._make.deformable_conv2d") -.set_body_typed([](Expr data, - Expr offset, - Expr weight, - Array strides, - Array padding, - Array dilation, - int deformable_groups, - int groups, - int channels, - Array kernel_size, - std::string data_layout, - std::string kernel_layout, - std::string out_layout, - DataType out_dtype) { - return MakeDeformableConv( - data, offset, weight, strides, padding, dilation, - deformable_groups, groups, channels, kernel_size, data_layout, - kernel_layout, out_layout, out_dtype, "nn.deformable_conv2d"); -}); + .set_body_typed([](Expr data, Expr offset, Expr weight, Array strides, + Array padding, Array dilation, int deformable_groups, + int groups, int channels, Array kernel_size, String data_layout, + String kernel_layout, String out_layout, DataType out_dtype) { + return MakeDeformableConv( + data, offset, weight, strides, padding, dilation, deformable_groups, groups, channels, + kernel_size, data_layout, kernel_layout, out_layout, out_dtype, "nn.deformable_conv2d"); + }); } // namespace relay } // namespace tvm diff --git a/src/relay/op/nn/convolution.h b/src/relay/op/nn/convolution.h index 6c5aebe2bd4c..0c5b20a153cf 100644 --- a/src/relay/op/nn/convolution.h +++ b/src/relay/op/nn/convolution.h @@ -24,7 +24,6 @@ #ifndef TVM_RELAY_OP_NN_CONVOLUTION_H_ #define TVM_RELAY_OP_NN_CONVOLUTION_H_ -#include #include #include @@ -36,7 +35,6 @@ namespace tvm { namespace relay { - // Standard convolution operator shape relations template bool Conv1DRel(const Array& types, int num_inputs, const Attrs& attrs, @@ -93,7 +91,7 @@ bool Conv1DRel(const Array& types, int num_inputs, const Attrs& attrs, auto wshape = trans_kernel_layout.ForwardShape(weight->shape); if (param->kernel_size.defined()) { // check the size - CHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]) ) + CHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2])) << "Conv1D: shape of weight is inconsistent with kernel_size, " << " kernel_size=" << param->kernel_size << " wshape=" << wshape; } @@ -111,7 +109,8 @@ bool Conv1DRel(const Array& types, int num_inputs, const Attrs& attrs, if (!dshape_ncw[2].as()) { oshape.Set(2, indexdiv(dshape_ncw[2] + param->padding[0] + param->padding[1] - dilated_ksize, - param->strides[0]) + 1); + param->strides[0]) + + 1); } else { oshape.Set(2, dshape_ncw[2]); } @@ -160,8 +159,8 @@ bool Conv2DRel(const Array& types, int num_inputs, const Attrs& attrs, Array dshape_nchw = trans_in_layout.ForwardShape(data->shape); bool is_depthwise = false; if (param->groups > 1) { - CHECK(weight && weight->shape.defined()) << - "Weight shape must be specified when groups is greater than 1."; + CHECK(weight && weight->shape.defined()) + << "Weight shape must be specified when groups is greater than 1."; Array wshape_oihw = trans_kernel_layout.ForwardShape(weight->shape); if (tvm::tir::ExprDeepEqual()(param->groups, dshape_nchw[1]) && tvm::tir::ExprDeepEqual()(param->groups, wshape_oihw[0])) { @@ -223,15 +222,13 @@ bool Conv2DRel(const Array& types, int num_inputs, const Attrs& attrs, IndexExpr pad_h, pad_w; GetPaddingHeightWidth(param->padding, &pad_h, &pad_w); if (!dshape_nchw[2].as()) { - oshape.Set(2, indexdiv(dshape_nchw[2] + pad_h - dilated_ksize_y, - param->strides[0]) + 1); + oshape.Set(2, indexdiv(dshape_nchw[2] + pad_h - dilated_ksize_y, param->strides[0]) + 1); } else { oshape.Set(2, dshape_nchw[2]); } if (!dshape_nchw[3].as()) { - oshape.Set(3, indexdiv(dshape_nchw[3] + pad_w - dilated_ksize_x, - param->strides[1]) + 1); + oshape.Set(3, indexdiv(dshape_nchw[3] + pad_w - dilated_ksize_x, param->strides[1]) + 1); } else { oshape.Set(3, dshape_nchw[3]); } @@ -337,22 +334,19 @@ bool Conv3DRel(const Array& types, int num_inputs, const Attrs& attrs, IndexExpr pad_d, pad_h, pad_w; GetPaddingDepthHeightWidth(param->padding, &pad_d, &pad_h, &pad_w); if (!dshape_ncdhw[2].as()) { - oshape.Set(2, indexdiv(dshape_ncdhw[2] + pad_d - dilated_ksize_z, - param->strides[0]) + 1); + oshape.Set(2, indexdiv(dshape_ncdhw[2] + pad_d - dilated_ksize_z, param->strides[0]) + 1); } else { oshape.Set(2, dshape_ncdhw[2]); } if (!dshape_ncdhw[3].as()) { - oshape.Set(3, indexdiv(dshape_ncdhw[3] + pad_h - dilated_ksize_y, - param->strides[1]) + 1); + oshape.Set(3, indexdiv(dshape_ncdhw[3] + pad_h - dilated_ksize_y, param->strides[1]) + 1); } else { oshape.Set(3, dshape_ncdhw[3]); } if (!dshape_ncdhw[4].as()) { - oshape.Set(4, indexdiv(dshape_ncdhw[4] + pad_w - dilated_ksize_x, - param->strides[2]) + 1); + oshape.Set(4, indexdiv(dshape_ncdhw[4] + pad_w - dilated_ksize_x, param->strides[2]) + 1); } else { oshape.Set(4, dshape_ncdhw[4]); } @@ -366,7 +360,6 @@ bool Conv3DRel(const Array& types, int num_inputs, const Attrs& attrs, return true; } - // Winograd convolution shape relations inline bool Conv2DWinogradWeightTransformRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { @@ -379,15 +372,14 @@ inline bool Conv2DWinogradWeightTransformRel(const Array& types, int num_i CHECK_EQ(data->shape.size(), 4) << "Only support NCHW normal kernel layout"; - std::vector oshape { + std::vector oshape{ param->tile_size + data->shape[2] - 1, param->tile_size + data->shape[3] - 1, data->shape[0], data->shape[1], }; - reporter->Assign(types[1], TensorType(Array(oshape), - data->dtype)); + reporter->Assign(types[1], TensorType(Array(oshape), data->dtype)); return true; } @@ -405,7 +397,7 @@ inline bool Conv3DWinogradWeightTransformRel(const Array& types, int num_i // Shape of packed weights depends on whether depth is being transformed or not. Array oshape({0, 0, 0, data->shape[0], data->shape[1]}); auto* depth_imm = data->shape[2].as(); - bool transform_depth = (depth_imm->value > 2)&&(depth_imm->value < 8); + bool transform_depth = (depth_imm->value > 2) && (depth_imm->value < 8); if (transform_depth) { oshape.Set(0, param->tile_size + data->shape[2] - 1); oshape.Set(1, param->tile_size + data->shape[3] - 1); @@ -450,10 +442,8 @@ inline bool Conv2DWinogradNNPACKWeightTransformRel(const Array& types, int return true; } -template -bool Conv2DWinogradRel(const Array& types, - int num_inputs, - const Attrs& attrs, +template +bool Conv2DWinogradRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 3); const auto* data = types[0].as(); @@ -468,13 +458,13 @@ bool Conv2DWinogradRel(const Array& types, const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCHW); CHECK(trans_in_layout.defined()) - << "Conv only support input layouts that are convertible from NCHW." - << " But got " << in_layout; + << "Conv only support input layouts that are convertible from NCHW." + << " But got " << in_layout; const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIHW); CHECK(trans_kernel_layout.defined()) - << "Conv only support kernel layouts that are convertible from OIHW." - << " But got "<< kernel_layout; + << "Conv only support kernel layouts that are convertible from OIHW." + << " But got " << kernel_layout; Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCHW); @@ -509,14 +499,12 @@ bool Conv2DWinogradRel(const Array& types, IndexExpr pad_h, pad_w; GetPaddingHeightWidth(param->padding, &pad_h, &pad_w); if (!dshape_nchw[2].as()) { - oshape.Set(2, (dshape_nchw[2] + pad_h - - dilated_ksize_y) / param->strides[0] + 1); + oshape.Set(2, (dshape_nchw[2] + pad_h - dilated_ksize_y) / param->strides[0] + 1); } else { oshape.Set(2, dshape_nchw[2]); } if (!dshape_nchw[3].as()) { - oshape.Set(3, (dshape_nchw[3] + pad_w - - dilated_ksize_x) / param->strides[1] + 1); + oshape.Set(3, (dshape_nchw[3] + pad_w - dilated_ksize_x) / param->strides[1] + 1); } else { oshape.Set(3, dshape_nchw[3]); } @@ -531,11 +519,8 @@ bool Conv2DWinogradRel(const Array& types, return true; } - -template -bool Conv3DWinogradRel(const Array& types, - int num_inputs, - const Attrs& attrs, +template +bool Conv3DWinogradRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 3); const auto* data = types[0].as(); @@ -550,13 +535,13 @@ bool Conv3DWinogradRel(const Array& types, const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCDHW); CHECK(trans_in_layout.defined()) - << "Conv only support input layouts that are convertible from NCDHW." - << " But got " << in_layout; + << "Conv only support input layouts that are convertible from NCDHW." + << " But got " << in_layout; const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIDHW); CHECK(trans_kernel_layout.defined()) - << "Conv only support kernel layouts that are convertible from OIDHW." - << " But got "<< kernel_layout; + << "Conv only support kernel layouts that are convertible from OIDHW." + << " But got " << kernel_layout; Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCDHW); @@ -592,20 +577,17 @@ bool Conv3DWinogradRel(const Array& types, IndexExpr pad_d, pad_h, pad_w; GetPaddingDepthHeightWidth(param->padding, &pad_d, &pad_h, &pad_w); if (!dshape_ncdhw[2].as()) { - oshape.Set(2, (dshape_ncdhw[2] + pad_d - - dilated_ksize_d) / param->strides[0] + 1); + oshape.Set(2, (dshape_ncdhw[2] + pad_d - dilated_ksize_d) / param->strides[0] + 1); } else { oshape.Set(2, dshape_ncdhw[2]); } if (!dshape_ncdhw[2].as()) { - oshape.Set(3, (dshape_ncdhw[3] + pad_h - - dilated_ksize_y) / param->strides[1] + 1); + oshape.Set(3, (dshape_ncdhw[3] + pad_h - dilated_ksize_y) / param->strides[1] + 1); } else { oshape.Set(3, dshape_ncdhw[3]); } if (!dshape_ncdhw[4].as()) { - oshape.Set(4, (dshape_ncdhw[4] + pad_w - - dilated_ksize_x) / param->strides[2] + 1); + oshape.Set(4, (dshape_ncdhw[4] + pad_w - dilated_ksize_x) / param->strides[2] + 1); } else { oshape.Set(4, dshape_ncdhw[4]); } @@ -620,12 +602,9 @@ bool Conv3DWinogradRel(const Array& types, return true; } - // Transposed convolution shape relations template -bool Conv1DTransposeRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool Conv1DTransposeRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 3); const auto* data = types[0].as(); @@ -642,19 +621,19 @@ bool Conv1DTransposeRel(const Array& types, const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCW); CHECK(trans_in_layout.defined()) - << "Conv only support input layouts that are convertible from NCW." - << " But got " << in_layout; + << "Conv only support input layouts that are convertible from NCW." + << " But got " << in_layout; const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIW); CHECK(trans_kernel_layout.defined()) - << "Conv only support kernel layouts that are convertible from OIW." - << " But got "<< kernel_layout; + << "Conv only support kernel layouts that are convertible from OIW." + << " But got " << kernel_layout; Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCW); CHECK(trans_out_layout.defined()) - << "Conv only support output layouts that are convertible from NCW." - << " But got " << out_layout; + << "Conv only support output layouts that are convertible from NCW." + << " But got " << out_layout; IndexExpr channels, dilated_ksize_y, dilated_ksize_x; @@ -665,9 +644,8 @@ bool Conv1DTransposeRel(const Array& types, CHECK_EQ(param->kernel_size.size(), 1); CHECK_EQ(param->dilation.size(), 1); - Array wshape({dshape_ncw[1], - indexdiv(param->channels, param->groups), - param->kernel_size[0]}); + Array wshape( + {dshape_ncw[1], indexdiv(param->channels, param->groups), param->kernel_size[0]}); wshape = trans_kernel_layout.BackwardShape(wshape); dilated_ksize_x = 1 + (param->kernel_size[0] - 1) * param->dilation[0]; @@ -684,14 +662,12 @@ bool Conv1DTransposeRel(const Array& types, // check the size CHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2])) << "Conv1D: shape of weight is inconsistent with kernel_size, " - << " kernel_size=" << param->kernel_size - << " wshape=" << Array(wshape); + << " kernel_size=" << param->kernel_size << " wshape=" << Array(wshape); } if (param->channels.defined()) { CHECK(reporter->AssertEQ(param->channels, wshape[1])) << "Conv1D: shape of weight is inconsistent with channels, " - << " channels=" << param->channels - << " wshape=" << Array(wshape); + << " channels=" << param->channels << " wshape=" << Array(wshape); } CHECK(reporter->AssertEQ(indexdiv(dshape_ncw[1], param->groups), wshape[0])); channels = wshape[1]; @@ -701,8 +677,8 @@ bool Conv1DTransposeRel(const Array& types, IndexExpr pad_w; GetPaddingWidth(param->padding, &pad_w); Array oshape({dshape_ncw[0], channels, 0}); - oshape.Set(2, (param->strides[0] * (dshape_ncw[2] - 1) + dilated_ksize_x - - pad_w + param->output_padding[0])); + oshape.Set(2, (param->strides[0] * (dshape_ncw[2] - 1) + dilated_ksize_x - pad_w + + param->output_padding[0])); DataType out_dtype = param->out_dtype; if (out_dtype.bits() == 0) { @@ -713,11 +689,105 @@ bool Conv1DTransposeRel(const Array& types, return true; } +template +bool Conv3DTransposeRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 3); + const auto* data = types[0].as(); + const auto* weight = types[1].as(); + if (data == nullptr) return false; + + static const Layout kNCDHW("NCDHW"); + static const Layout kOIDHW("OIDHW"); + + const Conv3DTransposeAttrs* param = attrs.as(); + CHECK(param != nullptr); + const Layout in_layout(param->data_layout); + const Layout kernel_layout(param->kernel_layout); + + const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCDHW); + CHECK(trans_in_layout.defined()) + << "Conv3d_transpose only support input layouts that are convertible from NCDHW." + << " But got " << in_layout; + + const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIDHW); + CHECK(trans_kernel_layout.defined()) + << "Conv3d_transpose only support kernel layouts that are convertible from OIDHW." + << " But got " << kernel_layout; + + Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); + const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCDHW); + CHECK(trans_out_layout.defined()) + << "Conv3d_transpose only support output layouts that are convertible from NCDHW." + << " But got " << out_layout; + + IndexExpr channels, dilated_ksize_d, dilated_ksize_y, dilated_ksize_x; + + auto dshape_ncdhw = trans_in_layout.ForwardShape(data->shape); + + // infer weight if the kernel_size and channels are defined + if (param->kernel_size.defined() && param->channels.defined()) { + CHECK_EQ(param->kernel_size.size(), 3); + CHECK_EQ(param->dilation.size(), 3); + + Array wshape({dshape_ncdhw[1], indexdiv(param->channels, param->groups), + param->kernel_size[0], param->kernel_size[1], param->kernel_size[2]}); + + wshape = trans_kernel_layout.BackwardShape(wshape); + dilated_ksize_d = 1 + (param->kernel_size[0] - 1) * param->dilation[0]; + dilated_ksize_y = 1 + (param->kernel_size[1] - 1) * param->dilation[1]; + dilated_ksize_x = 1 + (param->kernel_size[2] - 1) * param->dilation[2]; + channels = param->channels; + + // assign result to reporter + reporter->Assign(types[1], TensorType(wshape, data->dtype)); + } else { + // use weight to infer the conv shape. + if (weight == nullptr) return false; + auto wshape = trans_kernel_layout.ForwardShape(weight->shape); + if (param->kernel_size.defined()) { + CHECK_EQ(param->kernel_size.size(), 3); + // check the size + CHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]) && + reporter->AssertEQ(param->kernel_size[1], wshape[3]) && + reporter->AssertEQ(param->kernel_size[2], wshape[4])) + << "Conv3D: shape of weight is inconsistent with kernel_size, " + << " kernel_size=" << param->kernel_size << " wshape=" << Array(wshape); + } + if (param->channels.defined()) { + CHECK(reporter->AssertEQ(param->channels, wshape[1])) + << "Conv3D: shape of weight is inconsistent with channels, " + << " channels=" << param->channels << " wshape=" << Array(wshape); + } + CHECK(reporter->AssertEQ(indexdiv(dshape_ncdhw[1], param->groups), wshape[0])); + channels = wshape[1]; + dilated_ksize_d = 1 + (wshape[2] - 1) * param->dilation[0]; + dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1]; + dilated_ksize_y = 1 + (wshape[4] - 1) * param->dilation[2]; + } + + // dilation + Array oshape({dshape_ncdhw[0], channels, 0, 0, 0}); + IndexExpr pad_d, pad_h, pad_w; + GetPaddingDepthHeightWidth(param->padding, &pad_d, &pad_h, &pad_w); + oshape.Set(2, (param->strides[0] * (dshape_ncdhw[2] - 1) + dilated_ksize_d - pad_d + + param->output_padding[0])); + oshape.Set(3, (param->strides[1] * (dshape_ncdhw[3] - 1) + dilated_ksize_y - pad_h + + param->output_padding[1])); + oshape.Set(4, (param->strides[2] * (dshape_ncdhw[4] - 1) + dilated_ksize_x - pad_w + + param->output_padding[2])); + + DataType out_dtype = param->out_dtype; + if (out_dtype.bits() == 0) { + out_dtype = data->dtype; + } + oshape = trans_out_layout.BackwardShape(oshape); + reporter->Assign(types[2], TensorType(oshape, out_dtype)); + return true; +} template -bool Conv2DTransposeRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool Conv2DTransposeRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 3); const auto* data = types[0].as(); @@ -734,19 +804,19 @@ bool Conv2DTransposeRel(const Array& types, const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCHW); CHECK(trans_in_layout.defined()) - << "Conv only support input layouts that are convertible from NCHW." - << " But got " << in_layout; + << "Conv only support input layouts that are convertible from NCHW." + << " But got " << in_layout; const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIHW); CHECK(trans_kernel_layout.defined()) - << "Conv only support kernel layouts that are convertible from OIHW." - << " But got "<< kernel_layout; + << "Conv only support kernel layouts that are convertible from OIHW." + << " But got " << kernel_layout; Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCHW); CHECK(trans_out_layout.defined()) - << "Conv only support output layouts that are convertible from NCHW." - << " But got " << out_layout; + << "Conv only support output layouts that are convertible from NCHW." + << " But got " << out_layout; IndexExpr channels, dilated_ksize_y, dilated_ksize_x; @@ -757,10 +827,8 @@ bool Conv2DTransposeRel(const Array& types, CHECK_EQ(param->kernel_size.size(), 2); CHECK_EQ(param->dilation.size(), 2); - Array wshape({dshape_nchw[1], - indexdiv(param->channels, param->groups), - param->kernel_size[0], - param->kernel_size[1]}); + Array wshape({dshape_nchw[1], indexdiv(param->channels, param->groups), + param->kernel_size[0], param->kernel_size[1]}); wshape = trans_kernel_layout.BackwardShape(wshape); dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0]; @@ -779,14 +847,12 @@ bool Conv2DTransposeRel(const Array& types, CHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]) && reporter->AssertEQ(param->kernel_size[1], wshape[3])) << "Conv2D: shape of weight is inconsistent with kernel_size, " - << " kernel_size=" << param->kernel_size - << " wshape=" << Array(wshape); + << " kernel_size=" << param->kernel_size << " wshape=" << Array(wshape); } if (param->channels.defined()) { CHECK(reporter->AssertEQ(param->channels, wshape[1])) << "Conv2D: shape of weight is inconsistent with channels, " - << " channels=" << param->channels - << " wshape=" << Array(wshape); + << " channels=" << param->channels << " wshape=" << Array(wshape); } CHECK(reporter->AssertEQ(indexdiv(dshape_nchw[1], param->groups), wshape[0])); channels = wshape[1]; @@ -797,10 +863,10 @@ bool Conv2DTransposeRel(const Array& types, Array oshape({dshape_nchw[0], channels, 0, 0}); IndexExpr pad_h, pad_w; GetPaddingHeightWidth(param->padding, &pad_h, &pad_w); - oshape.Set(2, (param->strides[0] * (dshape_nchw[2] - 1) + dilated_ksize_y - - pad_h + param->output_padding[0])); - oshape.Set(3, (param->strides[1] * (dshape_nchw[3] - 1) + dilated_ksize_x - - pad_w + param->output_padding[1])); + oshape.Set(2, (param->strides[0] * (dshape_nchw[2] - 1) + dilated_ksize_y - pad_h + + param->output_padding[0])); + oshape.Set(3, (param->strides[1] * (dshape_nchw[3] - 1) + dilated_ksize_x - pad_w + + param->output_padding[1])); DataType out_dtype = param->out_dtype; if (out_dtype.bits() == 0) { @@ -811,7 +877,6 @@ bool Conv2DTransposeRel(const Array& types, return true; } - // Deformable Convolution shape relations. template bool DeformableConv2DRel(const Array& types, int num_inputs, const Attrs& attrs, @@ -831,11 +896,8 @@ bool DeformableConv2DRel(const Array& types, int num_inputs, const Attrs& if (param->kernel_size.defined() && param->channels.defined()) { CHECK_EQ(param->kernel_size.size(), 2); CHECK_EQ(param->dilation.size(), 2); - Array wshape( - {param->channels, - indexdiv(data->shape[1], param->groups), - param->kernel_size[0], - param->kernel_size[1]}); + Array wshape({param->channels, indexdiv(data->shape[1], param->groups), + param->kernel_size[0], param->kernel_size[1]}); channels = param->channels; ksize_y = param->kernel_size[0]; ksize_x = param->kernel_size[1]; @@ -853,14 +915,12 @@ bool DeformableConv2DRel(const Array& types, int num_inputs, const Attrs& CHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]) && reporter->AssertEQ(param->kernel_size[1], wshape[3])) << "DeformableConv2D: shape of weight is inconsistent with kernel_size, " - << " kernel_size=" << param->kernel_size - << " wshape=" << wshape; + << " kernel_size=" << param->kernel_size << " wshape=" << wshape; } if (param->channels.defined()) { CHECK(reporter->AssertEQ(param->channels, wshape[0])) << "DeformableConv2D: shape of weight is inconsistent with channels, " - << " channels=" << param->channels - << " wshape=" << wshape; + << " channels=" << param->channels << " wshape=" << wshape; } CHECK(reporter->AssertEQ(indexdiv(data->shape[1], param->groups), wshape[1])); channels = wshape[0]; @@ -874,15 +934,13 @@ bool DeformableConv2DRel(const Array& types, int num_inputs, const Attrs& IndexExpr pad_h, pad_w; GetPaddingHeightWidth(param->padding, &pad_h, &pad_w); - oshape.Set(2, indexdiv(data->shape[2] + pad_h - dilated_ksize_y, - param->strides[0]) + 1); - oshape.Set(3, indexdiv(data->shape[3] + pad_w - dilated_ksize_x, - param->strides[1]) + 1); + oshape.Set(2, indexdiv(data->shape[2] + pad_h - dilated_ksize_y, param->strides[0]) + 1); + oshape.Set(3, indexdiv(data->shape[3] + pad_w - dilated_ksize_x, param->strides[1]) + 1); DataType out_dtype = param->out_dtype; // infer offset shape - Array offset_shape({data->shape[0], 2 * ksize_y * ksize_x * param->deformable_groups, - oshape[2], oshape[3]}); + Array offset_shape( + {data->shape[0], 2 * ksize_y * ksize_x * param->deformable_groups, oshape[2], oshape[3]}); reporter->Assign(types[1], TensorType(offset_shape, data->dtype)); if (out_dtype.bits() == 0) { out_dtype = data->dtype; @@ -892,23 +950,20 @@ bool DeformableConv2DRel(const Array& types, int num_inputs, const Attrs& return true; } - -template -Array > ConvInferCorrectLayout( - const Attrs& attrs, - const Array& new_in_layouts, - const Array& old_in_layouts, - const Array &old_in_types) { +template +Array > ConvInferCorrectLayout(const Attrs& attrs, + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array& old_in_types) { const T* params = attrs.as(); // We always make other operators to fit the layouts of convolution layers // So this inference ignores all inputs - return Array >{{params->data_layout, params->kernel_layout}, - {params->out_layout == "" ? - params->data_layout : params->out_layout}}; + return Array >{ + {params->data_layout, params->kernel_layout}, + {params->out_layout == "" ? params->data_layout : params->out_layout}}; } - } // namespace relay } // namespace tvm #endif // TVM_RELAY_OP_NN_CONVOLUTION_H_ diff --git a/src/relay/op/nn/correlation.cc b/src/relay/op/nn/correlation.cc new file mode 100644 index 000000000000..67f42b7d3e85 --- /dev/null +++ b/src/relay/op/nn/correlation.cc @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file correlation.cc + * \brief Correlation operators + */ +#include +#include +#include +#include +#include + +#include + +#include "../op_common.h" + +namespace tvm { +namespace relay { + +// relay.nn.correlation +TVM_REGISTER_NODE_TYPE(CorrelationAttrs); + +Array> CorrelationInferCorrectLayout(const Attrs& attrs, + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array& old_in_types) { + const auto* params = attrs.as(); + Layout layout{params->layout}; + return Array>{{layout, layout}, {layout}}; +} + +// Positional relay function to create correlation operator +// used by frontend FFI. +Expr MakeCorrelation(Expr data1, Expr data2, int kernel_size, int max_displacement, int stride1, + int stride2, Array padding, bool is_multiply, String layout) { + auto attrs = make_object(); + attrs->kernel_size = kernel_size; + attrs->max_displacement = max_displacement; + attrs->stride1 = stride1; + attrs->stride2 = stride2; + attrs->padding = std::move(padding); + attrs->is_multiply = is_multiply; + attrs->layout = std::move(layout); + static const Op& op = Op::Get("nn.correlation"); + return Call(op, {data1, data2}, Attrs(attrs), {}); +} + +bool CorrelationRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 3); + const auto* data1 = types[0].as(); + const auto* data2 = types[1].as(); + if (data1 == nullptr || data2 == nullptr) return false; + + const CorrelationAttrs* param = attrs.as(); + CHECK(param != nullptr); + CHECK_EQ(param->layout, "NCHW") << "layout not supported."; + IndexExpr pad_h, pad_w; + GetPaddingHeightWidth(param->padding, &pad_h, &pad_w); + IndexExpr padded_height = data1->shape[2] + pad_h; + IndexExpr padded_width = data2->shape[3] + pad_w; + int kernel_radius = (param->kernel_size - 1) / 2; + int border_size = param->max_displacement + kernel_radius; + int displacement_radius = param->max_displacement / param->stride2; + int displacement_size = 2 * displacement_radius + 1; + int out_channel = displacement_size * displacement_size; + IndexExpr out_height = + indexdiv((padded_height - 2 * border_size + param->stride1 - 1), param->stride1); + IndexExpr out_width = + indexdiv((padded_width - 2 * border_size + param->stride1 - 1), param->stride1); + Array oshape{data1->shape[0], out_channel, out_height, out_width}; + // assign output type + reporter->Assign(types[2], TensorType(oshape, data1->dtype)); + return true; +} + +TVM_REGISTER_GLOBAL("relay.op.nn._make.correlation").set_body_typed(MakeCorrelation); + +RELAY_REGISTER_OP("nn.correlation") + .describe(R"code(Applies correlation to inputs. + +The correlation layer performs multiplicative patch comparisons between two feature maps. +Given two multi-channel feature maps :math:`f_{1}, f_{2}`, with :math:`w`, :math:`h`, and :math:`c` being their width, height, and number of channels, +the correlation layer lets the network compare each patch from :math:`f_{1}` with each patch from :math:`f_{2}`. + +For now we consider only a single comparison of two patches. The 'correlation' of two patches centered at :math:`x_{1}` in the first map and +:math:`x_{2}` in the second map is then defined as: + +.. math:: + c(x_{1}, x_{2}) = \sum_{o \in [-k,k] \times [-k,k]} + +for a square patch of size :math:`K:=2k+1`. + +Note that the equation above is identical to one step of a convolution in neural networks, but instead of convolving data with a filter, it convolves data with other +data. For this reason, it has no training weights. + +Computing :math:`c(x_{1}, x_{2})` involves :math:`c * K^{2}` multiplications. Comparing all patch combinations involves :math:`w^{2}*h^{2}` such computations. + +Given a maximum displacement :math:`d`, for each location :math:`x_{1}` it computes correlations :math:`c(x_{1}, x_{2})` only in a neighborhood of size :math:`D:=2d+1`, +by limiting the range of :math:`x_{2}`. We use strides :math:`s_{1}, s_{2}`, to quantize :math:`x_{1}` globally and to quantize :math:`x_{2}` within the neighborhood +centered around :math:`x_{1}`. + +The final output is defined by the following expression: + +.. math:: + out[n, q, i, j] = c(x_{i, j}, x_{q}) + +where :math:`i` and :math:`j` enumerate spatial locations in :math:`f_{1}`, and :math:`q` denotes the :math:`q^{th}` neighborhood of :math:`x_{i,j}`. +)code" TVM_ADD_FILELINE) + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data1", "Tensor", "Input data1 to the correlation.") + .add_argument("data2", "Tensor", "Input data2 to the correlation.") + .set_support_level(2) + .set_attr("FInferCorrectLayout", CorrelationInferCorrectLayout) + .add_type_rel("Correlation", CorrelationRel); + +} // namespace relay +} // namespace tvm diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index b9ba74f9e95d..d65fc27472c0 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -22,20 +22,23 @@ * \brief Property def of nn operators. */ -#include -#include -#include -#include +#include "nn.h" + #include #include -#include #include -#include +#include +#include +#include +#include +#include + #include -#include "../type_relations.h" +#include + #include "../../transforms/infer_layout_util.h" #include "../op_common.h" -#include "nn.h" +#include "../type_relations.h" namespace tvm { namespace relay { @@ -43,9 +46,7 @@ namespace relay { // relay.nn.bias_add TVM_REGISTER_NODE_TYPE(BiasAddAttrs); -bool BiasAddRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool BiasAddRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 3); const auto* data = types[0].as(); @@ -61,45 +62,36 @@ bool BiasAddRel(const Array& types, << "axis " << param->axis << " is out of range"; // assign output type - reporter->Assign(types[1], TensorType( - {data->shape[axis]}, data->dtype)); + reporter->Assign(types[1], TensorType({data->shape[axis]}, data->dtype)); reporter->Assign(types[2], types[0]); return true; } - // Positional relay function to create dense operator used by frontend FFI. -Expr MakeBiasAdd(Expr data, - Expr bias, - int axis) { +Expr MakeBiasAdd(Expr data, Expr bias, int axis) { auto attrs = make_object(); attrs->axis = axis; static const Op& op = Op::Get("nn.bias_add"); return Call(op, {data, bias}, Attrs(attrs), {}); } - -TVM_REGISTER_GLOBAL("relay.op.nn._make.bias_add") -.set_body_typed(MakeBiasAdd); - +TVM_REGISTER_GLOBAL("relay.op.nn._make.bias_add").set_body_typed(MakeBiasAdd); RELAY_REGISTER_OP("nn.bias_add") -.describe(R"code(Add bias to an axis of the input. + .describe(R"code(Add bias to an axis of the input. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(2) -.add_argument("data", "nD Tensor", "Input data.") -.add_argument("bias", "1D Tensor", "Bias.") -.set_support_level(1) -.add_type_rel("BiasAdd", BiasAddRel) -.set_attr("FTVMCompute", [](const Attrs& attrs, - const Array& inputs, - const Type& out_type) { - const auto* param = attrs.as(); - return tvm::Array{topi::nn::bias_add(inputs[0], inputs[1], param->axis)}; -}); - + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "nD Tensor", "Input data.") + .add_argument("bias", "1D Tensor", "Bias.") + .set_support_level(1) + .add_type_rel("BiasAdd", BiasAddRel) + .set_attr("FTVMCompute", [](const Attrs& attrs, const Array& inputs, + const Type& out_type) { + const auto* param = attrs.as(); + return tvm::Array{topi::nn::bias_add(inputs[0], inputs[1], param->axis)}; + }); // relay.nn.fifo_buffer TVM_REGISTER_NODE_TYPE(FIFOBufferAttrs); @@ -111,9 +103,7 @@ Expr MakeFIFOBuffer(Expr input, Expr buffer, int axis) { return Call(op, {input, buffer}, Attrs(attrs), {}); } -bool FIFOBufferRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool FIFOBufferRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 3); const auto* input = types[0].as(); @@ -125,9 +115,8 @@ bool FIFOBufferRel(const Array& types, CHECK(param != nullptr); CHECK_EQ(input->shape.size(), buffer->shape.size()); - const size_t buffer_axis - = static_cast(param->axis < 0 ? static_cast(buffer->shape.size()) + param->axis - : param->axis); + const size_t buffer_axis = static_cast( + param->axis < 0 ? static_cast(buffer->shape.size()) + param->axis : param->axis); reporter->Assert(buffer_axis < buffer->shape.size()); for (size_t i = 0; i < buffer->shape.size(); ++i) { @@ -143,11 +132,10 @@ bool FIFOBufferRel(const Array& types, return true; } -TVM_REGISTER_GLOBAL("relay.op.nn._make.fifo_buffer") -.set_body_typed(MakeFIFOBuffer); +TVM_REGISTER_GLOBAL("relay.op.nn._make.fifo_buffer").set_body_typed(MakeFIFOBuffer); RELAY_REGISTER_OP("nn.fifo_buffer") -.describe(R"code(FIFO buffer + .describe(R"code(FIFO buffer Compute equivalent of ``` @@ -159,23 +147,18 @@ Useful for * Encoding explicit re-use of computation in convolution ops operated on a sliding window input * Implementing a FIFO queue to cache intermediate results, e.g. as in Fast WaveNet. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(2) -.add_argument("data", "Tensor", "Latest input") -.add_argument("buffer", "Tensor", - "Buffer storing latest [length_buffer] inputs") -.set_support_level(3) -.add_type_rel("FIFOBuffer", FIFOBufferRel); - + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "Tensor", "Latest input") + .add_argument("buffer", "Tensor", "Buffer storing latest [length_buffer] inputs") + .set_support_level(3) + .add_type_rel("FIFOBuffer", FIFOBufferRel); // relay.nn.dense TVM_REGISTER_NODE_TYPE(DenseAttrs); // Positional relay function to create dense operator used by frontend FFI. -Expr MakeDense(Expr data, - Expr weight, - IndexExpr units, - DataType out_dtype) { +Expr MakeDense(Expr data, Expr weight, IndexExpr units, DataType out_dtype) { auto attrs = make_object(); attrs->units = units; attrs->out_dtype = out_dtype; @@ -183,70 +166,58 @@ Expr MakeDense(Expr data, return Call(op, {data, weight}, Attrs(attrs), {}); } - -TVM_REGISTER_GLOBAL("relay.op.nn._make.dense") -.set_body_typed(MakeDense); - +TVM_REGISTER_GLOBAL("relay.op.nn._make.dense").set_body_typed(MakeDense); RELAY_REGISTER_OP("nn.dense") -.describe(R"code(Applies a linear transformation: :math:`Y = XW^T`. + .describe(R"code(Applies a linear transformation: :math:`Y = XW^T`. - **data**: `(x1, x2, ..., xn, input_dim)` - **weight**: `(units, input_dim)` - **out**: `(x1, x2, ..., xn, units)`. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(2) -.add_argument("data", "nD Tensor", "Input data.") -.add_argument("weight", "2D Tensor", "Weight matrix.") -.set_support_level(1) -.add_type_rel("Dense", DenseRel); + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "nD Tensor", "Input data.") + .add_argument("weight", "2D Tensor", "Weight matrix.") + .set_support_level(1) + .add_type_rel("Dense", DenseRel); // relay.leaky_relu TVM_REGISTER_NODE_TYPE(LeakyReluAttrs); // Positional relay function to create leaky relu operator used by frontend FFI. -Expr MakeLeakyRelu(Expr data, - double alpha) { +Expr MakeLeakyRelu(Expr data, double alpha) { auto attrs = make_object(); attrs->alpha = alpha; static const Op& op = Op::Get("nn.leaky_relu"); return Call(op, {data}, Attrs(attrs), {}); } - -TVM_REGISTER_GLOBAL("relay.op.nn._make.leaky_relu") -.set_body_typed(MakeLeakyRelu); - +TVM_REGISTER_GLOBAL("relay.op.nn._make.leaky_relu").set_body_typed(MakeLeakyRelu); RELAY_REGISTER_OP("nn.leaky_relu") -.describe(R"code(Leaky version of a Rectified Linear Unit. + .describe(R"code(Leaky version of a Rectified Linear Unit. `y = x > 0 ? x : alpha * x` )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "Input data.") -.set_support_level(3) -.add_type_rel("Identity", IdentityRel) -.set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) -.set_attr( - "FTVMCompute", [](const Attrs& attrs, - const Array& inputs, - const Type& out_type) { - const auto* param = attrs.as(); - return Array{ topi::leaky_relu(inputs[0], param->alpha) }; -}); - + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "Input data.") + .set_support_level(3) + .add_type_rel("Identity", IdentityRel) + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) + .set_attr("FTVMCompute", [](const Attrs& attrs, const Array& inputs, + const Type& out_type) { + const auto* param = attrs.as(); + return Array{topi::leaky_relu(inputs[0], param->alpha)}; + }); // relay.prelu TVM_REGISTER_NODE_TYPE(PReluAttrs); -bool PReluRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool PReluRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 3); const auto* data = types[0].as(); @@ -256,7 +227,7 @@ bool PReluRel(const Array& types, CHECK(param != nullptr); CHECK(param->axis < static_cast(data->shape.size())) - << "Wrong axis (" << param->axis << ")value."; + << "Wrong axis (" << param->axis << ")value."; // assign alpha type Array alpha_shape({data->shape[param->axis]}); @@ -267,72 +238,59 @@ bool PReluRel(const Array& types, return true; } -template -Array > PReluInferCorrectLayout( - const Attrs& attrs, - const Array& new_in_layouts, - const Array& old_in_layouts, - const Array &old_in_types) { - +template +Array> PReluInferCorrectLayout(const Attrs& attrs, + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array& old_in_types) { CHECK_EQ(old_in_layouts.size(), 2U); CHECK_EQ(old_in_types.size(), 2U); Layout data_layout = old_in_layouts[0]; if (new_in_layouts.defined()) { CHECK_EQ(new_in_layouts.size(), 2U); } - return Array >{{data_layout, Layout("C")}, - {data_layout}}; + return Array>{{data_layout, Layout("C")}, {data_layout}}; } // Positional relay function to create prelu operator used by frontend FFI. -Expr MakePRelu(Expr data, - Expr alpha, - int axis) { +Expr MakePRelu(Expr data, Expr alpha, int axis) { auto attrs = make_object(); attrs->axis = axis; static const Op& op = Op::Get("nn.prelu"); return Call(op, {data, alpha}, Attrs(attrs), {}); } - -TVM_REGISTER_GLOBAL("relay.op.nn._make.prelu") -.set_body_typed(MakePRelu); - +TVM_REGISTER_GLOBAL("relay.op.nn._make.prelu").set_body_typed(MakePRelu); RELAY_REGISTER_OP("nn.prelu") -.describe(R"code(Parametric version of a Rectified Linear Unit. + .describe(R"code(Parametric version of a Rectified Linear Unit. It accepts two arguments: an input ``x`` and a channelwise slope ``alpha`` and computes the output as :math:`PReLU(x) y = x > 0 ? x : alpha * x`, where :math:`*` is an channelwise multiplication for each sample in the batch. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(2) -.add_argument("data", "Tensor", "Input data.") -.add_argument("alpha", "Tensor", "Input channelwise alpha.") -.set_support_level(3) -.add_type_rel("PRelu", PReluRel) -.set_attr("FInferCorrectLayout", PReluInferCorrectLayout) -.set_attr( - "FTVMCompute", [](const Attrs& attrs, - const Array& inputs, - const Type& out_type) { - const auto* param = attrs.as(); - return Array{ topi::prelu(inputs[0], inputs[1], param->axis)}; -}); - + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "Tensor", "Input data.") + .add_argument("alpha", "Tensor", "Input channelwise alpha.") + .set_support_level(3) + .add_type_rel("PRelu", PReluRel) + .set_attr("FInferCorrectLayout", PReluInferCorrectLayout) + .set_attr("FTVMCompute", [](const Attrs& attrs, const Array& inputs, + const Type& out_type) { + const auto* param = attrs.as(); + return Array{topi::prelu(inputs[0], inputs[1], param->axis)}; + }); // relay.softmax TVM_REGISTER_NODE_TYPE(SoftmaxAttrs); -TVM_REGISTER_GLOBAL("relay.op.nn._make.softmax") -.set_body_typed([](Expr data, int axis) { +TVM_REGISTER_GLOBAL("relay.op.nn._make.softmax").set_body_typed([](Expr data, int axis) { auto attrs = make_object(); attrs->axis = axis; static const Op& op = Op::Get("nn.softmax"); return Call(op, {data}, Attrs(attrs), {}); }); - RELAY_REGISTER_OP("nn.softmax") .describe(R"code(Softmax layer. @@ -343,16 +301,14 @@ RELAY_REGISTER_OP("nn.softmax") - **data**: The input data )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(1) -.add_type_rel("Identity", IdentityRel); - + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(1) + .add_type_rel("Identity", IdentityRel); // relay.nn.log_softmax -TVM_REGISTER_GLOBAL("relay.op.nn._make.log_softmax") -.set_body_typed([](Expr data, int axis) { +TVM_REGISTER_GLOBAL("relay.op.nn._make.log_softmax").set_body_typed([](Expr data, int axis) { auto attrs = make_object(); attrs->axis = axis; static const Op& op = Op::Get("nn.log_softmax"); @@ -369,26 +325,22 @@ RELAY_REGISTER_OP("nn.log_softmax") - **data**: The input data )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(1) -.add_type_rel("Identity", IdentityRel) -.set_attr("FTVMCompute", [](const Attrs& attrs, - const Array& inputs, - const Type& out_type) { - const auto* param = attrs.as(); - CHECK(param != nullptr); - CHECK(param->axis == -1 || param->axis == static_cast(inputs[0].ndim()) - 1) - << "log_softmax currently only works on last dimension"; - return Array{ topi::nn::log_softmax(inputs[0]) }; -}); - + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(1) + .add_type_rel("Identity", IdentityRel) + .set_attr("FTVMCompute", [](const Attrs& attrs, const Array& inputs, + const Type& out_type) { + const auto* param = attrs.as(); + CHECK(param != nullptr); + CHECK(param->axis == -1 || param->axis == static_cast(inputs[0].ndim()) - 1) + << "log_softmax currently only works on last dimension"; + return Array{topi::nn::log_softmax(inputs[0])}; + }); // relay.nn.batch_flatten -bool BatchFlattenRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool BatchFlattenRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); @@ -418,13 +370,10 @@ Expr MakeBatchFlatten(Expr data) { return Call(op, {data}, Attrs(), {}); } - -TVM_REGISTER_GLOBAL("relay.op.nn._make.batch_flatten") -.set_body_typed(MakeBatchFlatten); - +TVM_REGISTER_GLOBAL("relay.op.nn._make.batch_flatten").set_body_typed(MakeBatchFlatten); RELAY_REGISTER_OP("nn.batch_flatten") -.describe(R"code(Flattens the input into a 2-D array. + .describe(R"code(Flattens the input into a 2-D array. For an input array with shape ``(d1, d2, ..., dk)``, `batch_flatten` operation reshapes the input array into an output array of shape ``(d1, d2*...*dk)``. @@ -445,53 +394,42 @@ Example:: [ 1., 2., 3., 4., 5., 6., 7., 8., 9.]] )code" TVM_ADD_FILELINE) -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(2) -.add_type_rel("BatchFlatten", BatchFlattenRel) -.set_attr( - "FTVMCompute", [](const Attrs& attrs, - const Array& inputs, - const Type& out_type) { - return Array{ topi::nn::flatten(inputs[0]) }; -}); - + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(2) + .add_type_rel("BatchFlatten", BatchFlattenRel) + .set_attr("FTVMCompute", [](const Attrs& attrs, const Array& inputs, + const Type& out_type) { + return Array{topi::nn::flatten(inputs[0])}; + }); // relu -TVM_REGISTER_GLOBAL("relay.op.nn._make.relu") -.set_body_typed([](Expr data) { - static const Op& op = Op::Get("nn.relu"); - return Call(op, {data}, Attrs(), {}); - }); +TVM_REGISTER_GLOBAL("relay.op.nn._make.relu").set_body_typed([](Expr data) { + static const Op& op = Op::Get("nn.relu"); + return Call(op, {data}, Attrs(), {}); +}); RELAY_REGISTER_OP("nn.relu") -.describe(R"code(Returns the relu input array, computed element-wise. + .describe(R"code(Returns the relu input array, computed element-wise. .. math:: max(x, 0) )code" TVM_ADD_FILELINE) -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(1) -.add_type_rel("Identity", IdentityRel) -.set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) -.set_attr("FTVMCompute", [](const Attrs& attrs, - const Array& inputs, - const Type& out_type) { - return Array{ topi::relu(inputs[0], 0.0f) }; -}); - + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(1) + .add_type_rel("Identity", IdentityRel) + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) + .set_attr("FTVMCompute", [](const Attrs& attrs, const Array& inputs, + const Type& out_type) { + return Array{topi::relu(inputs[0], 0.0f)}; + }); // Positional relay function to create LRN operator used by frontend FFI. TVM_REGISTER_NODE_TYPE(LRNAttrs); -Expr MakeLRN(Expr data, - int size, - int axis, - double alpha, - double beta, - double bias) { +Expr MakeLRN(Expr data, int size, int axis, double alpha, double beta, double bias) { auto attrs = make_object(); attrs->size = size; attrs->axis = axis; @@ -502,11 +440,10 @@ Expr MakeLRN(Expr data, return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op.nn._make.lrn") -.set_body_typed(MakeLRN); +TVM_REGISTER_GLOBAL("relay.op.nn._make.lrn").set_body_typed(MakeLRN); RELAY_REGISTER_OP("nn.lrn") -.describe(R"code(LRN layer. + .describe(R"code(LRN layer. Normalize the input in a local region across or within feature maps. Each input value is divided by (1 + (\alpha/n) \sum_i x_i^2)^\beta, @@ -519,20 +456,16 @@ centered at that value (zero padding is added where necessary). - **data**: The input tensor. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(2) -.set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) -.add_type_rel("Identity", IdentityRel); - + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(2) + .add_type_rel("Identity", IdentityRel); // Positional relay function to create L2Normalize operator used by frontend FFI. TVM_REGISTER_NODE_TYPE(L2NormalizeAttrs); -Expr MakeL2Normalize(Expr data, - double eps, - Array axis) { +Expr MakeL2Normalize(Expr data, double eps, Array axis) { auto attrs = make_object(); attrs->eps = eps; attrs->axis = std::move(axis); @@ -540,11 +473,10 @@ Expr MakeL2Normalize(Expr data, return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op.nn._make.l2_normalize") -.set_body_typed(MakeL2Normalize); +TVM_REGISTER_GLOBAL("relay.op.nn._make.l2_normalize").set_body_typed(MakeL2Normalize); RELAY_REGISTER_OP("nn.l2_normalize") -.describe(R"code(L2 Normalization layer. + .describe(R"code(L2 Normalization layer. Normalizes along dimension axis using an L2 norm @@ -553,19 +485,17 @@ Normalizes along dimension axis using an L2 norm - **data**: The input tensor. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(2) -.set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) -.add_type_rel("Identity", IdentityRel); + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(2) + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) + .add_type_rel("Identity", IdentityRel); // Dropout TVM_REGISTER_NODE_TYPE(DropoutAttrs); -bool DropoutRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool DropoutRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); @@ -585,22 +515,21 @@ Expr MakeDropout(Expr data, double rate) { return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op.nn._make.dropout") -.set_body_typed(MakeDropout); +TVM_REGISTER_GLOBAL("relay.op.nn._make.dropout").set_body_typed(MakeDropout); RELAY_REGISTER_OP("nn.dropout") -.describe(R"code(Applies the dropout operation to the input array. + .describe(R"code(Applies the dropout operation to the input array. During training, each element of the input is set to zero with probability ``p``. The whole array is rescaled by ``1/(1-p)`` to keep the expected sum of the input unchanged. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "Input to which dropout will be applied.") -.set_support_level(1) -.set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) -.add_type_rel("Dropout", DropoutRel); + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "Input to which dropout will be applied.") + .set_support_level(1) + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) + .add_type_rel("Dropout", DropoutRel); // batch_norm TVM_REGISTER_NODE_TYPE(BatchNormAttrs); @@ -639,9 +568,7 @@ Array> BatchNormInferCorrectLayout(const Attrs& attrs, {ret, c_layout, c_layout}}; } -bool BatchNormRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool BatchNormRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 6); const auto* data = types[0].as(); @@ -663,8 +590,7 @@ bool BatchNormRel(const Array& types, // output is a tuple of the normed data (same shape as input), new running mean, // and new running average (the latter two are both vectors of length dim) std::vector fields; - auto vec_ty = TensorType(Array({data->shape[axis]}), - data->dtype); + auto vec_ty = TensorType(Array({data->shape[axis]}), data->dtype); fields.push_back(TensorType(data->shape, data->dtype)); fields.push_back(vec_ty); fields.push_back(vec_ty); @@ -672,8 +598,8 @@ bool BatchNormRel(const Array& types, return true; } -Expr MakeBatchNorm(Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr moving_var, - int axis, double epsilon, bool center, bool scale) { +Expr MakeBatchNorm(Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr moving_var, int axis, + double epsilon, bool center, bool scale) { auto attrs = make_object(); attrs->axis = axis; attrs->epsilon = epsilon; @@ -683,11 +609,10 @@ Expr MakeBatchNorm(Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr movi return Call(op, {data, gamma, beta, moving_mean, moving_var}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op.nn._make.batch_norm") -.set_body_typed(MakeBatchNorm); +TVM_REGISTER_GLOBAL("relay.op.nn._make.batch_norm").set_body_typed(MakeBatchNorm); RELAY_REGISTER_OP("nn.batch_norm") -.describe(R"code(Batch normalization layer (Ioffe and Szegedy, 2014). + .describe(R"code(Batch normalization layer (Ioffe and Szegedy, 2014). Normalizes the input at each batch, i.e. applies a transformation that maintains the mean activation close to 0 and the activation standard deviation close to 1. @@ -723,24 +648,21 @@ axis to be the last item in the input shape. .. note:: This operator can be optimized away for inference. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(5) -.add_argument("data", "Tensor", "Input to which batch_norm will be applied.") -.add_argument("gamma", "Tensor", "The gamma scale factor.") -.add_argument("beta", "Tensor", "The beta offset factor.") -.add_argument("moving_mean", "Tensor", "Running mean of input.") -.add_argument("moving_var", "Tensor", "Running variance of input.") -.set_attr("FInferCorrectLayout", BatchNormInferCorrectLayout) -.set_support_level(1) -.add_type_rel("BatchNorm", BatchNormRel); - + .set_attrs_type() + .set_num_inputs(5) + .add_argument("data", "Tensor", "Input to which batch_norm will be applied.") + .add_argument("gamma", "Tensor", "The gamma scale factor.") + .add_argument("beta", "Tensor", "The beta offset factor.") + .add_argument("moving_mean", "Tensor", "Running mean of input.") + .add_argument("moving_var", "Tensor", "Running variance of input.") + .set_attr("FInferCorrectLayout", BatchNormInferCorrectLayout) + .set_support_level(1) + .add_type_rel("BatchNorm", BatchNormRel); // instance_norm TVM_REGISTER_NODE_TYPE(InstanceNormAttrs); -bool InstanceNormRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool InstanceNormRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 4); const auto* data = types[0].as(); @@ -755,8 +677,8 @@ bool InstanceNormRel(const Array& types, return true; } -Expr MakeInstanceNorm(Expr data, Expr gamma, Expr beta, int axis, double epsilon, - bool center, bool scale) { +Expr MakeInstanceNorm(Expr data, Expr gamma, Expr beta, int axis, double epsilon, bool center, + bool scale) { auto attrs = make_object(); attrs->axis = axis; attrs->epsilon = epsilon; @@ -767,12 +689,12 @@ Expr MakeInstanceNorm(Expr data, Expr gamma, Expr beta, int axis, double epsilon } TVM_REGISTER_GLOBAL("relay.op.nn._make.instance_norm") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeInstanceNorm, args, rv); - }); + .set_body([](const TVMArgs& args, TVMRetValue* rv) { + runtime::detail::unpack_call(MakeInstanceNorm, args, rv); + }); RELAY_REGISTER_OP("nn.instance_norm") -.describe(R"code(Instance Normalization (Ulyanov and et al., 2016) + .describe(R"code(Instance Normalization (Ulyanov and et al., 2016) Applies instance normalization to the n-dimensional input array. .. math:: @@ -796,21 +718,18 @@ to be the last item in the input shape. This operator can be optimized away for inference. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(3) -.add_argument("data", "Tensor", "Input to which instance_norm will be applied.") -.add_argument("gamma", "Tensor", "The gamma scale factor.") -.add_argument("beta", "Tensor", "The beta offset factor.") -.set_support_level(1) -.add_type_rel("InstanceNorm", InstanceNormRel); - + .set_attrs_type() + .set_num_inputs(3) + .add_argument("data", "Tensor", "Input to which instance_norm will be applied.") + .add_argument("gamma", "Tensor", "The gamma scale factor.") + .add_argument("beta", "Tensor", "The beta offset factor.") + .set_support_level(1) + .add_type_rel("InstanceNorm", InstanceNormRel); // layer_norm TVM_REGISTER_NODE_TYPE(LayerNormAttrs); -bool LayerNormRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool LayerNormRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 4); const auto* data = types[0].as(); @@ -825,8 +744,8 @@ bool LayerNormRel(const Array& types, return true; } -Expr MakeLayerNorm(Expr data, Expr gamma, Expr beta, int axis, double epsilon, - bool center, bool scale) { +Expr MakeLayerNorm(Expr data, Expr gamma, Expr beta, int axis, double epsilon, bool center, + bool scale) { auto attrs = make_object(); attrs->axis = axis; attrs->epsilon = epsilon; @@ -837,25 +756,94 @@ Expr MakeLayerNorm(Expr data, Expr gamma, Expr beta, int axis, double epsilon, } TVM_REGISTER_GLOBAL("relay.op.nn._make.layer_norm") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeLayerNorm, args, rv); - }); + .set_body([](const TVMArgs& args, TVMRetValue* rv) { + runtime::detail::unpack_call(MakeLayerNorm, args, rv); + }); RELAY_REGISTER_OP("nn.layer_norm") -.describe(R"code( + .describe(R"code( )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(3) -.add_argument("data", "Tensor", "Input to which layer_norm will be applied.") -.add_argument("gamma", "Tensor", "The gamma scale factor.") -.add_argument("beta", "Tensor", "The beta offset factor.") -.set_support_level(1) -.add_type_rel("LayerNorm", LayerNormRel); + .set_attrs_type() + .set_num_inputs(3) + .add_argument("data", "Tensor", "Input to which layer_norm will be applied.") + .add_argument("gamma", "Tensor", "The gamma scale factor.") + .add_argument("beta", "Tensor", "The beta offset factor.") + .set_support_level(1) + .add_type_rel("LayerNorm", LayerNormRel); + +// group_norm +TVM_REGISTER_NODE_TYPE(GroupNormAttrs); + +bool GroupNormRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 4); + const auto* data = types[0].as(); + if (data == nullptr) return false; + const GroupNormAttrs* param = attrs.as(); + int axis = param->axis >= 0 ? param->axis : param->axis + data->shape.size(); + CHECK(axis >= 0 && axis < (int)data->shape.size()); + reporter->Assign(types[1], TensorType({data->shape[axis]}, data->dtype)); + reporter->Assign(types[2], TensorType({data->shape[axis]}, data->dtype)); + reporter->Assign(types[3], TensorType(data->shape, data->dtype)); + + return true; +} + +Expr MakeGroupNorm(Expr data, Expr gamma, Expr beta, int num_groups, int axis, double epsilon, + bool center, bool scale) { + auto attrs = make_object(); + attrs->num_groups = num_groups; + attrs->axis = axis; + attrs->epsilon = epsilon; + attrs->center = center; + attrs->scale = scale; + static const Op& op = Op::Get("nn.group_norm"); + return Call(op, {data, gamma, beta}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op.nn._make.group_norm") + .set_body([](const TVMArgs& args, TVMRetValue* rv) { + runtime::detail::unpack_call(MakeGroupNorm, args, rv); + }); + +RELAY_REGISTER_OP("nn.group_norm") + .describe(R"code( +Group normalization normalizes over group of channels for each training examples. +We can say that, Group Norm is in between Instance Norm and Layer Norm. When we put +all the channels into a single group, group normalization becomes Layer normalization. +And, when we put each channel into different groups it becomes Instance normalization + +https://arxiv.org/pdf/1803.08494.pdf + +Applies group normalization to the n-dimensional input array by seperating the input channels +into 'num_groups' groups, each containing 'num_channels / num_groups' channels. +The mean and standard-deviation are calculated separately over the each group. gamma and +beta are learnable per-channel affine transform parameter vectors of size num_channels. + +.. math:: + + out = \frac{data - mean(data, axis)}{\sqrt{var(data, axis)+\epsilon}} + * gamma + beta + +Unlike batch normalization, the mean and var are computed along a group of channels. + +If the input has size k on axis 1, then both gamma and beta have shape (k,). + +.. note:: + + This operator can be optimized away for inference. + +)code" TVM_ADD_FILELINE) + .set_attrs_type() + .set_num_inputs(3) + .add_argument("data", "Tensor", "Input to which group_norm will be applied.") + .add_argument("gamma", "Tensor", "The gamma scale factor.") + .add_argument("beta", "Tensor", "The beta offset factor.") + .set_support_level(1) + .add_type_rel("GroupNorm", GroupNormRel); // relay.nn.batch_matmul -bool BatchMatmulRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool BatchMatmulRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 3); const auto* x = types[0].as(); @@ -864,12 +852,10 @@ bool BatchMatmulRel(const Array& types, CHECK(x->shape.size() == 3 && y->shape.size() == 3); CHECK(reporter->AssertEQ(x->shape[0], y->shape[0])) << "BatchDot: batch dimension doesn't match, " - << " x shape=" << x->shape - << ", y shape=" << y->shape; + << " x shape=" << x->shape << ", y shape=" << y->shape; CHECK(reporter->AssertEQ(x->shape[2], y->shape[2])) << "BatchDot: shapes of x and y is inconsistent, " - << " x shape=" << x->shape - << ", y shape=" << y->shape; + << " x shape=" << x->shape << ", y shape=" << y->shape; Array oshape = x->shape; oshape.Set(2, y->shape[1]); @@ -879,21 +865,16 @@ bool BatchMatmulRel(const Array& types, return true; } - // Positional relay function to create batch_matmul operator used by frontend FFI. -Expr MakeBatchMatmul(Expr x, - Expr y) { +Expr MakeBatchMatmul(Expr x, Expr y) { static const Op& op = Op::Get("nn.batch_matmul"); return Call(op, {x, y}, Attrs(), {}); } - -TVM_REGISTER_GLOBAL("relay.op.nn._make.batch_matmul") -.set_body_typed(MakeBatchMatmul); - +TVM_REGISTER_GLOBAL("relay.op.nn._make.batch_matmul").set_body_typed(MakeBatchMatmul); RELAY_REGISTER_OP("nn.batch_matmul") -.describe(R"code(Computes matrix multiplication of `x` and `y` when `x` and `y` + .describe(R"code(Computes matrix multiplication of `x` and `y` when `x` and `y` are data in batch. .. math:: @@ -905,34 +886,31 @@ are data in batch. - **out**: `(b, m, n)`. )code" TVM_ADD_FILELINE) -.set_num_inputs(2) -.add_argument("x", "3D Tensor", "First input.") -.add_argument("y", "3D Tensor", "Second input.") -.set_support_level(10) -.add_type_rel("BatchMatmul", BatchMatmulRel); - + .set_num_inputs(2) + .add_argument("x", "3D Tensor", "First input.") + .add_argument("y", "3D Tensor", "Second input.") + .set_support_level(10) + .add_type_rel("BatchMatmul", BatchMatmulRel); // relay.nn.cross_entropy -bool CrossEntropyRel(const Array& types, - int num_inputs, - const Attrs& attrs, - const TypeReporter& reporter) { +bool CrossEntropyRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { CHECK_EQ(types.size(), 3); const auto* x = types[0].as(); const auto* y = types[1].as(); if (x == nullptr || y == nullptr) return false; CHECK(x->shape.size() == 2 && y->shape.size() == 2) - << "CrossEntropy: shapes of x and y is inconsistent, " - << "x shape = " << x->shape << ", " - << "y shape = " << y->shape; + << "CrossEntropy: shapes of x and y is inconsistent, " + << "x shape = " << x->shape << ", " + << "y shape = " << y->shape; CHECK(reporter->AssertEQ(x->shape[0], y->shape[0])) - << "CrossEntropy: shapes of x and y is inconsistent, " - << "x shape = " << x->shape << ", " - << "y shape = " << y->shape; + << "CrossEntropy: shapes of x and y is inconsistent, " + << "x shape = " << x->shape << ", " + << "y shape = " << y->shape; CHECK(reporter->AssertEQ(x->shape[1], y->shape[1])) - << "CrossEntropy: shapes of x and y is inconsistent, " - << "x shape = " << x->shape << ", " - << "y shape = " << y->shape; + << "CrossEntropy: shapes of x and y is inconsistent, " + << "x shape = " << x->shape << ", " + << "y shape = " << y->shape; // assign output type reporter->Assign(types[2], TensorType({}, x->dtype)); return true; @@ -944,22 +922,61 @@ Expr MakeCrossEntropy(Expr predictions, Expr targets) { return Call(op, {predictions, targets}, Attrs(), {}); } - -TVM_REGISTER_GLOBAL("relay.op.nn._make.cross_entropy") -.set_body_typed(MakeCrossEntropy); - +TVM_REGISTER_GLOBAL("relay.op.nn._make.cross_entropy").set_body_typed(MakeCrossEntropy); RELAY_REGISTER_OP("nn.cross_entropy") -.describe(R"code( + .describe(R"code( Computes cross entropy given predictions and targets. Do log on the data - do not accept logits. )code" TVM_ADD_FILELINE) -.set_num_inputs(2) -.add_argument("x", "1D Tensor", "Predictions.") -.add_argument("y", "1D Tensor", "Targets.") -.set_support_level(10) -.add_type_rel("CrossEntropy", CrossEntropyRel); + .set_num_inputs(2) + .add_argument("x", "1D Tensor", "Predictions.") + .add_argument("y", "1D Tensor", "Targets.") + .set_support_level(10) + .add_type_rel("CrossEntropy", CrossEntropyRel); +// relay.nn.dilate +TVM_REGISTER_NODE_TYPE(DilateAttrs); + +bool DilateRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 2); + const auto* x = types[0].as(); + const DilateAttrs* param = attrs.as(); + if (x == nullptr) return false; + CHECK_EQ(x->shape.size(), param->strides.size()); + + std::vector oshape; + for (size_t i = 0; i < param->strides.size(); ++i) { + if (!x->shape[i].as()) { + oshape.push_back((x->shape[i] - 1) * param->strides[i] + 1); + } else { + oshape.push_back(x->shape[i]); + } + } + + reporter->Assign(types[1], TensorType(Array(oshape), x->dtype)); + return true; +} + +// Positional relay function to create dilate operator used by frontend FFI. +Expr MakeDilate(Expr data, Array strides) { + auto attrs = make_object(); + attrs->strides = std::move(strides); + static const Op& op = Op::Get("nn.dilate"); + return Call(op, {data}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op.nn._make.dilate").set_body_typed(MakeDilate); + +RELAY_REGISTER_OP("nn.dilate") + .describe(R"code( +Dilate data with zeros. +)code" TVM_ADD_FILELINE) + .set_num_inputs(1) + .add_argument("x", "1D Tensor", "Data to dilate.") + .set_support_level(10) + .add_type_rel("Dilate", DilateRel); // Positional relay function to create cross_entropy_with_logits operator used by frontend FFI. Expr MakeCrossEntropyWithLogits(Expr predictions, Expr targets) { @@ -967,21 +984,19 @@ Expr MakeCrossEntropyWithLogits(Expr predictions, Expr targets) { return Call(op, {predictions, targets}, Attrs(), {}); } - TVM_REGISTER_GLOBAL("relay.op.nn._make.cross_entropy_with_logits") -.set_body_typed(MakeCrossEntropyWithLogits); - + .set_body_typed(MakeCrossEntropyWithLogits); RELAY_REGISTER_OP("nn.cross_entropy_with_logits") -.describe(R"code( + .describe(R"code( Computes cross entropy given predictions and targets. Accept logits. )code" TVM_ADD_FILELINE) -.set_num_inputs(2) -.add_argument("x", "1D Tensor", "Predictions.") -.add_argument("y", "1D Tensor", "Targets.") -.set_support_level(10) -.add_type_rel("CrossEntropy", CrossEntropyRel); + .set_num_inputs(2) + .add_argument("x", "1D Tensor", "Predictions.") + .add_argument("y", "1D Tensor", "Targets.") + .set_support_level(10) + .add_type_rel("CrossEntropy", CrossEntropyRel); // Depth to space and space to depth TVM_REGISTER_NODE_TYPE(SubPixelAttrs); @@ -1009,15 +1024,14 @@ bool DepthToSpaceRel(const Array& types, int num_inputs, const Attrs& attr oshape.Set(3, oshape[3] * block_size); // Assign output type - reporter->Assign(types[1], - TensorType(layout_converter.BackwardShape(oshape), data->dtype)); + reporter->Assign(types[1], TensorType(layout_converter.BackwardShape(oshape), data->dtype)); return true; } // Positional relay function to create DepthToSpace operator // used by frontend FFI -Expr MakeDepthToSpace(Expr data, int block_size, std::string layout, std::string mode) { +Expr MakeDepthToSpace(Expr data, int block_size, String layout, String mode) { auto attrs = make_object(); attrs->block_size = block_size; attrs->layout = std::move(layout); @@ -1067,15 +1081,14 @@ bool SpaceToDepthRel(const Array& types, int num_inputs, const Attrs& attr oshape.Set(3, indexdiv(oshape[3], block_size)); // Assign output type - reporter->Assign(types[1], - TensorType(layout_converter.BackwardShape(oshape), data->dtype)); + reporter->Assign(types[1], TensorType(layout_converter.BackwardShape(oshape), data->dtype)); return true; } // Positional relay function to create SpaceToDepth operator // used by frontend FFI -Expr MakeSpaceToDepth(Expr data, int block_size, std::string layout) { +Expr MakeSpaceToDepth(Expr data, int block_size, String layout) { auto attrs = make_object(); attrs->block_size = block_size; attrs->layout = std::move(layout); diff --git a/src/relay/op/nn/nn.h b/src/relay/op/nn/nn.h index dc876e863ad0..0fb02638db07 100644 --- a/src/relay/op/nn/nn.h +++ b/src/relay/op/nn/nn.h @@ -24,6 +24,11 @@ #ifndef TVM_RELAY_OP_NN_NN_H_ #define TVM_RELAY_OP_NN_NN_H_ +#include +#include +#include +#include + #include namespace tvm { @@ -58,8 +63,7 @@ bool DenseRel(const Array& types, int num_inputs, const Attrs& attrs, if (weight == nullptr) return false; Array wshape = weight->shape; CHECK(static_cast(weight->shape.size()) == 2); - CHECK(reporter->AssertEQ(data->shape[data->shape.size() - 1], - weight->shape[1])) + CHECK(reporter->AssertEQ(data->shape[data->shape.size() - 1], weight->shape[1])) << "DenseRel: input dimension doesn't match," << " data shape=" << data->shape << ", weight shape=" << weight->shape; oshape.Set((oshape.size() - 1), wshape[0]); diff --git a/src/relay/op/nn/pad.cc b/src/relay/op/nn/pad.cc index abff06ef9d88..aba87e2017a0 100644 --- a/src/relay/op/nn/pad.cc +++ b/src/relay/op/nn/pad.cc @@ -21,12 +21,14 @@ * \file pad.cc * \brief Implementation of operator pad */ +#include +#include +#include #include #include -#include -#include -#include + #include + #include "../op_common.h" namespace tvm { @@ -35,13 +37,11 @@ namespace relay { // relay.nn.pad TVM_REGISTER_NODE_TYPE(PadAttrs); -Array > PadInferCorrectLayout( - const Attrs& attrs, - const Array& new_in_layouts, - const Array& old_in_layouts, - const Array &old_in_types) { +Array> PadInferCorrectLayout(const Attrs& attrs, const Array& new_in_layouts, + const Array& old_in_layouts, + const Array& old_in_types) { // NOTE: Discard "const" qualifier here. - PadAttrs *params = const_cast(attrs.as()); + PadAttrs* params = const_cast(attrs.as()); Layout ret; // If new_in_layouts are defined, this code tries to modify the layout. @@ -108,12 +108,10 @@ Array > PadInferCorrectLayout( } } - return Array >{{ret}, {ret}}; + return Array>{{ret}, {ret}}; } -bool PadRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool PadRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); @@ -124,28 +122,26 @@ bool PadRel(const Array& types, // check that pad widths match lengths CHECK(data->shape.size() == param->pad_width.size()) - << "There should be as many pad width pairs as shape dimensions " - << "but the shape has " << data->shape.size() << " dimensions " - << "and there are " << param->pad_width.size() << " pad width pairs."; + << "There should be as many pad width pairs as shape dimensions " + << "but the shape has " << data->shape.size() << " dimensions " + << "and there are " << param->pad_width.size() << " pad width pairs."; // each pad width element should be a pair of positive integers std::vector oshape; for (size_t i = 0; i < param->pad_width.size(); i++) { CHECK(param->pad_width[i].size() == 2) - << "Each pad width element should be a pair but at index " << i - << " there are " << param->pad_width[i].size() << " elements."; + << "Each pad width element should be a pair but at index " << i << " there are " + << param->pad_width[i].size() << " elements."; auto width1 = tir::as_const_int(param->pad_width[i][0]); auto width2 = tir::as_const_int(param->pad_width[i][1]); CHECK(width1 != nullptr); CHECK(width2 != nullptr); - CHECK(*width1 >= 0) - << "Param width elements should be positive but first pad width at " - << "index " << i << " is " << *width1 << "."; - CHECK(*width2 >= 0) - << "Param width elements should be positive but first pad width at " - << "index " << i << " is " << *width2 << "."; + CHECK(*width1 >= 0) << "Param width elements should be positive but first pad width at " + << "index " << i << " is " << *width1 << "."; + CHECK(*width2 >= 0) << "Param width elements should be positive but first pad width at " + << "index " << i << " is " << *width2 << "."; if (!data->shape[i].as()) { auto padding = tir::make_const(data->shape[i].dtype(), *width1 + *width2); @@ -155,21 +151,17 @@ bool PadRel(const Array& types, } } - reporter->Assign(types[1], TensorType(Array(oshape), - data->dtype)); + reporter->Assign(types[1], TensorType(Array(oshape), data->dtype)); return true; } -Array PadCompute(const Attrs& attrs, - const Array& inputs, +Array PadCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { const auto* param = attrs.as(); CHECK(param != nullptr); auto pad_width = param->pad_width; - CHECK(pad_width.size() == inputs[0].ndim() && - pad_width[0].size() == 2) - << "Illegal pad_width"; + CHECK(pad_width.size() == inputs[0].ndim() && pad_width[0].size() == 2) << "Illegal pad_width"; Array pad_before; for (size_t i = 0; i < pad_width.size(); ++i) { pad_before.push_back(pad_width[i][0]); @@ -179,18 +171,13 @@ Array PadCompute(const Attrs& attrs, pad_after.push_back(pad_width[i][1]); } const auto* out_ttype = out_type.as(); - return Array{ topi::pad(inputs[0], pad_before, pad_after, - tvm::tir::make_const(out_ttype->dtype, param->pad_value), - "T_pad", - topi::kElementWise, - param->pad_mode) }; + return Array{topi::pad(inputs[0], pad_before, pad_after, + tvm::tir::make_const(out_ttype->dtype, param->pad_value), + "T_pad", topi::kElementWise, param->pad_mode)}; } // Handler to create a call to the padding op used by front-end FFI -Expr MakePad(Expr data, - Array > pad_width, - double pad_value, - std::string pad_mode) { +Expr MakePad(Expr data, Array> pad_width, double pad_value, String pad_mode) { auto attrs = make_object(); attrs->pad_value = pad_value; attrs->pad_width = std::move(pad_width); @@ -199,29 +186,25 @@ Expr MakePad(Expr data, return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op.nn._make.pad") -.set_body_typed(MakePad); +TVM_REGISTER_GLOBAL("relay.op.nn._make.pad").set_body_typed(MakePad); RELAY_REGISTER_OP("nn.pad") -.describe(R"code(Pad for n-D tensor. + .describe(R"code(Pad for n-D tensor. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(2) -.add_type_rel("Pad", PadRel) -.set_attr("FInferCorrectLayout", PadInferCorrectLayout) -.set_attr("TOpPattern", kInjective) -.set_attr("FTVMCompute", PadCompute); - + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(2) + .add_type_rel("Pad", PadRel) + .set_attr("FInferCorrectLayout", PadInferCorrectLayout) + .set_attr("TOpPattern", kInjective) + .set_attr("FTVMCompute", PadCompute); // relay.nn.mirror_pad TVM_REGISTER_NODE_TYPE(MirrorPadAttrs); -bool MirrorPadRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool MirrorPadRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); @@ -232,40 +215,37 @@ bool MirrorPadRel(const Array& types, // check that pad widths match lengths CHECK(data->shape.size() == param->pad_width.size()) - << "There should be as many pad width pairs as shape dimensions " - << "but the shape has " << data->shape.size() << " dimensions " - << "and there are " << param->pad_width.size() << " pad width pairs."; + << "There should be as many pad width pairs as shape dimensions " + << "but the shape has " << data->shape.size() << " dimensions " + << "and there are " << param->pad_width.size() << " pad width pairs."; // each pad width element should be a pair of positive integers std::vector oshape; for (size_t i = 0; i < param->pad_width.size(); i++) { CHECK(param->pad_width[i].size() == 2) - << "Each pad width element should be a pair but at index " << i - << " there are " << param->pad_width[i].size() << " elements."; + << "Each pad width element should be a pair but at index " << i << " there are " + << param->pad_width[i].size() << " elements."; auto width1 = tir::as_const_int(param->pad_width[i][0]); auto width2 = tir::as_const_int(param->pad_width[i][1]); CHECK(width1 != nullptr); CHECK(width2 != nullptr); - CHECK(*width1 >= 0) - << "Param width elements should be positive but first pad width at " - << "index " << i << " is " << *width1 << "."; - CHECK(*width2 >= 0) - << "Param width elements should be positive but first pad width at " - << "index " << i << " is " << *width2 << "."; + CHECK(*width1 >= 0) << "Param width elements should be positive but first pad width at " + << "index " << i << " is " << *width1 << "."; + CHECK(*width2 >= 0) << "Param width elements should be positive but first pad width at " + << "index " << i << " is " << *width2 << "."; auto padding = tir::make_const(data->shape[i].dtype(), *width1 + *width2); oshape.push_back(data->shape[i] + padding); } - reporter->Assign(types[1], TensorType(Array(oshape), - data->dtype)); + reporter->Assign(types[1], TensorType(Array(oshape), data->dtype)); return true; } // Handler to create a call to the padding op used by front-end FFI -Expr MakeMirrorPad(Expr data, Array > pad_width, std::string mode) { +Expr MakeMirrorPad(Expr data, Array> pad_width, String mode) { auto attrs = make_object(); attrs->mode = mode; attrs->pad_width = std::move(pad_width); @@ -273,19 +253,18 @@ Expr MakeMirrorPad(Expr data, Array > pad_width, std::string mo return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op.nn._make.mirror_pad") -.set_body_typed(MakeMirrorPad); +TVM_REGISTER_GLOBAL("relay.op.nn._make.mirror_pad").set_body_typed(MakeMirrorPad); RELAY_REGISTER_OP("nn.mirror_pad") -.describe(R"code(MirrorPad for n-D tensor. + .describe(R"code(MirrorPad for n-D tensor. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(2) -.add_type_rel("MirrorPad", MirrorPadRel) -.set_attr("TOpPattern", kInjective); + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(2) + .add_type_rel("MirrorPad", MirrorPadRel) + .set_attr("TOpPattern", kInjective); } // namespace relay } // namespace tvm diff --git a/src/relay/op/nn/pooling.cc b/src/relay/op/nn/pooling.cc index c20793d9ac28..e54a5f32fc88 100644 --- a/src/relay/op/nn/pooling.cc +++ b/src/relay/op/nn/pooling.cc @@ -21,12 +21,14 @@ * \file pooling.cc * \brief Pooling operators */ -#include +#include +#include #include #include -#include -#include +#include + #include + #include "../../transforms/infer_layout_util.h" namespace tvm { @@ -37,13 +39,12 @@ TVM_REGISTER_NODE_TYPE(MaxPool2DAttrs); TVM_REGISTER_NODE_TYPE(AvgPool2DAttrs); template -Array > PoolInferCorrectLayout( - const Attrs& attrs, - const Array& new_in_layouts, - const Array& old_in_layouts, - const Array &old_in_types) { +Array > PoolInferCorrectLayout(const Attrs& attrs, + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array& old_in_types) { // NOTE: Discard "const" qualifier here. - T *params = const_cast(attrs.as()); + T* params = const_cast(attrs.as()); if (new_in_layouts.defined()) { // Set the pool with the new layout. @@ -56,13 +57,8 @@ Array > PoolInferCorrectLayout( } template -Expr MakeMaxPool(Expr data, - Array pool_size, - Array strides, - Array padding, - std::string layout, - bool ceil_mode, - std::string op_name) { +Expr MakeMaxPool(Expr data, Array pool_size, Array strides, + Array padding, String layout, bool ceil_mode, String op_name) { auto attrs = make_object(); attrs->pool_size = std::move(pool_size); attrs->strides = std::move(strides); @@ -74,14 +70,9 @@ Expr MakeMaxPool(Expr data, } template -Expr MakeAvgPool(Expr data, - Array pool_size, - Array strides, - Array padding, - std::string layout, - bool ceil_mode, - bool count_include_pad, - std::string op_name) { +Expr MakeAvgPool(Expr data, Array pool_size, Array strides, + Array padding, String layout, bool ceil_mode, bool count_include_pad, + String op_name) { auto attrs = make_object(); attrs->pool_size = std::move(pool_size); attrs->strides = std::move(strides); @@ -94,9 +85,7 @@ Expr MakeAvgPool(Expr data, } template -bool Pool2DRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool Pool2DRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); @@ -112,8 +101,7 @@ bool Pool2DRel(const Array& types, Layout layout(param->layout); CHECK(layout.Contains(LayoutAxis::Get('H')) && layout.Contains(LayoutAxis::Get('W')) && !layout.Contains(LayoutAxis::Get('h')) && !layout.Contains(LayoutAxis::Get('w'))) - << "Invalid layout " << layout - << ". Pool2D layout must have H and W, which cannot be split"; + << "Invalid layout " << layout << ". Pool2D layout must have H and W, which cannot be split"; const auto hidx = layout.IndexOf(LayoutAxis::Get('H')); const auto widx = layout.IndexOf(LayoutAxis::Get('W')); @@ -140,8 +128,9 @@ bool Pool2DRel(const Array& types, oshape[hidx] = dshape[hidx]; } else { if (param->ceil_mode) { - oshape[hidx] = ((dshape[hidx] + pad_h - param->pool_size[0] + - param->strides[0] - 1) / param->strides[0]) + 1; + oshape[hidx] = ((dshape[hidx] + pad_h - param->pool_size[0] + param->strides[0] - 1) / + param->strides[0]) + + 1; } else { oshape[hidx] = ((dshape[hidx] + pad_h - param->pool_size[0]) / param->strides[0]) + 1; } @@ -150,8 +139,9 @@ bool Pool2DRel(const Array& types, oshape[widx] = dshape[widx]; } else { if (param->ceil_mode) { - oshape[widx] = ((dshape[widx] + pad_w - param->pool_size[1] + - param->strides[1] - 1) / param->strides[1]) + 1; + oshape[widx] = ((dshape[widx] + pad_w - param->pool_size[1] + param->strides[1] - 1) / + param->strides[1]) + + 1; } else { oshape[widx] = ((dshape[widx] + pad_w - param->pool_size[1]) / param->strides[1]) + 1; } @@ -162,9 +152,8 @@ bool Pool2DRel(const Array& types, return true; } -template -Array Pool2DCompute(const Attrs& attrs, - const Array& inputs, +template +Array Pool2DCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { static const Layout kNCHW("NCHW"); const auto* param = attrs.as(); @@ -182,9 +171,7 @@ Array Pool2DCompute(const Attrs& attrs, CHECK_EQ(layout.IndexOf(LayoutAxis::Get('w')), -1) << "max_pool2d does not support input split on width"; - CHECK(inputs[0].ndim() == 4U || - inputs[0].ndim() == 5U || - inputs[0].ndim() == 6U) + CHECK(inputs[0].ndim() == 4U || inputs[0].ndim() == 5U || inputs[0].ndim() == 6U) << "Pool2D only support 4-D input (e.g., NCHW)" << " or 5-D input (e.g. NCHWc on for vector instructions)" << " or 6-D input (e.g. NCHWnc for tensor accelerators)"; @@ -199,30 +186,23 @@ Array Pool2DCompute(const Attrs& attrs, } if (mode == topi::nn::kAvgPool) { bool count_include_pad = reinterpret_cast(param)->count_include_pad; - return Array{ - topi::nn::pool(inputs[0], pool_size, strides, padding, - mode, ceil_mode, layout.name(), count_include_pad)}; + return Array{topi::nn::pool(inputs[0], pool_size, strides, padding, mode, ceil_mode, + layout.name(), count_include_pad)}; } else { return Array{ - topi::nn::pool(inputs[0], pool_size, strides, padding, - mode, ceil_mode, layout.name())}; + topi::nn::pool(inputs[0], pool_size, strides, padding, mode, ceil_mode, layout.name())}; } } TVM_REGISTER_GLOBAL("relay.op.nn._make.max_pool2d") -.set_body_typed([](Expr data, - Array pool_size, - Array strides, - Array padding, - std::string layout, - bool ceil_mode) { - return MakeMaxPool(data, pool_size, strides, padding, layout, ceil_mode, - "nn.max_pool2d"); -}); - + .set_body_typed([](Expr data, Array pool_size, Array strides, + Array padding, String layout, bool ceil_mode) { + return MakeMaxPool(data, pool_size, strides, padding, layout, ceil_mode, + "nn.max_pool2d"); + }); RELAY_REGISTER_OP("nn.max_pool2d") -.describe(R"code(Max pooling operation for two dimensional data. + .describe(R"code(Max pooling operation for two dimensional data. - **data**: This depends on the `layout` parameter. Input is 4D array of shape (batch_size, channels, height, width) if `layout` is `NCHW`. @@ -242,30 +222,25 @@ RELAY_REGISTER_OP("nn.max_pool2d") equation. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(2) -.add_type_rel("MaxPool2D", Pool2DRel) -.set_attr("FInferCorrectLayout", PoolInferCorrectLayout) -.set_attr("FTVMCompute", Pool2DCompute); - + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(2) + .add_type_rel("MaxPool2D", Pool2DRel) + .set_attr("FInferCorrectLayout", PoolInferCorrectLayout) + .set_attr("FTVMCompute", Pool2DCompute); // AvgPool2D TVM_REGISTER_GLOBAL("relay.op.nn._make.avg_pool2d") -.set_body_typed([](Expr data, - Array pool_size, - Array strides, - Array padding, - std::string layout, - bool ceil_mode, - bool count_include_pad) { - return MakeAvgPool(data, pool_size, strides, padding, layout, ceil_mode, - count_include_pad, "nn.avg_pool2d"); -}); + .set_body_typed([](Expr data, Array pool_size, Array strides, + Array padding, String layout, bool ceil_mode, + bool count_include_pad) { + return MakeAvgPool(data, pool_size, strides, padding, layout, ceil_mode, + count_include_pad, "nn.avg_pool2d"); + }); RELAY_REGISTER_OP("nn.avg_pool2d") -.describe(R"code( + .describe(R"code( Average pooling operation for one dimensional data. - **data**: This depends on the `layout` parameter. Input is 4D array of shape @@ -286,24 +261,24 @@ Average pooling operation for one dimensional data. equation. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(2) -.add_type_rel("AvgPool2D", Pool2DRel) -.set_attr("FInferCorrectLayout", PoolInferCorrectLayout) -.set_attr("FTVMCompute", Pool2DCompute); + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(2) + .add_type_rel("AvgPool2D", Pool2DRel) + .set_attr("FInferCorrectLayout", PoolInferCorrectLayout) + .set_attr("FTVMCompute", Pool2DCompute); // relay.nn.global_pool_2d & relay.nn.max_pool_2d TVM_REGISTER_NODE_TYPE(GlobalPool2DAttrs); -bool GlobalPool2DRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool GlobalPool2DRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); - if (data == nullptr) { return false; } + if (data == nullptr) { + return false; + } const auto dshape = data->shape; CHECK_GE(dshape.size(), 2U) << "Pool2D only support input >= 2-D: input must have height and width"; @@ -313,8 +288,7 @@ bool GlobalPool2DRel(const Array& types, Layout layout(param->layout); CHECK(layout.Contains(LayoutAxis::Get('H')) && layout.Contains(LayoutAxis::Get('W')) && !layout.Contains(LayoutAxis::Get('h')) && !layout.Contains(LayoutAxis::Get('w'))) - << "Invalid layout " << layout - << ". Pool2D layout must have H and W, which cannot be split"; + << "Invalid layout " << layout << ". Pool2D layout must have H and W, which cannot be split"; const auto hidx = layout.IndexOf(LayoutAxis::Get('H')); const auto widx = layout.IndexOf(LayoutAxis::Get('W')); @@ -327,44 +301,38 @@ bool GlobalPool2DRel(const Array& types, return true; } - -template -Array GlobalPool2DCompute(const Attrs& attrs, - const Array& inputs, +template +Array GlobalPool2DCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { static const Layout kNCHW("NCHW"); const auto* param = attrs.as(); CHECK(param != nullptr); Layout layout(param->layout); CHECK(tir::BijectiveLayout(layout, kNCHW).defined()) - << "global_avg_pool2d currently only supports layouts that are convertible from NCHW"; + << "global_avg_pool2d currently only supports layouts that are convertible from NCHW"; CHECK_EQ(layout.IndexOf(LayoutAxis::Get('h')), -1) - << "global_avg_pool2d does not support input split on height"; + << "global_avg_pool2d does not support input split on height"; CHECK_EQ(layout.IndexOf(LayoutAxis::Get('w')), -1) - << "global_avg_pool2d does not support input split on width"; + << "global_avg_pool2d does not support input split on width"; CHECK(inputs[0].ndim() == 4U || inputs[0].ndim() == 5U) - << "Pool2D only support 4-D input (e.g., NCHW)" - << " or 5-D input (last dimension is a split of channel)"; - return Array{ - topi::nn::global_pool(inputs[0], mode, layout.name()) }; + << "Pool2D only support 4-D input (e.g., NCHW)" + << " or 5-D input (last dimension is a split of channel)"; + return Array{topi::nn::global_pool(inputs[0], mode, layout.name())}; } -Expr MakeGlobalAvgPool2D(Expr data, - std::string layout) { +Expr MakeGlobalAvgPool2D(Expr data, String layout) { auto attrs = make_object(); attrs->layout = std::move(layout); static const Op& op = Op::Get("nn.global_avg_pool2d"); return Call(op, {data}, Attrs(attrs), {}); } - -TVM_REGISTER_GLOBAL("relay.op.nn._make.global_avg_pool2d") -.set_body_typed(MakeGlobalAvgPool2D); +TVM_REGISTER_GLOBAL("relay.op.nn._make.global_avg_pool2d").set_body_typed(MakeGlobalAvgPool2D); // GlobalAvgPool RELAY_REGISTER_OP("nn.global_avg_pool2d") -.describe(R"code(Global average pooling operation for 2D data. + .describe(R"code(Global average pooling operation for 2D data. - **data**: This depends on the `layout` parameter. Input is 4D array of shape (batch_size, channels, height, width) if `layout` is `NCHW`. @@ -372,30 +340,26 @@ RELAY_REGISTER_OP("nn.global_avg_pool2d") (batch_size, channels, 1, 1) if `layout` is `NCHW`. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(2) -.add_type_rel("GlobalAvgPool2D", GlobalPool2DRel) -.set_attr("FInferCorrectLayout", - PoolInferCorrectLayout) -.set_attr("FTVMCompute", GlobalPool2DCompute); + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(2) + .add_type_rel("GlobalAvgPool2D", GlobalPool2DRel) + .set_attr("FInferCorrectLayout", PoolInferCorrectLayout) + .set_attr("FTVMCompute", GlobalPool2DCompute); // GlobalMaxPool -Expr MakeGlobalMaxPool2D(Expr data, - std::string layout) { +Expr MakeGlobalMaxPool2D(Expr data, String layout) { auto attrs = make_object(); attrs->layout = std::move(layout); static const Op& op = Op::Get("nn.global_max_pool2d"); return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op.nn._make.global_max_pool2d") -.set_body_typed(MakeGlobalMaxPool2D); - +TVM_REGISTER_GLOBAL("relay.op.nn._make.global_max_pool2d").set_body_typed(MakeGlobalMaxPool2D); RELAY_REGISTER_OP("nn.global_max_pool2d") -.describe(R"code(Global max pooling operation for 2D data. + .describe(R"code(Global max pooling operation for 2D data. - **data**: This depends on the `layout` parameter. Input is 4D array of shape (batch_size, channels, height, width) if `layout` is `NCHW`. @@ -403,44 +367,40 @@ RELAY_REGISTER_OP("nn.global_max_pool2d") (batch_size, channels, 1, 1) if `layout` is `NCHW`. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(2) -.add_type_rel("GlobalMaxPool2D", GlobalPool2DRel) -.set_attr("FInferCorrectLayout", - PoolInferCorrectLayout) -.set_attr("FTVMCompute", GlobalPool2DCompute); - + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(2) + .add_type_rel("GlobalMaxPool2D", GlobalPool2DRel) + .set_attr("FInferCorrectLayout", PoolInferCorrectLayout) + .set_attr("FTVMCompute", GlobalPool2DCompute); // relay.nn.adaptive_pool_2d TVM_REGISTER_NODE_TYPE(AdaptivePool2DAttrs); -bool AdaptivePool2DRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool AdaptivePool2DRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); - if (data == nullptr) { return false; } + if (data == nullptr) { + return false; + } const auto dshape = data->shape; CHECK_GE(dshape.size(), 2U) - << "Pool2D only support input >= 2-D: input must have height and width"; + << "Pool2D only support input >= 2-D: input must have height and width"; const auto* param = attrs.as(); CHECK(param != nullptr); Layout layout(param->layout); CHECK(layout.Contains(LayoutAxis::Get('H')) && layout.Contains(LayoutAxis::Get('W')) && !layout.Contains(LayoutAxis::Get('h')) && !layout.Contains(LayoutAxis::Get('w'))) - << "Invalid layout " << layout - << ". Pool2D layout must have H and W, which cannot be split"; + << "Invalid layout " << layout << ". Pool2D layout must have H and W, which cannot be split"; const auto hidx = layout.IndexOf(LayoutAxis::Get('H')); const auto widx = layout.IndexOf(LayoutAxis::Get('W')); Array oshape(dshape); auto output_size = param->output_size; - CHECK_LE(output_size.size(), 2U) - << "output_size can have up to 2 elements."; + CHECK_LE(output_size.size(), 2U) << "output_size can have up to 2 elements."; IndexExpr output_height, output_width; if (output_size.empty()) { output_height = dshape[hidx]; @@ -461,24 +421,23 @@ bool AdaptivePool2DRel(const Array& types, return true; } -template -Array AdaptivePool2DCompute(const Attrs& attrs, - const Array& inputs, +template +Array AdaptivePool2DCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { static const Layout kNCHW("NCHW"); const auto* param = attrs.as(); CHECK(param != nullptr); Layout layout(param->layout); CHECK(tir::BijectiveLayout(layout, kNCHW).defined()) - << "Adaptive pool2d currently only supports layouts that are convertible from NCHW"; + << "Adaptive pool2d currently only supports layouts that are convertible from NCHW"; CHECK_EQ(layout.IndexOf(LayoutAxis::Get('h')), -1) - << "Adaptive pool2d does not support input split on height"; + << "Adaptive pool2d does not support input split on height"; CHECK_EQ(layout.IndexOf(LayoutAxis::Get('w')), -1) - << "Adaptive pool2d does not support input split on width"; + << "Adaptive pool2d does not support input split on width"; CHECK(inputs[0].ndim() == 4U || inputs[0].ndim() == 5U) - << "Pool2D only support 4-D input (e.g., NCHW)" - << " or 5-D input (last dimension is a split of channel)"; + << "Pool2D only support 4-D input (e.g., NCHW)" + << " or 5-D input (last dimension is a split of channel)"; auto output_size = param->output_size; const auto hidx = layout.IndexOf(LayoutAxis::Get('H')); @@ -494,15 +453,12 @@ Array AdaptivePool2DCompute(const Attrs& attrs, output_height = output_size[0]; output_width = output_size[1]; } - return Array{ - topi::nn::adaptive_pool(inputs[0], Array{ output_height, output_width }, - mode, layout.name()) }; + return Array{topi::nn::adaptive_pool( + inputs[0], Array{output_height, output_width}, mode, layout.name())}; } // relay.nn.adaptive_avg_pool2d -Expr MakeAdaptiveAvgPool2D(Expr data, - Array output_size, - std::string layout) { +Expr MakeAdaptiveAvgPool2D(Expr data, Array output_size, String layout) { auto attrs = make_object(); attrs->output_size = std::move(output_size); attrs->layout = std::move(layout); @@ -510,11 +466,10 @@ Expr MakeAdaptiveAvgPool2D(Expr data, return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op.nn._make.adaptive_avg_pool2d") -.set_body_typed(MakeAdaptiveAvgPool2D); +TVM_REGISTER_GLOBAL("relay.op.nn._make.adaptive_avg_pool2d").set_body_typed(MakeAdaptiveAvgPool2D); RELAY_REGISTER_OP("nn.adaptive_avg_pool2d") - .describe(R"code(Adaptive average pooling operation for 2D data. + .describe(R"code(Adaptive average pooling operation for 2D data. - **data**: This depends on the `layout` parameter. Input is 4D array of shape (batch_size, channels, height, width) if `layout` is `NCHW`. @@ -528,19 +483,17 @@ RELAY_REGISTER_OP("nn.adaptive_avg_pool2d") (batch_size, channels, output_height, output_width) if `layout` is `NCHW`. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(10) -.add_type_rel("AdaptiveAvgPool2D", AdaptivePool2DRel) -.set_attr("FInferCorrectLayout", - PoolInferCorrectLayout) -.set_attr("FTVMCompute", AdaptivePool2DCompute); + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(10) + .add_type_rel("AdaptiveAvgPool2D", AdaptivePool2DRel) + .set_attr("FInferCorrectLayout", + PoolInferCorrectLayout) + .set_attr("FTVMCompute", AdaptivePool2DCompute); // relay.nn.adaptive_max_pool2d -Expr MakeAdaptiveMaxPool2D(Expr data, - Array output_size, - std::string layout) { +Expr MakeAdaptiveMaxPool2D(Expr data, Array output_size, String layout) { auto attrs = make_object(); attrs->output_size = std::move(output_size); attrs->layout = std::move(layout); @@ -548,11 +501,10 @@ Expr MakeAdaptiveMaxPool2D(Expr data, return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op.nn._make.adaptive_max_pool2d") -.set_body_typed(MakeAdaptiveMaxPool2D); +TVM_REGISTER_GLOBAL("relay.op.nn._make.adaptive_max_pool2d").set_body_typed(MakeAdaptiveMaxPool2D); RELAY_REGISTER_OP("nn.adaptive_max_pool2d") - .describe(R"code(Adaptive max pooling operation for 2D data. + .describe(R"code(Adaptive max pooling operation for 2D data. - **data**: This depends on the `layout` parameter. Input is 4D array of shape (batch_size, channels, height, width) if `layout` is `NCHW`. @@ -566,45 +518,43 @@ RELAY_REGISTER_OP("nn.adaptive_max_pool2d") (batch_size, channels, output_height, output_width) if `layout` is `NCHW`. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(10) -.add_type_rel("AdaptiveMaxPool2D", AdaptivePool2DRel) -.set_attr("FInferCorrectLayout", - PoolInferCorrectLayout) -.set_attr("FTVMCompute", AdaptivePool2DCompute); - + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(10) + .add_type_rel("AdaptiveMaxPool2D", AdaptivePool2DRel) + .set_attr("FInferCorrectLayout", + PoolInferCorrectLayout) + .set_attr("FTVMCompute", AdaptivePool2DCompute); TVM_REGISTER_NODE_TYPE(AdaptivePool3DAttrs); -bool AdaptivePool3DRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool AdaptivePool3DRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); - if (data == nullptr) { return false; } + if (data == nullptr) { + return false; + } const auto dshape = data->shape; CHECK_GE(dshape.size(), 3U) - << "Pool3D only support input >= 3-D: input must have depth, height and width"; + << "Pool3D only support input >= 3-D: input must have depth, height and width"; const auto* param = attrs.as(); CHECK(param != nullptr); Layout layout(param->layout); CHECK(layout.Contains(LayoutAxis::Get('D')) && layout.Contains(LayoutAxis::Get('H')) && layout.Contains(LayoutAxis::Get('W')) && !layout.Contains(LayoutAxis::Get('d')) && - !layout.Contains(LayoutAxis::Get('h')) && !layout.Contains(LayoutAxis::Get('w'))) - << "Invalid layout " << layout - << ". Pool3D layout must have D, H and W, which cannot be split"; + !layout.Contains(LayoutAxis::Get('h')) && !layout.Contains(LayoutAxis::Get('w'))) + << "Invalid layout " << layout + << ". Pool3D layout must have D, H and W, which cannot be split"; const auto didx = layout.IndexOf(LayoutAxis::Get('D')); const auto hidx = layout.IndexOf(LayoutAxis::Get('H')); const auto widx = layout.IndexOf(LayoutAxis::Get('W')); Array oshape(dshape); auto output_size = param->output_size; - CHECK_LE(output_size.size(), 3U) - << "output_size can have up to 3 elements."; + CHECK_LE(output_size.size(), 3U) << "output_size can have up to 3 elements."; IndexExpr output_depth, output_height, output_width; if (output_size.empty()) { output_depth = dshape[didx]; @@ -629,26 +579,25 @@ bool AdaptivePool3DRel(const Array& types, return true; } -template -Array AdaptivePool3DCompute(const Attrs& attrs, - const Array& inputs, +template +Array AdaptivePool3DCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { static const Layout kNCDHW("NCDHW"); const auto* param = attrs.as(); CHECK(param != nullptr); Layout layout(param->layout); CHECK(tir::BijectiveLayout(layout, kNCDHW).defined()) - << "Adaptive pool3d currently only supports layouts that are convertible from NCDHW"; + << "Adaptive pool3d currently only supports layouts that are convertible from NCDHW"; CHECK_EQ(layout.IndexOf(LayoutAxis::Get('d')), -1) - << "Adaptive pool3d does not support input split on depth"; + << "Adaptive pool3d does not support input split on depth"; CHECK_EQ(layout.IndexOf(LayoutAxis::Get('h')), -1) - << "Adaptive pool3d does not support input split on height"; + << "Adaptive pool3d does not support input split on height"; CHECK_EQ(layout.IndexOf(LayoutAxis::Get('w')), -1) - << "Adaptive pool3d does not support input split on width"; + << "Adaptive pool3d does not support input split on width"; CHECK(inputs[0].ndim() == 5U || inputs[0].ndim() == 6U) - << "Pool3D only support 5-D input (e.g., NCDHW)" - << " or 6-D input (last dimension is a split of channel)"; + << "Pool3D only support 5-D input (e.g., NCDHW)" + << " or 6-D input (last dimension is a split of channel)"; auto output_size = param->output_size; const auto didx = layout.IndexOf(LayoutAxis::Get('D')); @@ -669,16 +618,12 @@ Array AdaptivePool3DCompute(const Attrs& attrs, output_width = output_size[2]; } - auto osize = Array{ output_depth, output_height, output_width }; - return Array { - topi::nn::adaptive_pool3d(inputs[0], osize, mode, layout.name()) - }; + auto osize = Array{output_depth, output_height, output_width}; + return Array{topi::nn::adaptive_pool3d(inputs[0], osize, mode, layout.name())}; } // relay.nn.adaptive_max_pool3d -Expr MakeAdaptiveMaxPool3D(Expr data, - Array output_size, - std::string layout) { +Expr MakeAdaptiveMaxPool3D(Expr data, Array output_size, String layout) { auto attrs = make_object(); attrs->output_size = std::move(output_size); attrs->layout = std::move(layout); @@ -686,11 +631,10 @@ Expr MakeAdaptiveMaxPool3D(Expr data, return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op.nn._make.adaptive_max_pool3d") -.set_body_typed(MakeAdaptiveMaxPool3D); +TVM_REGISTER_GLOBAL("relay.op.nn._make.adaptive_max_pool3d").set_body_typed(MakeAdaptiveMaxPool3D); RELAY_REGISTER_OP("nn.adaptive_max_pool3d") - .describe(R"code(Adaptive max pooling operation for 3D data. + .describe(R"code(Adaptive max pooling operation for 3D data. - **data**: This depends on the `layout` parameter. Input is 5D array of shape (batch_size, channels, depth, height, width) if `layout` is `NCDHW`. @@ -704,19 +648,17 @@ RELAY_REGISTER_OP("nn.adaptive_max_pool3d") (batch_size, channels, output_depth, output_height, output_width) if `layout` is `NCDHW`. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(10) -.add_type_rel("AdaptiveMaxPool3D", AdaptivePool3DRel) -.set_attr("FInferCorrectLayout", - PoolInferCorrectLayout) -.set_attr("FTVMCompute", AdaptivePool3DCompute); + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(10) + .add_type_rel("AdaptiveMaxPool3D", AdaptivePool3DRel) + .set_attr("FInferCorrectLayout", + PoolInferCorrectLayout) + .set_attr("FTVMCompute", AdaptivePool3DCompute); // relay.nn.adaptive_max_pool3d -Expr MakeAdaptiveAvgPool3D(Expr data, - Array output_size, - std::string layout) { +Expr MakeAdaptiveAvgPool3D(Expr data, Array output_size, String layout) { auto attrs = make_object(); attrs->output_size = std::move(output_size); attrs->layout = std::move(layout); @@ -724,11 +666,10 @@ Expr MakeAdaptiveAvgPool3D(Expr data, return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op.nn._make.adaptive_avg_pool3d") -.set_body_typed(MakeAdaptiveAvgPool3D); +TVM_REGISTER_GLOBAL("relay.op.nn._make.adaptive_avg_pool3d").set_body_typed(MakeAdaptiveAvgPool3D); RELAY_REGISTER_OP("nn.adaptive_avg_pool3d") - .describe(R"code(Adaptive avg pooling operation for 3D data. + .describe(R"code(Adaptive avg pooling operation for 3D data. - **data**: This depends on the `layout` parameter. Input is 5D array of shape (batch_size, channels, depth, height, width) if `layout` is `NCDHW`. - **output_size**: If this argument is not provided, input depth, height and width will be used @@ -740,15 +681,14 @@ RELAY_REGISTER_OP("nn.adaptive_avg_pool3d") - **out**: This depends on the `layout` parameter. Output is 5D array of shape (batch_size, channels, output_depth, output_height, output_width) if `layout` is `NCDHW`. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(10) -.add_type_rel("AdaptiveAvgPool3D", AdaptivePool3DRel) -.set_attr("FInferCorrectLayout", - PoolInferCorrectLayout) -.set_attr("FTVMCompute", AdaptivePool3DCompute); - + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(10) + .add_type_rel("AdaptiveAvgPool3D", AdaptivePool3DRel) + .set_attr("FInferCorrectLayout", + PoolInferCorrectLayout) + .set_attr("FTVMCompute", AdaptivePool3DCompute); bool Pool2DGradRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { @@ -763,8 +703,7 @@ bool Pool2DGradRel(const Array& types, int num_inputs, const Attrs& attrs, } template -Array Pool2DGradCompute(const Attrs& attrs, - const Array& inputs, +Array Pool2DGradCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { static const Layout kNCHW("NCHW"); const auto* param = attrs.as(); @@ -802,17 +741,18 @@ Array Pool2DGradCompute(const Attrs& attrs, if (mode == topi::nn::kAvgPool) { bool count_include_pad = reinterpret_cast(param)->count_include_pad; return Array{topi::nn::pool_grad(inputs[0], inputs[1], pool_size, strides, padding, - mode, ceil_mode, layout.name(), count_include_pad)}; + mode, ceil_mode, layout.name(), + count_include_pad)}; } else { return Array{topi::nn::pool_grad(inputs[0], inputs[1], pool_size, strides, padding, - mode, ceil_mode, layout.name())}; + mode, ceil_mode, layout.name())}; } } - // MaxPool2DGrad Expr MakeMaxPool2DGrad(Expr out_grad, Expr data, Array pool_size, - Array strides, Array padding, std::string layout, bool ceil_mode) { + Array strides, Array padding, String layout, + bool ceil_mode) { auto attrs = make_object(); attrs->pool_size = std::move(pool_size); attrs->strides = std::move(strides); @@ -825,7 +765,6 @@ Expr MakeMaxPool2DGrad(Expr out_grad, Expr data, Array pool_size, TVM_REGISTER_GLOBAL("relay.op.nn._make.max_pool2d_grad").set_body_typed(MakeMaxPool2DGrad); - RELAY_REGISTER_OP("nn.max_pool2d_grad") .describe(R"code(Gradient of max pooling operation for two dimensional data. @@ -849,18 +788,17 @@ RELAY_REGISTER_OP("nn.max_pool2d_grad") (batch_size, channels, height, width) if `layout` is `NCHW`. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(2) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(2) -.add_type_rel("MaxPool2DGrad", Pool2DGradRel) -.set_attr("FTVMCompute", Pool2DGradCompute); - + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(2) + .add_type_rel("MaxPool2DGrad", Pool2DGradRel) + .set_attr("FTVMCompute", Pool2DGradCompute); // AvgPool2DGrad Expr MakeAvgPool2DGrad(Expr out_grad, Expr data, Array pool_size, - Array strides, Array padding, std::string layout, bool ceil_mode, - bool count_include_pad) { + Array strides, Array padding, String layout, + bool ceil_mode, bool count_include_pad) { auto attrs = make_object(); attrs->pool_size = std::move(pool_size); attrs->strides = std::move(strides); @@ -874,7 +812,6 @@ Expr MakeAvgPool2DGrad(Expr out_grad, Expr data, Array pool_size, TVM_REGISTER_GLOBAL("relay.op.nn._make.avg_pool2d_grad").set_body_typed(MakeAvgPool2DGrad); - RELAY_REGISTER_OP("nn.avg_pool2d_grad") .describe(R"code(Gradient of average pooling operation for two dimensional data. @@ -898,22 +835,19 @@ RELAY_REGISTER_OP("nn.avg_pool2d_grad") (batch_size, channels, height, width) if `layout` is `NCHW`. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(2) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(2) -.add_type_rel("MaxPool2DGrad", Pool2DGradRel) -.set_attr("FTVMCompute", Pool2DGradCompute); - + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(2) + .add_type_rel("MaxPool2DGrad", Pool2DGradRel) + .set_attr("FTVMCompute", Pool2DGradCompute); // relay.nn.max_pool1d & relay.nn.avg_pool1d TVM_REGISTER_NODE_TYPE(MaxPool1DAttrs); TVM_REGISTER_NODE_TYPE(AvgPool1DAttrs); template -bool Pool1DRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool Pool1DRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); @@ -921,15 +855,13 @@ bool Pool1DRel(const Array& types, if (data == nullptr) return false; const auto dshape = data->shape; - CHECK_GE(dshape.size(), 1U) - << "Pool1D only support input >= 1-D: input must have width"; + CHECK_GE(dshape.size(), 1U) << "Pool1D only support input >= 1-D: input must have width"; const auto param = attrs.as(); CHECK(param != nullptr); Layout layout(param->layout); CHECK(layout.Contains(LayoutAxis::Get('W')) && !layout.Contains(LayoutAxis::Get('w'))) - << "Invalid layout " << layout - << ". Pool1D layout must have W, which cannot be split"; + << "Invalid layout " << layout << ". Pool1D layout must have W, which cannot be split"; const auto widx = layout.IndexOf(LayoutAxis::Get('W')); @@ -949,8 +881,9 @@ bool Pool1DRel(const Array& types, oshape[widx] = dshape[widx]; } else { if (param->ceil_mode) { - oshape[widx] = ((dshape[widx] + pad_w - param->pool_size[0] + - param->strides[0] - 1) / param->strides[0]) + 1; + oshape[widx] = ((dshape[widx] + pad_w - param->pool_size[0] + param->strides[0] - 1) / + param->strides[0]) + + 1; } else { oshape[widx] = ((dshape[widx] + pad_w - param->pool_size[0]) / param->strides[0]) + 1; } @@ -961,10 +894,8 @@ bool Pool1DRel(const Array& types, return true; } - -template -Array Pool1DCompute(const Attrs& attrs, - const Array& inputs, +template +Array Pool1DCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { static const Layout kNCW("NCW"); const auto* param = attrs.as(); @@ -980,9 +911,7 @@ Array Pool1DCompute(const Attrs& attrs, CHECK_EQ(layout.IndexOf(LayoutAxis::Get('w')), -1) << "max_pool1d does not support input split on width"; - CHECK(inputs[0].ndim() == 3U || - inputs[0].ndim() == 4U || - inputs[0].ndim() == 5U) + CHECK(inputs[0].ndim() == 3U || inputs[0].ndim() == 4U || inputs[0].ndim() == 5U) << "Pool1D only support 3-D input (e.g., NCW)" << " or 4-D input (e.g. NCWc on for vector instructions)" << " or 5-D input (e.g. NCWnc for tensor accelerators)"; @@ -993,29 +922,23 @@ Array Pool1DCompute(const Attrs& attrs, if (mode == topi::nn::kAvgPool) { bool count_include_pad = reinterpret_cast(param)->count_include_pad; - return Array{ - topi::nn::pool1d(inputs[0], pool_size, strides, padding, - mode, ceil_mode, layout.name(), count_include_pad)}; + return Array{topi::nn::pool1d(inputs[0], pool_size, strides, padding, mode, + ceil_mode, layout.name(), count_include_pad)}; } else { return Array{ - topi::nn::pool1d(inputs[0], pool_size, strides, padding, - mode, ceil_mode, layout.name())}; + topi::nn::pool1d(inputs[0], pool_size, strides, padding, mode, ceil_mode, layout.name())}; } } TVM_REGISTER_GLOBAL("relay.op.nn._make.max_pool1d") -.set_body_typed([](Expr data, - Array pool_size, - Array strides, - Array padding, - std::string layout, - bool ceil_mode) { - return MakeMaxPool(data, pool_size, strides, padding, layout, ceil_mode, - "nn.max_pool1d"); -}); + .set_body_typed([](Expr data, Array pool_size, Array strides, + Array padding, String layout, bool ceil_mode) { + return MakeMaxPool(data, pool_size, strides, padding, layout, ceil_mode, + "nn.max_pool1d"); + }); RELAY_REGISTER_OP("nn.max_pool1d") -.describe(R"code(Max pooling operation for one dimensional data. + .describe(R"code(Max pooling operation for one dimensional data. - **data**: This depends on the `layout` parameter. Input is 3D array of shape (batch_size, channels, width) if `layout` is `NCW`. @@ -1033,30 +956,25 @@ RELAY_REGISTER_OP("nn.max_pool1d") equation. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(2) -.add_type_rel("MaxPool1D", Pool1DRel) -.set_attr("FInferCorrectLayout", PoolInferCorrectLayout) -.set_attr("FTVMCompute", Pool1DCompute); - + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(2) + .add_type_rel("MaxPool1D", Pool1DRel) + .set_attr("FInferCorrectLayout", PoolInferCorrectLayout) + .set_attr("FTVMCompute", Pool1DCompute); // AvgPool1D TVM_REGISTER_GLOBAL("relay.op.nn._make.avg_pool1d") -.set_body_typed([](Expr data, - Array pool_size, - Array strides, - Array padding, - std::string layout, - bool ceil_mode, - bool count_include_pad) { - return MakeAvgPool(data, pool_size, strides, padding, layout, ceil_mode, - count_include_pad, "nn.avg_pool1d"); -}); + .set_body_typed([](Expr data, Array pool_size, Array strides, + Array padding, String layout, bool ceil_mode, + bool count_include_pad) { + return MakeAvgPool(data, pool_size, strides, padding, layout, ceil_mode, + count_include_pad, "nn.avg_pool1d"); + }); RELAY_REGISTER_OP("nn.avg_pool1d") -.describe(R"code( + .describe(R"code( Average pooling operation for one dimensional data. - **data**: This depends on the `layout` parameter. Input is 3D array of shape @@ -1075,23 +993,20 @@ Average pooling operation for one dimensional data. equation. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(2) -.add_type_rel("AvgPool1D", Pool1DRel) -.set_attr("FInferCorrectLayout", PoolInferCorrectLayout) -.set_attr("FTVMCompute", Pool1DCompute); - + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(2) + .add_type_rel("AvgPool1D", Pool1DRel) + .set_attr("FInferCorrectLayout", PoolInferCorrectLayout) + .set_attr("FTVMCompute", Pool1DCompute); // relay.nn.max_pool3d & relay.nn.avg_pool3d TVM_REGISTER_NODE_TYPE(MaxPool3DAttrs); TVM_REGISTER_NODE_TYPE(AvgPool3DAttrs); template -bool Pool3DRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool Pool3DRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); @@ -1108,8 +1023,8 @@ bool Pool3DRel(const Array& types, CHECK(layout.Contains(LayoutAxis::Get('D')) && layout.Contains(LayoutAxis::Get('H')) && layout.Contains(LayoutAxis::Get('W')) && !layout.Contains(LayoutAxis::Get('d')) && !layout.Contains(LayoutAxis::Get('h')) && !layout.Contains(LayoutAxis::Get('w'))) - << "Invalid layout " << layout - << ". Pool3D layout must have D, H and W, which cannot be split"; + << "Invalid layout " << layout + << ". Pool3D layout must have D, H and W, which cannot be split"; const auto didx = layout.IndexOf(LayoutAxis::Get('D')); const auto hidx = layout.IndexOf(LayoutAxis::Get('H')); @@ -1143,8 +1058,9 @@ bool Pool3DRel(const Array& types, oshape[ii] = dshape[ii]; } else { if (param->ceil_mode) { - oshape[ii] = ((dshape[ii] + pad[i] - param->pool_size[i] + - param->strides[i] - 1) / param->strides[i]) + 1; + oshape[ii] = ((dshape[ii] + pad[i] - param->pool_size[i] + param->strides[i] - 1) / + param->strides[i]) + + 1; } else { oshape[ii] = ((dshape[ii] + pad[i] - param->pool_size[i]) / param->strides[i]) + 1; } @@ -1156,10 +1072,8 @@ bool Pool3DRel(const Array& types, return true; } - -template -Array Pool3DCompute(const Attrs& attrs, - const Array& inputs, +template +Array Pool3DCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { static const Layout kNCDHW("NCDHW"); const auto* param = attrs.as(); @@ -1179,9 +1093,7 @@ Array Pool3DCompute(const Attrs& attrs, CHECK_EQ(layout.IndexOf(LayoutAxis::Get('w')), -1) << "max_pool3d does not support input split on width"; - CHECK(inputs[0].ndim() == 4U || - inputs[0].ndim() == 5U || - inputs[0].ndim() == 6U) + CHECK(inputs[0].ndim() == 4U || inputs[0].ndim() == 5U || inputs[0].ndim() == 6U) << "Pool3D only support 5-D input (e.g., NCDHW)" << " or 6-D input (e.g. NCDHWc on for vector instructions)" << " or 7-D input (e.g. NCDHWnc for tensor accelerators)"; @@ -1197,29 +1109,23 @@ Array Pool3DCompute(const Attrs& attrs, } if (mode == topi::nn::kAvgPool) { bool count_include_pad = reinterpret_cast(param)->count_include_pad; - return Array{ - topi::nn::pool3d(inputs[0], pool_size, strides, padding, - mode, ceil_mode, layout.name(), count_include_pad)}; + return Array{topi::nn::pool3d(inputs[0], pool_size, strides, padding, mode, + ceil_mode, layout.name(), count_include_pad)}; } else { return Array{ - topi::nn::pool3d(inputs[0], pool_size, strides, padding, - mode, ceil_mode, layout.name())}; + topi::nn::pool3d(inputs[0], pool_size, strides, padding, mode, ceil_mode, layout.name())}; } } TVM_REGISTER_GLOBAL("relay.op.nn._make.max_pool3d") -.set_body_typed([](Expr data, - Array pool_size, - Array strides, - Array padding, - std::string layout, - bool ceil_mode) { - return MakeMaxPool(data, pool_size, strides, padding, layout, ceil_mode, - "nn.max_pool3d"); -}); + .set_body_typed([](Expr data, Array pool_size, Array strides, + Array padding, String layout, bool ceil_mode) { + return MakeMaxPool(data, pool_size, strides, padding, layout, ceil_mode, + "nn.max_pool3d"); + }); RELAY_REGISTER_OP("nn.max_pool3d") -.describe(R"code(Max pooling operation for three dimensional data. + .describe(R"code(Max pooling operation for three dimensional data. - **data**: This depends on the `layout` parameter. Input is 5D array of shape (batch_size, channels, depth, height, width) if `layout` is `NCDHW`. @@ -1240,30 +1146,25 @@ RELAY_REGISTER_OP("nn.max_pool3d") equation. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(2) -.add_type_rel("MaxPool3D", Pool3DRel) -.set_attr("FInferCorrectLayout", PoolInferCorrectLayout) -.set_attr("FTVMCompute", Pool3DCompute); - + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(2) + .add_type_rel("MaxPool3D", Pool3DRel) + .set_attr("FInferCorrectLayout", PoolInferCorrectLayout) + .set_attr("FTVMCompute", Pool3DCompute); // AvgPool3D TVM_REGISTER_GLOBAL("relay.op.nn._make.avg_pool3d") -.set_body_typed([](Expr data, - Array pool_size, - Array strides, - Array padding, - std::string layout, - bool ceil_mode, - bool count_include_pad) { - return MakeAvgPool(data, pool_size, strides, padding, layout, ceil_mode, - count_include_pad, "nn.avg_pool3d"); -}); + .set_body_typed([](Expr data, Array pool_size, Array strides, + Array padding, String layout, bool ceil_mode, + bool count_include_pad) { + return MakeAvgPool(data, pool_size, strides, padding, layout, ceil_mode, + count_include_pad, "nn.avg_pool3d"); + }); RELAY_REGISTER_OP("nn.avg_pool3d") -.describe(R"code( + .describe(R"code( Average pooling operation for three dimensional data. - **data**: This depends on the `layout` parameter. Input is 5D array of shape @@ -1285,13 +1186,13 @@ Average pooling operation for three dimensional data. equation. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(2) -.add_type_rel("AvgPool3D", Pool3DRel) -.set_attr("FInferCorrectLayout", PoolInferCorrectLayout) -.set_attr("FTVMCompute", Pool3DCompute); + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(2) + .add_type_rel("AvgPool3D", Pool3DRel) + .set_attr("FInferCorrectLayout", PoolInferCorrectLayout) + .set_attr("FTVMCompute", Pool3DCompute); } // namespace relay } // namespace tvm diff --git a/src/relay/op/nn/sparse.cc b/src/relay/op/nn/sparse.cc index c761c3f8466e..0aca00ce80a4 100644 --- a/src/relay/op/nn/sparse.cc +++ b/src/relay/op/nn/sparse.cc @@ -22,9 +22,10 @@ * \brief Property def of nn.sparse_dense operator. */ -#include -#include #include +#include +#include + #include #include "../../transforms/infer_layout_util.h" @@ -53,9 +54,8 @@ bool SparseDenseRel(const Array& types, int num_inputs, const Attrs& attrs if (weight_data->shape.size() == 3) { // BSR case. - Array oshape({ - data->shape[0], - (weight_indptr->shape[0] - 1) * weight_data->shape[1]}); + Array oshape( + {data->shape[0], (weight_indptr->shape[0] - 1) * weight_data->shape[1]}); reporter->Assign(types[4], TensorType(oshape, data->dtype)); return true; } @@ -71,32 +71,32 @@ Expr MakeSparseDense(Expr data, Expr weight_data, Expr weight_indices, Expr weig } TVM_REGISTER_GLOBAL("relay.op.nn._make.sparse_dense") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeSparseDense, args, rv); -}); + .set_body([](const TVMArgs& args, TVMRetValue* rv) { + runtime::detail::unpack_call(MakeSparseDense, args, rv); + }); RELAY_REGISTER_OP("nn.sparse_dense") -.describe(R"code(Applies a sparse linear transformation: :math:`Y = XW^T` with X sparse. + .describe(R"code(Applies a sparse linear transformation: :math:`Y = XW^T` with X sparse. - **data**: `(x1, x2, ..., xn, input_dim)` - **weight**: `(units, input_dim)` - **out**: `(x1, x2, ..., xn, units)`. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(4) -.add_argument("data", "nD Tensor", "Input data.") -.add_argument("weight_data", "1D Tensor", "Weight data matrix.") -.add_argument("weight_indices", "1D Tensor", "Weight indices matrix.") -.add_argument("weight_indptr", "1D Tensor", "Weight indptr matrix.") -.set_support_level(1) -.add_type_rel("SparseDense", SparseDenseRel); + .set_attrs_type() + .set_num_inputs(4) + .add_argument("data", "nD Tensor", "Input data.") + .add_argument("weight_data", "1D Tensor", "Weight data matrix.") + .add_argument("weight_indices", "1D Tensor", "Weight indices matrix.") + .add_argument("weight_indptr", "1D Tensor", "Weight indptr matrix.") + .set_support_level(1) + .add_type_rel("SparseDense", SparseDenseRel); // relay.nn.sparse_transpose TVM_REGISTER_NODE_TYPE(SparseTransposeAttrs); bool SparseTransposeRel(const Array& types, int num_inputs, const Attrs& attrs, - const TypeReporter& reporter) { + const TypeReporter& reporter) { CHECK_EQ(types.size(), 4); const auto* sparse_data = types[0].as(); CHECK_EQ(sparse_data->shape.size(), 1); @@ -119,24 +119,22 @@ Expr MakeSparseTranspose(Expr sparse_data, Expr sparse_indices, Expr sparse_indp return Call(op, {sparse_data, sparse_indices, sparse_indptr}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op.nn._make.sparse_transpose") -.set_body_typed(MakeSparseTranspose); - +TVM_REGISTER_GLOBAL("relay.op.nn._make.sparse_transpose").set_body_typed(MakeSparseTranspose); RELAY_REGISTER_OP("nn.sparse_transpose") -.describe(R"code(Transpose a sparse matrix X. Only support square sparse matrix + .describe(R"code(Transpose a sparse matrix X. Only support square sparse matrix - **input**: `(N, N)` - **out**: `(N, N)`. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(3) -.add_argument("sparse_data", "1D Tensor", "Sparse data matrix.") -.add_argument("sparse_indices", "1D Tensor", "Sparse indices matrix.") -.add_argument("sparse_indptr", "1D Tensor", "Sparse index pointer matrix.") -.set_support_level(1) -.add_type_rel("SparseTranspose", SparseTransposeRel); + .set_attrs_type() + .set_num_inputs(3) + .add_argument("sparse_data", "1D Tensor", "Sparse data matrix.") + .add_argument("sparse_indices", "1D Tensor", "Sparse indices matrix.") + .add_argument("sparse_indptr", "1D Tensor", "Sparse index pointer matrix.") + .set_support_level(1) + .add_type_rel("SparseTranspose", SparseTransposeRel); } // namespace relay } // namespace tvm diff --git a/src/relay/op/nn/upsampling.cc b/src/relay/op/nn/upsampling.cc index 63bd42d8f508..cb20881c1c5f 100644 --- a/src/relay/op/nn/upsampling.cc +++ b/src/relay/op/nn/upsampling.cc @@ -21,11 +21,13 @@ * \file upsampling.cc * \brief upsampling operator */ -#include -#include #include +#include #include +#include + #include + #include "../op_common.h" namespace tvm { @@ -35,13 +37,12 @@ TVM_REGISTER_NODE_TYPE(UpSamplingAttrs); TVM_REGISTER_NODE_TYPE(UpSampling3DAttrs); template -Array > UpsamplingInferCorrectLayout( - const Attrs& attrs, - const Array& new_in_layouts, - const Array& old_in_layouts, - const Array &old_in_types) { +Array > UpsamplingInferCorrectLayout(const Attrs& attrs, + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array& old_in_types) { // NOTE: Discard "const" qualifier here. - T *params = const_cast(attrs.as()); + T* params = const_cast(attrs.as()); if (new_in_layouts.defined()) { CHECK_EQ(new_in_layouts.size(), 1); @@ -49,12 +50,12 @@ Array > UpsamplingInferCorrectLayout( Layout raw_layout(params->layout); Layout input = new_in_layouts[0]; if (input.IndexOf(LayoutAxis::Get('W')) == raw_layout.IndexOf(LayoutAxis::Get('W')) && - input.IndexOf(LayoutAxis::Get('H')) == raw_layout.IndexOf(LayoutAxis::Get('H')) && - !input.Contains(LayoutAxis::Get('w')) && !input.Contains(LayoutAxis::Get('h'))&& + input.IndexOf(LayoutAxis::Get('H')) == raw_layout.IndexOf(LayoutAxis::Get('H')) && + !input.Contains(LayoutAxis::Get('w')) && !input.Contains(LayoutAxis::Get('h')) && (input.IndexOf(LayoutAxis::Get('D')) == -1 || - (input.IndexOf(LayoutAxis::Get('D')) == raw_layout.IndexOf(LayoutAxis::Get('D')) && - !input.Contains(LayoutAxis::Get('d'))))) { - params->layout = input.name(); // modify self to follow the input layout + (input.IndexOf(LayoutAxis::Get('D')) == raw_layout.IndexOf(LayoutAxis::Get('D')) && + !input.Contains(LayoutAxis::Get('d'))))) { + params->layout = input.name(); // modify self to follow the input layout } } @@ -62,9 +63,7 @@ Array > UpsamplingInferCorrectLayout( return Array >{{inferred_layout}, {inferred_layout}}; } -bool UpSamplingRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool UpSamplingRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); @@ -78,28 +77,21 @@ bool UpSamplingRel(const Array& types, auto layout_converter = tir::BijectiveLayout(in_layout, kNCHW); CHECK(layout_converter.defined()) - << "UpSampling only support input layouts that are convertible from NCHW." - << " But got " << in_layout; + << "UpSampling only support input layouts that are convertible from NCHW." + << " But got " << in_layout; auto oshape = layout_converter.ForwardShape(data->shape); - oshape.Set(2, tir::CastNode::make(oshape[2].dtype(), tvm::round(oshape[2] * param->scale_h))); - oshape.Set(3, tir::CastNode::make(oshape[3].dtype(), tvm::round(oshape[3] * param->scale_w))); + oshape.Set(2, tir::Cast(oshape[2].dtype(), tvm::round(oshape[2] * param->scale_h))); + oshape.Set(3, tir::Cast(oshape[3].dtype(), tvm::round(oshape[3] * param->scale_w))); // assign output type - reporter->Assign(types[1], - TensorType(layout_converter.BackwardShape(oshape), - data->dtype)); + reporter->Assign(types[1], TensorType(layout_converter.BackwardShape(oshape), data->dtype)); return true; } - // Positional relay function to create upsampling operator // used by frontend FFI. -Expr MakeUpSampling(Expr data, - double scale_h, - double scale_w, - std::string layout, - std::string method, +Expr MakeUpSampling(Expr data, double scale_h, double scale_w, String layout, String method, bool align_corners) { auto attrs = make_object(); attrs->layout = std::move(layout); @@ -111,12 +103,11 @@ Expr MakeUpSampling(Expr data, return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op.nn._make.upsampling") -.set_body_typed(MakeUpSampling); - +TVM_REGISTER_GLOBAL("relay.op.nn._make.upsampling").set_body_typed(MakeUpSampling); RELAY_REGISTER_OP("nn.upsampling") -.describe(R"code(Perform upsampling on input array with nearest neighbour or bilinear interpolation. + .describe( + R"code(Perform upsampling on input array with nearest neighbour or bilinear interpolation. - **data**: data is 4D array of shape (batch_size, channels, in_height, in_width) for NCHW @@ -130,20 +121,17 @@ RELAY_REGISTER_OP("nn.upsampling") (batch_size, in_height*scale, in_width*scale, channels) )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(2) -.add_type_rel("UpSampling", UpSamplingRel) -.set_attr("FInferCorrectLayout", - UpsamplingInferCorrectLayout) -.set_attr("TOpPattern", kInjective); - + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(2) + .add_type_rel("UpSampling", UpSamplingRel) + .set_attr("FInferCorrectLayout", + UpsamplingInferCorrectLayout) + .set_attr("TOpPattern", kInjective); // UpSampling3D -bool UpSampling3DRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool UpSampling3DRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); @@ -157,30 +145,23 @@ bool UpSampling3DRel(const Array& types, auto layout_converter = tir::BijectiveLayout(in_layout, kNCDHW); CHECK(layout_converter.defined()) - << "UpSampling3D only support input layouts that are convertible from NCDHW." - << " But got " << in_layout; + << "UpSampling3D only support input layouts that are convertible from NCDHW." + << " But got " << in_layout; auto oshape = layout_converter.ForwardShape(data->shape); - oshape.Set(2, tir::CastNode::make(oshape[2].dtype(), tvm::round(oshape[2] * param->scale_d))); - oshape.Set(3, tir::CastNode::make(oshape[3].dtype(), tvm::round(oshape[3] * param->scale_h))); - oshape.Set(4, tir::CastNode::make(oshape[4].dtype(), tvm::round(oshape[4] * param->scale_w))); + oshape.Set(2, tir::Cast(oshape[2].dtype(), tvm::round(oshape[2] * param->scale_d))); + oshape.Set(3, tir::Cast(oshape[3].dtype(), tvm::round(oshape[3] * param->scale_h))); + oshape.Set(4, tir::Cast(oshape[4].dtype(), tvm::round(oshape[4] * param->scale_w))); // assign output type - reporter->Assign(types[1], - TensorType(layout_converter.BackwardShape(oshape), - data->dtype)); + reporter->Assign(types[1], TensorType(layout_converter.BackwardShape(oshape), data->dtype)); return true; } // Positional relay function to create upsampling3d operator // used by frontend FFI. -Expr MakeUpSampling3D(Expr data, - double scale_d, - double scale_h, - double scale_w, - std::string layout, - std::string method, - std::string coordinate_transformation_mode) { +Expr MakeUpSampling3D(Expr data, double scale_d, double scale_h, double scale_w, String layout, + String method, String coordinate_transformation_mode) { auto attrs = make_object(); attrs->layout = std::move(layout); attrs->method = std::move(method); @@ -192,12 +173,10 @@ Expr MakeUpSampling3D(Expr data, return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op.nn._make.upsampling3d") -.set_body_typed(MakeUpSampling3D); - +TVM_REGISTER_GLOBAL("relay.op.nn._make.upsampling3d").set_body_typed(MakeUpSampling3D); RELAY_REGISTER_OP("nn.upsampling3d") -.describe(R"code(Perform upsampling on input array with nearest neighbour or + .describe(R"code(Perform upsampling on input array with nearest neighbour or bilinear interpolation. - **data**: data is 5D array of shape @@ -212,14 +191,14 @@ bilinear interpolation. (batch_size, in_depth*scale, in_height*scale, in_width*scale, channels) )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(2) -.add_type_rel("UpSampling3D", UpSampling3DRel) -.set_attr("FInferCorrectLayout", - UpsamplingInferCorrectLayout) -.set_attr("TOpPattern", kInjective); + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(2) + .add_type_rel("UpSampling3D", UpSampling3DRel) + .set_attr("FInferCorrectLayout", + UpsamplingInferCorrectLayout) + .set_attr("TOpPattern", kInjective); } // namespace relay } // namespace tvm diff --git a/src/relay/op/op_common.h b/src/relay/op/op_common.h index 2d89d778e62c..cbb8cec2d43b 100644 --- a/src/relay/op/op_common.h +++ b/src/relay/op/op_common.h @@ -28,11 +28,13 @@ #include #include #include -#include + #include #include -#include "type_relations.h" +#include + #include "../transforms/infer_layout_util.h" +#include "type_relations.h" namespace tvm { namespace relay { @@ -47,21 +49,18 @@ namespace relay { * \param OpName the name of registry. */ -#define RELAY_REGISTER_UNARY_OP(OpName) \ - TVM_REGISTER_GLOBAL("relay.op._make." OpName) \ - .set_body_typed([](Expr data) { \ - static const Op& op = Op::Get(OpName); \ - return Call(op, {data}, Attrs(), {}); \ - }); \ - RELAY_REGISTER_OP(OpName) \ - .set_num_inputs(1) \ - .add_argument("data", "Tensor", "The input tensor.") \ - .add_type_rel("Identity", IdentityRel) \ - .set_attr("TOpPattern", kElemWise) \ - .set_attr("TOpIsStateful", false) \ - .set_attr("FInferCorrectLayout", \ - ElemwiseArbitraryLayout) \ - +#define RELAY_REGISTER_UNARY_OP(OpName) \ + TVM_REGISTER_GLOBAL("relay.op._make." OpName).set_body_typed([](Expr data) { \ + static const Op& op = Op::Get(OpName); \ + return Call(op, {data}, Attrs(), {}); \ + }); \ + RELAY_REGISTER_OP(OpName) \ + .set_num_inputs(1) \ + .add_argument("data", "Tensor", "The input tensor.") \ + .add_type_rel("Identity", IdentityRel) \ + .set_attr("TOpPattern", kElemWise) \ + .set_attr("TOpIsStateful", false) \ + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) /*! Quick helper macro * - Expose a positional make function to construct the node. @@ -73,42 +72,37 @@ namespace relay { * * \param OpName the name of registry. */ -#define RELAY_REGISTER_BINARY_OP(OpName) \ - TVM_REGISTER_GLOBAL("relay.op._make." OpName) \ - .set_body_typed([](Expr lhs, Expr rhs) { \ - static const Op& op = Op::Get(OpName); \ - return Call(op, {lhs, rhs}, Attrs(), {}); \ - }); \ - RELAY_REGISTER_OP(OpName) \ - .set_num_inputs(2) \ - .add_argument("lhs", "Tensor", "The left hand side tensor.") \ - .add_argument("rhs", "Tensor", "The right hand side tensor.") \ - .add_type_rel("Broadcast", BroadcastRel) \ - .set_attr("TOpPattern", kBroadcast) \ - .set_attr("TOpIsStateful", false) \ - .set_attr("FInferCorrectLayout", \ - BinaryBroadcastLayout) +#define RELAY_REGISTER_BINARY_OP(OpName) \ + TVM_REGISTER_GLOBAL("relay.op._make." OpName).set_body_typed([](Expr lhs, Expr rhs) { \ + static const Op& op = Op::Get(OpName); \ + return Call(op, {lhs, rhs}, Attrs(), {}); \ + }); \ + RELAY_REGISTER_OP(OpName) \ + .set_num_inputs(2) \ + .add_argument("lhs", "Tensor", "The left hand side tensor.") \ + .add_argument("rhs", "Tensor", "The right hand side tensor.") \ + .add_type_rel("Broadcast", BroadcastRel) \ + .set_attr("TOpPattern", kBroadcast) \ + .set_attr("TOpIsStateful", false) \ + .set_attr("FInferCorrectLayout", BinaryBroadcastLayout) // Comparisons -#define RELAY_REGISTER_CMP_OP(OpName) \ - TVM_REGISTER_GLOBAL("relay.op._make." OpName) \ - .set_body_typed([](Expr lhs, Expr rhs) { \ - static const Op& op = Op::Get(OpName); \ - return Call(op, {lhs, rhs}, Attrs(), {}); \ - }); \ - RELAY_REGISTER_OP(OpName) \ - .set_num_inputs(2) \ - .add_argument("lhs", "Tensor", "The left hand side tensor.") \ - .add_argument("rhs", "Tensor", "The right hand side tensor.") \ - .add_type_rel("BroadcastComp", BroadcastCompRel) \ - .set_attr("TOpPattern", kBroadcast) \ - .set_attr("TOpIsStateful", false) \ - .set_attr("FInferCorrectLayout", \ - BinaryBroadcastLayout) - +#define RELAY_REGISTER_CMP_OP(OpName) \ + TVM_REGISTER_GLOBAL("relay.op._make." OpName).set_body_typed([](Expr lhs, Expr rhs) { \ + static const Op& op = Op::Get(OpName); \ + return Call(op, {lhs, rhs}, Attrs(), {}); \ + }); \ + RELAY_REGISTER_OP(OpName) \ + .set_num_inputs(2) \ + .add_argument("lhs", "Tensor", "The left hand side tensor.") \ + .add_argument("rhs", "Tensor", "The right hand side tensor.") \ + .add_type_rel("BroadcastComp", BroadcastCompRel) \ + .set_attr("TOpPattern", kBroadcast) \ + .set_attr("TOpIsStateful", false) \ + .set_attr("FInferCorrectLayout", BinaryBroadcastLayout) /*! \brief A helper class for matching and rewriting operators. */ -template +template class OpMatch { public: using MatchFunc = @@ -145,7 +139,7 @@ class OpMatch { private: /*! \brief The match function map. */ - std::unordered_map match_map_; + std::unordered_map match_map_; /*! \brief An optional default case. */ MatchFunc default_; }; @@ -157,8 +151,7 @@ inline void GetPaddingWidth(const Array& padding, IndexExpr* pad_w) { } else if (padding.size() == 2) { *pad_w = padding[0] + padding[1]; } else { - CHECK_EQ(padding.size(), 4) << " Expected padding size of 1 or 2, found " - << padding.size(); + CHECK_EQ(padding.size(), 4) << " Expected padding size of 1 or 2, found " << padding.size(); } } @@ -175,8 +168,7 @@ inline void GetPaddingHeightWidth(const Array& padding, IndexExpr* pa *pad_h = padding[0] + padding[2]; *pad_w = padding[1] + padding[3]; } else { - CHECK_EQ(padding.size(), 4) << " Padding size should be 1, 2 or 4, but got " - << padding.size(); + CHECK_EQ(padding.size(), 4) << " Padding size should be 1, 2 or 4, but got " << padding.size(); } } @@ -196,8 +188,7 @@ inline void GetPaddingDepthHeightWidth(const Array& padding, IndexExp *pad_h = padding[1] + padding[4]; *pad_w = padding[2] + padding[5]; } else { - CHECK_EQ(padding.size(), 6) << " Padding size should be 1, 3 or 6, but got " - << padding.size(); + CHECK_EQ(padding.size(), 6) << " Padding size should be 1, 3 or 6, but got " << padding.size(); } } diff --git a/src/relay/op/tensor/binary.cc b/src/relay/op/tensor/binary.cc index 0f47c9aa2553..026dfc21dd5f 100644 --- a/src/relay/op/tensor/binary.cc +++ b/src/relay/op/tensor/binary.cc @@ -21,166 +21,145 @@ * \file binary.cc * \brief binary broadcast operators. */ +#include #include #include -#include -#include "../type_relations.h" + #include "../op_common.h" +#include "../type_relations.h" namespace tvm { namespace relay { -#define RELAY_BINARY_COMPUTE(FTOPI) \ - [] (const Attrs& attrs, \ - const Array& inputs, \ - const Type& out_type) -> Array { \ - CHECK_EQ(inputs.size(), 2U); \ - return {FTOPI(inputs[0], inputs[1])}; \ - } \ +#define RELAY_BINARY_COMPUTE(FTOPI) \ + [](const Attrs& attrs, const Array& inputs, \ + const Type& out_type) -> Array { \ + CHECK_EQ(inputs.size(), 2U); \ + return {FTOPI(inputs[0], inputs[1])}; \ + } // Addition RELAY_REGISTER_BINARY_OP("add") -.describe("Elementwise add with with broadcasting") -.set_support_level(1) -.set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::add)); + .describe("Elementwise add with with broadcasting") + .set_support_level(1) + .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::add)); // Subtraction RELAY_REGISTER_BINARY_OP("subtract") -.describe("Elementwise substract with broadcasting") -.set_support_level(1) -.set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::subtract)); + .describe("Elementwise substract with broadcasting") + .set_support_level(1) + .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::subtract)); // Right shift RELAY_REGISTER_BINARY_OP("right_shift") -.describe("Elementwise right shift with broadcasting") -.set_support_level(4) -.set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::right_shift)); - + .describe("Elementwise right shift with broadcasting") + .set_support_level(4) + .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::right_shift)); RELAY_REGISTER_BINARY_OP("left_shift") -.describe("Elementwise left shift with broadcasting") -.set_support_level(4) -.set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::left_shift)); - + .describe("Elementwise left shift with broadcasting") + .set_support_level(4) + .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::left_shift)); RELAY_REGISTER_BINARY_OP("maximum") -.describe("Elementwise maximum of two tensors with broadcasting") -.set_support_level(4) -.set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::maximum)); - + .describe("Elementwise maximum of two tensors with broadcasting") + .set_support_level(4) + .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::maximum)); RELAY_REGISTER_BINARY_OP("minimum") -.describe("Elementwise minimum of two tensors with broadcasting") -.set_support_level(4) -.set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::minimum)); - + .describe("Elementwise minimum of two tensors with broadcasting") + .set_support_level(4) + .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::minimum)); RELAY_REGISTER_BINARY_OP("divide") -.describe("Elementwise divide with broadcasting") -.set_support_level(1) -.set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::divide)); - + .describe("Elementwise divide with broadcasting") + .set_support_level(1) + .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::divide)); RELAY_REGISTER_BINARY_OP("floor_divide") -.describe("Elementwise floor divide with broadcasting") -.set_support_level(1) -.set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::floor_divide)); - + .describe("Elementwise floor divide with broadcasting") + .set_support_level(1) + .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::floor_divide)); RELAY_REGISTER_BINARY_OP("multiply") -.describe("Elementwise multiply with broadcasting") -.set_support_level(1) -.set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::multiply)); - + .describe("Elementwise multiply with broadcasting") + .set_support_level(1) + .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::multiply)); RELAY_REGISTER_BINARY_OP("power") -.describe("Elementwise power with broadcasting") -.set_support_level(4) -.set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::power)); - + .describe("Elementwise power with broadcasting") + .set_support_level(4) + .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::power)); RELAY_REGISTER_BINARY_OP("mod") -.describe("Elementwise mod with broadcasting") -.set_support_level(1) -.set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::mod)); - + .describe("Elementwise mod with broadcasting") + .set_support_level(1) + .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::mod)); RELAY_REGISTER_BINARY_OP("floor_mod") - .describe("Elementwise floor mod with broadcasting") - .set_support_level(1) - .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::floor_mod)); - + .describe("Elementwise floor mod with broadcasting") + .set_support_level(1) + .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::floor_mod)); RELAY_REGISTER_BINARY_OP("logical_and") -.describe("Elementwise logical AND with broadcasting") -.set_support_level(4) -.set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::logical_and)); - + .describe("Elementwise logical AND with broadcasting") + .set_support_level(4) + .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::logical_and)); RELAY_REGISTER_BINARY_OP("logical_or") -.describe("Elementwise logical OR with broadcasting") -.set_support_level(4) -.set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::logical_or)); - + .describe("Elementwise logical OR with broadcasting") + .set_support_level(4) + .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::logical_or)); RELAY_REGISTER_BINARY_OP("logical_xor") -.describe("Elementwise logical XOR with broadcasting") -.set_support_level(4) -.set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::logical_xor)); - + .describe("Elementwise logical XOR with broadcasting") + .set_support_level(4) + .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::logical_xor)); RELAY_REGISTER_BINARY_OP("bitwise_and") -.describe("Elementwise bitwise AND with broadcasting") -.set_support_level(4) -.set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::bitwise_and)); - + .describe("Elementwise bitwise AND with broadcasting") + .set_support_level(4) + .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::bitwise_and)); RELAY_REGISTER_BINARY_OP("bitwise_or") -.describe("Elementwise bitwise OR with broadcasting") -.set_support_level(4) -.set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::bitwise_or)); - + .describe("Elementwise bitwise OR with broadcasting") + .set_support_level(4) + .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::bitwise_or)); RELAY_REGISTER_BINARY_OP("bitwise_xor") -.describe("Elementwise bitwise XOR with broadcasting") -.set_support_level(4) -.set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::bitwise_xor)); - + .describe("Elementwise bitwise XOR with broadcasting") + .set_support_level(4) + .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::bitwise_xor)); RELAY_REGISTER_CMP_OP("equal") -.describe("Elementwise equal compare with broadcasting") -.set_support_level(4) -.set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::equal)); - + .describe("Elementwise equal compare with broadcasting") + .set_support_level(4) + .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::equal)); RELAY_REGISTER_CMP_OP("not_equal") -.describe("Elementwise not equal with broadcasting") -.set_support_level(4) -.set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::not_equal)); - + .describe("Elementwise not equal with broadcasting") + .set_support_level(4) + .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::not_equal)); RELAY_REGISTER_CMP_OP("less") -.describe("Elementwise less than with broadcasting") -.set_support_level(4) -.set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::less)); - + .describe("Elementwise less than with broadcasting") + .set_support_level(4) + .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::less)); RELAY_REGISTER_CMP_OP("less_equal") -.describe("Elementwise less than or equal compare with broadcasting") -.set_support_level(4) -.set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::less_equal)); - + .describe("Elementwise less than or equal compare with broadcasting") + .set_support_level(4) + .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::less_equal)); RELAY_REGISTER_CMP_OP("greater") -.describe("Elementwise greater than compare with broadcasting") -.set_support_level(4) -.set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::greater)); - + .describe("Elementwise greater than compare with broadcasting") + .set_support_level(4) + .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::greater)); RELAY_REGISTER_CMP_OP("greater_equal") -.describe("Elementwise greater than or equal compare with broadcasting") -.set_support_level(4) -.set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::greater_equal)); + .describe("Elementwise greater than or equal compare with broadcasting") + .set_support_level(4) + .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::greater_equal)); } // namespace relay } // namespace tvm diff --git a/src/relay/op/tensor/reduce.cc b/src/relay/op/tensor/reduce.cc index 3f220fb64ad5..d526cef5bf62 100644 --- a/src/relay/op/tensor/reduce.cc +++ b/src/relay/op/tensor/reduce.cc @@ -21,13 +21,15 @@ * \file reduce.cc * \brief Reduction operators. */ -#include -#include -#include #include #include -#include +#include +#include +#include + #include +#include + #include "../op_common.h" #include "../type_relations.h" @@ -37,14 +39,13 @@ namespace relay { TVM_REGISTER_NODE_TYPE(ReduceAttrs); /*! -* \brief GetReduceAxes, get the new axis from indim and other arguments -* \param indim Number of dimensions of input data. -* \param axis The input axis vector. -* \param exclude Whether 'axis' input given is the excluded axis. -* \return r_axes The new reduced axes of the output. -*/ -inline std::vector GetReduceAxes(const uint32_t indim, - const Array& inaxis, + * \brief GetReduceAxes, get the new axis from indim and other arguments + * \param indim Number of dimensions of input data. + * \param axis The input axis vector. + * \param exclude Whether 'axis' input given is the excluded axis. + * \return r_axes The new reduced axes of the output. + */ +inline std::vector GetReduceAxes(const uint32_t indim, const Array& inaxis, bool exclude) { if (!inaxis.defined()) { std::vector r_axes(indim); @@ -60,16 +61,13 @@ inline std::vector GetReduceAxes(const uint32_t indim, } // Check out of bounds error - CHECK(axis >= 0) - << "Axis out of bounds in reduce operator."; - CHECK(axis < indim) - << "Axis out of bounds in reduce operator."; + CHECK(axis >= 0) << "Axis out of bounds in reduce operator."; + CHECK(axis < indim) << "Axis out of bounds in reduce operator."; in_axes.push_back(axis); } CHECK(in_axes[in_axes.size() - 1] < indim) - << "Reduction axis " << in_axes[in_axes.size() - 1] - << " exceeds input dimensions " << indim; + << "Reduction axis " << in_axes[in_axes.size() - 1] << " exceeds input dimensions " << indim; std::sort(in_axes.begin(), in_axes.end()); @@ -81,18 +79,16 @@ inline std::vector GetReduceAxes(const uint32_t indim, std::vector r_axes(r_size); for (uint32_t i = 0, j = 0, k = 0; i < indim; ++i) { if (j < in_axes.size() && in_axes[j] == i) { - ++j; - continue; + ++j; + continue; } r_axes[k++] = i; } return r_axes; } - // Get axis under exclude condition. -Array GetExcludeAxes(size_t indim, - const Array& inaxis) { +Array GetExcludeAxes(size_t indim, const Array& inaxis) { CHECK(inaxis.defined()) << "Cannot set exclude when axis=None"; std::vector axis_flag(indim, true); for (auto i : inaxis) { @@ -101,10 +97,8 @@ Array GetExcludeAxes(size_t indim, axis = axis + static_cast(indim); } // Check out of bounds error - CHECK_GE(axis, 0) - << "Axis out of bounds in reduce operator."; - CHECK_LT(axis, static_cast(indim)) - << "Axis out of bounds in reduce operator."; + CHECK_GE(axis, 0) << "Axis out of bounds in reduce operator."; + CHECK_LT(axis, static_cast(indim)) << "Axis out of bounds in reduce operator."; axis_flag[axis] = false; } @@ -177,34 +171,32 @@ Array> ReduceInferCorrectLayout(const Attrs& attrs, return Array>{{ret}, {ret}}; } -template -Array ReduceCompute(const Attrs& attrs, - const Array& inputs, - const Type& out_type, - F f) { +template +Array ReduceCompute(const Attrs& attrs, const Array& inputs, + const Type& out_type, F f) { const ReduceAttrs* param = attrs.as(); CHECK(param != nullptr); if (inputs[0]->shape.size() == 0) { - return { topi::identity(inputs[0]) }; + return {topi::identity(inputs[0])}; } auto axes = param->axis; if (param->exclude) { axes = GetExcludeAxes(inputs[0]->shape.size(), param->axis); if (axes.size() == 0) { - return { topi::identity(inputs[0]) }; + return {topi::identity(inputs[0])}; } } - return { f(inputs[0], axes, param->keepdims, false) }; + return {f(inputs[0], axes, param->keepdims, false)}; } /*! -* \brief ReduceShapeImpl get the outshape for the reduction operator -* \param in_shape Shape of input data. -* \param param ReduceAttrs details. -* \param reporter The reporter to report solution to. -* \return oshape Output shape inferred. -*/ -inline std::vector ReduceShapeImpl(const std::vector &in_shape, + * \brief ReduceShapeImpl get the outshape for the reduction operator + * \param in_shape Shape of input data. + * \param param ReduceAttrs details. + * \param reporter The reporter to report solution to. + * \return oshape Output shape inferred. + */ +inline std::vector ReduceShapeImpl(const std::vector& in_shape, const ReduceAttrs* param, const TypeReporter& reporter) { uint32_t indim = in_shape.size(); @@ -225,9 +217,9 @@ inline std::vector ReduceShapeImpl(const std::vector &in_s } if (is_dynamic_input) { - CHECK(reporter->Assert(max_shape < tir::make_const( - DataType::Int(64), std::numeric_limits::max()))) - << "The maximum possible index of reduced shape cannot be more than int32 max."; + CHECK(reporter->Assert(max_shape < + tir::make_const(DataType::Int(64), std::numeric_limits::max()))) + << "The maximum possible index of reduced shape cannot be more than int32 max."; } if (param->keepdims) { @@ -255,16 +247,14 @@ inline std::vector ReduceShapeImpl(const std::vector &in_s } /*! -* \brief ArgReduceRel Output type and shape relation evaluation function. -* \param num_inputs Number of input types in the args. -* \param attrs The additional attributes of the operator. -* \param reporter The reporter to report solution to. -* \return false if This relation cannot be resolved. true if this relation has been resolved. -*/ -bool ArgReduceRel(const Array& types, - int num_inputs, - const Attrs& attrs, - const TypeReporter& reporter) { + * \brief ArgReduceRel Output type and shape relation evaluation function. + * \param num_inputs Number of input types in the args. + * \param attrs The additional attributes of the operator. + * \param reporter The reporter to report solution to. + * \return false if This relation cannot be resolved. true if this relation has been resolved. + */ +bool ArgReduceRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); if (data == nullptr) return false; @@ -281,15 +271,13 @@ bool ArgReduceRel(const Array& types, } /*! -* \brief ReduceRel Output type and shape relation evaluation function. -* \param num_inputs Number of input types in the args. -* \param attrs The additional attributes of the operator. -* \param reporter The reporter to report solution to. -* \return false if This relation cannot be resolved. true if this relation has been resolved. -*/ -bool ReduceRel(const Array& types, - int num_inputs, - const Attrs& attrs, + * \brief ReduceRel Output type and shape relation evaluation function. + * \param num_inputs Number of input types in the args. + * \param attrs The additional attributes of the operator. + * \param reporter The reporter to report solution to. + * \return false if This relation cannot be resolved. true if this relation has been resolved. + */ +bool ReduceRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); @@ -305,70 +293,57 @@ bool ReduceRel(const Array& types, return true; } -#define RELAY_REGISTER_REDUCE_OP(OpName) \ - TVM_REGISTER_GLOBAL("relay.op._make." OpName) \ - .set_body_typed([]( \ - Expr data, \ - Array axis, \ - bool keepdims, \ - bool exclude) { \ - auto attrs = make_object(); \ - attrs->axis = std::move(axis); \ - attrs->keepdims = keepdims; \ - attrs->exclude = exclude; \ - static const Op& op = Op::Get(OpName); \ - return Call(op, {data}, Attrs(attrs), {}); \ - }); \ - RELAY_REGISTER_OP(OpName) \ - .set_num_inputs(1) \ - .add_argument("data", "Tensor", "The input tensor.") - - -Array ArgMaxCompute(const Attrs& attrs, - const Array& inputs, +#define RELAY_REGISTER_REDUCE_OP(OpName) \ + TVM_REGISTER_GLOBAL("relay.op._make." OpName) \ + .set_body_typed([](Expr data, Array axis, bool keepdims, bool exclude) { \ + auto attrs = make_object(); \ + attrs->axis = std::move(axis); \ + attrs->keepdims = keepdims; \ + attrs->exclude = exclude; \ + static const Op& op = Op::Get(OpName); \ + return Call(op, {data}, Attrs(attrs), {}); \ + }); \ + RELAY_REGISTER_OP(OpName).set_num_inputs(1).add_argument("data", "Tensor", "The input tensor.") + +Array ArgMaxCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { return ReduceCompute(attrs, inputs, out_type, topi::argmax); } - RELAY_REGISTER_REDUCE_OP("argmax") -.describe(R"code(Creates an operation that finds the indices of the maximum + .describe(R"code(Creates an operation that finds the indices of the maximum values over a given axis. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_support_level(4) -.add_type_rel("ArgReduce", ArgReduceRel) -.set_attr("FTVMCompute", ArgMaxCompute) -.set_attr("TOpPattern", kCommReduce); - + .set_attrs_type() + .set_support_level(4) + .add_type_rel("ArgReduce", ArgReduceRel) + .set_attr("FTVMCompute", ArgMaxCompute) + .set_attr("TOpPattern", kCommReduce); -Array ArgMinCompute(const Attrs& attrs, - const Array& inputs, +Array ArgMinCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { return ReduceCompute(attrs, inputs, out_type, topi::argmin); } RELAY_REGISTER_REDUCE_OP("argmin") -.describe(R"code(Creates an operation that finds the indices of the minimum + .describe(R"code(Creates an operation that finds the indices of the minimum values over a given axis. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_support_level(4) -.add_type_rel("ArgReduce", ArgReduceRel) -.set_attr("FTVMCompute", ArgMinCompute) -.set_attr("TOpPattern", kCommReduce); - -Array SumCompute(const Attrs& attrs, - const Array& inputs, + .set_attrs_type() + .set_support_level(4) + .add_type_rel("ArgReduce", ArgReduceRel) + .set_attr("FTVMCompute", ArgMinCompute) + .set_attr("TOpPattern", kCommReduce); + +Array SumCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { return ReduceCompute(attrs, inputs, out_type, topi::sum); } - RELAY_REGISTER_REDUCE_OP("sum") -.describe(R"code(Computes the sum of array elements over given axes. + .describe(R"code(Computes the sum of array elements over given axes. Example:: @@ -385,23 +360,20 @@ Example:: [ 12. 19. 27.] )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_support_level(4) -.add_type_rel("Reduce", ReduceRel) -.set_attr("FInferCorrectLayout", ReduceInferCorrectLayout) -.set_attr("FTVMCompute", SumCompute) -.set_attr("TOpPattern", kCommReduce); - - -Array AllCompute(const Attrs& attrs, - const Array& inputs, + .set_attrs_type() + .set_support_level(4) + .add_type_rel("Reduce", ReduceRel) + .set_attr("FInferCorrectLayout", ReduceInferCorrectLayout) + .set_attr("FTVMCompute", SumCompute) + .set_attr("TOpPattern", kCommReduce); + +Array AllCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { return ReduceCompute(attrs, inputs, out_type, topi::all); } - RELAY_REGISTER_REDUCE_OP("all") -.describe(R"code(Computes the logical AND of boolean array elements over given axes. + .describe(R"code(Computes the logical AND of boolean array elements over given axes. Example:: @@ -422,22 +394,19 @@ Example:: [False, True, False]] )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_support_level(4) -.add_type_rel("Reduce", ReduceRel) -.set_attr("FTVMCompute", AllCompute) -.set_attr("TOpPattern", kCommReduce); - + .set_attrs_type() + .set_support_level(4) + .add_type_rel("Reduce", ReduceRel) + .set_attr("FTVMCompute", AllCompute) + .set_attr("TOpPattern", kCommReduce); -Array AnyCompute(const Attrs& attrs, - const Array& inputs, +Array AnyCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { return ReduceCompute(attrs, inputs, out_type, topi::any); } - RELAY_REGISTER_REDUCE_OP("any") -.describe(R"code(Computes the logical OR of boolean array elements over given axes. + .describe(R"code(Computes the logical OR of boolean array elements over given axes. Example:: @@ -458,56 +427,49 @@ Example:: [False, True, True]] )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_support_level(4) -.add_type_rel("Reduce", ReduceRel) -.set_attr("FTVMCompute", AnyCompute) -.set_attr("TOpPattern", kCommReduce); - + .set_attrs_type() + .set_support_level(4) + .add_type_rel("Reduce", ReduceRel) + .set_attr("FTVMCompute", AnyCompute) + .set_attr("TOpPattern", kCommReduce); -Array MaxCompute(const Attrs& attrs, - const Array& inputs, +Array MaxCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { return ReduceCompute(attrs, inputs, out_type, topi::max); } RELAY_REGISTER_REDUCE_OP("max") -.describe(R"code(Computes the max of array elements over given axes. + .describe(R"code(Computes the max of array elements over given axes. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_support_level(4) -.add_type_rel("Reduce", ReduceRel) -.set_attr("FTVMCompute", MaxCompute) -.set_attr("TOpPattern", kCommReduce); - + .set_attrs_type() + .set_support_level(4) + .add_type_rel("Reduce", ReduceRel) + .set_attr("FTVMCompute", MaxCompute) + .set_attr("TOpPattern", kCommReduce); -Array MinCompute(const Attrs& attrs, - const Array& inputs, +Array MinCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { return ReduceCompute(attrs, inputs, out_type, topi::min); } - RELAY_REGISTER_REDUCE_OP("min") -.describe(R"code(Computes the min of array elements over given axes. + .describe(R"code(Computes the min of array elements over given axes. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_support_level(4) -.add_type_rel("Reduce", ReduceRel) -.set_attr("FTVMCompute", MinCompute) -.set_attr("TOpPattern", kCommReduce); - + .set_attrs_type() + .set_support_level(4) + .add_type_rel("Reduce", ReduceRel) + .set_attr("FTVMCompute", MinCompute) + .set_attr("TOpPattern", kCommReduce); -Array ProdCompute(const Attrs& attrs, - const Array& inputs, +Array ProdCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { return ReduceCompute(attrs, inputs, out_type, topi::prod); } RELAY_REGISTER_REDUCE_OP("prod") -.describe(R"code(Computes the products of array elements over given axes. + .describe(R"code(Computes the products of array elements over given axes. Example:: @@ -522,32 +484,27 @@ Example:: [ 36 480 2058] )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_support_level(4) -.add_type_rel("Reduce", ReduceRel) -.set_attr("FTVMCompute", ProdCompute) -.set_attr("TOpPattern", kCommReduce); + .set_attrs_type() + .set_support_level(4) + .add_type_rel("Reduce", ReduceRel) + .set_attr("FTVMCompute", ProdCompute) + .set_attr("TOpPattern", kCommReduce); - -Array MeanCompute(const Attrs& attrs, - const Array& inputs, - const Type& out_type) { +Array MeanCompute(const Attrs& attrs, const Array& inputs, + const Type& out_type) { IndexExpr count = tir::make_const(inputs[0]->dtype, 1); const ReduceAttrs* param = attrs.as(); CHECK(param != nullptr); auto axes = param->axis; - for (int64_t i : GetReduceAxes(inputs[0]->shape.size(), - param->axis, - param->exclude)) { + for (int64_t i : GetReduceAxes(inputs[0]->shape.size(), param->axis, param->exclude)) { count *= inputs[0]->shape[i]; } auto res = ReduceCompute(attrs, inputs, out_type, topi::sum); return {topi::divide(res[0], count)}; } - RELAY_REGISTER_REDUCE_OP("mean") -.describe(R"code(Computes the mean of array elements over given axes. + .describe(R"code(Computes the mean of array elements over given axes. Example:: @@ -562,16 +519,13 @@ Example:: [ 2. 3.16666667 4.5] )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_support_level(4) -.add_type_rel("Reduce", ReduceRel) -.set_attr("FTVMCompute", MeanCompute) -.set_attr("TOpPattern", kCommReduce); - + .set_attrs_type() + .set_support_level(4) + .add_type_rel("Reduce", ReduceRel) + .set_attr("FTVMCompute", MeanCompute) + .set_attr("TOpPattern", kCommReduce); -bool VarianceRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool VarianceRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 3); const auto* data = types[0].as(); @@ -593,8 +547,7 @@ bool VarianceRel(const Array& types, return true; } -Array VarianceCompute(const Attrs& attrs, - const Array& inputs, +Array VarianceCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { IndexExpr count = tir::make_const(inputs[0]->dtype, 1); const ReduceAttrs* param = attrs.as(); @@ -602,9 +555,7 @@ Array VarianceCompute(const Attrs& attrs, auto axes = param->axis; auto data = inputs[0]; auto mean = inputs[1]; - for (int64_t i : GetReduceAxes(data->shape.size(), - param->axis, - param->exclude)) { + for (int64_t i : GetReduceAxes(data->shape.size(), param->axis, param->exclude)) { count *= data->shape[i]; } std::vector expand_shape; @@ -614,11 +565,7 @@ Array VarianceCompute(const Attrs& attrs, return {var}; } -Expr MakeVariance(Expr data, - Expr mean, - Array axis, - bool keepdims, - bool exclude) { +Expr MakeVariance(Expr data, Expr mean, Array axis, bool keepdims, bool exclude) { auto attrs = make_object(); attrs->axis = std::move(axis); attrs->keepdims = keepdims; @@ -627,23 +574,22 @@ Expr MakeVariance(Expr data, return Call(op, {data, mean}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op._make._variance") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("relay.op._make._variance").set_body([](const TVMArgs& args, TVMRetValue* rv) { runtime::detail::unpack_call(MakeVariance, args, rv); }); RELAY_REGISTER_OP("variance") -.describe(R"code(Computes the variance of array elements over given axes. + .describe(R"code(Computes the variance of array elements over given axes. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_support_level(4) -.set_num_inputs(2) -.add_argument("data", "Tensor", "The input tensor.") -.add_argument("mean", "Tensor", "The mean tensor.") -.add_type_rel("Variance", VarianceRel) -.set_attr("FTVMCompute", VarianceCompute) -.set_attr("TOpPattern", kCommReduce); + .set_attrs_type() + .set_support_level(4) + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("mean", "Tensor", "The mean tensor.") + .add_type_rel("Variance", VarianceRel) + .set_attr("FTVMCompute", VarianceCompute) + .set_attr("TOpPattern", kCommReduce); } // namespace relay } // namespace tvm diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 7aa8bf1863a1..2a7e4e21e68b 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -21,24 +21,26 @@ * \file transform.cc * \brief Transform operators. */ -#include +#include "transform.h" + +#include +#include +#include +#include +#include #include #include -#include -#include -#include +#include #include -#include -#include -#include -#include -#include +#include +#include +#include + #include -#include "../op_common.h" -#include "../../../arith/compute_expr.h" + #include "../../transforms/infer_layout_util.h" #include "../../transforms/pattern_util.h" -#include "transform.h" +#include "../op_common.h" namespace tvm { namespace relay { @@ -47,115 +49,95 @@ using tir::IntImmNode; // relay.cast TVM_REGISTER_NODE_TYPE(CastAttrs); -bool CastRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool CastRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); if (data == nullptr) { CHECK(types[0].as()) - << "cast: expect input type to be TensorType but get " - << types[0]; + << "cast: expect input type to be TensorType but get " << types[0]; return false; } const auto* param = attrs.as(); - reporter->Assign(types[1], TensorType( - data->shape, param->dtype)); + reporter->Assign(types[1], TensorType(data->shape, param->dtype)); return true; } -Array CastCompute(const Attrs& attrs, - const Array& inputs, +Array CastCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { - const CastAttrs *param = attrs.as(); + const CastAttrs* param = attrs.as(); CHECK(param != nullptr); DataType dtype = param->dtype; - return { topi::cast(inputs[0], dtype) }; + return {topi::cast(inputs[0], dtype)}; } -Expr MakeCast(Expr data, - DataType dtype) { +Expr MakeCast(Expr data, DataType dtype) { auto attrs = make_object(); attrs->dtype = dtype; static const Op& op = Op::Get("cast"); return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.ir.cast") -.set_body_typed(MakeCast); +TVM_REGISTER_GLOBAL("relay.ir.cast").set_body_typed(MakeCast); RELAY_REGISTER_OP("cast") -.describe(R"code(Cast the data into a new data type. + .describe(R"code(Cast the data into a new data type. )code" TVM_ADD_FILELINE) -.set_num_inputs(1) -.set_attrs_type() -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(3) -.add_type_rel("Cast", CastRel) -.set_attr("FTVMCompute", CastCompute) -.set_attr("TOpPattern", kElemWise) -.set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout); - + .set_num_inputs(1) + .set_attrs_type() + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(3) + .add_type_rel("Cast", CastRel) + .set_attr("FTVMCompute", CastCompute) + .set_attr("TOpPattern", kElemWise) + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout); // relay.cast_like -bool CastLikeRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool CastLikeRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 3); const auto* data = types[0].as(); if (data == nullptr) { CHECK(types[0].as()) - << "cast: expect input type to be TensorType but get " - << types[0]; + << "cast: expect input type to be TensorType but get " << types[0]; return false; } const auto* dtype_like = types[1].as(); if (dtype_like == nullptr) { CHECK(types[1].as()) - << "cast: expect input type to be TensorType but get " - << types[1]; + << "cast: expect input type to be TensorType but get " << types[1]; return false; } reporter->Assign(types[2], TensorType(data->shape, dtype_like->dtype)); return true; } - -Array CastLikeCompute(const Attrs& attrs, - const Array& inputs, +Array CastLikeCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { - return { topi::cast(inputs[0], inputs[1]->dtype) }; + return {topi::cast(inputs[0], inputs[1]->dtype)}; } - -Expr MakeCastLike(Expr data, - Expr dtype_like) { +Expr MakeCastLike(Expr data, Expr dtype_like) { static const Op& op = Op::Get("cast_like"); return Call(op, {data, dtype_like}, Attrs(), {}); } - -TVM_REGISTER_GLOBAL("relay.ir.cast_like") -.set_body_typed(MakeCastLike); +TVM_REGISTER_GLOBAL("relay.ir.cast_like").set_body_typed(MakeCastLike); RELAY_REGISTER_OP("cast_like") -.describe(R"code(Cast the data into the type of another tensor. + .describe(R"code(Cast the data into the type of another tensor. )code" TVM_ADD_FILELINE) -.set_num_inputs(2) -.add_argument("data", "Tensor", "The input tensor.") -.add_argument("dtype_like", "Tensor", "The tensor to cast to.") -.set_support_level(3) -.add_type_rel("CastLike", CastLikeRel) -.set_attr("FTVMCompute", CastLikeCompute) -.set_attr("TOpPattern", kElemWise) -.set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout); - - -Array ReinterpretCompute(const Attrs& attrs, - const Array& inputs, + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("dtype_like", "Tensor", "The tensor to cast to.") + .set_support_level(3) + .add_type_rel("CastLike", CastLikeRel) + .set_attr("FTVMCompute", CastLikeCompute) + .set_attr("TOpPattern", kElemWise) + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout); + +Array ReinterpretCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { const CastAttrs* param = attrs.as(); CHECK(param != nullptr); @@ -175,44 +157,39 @@ TVM_REGISTER_GLOBAL("relay._make.reinterpret").set_body([](const TVMArgs& args, }); RELAY_REGISTER_OP("reinterpret") -.describe(R"code(Reinterpret the data into a new data type. + .describe(R"code(Reinterpret the data into a new data type. )code" TVM_ADD_FILELINE) -.set_num_inputs(1) -.set_attrs_type() -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(3) -.add_type_rel("Reinterpret", CastRel) -.set_attr("FTVMCompute", ReinterpretCompute) -.set_attr("TOpPattern", kElemWise) -.set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout); + .set_num_inputs(1) + .set_attrs_type() + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(3) + .add_type_rel("Reinterpret", CastRel) + .set_attr("FTVMCompute", ReinterpretCompute) + .set_attr("TOpPattern", kElemWise) + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout); // relay.expand_dims TVM_REGISTER_NODE_TYPE(ExpandDimsAttrs); -bool ExpandDimsRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool ExpandDimsRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { // `types` contains: [data, result] CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); if (data == nullptr) { CHECK(types[0].as()) - << "expand_dims: expect input type to be TensorType but get " - << types[0]; + << "expand_dims: expect input type to be TensorType but get " << types[0]; return false; } const auto* param = attrs.as(); const int ndim = static_cast(data->shape.size()); const int axis = param->axis; const int num_newaxis = param->num_newaxis; - CHECK(num_newaxis >= 0) - << "expand_dims only accepts `num_newaxis >= 0`" - << ", but got num_newaxis = " << num_newaxis; + CHECK(num_newaxis >= 0) << "expand_dims only accepts `num_newaxis >= 0`" + << ", but got num_newaxis = " << num_newaxis; CHECK(-ndim - 1 <= axis && axis <= ndim) - << "expand_dims only accepts `axis` in [-data.ndim - 1, data.ndim]" - << ", but got axis = " << axis - << ", and data.ndim = " << ndim; + << "expand_dims only accepts `axis` in [-data.ndim - 1, data.ndim]" + << ", but got axis = " << axis << ", and data.ndim = " << ndim; const int pivot = axis < 0 ? ndim + axis + 1 : axis; std::vector oshape; oshape.reserve(ndim + num_newaxis); @@ -229,17 +206,14 @@ bool ExpandDimsRel(const Array& types, return true; } -Array ExpandDimsCompute(const Attrs& attrs, - const Array& inputs, +Array ExpandDimsCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { - const ExpandDimsAttrs *param = attrs.as(); + const ExpandDimsAttrs* param = attrs.as(); CHECK(param != nullptr); - return { topi::expand_dims(inputs[0], param->axis, param->num_newaxis) }; + return {topi::expand_dims(inputs[0], param->axis, param->num_newaxis)}; } -Expr MakeExpandDims(Expr data, - int axis, - int num_newaxis) { +Expr MakeExpandDims(Expr data, int axis, int num_newaxis) { auto attrs = make_object(); attrs->axis = axis; attrs->num_newaxis = num_newaxis; @@ -247,75 +221,68 @@ Expr MakeExpandDims(Expr data, return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op._make.expand_dims") -.set_body_typed(MakeExpandDims); +TVM_REGISTER_GLOBAL("relay.op._make.expand_dims").set_body_typed(MakeExpandDims); RELAY_REGISTER_OP("expand_dims") -.describe(R"code(Insert `num_newaxis` axises at the position given by `axis` + .describe(R"code(Insert `num_newaxis` axises at the position given by `axis` - **data**: The input data to the operator. )code" TVM_ADD_FILELINE) -.set_num_inputs(1) -.set_attrs_type() -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(1) -.add_type_rel("ExpandDims", ExpandDimsRel) -.set_attr("FTVMCompute", ExpandDimsCompute) -.set_attr("TOpPattern", kBroadcast); + .set_num_inputs(1) + .set_attrs_type() + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(1) + .add_type_rel("ExpandDims", ExpandDimsRel) + .set_attr("FTVMCompute", ExpandDimsCompute) + .set_attr("TOpPattern", kBroadcast); // relay.concatenate TVM_REGISTER_NODE_TYPE(ConcatenateAttrs); -Array ConcatenateCompute(const Attrs& attrs, - const Array& inputs, +Array ConcatenateCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { - const ConcatenateAttrs *param = attrs.as(); + const ConcatenateAttrs* param = attrs.as(); CHECK(param != nullptr); - return { topi::concatenate(inputs, param->axis) }; + return {topi::concatenate(inputs, param->axis)}; } -Expr MakeConcatenate(Expr data, - int axis) { +Expr MakeConcatenate(Expr data, int axis) { auto attrs = make_object(); attrs->axis = axis; static const Op& op = Op::Get("concatenate"); return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op._make.concatenate") -.set_body_typed(MakeConcatenate); +TVM_REGISTER_GLOBAL("relay.op._make.concatenate").set_body_typed(MakeConcatenate); RELAY_REGISTER_OP("concatenate") -.describe(R"code(Concatenate the input tensors along the given axis. + .describe(R"code(Concatenate the input tensors along the given axis. - **data** : A list of tensors. - **axis** : The axis along which the tensors are concatenated. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input list of tensors.") -.set_support_level(1) -.add_type_rel("Concatenate", ConcatenateRel) -.set_attr("FInferCorrectLayout", ConcatenateLayout) -.set_attr("FTVMCompute", ConcatenateCompute) -.set_attr("TOpPattern", kInjective); + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input list of tensors.") + .set_support_level(1) + .add_type_rel("Concatenate", ConcatenateRel) + .set_attr("FInferCorrectLayout", ConcatenateLayout) + .set_attr("FTVMCompute", ConcatenateCompute) + .set_attr("TOpPattern", kInjective); TVM_REGISTER_NODE_TYPE(StackAttrs); -bool StackRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool StackRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { // types: [data, result] CHECK_EQ(types.size(), 2); const auto* tensor_tuple = types[0].as(); if (tensor_tuple == nullptr) { CHECK(types[0].as()) - << "cast: expect input type to be TupleType but get " - << types[0]; + << "cast: expect input type to be TupleType but get " << types[0]; return false; } const auto* param = attrs.as(); @@ -324,11 +291,9 @@ bool StackRel(const Array& types, // Sanity check: axis int axis = param->axis; - CHECK(-ndim <= axis && axis < ndim) - << "stack only accepts `axis` in [-ndim, ndim)" - << ", but got axis = " << axis - << ", and ndim = " << ndim; - axis = axis < 0 ? ndim + axis + 1: axis; + CHECK(-ndim <= axis && axis < ndim) << "stack only accepts `axis` in [-ndim, ndim)" + << ", but got axis = " << axis << ", and ndim = " << ndim; + axis = axis < 0 ? ndim + axis + 1 : axis; // Sanity check: ndim and dtype. const DataType dtype = first->dtype; @@ -341,8 +306,9 @@ bool StackRel(const Array& types, for (size_t j = 0; j < first->shape.size(); ++j) { if (j == static_cast(axis)) continue; if (reporter->AssertEQ(first->shape[j], e->shape[j])) continue; - throw Error("relay.stack requires all tensors have the same shape " - "on non-stacking axes"); + throw Error( + "relay.stack requires all tensors have the same shape " + "on non-stacking axes"); } } @@ -361,55 +327,49 @@ bool StackRel(const Array& types, return true; } -Array StackCompute(const Attrs& attrs, - const Array& inputs, +Array StackCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { - const StackAttrs *param = attrs.as(); + const StackAttrs* param = attrs.as(); CHECK(param != nullptr); - return { topi::stack(inputs, param->axis) }; + return {topi::stack(inputs, param->axis)}; } -Expr MakeStack(Expr data, - int axis) { +Expr MakeStack(Expr data, int axis) { auto attrs = make_object(); attrs->axis = axis; static const Op& op = Op::Get("stack"); return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op._make.stack") -.set_body_typed(MakeStack); +TVM_REGISTER_GLOBAL("relay.op._make.stack").set_body_typed(MakeStack); RELAY_REGISTER_OP("stack") -.describe(R"code(Stack the input tensors along the given axis. + .describe(R"code(Stack the input tensors along the given axis. - **data** : A list of tensors. - **axis** : The axis along which the tensors are stacked. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input list of tensors.") -.set_support_level(3) -.add_type_rel("Stack", StackRel) -.set_attr("FTVMCompute", StackCompute) -.set_attr("TOpPattern", kInjective); + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input list of tensors.") + .set_support_level(3) + .add_type_rel("Stack", StackRel) + .set_attr("FTVMCompute", StackCompute) + .set_attr("TOpPattern", kInjective); /* relay.transpose */ TVM_REGISTER_NODE_TYPE(TransposeAttrs); -bool TransposeRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool TransposeRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { // types: [data, result] CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); if (data == nullptr) { CHECK(types[0].as()) - << "transpose: expect input type to be TensorType but get " - << types[0]; + << "transpose: expect input type to be TensorType but get " << types[0]; return false; } const auto* param = attrs.as(); @@ -417,8 +377,8 @@ bool TransposeRel(const Array& types, const Array& axes = param->axes; // check dimension match CHECK(!axes.defined() || static_cast(axes.size()) == ndim) - << "Dimension mismatch: axes has " << axes.size() << " elements" - << ", but data.ndim = " << ndim; + << "Dimension mismatch: axes has " << axes.size() << " elements" + << ", but data.ndim = " << ndim; // construct int_axes std::vector int_axes; int_axes.reserve(ndim); @@ -433,9 +393,8 @@ bool TransposeRel(const Array& types, int64_t axis = e; // sanity check for axis and ndim CHECK(-ndim <= axis && axis < ndim) - << "transpose only allows each `axis` in `axes` in range [-data.ndim, data.ndim)" - << ", but got axis = " << axis - << ", and data.ndim = " << ndim; + << "transpose only allows each `axis` in `axes` in range [-data.ndim, data.ndim)" + << ", but got axis = " << axis << ", and data.ndim = " << ndim; axis = axis < 0 ? axis + ndim : axis; // sanity check for duplication CHECK(!axis_used[axis]) << "Duplicate axes in transpose: " << axis; @@ -452,69 +411,83 @@ bool TransposeRel(const Array& types, return true; } -Array TransposeCompute(const Attrs& attrs, - const Array& inputs, +Array TransposeCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { const auto* param = attrs.as(); CHECK(param != nullptr); - return Array{ topi::transpose(inputs[0], param->axes) }; + return Array{topi::transpose(inputs[0], param->axes)}; } -Expr MakeTranspose(Expr data, - Array axes) { +Expr MakeTranspose(Expr data, Array axes) { auto attrs = make_object(); attrs->axes = std::move(axes); static const Op& op = Op::Get("transpose"); return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op._make.transpose") -.set_body_typed(MakeTranspose); +TVM_REGISTER_GLOBAL("relay.op._make.transpose").set_body_typed(MakeTranspose); RELAY_REGISTER_OP("transpose") -.describe(R"code(Permutes the dimensions of an array. + .describe(R"code(Permutes the dimensions of an array. - **data**: The input data to the operator. - **axes**: The target axes order, reverse order if not specified. )code" TVM_ADD_FILELINE) -.set_num_inputs(1) -.set_attrs_type() -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(3) -.add_type_rel("Transpose", TransposeRel) -.set_attr("FTVMCompute", TransposeCompute) -.set_attr("TOpPattern", kInjective); + .set_num_inputs(1) + .set_attrs_type() + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(3) + .add_type_rel("Transpose", TransposeRel) + .set_attr("FTVMCompute", TransposeCompute) + .set_attr("TOpPattern", kInjective); /* relay.reshape */ TVM_REGISTER_NODE_TYPE(ReshapeAttrs); -bool ReshapeRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool ReshapeRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { - // types: [data, result] - CHECK_EQ(types.size(), 2); + const auto* param = attrs.as(); + if (param->reverse) { + // types: [data, result] + CHECK_EQ(types.size(), 2); + } else { + // types: [data, newshape, result] + CHECK_EQ(types.size(), 3); + } const auto* data = types[0].as(); if (data == nullptr) { CHECK(types[0].as()) - << "reshape: expect input type to be TensorType but get " - << types[0]; + << "reshape: expect input type to be TensorType but get " << types[0]; return false; } - const auto* param = attrs.as(); + Array oshape; Array data_shape; Array newshape; - if (param->reverse) { - data_shape.assign(data->shape.rbegin(), data->shape.rend()); - newshape.assign(param->newshape.rbegin(), param->newshape.rend()); + + if (param->newshape) { + auto temp = param->newshape.value(); + if (param->reverse) { + data_shape.Assign(data->shape.rbegin(), data->shape.rend()); + newshape.Assign(temp.rbegin(), temp.rend()); + } else { + data_shape = data->shape; + newshape = temp; + } } else { - data_shape = data->shape; - newshape = param->newshape; + const auto* newshape = types[1].as(); + + // Doesn't support dynamic output rank + for (int i = 0; i < newshape->shape[0].as()->value; i++) { + oshape.push_back(Any()); + } + + reporter->Assign(types[2], TensorType(oshape, data->dtype)); + return true; } - Array oshape; + std::unordered_set used_input_dims; std::unordered_set used_output_dims; size_t src_idx = 0; @@ -534,8 +507,7 @@ bool ReshapeRel(const Array& types, oshape.push_back(data_shape[src_idx++]); } else if (svalue == -1) { // inference based on rest - CHECK_LT(infer_idx, 0) - << "One and only one dim can be inferred"; + CHECK_LT(infer_idx, 0) << "One and only one dim can be inferred"; infer_idx = i; oshape.push_back(1); ++src_idx; @@ -554,8 +526,8 @@ bool ReshapeRel(const Array& types, used_input_dims.insert(src_idx); IndexExpr d2 = data_shape[src_idx++]; used_output_dims.insert(oshape.size()); - if (d1.as() || d2.as()) { - oshape.push_back(Any::make()); + if (d1.as() || d2.as()) { + oshape.push_back(Any()); } else { oshape.push_back(d1 * d2); } @@ -569,11 +541,10 @@ bool ReshapeRel(const Array& types, Integer d1 = newshape[++i]; Integer d2 = newshape[++i]; if (d1->value == -1) { - CHECK(d2->value != -1) - << "Split dims cannot both be -1."; + CHECK(d2->value != -1) << "Split dims cannot both be -1."; used_output_dims.insert(oshape.size()); - if (d0.as()) { - oshape.push_back(Any::make()); + if (d0.as()) { + oshape.push_back(Any()); } else { oshape.push_back(indexdiv(d0, d2)); } @@ -584,8 +555,8 @@ bool ReshapeRel(const Array& types, oshape.push_back(d1); used_output_dims.insert(oshape.size()); if (d2->value == -1) { - if (d0.as()) { - oshape.push_back(Any::make()); + if (d0.as()) { + oshape.push_back(Any()); } else { oshape.push_back(indexdiv(d0, d1)); } @@ -604,19 +575,19 @@ bool ReshapeRel(const Array& types, if (used_input_dims.count(i) != 0) { continue; } - if (data_shape[i].as()) { - infer_dim = Any::make(); + if (data_shape[i].as()) { + infer_dim = Any(); break; } infer_dim *= data_shape[i]; } - if (!infer_dim.as()) { + if (!infer_dim.as()) { for (size_t i = 0; i < oshape.size(); ++i) { if (used_output_dims.count(i) != 0) { continue; } - if (oshape[i].as()) { - infer_dim = Any::make(); + if (oshape[i].as()) { + infer_dim = Any(); break; } infer_dim = indexdiv(infer_dim, oshape[i]); @@ -626,16 +597,15 @@ bool ReshapeRel(const Array& types, } if (param->reverse) { - reporter->Assign(types[1], TensorType( - Array(oshape.rbegin(), oshape.rend()), data->dtype)); + reporter->Assign(types[1], + TensorType(Array(oshape.rbegin(), oshape.rend()), data->dtype)); } else { - reporter->Assign(types[1], TensorType(oshape, data->dtype)); + reporter->Assign(types[2], TensorType(oshape, data->dtype)); } return true; } -Array ReshapeCompute(const Attrs& attrs, - const Array& inputs, +Array ReshapeCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { const auto* out_ttype = out_type.as(); CHECK(out_ttype != nullptr); @@ -647,23 +617,24 @@ Array ReshapeCompute(const Attrs& attrs, newshape.push_back(val); } } - return { topi::reshape(inputs[0], newshape) }; + return {topi::reshape(inputs[0], newshape)}; } -Expr MakeReshape(Expr data, - Array newshape) { +Expr MakeReshape(Expr data, Expr newshape) { auto attrs = make_object(); - attrs->newshape = std::move(newshape); + if (const ConstantNode* c = newshape.as()) { + CHECK_EQ(c->data->ndim, 1); + attrs->newshape = ToVector(c->data); + } attrs->reverse = false; static const Op& op = Op::Get("reshape"); - return Call(op, {data}, Attrs(attrs), {}); + return Call(op, {data, newshape}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op._make.reshape") -.set_body_typed(MakeReshape); +TVM_REGISTER_GLOBAL("relay.op._make.reshape").set_body_typed(MakeReshape); RELAY_REGISTER_OP("reshape") -.describe(R"code(Reshapes the input array. + .describe(R"code(Reshapes the input array. Example:: @@ -713,26 +684,24 @@ Example:: - data.shape = (2,3,4), newshape = (2,-4,-1,3,-2), result.shape = (2,1,3,4) )code" TVM_ADD_FILELINE) -.set_num_inputs(1) -.set_attrs_type() -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(3) -.add_type_rel("Reshape", ReshapeRel) -.set_attr("FTVMCompute", ReshapeCompute) -.set_attr("TOpPattern", kInjective); - + .set_num_inputs(2) + .set_attrs_type() + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("newshape", "Tensor", "The shape of output tensor.") + .set_support_level(3) + .add_type_rel("Reshape", ReshapeRel) + .set_attr("FTVMCompute", ReshapeCompute) + .set_attr("TOpPattern", kInjective); /*! -* \brief ReshapeLikeRel User defined type constraint function. -* \param num_inputs Number of input types in the args. -* \param attrs The additional attributes of the operator. -* \param reporter The reporter to report solution to. -* \return False if the relation has not been resolved, it might be resolved later. -* True if this relation has been resolved. -*/ -bool ReshapeLikeRel(const Array& types, - int num_inputs, - const Attrs& attrs, + * \brief ReshapeLikeRel User defined type constraint function. + * \param num_inputs Number of input types in the args. + * \param attrs The additional attributes of the operator. + * \param reporter The reporter to report solution to. + * \return False if the relation has not been resolved, it might be resolved later. + * True if this relation has been resolved. + */ +bool ReshapeLikeRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 3); const auto* data = types[0].as(); @@ -753,43 +722,36 @@ bool ReshapeLikeRel(const Array& types, } if (is_static_shape) { CHECK(reporter->AssertEQ(data->Size(), reshape_like->Size())) - << "Reshape inputs size should be compatible."; + << "Reshape inputs size should be compatible."; } reporter->Assign(types[2], TensorType(reshape_like->shape, data->dtype)); return true; } - -Expr MakeReshapeLike(Expr data, - Expr shape_like) { +Expr MakeReshapeLike(Expr data, Expr shape_like) { static const Op& op = Op::Get("reshape_like"); return Call(op, {data, shape_like}, Attrs(), {}); } - -TVM_REGISTER_GLOBAL("relay.op._make.reshape_like") -.set_body_typed(MakeReshapeLike); - +TVM_REGISTER_GLOBAL("relay.op._make.reshape_like").set_body_typed(MakeReshapeLike); RELAY_REGISTER_OP("reshape_like") -.describe(R"code(Reshapes the input array by the size of another array. + .describe(R"code(Reshapes the input array by the size of another array. For an input array with shape ``(d1, d2, ..., dk)``, `reshape_like` operation reshapes the input array into an output array with the same shape as the second input array. .. note:: Sizes for both array should be compatible. )code" TVM_ADD_FILELINE) -.set_num_inputs(2) -.add_argument("data", "Tensor", "The input tensor.") -.add_argument("shape_like", "Tensor", "Shape tensor.") -.set_support_level(3) -.add_type_rel("ReshapeLike", ReshapeLikeRel) -.set_attr("FTVMCompute", ReshapeCompute) -.set_attr("TOpPattern", kInjective); + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("shape_like", "Tensor", "Shape tensor.") + .set_support_level(3) + .add_type_rel("ReshapeLike", ReshapeLikeRel) + .set_attr("FTVMCompute", ReshapeCompute) + .set_attr("TOpPattern", kInjective); // ArgWhere -bool ArgWhereRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool ArgWhereRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(num_inputs, 1); auto tt = types[0].as(); @@ -797,41 +759,89 @@ bool ArgWhereRel(const Array& types, const auto& input_shape = tt->shape; const auto& input_rank = input_shape.size(); std::vector result_shape; - result_shape.push_back(Any::make()); + result_shape.push_back(Any()); result_shape.push_back(IntImm(DataType::Int(32), input_rank)); reporter->Assign(types[1], TensorType(result_shape, DataType::Int(32))); return true; } -TVM_REGISTER_GLOBAL("relay.op._make.argwhere") -.set_body_typed([](Expr data) { +TVM_REGISTER_GLOBAL("relay.op._make.argwhere").set_body_typed([](Expr data) { static const Op& op = Op::Get("argwhere"); return Call(op, {data}, Attrs(), {}); }); RELAY_REGISTER_OP("argwhere") -.describe(R"doc(Find the indices of elements of a tensor that are + .describe(R"doc(Find the indices of elements of a tensor that are non-zero)doc" TVM_ADD_FILELINE) -.set_num_inputs(1) -.add_argument("condition", "Tensor", "The input condition tensor.") -.add_type_rel("ArgWhere", ArgWhereRel) -.set_attr("TOpIsStateful", false) -.set_attr("TOpPattern", kOpaque) -.set_support_level(10); + .set_num_inputs(1) + .add_argument("condition", "Tensor", "The input condition tensor.") + .add_type_rel("ArgWhere", ArgWhereRel) + .set_attr("TOpIsStateful", false) + .set_attr("TOpPattern", kOpaque) + .set_support_level(10); + +// Scatter +TVM_REGISTER_NODE_TYPE(ScatterAttrs); + +// Scatter +bool ScatterRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(num_inputs, 3); + CHECK_EQ(types.size(), 4); + auto data = types[0].as(); + if (data == nullptr) { + return false; + } + auto indices = types[1].as(); + if (indices == nullptr) { + return false; + } + auto updates = types[2].as(); + if (updates == nullptr) { + return false; + } + CHECK(indices->dtype.is_int()) << "indices of take must be tensor of integer"; + const auto param = attrs.as(); + CHECK(param != nullptr); + reporter->Assign(types[3], TensorType(data->shape, data->dtype)); + return true; +} + +TVM_REGISTER_GLOBAL("relay.op._make.scatter") + .set_body_typed([](Expr data, Expr indices, Expr updates, int axis) { + auto attrs = make_object(); + attrs->axis = std::move(axis); + static const Op& op = Op::Get("scatter"); + return Call(op, {data, indices, updates}, Attrs(attrs), {}); + }); + +RELAY_REGISTER_OP("scatter") + .describe( + R"doc(Update data at positions defined by indices with values in updates)doc" TVM_ADD_FILELINE) + .set_num_inputs(3) + .add_argument("data", "Tensor", "The input data tensor.") + .add_argument("indicies", "Tensor", "The indicies location tensor.") + .add_argument("updates", "Tensor", "The values to update the input with.") + .add_type_rel("Scatter", ScatterRel) + .set_attr("TOpIsStateful", false) + .set_attr("TOpPattern", kOpaque) + .set_support_level(10); // Take TVM_REGISTER_NODE_TYPE(TakeAttrs); -bool TakeRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool TakeRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { // `types` contains: [data, indices, result] CHECK_EQ(types.size(), 3); const auto* data = types[0].as(); - CHECK(data != nullptr); + if (data == nullptr) { + return false; + } const auto* indices = types[1].as(); - CHECK(indices != nullptr); + if (indices == nullptr) { + return false; + } CHECK(indices->dtype.is_int()) << "indices of take must be tensor of integer"; const auto param = attrs.as(); CHECK(param != nullptr); @@ -847,9 +857,8 @@ bool TakeRel(const Array& types, const auto ndim_indices = static_cast(indices->shape.size()); int axis = static_cast(param->axis->value); if (axis < 0) axis += ndim_data; - CHECK_LE(axis, ndim_data) - << "axis should be with in data shape" - << ", but got = " << axis; + CHECK_LE(axis, ndim_data) << "axis should be with in data shape" + << ", but got = " << axis; oshape.reserve(ndim_data - 1 + ndim_indices); for (int i = 0; i < axis; ++i) { @@ -858,7 +867,7 @@ bool TakeRel(const Array& types, for (int i = 0; i < ndim_indices; ++i) { oshape.emplace_back(indices->shape[i]); } - for (int i = axis+1; i < ndim_data; ++i) { + for (int i = axis + 1; i < ndim_data; ++i) { oshape.emplace_back(data->shape[i]); } @@ -866,22 +875,18 @@ bool TakeRel(const Array& types, return true; } -Array TakeCompute(const Attrs& attrs, - const Array& inputs, +Array TakeCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { const auto* param = attrs.as(); CHECK(param != nullptr); if (!param->axis.defined()) { - return Array{ topi::take(inputs[0], inputs[1], param->mode) }; + return Array{topi::take(inputs[0], inputs[1], param->mode)}; } else { - return Array{ topi::take(inputs[0], inputs[1], param->axis, param->mode) }; + return Array{topi::take(inputs[0], inputs[1], param->axis, param->mode)}; } } -Expr MakeTake(Expr data, - Expr indices, - Integer axis, - std::string mode) { +Expr MakeTake(Expr data, Expr indices, Integer axis, String mode) { auto attrs = make_object(); attrs->axis = std::move(axis); attrs->mode = std::move(mode); @@ -889,11 +894,10 @@ Expr MakeTake(Expr data, return Call(op, {data, indices}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op._make.take") -.set_body_typed(MakeTake); +TVM_REGISTER_GLOBAL("relay.op._make.take").set_body_typed(MakeTake); RELAY_REGISTER_OP("take") -.describe(R"code(Take elements from an array along an axis. + .describe(R"code(Take elements from an array along an axis. When axis is not None, this function does the same thing as 'fancy' indexing (indexing arrays using arrays); however, it can be easier to use if you need @@ -915,26 +919,24 @@ Examples:: [ 4., 3.]] )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(2) -.add_argument("data", "Tensor", "The input tensor.") -.add_argument("indices", "Tensor", "The indices tensor.") -.set_support_level(3) -.add_type_rel("Take", TakeRel) -.set_attr("FTVMCompute", TakeCompute) -.set_attr("TOpPattern", kInjective); - + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("indices", "Tensor", "The indices tensor.") + .set_support_level(3) + .add_type_rel("Take", TakeRel) + .set_attr("FTVMCompute", TakeCompute) + .set_attr("TOpPattern", kInjective); // Init ops TVM_REGISTER_NODE_TYPE(InitOpAttrs); -bool FullRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool FullRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { - CHECK_EQ(types.size(), 2); + CHECK_EQ(types.size(), 3); const InitOpAttrs* param = attrs.as(); const auto* fill_value = types[0].as(); + const auto* fill_shape = types[1].as(); if (fill_value == nullptr) { return false; } @@ -945,101 +947,127 @@ bool FullRel(const Array& types, } CHECK_EQ(fill_value->shape.size(), 0) - << "Fill value should be a scalar but has dimension " - << fill_value->shape.size() << "."; + << "Fill value should be a scalar but has dimension " << fill_value->shape.size() << "."; - reporter->Assign(types[1], TensorType(param->shape, out_dtype)); + const IntImmNode* shape_shape = fill_shape->shape[0].as(); + CHECK(shape_shape) << "Parameter shape must have static shape"; + + std::vector oshape; + if (param->shape) { + const Array& cshape_array = param->shape.value(); + for (size_t i = 0; i < cshape_array.size(); ++i) { + oshape.push_back(cshape_array[i]); + } + } else { + for (int i = 0; i < shape_shape->value; ++i) { + oshape.push_back(Any()); + } + } + reporter->Assign(types[2], TensorType(oshape, out_dtype)); return true; } -Array FullCompute(const Attrs& attrs, - const Array& inputs, +Array FullCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { const auto* out_ttype = out_type.as(); - return { topi::full(out_ttype->shape, out_ttype->dtype, inputs[0]()) }; + return {topi::full(out_ttype->shape, out_ttype->dtype, inputs[0]())}; } -Expr MakeFull(Expr fill_value, - Array shape, - DataType dtype) { +Expr MakeFull(Expr fill_value, Expr shape, DataType dtype) { auto attrs = make_object(); - attrs->shape = std::move(shape); + if (const auto* cshape = shape.as()) { + attrs->shape = ToVector(cshape->data); + } attrs->dtype = std::move(dtype); static const Op& op = Op::Get("full"); - return Call(op, {fill_value}, Attrs(attrs), {}); + return Call(op, {fill_value, shape}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op._make.full") -.set_body_typed(MakeFull); +TVM_REGISTER_GLOBAL("relay.op._make.full").set_body_typed(MakeFull); RELAY_REGISTER_OP("full") -.describe(R"code(Fill array with scalar value. + .describe(R"code(Fill array with scalar value. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("fill_value", "double", "The value to fill.") -.set_support_level(3) -.add_type_rel("Full", FullRel) -.set_attr("FTVMCompute", FullCompute) -.set_attr("TOpPattern", kElemWise); - -bool InitOpRel(const Array& types, - int num_inputs, - const Attrs& attrs, + .set_attrs_type() + .set_num_inputs(2) + .add_argument("fill_value", "double", "The value to fill.") + .add_argument("shape", "Tensor", "Target shape.") + .set_support_level(3) + .add_type_rel("Full", FullRel) + .set_attr("FTVMCompute", FullCompute) + .set_attr("TOpPattern", kElemWise); + +bool InitOpRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { - CHECK_EQ(types.size(), 1); + CHECK_EQ(types.size(), 2); const InitOpAttrs* param = attrs.as(); + const auto* fill_shape = types[0].as(); + DataType out_dtype = param->dtype; - reporter->Assign(types[0], TensorType(param->shape, param->dtype)); + const IntImmNode* shape_shape = fill_shape->shape[0].as(); + CHECK(shape_shape) << "Parameter shape must have static shape"; + + std::vector oshape; + if (param->shape) { + const Array& cshape_array = param->shape.value(); + for (size_t i = 0; i < cshape_array.size(); ++i) { + oshape.push_back(cshape_array[i]); + } + } else { + for (int i = 0; i < shape_shape->value; ++i) { + oshape.push_back(Any()); + } + } + reporter->Assign(types[1], TensorType(oshape, out_dtype)); return true; } -Expr MakeZeros(Array shape, - DataType dtype) { +Expr MakeZeros(Expr shape, DataType dtype) { auto attrs = make_object(); - attrs->shape = std::move(shape); + if (const auto* cshape = shape.as()) { + attrs->shape = ToVector(cshape->data); + } attrs->dtype = std::move(dtype); static const Op& op = Op::Get("zeros"); - return Call(op, {}, Attrs(attrs), {}); + return Call(op, {shape}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op._make.zeros") -.set_body_typed(MakeZeros); +TVM_REGISTER_GLOBAL("relay.op._make.zeros").set_body_typed(MakeZeros); RELAY_REGISTER_OP("zeros") -.describe(R"code(Fill array with zeros. + .describe(R"code(Fill array with zeros. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(0) -.set_support_level(3) -.add_type_rel("InitOp", InitOpRel); + .set_attrs_type() + .set_num_inputs(1) + .add_argument("shape", "Tensor", "Target shape.") + .set_support_level(3) + .add_type_rel("InitOp", InitOpRel); -Expr MakeOnes(Array shape, - DataType dtype) { +Expr MakeOnes(Expr shape, DataType dtype) { auto attrs = make_object(); - attrs->shape = std::move(shape); + if (const auto* cshape = shape.as()) { + attrs->shape = ToVector(cshape->data); + } attrs->dtype = std::move(dtype); static const Op& op = Op::Get("ones"); - return Call(op, {}, Attrs(attrs), {}); + return Call(op, {shape}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op._make.ones") -.set_body_typed(MakeOnes); +TVM_REGISTER_GLOBAL("relay.op._make.ones").set_body_typed(MakeOnes); RELAY_REGISTER_OP("ones") -.describe(R"code(Fill array with ones. + .describe(R"code(Fill array with ones. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(0) -.set_support_level(3) -.add_type_rel("InitOp", InitOpRel); - -bool FullLikeRel(const Array& types, - int num_inputs, - const Attrs& attrs, + .set_attrs_type() + .set_num_inputs(1) + .add_argument("shape", "Tensor", "Target shape.") + .set_support_level(3) + .add_type_rel("InitOp", InitOpRel); + +bool FullLikeRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 3); const auto* data = types[0].as(); @@ -1052,85 +1080,42 @@ bool FullLikeRel(const Array& types, } CHECK_EQ(fill_value->shape.size(), 0) - << "The fill value should be a scalar but here it has dimension " - << fill_value->shape.size() << "."; + << "The fill value should be a scalar but here it has dimension " << fill_value->shape.size() + << "."; reporter->Assign(types[2], TensorType(data->shape, data->dtype)); return true; } -Array FullLikeCompute(const Attrs& attrs, - const Array& inputs, +Array FullLikeCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { - return { topi::full_like(inputs[0], inputs[1]()) }; + return {topi::full_like(inputs[0], inputs[1]())}; } -Expr MakeFullLike(Expr data, - Expr fill_value) { +Expr MakeFullLike(Expr data, Expr fill_value) { static const Op& op = Op::Get("full_like"); return Call(op, {data, fill_value}, Attrs(), {}); } -TVM_REGISTER_GLOBAL("relay.op._make.full_like") -.set_body_typed(MakeFullLike); +TVM_REGISTER_GLOBAL("relay.op._make.full_like").set_body_typed(MakeFullLike); RELAY_REGISTER_OP("full_like") -.describe(R"code(Return an scalar value array with the same shape + .describe(R"code(Return an scalar value array with the same shape and type as the input array. )code" TVM_ADD_FILELINE) -.set_num_inputs(2) -.add_argument("data", "Tensor", "The input tensor.") -.add_argument("fill_value", "double", "Scalar value to fill.") -.set_support_level(3) -.add_type_rel("FullLike", FullLikeRel) -.set_attr("FTVMCompute", FullLikeCompute) -.set_attr("TOpPattern", kElemWise); + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("fill_value", "double", "Scalar value to fill.") + .set_support_level(3) + .add_type_rel("FullLike", FullLikeRel) + .set_attr("FTVMCompute", FullLikeCompute) + .set_attr("TOpPattern", kElemWise); // arange operator TVM_REGISTER_NODE_TYPE(ArangeAttrs); -double ToScalar(const runtime::NDArray& array) { - if (array->dtype.code == kDLInt) { - if (array->dtype.bits == 8) { - return reinterpret_cast(array->data)[0]; - } else if (array->dtype.bits == 16) { - return reinterpret_cast(array->data)[0]; - } else if (array->dtype.bits == 32) { - return reinterpret_cast(array->data)[0]; - } else if (array->dtype.bits == 64) { - return reinterpret_cast(array->data)[0]; - } - } else if (array->dtype.code == kDLUInt) { - if (array->dtype.bits == 8) { - return reinterpret_cast(array->data)[0]; - } else if (array->dtype.bits == 16) { - return reinterpret_cast(array->data)[0]; - } else if (array->dtype.bits == 32) { - return reinterpret_cast(array->data)[0]; - } else if (array->dtype.bits == 64) { - return reinterpret_cast(array->data)[0]; - } - } else if (array->dtype.code == kDLFloat) { -#if (__ARM_FP16_FORMAT_IEEE == 1) - if (array->dtype.bits == 16) { - return reinterpret_cast<__fp16*>(array->data)[0]; - } -#endif - if (array->dtype.bits == 32) { - return reinterpret_cast(array->data)[0]; - } else if (array->dtype.bits == 64) { - return reinterpret_cast(array->data)[0]; - } - } - LOG(FATAL) << "Unknown data type: " << tvm::runtime::DLDataType2String(array->dtype); - // make compiler happy - return -std::numeric_limits::infinity(); -} - -bool ArangeRel(const Array& types, - int num_inputs, - const Attrs& raw_attrs, +bool ArangeRel(const Array& types, int num_inputs, const Attrs& raw_attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 4); const ArangeAttrs* attrs = raw_attrs.as(); @@ -1140,50 +1125,46 @@ bool ArangeRel(const Array& types, reporter->Assign(types[1], types[2]); reporter->Assign(types[2], TensorType({}, attrs->dtype)); - if ((cstart = attrs->start.as()) && - (cstop = attrs->stop.as()) && + if ((cstart = attrs->start.as()) && (cstop = attrs->stop.as()) && (cstep = attrs->step.as())) { double start = ToScalar(cstart->data); double stop = ToScalar(cstop->data); double step = ToScalar(cstep->data); int32_t num_elem = static_cast(std::ceil((stop - start) / step)); - CHECK_GT(num_elem, 0) - << "Invalid arange attributes (start, stop, step): " << attrs->start - << ", " << attrs->stop << ", " << attrs->step; + CHECK_GT(num_elem, 0) << "Invalid arange attributes (start, stop, step): " << attrs->start + << ", " << attrs->stop << ", " << attrs->step; reporter->Assign(types[3], TensorType({num_elem}, attrs->dtype)); return true; } else { - reporter->Assign(types[3], TensorType({Any::make()}, attrs->dtype)); + reporter->Assign(types[3], TensorType({Any()}, attrs->dtype)); return true; } } -inline te::Tensor DynamicArange(const te::Tensor& start, - const te::Tensor& stop, - const te::Tensor& step, - tvm::DataType dtype, - std::string name = "tensor", - std::string tag = topi::kInjective) { +inline te::Tensor DynamicArange(const te::Tensor& start, const te::Tensor& stop, + const te::Tensor& step, tvm::DataType dtype, + std::string name = "T_arange_dynamic", + std::string tag = topi::kInjective) { tvm::PrimExpr num_elem = tvm::tir::Var("num_elem"); - return te::compute({num_elem}, [&](const Array& indices) { - return tvm::cast(dtype, start[0] + step[0] * indices[0]); - }, name, tag); + return te::compute( + {num_elem}, + [&](const Array& indices) { + return tvm::cast(dtype, start[0] + step[0] * indices[0]); + }, + name, tag); } -Array ArangeCompute(const Attrs& attrs, - const Array& inputs, +Array ArangeCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { const ArangeAttrs* param = attrs.as(); + CHECK(param != nullptr); te::Tensor start = inputs[0]; - te::Tensor stop = inputs[1]; + te::Tensor stop = inputs[1]; te::Tensor step = inputs[2]; - return { DynamicArange(start, stop, step, param->dtype) }; + return {DynamicArange(start, stop, step, param->dtype)}; } -Expr MakeArange(Expr start, - Expr stop, - Expr step, - DataType dtype) { +Expr MakeArange(Expr start, Expr stop, Expr step, DataType dtype) { auto attrs = make_object(); attrs->start = start; attrs->stop = stop; @@ -1193,8 +1174,7 @@ Expr MakeArange(Expr start, return Call(op, {start, stop, step}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op._make.arange") -.set_body_typed(MakeArange); +TVM_REGISTER_GLOBAL("relay.op._make.arange").set_body_typed(MakeArange); // An issue with the existing design is that we require dependency // to type the operator precisely. @@ -1210,45 +1190,40 @@ TVM_REGISTER_GLOBAL("relay.op._make.arange") // In general I think we should avoid this pattern, and introduce // a secondary shape analysis to recover more precise information. RELAY_REGISTER_OP("arange") -.describe(R"code(Returns evenly spaced values within a given interval. + .describe(R"code(Returns evenly spaced values within a given interval. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(3) -.set_support_level(3) -.add_type_rel("Arange", ArangeRel) -.set_attr("FTVMCompute", ArangeCompute) -// TODO(@icemelon): Change arange to kOpaque because FuseOps doesn't consider dynamic shape -.set_attr("TOpPattern", kOpaque) -.set_attr("AnyCodegenStrategy", kVariableDimensions); + .set_attrs_type() + .set_num_inputs(3) + .set_support_level(3) + .add_type_rel("Arange", ArangeRel) + .set_attr("FTVMCompute", ArangeCompute) + // TODO(@icemelon): Change arange to kOpaque because FuseOps doesn't consider dynamic shape + .set_attr("TOpPattern", kOpaque) + .set_attr("AnyCodegenStrategy", kVariableDimensions); // repeat operator TVM_REGISTER_NODE_TYPE(RepeatAttrs); -bool RepeatRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool RepeatRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { // `types` contains: [data, result] CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); if (data == nullptr) { CHECK(types[0].as()) - << "repeat: expect input type to be TensorType but get " - << types[0]; + << "repeat: expect input type to be TensorType but get " << types[0]; return false; } const auto* param = attrs.as(); const int ndim = static_cast(data->shape.size()); const int repeats = param->repeats; const int axis = param->axis; - CHECK(repeats >= 1) - << "repeat only accepts `repeats >= 1`" - << ", but got repeats = " << repeats; + CHECK(repeats >= 1) << "repeat only accepts `repeats >= 1`" + << ", but got repeats = " << repeats; CHECK(-ndim - 1 <= axis && axis <= ndim) - << "repeat only accepts `axis` in [-data.ndim - 1, data.ndim]" - << ", but got axis = " << axis - << ", and data.ndim = " << ndim; + << "repeat only accepts `axis` in [-data.ndim - 1, data.ndim]" + << ", but got axis = " << axis << ", and data.ndim = " << ndim; const int pivot = axis < 0 ? ndim + axis : axis; std::vector oshape; oshape.reserve(ndim + repeats); @@ -1263,17 +1238,14 @@ bool RepeatRel(const Array& types, return true; } -Array RepeatCompute(const Attrs& attrs, - const Array& inputs, +Array RepeatCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { - const RepeatAttrs *param = attrs.as(); + const RepeatAttrs* param = attrs.as(); CHECK(param != nullptr); - return { topi::repeat(inputs[0], param->repeats, param->axis) }; + return {topi::repeat(inputs[0], param->repeats, param->axis)}; } -Expr MakeRepeat(Expr data, - int repeats, - int axis) { +Expr MakeRepeat(Expr data, int repeats, int axis) { auto attrs = make_object(); attrs->repeats = repeats; attrs->axis = axis; @@ -1281,50 +1253,45 @@ Expr MakeRepeat(Expr data, return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op._make.repeat") -.set_body_typed(MakeRepeat); +TVM_REGISTER_GLOBAL("relay.op._make.repeat").set_body_typed(MakeRepeat); RELAY_REGISTER_OP("repeat") -.describe(R"code(Repeat elements of an array `repeats` times along axis `axis` + .describe(R"code(Repeat elements of an array `repeats` times along axis `axis` - **data**: The input data to the operator. )code" TVM_ADD_FILELINE) -.set_num_inputs(1) -.set_attrs_type() -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(3) -.add_type_rel("Repeat", RepeatRel) -.set_attr("FTVMCompute", RepeatCompute) -.set_attr("TOpPattern", kBroadcast); + .set_num_inputs(1) + .set_attrs_type() + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(3) + .add_type_rel("Repeat", RepeatRel) + .set_attr("FTVMCompute", RepeatCompute) + .set_attr("TOpPattern", kBroadcast); // tile operator TVM_REGISTER_NODE_TYPE(TileAttrs); -bool TileRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool TileRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { // `types` contains: [data, result] CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); if (data == nullptr) { CHECK(types[0].as()) - << "tile: expect input type to be TensorType but get " - << types[0]; + << "tile: expect input type to be TensorType but get " << types[0]; return false; } const auto* param = attrs.as(); const size_t ndim = data->shape.size(); const Array& reps = param->reps; // check dimension match - CHECK(reps.defined()) - << "repetition array is not defined. data.ndim = " << ndim; + CHECK(reps.defined()) << "repetition array is not defined. data.ndim = " << ndim; const size_t rndim = reps.size(); for (size_t i = 0; i < rndim; ++i) { if (const tvm::tir::IntImmNode* val = reps[i].as()) { - CHECK_GT(val->value, 0) - << "Tile reps value should always be larger than 0, but get: " << val->value; + CHECK_GT(val->value, 0) << "Tile reps value should always be larger than 0, but get: " + << val->value; } } size_t tndim = (ndim > rndim) ? ndim : rndim; @@ -1364,7 +1331,7 @@ bool TileRel(const Array& types, for (size_t i = 0; i < tndim; ++i) { // Save Any if it is dynamic shape if (!data_shape[i].as()) { - oshape.emplace_back(Any::make()); + oshape.emplace_back(Any()); } else { oshape.emplace_back(data_shape[i] * reps_shape[i]); } @@ -1373,103 +1340,91 @@ bool TileRel(const Array& types, return true; } -Array TileCompute(const Attrs& attrs, - const Array& inputs, +Array TileCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { - const TileAttrs *param = attrs.as(); + const TileAttrs* param = attrs.as(); CHECK(param != nullptr); - return { topi::tile(inputs[0], param->reps) }; + return {topi::tile(inputs[0], param->reps)}; } -Expr MakeTile(Expr data, - Array reps) { +Expr MakeTile(Expr data, Array reps) { auto attrs = make_object(); attrs->reps = reps; static const Op& op = Op::Get("tile"); return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op._make.tile") -.set_body_typed(MakeTile); +TVM_REGISTER_GLOBAL("relay.op._make.tile").set_body_typed(MakeTile); RELAY_REGISTER_OP("tile") -.describe(R"code(Repeat the whole array multiple times. + .describe(R"code(Repeat the whole array multiple times. - **data**: The input data to the operator. )code" TVM_ADD_FILELINE) -.set_num_inputs(1) -.set_attrs_type() -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(3) -.add_type_rel("Tile", TileRel) -.set_attr("FTVMCompute", TileCompute) -.set_attr("TOpPattern", kBroadcast); + .set_num_inputs(1) + .set_attrs_type() + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(3) + .add_type_rel("Tile", TileRel) + .set_attr("FTVMCompute", TileCompute) + .set_attr("TOpPattern", kBroadcast); // reverse operator TVM_REGISTER_NODE_TYPE(ReverseAttrs); -bool ReverseRel(const Array& types, - int num_inputs, - const Attrs& attrs, - const TypeReporter& reporter) { +bool ReverseRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { // `types` contains: [data, result] CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); if (data == nullptr) { CHECK(types[0].as()) - << "reverse: expect input type to be TensorType but get " - << types[0]; + << "reverse: expect input type to be TensorType but get " << types[0]; return false; } const auto* param = attrs.as(); const int ndim = static_cast(data->shape.size()); const int axis = param->axis; CHECK(-ndim <= axis && axis < ndim) - << "reverse only accepts `axis` in [-data.ndim, data.ndim - 1]" - << ", but got axis = " << axis - << ", and data.ndim = " << ndim; + << "reverse only accepts `axis` in [-data.ndim, data.ndim - 1]" + << ", but got axis = " << axis << ", and data.ndim = " << ndim; reporter->Assign(types[1], types[0]); return true; } -Array ReverseCompute(const Attrs& attrs, - const Array& inputs, +Array ReverseCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { - const ReverseAttrs *param = attrs.as(); + const ReverseAttrs* param = attrs.as(); CHECK(param != nullptr); - return { topi::flip(inputs[0], param->axis) }; + return {topi::flip(inputs[0], param->axis)}; } -Expr MakeReverse(Expr data, - int axis) { +Expr MakeReverse(Expr data, int axis) { auto attrs = make_object(); attrs->axis = axis; static const Op& op = Op::Get("reverse"); return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op._make.reverse") -.set_body_typed(MakeReverse); +TVM_REGISTER_GLOBAL("relay.op._make.reverse").set_body_typed(MakeReverse); RELAY_REGISTER_OP("reverse") -.describe(R"code(Reverses the order of elements along given `axis` while preserving array shape. + .describe(R"code(Reverses the order of elements along given `axis` while preserving array shape. - **data**: The input data to the operator. )code" TVM_ADD_FILELINE) -.set_num_inputs(1) -.set_attrs_type() -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(3) -.add_type_rel("Reverse", ReverseRel) -.set_attr("FTVMCompute", ReverseCompute) -.set_attr("TOpPattern", kInjective); + .set_num_inputs(1) + .set_attrs_type() + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(3) + .add_type_rel("Reverse", ReverseRel) + .set_attr("FTVMCompute", ReverseCompute) + .set_attr("TOpPattern", kInjective); // where operator -bool WhereRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool WhereRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 4U); const auto* condition = types[0].as(); @@ -1483,17 +1438,16 @@ bool WhereRel(const Array& types, CHECK(x_shape.size() == y_shape.size()) << "x and y must have the same size"; if (cond_shape.size() != x_shape.size()) { - CHECK_EQ(cond_shape.size(), 1) - << "Shape of condition " << condition->shape - << " must be either equal to x or has dimension of 1."; + CHECK_EQ(cond_shape.size(), 1) << "Shape of condition " << condition->shape + << " must be either equal to x or has dimension of 1."; } for (size_t i = 0; i < x_shape.size(); i++) { CHECK(reporter->AssertEQ(x_shape[i], y_shape[i])) << "x and y must have the same shape: " << x_shape << " vs " << y_shape; if (i < cond_shape.size()) { - CHECK(reporter->AssertEQ(cond_shape[i], x_shape[i])) - << "condition and x must have the same shape: " << cond_shape << " vs " << x_shape; + CHECK(reporter->AssertEQ(cond_shape[i], x_shape[i])) + << "condition and x must have the same shape: " << cond_shape << " vs " << x_shape; } } reporter->Assign(types[3], TensorType(x_shape, x->dtype)); @@ -1506,17 +1460,15 @@ Expr MakeWhere(const Expr& condition, const Expr& x, const Expr& y) { return Call(op, {condition, x, y}); } -Array WhereCompute(const Attrs& attrs, - const Array& inputs, +Array WhereCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { - return { topi::where(inputs[0], inputs[1], inputs[2]) }; + return {topi::where(inputs[0], inputs[1], inputs[2])}; } -TVM_REGISTER_GLOBAL("relay.op._make.where") -.set_body_typed(MakeWhere); +TVM_REGISTER_GLOBAL("relay.op._make.where").set_body_typed(MakeWhere); RELAY_REGISTER_OP("where") -.describe(R"code( + .describe(R"code( Return the elements, either from x or y, depending on the condition. Given three ndarrays, condition, x, and y, return an ndarray with the elements @@ -1544,34 +1496,28 @@ Examples:: where(cond, x, y) = [[1, 2], [7, 8]] )code" TVM_ADD_FILELINE) -.add_argument("condition", "Tensor", "Condition array") -.add_argument("x", "Tensor", "First array to be selected") -.add_argument("y", "Tensor", "Second array to be selected") -.set_num_inputs(3) -.set_support_level(4) -.add_type_rel("Where", WhereRel) -.set_attr("FTVMCompute", WhereCompute) -.set_attr("TOpPattern", kBroadcast); - + .add_argument("condition", "Tensor", "Condition array") + .add_argument("x", "Tensor", "First array to be selected") + .add_argument("y", "Tensor", "Second array to be selected") + .set_num_inputs(3) + .set_support_level(4) + .add_type_rel("Where", WhereRel) + .set_attr("FTVMCompute", WhereCompute) + .set_attr("TOpPattern", kBroadcast); // Squeeze TVM_REGISTER_NODE_TYPE(SqueezeAttrs); -Expr MakeSqueeze(Expr data, - Array axis) { +Expr MakeSqueeze(Expr data, Array axis) { auto attrs = make_object(); attrs->axis = std::move(axis); static const Op& op = Op::Get("squeeze"); return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op._make.squeeze") -.set_body_typed(MakeSqueeze); - +TVM_REGISTER_GLOBAL("relay.op._make.squeeze").set_body_typed(MakeSqueeze); -bool SqueezeRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool SqueezeRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); @@ -1595,7 +1541,7 @@ bool SqueezeRel(const Array& types, } } else { // pair up original shape with a boolean which control whether it will be in the final shape. - std::vector > original_shape; + std::vector> original_shape; for (const auto& e : data->shape) { original_shape.push_back(std::pair(e, true)); } @@ -1622,268 +1568,267 @@ bool SqueezeRel(const Array& types, return true; } -Array SqueezeCompute(const Attrs& attrs, - const Array& inputs, +Array SqueezeCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { - const SqueezeAttrs *param = attrs.as(); + const SqueezeAttrs* param = attrs.as(); CHECK(param != nullptr); - return { topi::squeeze(inputs[0], param->axis) }; + return {topi::squeeze(inputs[0], param->axis)}; } - RELAY_REGISTER_OP("squeeze") -.describe(R"code(Squeeze the input tensor at the dimensions given by axes + .describe(R"code(Squeeze the input tensor at the dimensions given by axes - **data**: The input data to the operator. )code" TVM_ADD_FILELINE) -.set_num_inputs(1) -.set_attrs_type() -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(3) -.add_type_rel("Squeeze", SqueezeRel) -.set_attr("FTVMCompute", SqueezeCompute) -.set_attr("TOpPattern", kInjective); - + .set_num_inputs(1) + .set_attrs_type() + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(3) + .add_type_rel("Squeeze", SqueezeRel) + .set_attr("FTVMCompute", SqueezeCompute) + .set_attr("TOpPattern", kInjective); // CollapseSumLike: -> B where BroadCast(A, B) = A -bool CollapseSumLikeRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool CollapseSumLikeRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 3); reporter->Assign(types[2], types[1]); return BroadcastRel({types[0], types[1], types[0]}, 2, Attrs(), reporter); } -Expr MakeCollapseSumLike(Expr data, - Expr collapse_type) { +Expr MakeCollapseSumLike(Expr data, Expr collapse_type) { static const Op& op = Op::Get("collapse_sum_like"); return Call(op, {data, collapse_type}, Attrs(), {}); } -Array CollapseSumLikeCompute(const Attrs& attrs, - const Array& inputs, +Array CollapseSumLikeCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { const auto* out_ttype = out_type.as(); CHECK(out_ttype != nullptr); - return { topi::collapse_sum(inputs[0], out_ttype->shape) }; + return {topi::collapse_sum(inputs[0], out_ttype->shape)}; } -TVM_REGISTER_GLOBAL("relay.op._make.collapse_sum_like") -.set_body_typed(MakeCollapseSumLike); +TVM_REGISTER_GLOBAL("relay.op._make.collapse_sum_like").set_body_typed(MakeCollapseSumLike); RELAY_REGISTER_OP("collapse_sum_like") -.describe(R"code(Collapse the first input to match the shape of the second input. + .describe(R"code(Collapse the first input to match the shape of the second input. )code" TVM_ADD_FILELINE) -.set_num_inputs(2) -.add_argument("data", "Tensor", "The input tensor.") -.add_argument("collapse_type", "Tensor", "Provide the type to collapse to.") -.set_support_level(10) -.add_type_rel("CollapseSumLike", CollapseSumLikeRel) -.set_attr("FTVMCompute", CollapseSumLikeCompute) -.set_attr("TOpPattern", kCommReduce); + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("collapse_type", "Tensor", "Provide the type to collapse to.") + .set_support_level(10) + .add_type_rel("CollapseSumLike", CollapseSumLikeRel) + .set_attr("FTVMCompute", CollapseSumLikeCompute) + .set_attr("TOpPattern", kCommReduce); // BroadCastTo: -> B where BroadCast(A, B) = B -bool BroadCastToRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool BroadCastToRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { - CHECK_EQ(types.size(), 2); - auto ioattrs = attrs.as(); - CHECK(ioattrs); - auto intt = types[0].as(); - if (intt == nullptr) { return false; } - auto type = TensorType(ioattrs->shape, intt->dtype); - reporter->Assign(types[1], type); - return BroadcastRel({types[0], types[1], types[1]}, 2, Attrs(), reporter); + CHECK_EQ(types.size(), 3); + const InitOpAttrs* param = attrs.as(); + const auto* target_shape = types[1].as(); + DataType out_dtype = types[0].as()->dtype; + + const IntImmNode* shape_shape = target_shape->shape[0].as(); + CHECK(shape_shape) << "Parameter shape must have static shape"; + + std::vector oshape; + if (param->shape) { + const Array& cshape_array = param->shape.value(); + for (size_t i = 0; i < cshape_array.size(); ++i) { + oshape.push_back(cshape_array[i]); + } + } else { + for (int i = 0; i < shape_shape->value; ++i) { + oshape.push_back(Any()); + } + } + reporter->Assign(types[2], TensorType(oshape, out_dtype)); + return BroadcastRel({types[0], types[2], types[2]}, 2, Attrs(), reporter); } -Expr MakeBroadCastTo(Expr data, Array shape) { +Expr MakeBroadCastTo(Expr data, Expr shape) { static const Op& op = Op::Get("broadcast_to"); auto attrs = make_object(); - attrs->shape = std::move(shape); - return Call(op, {data}, Attrs(attrs), {}); + if (const auto* cshape = shape.as()) { + attrs->shape = ToVector(cshape->data); + } + return Call(op, {data, shape}, Attrs(attrs), {}); } -Array BroadCastToCompute(const Attrs& attrs, - const Array& inputs, +Array BroadCastToCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { - auto ioattrs = attrs.as(); - CHECK(ioattrs != nullptr); - return { topi::broadcast_to(inputs[0], ioattrs->shape) }; + const auto* out_ttype = out_type.as(); + return {topi::broadcast_to(inputs[0], out_ttype->shape)}; } -TVM_REGISTER_GLOBAL("relay.op._make.broadcast_to") -.set_body_typed(MakeBroadCastTo); +TVM_REGISTER_GLOBAL("relay.op._make.broadcast_to").set_body_typed(MakeBroadCastTo); RELAY_REGISTER_OP("broadcast_to") -.describe(R"code(Broadcast the first input to match the shape argument. + .describe(R"code(Broadcast the first input to match the shape argument. )code" TVM_ADD_FILELINE) -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(4) -.add_type_rel("BroadCastTo", BroadCastToRel) -.set_attr("FTVMCompute", BroadCastToCompute) -.set_attr("TOpPattern", kBroadcast); + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("shape", "Tensor", "Target shape.") + .set_support_level(4) + .add_type_rel("BroadCastTo", BroadCastToRel) + .set_attr("FTVMCompute", BroadCastToCompute) + .set_attr("TOpPattern", kBroadcast); // BroadCastToLike: -> B where BroadCast(A, B) = B -bool BroadCastToLikeRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool BroadCastToLikeRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 3); reporter->Assign(types[2], types[1]); return BroadcastRel({types[0], types[1], types[1]}, 2, Attrs(), reporter); } -Expr MakeBroadCastToLike(Expr data, - Expr broadcast_type) { +Expr MakeBroadCastToLike(Expr data, Expr broadcast_type) { static const Op& op = Op::Get("broadcast_to_like"); return Call(op, {data, broadcast_type}, Attrs(), {}); } -Array BroadCastToLikeCompute(const Attrs& attrs, - const Array& inputs, +Array BroadCastToLikeCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { const auto* out_ttype = out_type.as(); CHECK(out_ttype != nullptr); - return { topi::broadcast_to(inputs[0], out_ttype->shape) }; + return {topi::broadcast_to(inputs[0], out_ttype->shape)}; } -TVM_REGISTER_GLOBAL("relay.op._make.broadcast_to_like") -.set_body_typed(MakeBroadCastToLike); +TVM_REGISTER_GLOBAL("relay.op._make.broadcast_to_like").set_body_typed(MakeBroadCastToLike); RELAY_REGISTER_OP("broadcast_to_like") -.describe(R"code(Broadcast the first input to match the shape of the second input. + .describe(R"code(Broadcast the first input to match the shape of the second input. )code" TVM_ADD_FILELINE) -.set_num_inputs(2) -.add_argument("data", "Tensor", "The input tensor.") -.add_argument("broadcast_type", "Tensor", "Provide the type to broadcast to.") -.set_support_level(10) -.add_type_rel("BroadCastToLike", BroadCastToLikeRel) -.set_attr("FTVMCompute", BroadCastToLikeCompute) -.set_attr("TOpPattern", kBroadcast); - + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("broadcast_type", "Tensor", "Provide the type to broadcast to.") + .set_support_level(10) + .add_type_rel("BroadCastToLike", BroadCastToLikeRel) + .set_attr("FTVMCompute", BroadCastToLikeCompute) + .set_attr("TOpPattern", kBroadcast); // Adapter function to make int array. Array GetIntArray(Array arr) { for (size_t i = 0; i < arr.size(); ++i) { - CHECK(!arr[i].defined() || arr[i].as()) - << "Expect an int array"; + CHECK(!arr[i].defined() || arr[i].as()) << "Expect an int array"; } - return Downcast >(arr); + return Downcast>(arr); } - // strided_slice TVM_REGISTER_NODE_TYPE(StridedSliceAttrs); -bool StridedSliceRel(const Array& types, - int num_inputs, - const Attrs& attrs, - const TypeReporter& reporter) { - CHECK_EQ(types.size(), 2); - const auto* data = types[0].as(); - if (data == nullptr) return false; - const StridedSliceAttrs *param = attrs.as(); +bool StridedSliceRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 5); + const StridedSliceAttrs* param = attrs.as(); CHECK(param != nullptr); - + const auto* data = types[0].as(); + CHECK(data != nullptr); auto dshape = data->shape; - auto num_axis = dshape.size(); - - std::vector stride_vec; - for (Integer i : param->strides) { - CHECK(i.defined()); - stride_vec.push_back(i->value); - } - for (size_t i = stride_vec.size(); i < num_axis; ++i) { - stride_vec.push_back(1); - } - const int64_t max_range = std::numeric_limits::max(); - - std::vector begin_vec; - for (size_t i = 0; i < param->begin.size(); ++i) { - if (!param->begin[i].defined()) { - // value=None + int64_t num_axis = dshape.size(); + + // calculate output shape + std::vector oshape(num_axis); + if (param->begin && param->end && param->strides) { + // stride will be set as 1 if slice mode is enabled + std::vector stride_vec(num_axis, 1); + if (param->slice_mode == "end") { + for (size_t i = 0; i < param->strides.value().size(); ++i) { + CHECK(param->strides.value()[i].defined()); + stride_vec[i] = param->strides.value()[i]->value; + } + } + const int64_t max_range = std::numeric_limits::max(); + std::vector begin_vec; + for (size_t i = 0; i < param->begin.value().size(); ++i) { + if (!param->begin.value()[i].defined()) { + begin_vec.push_back(stride_vec[i] > 0 ? 0 : max_range); + } else { + begin_vec.push_back(param->begin.value()[i]->value); + } + } + for (int64_t i = begin_vec.size(); i < num_axis; ++i) { begin_vec.push_back(stride_vec[i] > 0 ? 0 : max_range); - } else { - begin_vec.push_back(param->begin[i]->value); } - } - for (size_t i = begin_vec.size(); i < num_axis; ++i) { - begin_vec.push_back(stride_vec[i] > 0 ? 0 : max_range); - } - std::vector end_vec; - for (size_t i = 0; i < param->end.size(); ++i) { - // allow end to be None - if (!param->end[i].defined()) { + std::vector end_vec; + for (size_t i = 0; i < param->end.value().size(); ++i) { + // allow end to be None + if (!param->end.value()[i].defined()) { + end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range); + } else if (param->slice_mode == "size") { + if (param->end.value()[i]->value < 0) { + end_vec.push_back(max_range); + } else { + end_vec.push_back(begin_vec[i] + param->end.value()[i]->value); + } + } else if (param->slice_mode == "end") { + end_vec.push_back(param->end.value()[i]->value); + } else { + LOG(FATAL) << "Unsupported slice mode: " << param->slice_mode; + } + } + for (int64_t i = end_vec.size(); i < num_axis; ++i) { end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range); - } else { - end_vec.push_back(param->end[i]->value); } - } - for (size_t i = end_vec.size(); i < num_axis; ++i) { - end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range); - } - - std::vector oshape(dshape.size()); - for (size_t i = 0; i < num_axis; ++i) { - int64_t stride_v = stride_vec[i]; - int64_t begin_v = begin_vec[i]; - int64_t end_v = end_vec[i]; - - if ((stride_v == 1 && - begin_v == 0 && - end_v == max_range) || - (stride_v == -1 && - begin_v == max_range && - end_v == 0)) { - // Quick path, do not slice this dimension. - oshape[i] = dshape[i]; - continue; + + for (int64_t i = 0; i < num_axis; ++i) { + int64_t stride_v = stride_vec[i]; + int64_t begin_v = begin_vec[i]; + int64_t end_v = end_vec[i]; + + if ((stride_v == 1 && begin_v == 0 && end_v == max_range) || + (stride_v == -1 && begin_v == max_range && end_v == 0)) { + // Quick path, do not slice this dimension. + oshape[i] = dshape[i]; + continue; + } + // Normal path, require the shape to be concrete integer. + // Require concrete integer as symbolic inference of min/max + // can get complicated and not very helpful. + const int64_t* p_dim_size = tir::as_const_int(dshape[i]); + if (!p_dim_size) { + oshape[i] = dshape[i]; + continue; + } + int64_t dim_size = p_dim_size[0]; + begin_v = (begin_v < 0) ? dim_size + begin_v : begin_v; + end_v = (end_v < 0) ? dim_size + end_v : end_v; + + int64_t slice_range, step; + if (stride_v < 0) { + if (end_v < -1) end_v = -1; + CHECK_LE(end_v, begin_v) << "strided_slice get empty slice at axis " << i; + begin_v = std::min(dim_size - 1, begin_v); + slice_range = begin_v - end_v; + step = -stride_v; + } else { + if (begin_v < 0) begin_v = 0; + CHECK_GE(stride_v, 0); + CHECK_LE(begin_v, end_v) << "strided_slice get invalid slice at axis " << i; + end_v = std::min(dim_size, end_v); + slice_range = end_v - begin_v; + step = stride_v; + } + oshape[i] = tir::make_const(dshape[i].dtype(), (slice_range + step - 1) / step); } - // Normal path, require the shape to be concrete integer. - // Require concrete integer as symbolic inference of min/max - // can get complicated and not very helpful. - const int64_t* p_dim_size = tir::as_const_int(dshape[i]); - CHECK(p_dim_size) - << "strided_slice requires sliced dimension to be concrete int"; - int64_t dim_size = p_dim_size[0]; - begin_v = (begin_v < 0) ? dim_size + begin_v : begin_v; - end_v = (end_v < 0) ? dim_size + end_v : end_v; - - int64_t slice_range, step; - if (stride_v < 0) { - if (end_v < -1) end_v = -1; - CHECK_LT(end_v, begin_v) - << "strided_slice get empty slice at axis " << i; - begin_v = std::min(dim_size - 1, begin_v); - slice_range = begin_v - end_v; - step = -stride_v; - } else { - if (begin_v < 0) begin_v = 0; - CHECK_GE(stride_v, 0); - CHECK_LT(begin_v, end_v) - << "strided_slice get empty slice at axis " << i; - end_v = std::min(dim_size, end_v); - slice_range = end_v - begin_v; - step = stride_v; + } else { + for (int64_t i = 0; i < num_axis; ++i) { + oshape[i] = Any(); } - oshape[i] = tir::make_const(dshape[i].dtype(), (slice_range + step - 1) / step); } - reporter->Assign(types[1], TensorType(oshape, data->dtype)); + + reporter->Assign(types[4], TensorType(oshape, data->dtype)); return true; } - -Array > StridedSliceInferCorrectLayout( - const Attrs& attrs, - const Array& new_in_layouts, - const Array& old_in_layouts, - const Array& old_in_types) { - +Array> StridedSliceInferCorrectLayout(const Attrs& attrs, + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array& old_in_types) { Array> old_in_shapes; for (auto old_in_t : old_in_types) { CHECK(old_in_t.as()); @@ -1891,22 +1836,39 @@ Array > StridedSliceInferCorrectLayout( } CHECK(old_in_layouts.defined()); - CHECK_EQ(old_in_layouts.size(), 1); + CHECK_GE(old_in_layouts.size(), 1); CHECK(old_in_shapes.defined()); - CHECK_EQ(old_in_shapes.size(), 1); + CHECK_GE(old_in_shapes.size(), 1); auto layout = old_in_layouts[0]; if (layout.defined() && new_in_layouts.defined()) { - CHECK_EQ(new_in_layouts.size(), 1); + CHECK_GE(new_in_layouts.size(), 1); auto new_layout = new_in_layouts[0]; auto shape = old_in_shapes[0]; // NOTE: Discard "const" qualifier here. - auto *params = const_cast(attrs.as()); + auto* params = const_cast(attrs.as()); + CHECK(params != nullptr); + Array begin, end, strides; + if (params->begin && params->end && params->strides) { + for (Integer i : params->strides.value()) { + CHECK(i.defined()); + strides.push_back(params->slice_mode == "size" ? 1 : i->value); + } + + for (Integer i : params->begin.value()) { + CHECK(i.defined()); + begin.push_back(i->value); + } + for (Integer i : params->end.value()) { + CHECK(i.defined()); + end.push_back(i->value); + } + } Array new_begin, new_end; - for (size_t i = 0; i < params->begin.size(); i++) { + for (size_t i = 0; i < begin.size(); i++) { const LayoutAxis& axis = layout[i]; if (!axis.IsPrimal()) { // original layout that contains splitted axes is not supported @@ -1914,62 +1876,118 @@ Array > StridedSliceInferCorrectLayout( } auto factor = new_layout.FactorOf(axis); if (factor == -1) { - new_begin.push_back(params->begin[i]); - new_end.push_back(params->end[i]); + new_begin.push_back(begin[i]); + new_end.push_back(end[i]); } else { - if (params->strides.defined() && i < params->strides.size()) { - auto stride = params->strides[i]; + if (strides.defined() && i < strides.size()) { + auto stride = strides[i]; // arbitrary stride is not supported if (stride.defined() && stride->value != 1) { return {{Layout::Undef()}, {Layout::Undef()}}; } } - int64_t begin = params->begin[i].defined() ? params->begin[i]->value : 0; - int64_t end = params->end[i].defined() ? params->end[i]->value : - shape[i].as()->value; - if (begin % factor || end % factor) { + int64_t bg = begin[i].defined() ? begin[i]->value : 0; + int64_t ed; + if (!end[i].defined()) { + ed = shape[i].as()->value; + } else if (params->slice_mode == "size") { + if (end[i]->value < 0) { + ed = shape[i].as()->value; + } else { + ed = bg + end[i]->value; + } + } else { + ed = end[i]->value; + } + + if (bg % factor || ed % factor) { // transform to original layout return {{Layout::Undef()}, {Layout::Undef()}}; } - new_begin.push_back(tvm::Integer(begin / factor)); - new_end.push_back(tvm::Integer(end / factor)); + new_begin.push_back(tvm::Integer(bg / factor)); + new_end.push_back(tvm::Integer(ed / factor)); } } + layout = new_layout; params->begin = new_begin; params->end = new_end; } - return {{layout}, {layout}}; + return {{layout, Layout("C"), Layout("C"), Layout("C")}, {layout}}; +} + +inline te::Tensor DynamicStridedSlice(const te::Tensor& input, const te::Tensor& begin, + const te::Tensor& end, const te::Tensor& strides, + std::string name = "T_strided_slice_dynamic", + std::string tag = topi::kInjective) { + int64_t src_tensor_dim = input->shape.size(); + Array out_shape; + for (int64_t i = 0; i < src_tensor_dim; ++i) { + out_shape.push_back(tvm::tir::Var("dim")); + } + // TODO(yongwww): move the compute into topi + return te::compute( + out_shape, + [&](const Array& indices) { + Array real_indices; + for (int32_t i = 0; i < src_tensor_dim; ++i) { + real_indices.push_back(indices[i] * strides(i) + begin(i)); + } + return input(real_indices); + }, + name, tag); } +Array StridedSliceCompute(const Attrs& attrs, const Array& inputs, + const Type& out_type) { + const StridedSliceAttrs* param = attrs.as(); + CHECK(param != nullptr); + if (param->begin && param->end && param->strides) { + Array begin, end, strides; + begin = param->begin.value(); + end = param->end.value(); + strides = param->strides.value(); + return Array{ + topi::strided_slice(inputs[0], begin, end, strides, param->slice_mode)}; + } else { + te::Tensor data = inputs[0]; + te::Tensor begin = inputs[1]; + te::Tensor end = inputs[2]; + te::Tensor strides = inputs[3]; + // Dynamic computation + int64_t attr_size = data->shape.size(); + CHECK(begin->shape[0].as()->value == attr_size && + end->shape[0].as()->value == attr_size && + strides->shape[0].as()->value == attr_size) + << "begin, end, and strides are required to have the same length" + << " if they are non-constant."; + return Array{DynamicStridedSlice(data, begin, end, strides)}; + } +} // Positional relay function to create StridedSlice operator used by frontend FFI. -Expr MakeStridedSlice(Expr data, - Array begin, - Array end, - Array strides) { +Expr MakeStridedSlice(Expr data, Expr begin, Expr end, Expr strides, String slice_mode) { auto attrs = make_object(); - attrs->begin = std::move(begin); - attrs->end = std::move(end); - attrs->strides = std::move(strides); + const ConstantNode *cbegin, *cend, *cstrides; + if ((cbegin = begin.as()) && (cend = end.as()) && + (cstrides = strides.as())) { + CHECK_EQ(cbegin->data->ndim, 1); + CHECK_EQ(cend->data->ndim, 1); + CHECK_EQ(cstrides->data->ndim, 1); + Array begin, end, strides; + begin = ToVector(cbegin->data); + end = ToVector(cend->data); + strides = ToVector(cstrides->data); + attrs->begin = begin; + attrs->end = end; + attrs->strides = strides; + } + attrs->slice_mode = slice_mode; static const Op& op = Op::Get("strided_slice"); - return Call(op, {data}, Attrs(attrs), {}); + return Call(op, {data, begin, end, strides}, Attrs(attrs), {}); } -Array StridedSliceCompute(const Attrs& attrs, - const Array& inputs, - const Type& out_type) { - const StridedSliceAttrs *param = attrs.as(); - CHECK(param != nullptr); - return Array{ - topi::strided_slice(inputs[0], param->begin, param->end, param->strides) - }; -} - - -TVM_REGISTER_GLOBAL("relay.op._make.strided_slice") -.set_body_typed(MakeStridedSlice); - +TVM_REGISTER_GLOBAL("relay.op._make.strided_slice").set_body_typed(MakeStridedSlice); RELAY_REGISTER_OP("strided_slice") .describe(R"code(Strided slice of an array. @@ -1995,40 +2013,37 @@ Examples:: [[ 5., 6.], [ 7., 8.]]] )code" TVM_ADD_FILELINE) -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(4) -.set_attrs_type() -.add_type_rel("StridedSlice", StridedSliceRel) -.set_attr("FTVMCompute", StridedSliceCompute) -.set_attr("TOpPattern", kInjective) -.set_attr("FInferCorrectLayout", StridedSliceInferCorrectLayout); + .set_num_inputs(4) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("begin", "Tensor", "The indices to begin with in the slicing.") + .add_argument("end", "Tensor", "Indices indicating end of the slice.") + .add_argument("strides", "Tensor", "The stride values.") + .add_argument("slice_mode", "Tensor", "The slice mode.") + .set_support_level(4) + .set_attrs_type() + .add_type_rel("StridedSlice", StridedSliceRel) + .set_attr("FTVMCompute", StridedSliceCompute) + .set_attr("TOpPattern", kInjective) + .set_attr("AnyCodegenStrategy", kVariableDimensions) + .set_attr("FInferCorrectLayout", StridedSliceInferCorrectLayout); // strided_set -bool StridedSetRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool StridedSetRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 6); reporter->Assign(types[5], types[0]); return true; } -Expr MakeStridedSet(Expr data, - Expr v, - Expr begin, - Expr end, - Expr strides) { +Expr MakeStridedSet(Expr data, Expr v, Expr begin, Expr end, Expr strides) { static const Op& op = Op::Get("strided_set"); return Call(op, {data, v, begin, end, strides}, {}); } -TVM_REGISTER_GLOBAL("relay.op._make.strided_set") -.set_body_typed(MakeStridedSet); - +TVM_REGISTER_GLOBAL("relay.op._make.strided_set").set_body_typed(MakeStridedSet); RELAY_REGISTER_OP("strided_set") - .describe(R"code(Strided set of an array. + .describe(R"code(Strided set of an array. Example:: x = [[ 1., 4., 7., 10.], @@ -2043,22 +2058,20 @@ Example:: [ 2., 44., 55., 66.], [ 3., 6., 9., 12.]] )code" TVM_ADD_FILELINE) -.set_num_inputs(5) -.add_argument("data", "Tensor", "The input tensor.") -.add_argument("v", "Tensor", "The data to set.") -.add_argument("begin", "Tensor", "Indices for the start of the slice.") -.add_argument("end", "Tensor", "Indices indicating the end of the slice.") -.add_argument("strides", "Tensor", "The strides values.") -.set_support_level(4) -.set_attr("TOpPattern", kInjective) -.add_type_rel("StridedSet", StridedSetRel); + .set_num_inputs(5) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("v", "Tensor", "The data to set.") + .add_argument("begin", "Tensor", "Indices for the start of the slice.") + .add_argument("end", "Tensor", "Indices indicating the end of the slice.") + .add_argument("strides", "Tensor", "The strides values.") + .set_support_level(4) + .set_attr("TOpPattern", kInjective) + .add_type_rel("StridedSet", StridedSetRel); // relay.split TVM_REGISTER_NODE_TYPE(SplitAttrs); -bool SplitRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool SplitRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { // `types` contains: [data, result] CHECK_EQ(types.size(), 2); @@ -2071,25 +2084,29 @@ bool SplitRel(const Array& types, if (axis < 0) { axis += data->shape.size(); } - CHECK_LT(axis, data->shape.size()) - << "axis should be within the input dimension range."; - CHECK_GE(axis, 0) - << "axis should be within the input dimension range."; + CHECK_LT(axis, data->shape.size()) << "axis should be within the input dimension range."; + CHECK_GE(axis, 0) << "axis should be within the input dimension range."; if (const IntImmNode* sections = param->indices_or_sections.as()) { - CHECK(reporter->Assert(indexmod(data->shape[axis], - sections->value) == tir::make_zero(DataType::Int(64)))) - << "indices_or_sections need to be able to divide input.shape[axis]"; + if (!data->shape[axis].as()) { + CHECK(reporter->Assert(indexmod(data->shape[axis], sections->value) == + tir::make_zero(DataType::Int(64)))) + << "indices_or_sections need to be able to divide input.shape[axis]"; + } std::vector fields; for (int i = 0; i < sections->value; ++i) { - std::vector oshape(data->shape.begin(), data->shape.end()); + std::vector oshape(data->shape.begin(), data->shape.end()); + if (data->shape[axis].as()) { + oshape[axis] = Any(); + } else { oshape[axis] = indexdiv(oshape[axis], sections->value); - auto vec_type = TensorType(oshape, data->dtype); - fields.push_back(vec_type); + } + auto vec_type = TensorType(oshape, data->dtype); + fields.push_back(vec_type); } reporter->Assign(types[1], TupleType(Array(fields))); } else { - auto indices = param->indices_or_sections.as()->data; + auto indices = Downcast>(param->indices_or_sections); auto begin = IndexExpr(tir::make_zero(DataType::Int(32))); std::vector fields; for (unsigned int i = 0; i < indices.size(); ++i) { @@ -2101,10 +2118,16 @@ bool SplitRel(const Array& types, auto vec_type = TensorType(oshape, data->dtype); fields.push_back(vec_type); } - CHECK(reporter->Assert(begin < data->shape[axis])) - << "The sum of sections must match the input.shape[axis]"; + if (!data->shape[axis].as()) { + CHECK(reporter->Assert(begin < data->shape[axis])) + << "The sum of sections must match the input.shape[axis]"; + } std::vector oshape(data->shape.begin(), data->shape.end()); - oshape[axis] = data->shape[axis] - begin; + if (data->shape[axis].as()) { + oshape[axis] = Any(); + } else { + oshape[axis] = data->shape[axis] - begin; + } auto vec_type = TensorType(oshape, data->dtype); fields.push_back(vec_type); reporter->Assign(types[1], TupleType(Array(fields))); @@ -2112,25 +2135,21 @@ bool SplitRel(const Array& types, return true; } -Array SplitCompute(const Attrs& attrs, - const Array& inputs, +Array SplitCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { const auto param = attrs.as(); CHECK(param != nullptr); if (const IntImmNode* sections = param->indices_or_sections.as()) { int64_t num_sections = sections->value; - return Array{ - topi::split_sections(inputs[0], num_sections, param->axis) }; + return Array{topi::split_sections(inputs[0], num_sections, param->axis)}; } else { - auto indices = Downcast >(param->indices_or_sections); - return Array{ topi::split(inputs[0], indices, param->axis) }; + auto indices = Downcast>(param->indices_or_sections); + return Array{topi::split(inputs[0], indices, param->axis)}; } } -Expr MakeSplit(Expr data, - ObjectRef indices_or_sections, - int axis) { +Expr MakeSplit(Expr data, ObjectRef indices_or_sections, int axis) { auto attrs = make_object(); attrs->axis = axis; attrs->indices_or_sections = std::move(indices_or_sections); @@ -2138,22 +2157,20 @@ Expr MakeSplit(Expr data, return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op._make.split") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { - if (args.type_codes[1] == kDLInt) { - // Note: we change it from Int(64) to Int(32) for now as - // combine_parallel_dense will transform the graph with Int(32). - // More invetigation is needs to check which one we should use. - *rv = MakeSplit(args[0], - tir::make_const(DataType::Int(32), static_cast(args[1])), - args[2]); - } else { - *rv = MakeSplit(args[0], args[1], args[2]); - } +TVM_REGISTER_GLOBAL("relay.op._make.split").set_body([](const TVMArgs& args, TVMRetValue* rv) { + if (args.type_codes[1] == kDLInt) { + // Note: we change it from Int(64) to Int(32) for now as + // combine_parallel_dense will transform the graph with Int(32). + // More invetigation is needs to check which one we should use. + *rv = + MakeSplit(args[0], tir::make_const(DataType::Int(32), static_cast(args[1])), args[2]); + } else { + *rv = MakeSplit(args[0], args[1], args[2]); + } }); RELAY_REGISTER_OP("split") -.describe(R"code(Splits an array along a particular axis into multiple sub-arrays. + .describe(R"code(Splits an array along a particular axis into multiple sub-arrays. Indices or sections to split into. Accepts an int or a tuple If indices_or_sections is an integer, the input will be divided equally @@ -2163,29 +2180,26 @@ If indices_or_sections is a tuple of sorted integers, the entries indicate where along axis the array is split. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(3) -.add_type_rel("Split", SplitRel) -.set_attr("FTVMCompute", SplitCompute) -.set_attr("TOpPattern", kInjective); - + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(3) + .add_type_rel("Split", SplitRel) + .set_attr("FTVMCompute", SplitCompute) + .set_attr("TOpPattern", kInjective); // relay.slice_like TVM_REGISTER_NODE_TYPE(SliceLikeAttrs); /*! -* \brief SliceLikeRel User defined type constraint function. -* \param num_inputs Number of input types in the args. -* \param attrs The additional attributes of the operator. -* \param reporter The reporter to report solution to. -* \return False if the relation has not been resolved, it might be resolved later. -* True if this relation has been resolved. -*/ -bool SliceLikeRel(const Array& types, - int num_inputs, - const Attrs& attrs, + * \brief SliceLikeRel User defined type constraint function. + * \param num_inputs Number of input types in the args. + * \param attrs The additional attributes of the operator. + * \param reporter The reporter to report solution to. + * \return False if the relation has not been resolved, it might be resolved later. + * True if this relation has been resolved. + */ +bool SliceLikeRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 3); const auto* data = types[0].as(); @@ -2210,8 +2224,8 @@ bool SliceLikeRel(const Array& types, if (i < target_shape.size()) { oshape[i] = target_shape[i]; CHECK(reporter->Assert(oshape[i] <= dshape[i])) - << "End index of axis " << i << " exceeds input shape: " - << oshape[i] << " vs " << dshape[i]; + << "End index of axis " << i << " exceeds input shape: " << oshape[i] << " vs " + << dshape[i]; } } } else { @@ -2222,12 +2236,11 @@ bool SliceLikeRel(const Array& types, axis += dshape.size(); } CHECK(axis < static_cast(target_shape.size())) - << "Axis " << axis << " exceeds dimension " - << target_shape.size() << " of target_shape."; + << "Axis " << axis << " exceeds dimension " << target_shape.size() << " of target_shape."; oshape[axis] = target_shape[axis]; CHECK(reporter->Assert(oshape[axis] <= dshape[axis])) - << "End index of axis " << axis << " exceeds input shape: " - << oshape[axis] << " vs " << dshape[axis]; + << "End index of axis " << axis << " exceeds input shape: " << oshape[axis] << " vs " + << dshape[axis]; } } @@ -2235,18 +2248,14 @@ bool SliceLikeRel(const Array& types, return true; } - -Expr MakeSliceLike(Expr data, - Expr shape_like, - Array axes) { +Expr MakeSliceLike(Expr data, Expr shape_like, Array axes) { auto attrs = make_object(); attrs->axes = std::move(axes); static const Op& op = Op::Get("slice_like"); return Call(op, {data, shape_like}, Attrs(attrs), {}); } -Array SliceLikeCompute(const Attrs& attrs, - const Array& inputs, +Array SliceLikeCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { const auto* param = attrs.as(); CHECK(param != nullptr); @@ -2262,11 +2271,10 @@ Array SliceLikeCompute(const Attrs& attrs, for (size_t i = 0; i < src_shape.size(); ++i) { if (i < target_shape.size()) { end_idx.Set(i, target_shape[i]); - CHECK_LE(topi::GetConstInt(end_idx[i]), - topi::GetConstInt(src_shape[i])) - << "End index of axis " << i << " exceeds input shape: " - << topi::GetConstInt(end_idx[i]) << " vs " - << topi::GetConstInt(src_shape[i]); + CHECK_LE(topi::GetConstInt(end_idx[i]), topi::GetConstInt(src_shape[i])) + << "End index of axis " << i + << " exceeds input shape: " << topi::GetConstInt(end_idx[i]) << " vs " + << topi::GetConstInt(src_shape[i]); } } } else { @@ -2275,77 +2283,64 @@ Array SliceLikeCompute(const Attrs& attrs, axis = static_cast(src_shape.size()) + axis; } end_idx.Set(axis, target_shape[axis]); - CHECK_LE(topi::GetConstInt(end_idx[axis]), - topi::GetConstInt(src_shape[axis])) - << "End index of axis " << axis << " exceeds input shape: " - << topi::GetConstInt(end_idx[axis]) << " vs " - << topi::GetConstInt(src_shape[axis]); + CHECK_LE(topi::GetConstInt(end_idx[axis]), topi::GetConstInt(src_shape[axis])) + << "End index of axis " << axis + << " exceeds input shape: " << topi::GetConstInt(end_idx[axis]) << " vs " + << topi::GetConstInt(src_shape[axis]); } } - return Array{ - topi::strided_slice(inputs[0], - GetIntArray(begin_idx), - GetIntArray(end_idx), - GetIntArray(strides)) - }; + return Array{topi::strided_slice(inputs[0], GetIntArray(begin_idx), + GetIntArray(end_idx), GetIntArray(strides), "end")}; } - -TVM_REGISTER_GLOBAL("relay.op._make.slice_like") -.set_body_typed(MakeSliceLike); - +TVM_REGISTER_GLOBAL("relay.op._make.slice_like").set_body_typed(MakeSliceLike); RELAY_REGISTER_OP("slice_like") -.describe(R"code(Slice the first input respect to the second input. + .describe(R"code(Slice the first input respect to the second input. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(2) -.add_argument("data", "Tensor", "The input tensor.") -.add_argument("shape_like", "Tensor", "Shape tensor.") -.set_support_level(10) -.add_type_rel("SliceLike", SliceLikeRel) -.set_attr("FTVMCompute", SliceLikeCompute) -.set_attr("TOpPattern", kInjective); + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("shape_like", "Tensor", "Shape tensor.") + .set_support_level(10) + .add_type_rel("SliceLike", SliceLikeRel) + .set_attr("FTVMCompute", SliceLikeCompute) + .set_attr("TOpPattern", kInjective); // relay.layout_transform TVM_REGISTER_NODE_TYPE(LayoutTransformAttrs); -Array LayoutTransformCompute(const Attrs& attrs, - const Array& inputs, +Array LayoutTransformCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { const auto* param = attrs.as(); CHECK(param != nullptr); - return Array{ - topi::layout_transform(inputs[0], param->src_layout, param->dst_layout) - }; + return Array{topi::layout_transform(inputs[0], param->src_layout, param->dst_layout)}; } -bool LayoutTransformRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool LayoutTransformRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { const auto* data = types[0].as(); - CHECK(data != nullptr); + if (data == nullptr) { + CHECK(types[0].as()) + << "LayoutTransform: expect input data type to be TensorType but get " << types[0]; + return false; + } const LayoutTransformAttrs* params = attrs.as(); Layout src_layout(params->src_layout); Layout dst_layout(params->dst_layout); - CHECK(src_layout.defined() && dst_layout.defined()) - << "cannot convert from/to undefined layout"; - + CHECK(src_layout.defined() && dst_layout.defined()) << "cannot convert from/to undefined layout"; auto layout_converter = tir::BijectiveLayout(src_layout, dst_layout); CHECK(layout_converter.defined()) - << "cannot convert from " << params->src_layout << " to " << params->dst_layout; + << "cannot convert from " << params->src_layout << " to " << params->dst_layout; const auto& out_shape = layout_converter.ForwardShape(data->shape); reporter->Assign(types[1], TensorType(out_shape, data->dtype)); return true; } -Expr MakeLayoutTransform(Expr data, - std::string src_layout, - std::string dst_layout) { +Expr MakeLayoutTransform(Expr data, String src_layout, String dst_layout) { auto attrs = make_object(); attrs->src_layout = std::move(src_layout); attrs->dst_layout = std::move(dst_layout); @@ -2353,27 +2348,24 @@ Expr MakeLayoutTransform(Expr data, return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op._make.layout_transform") -.set_body_typed(MakeLayoutTransform); +TVM_REGISTER_GLOBAL("relay.op._make.layout_transform").set_body_typed(MakeLayoutTransform); RELAY_REGISTER_OP("layout_transform") -.describe(R"code(Transform the input data layout. + .describe(R"code(Transform the input data layout. For transforming from NCHW to N16cHWC, the `__layout_transform__` operator reshapes the input array by output[n, c, h, w, C] = data[n, C*16+c, h, w] )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.add_type_rel("layout_transform", LayoutTransformRel) -.set_support_level(5) -.set_attr("FTVMCompute", LayoutTransformCompute); - + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .add_type_rel("layout_transform", LayoutTransformRel) + .set_support_level(5) + .set_attr("FTVMCompute", LayoutTransformCompute); /* relay._contrib_reverse_reshape */ -Expr MakeReverseReshape(Expr data, - Array newshape) { +Expr MakeReverseReshape(Expr data, Array newshape) { auto attrs = make_object(); attrs->newshape = std::move(newshape); attrs->reverse = true; @@ -2381,11 +2373,10 @@ Expr MakeReverseReshape(Expr data, return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op._make._contrib_reverse_reshape") -.set_body_typed(MakeReverseReshape); +TVM_REGISTER_GLOBAL("relay.op._make._contrib_reverse_reshape").set_body_typed(MakeReverseReshape); RELAY_REGISTER_OP("_contrib_reverse_reshape") -.describe(R"code(Reshapes the input array where the special values are inferred from + .describe(R"code(Reshapes the input array where the special values are inferred from right to left. Example:: @@ -2398,18 +2389,98 @@ example below:: - data.shape = (10,5,4), newshape = (-1,0), reverse_reshape results in (40,5) )code" TVM_ADD_FILELINE) -.set_num_inputs(1) -.set_attrs_type() -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(10) -.add_type_rel("Reshape", ReshapeRel) -.set_attr("FTVMCompute", ReshapeCompute) -.set_attr("TOpPattern", kInjective); + .set_num_inputs(1) + .set_attrs_type() + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(10) + .add_type_rel("Reshape", ReshapeRel) + .set_attr("FTVMCompute", ReshapeCompute) + .set_attr("TOpPattern", kInjective); + +// gather operator +TVM_REGISTER_NODE_TYPE(GatherAttrs); + +bool GatherRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + // `types` contains: [data, indices, result] + CHECK_EQ(types.size(), 3); + const auto* data = types[0].as(); + const auto* indices = types[1].as(); + if (data == nullptr) { + CHECK(types[0].as()) + << "Gather: expect input data type to be TensorType but get " << types[0]; + return false; + } + if (indices == nullptr) { + CHECK(types[1].as()) + << "Gather: expect indices type to be TensorType but get " << types[1]; + return false; + } + CHECK(indices->dtype.is_int()) << "indices of take must be tensor of integer"; + const auto param = attrs.as(); + CHECK(param != nullptr); + CHECK(param->axis.defined()); + + const auto ndim_data = data->shape.size(); + const auto ndim_indices = indices->shape.size(); + int axis = param->axis->value; + CHECK_EQ(ndim_data, ndim_indices); + CHECK_GE(axis, 0); + CHECK_LT(axis, ndim_data); + + std::vector oshape; + oshape.reserve(ndim_data); + for (size_t i = 0; i < ndim_data; ++i) { + if (i == (size_t)axis) { + const int64_t* indice_shape_i = tir::as_const_int(indices->shape[i]); + CHECK_GE(*indice_shape_i, 1); + } else { + CHECK(reporter->AssertEQ(indices->shape[i], data->shape[i])); + } + oshape.emplace_back(indices->shape[i]); + } + reporter->Assign(types[2], TensorType(oshape, data->dtype)); + return true; +} + +Array GatherCompute(const Attrs& attrs, const Array& inputs, + const Type& out_type) { + const auto* param = attrs.as(); + return {topi::gather(inputs[0], param->axis, inputs[1])}; +} + +Expr MakeGather(Expr data, Integer axis, Expr indices) { + auto attrs = make_object(); + attrs->axis = std::move(axis); + static const Op& op = Op::Get("gather"); + return Call(op, {data, indices}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op._make.gather").set_body_typed(MakeGather); + +RELAY_REGISTER_OP("gather") + .describe(R"code(Gather values along given axis from given indices. + +E.g. for a 3D tensor, output is computed as: + + out[i][j][k] = data[indices[i][j][k]][j][k] # if axis == 0 + out[i][j][k] = data[i][indices[i][j][k]][k] # if axis == 1 + out[i][j][k] = data[i][j][indices[i][j][k]] # if axis == 2 + +``indices`` must have same shape as ``data``, except at dimension ``axis`` +which must just be not null. Output will have same shape as ``indices``. +)code" TVM_ADD_FILELINE) + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input data to the operator.") + .add_argument("indices", "Tensor", "The indices of values to gather.") + .set_support_level(3) + .add_type_rel("Gather", GatherRel) + .set_attr("FTVMCompute", GatherCompute) + .set_attr("TOpPattern", kInjective); // gather_nd operator -bool GatherNDRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool GatherNDRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { // `types` contains: [data, indices, result] CHECK_EQ(types.size(), 3); @@ -2417,48 +2488,43 @@ bool GatherNDRel(const Array& types, const auto* indices = types[1].as(); if (data == nullptr) { CHECK(types[0].as()) - << "GatherND: expect input data type to be TensorType but get " - << types[0]; + << "GatherND: expect input data type to be TensorType but get " << types[0]; return false; } if (indices == nullptr) { CHECK(types[1].as()) - << "GatherND: expect indices type to be TensorType but get " - << types[1]; + << "GatherND: expect indices type to be TensorType but get " << types[1]; return false; } const size_t ndim = data->shape.size(); const IntImmNode* mdim = indices->shape[0].as(); const size_t kdim = indices->shape.size() - 1; - CHECK(size_t(mdim->value) <= ndim) - << "GatherND: indices shape does satisfy."; + CHECK(size_t(mdim->value) <= ndim) << "GatherND: indices shape does satisfy."; Array oshape; - for (size_t i = 1; i < kdim + 1; ++i) - oshape.push_back(indices->shape[i]); - for (size_t i = mdim->value; i < ndim; ++i) - oshape.push_back(data->shape[i]); + for (size_t i = 1; i < kdim + 1; ++i) oshape.push_back(indices->shape[i]); + for (size_t i = mdim->value; i < ndim; ++i) oshape.push_back(data->shape[i]); + if (oshape.size() == 0) { + oshape.push_back(tir::make_const(DataType::Int(32), 1)); + } reporter->Assign(types[2], TensorType(oshape, data->dtype)); return true; } -Array GatherNDCompute(const Attrs& attrs, - const Array& inputs, +Array GatherNDCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { - return { topi::gather_nd(inputs[0], inputs[1]) }; + return {topi::gather_nd(inputs[0], inputs[1])}; } -Expr MakeGatherND(Expr data, - Expr indices) { +Expr MakeGatherND(Expr data, Expr indices) { static const Op& op = Op::Get("gather_nd"); return Call(op, {data, indices}, {}); } -TVM_REGISTER_GLOBAL("relay.op._make.gather_nd") -.set_body_typed(MakeGatherND); +TVM_REGISTER_GLOBAL("relay.op._make.gather_nd").set_body_typed(MakeGatherND); RELAY_REGISTER_OP("gather_nd") -.describe(R"code(Gather elements or slices from data and store to + .describe(R"code(Gather elements or slices from data and store to a tensor whose shape is defined by indices. Given data with shape (X_0, X_1, ..., X_{N-1}) and indices with @@ -2466,19 +2532,17 @@ shape (M, Y_0, ..., Y_{K-1}), the output will have shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), where M <= N. If M == N, output shape will simply be (Y_0, ..., Y_{K-1}). )code" TVM_ADD_FILELINE) -.set_num_inputs(2) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(3) -.add_type_rel("GatherND", GatherNDRel) -.set_attr("FTVMCompute", GatherNDCompute) -.set_attr("TOpPattern", kInjective); + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(3) + .add_type_rel("GatherND", GatherNDRel) + .set_attr("FTVMCompute", GatherNDCompute) + .set_attr("TOpPattern", kInjective); // relay.sequence_mask TVM_REGISTER_NODE_TYPE(SequenceMaskAttrs); -bool SequenceMaskRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool SequenceMaskRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { // `types` contains: [data, valid_length, result] CHECK_EQ(types.size(), 3); @@ -2495,19 +2559,15 @@ bool SequenceMaskRel(const Array& types, return true; } -Array SequenceMaskCompute(const Attrs& attrs, - const Array& inputs, +Array SequenceMaskCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { const auto* param = attrs.as(); CHECK(param != nullptr); return Array{ - topi::sequence_mask(inputs[0], inputs[1], param->mask_value, param->axis) }; + topi::sequence_mask(inputs[0], inputs[1], param->mask_value, param->axis)}; } -Expr MakeSequenceMask(Expr data, - Expr valid_length, - double mask_value, - int axis) { +Expr MakeSequenceMask(Expr data, Expr valid_length, double mask_value, int axis) { auto attrs = make_object(); attrs->mask_value = std::move(mask_value); attrs->axis = std::move(axis); @@ -2515,11 +2575,11 @@ Expr MakeSequenceMask(Expr data, return Call(op, {data, valid_length}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op._make.sequence_mask") -.set_body_typed(MakeSequenceMask); +TVM_REGISTER_GLOBAL("relay.op._make.sequence_mask").set_body_typed(MakeSequenceMask); RELAY_REGISTER_OP("sequence_mask") -.describe(R"code(Sets all elements outside the expected length of the sequence to a constant value. + .describe( + R"code(Sets all elements outside the expected length of the sequence to a constant value. This function takes an n-dimensional input array of the form [MAX_LENGTH, batch_size, ...] or [batch_size, MAX_LENGTH, ...] and returns an array of the same shape. @@ -2567,21 +2627,19 @@ Examples:: [[ 0.1, 0.1, 0.1], [ 16., 17., 18.]]] )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(2) -.add_argument("data", "Tensor", "The input tensor.") -.add_argument("valid_length", "Tensor", "The real (valid) length of each sequence.") -.set_support_level(10) -.add_type_rel("SequenceMask", SequenceMaskRel) -.set_attr("FTVMCompute", SequenceMaskCompute) -.set_attr("TOpPattern", kInjective); + .set_attrs_type() + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("valid_length", "Tensor", "The real (valid) length of each sequence.") + .set_support_level(10) + .add_type_rel("SequenceMask", SequenceMaskRel) + .set_attr("FTVMCompute", SequenceMaskCompute) + .set_attr("TOpPattern", kInjective); // relay.one_hot TVM_REGISTER_NODE_TYPE(OneHotAttrs); -bool OneHotRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool OneHotRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { // `types` contains: [indices, on_value, off_value, result] CHECK_EQ(types.size(), 4); @@ -2607,27 +2665,15 @@ bool OneHotRel(const Array& types, return true; } -Array OneHotCompute(const Attrs& attrs, - const Array& inputs, +Array OneHotCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { const auto* param = attrs.as(); CHECK(param != nullptr); - return Array { - topi::one_hot(inputs[0], - inputs[1](), - inputs[2](), - param->depth, - param->axis, - param->dtype) - }; -} - -Expr MakeOneHot(Expr indices, - Expr on_value, - Expr off_value, - int depth, - int axis, - DataType dtype) { + return Array{ + topi::one_hot(inputs[0], inputs[1](), inputs[2](), param->depth, param->axis, param->dtype)}; +} + +Expr MakeOneHot(Expr indices, Expr on_value, Expr off_value, int depth, int axis, DataType dtype) { auto attrs = make_object(); attrs->depth = std::move(depth); attrs->axis = axis; @@ -2636,11 +2682,10 @@ Expr MakeOneHot(Expr indices, return Call(op, {indices, on_value, off_value}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op._make.one_hot") -.set_body_typed(MakeOneHot); +TVM_REGISTER_GLOBAL("relay.op._make.one_hot").set_body_typed(MakeOneHot); RELAY_REGISTER_OP("one_hot") -.describe(R"code(Returns a one-hot tensor where the locations repsented by indices take value 1, + .describe(R"code(Returns a one-hot tensor where the locations repsented by indices take value 1, other locations take value 0. Final dimension is x depth. **indices** Locations to set to 1. @@ -2654,42 +2699,36 @@ RELAY_REGISTER_OP("one_hot") **axis** Axis to fill. **dtype**)code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(3) -.add_argument("indices", "Tensor", "Locations to set to on_value.") -.add_argument("on_value", "Expr", "Value to fill at indices.") -.add_argument("off_value", "Expr", "Value to fill at all other positions besides indices.") -.set_support_level(10) -.add_type_rel("OneHot", OneHotRel) -.set_attr("FTVMCompute", OneHotCompute) -.set_attr("TOpPattern", kOutEWiseFusable); + .set_attrs_type() + .set_num_inputs(3) + .add_argument("indices", "Tensor", "Locations to set to on_value.") + .add_argument("on_value", "Expr", "Value to fill at indices.") + .add_argument("off_value", "Expr", "Value to fill at all other positions besides indices.") + .set_support_level(10) + .add_type_rel("OneHot", OneHotRel) + .set_attr("FTVMCompute", OneHotCompute) + .set_attr("TOpPattern", kOutEWiseFusable); /* relay.unravel_index */ -bool UnRavelIndexRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool UnRavelIndexRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 3); const auto* indices = types[0].as(); if (indices == nullptr) { CHECK(types[0].as()) - << "unravel_index: expect input type to be TensorType but get " - << types[0]; + << "unravel_index: expect input type to be TensorType but get " << types[0]; return false; } - CHECK(indices->dtype.is_int()) - << "indices of unravel_index must be tensor of integer"; + CHECK(indices->dtype.is_int()) << "indices of unravel_index must be tensor of integer"; const auto* shape = types[1].as(); if (shape == nullptr) { CHECK(types[1].as()) - << "unravel_index: expect input type to be TensorType but get " - << types[1]; + << "unravel_index: expect input type to be TensorType but get " << types[1]; return false; } - CHECK(indices->dtype.is_int()) - << "shape of unravel_index must be tensor of integer"; + CHECK(indices->dtype.is_int()) << "shape of unravel_index must be tensor of integer"; Array indices_shape; Array shape_shape; @@ -2705,32 +2744,104 @@ bool UnRavelIndexRel(const Array& types, return true; } -Array UnRavelIndexCompute(const Attrs& attrs, - const Array& inputs, +Array UnRavelIndexCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { return Array{topi::unravel_index(inputs[0], inputs[1])}; } -Expr MakeUnRavelIndex(Expr data, - Expr shape) { +Expr MakeUnRavelIndex(Expr data, Expr shape) { static const Op& op = Op::Get("unravel_index"); return Call(op, {data, shape}, Attrs(), {}); } -TVM_REGISTER_GLOBAL("relay.op._make.unravel_index") -.set_body_typed(MakeUnRavelIndex); +TVM_REGISTER_GLOBAL("relay.op._make.unravel_index").set_body_typed(MakeUnRavelIndex); RELAY_REGISTER_OP("unravel_index") -.describe(R"code(Converts a flat index or array of flat indices into a tuple of coordinate arrays. + .describe( + R"code(Converts a flat index or array of flat indices into a tuple of coordinate arrays. Example:: - unravel_index([22, 41, 37], (7, 6)) = [[3, 6, 6], [4, 5, 1]] )code" TVM_ADD_FILELINE) -.set_num_inputs(2) -.set_support_level(3) -.add_type_rel("UnRavelIndexRel", UnRavelIndexRel) -.set_attr("FTVMCompute", UnRavelIndexCompute) -.set_attr("TOpPattern", kInjective); + .set_num_inputs(2) + .set_support_level(3) + .add_type_rel("UnRavelIndexRel", UnRavelIndexRel) + .set_attr("FTVMCompute", UnRavelIndexCompute) + .set_attr("TOpPattern", kInjective); + +// sparse_to_dense +TVM_REGISTER_NODE_TYPE(SparseToDenseAttrs); + +bool SparseToDenseRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(num_inputs, 3); + auto sparse_indices = types[0].as(); + auto sparse_values = types[1].as(); + auto default_value = types[2].as(); + CHECK(sparse_indices != nullptr && sparse_values != nullptr && default_value != nullptr); + + CHECK(sparse_indices->dtype.is_int()) << "sparse_indices must be tensor of integers"; + + CHECK_LE(sparse_indices->shape.size(), 3) + << "sparse_indices must be a tensor of either 0D, 1D or 2D"; + + CHECK_LE(sparse_values->shape.size(), 2) << "sparse_values must be a tensor of either 0D, 1D"; + + CHECK_EQ(default_value->shape.size(), 0) << "default_value should be a scalar"; + + const auto* param = attrs.as(); + CHECK(param != nullptr); + + Array oshape; + for (auto i : param->output_shape) { + oshape.push_back(i); + } + reporter->Assign(types[3], TensorType(oshape, sparse_values->dtype)); + return true; +} + +Array SparseToDenseCompute(const Attrs& attrs, const Array& inputs, + const Type& out_type) { + CHECK_EQ(inputs.size(), 3); + const auto* param = attrs.as(); + CHECK(param != nullptr); + return {topi::sparse_to_dense(inputs[0], param->output_shape, inputs[1], inputs[2]())}; +} + +TVM_REGISTER_GLOBAL("relay.op._make.sparse_to_dense") + .set_body_typed([](Expr indices, Array output_shape, Expr values, Expr default_value) { + auto attrs = make_object(); + attrs->output_shape = std::move(output_shape); + static const Op& op = Op::Get("sparse_to_dense"); + return Call(op, {indices, values, default_value}, Attrs(attrs)); + }); + +RELAY_REGISTER_OP("sparse_to_dense") + .describe(R"code(A dense tensor from a sparse representation. + + - **sparse_indices**: A 0-D, 1-D, or 2-D tensor of integers containing location of sparse values + + - **output_shape**: A list of integers. Shape of the dense output tensor. + + - **sparse_values**: A 0-D or 1-D tensor containing the sparse values for the sparse indices. + + - **default_value**: A 0-D tensor containing the default value for the remaining locations. Defaults to 0. + + Example:: + - sparse_to_dense([0, 0], [1, 2]], [3, 4], [1, 2], 0) = [[1, 0, 0, 0], [0, 0, 2, 0], [0, 0, 0, 0]] + + )code" TVM_ADD_FILELINE) + .set_num_inputs(3) + .set_support_level(3) + .set_attrs_type() + .add_argument("sparse_indices", "Tensor", "Contains sparse indices.") + .add_argument("sparse_values", "Tensor", "Contains values for sparse indices.") + .add_argument("default_value", "Tensor", "Value to set for non-sparse indices. Defaults to 0.") + .add_type_rel("SparseToDense", SparseToDenseRel) + .set_attr("TOpIsStateful", false) + .set_attr("TOpPattern", kOpaque) + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) + .set_attr("FTVMCompute", SparseToDenseCompute); } // namespace relay } // namespace tvm diff --git a/src/relay/op/tensor/transform.h b/src/relay/op/tensor/transform.h index a64dcd5a6b30..7149417aa9b5 100644 --- a/src/relay/op/tensor/transform.h +++ b/src/relay/op/tensor/transform.h @@ -26,33 +26,34 @@ #include #include -#include +#include + #include #include #include #include #include +#include namespace tvm { namespace relay { +extern Expr MakeReshape(Expr data, Expr newshape); + template -bool ConcatenateRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool ConcatenateRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { // types: [data, result] CHECK_EQ(types.size(), 2); /* If we receive a tuple we can continue, if we receive * anything but an incomplete type we should signal an * error. - */ + */ const auto* tensor_tuple = types[0].as(); if (tensor_tuple == nullptr) { throw Error( - ErrorBuilder() - << "concatenate requires a tuple of tensors as the first argument, found " - << PrettyPrint(types[0])); + ErrorBuilder() << "concatenate requires a tuple of tensors as the first argument, found " + << PrettyPrint(types[0])); } else if (types[0].as() != nullptr) { return false; } @@ -69,10 +70,8 @@ bool ConcatenateRel(const Array& types, // Sanity check: axis int axis = param->axis; if (!(-ndim <= axis && axis < ndim)) { - throw Error(ErrorBuilder() << - "concatenate only accepts `axis` in [-ndim, ndim)" << - ", but got axis = " << axis << - ", and ndim = " << ndim); + throw Error(ErrorBuilder() << "concatenate only accepts `axis` in [-ndim, ndim)" + << ", but got axis = " << axis << ", and ndim = " << ndim); } axis = axis < 0 ? ndim + axis : axis; @@ -91,33 +90,33 @@ bool ConcatenateRel(const Array& types, if (e_dtype != dtype) { throw Error("relay.concatenate requires all tensors have the same dtype"); } - for (size_t j = 0; j < first->shape.size(); ++j) { - if (j == static_cast(axis)) continue; - if (reporter->AssertEQ(first->shape[j], e->shape[j])) continue; - throw Error("relay.concatenate requires all tensors have the same shape " - "on non-concatenating axes"); - } } // Calculate shape std::vector oshape(first->shape.begin(), first->shape.end()); - IndexExpr &concat_dim = oshape[axis]; - bool has_any = false; - if (concat_dim.as()) { - has_any = true; - } else { - for (int i = 1; i < static_cast(tensor_tuple->fields.size()); ++i) { - const auto& e = Downcast(tensor_tuple->fields[i]); - if (e->shape[axis].as()) { - has_any = true; - break; + int data_length = static_cast(tensor_tuple->fields.size()); + for (int i = 0; i < ndim; ++i) { + std::vector non_any; + for (int j = 0; j < data_length; ++j) { + const auto& e = Downcast(tensor_tuple->fields[j]); + if (!e->shape[i].as()) { + non_any.push_back(e->shape[i]); + // accumulate axis dimension + if (j > 0 && i == axis && !oshape[i].as()) { + oshape[i] += e->shape[i]; + } + } + } + int non_any_size = static_cast(non_any.size()); + if (non_any_size != data_length) oshape[i] = Any(); + if (i != axis) { + for (int k = 1; k < non_any_size; k++) { + if (reporter->AssertEQ(non_any[0], non_any[k])) continue; + throw Error( + "relay.concatenate requires all tensors have the same shape " + "on non-concatenating axes"); } - concat_dim += e->shape[axis]; } - } - - if (has_any) { - concat_dim = Any::make(); } auto rtype = TensorType(oshape, dtype); @@ -125,11 +124,10 @@ bool ConcatenateRel(const Array& types, return true; } -static inline Array> ConcatenateLayout( - const Attrs& attrs, - const Array& new_in_layouts, - const Array& old_in_layouts, - const Array &old_in_types) { +static inline Array> ConcatenateLayout(const Attrs& attrs, + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array& old_in_types) { ConcatenateAttrs* param = const_cast(attrs.as()); Array> old_in_shapes; @@ -141,8 +139,8 @@ static inline Array> ConcatenateLayout( } } - size_t axis = param->axis < 0 ? param->axis + old_in_shapes[0].size() : - static_cast(param->axis); + size_t axis = + param->axis < 0 ? param->axis + old_in_shapes[0].size() : static_cast(param->axis); Layout ret; bool is_new_layout_selected = false; @@ -175,11 +173,11 @@ static inline Array> ConcatenateLayout( } if (ret.ndim() <= axis || !ret[axis].IsPrimal()) { - return Array > {{Layout::Undef()}, {Layout::Undef()}}; + return Array>{{Layout::Undef()}, {Layout::Undef()}}; } } - return Array > {Array(old_in_layouts.size(), ret), {ret}}; + return Array>{Array(old_in_layouts.size(), ret), {ret}}; } } // namespace relay diff --git a/src/relay/op/tensor/unary.cc b/src/relay/op/tensor/unary.cc index 10da11d8c7ac..6b72670babaf 100644 --- a/src/relay/op/tensor/unary.cc +++ b/src/relay/op/tensor/unary.cc @@ -21,322 +21,388 @@ * \file unary.cc * \brief Unary operators. */ -#include -#include -#include #include #include -#include "../type_relations.h" +#include +#include +#include + #include "../op_common.h" +#include "../type_relations.h" namespace tvm { namespace relay { -#define RELAY_UNARY_COMPUTE(FTOPI) \ - [] (const Attrs& attrs, \ - const Array& inputs, \ - const Type& out_type) -> Array { \ - return {FTOPI(inputs[0])}; \ - } \ - +#define RELAY_UNARY_COMPUTE(FTOPI) \ + [](const Attrs& attrs, const Array& inputs, \ + const Type& out_type) -> Array { return {FTOPI(inputs[0])}; } RELAY_REGISTER_UNARY_OP("log") -.describe(R"code(Returns the log input array, computed element-wise. + .describe(R"code(Returns the log input array, computed element-wise. .. math:: log(x) )code" TVM_ADD_FILELINE) -.set_support_level(1) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::log)); + .set_support_level(1) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::log)); + +RELAY_REGISTER_UNARY_OP("log2") + .describe(R"code(Returns the log to base 2 of input array, computed element-wise. +.. math:: + log2(x) + +)code" TVM_ADD_FILELINE) + .set_support_level(1) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::log2)); + +RELAY_REGISTER_UNARY_OP("log10") + .describe(R"code(Returns the log to base 10 of input array, computed element-wise. + +.. math:: + log10(x) + +)code" TVM_ADD_FILELINE) + .set_support_level(1) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::log10)); RELAY_REGISTER_UNARY_OP("tan") -.describe(R"code(Returns the tan of input array, computed element-wise. + .describe(R"code(Returns the tan of input array, computed element-wise. .. math:: Y = tan(X) )code" TVM_ADD_FILELINE) -.set_support_level(1) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::tan)); - + .set_support_level(1) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::tan)); RELAY_REGISTER_UNARY_OP("cos") -.describe(R"code(Returns the cos of input array, computed element-wise. + .describe(R"code(Returns the cos of input array, computed element-wise. .. math:: Y = cos(X) )code" TVM_ADD_FILELINE) -.set_support_level(1) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::cos)); + .set_support_level(1) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::cos)); +RELAY_REGISTER_UNARY_OP("cosh") + .describe(R"code(Returns the cosh of input array, computed element-wise. + +.. math:: + Y = cosh(X) + +)code" TVM_ADD_FILELINE) + .set_support_level(1) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::cosh)); RELAY_REGISTER_UNARY_OP("sin") -.describe(R"code(Returns the sin of input array, computed element-wise. + .describe(R"code(Returns the sin of input array, computed element-wise. .. math:: Y = sin(X) )code" TVM_ADD_FILELINE) -.set_support_level(1) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::sin)); + .set_support_level(1) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::sin)); + +RELAY_REGISTER_UNARY_OP("sinh") + .describe(R"code(Returns the sinh of input array, computed element-wise. +.. math:: + Y = sinh(X) + +)code" TVM_ADD_FILELINE) + .set_support_level(1) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::sinh)); + +RELAY_REGISTER_UNARY_OP("acos") + .describe(R"code(Returns the acos of input array, computed element-wise. + +.. math:: + Y = acos(X) + +)code" TVM_ADD_FILELINE) + .set_support_level(1) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::acos)); + +RELAY_REGISTER_UNARY_OP("acosh") + .describe(R"code(Returns the acosh of input array, computed element-wise. + +.. math:: + Y = acosh(X) + +)code" TVM_ADD_FILELINE) + .set_support_level(1) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::acosh)); + +RELAY_REGISTER_UNARY_OP("asin") + .describe(R"code(Returns the asin of input array, computed element-wise. + +.. math:: + Y = asin(X) + +)code" TVM_ADD_FILELINE) + .set_support_level(1) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::asin)); + +RELAY_REGISTER_UNARY_OP("asinh") + .describe(R"code(Returns the asinh of input array, computed element-wise. + +.. math:: + Y = asinh(X) + +)code" TVM_ADD_FILELINE) + .set_support_level(1) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::asinh)); RELAY_REGISTER_UNARY_OP("atan") -.describe(R"code(Returns the atan of input array, computed element-wise. + .describe(R"code(Returns the atan of input array, computed element-wise. .. math:: Y = atan(X) )code" TVM_ADD_FILELINE) -.set_support_level(1) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::atan)); + .set_support_level(1) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::atan)); +RELAY_REGISTER_UNARY_OP("atanh") + .describe(R"code(Returns the atanh of input array, computed element-wise. + +.. math:: + Y = atanh(X) + +)code" TVM_ADD_FILELINE) + .set_support_level(1) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::atanh)); RELAY_REGISTER_UNARY_OP("exp") -.describe(R"code(Returns the exp input array, computed element-wise. + .describe(R"code(Returns the exp input array, computed element-wise. .. math:: \exp(x) )code" TVM_ADD_FILELINE) -.set_support_level(1) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::exp)); - + .set_support_level(1) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::exp)); RELAY_REGISTER_UNARY_OP("fast_exp") -.describe(R"code(Returns the fast_exp input array, computed element-wise. + .describe(R"code(Returns the fast_exp input array, computed element-wise. .. math:: \fast_exp(x) )code" TVM_ADD_FILELINE) -.set_support_level(1) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::fast_exp)); - + .set_support_level(1) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::fast_exp)); RELAY_REGISTER_UNARY_OP("erf") -.describe(R"code(Returns the error function value for input array, computed element-wise. + .describe(R"code(Returns the error function value for input array, computed element-wise. .. math:: \erf(x) )code" TVM_ADD_FILELINE) -.set_support_level(1) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::erf)); - + .set_support_level(1) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::erf)); RELAY_REGISTER_UNARY_OP("fast_erf") -.describe(R"code(Returns the error function value for input array, computed element-wise. + .describe(R"code(Returns the error function value for input array, computed element-wise. .. math:: \fast_erf(x) )code" TVM_ADD_FILELINE) -.set_support_level(1) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::fast_erf)); - + .set_support_level(1) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::fast_erf)); RELAY_REGISTER_UNARY_OP("sqrt") -.describe(R"code(Returns the sqrt input array, computed element-wise. + .describe(R"code(Returns the sqrt input array, computed element-wise. .. math:: sqrt(x) )code" TVM_ADD_FILELINE) -.set_support_level(1) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::sqrt)); + .set_support_level(1) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::sqrt)); RELAY_REGISTER_UNARY_OP("rsqrt") -.describe(R"code(Returns the rsqrt input array, computed element-wise. + .describe(R"code(Returns the rsqrt input array, computed element-wise. .. math:: 1/sqrt(x) )code" TVM_ADD_FILELINE) -.set_support_level(1) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::rsqrt)); + .set_support_level(1) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::rsqrt)); RELAY_REGISTER_UNARY_OP("zeros_like") -.describe(R"code(Returns an array of zeros, with same type and shape as the input. + .describe(R"code(Returns an array of zeros, with same type and shape as the input. )code" TVM_ADD_FILELINE) -.set_support_level(4); + .set_support_level(4); RELAY_REGISTER_UNARY_OP("ones_like") -.describe(R"code(Returns an array of ones, with same type and shape as the input. + .describe(R"code(Returns an array of ones, with same type and shape as the input. )code" TVM_ADD_FILELINE) -.set_support_level(4); + .set_support_level(4); RELAY_REGISTER_UNARY_OP("sigmoid") -.describe(R"code(Returns the sigmoid input array, computed element-wise. + .describe(R"code(Returns the sigmoid input array, computed element-wise. .. math:: sigmoid(x) )code" TVM_ADD_FILELINE) -.set_support_level(1) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::sigmoid)); - + .set_support_level(1) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::sigmoid)); RELAY_REGISTER_UNARY_OP("copy") -.describe(R"code(Copy a tensor. + .describe(R"code(Copy a tensor. )code" TVM_ADD_FILELINE) -.set_support_level(3) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::identity)); + .set_support_level(3) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::identity)); // relay.clip TVM_REGISTER_NODE_TYPE(ClipAttrs); -TVM_REGISTER_GLOBAL("relay.op._make.clip") -.set_body_typed([](Expr a, double a_min, double a_max) { - auto attrs = make_object(); - attrs->a_min = a_min; - attrs->a_max = a_max; - static const Op& op = Op::Get("clip"); +TVM_REGISTER_GLOBAL("relay.op._make.clip").set_body_typed([](Expr a, double a_min, double a_max) { + auto attrs = make_object(); + attrs->a_min = a_min; + attrs->a_max = a_max; + static const Op& op = Op::Get("clip"); return Call(op, {a}, Attrs(attrs), {}); }); RELAY_REGISTER_OP("clip") -.describe(R"code(Clip tensor values. + .describe(R"code(Clip tensor values. This function takes a tensor, a minimum value `a_min`, and a maximum value `a_max`, and returns a clipped tensor where all values below `a_min` are set to `a_min` and all values above `a_max` are set to `a_max`. `a_min` and `a_max` are cast to the tensor's dtype. )code" TVM_ADD_FILELINE) -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.add_type_rel("Identity", IdentityRel) -.set_attr("TOpPattern", kElemWise) -.set_attr("TOpIsStateful", false) -.set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) -.set_attrs_type() -.set_support_level(3); - + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .add_type_rel("Identity", IdentityRel) + .set_attr("TOpPattern", kElemWise) + .set_attr("TOpIsStateful", false) + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) + .set_attrs_type() + .set_support_level(3); RELAY_REGISTER_UNARY_OP("floor") -.describe(R"code(Returns the floor of input array, computed element-wise. + .describe(R"code(Returns the floor of input array, computed element-wise. )code" TVM_ADD_FILELINE) -.set_support_level(3) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::floor)); - + .set_support_level(3) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::floor)); RELAY_REGISTER_UNARY_OP("ceil") -.describe(R"code(Returns the ceil of input array, computed element-wise. + .describe(R"code(Returns the ceil of input array, computed element-wise. .. math:: ceil(x) )code" TVM_ADD_FILELINE) -.set_support_level(3) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::ceil)); - + .set_support_level(3) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::ceil)); RELAY_REGISTER_UNARY_OP("trunc") -.describe(R"code(Returns the trunc of input array, computed element-wise. + .describe(R"code(Returns the trunc of input array, computed element-wise. .. math:: trunc(x) )code" TVM_ADD_FILELINE) -.set_support_level(3) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::trunc)); + .set_support_level(3) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::trunc)); RELAY_REGISTER_UNARY_OP("round") -.describe(R"code(Returns the round of input array, computed element-wise. + .describe(R"code(Returns the round of input array, computed element-wise. .. math:: round(x) )code" TVM_ADD_FILELINE) -.set_support_level(3) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::round)); + .set_support_level(3) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::round)); RELAY_REGISTER_UNARY_OP("sign") -.describe(R"code(Returns the sign of input array, computed element-wise. + .describe(R"code(Returns the sign of input array, computed element-wise. .. numpy:: sign(x) )code" TVM_ADD_FILELINE) -.set_support_level(3) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::sign)); - + .set_support_level(3) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::sign)); RELAY_REGISTER_UNARY_OP("abs") -.describe(R"code(Returns the abs of input array, computed element-wise. + .describe(R"code(Returns the abs of input array, computed element-wise. .. math:: abs(x) )code" TVM_ADD_FILELINE) -.set_support_level(3) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::abs)); - + .set_support_level(3) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::abs)); RELAY_REGISTER_UNARY_OP("tanh") -.describe(R"code(Returns the tanh of input array, computed element-wise. + .describe(R"code(Returns the tanh of input array, computed element-wise. .. math:: Y = sinh(X) / cosh(X) )code" TVM_ADD_FILELINE) -.set_support_level(1) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::tanh)); - + .set_support_level(1) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::tanh)); RELAY_REGISTER_UNARY_OP("fast_tanh") -.describe(R"code(Returns the fast_tanh of input array, computed element-wise. + .describe(R"code(Returns the fast_tanh of input array, computed element-wise. .. math:: Y = sinh(X) / cosh(X) )code" TVM_ADD_FILELINE) -.set_support_level(1) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::fast_tanh)); - + .set_support_level(1) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::fast_tanh)); RELAY_REGISTER_UNARY_OP("negative") -.describe(R"code(Returns the numeric negative of input array, computed element-wise. + .describe(R"code(Returns the numeric negative of input array, computed element-wise. .. math:: -(x) )code" TVM_ADD_FILELINE) -.set_support_level(3) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::negative)); - + .set_support_level(3) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::negative)); RELAY_REGISTER_UNARY_OP("logical_not") -.describe(R"code(Returns the logical inverse of input array, computed element-wise. + .describe(R"code(Returns the logical inverse of input array, computed element-wise. .. math:: !(x) )code" TVM_ADD_FILELINE) -.set_support_level(4) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::logical_not)); - + .set_support_level(4) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::logical_not)); RELAY_REGISTER_UNARY_OP("bitwise_not") -.describe(R"code(Returns the bitwise inverse of input array, computed element-wise. + .describe(R"code(Returns the bitwise inverse of input array, computed element-wise. .. math:: ~(x) )code" TVM_ADD_FILELINE) -.set_support_level(4) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::bitwise_not)); - + .set_support_level(4) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::bitwise_not)); // shape_of TVM_REGISTER_NODE_TYPE(ShapeOfAttrs); -bool ShapeOfRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool ShapeOfRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(num_inputs, 1); auto tt = types[0].as(); - CHECK(tt != nullptr); + if (tt == nullptr) { + return false; + } const auto* param = attrs.as(); CHECK(param != nullptr); auto rank_shape = RankShape(tt->shape); @@ -344,8 +410,7 @@ bool ShapeOfRel(const Array& types, return true; } -Array ShapeOfCompute(const Attrs& attrs, - const Array& inputs, +Array ShapeOfCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { CHECK_EQ(inputs.size(), 1); const auto* param = attrs.as(); @@ -353,8 +418,7 @@ Array ShapeOfCompute(const Attrs& attrs, return {topi::shape(inputs[0], param->dtype)}; } -TVM_REGISTER_GLOBAL("relay.op._make.shape_of") -.set_body_typed([](Expr data, DataType dtype) { +TVM_REGISTER_GLOBAL("relay.op._make.shape_of").set_body_typed([](Expr data, DataType dtype) { auto attrs = make_object(); attrs->dtype = dtype; static const Op& op = Op::Get("shape_of"); @@ -362,29 +426,25 @@ TVM_REGISTER_GLOBAL("relay.op._make.shape_of") }); RELAY_REGISTER_OP("shape_of") -.describe(R"code(Returns a tensor representing the shape of a tensor. + .describe(R"code(Returns a tensor representing the shape of a tensor. )code" TVM_ADD_FILELINE) -.set_num_inputs(1) -.set_attrs_type() -.add_argument("data", "Tensor", "The input tensor.") -.add_type_rel("ShapeOf", ShapeOfRel) -.set_attr("TOpIsStateful", false) -// Use kOpaque for shape_of op for now since it won't be performance critic, -// and it makes things easier for dynamic shape func -.set_attr("TOpPattern", kOpaque) -.set_attr("FInferCorrectLayout", - ElemwiseArbitraryLayout) -.set_support_level(10) -.set_attr("FTVMCompute", ShapeOfCompute); - + .set_num_inputs(1) + .set_attrs_type() + .add_argument("data", "Tensor", "The input tensor.") + .add_type_rel("ShapeOf", ShapeOfRel) + .set_attr("TOpIsStateful", false) + // Use kOpaque for shape_of op for now since it won't be performance critic, + // and it makes things easier for dynamic shape func + .set_attr("TOpPattern", kOpaque) + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) + .set_support_level(10) + .set_attr("FTVMCompute", ShapeOfCompute); TVM_REGISTER_NODE_TYPE(NdarraySizeAttrs); -bool NdarraySizeRel(const Array& types, - int num_inputs, - const Attrs& attrs, - const TypeReporter& reporter) { +bool NdarraySizeRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { CHECK_EQ(num_inputs, 1); auto tt = types[0].as(); CHECK(tt != nullptr); @@ -394,8 +454,7 @@ bool NdarraySizeRel(const Array& types, return true; } -Array NdarraySizeCompute(const Attrs& attrs, - const Array& inputs, +Array NdarraySizeCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { CHECK_EQ(inputs.size(), 1); const auto* param = attrs.as(); @@ -403,8 +462,7 @@ Array NdarraySizeCompute(const Attrs& attrs, return Array{topi::ndarray_size(inputs[0], param->dtype)}; } -TVM_REGISTER_GLOBAL("relay.op._make.ndarray_size") -.set_body_typed([](Expr data, DataType dtype) { +TVM_REGISTER_GLOBAL("relay.op._make.ndarray_size").set_body_typed([](Expr data, DataType dtype) { auto attrs = make_object(); attrs->dtype = dtype; static const Op& op = Op::Get("ndarray_size"); @@ -412,46 +470,45 @@ TVM_REGISTER_GLOBAL("relay.op._make.ndarray_size") }); RELAY_REGISTER_OP("ndarray_size") -.describe(R"code(Returns a tensor representing the number of elements of input tensor. + .describe(R"code(Returns a tensor representing the number of elements of input tensor. )code" TVM_ADD_FILELINE) -.set_num_inputs(1) -.set_attrs_type() -.add_argument("data", "Tensor", "The input tensor.") -.add_type_rel("NdarraySize", NdarraySizeRel) -.set_attr("TOpIsStateful", false) -.set_attr("TOpPattern", kInjective) -.set_attr("FInferCorrectLayout", -ElemwiseArbitraryLayout) -.set_support_level(10) -.set_attr("FTVMCompute", NdarraySizeCompute); + .set_num_inputs(1) + .set_attrs_type() + .add_argument("data", "Tensor", "The input tensor.") + .add_type_rel("NdarraySize", NdarraySizeRel) + .set_attr("TOpIsStateful", false) + .set_attr("TOpPattern", kInjective) + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) + .set_support_level(10) + .set_attr("FTVMCompute", NdarraySizeCompute); RELAY_REGISTER_UNARY_OP("isnan") -.describe(R"code(Returns whether the input contains any NaN, computed element-wise. + .describe(R"code(Returns whether the input contains any NaN, computed element-wise. .. math:: isnan(x) )code" TVM_ADD_FILELINE) -.set_support_level(3) -.add_type_rel("IdentityCompRel", IdentityCompRel) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::isnan)); + .set_support_level(3) + .add_type_rel("IdentityCompRel", IdentityCompRel) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::isnan)); RELAY_REGISTER_UNARY_OP("isfinite") -.describe(R"code(Returns the finiteness of input, computed element-wise. + .describe(R"code(Returns the finiteness of input, computed element-wise. .. math:: isfinite(x) )code" TVM_ADD_FILELINE) -.set_support_level(3) -.add_type_rel("IdentityCompRel", IdentityCompRel) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::isfinite)); + .set_support_level(3) + .add_type_rel("IdentityCompRel", IdentityCompRel) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::isfinite)); RELAY_REGISTER_UNARY_OP("isinf") -.describe(R"code(Returns the infiniteness of input, computed element-wise. + .describe(R"code(Returns the infiniteness of input, computed element-wise. .. math:: isinf(x) )code" TVM_ADD_FILELINE) -.set_support_level(3) -.add_type_rel("IdentityCompRel", IdentityCompRel) -.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::isinf)); + .set_support_level(3) + .add_type_rel("IdentityCompRel", IdentityCompRel) + .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::isinf)); } // namespace relay } // namespace tvm diff --git a/src/relay/op/type_relations.cc b/src/relay/op/type_relations.cc index f9653e24b1b9..46143d16c96d 100644 --- a/src/relay/op/type_relations.cc +++ b/src/relay/op/type_relations.cc @@ -22,18 +22,19 @@ * \brief A set of utilities and common functionality * for type relations. */ +#include "./type_relations.h" + +#include #include #include -#include +#include + #include -#include "./type_relations.h" namespace tvm { namespace relay { -bool IdentityRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool IdentityRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { for (size_t i = 1; i < types.size(); ++i) { reporter->Assign(types[i], types[0]); @@ -41,14 +42,14 @@ bool IdentityRel(const Array& types, return true; } -bool EqualCheck(const IndexExpr& lhs, - const IndexExpr& rhs) { +bool EqualCheck(const IndexExpr& lhs, const IndexExpr& rhs) { IndexExpr diff = lhs - rhs; if (const int64_t* pdiff = tir::as_const_int(diff)) { return pdiff[0] == 0; } // symbolic - diff = tvm::tir::CanonicalSimplify(diff); + tvm::arith::Analyzer ana; + diff = ana.Simplify(diff); if (const int64_t* pdiff = tir::as_const_int(diff)) { return pdiff[0] == 0; } @@ -62,9 +63,7 @@ bool EqualConstInt(const IndexExpr& lhs, int64_t value) { return false; } -Type ConcreteBroadcast(const TensorType& t1, - const TensorType& t2, - DataType output_dtype) { +Type ConcreteBroadcast(const TensorType& t1, const TensorType& t2, DataType output_dtype) { std::vector oshape; size_t ndim1 = t1->shape.size(); size_t ndim2 = t2->shape.size(); @@ -76,18 +75,16 @@ Type ConcreteBroadcast(const TensorType& t1, oshape.push_back(s2); } else if (EqualConstInt(s2, 1)) { oshape.push_back(s1); - } else if (s1.as()) { + } else if (s1.as()) { // s1 == 1 || s1 == s2 oshape.push_back(s2); - } else if (s2.as()) { + } else if (s2.as()) { // s2 == 1 || s2 == s1 oshape.push_back(s1); } else if (EqualCheck(s1, s2)) { oshape.push_back(s1); } else { - throw Error(ErrorBuilder() - << "Incompatible broadcast type " - << t1 << " and " << t2); + throw Error(ErrorBuilder() << "Incompatible broadcast type " << t1 << " and " << t2); } } @@ -96,13 +93,10 @@ Type ConcreteBroadcast(const TensorType& t1, for (; i <= max_ndim; ++i) { oshape.push_back(rshape[max_ndim - i]); } - return TensorType(Array( - oshape.rbegin(), oshape.rend()), output_dtype); + return TensorType(Array(oshape.rbegin(), oshape.rend()), output_dtype); } -bool BroadcastRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool BroadcastRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 3); // DLOG(INFO) << "In1:" << types[0] << ",In2:" << types[1] @@ -110,17 +104,15 @@ bool BroadcastRel(const Array& types, if (auto* t0 = types[0].as()) { if (auto* t1 = types[1].as()) { CHECK_EQ(t0->dtype, t1->dtype); - reporter->Assign(types[2], - ConcreteBroadcast(GetRef(t0), GetRef(t1), t0->dtype)); + reporter->Assign( + types[2], ConcreteBroadcast(GetRef(t0), GetRef(t1), t0->dtype)); return true; } } return false; } -bool BroadcastCompRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool BroadcastCompRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 3); // DLOG(INFO) << "In1:" << types[0] << ",In2:" << types[1] @@ -128,17 +120,15 @@ bool BroadcastCompRel(const Array& types, if (auto* t0 = types[0].as()) { if (auto* t1 = types[1].as()) { CHECK_EQ(t0->dtype, t1->dtype); - reporter->Assign(types[2], - ConcreteBroadcast(GetRef(t0), GetRef(t1), DataType::Bool())); + reporter->Assign(types[2], ConcreteBroadcast(GetRef(t0), GetRef(t1), + DataType::Bool())); return true; } } return false; } -bool IdentityCompRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool IdentityCompRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { if (auto* t0 = types[0].as()) { Type out_type = TensorType(GetRef(t0)->shape, DataType::Bool()); @@ -152,7 +142,7 @@ Array RankShape(const Array& shape) { if (shape.size() == 0) { return {}; } else { - return { tvm::Integer(shape.size()) }; + return {tvm::Integer(shape.size())}; } } diff --git a/src/relay/op/type_relations.h b/src/relay/op/type_relations.h index 48a545bddd0b..acd4b2dae1be 100644 --- a/src/relay/op/type_relations.h +++ b/src/relay/op/type_relations.h @@ -27,6 +27,7 @@ #include #include + #include namespace tvm { @@ -40,9 +41,7 @@ namespace relay { * \param reporter The reporter. * \return true whether relation has been resolved. */ -bool IdentityRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool IdentityRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter); /*! @@ -55,9 +54,7 @@ bool IdentityRel(const Array& types, * \param reporter The reporter. * \return true whether relation has been resolved. */ -bool BroadcastRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool BroadcastRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter); /*! @@ -74,15 +71,11 @@ bool BroadcastRel(const Array& types, * \param reporter The reporter. * \return true whether relation has been resolved. */ -bool BroadcastCompRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool BroadcastCompRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter); -bool IdentityCompRel(const Array& types, - int num_inputs, - const Attrs& attrs, - const TypeReporter& reporter); +bool IdentityCompRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter); Array RankShape(const Array& shape); diff --git a/src/relay/op/vision/multibox_op.cc b/src/relay/op/vision/multibox_op.cc index cafe9b6dd0c3..18a2edb4540a 100644 --- a/src/relay/op/vision/multibox_op.cc +++ b/src/relay/op/vision/multibox_op.cc @@ -21,46 +21,38 @@ * \file multibox_op.cc * \brief Multibox related operators */ -#include -#include #include +#include +#include namespace tvm { namespace relay { TVM_REGISTER_NODE_TYPE(MultiBoxPriorAttrs); -bool MultiboxPriorRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool MultiboxPriorRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); const MultiBoxPriorAttrs* param = attrs.as(); const auto& dshape = data->shape; CHECK_EQ(dshape.size(), 4) << "Input data should be 4D: " - "[batch, channel, height, width]"; + "[batch, channel, height, width]"; IndexExpr in_height = dshape[2]; IndexExpr in_width = dshape[3]; int num_sizes = static_cast(param->sizes.size()); int num_ratios = static_cast(param->ratios.size()); // since input sizes are same in each batch, we could share MultiBoxPrior - std::vector oshape( - {1, in_height * in_width * (num_sizes + num_ratios - 1), 4}); + std::vector oshape({1, in_height * in_width * (num_sizes + num_ratios - 1), 4}); // assign output type reporter->Assign(types[1], TensorType(oshape, data->dtype)); return true; } - -Expr MakeMultiBoxPrior(Expr data, - Array sizes, - Array ratios, - Array steps, - Array offsets, - bool clip) { +Expr MakeMultiBoxPrior(Expr data, Array sizes, Array ratios, + Array steps, Array offsets, bool clip) { auto attrs = make_object(); attrs->sizes = std::move(sizes); attrs->ratios = std::move(ratios); @@ -71,25 +63,20 @@ Expr MakeMultiBoxPrior(Expr data, return Call(op, {data}, Attrs(attrs), {}); } - -TVM_REGISTER_GLOBAL("relay.op.vision._make.multibox_prior") -.set_body_typed(MakeMultiBoxPrior); - +TVM_REGISTER_GLOBAL("relay.op.vision._make.multibox_prior").set_body_typed(MakeMultiBoxPrior); RELAY_REGISTER_OP("vision.multibox_prior") -.describe(R"doc("Generate prior(anchor) boxes from data, sizes and ratios." + .describe(R"doc("Generate prior(anchor) boxes from data, sizes and ratios." )doc" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(1) -.add_argument("data", "Tensor", "The input tensor.") -.set_support_level(5) -.add_type_rel("MultiBoxPrior", MultiboxPriorRel); + .set_attrs_type() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(5) + .add_type_rel("MultiBoxPrior", MultiboxPriorRel); TVM_REGISTER_NODE_TYPE(MultiBoxTransformLocAttrs); -bool MultiBoxTransformLocRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool MultiBoxTransformLocRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 4); @@ -102,20 +89,15 @@ bool MultiBoxTransformLocRel(const Array& types, const auto& loc_shape = loc_pred->shape; const auto& anchor_shape = anchor->shape; - CHECK_EQ(cls_shape.size(), 3U) - << "The dimension of class probability should be 3, but received " - << cls_shape.size(); + CHECK_EQ(cls_shape.size(), 3U) << "The dimension of class probability should be 3, but received " + << cls_shape.size(); CHECK_EQ(loc_shape.size(), 2U) - << "The dimension of location prediction should be 2, but received " - << loc_shape.size(); + << "The dimension of location prediction should be 2, but received " << loc_shape.size(); CHECK_EQ(anchor_shape.size(), 3U) - << "The dimension of anchor should be 3, but received " - << anchor_shape.size(); + << "The dimension of anchor should be 3, but received " << anchor_shape.size(); - CHECK(reporter->AssertEQ(cls_shape[2], anchor_shape[1])) - << "Number of anchors mismatch found"; - CHECK(reporter->AssertEQ(cls_shape[2] * 4, loc_shape[1])) - << "# anchors mismatch with # loc."; + CHECK(reporter->AssertEQ(cls_shape[2], anchor_shape[1])) << "Number of anchors mismatch found"; + CHECK(reporter->AssertEQ(cls_shape[2] * 4, loc_shape[1])) << "# anchors mismatch with # loc."; CHECK(reporter->Assert(anchor_shape[1] > 0)) << "Number of anchors must > 0."; CHECK(reporter->AssertEQ(anchor_shape[2], 4)); @@ -130,12 +112,8 @@ bool MultiBoxTransformLocRel(const Array& types, return true; } -Expr MakeMultiBoxTransformLoc(Expr cls_prob, - Expr loc_pred, - Expr anchor, - bool clip, - double threshold, - Array variances) { +Expr MakeMultiBoxTransformLoc(Expr cls_prob, Expr loc_pred, Expr anchor, bool clip, + double threshold, Array variances) { auto attrs = make_object(); attrs->clip = std::move(clip); attrs->threshold = std::move(threshold); @@ -145,18 +123,18 @@ Expr MakeMultiBoxTransformLoc(Expr cls_prob, } TVM_REGISTER_GLOBAL("relay.op.vision._make.multibox_transform_loc") -.set_body_typed(MakeMultiBoxTransformLoc); + .set_body_typed(MakeMultiBoxTransformLoc); RELAY_REGISTER_OP("vision.multibox_transform_loc") -.describe(R"doc("Location transformation for multibox detection." + .describe(R"doc("Location transformation for multibox detection." )doc" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(3) -.add_argument("cls_prob", "Tensor", "Class probabilities.") -.add_argument("loc_pred", "Tensor", "Location regression predictions.") -.add_argument("anchor", "Tensor", "Multibox prior anchor boxes") -.add_type_rel("MultiBoxTransformLoc", MultiBoxTransformLocRel) -.set_support_level(5); + .set_attrs_type() + .set_num_inputs(3) + .add_argument("cls_prob", "Tensor", "Class probabilities.") + .add_argument("loc_pred", "Tensor", "Location regression predictions.") + .add_argument("anchor", "Tensor", "Multibox prior anchor boxes") + .add_type_rel("MultiBoxTransformLoc", MultiBoxTransformLocRel) + .set_support_level(5); } // namespace relay } // namespace tvm diff --git a/src/relay/op/vision/nms.cc b/src/relay/op/vision/nms.cc index 25743f98bc0b..7486db790780 100644 --- a/src/relay/op/vision/nms.cc +++ b/src/relay/op/vision/nms.cc @@ -21,17 +21,15 @@ * \file nms.cc * \brief Non-maximum suppression operators */ -#include #include +#include namespace tvm { namespace relay { TVM_REGISTER_NODE_TYPE(GetValidCountsAttrs); -bool GetValidCountRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool GetValidCountRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); @@ -39,19 +37,18 @@ bool GetValidCountRel(const Array& types, CHECK_EQ(dshape.size(), 3) << "Input data should be 3-D."; std::vector oshape({data->shape[0]}); + std::vector oshape_indices({data->shape[0], data->shape[1]}); std::vector fields; fields.push_back(TensorType(oshape, DataType::Int(32))); fields.push_back(TensorType(data->shape, data->dtype)); + fields.push_back(TensorType(oshape_indices, DataType::Int(32))); // assign output type reporter->Assign(types[1], TupleType(Array(fields))); return true; } -Expr MakeGetValidCounts(Expr data, - double score_threshold, - int id_index, - int score_index) { +Expr MakeGetValidCounts(Expr data, double score_threshold, int id_index, int score_index) { auto attrs = make_object(); attrs->score_threshold = score_threshold; attrs->id_index = id_index; @@ -60,33 +57,26 @@ Expr MakeGetValidCounts(Expr data, return Call(op, {data}, Attrs(attrs), {}); } - -TVM_REGISTER_GLOBAL("relay.op.vision._make.get_valid_counts") -.set_body_typed(MakeGetValidCounts); - +TVM_REGISTER_GLOBAL("relay.op.vision._make.get_valid_counts").set_body_typed(MakeGetValidCounts); RELAY_REGISTER_OP("vision.get_valid_counts") -.describe(R"doc(Get valid count of bounding boxes given + .describe(R"doc(Get valid count of bounding boxes given a score threshold. Also moves valid boxes to the top of input data. )doc" TVM_ADD_FILELINE) -.set_num_inputs(1) -.add_argument("data", "Tensor", "Input data.") -.set_support_level(5) -.add_type_rel("GetValidCount", GetValidCountRel); - + .set_num_inputs(1) + .add_argument("data", "Tensor", "Input data.") + .set_support_level(5) + .add_type_rel("GetValidCount", GetValidCountRel); TVM_REGISTER_NODE_TYPE(NonMaximumSuppressionAttrs); -bool NMSRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool NMSRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { - CHECK_EQ(types.size(), 3); + CHECK_EQ(types.size(), 4); const auto* data = types[0].as(); const auto* valid_count = types[1].as(); - const NonMaximumSuppressionAttrs* param = - attrs.as(); + const NonMaximumSuppressionAttrs* param = attrs.as(); const auto& dshape = data->shape; const auto& vshape = valid_count->shape; CHECK_EQ(dshape.size(), 3) << "Input data should be 3-D."; @@ -94,26 +84,22 @@ bool NMSRel(const Array& types, // assign output type if (param->return_indices) { + std::vector fields; + // dynamic happens for return_indices in TensorFlow & ONNX std::vector oshape({dshape[0], dshape[1]}); - reporter->Assign(types[2], TensorType(oshape, DataType::Int(32))); + fields.push_back(TensorType(oshape, DataType::Int(32))); + std::vector countshape({dshape[0], 1}); + fields.push_back(TensorType(countshape, DataType::Int(32))); + reporter->Assign(types[3], TupleType(Array(fields))); } else { - reporter->Assign(types[2], TensorType(dshape, data->dtype)); + reporter->Assign(types[3], TensorType(dshape, data->dtype)); } return true; } - -Expr MakeNMS(Expr data, - Expr valid_count, - int max_output_size, - double iou_threshold, - bool force_suppress, - int top_k, - int coord_start, - int score_index, - int id_index, - bool return_indices, - bool invalid_to_bottom) { +Expr MakeNMS(Expr data, Expr valid_count, Expr indices, int max_output_size, double iou_threshold, + bool force_suppress, int top_k, int coord_start, int score_index, int id_index, + bool return_indices, bool invalid_to_bottom) { auto attrs = make_object(); attrs->max_output_size = max_output_size; attrs->iou_threshold = iou_threshold; @@ -125,24 +111,23 @@ Expr MakeNMS(Expr data, attrs->return_indices = return_indices; attrs->invalid_to_bottom = invalid_to_bottom; static const Op& op = Op::Get("vision.non_max_suppression"); - return Call(op, {data, valid_count}, Attrs(attrs), {}); + return Call(op, {data, valid_count, indices}, Attrs(attrs), {}); } - -TVM_REGISTER_GLOBAL("relay.op.vision._make.non_max_suppression") -.set_body_typed(MakeNMS); - +TVM_REGISTER_GLOBAL("relay.op.vision._make.non_max_suppression").set_body_typed(MakeNMS); RELAY_REGISTER_OP("vision.non_max_suppression") -.describe(R"doc(Non-maximum suppression. The input boxes should -be in the format of [class_id, score, left, top, right, bottom]. -Set id_index to be -1 to ignore class_id axis. + .describe(R"doc(Non-maximum suppression. The input boxes should +be in the format of [class_id, score, left, top, right, bottom] +or [score, left, top, right, bottom]. Set id_index to be -1 to +ignore class_id axis. )doc" TVM_ADD_FILELINE) -.set_num_inputs(2) -.add_argument("data", "Tensor", "Input data.") -.add_argument("valid_count", "Tensor", "Number of valid anchor boxes.") -.set_support_level(5) -.add_type_rel("NMS", NMSRel); + .set_num_inputs(3) + .add_argument("data", "Tensor", "Input data.") + .add_argument("valid_count", "Tensor", "Number of valid anchor boxes.") + .add_argument("indices", "Tensor", "Corresponding indices in original input tensor.") + .set_support_level(5) + .add_type_rel("NMS", NMSRel); } // namespace relay } // namespace tvm diff --git a/src/relay/op/vision/rcnn_op.cc b/src/relay/op/vision/rcnn_op.cc index 6b221a279bac..f7e1ecb82dcb 100644 --- a/src/relay/op/vision/rcnn_op.cc +++ b/src/relay/op/vision/rcnn_op.cc @@ -21,9 +21,9 @@ * \file rcnn_op.cc * \brief Faster RCNN and Mask RCNN operators */ +#include #include #include -#include namespace tvm { namespace relay { @@ -36,6 +36,8 @@ bool ROIAlignRel(const Array& types, int num_inputs, const Attrs& attrs, CHECK_EQ(types.size(), 3); const auto* data = types[0].as(); const auto* rois = types[1].as(); + CHECK(data); + CHECK(rois); const auto& dshape = data->shape; const auto& rshape = rois->shape; CHECK(roi_align_attrs); @@ -50,7 +52,7 @@ bool ROIAlignRel(const Array& types, int num_inputs, const Attrs& attrs, } Expr MakeROIAlign(Expr data, Expr rois, Array pooled_size, double spatial_scale, - int sample_ratio, std::string layout) { + int sample_ratio, String layout) { auto attrs = make_object(); attrs->pooled_size = pooled_size; attrs->spatial_scale = spatial_scale; @@ -60,8 +62,7 @@ Expr MakeROIAlign(Expr data, Expr rois, Array pooled_size, double spa return Call(op, {data, rois}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op.vision._make.roi_align") -.set_body_typed(MakeROIAlign); +TVM_REGISTER_GLOBAL("relay.op.vision._make.roi_align").set_body_typed(MakeROIAlign); RELAY_REGISTER_OP("vision.roi_align") .describe(R"doc(ROI Align operator. @@ -73,16 +74,16 @@ RELAY_REGISTER_OP("vision.roi_align") - **out**: This depends on the `layout` parameter. Output is 4D array of shape (num_roi, channels, pooled_height, pooled_width) if `layout` is `NCHW`. )doc" TVM_ADD_FILELINE) -.set_num_inputs(2) -.add_argument("data", "Tensor", "The input tensor.") -.add_argument("rois", "Tensor", "The input rois") -.set_support_level(5) -.add_type_rel("ROIAlign", ROIAlignRel); + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("rois", "Tensor", "The input rois") + .set_support_level(5) + .add_type_rel("ROIAlign", ROIAlignRel); TVM_REGISTER_NODE_TYPE(ROIPoolAttrs); bool ROIPoolRel(const Array& types, int num_inputs, const Attrs& attrs, - const TypeReporter& reporter) { + const TypeReporter& reporter) { auto roi_pool_attrs = attrs.as(); CHECK_EQ(types.size(), 3); const auto* data = types[0].as(); @@ -101,7 +102,7 @@ bool ROIPoolRel(const Array& types, int num_inputs, const Attrs& attrs, } Expr MakeROIPool(Expr data, Expr rois, Array pooled_size, double spatial_scale, - std::string layout) { + String layout) { auto attrs = make_object(); attrs->pooled_size = pooled_size; attrs->spatial_scale = spatial_scale; @@ -110,8 +111,7 @@ Expr MakeROIPool(Expr data, Expr rois, Array pooled_size, double spat return Call(op, {data, rois}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op.vision._make.roi_pool") -.set_body_typed(MakeROIPool); +TVM_REGISTER_GLOBAL("relay.op.vision._make.roi_pool").set_body_typed(MakeROIPool); RELAY_REGISTER_OP("vision.roi_pool") .describe(R"doc(ROI Pool operator. @@ -123,11 +123,11 @@ RELAY_REGISTER_OP("vision.roi_pool") - **out**: This depends on the `layout` parameter. Output is 4D array of shape (num_roi, channels, pooled_height, pooled_width) if `layout` is `NCHW`. )doc" TVM_ADD_FILELINE) -.set_num_inputs(2) -.add_argument("data", "Tensor", "The input tensor.") -.add_argument("rois", "Tensor", "The input rois") -.set_support_level(5) -.add_type_rel("ROIPool", ROIPoolRel); + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("rois", "Tensor", "The input rois") + .set_support_level(5) + .add_type_rel("ROIPool", ROIPoolRel); TVM_REGISTER_NODE_TYPE(ProposalAttrs); @@ -153,16 +153,14 @@ bool ProposalRel(const Array& types, int num_inputs, const Attrs& attrs, auto batch = cls_prob->shape[0]; - std::vector oshape( - {batch * proposal_attrs->rpn_post_nms_top_n, 5}); + std::vector oshape({batch * proposal_attrs->rpn_post_nms_top_n, 5}); reporter->Assign(types[3], TensorType(oshape, cls_prob->dtype)); return true; } Expr MakeProposal(Expr cls_prob, Expr bbox_pred, Expr im_info, Array scales, Array ratios, int feature_stride, double threshold, - int rpn_pre_nms_top_n, int rpn_post_nms_top_n, int rpn_min_size, - bool iou_loss) { + int rpn_pre_nms_top_n, int rpn_post_nms_top_n, int rpn_min_size, bool iou_loss) { auto attrs = make_object(); attrs->scales = scales; attrs->ratios = ratios; @@ -176,8 +174,7 @@ Expr MakeProposal(Expr cls_prob, Expr bbox_pred, Expr im_info, Array return Call(op, {cls_prob, bbox_pred, im_info}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op.vision._make.proposal") -.set_body_typed(MakeProposal); +TVM_REGISTER_GLOBAL("relay.op.vision._make.proposal").set_body_typed(MakeProposal); RELAY_REGISTER_OP("vision.proposal") .describe(R"code(Generate region proposals via RPN. @@ -187,12 +184,12 @@ RELAY_REGISTER_OP("vision.proposal") - **im_info**: 2-D with shape [batch, 3]. - **out**: 2-D with shape [batch * rpn_post_nms_top_n, 5]. )code" TVM_ADD_FILELINE) -.set_num_inputs(3) -.add_argument("cls_prob", "Tensor", "Score of how likely proposal is object") -.add_argument("bbox_pred", "Tensor", "BBox predicted deltas from anchors for proposals") -.add_argument("im_info", "Tensor", "Image size and scale") -.set_support_level(5) -.add_type_rel("Proposal", ProposalRel); + .set_num_inputs(3) + .add_argument("cls_prob", "Tensor", "Score of how likely proposal is object") + .add_argument("bbox_pred", "Tensor", "BBox predicted deltas from anchors for proposals") + .add_argument("im_info", "Tensor", "Image size and scale") + .set_support_level(5) + .add_type_rel("Proposal", ProposalRel); } // namespace relay } // namespace tvm diff --git a/src/relay/op/vision/yolo.cc b/src/relay/op/vision/yolo.cc index 58596778de1d..e54473f68ef7 100644 --- a/src/relay/op/vision/yolo.cc +++ b/src/relay/op/vision/yolo.cc @@ -21,10 +21,12 @@ * \file yolo.cc * \brief Yolo related operators */ -#include -#include #include +#include +#include + #include + #include "../op_common.h" #include "../type_relations.h" @@ -34,15 +36,13 @@ namespace relay { TVM_REGISTER_NODE_TYPE(YoloReorgAttrs); /*! -* \brief YoloReorgRel Output type and shape relation evaluation function. -* \param num_inputs Number of input types in the args. -* \param attrs The additional attributes of the operator. -* \param reporter The reporter to report solution to. -* \return false if This relation cannot be resolved. true if this relation has been resolved. -*/ -bool YoloReorgRel(const Array& types, - int num_inputs, - const Attrs& attrs, + * \brief YoloReorgRel Output type and shape relation evaluation function. + * \param num_inputs Number of input types in the args. + * \param attrs The additional attributes of the operator. + * \param reporter The reporter to report solution to. + * \return false if This relation cannot be resolved. true if this relation has been resolved. + */ +bool YoloReorgRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); @@ -60,34 +60,29 @@ bool YoloReorgRel(const Array& types, return true; } -Expr MakeYoloReorg(Expr data, - Integer stride) { +Expr MakeYoloReorg(Expr data, Integer stride) { auto attrs = make_object(); attrs->stride = stride; static const Op& op = Op::Get("vision.yolo_reorg"); return Call(op, {data}, Attrs(attrs), {}); } - -TVM_REGISTER_GLOBAL("relay.op.vision._make.yolo_reorg") -.set_body_typed(MakeYoloReorg); - +TVM_REGISTER_GLOBAL("relay.op.vision._make.yolo_reorg").set_body_typed(MakeYoloReorg); RELAY_REGISTER_OP("vision.yolo_reorg") -.describe(R"doc("Yolo reorg operation. This layer reorganize the output. + .describe(R"doc("Yolo reorg operation. This layer reorganize the output. Its function is mostly shape transform.")doc" TVM_ADD_FILELINE) -.add_argument("data", "Tensor", "The input tensor.") -.set_num_inputs(1) -.set_support_level(5) -.set_attrs_type() -.add_type_rel("YoloReorg", YoloReorgRel) -.set_attr("FTVMCompute", [](const Attrs& attrs, - const Array& inputs, - const Type& out_type) { - const auto* params = attrs.as(); - CHECK(params != nullptr); - return Array{ topi::vision::reorg(inputs[0], params->stride) }; -}); + .add_argument("data", "Tensor", "The input tensor.") + .set_num_inputs(1) + .set_support_level(5) + .set_attrs_type() + .add_type_rel("YoloReorg", YoloReorgRel) + .set_attr("FTVMCompute", [](const Attrs& attrs, const Array& inputs, + const Type& out_type) { + const auto* params = attrs.as(); + CHECK(params != nullptr); + return Array{topi::vision::reorg(inputs[0], params->stride)}; + }); } // namespace relay } // namespace tvm diff --git a/src/relay/qnn/op/add.cc b/src/relay/qnn/op/add.cc index d8752d8030d7..b0dc3e4af5c4 100644 --- a/src/relay/qnn/op/add.cc +++ b/src/relay/qnn/op/add.cc @@ -23,6 +23,7 @@ */ #include #include + #include "op_common.h" namespace tvm { @@ -44,7 +45,6 @@ Expr QnnAddCanonicalize(const Attrs& attrs, const Array& new_args, // Get the input dtype and shape. QnnBinaryOpTensorType input_type(arg_types, 0); - // FIXME (anijain2305) - The lowering can be further optimized. Instead of inserting requantize in // the start, we can insert requantize at the end if both input tensors have same qnn params. In // that case, we can first add the tensors, subtract the zero point, and requantize at the end. @@ -65,18 +65,14 @@ Expr QnnAddCanonicalize(const Attrs& attrs, const Array& new_args, // Q_c = Q_a' + Q_b' - zp_c // The add op is done in int32 precision. - - // Requantize LHS if necessary. Computes Q_a' - auto requantized_lhs = RequantizeOrUpcast(args.lhs, args.lhs_scale, - args.lhs_zero_point, - args.output_scale, args.output_zero_point, - input_type.shape); + auto requantized_lhs = + RequantizeOrUpcast(args.lhs, args.lhs_scale, args.lhs_zero_point, args.output_scale, + args.output_zero_point, input_type.shape); // Requantize RHS if necessary. Computes Q_b' - auto requantized_rhs = RequantizeOrUpcast(args.rhs, args.rhs_scale, - args.rhs_zero_point, - args.output_scale, args.output_zero_point, - input_type.shape); + auto requantized_rhs = + RequantizeOrUpcast(args.rhs, args.rhs_scale, args.rhs_zero_point, args.output_scale, + args.output_zero_point, input_type.shape); // Computes Q_a' + Q_b' auto output = Add(requantized_lhs, requantized_rhs); @@ -92,9 +88,9 @@ Expr QnnAddCanonicalize(const Attrs& attrs, const Array& new_args, // QNN Addition operator. QNN_REGISTER_BINARY_OP("add") -.describe("Elementwise add with with broadcasting for quantized tensors.") -.set_support_level(11) -.set_attr("FTVMQnnCanonicalize", QnnAddCanonicalize); + .describe("Elementwise add with with broadcasting for quantized tensors.") + .set_support_level(11) + .set_attr("FTVMQnnCanonicalize", QnnAddCanonicalize); } // namespace qnn } // namespace relay diff --git a/src/relay/qnn/op/concatenate.cc b/src/relay/qnn/op/concatenate.cc index 650dcb962d44..bda8cf878793 100644 --- a/src/relay/qnn/op/concatenate.cc +++ b/src/relay/qnn/op/concatenate.cc @@ -22,13 +22,14 @@ * \brief QNN concatenate operator. It concatenates quantized input tensors along a given axis. */ -#include #include #include #include +#include + #include "../../op/tensor/transform.h" -#include "../../transforms/pattern_util.h" #include "../../transforms/infer_layout_util.h" +#include "../../transforms/pattern_util.h" #include "../util.h" namespace tvm { @@ -42,10 +43,9 @@ bool QnnConcatenateRel(const Array& types, int num_inputs, const Attrs& at // Check the scale and zero point types const auto* input_scales_tuple = types[1].as(); if (input_scales_tuple == nullptr) { - throw Error( - ErrorBuilder() - << "qnn concatenate requires a tuple of scales as the second argument, found " - << PrettyPrint(types[1])); + throw Error(ErrorBuilder() + << "qnn concatenate requires a tuple of scales as the second argument, found " + << PrettyPrint(types[1])); } for (const auto& input_scale : input_scales_tuple->fields) { CHECK(IsScalarType(input_scale, DataType::Float(32))); // input_scales[idx] @@ -53,10 +53,9 @@ bool QnnConcatenateRel(const Array& types, int num_inputs, const Attrs& at const auto* input_zero_points_tuple = types[2].as(); if (input_zero_points_tuple == nullptr) { - throw Error( - ErrorBuilder() - << "qnn concatenate requires a tuple of zero_points as the third argument, found " - << PrettyPrint(types[2])); + throw Error(ErrorBuilder() + << "qnn concatenate requires a tuple of zero_points as the third argument, found " + << PrettyPrint(types[2])); } for (const auto& input_zero_point : input_zero_points_tuple->fields) { CHECK(IsScalarType(input_zero_point, DataType::Int(32))); // input_zero_points[idx] @@ -113,9 +112,8 @@ Expr MakeQnnConcatenate(Expr data, Expr input_scales, Expr input_zero_points, Ex auto attrs = make_object(); attrs->axis = axis; static const Op& op = Op::Get("qnn.concatenate"); - return Call(op, - {data, input_scales, input_zero_points, output_scale, output_zero_point}, - Attrs(attrs), {}); + return Call(op, {data, input_scales, input_zero_points, output_scale, output_zero_point}, + Attrs(attrs), {}); } /* @@ -149,8 +147,16 @@ Expr ConcatenateQnnCanonicalize(const Attrs& attrs, const Array& new_args, // If the output qnn params do not match the input qnn params, we can call requantize on the input // expr first, followed by a concatenate on the requantized input exprs. - auto tuple_data = data.as(); - CHECK(tuple_data != nullptr); + Array tuple_exprs; + if (data->IsInstance()) { + tuple_exprs = data.as()->fields; + } else if (data->IsInstance()) { // if the data is a CallNode, use TupleGetItems + auto call = Downcast(data); + for (size_t i = 0; i < tuple_type->fields.size(); i++) { + tuple_exprs.push_back(TupleGetItem(call, i)); + } + } + CHECK(!tuple_exprs.empty()); auto tuple_input_scales = input_scales.as(); CHECK(tuple_input_scales != nullptr); @@ -160,7 +166,7 @@ Expr ConcatenateQnnCanonicalize(const Attrs& attrs, const Array& new_args, int idx = 0; Array requantized_exprs; - for (auto quantized_expr : tuple_data->fields) { + for (auto quantized_expr : tuple_exprs) { // Get the input scale for the idx quantized input tensor. auto input_scale = tuple_input_scales->fields[idx]; @@ -188,22 +194,23 @@ Expr ConcatenateQnnCanonicalize(const Attrs& attrs, const Array& new_args, } RELAY_REGISTER_OP("qnn.concatenate") -.describe(R"code(Concatenate the quantized input tensors along the given axis. + .describe(R"code(Concatenate the quantized input tensors along the given axis. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(5) -.add_argument("data", "Tensor", "The tensor to concatenate.") -.add_argument("input_scales", "Tensor", "The quantization scales of the input tensors.") -.add_argument("input_zero_points", "Tensor", "The quantization zero_points of the input tensors.") -.add_argument("output_scale", "Tensor", "The quantization scale of the output tensor.") -.add_argument("output_zero_point", "Tensor", "The quantization zero_point of the output tensor.") -.set_support_level(11) -.add_type_rel("QnnConcatenate", QnnConcatenateRel) -.set_attr("FTVMQnnCanonicalize", ConcatenateQnnCanonicalize) -.set_attr("FInferCorrectLayout", QnnConcatenateLayout); - -TVM_REGISTER_GLOBAL("relay.qnn.op._make.concatenate") -.set_body_typed(MakeQnnConcatenate); + .set_attrs_type() + .set_num_inputs(5) + .add_argument("data", "Tensor", "The tensor to concatenate.") + .add_argument("input_scales", "Tensor", "The quantization scales of the input tensors.") + .add_argument("input_zero_points", "Tensor", + "The quantization zero_points of the input tensors.") + .add_argument("output_scale", "Tensor", "The quantization scale of the output tensor.") + .add_argument("output_zero_point", "Tensor", + "The quantization zero_point of the output tensor.") + .set_support_level(11) + .add_type_rel("QnnConcatenate", QnnConcatenateRel) + .set_attr("FTVMQnnCanonicalize", ConcatenateQnnCanonicalize) + .set_attr("FInferCorrectLayout", QnnConcatenateLayout); + +TVM_REGISTER_GLOBAL("relay.qnn.op._make.concatenate").set_body_typed(MakeQnnConcatenate); } // namespace qnn } // namespace relay diff --git a/src/relay/qnn/op/convolution.cc b/src/relay/qnn/op/convolution.cc index 37186283ba51..9412ab4393c5 100644 --- a/src/relay/qnn/op/convolution.cc +++ b/src/relay/qnn/op/convolution.cc @@ -21,15 +21,16 @@ * \file src/relay/qnn/op/convolution.cc * \brief Property def of qnn convolution operator. */ -#include +#include "../../op/nn/convolution.h" + #include #include #include #include #include #include +#include -#include "../../op/nn/convolution.h" #include "../../transforms/pattern_util.h" #include "../util.h" @@ -88,9 +89,8 @@ Array> QnnConvInferCorrectLayout(const Attrs& attrs, } bool is_depthwise(const Conv2DAttrs* param) { - return param->channels.defined() && - tvm::tir::ExprDeepEqual()(param->channels, param->groups) && - param->groups != 1; + return param->channels.defined() && tvm::tir::ExprDeepEqual()(param->channels, param->groups) && + param->groups != 1; } // Workload - batch_size, in_channels, out_channels, kernel_h, kernel_w, channel_multiplier @@ -201,8 +201,8 @@ Expr Conv2DPadInput(const Expr& data, const Expr& input_zero_point, const Conv2D auto pad_left_value = get_const_int(param->padding[1]); auto pad_bottom_value = get_const_int(param->padding[2]); auto pad_right_value = get_const_int(param->padding[3]); - bool do_pad = pad_top_value != 0 || pad_left_value != 0 || - pad_bottom_value != 0 || pad_right_value != 0; + bool do_pad = + pad_top_value != 0 || pad_left_value != 0 || pad_bottom_value != 0 || pad_right_value != 0; if (do_pad) { Array pad_n({0, 0}); Array pad_c({0, 0}); @@ -662,8 +662,8 @@ Expr QnnConv2DCanonicalize(const Attrs& attrs, const Array& new_args, Expr MakeQnnConv2D(Expr data, Expr weight, Expr input_zero_point, Expr kernel_zero_point, Expr input_scale, Expr kernel_scale, Array strides, Array padding, Array dilation, int groups, - IndexExpr channels, Array kernel_size, std::string data_layout, - std::string kernel_layout, std::string out_layout, DataType out_dtype) { + IndexExpr channels, Array kernel_size, String data_layout, + String kernel_layout, String out_layout, DataType out_dtype) { auto attrs = make_object(); attrs->strides = std::move(strides); attrs->padding = std::move(padding); @@ -676,13 +676,12 @@ Expr MakeQnnConv2D(Expr data, Expr weight, Expr input_zero_point, Expr kernel_ze attrs->out_layout = std::move(out_layout); attrs->out_dtype = std::move(out_dtype); static const Op& op = Op::Get("qnn.conv2d"); - return Call( - op, {data, weight, input_zero_point, kernel_zero_point, input_scale, kernel_scale}, - Attrs(attrs), {}); + return Call(op, {data, weight, input_zero_point, kernel_zero_point, input_scale, kernel_scale}, + Attrs(attrs), {}); } RELAY_REGISTER_OP("qnn.conv2d") -.describe(R"code(2D quantized convolution layer. + .describe(R"code(2D quantized convolution layer. This operator convolves quantized weight with quantized data. The scale of the output quantized tensor is the product of the weight_scale and input_scale of the input quantized tensors. The zero point of the output quantized tensor is @@ -694,18 +693,19 @@ operator to understand how to scale back the int32 output to (u)int8. - **out**: This depends on the `layout` parameter. Output is 4D array of shape (batch_size, channels, out_height, out_width) if `layout` is `NCHW`. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(6) -.add_argument("data", "Tensor", "The quantized input data tensor.") -.add_argument("weight", "Tensor", "The quantized weight tensor.") -.add_argument("input_scale", "Tensor", "The quantization scale of the input tensor.") -.add_argument("input_zero_point", "Tensor", "The quantization zero_point of the input tensor.") -.add_argument("weight_scale", "Tensor", "The quantization scale of the weight tensor.") -.add_argument("weight_zero_point", "Tensor", "The quantization zero_point of the weight tensor.") -.set_support_level(11) -.add_type_rel("QnnConv2D", QnnConv2DRel) -.set_attr("FTVMQnnCanonicalize", QnnConv2DCanonicalize) -.set_attr("FInferCorrectLayout", QnnConvInferCorrectLayout); + .set_attrs_type() + .set_num_inputs(6) + .add_argument("data", "Tensor", "The quantized input data tensor.") + .add_argument("weight", "Tensor", "The quantized weight tensor.") + .add_argument("input_scale", "Tensor", "The quantization scale of the input tensor.") + .add_argument("input_zero_point", "Tensor", "The quantization zero_point of the input tensor.") + .add_argument("weight_scale", "Tensor", "The quantization scale of the weight tensor.") + .add_argument("weight_zero_point", "Tensor", + "The quantization zero_point of the weight tensor.") + .set_support_level(11) + .add_type_rel("QnnConv2D", QnnConv2DRel) + .set_attr("FTVMQnnCanonicalize", QnnConv2DCanonicalize) + .set_attr("FInferCorrectLayout", QnnConvInferCorrectLayout); TVM_REGISTER_GLOBAL("relay.qnn.op._make.conv2d").set_body_typed(MakeQnnConv2D); diff --git a/src/relay/qnn/op/dense.cc b/src/relay/qnn/op/dense.cc index 7b9733c36586..464b3f9aeff3 100644 --- a/src/relay/qnn/op/dense.cc +++ b/src/relay/qnn/op/dense.cc @@ -26,6 +26,7 @@ #include #include #include + #include "../../op/nn/nn.h" #include "../../transforms/pattern_util.h" #include "../util.h" @@ -72,9 +73,8 @@ Expr MakeQuantizedDense(Expr data, Expr weight, Expr input_zero_point, Expr kern attrs->units = std::move(units); attrs->out_dtype = out_dtype; static const Op& op = Op::Get("qnn.dense"); - return Call( - op, {data, weight, input_zero_point, kernel_zero_point, input_scale, kernel_scale}, - Attrs(attrs), {}); + return Call(op, {data, weight, input_zero_point, kernel_zero_point, input_scale, kernel_scale}, + Attrs(attrs), {}); } Expr DenseFirstTerm(const Expr& quantized_data, const Expr& quantized_kernel, @@ -173,25 +173,25 @@ Expr QnnDenseCanonicalize(const Attrs& attrs, const Array& new_args, } RELAY_REGISTER_OP("qnn.dense") -.describe(R"code(Applies a linear transformation: :math:`Y = XW^T`. + .describe(R"code(Applies a linear transformation: :math:`Y = XW^T`. - **data**: quantized(int8, unit8) `(x1, x2, ..., xn, input_dim)` - **weight**: quantized(int8, unit8) `(units, input_dim)` - **out**: quantized(int32) `(x1, x2, ..., xn, units)`. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(6) -.add_argument("data", "quantized nD Tensor", "Input data.") -.add_argument("weight", "quantized 2D Tensor", "Weight matrix.") -.add_argument("input_scale", "Tensor", "The quantization scale of the input tensor.") -.add_argument("input_zero_point", "Tensor", "The quantization zero_point of the input tensor.") -.add_argument("weight_scale", "Tensor", "The quantization scale of the weight tensor.") -.add_argument("weight_zero_point", "Tensor", "The quantization zero_point of the weight tensor.") -.set_support_level(11) -.add_type_rel("QDense", QnnDenseRel) -.set_attr("FTVMQnnCanonicalize", QnnDenseCanonicalize); - -TVM_REGISTER_GLOBAL("relay.qnn.op._make.dense") -.set_body_typed(MakeQuantizedDense); + .set_attrs_type() + .set_num_inputs(6) + .add_argument("data", "quantized nD Tensor", "Input data.") + .add_argument("weight", "quantized 2D Tensor", "Weight matrix.") + .add_argument("input_scale", "Tensor", "The quantization scale of the input tensor.") + .add_argument("input_zero_point", "Tensor", "The quantization zero_point of the input tensor.") + .add_argument("weight_scale", "Tensor", "The quantization scale of the weight tensor.") + .add_argument("weight_zero_point", "Tensor", + "The quantization zero_point of the weight tensor.") + .set_support_level(11) + .add_type_rel("QDense", QnnDenseRel) + .set_attr("FTVMQnnCanonicalize", QnnDenseCanonicalize); + +TVM_REGISTER_GLOBAL("relay.qnn.op._make.dense").set_body_typed(MakeQuantizedDense); } // namespace qnn } // namespace relay diff --git a/src/relay/qnn/op/dequantize.cc b/src/relay/qnn/op/dequantize.cc index 69389a7317aa..7c014d71a76a 100644 --- a/src/relay/qnn/op/dequantize.cc +++ b/src/relay/qnn/op/dequantize.cc @@ -26,6 +26,7 @@ #include #include #include + #include "../../transforms/pattern_util.h" #include "../util.h" @@ -33,19 +34,16 @@ namespace tvm { namespace relay { namespace qnn { -bool DequantizeRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool DequantizeRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 4); const auto* data = types[0].as(); CHECK(data != nullptr); const auto input_dtype = data->dtype; - CHECK(input_dtype == DataType::Int(8) || - input_dtype == DataType::UInt(8) || + CHECK(input_dtype == DataType::Int(8) || input_dtype == DataType::UInt(8) || input_dtype == DataType::Int(32)) - << "Input type should be one of the quantized types [unit8, int8, int32] but was " - << input_dtype; + << "Input type should be one of the quantized types [unit8, int8, int32] but was " + << input_dtype; // Check the types of scale and zero points. CHECK(IsScalarType(types[1], DataType::Float(32))); // input_scale @@ -83,20 +81,19 @@ Expr DequantizeQnnCanonicalize(const Attrs& attrs, const Array& new_args, } RELAY_REGISTER_OP("qnn.dequantize") -.describe(R"code(Dequantizes the input and produces float32 output. + .describe(R"code(Dequantizes the input and produces float32 output. The input is always quantized (int8, uint8) and will be converted to float32 given input scale and zero_point. - **data**: Quantized tensor of any shape to dequantize. The input data can be of floating point )code" TVM_ADD_FILELINE) -.set_num_inputs(3) -.add_argument("data", "Tensor", "The tensor to dequantize.") -.add_argument("input_scale", "Tensor", "The quantization scale of the input tensor.") -.add_argument("input_zero_point", "Tensor", "The quantization zero_point of the input tensor.") -.set_support_level(11) -.add_type_rel("Dequantize", DequantizeRel) -.set_attr("FTVMQnnCanonicalize", DequantizeQnnCanonicalize); + .set_num_inputs(3) + .add_argument("data", "Tensor", "The tensor to dequantize.") + .add_argument("input_scale", "Tensor", "The quantization scale of the input tensor.") + .add_argument("input_zero_point", "Tensor", "The quantization zero_point of the input tensor.") + .set_support_level(11) + .add_type_rel("Dequantize", DequantizeRel) + .set_attr("FTVMQnnCanonicalize", DequantizeQnnCanonicalize); -TVM_REGISTER_GLOBAL("relay.qnn.op._make.dequantize") -.set_body_typed(MakeDequantize); +TVM_REGISTER_GLOBAL("relay.qnn.op._make.dequantize").set_body_typed(MakeDequantize); } // namespace qnn } // namespace relay diff --git a/src/relay/qnn/op/mul.cc b/src/relay/qnn/op/mul.cc index 5f9251b35080..ec74b799407b 100644 --- a/src/relay/qnn/op/mul.cc +++ b/src/relay/qnn/op/mul.cc @@ -24,6 +24,7 @@ #include #include #include + #include "../../transforms/pattern_util.h" #include "../util.h" #include "op_common.h" @@ -85,21 +86,17 @@ Expr QnnMulCanonicalize(const Attrs& attrs, const Array& new_args, auto new_input_zero_point = zero_scalar; // Requantize to get Q_c - output = Requantize(output, input_type.shape, - new_input_scale, - new_input_zero_point, - args.output_scale, - args.output_zero_point, - input_type.dtype); + output = Requantize(output, input_type.shape, new_input_scale, new_input_zero_point, + args.output_scale, args.output_zero_point, input_type.dtype); return output; } // QNN Multiplication operator. QNN_REGISTER_BINARY_OP("mul") -.describe("Elementwise mul with with broadcasting for quantized tensors.") -.set_support_level(11) -.set_attr("FTVMQnnCanonicalize", QnnMulCanonicalize); + .describe("Elementwise mul with with broadcasting for quantized tensors.") + .set_support_level(11) + .set_attr("FTVMQnnCanonicalize", QnnMulCanonicalize); } // namespace qnn } // namespace relay diff --git a/src/relay/qnn/op/op_common.h b/src/relay/qnn/op/op_common.h index f780f70dc7b3..50fc0cda30cf 100644 --- a/src/relay/qnn/op/op_common.h +++ b/src/relay/qnn/op/op_common.h @@ -28,7 +28,9 @@ #include #include #include + #include + #include "../../op/type_relations.h" #include "../../transforms/infer_layout_util.h" #include "../util.h" @@ -87,10 +89,9 @@ struct QnnBinaryOpArguments { */ struct QnnBinaryOpTensorType { DataType dtype; - Array shape; + Array shape; - explicit QnnBinaryOpTensorType(const Array& arg_types, - const int32_t arg_idx) { + explicit QnnBinaryOpTensorType(const Array& arg_types, const int32_t arg_idx) { CHECK_EQ(arg_types.size(), kNumQnnBinaryOpArgTypes); auto tensor_type = arg_types[arg_idx].as(); CHECK(tensor_type != nullptr); @@ -109,8 +110,7 @@ struct QnnBinaryOpTensorType { * \return New expression with target dtype and possibly lower * precision. */ -inline Expr ConvertDtype(const Expr& expr, - const DataType& target_dtype) { +inline Expr ConvertDtype(const Expr& expr, const DataType& target_dtype) { auto q_min = GetQmin(target_dtype); auto q_max = GetQmax(target_dtype); auto output = Clip(expr, q_min, q_max); @@ -134,18 +134,15 @@ inline Expr ConvertDtype(const Expr& expr, * it simply casts the given expression to Int32 as no requantization is * needed in this case. */ -inline Expr RequantizeOrUpcast(const Expr& expr, - const Expr& expr_scale, - const Expr& expr_zero_point, - const Expr& target_scale, - const Expr& target_zero_point, - const Array& expr_shape, +inline Expr RequantizeOrUpcast(const Expr& expr, const Expr& expr_scale, + const Expr& expr_zero_point, const Expr& target_scale, + const Expr& target_zero_point, const Array& expr_shape, const DataType& target_dtype = DataType::Int(32)) { auto result = expr; if (!IsEqualScalar(expr_scale, target_scale) || !IsEqualScalar(expr_zero_point, target_zero_point)) { - result = Requantize(expr, expr_shape, expr_scale, expr_zero_point, - target_scale, target_zero_point, target_dtype); + result = Requantize(expr, expr_shape, expr_scale, expr_zero_point, target_scale, + target_zero_point, target_dtype); } else { result = Cast(result, target_dtype); } @@ -153,27 +150,23 @@ inline Expr RequantizeOrUpcast(const Expr& expr, } /*! \brief Infer layout for QNN binary broadcast operators */ -inline Array > QnnBinaryBroadcastLayout( - const Attrs& attrs, - const Array& new_in_layouts, - const Array& old_in_layouts, - const Array& old_in_types) { +inline Array > QnnBinaryBroadcastLayout(const Attrs& attrs, + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array& old_in_types) { // Use Relay Binary Broadcast Infer correct layout. auto layouts = BinaryBroadcastLayout(attrs, new_in_layouts, old_in_layouts, old_in_types); // Fill the layouts of remaining input tensors - scales and zero points. The layouts of these // tensors can be treated as C. Layout channel_layout = Layout("C"); - Array input_layouts = {layouts[0][0], layouts[0][1], channel_layout, channel_layout, + Array input_layouts = {layouts[0][0], layouts[0][1], channel_layout, channel_layout, channel_layout, channel_layout, channel_layout, channel_layout}; Array output_layouts = layouts[1]; return {input_layouts, output_layouts}; } - -static inline bool QnnBroadcastRel(const Array& types, - int num_inputs, - const Attrs& attrs, +static inline bool QnnBroadcastRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), kNumQnnBinaryOpArgTypes); @@ -201,28 +194,28 @@ static inline bool QnnBroadcastRel(const Array& types, * * \param OpName the name of registry. */ -#define QNN_REGISTER_BINARY_OP(OpName) \ - TVM_REGISTER_GLOBAL("relay.qnn.op._make." OpName) \ - .set_body_typed([](Expr lhs, Expr rhs, Expr lhs_scale, Expr lhs_zero_point, Expr rhs_scale, \ - Expr rhs_zero_point, Expr output_scale, Expr output_zero_point) { \ - static const Op& op = Op::Get("qnn." OpName); \ - return Call(op, {lhs, rhs, \ - lhs_scale, lhs_zero_point, \ - rhs_scale, rhs_zero_point, \ - output_scale, output_zero_point}, Attrs(), {}); \ - }); \ - RELAY_REGISTER_OP("qnn." OpName) \ - .set_num_inputs(kNumQnnBinaryOpInputs) \ - .add_argument("lhs", "Tensor", "The left hand side quantized tensor.") \ - .add_argument("rhs", "Tensor", "The right hand side quantized tensor.") \ - .add_argument("lhs_scale", "Tensor", "The scale of the lhs tensor.") \ - .add_argument("lhs_zero_point", "Tensor", "The zero_point of the lhs tensor.") \ - .add_argument("rhs_scale", "Tensor", "The scale of the rhs tensor.") \ - .add_argument("rhs_zero_point", "Tensor", "The zero_point of the rhs tensor.") \ - .add_argument("output_scale", "Tensor", "The scale of the output tensor.") \ - .add_argument("output_zero_point", "Tensor", "The zero_point of the output tensor.") \ - .add_type_rel("QnnBroadcast", QnnBroadcastRel) \ - .set_attr("FInferCorrectLayout", QnnBinaryBroadcastLayout) +#define QNN_REGISTER_BINARY_OP(OpName) \ + TVM_REGISTER_GLOBAL("relay.qnn.op._make." OpName) \ + .set_body_typed([](Expr lhs, Expr rhs, Expr lhs_scale, Expr lhs_zero_point, Expr rhs_scale, \ + Expr rhs_zero_point, Expr output_scale, Expr output_zero_point) { \ + static const Op& op = Op::Get("qnn." OpName); \ + return Call(op, \ + {lhs, rhs, lhs_scale, lhs_zero_point, rhs_scale, rhs_zero_point, output_scale, \ + output_zero_point}, \ + Attrs(), {}); \ + }); \ + RELAY_REGISTER_OP("qnn." OpName) \ + .set_num_inputs(kNumQnnBinaryOpInputs) \ + .add_argument("lhs", "Tensor", "The left hand side quantized tensor.") \ + .add_argument("rhs", "Tensor", "The right hand side quantized tensor.") \ + .add_argument("lhs_scale", "Tensor", "The scale of the lhs tensor.") \ + .add_argument("lhs_zero_point", "Tensor", "The zero_point of the lhs tensor.") \ + .add_argument("rhs_scale", "Tensor", "The scale of the rhs tensor.") \ + .add_argument("rhs_zero_point", "Tensor", "The zero_point of the rhs tensor.") \ + .add_argument("output_scale", "Tensor", "The scale of the output tensor.") \ + .add_argument("output_zero_point", "Tensor", "The zero_point of the output tensor.") \ + .add_type_rel("QnnBroadcast", QnnBroadcastRel) \ + .set_attr("FInferCorrectLayout", QnnBinaryBroadcastLayout) } // namespace qnn } // namespace relay diff --git a/src/relay/qnn/op/quantize.cc b/src/relay/qnn/op/quantize.cc index 43ba4b6b1ba4..28f0b8994a01 100644 --- a/src/relay/qnn/op/quantize.cc +++ b/src/relay/qnn/op/quantize.cc @@ -26,6 +26,7 @@ #include #include #include + #include "../../transforms/pattern_util.h" #include "../util.h" @@ -35,24 +36,21 @@ namespace qnn { TVM_REGISTER_NODE_TYPE(QuantizeAttrs); -bool QuantizeRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool QuantizeRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 4); const auto* data = types[0].as(); CHECK(data != nullptr); const auto input_dtype = data->dtype; CHECK(input_dtype == DataType::Float(32)) - << "Input type should be one of float32 but was " << input_dtype; + << "Input type should be one of float32 but was " << input_dtype; const auto* quantize_attrs = attrs.as(); int axis = quantize_attrs->axis; - axis = (axis == -1) ? data->shape.size() - 1: axis; + axis = (axis == -1) ? data->shape.size() - 1 : axis; CHECK_LT(axis, static_cast(data->shape.size())) << "axis " << quantize_attrs->axis << " is out of range"; - CHECK_GE(axis, 0) - << "axis " << quantize_attrs->axis << " is out of range"; + CHECK_GE(axis, 0) << "axis " << quantize_attrs->axis << " is out of range"; // Check and assign types for scale and zero points. AssignType(types[1], DataType::Float(32), data->shape[axis], reporter); // scale @@ -130,7 +128,7 @@ Expr QuantizeQnnCanonicalize(const Attrs& attrs, const Array& new_args, } RELAY_REGISTER_OP("qnn.quantize") -.describe(R"code(Quantizes the input and produces quantized output. + .describe(R"code(Quantizes the input and produces quantized output. The input can be either float or quantized(int8, unit8). If the input is float, this op takes scale and zero point and quantize the float value to quantized output, in int8 or uint8 format. If the input is quantized value, @@ -140,17 +138,17 @@ scale and zero point. - **data**: Tensor of any shape to quantize. The input data can be of floating point or quantized. )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(3) -.add_argument("data", "Tensor", "The tensor to quantize.") -.add_argument("output_scale", "Tensor", "The quantization scale of the output tensor.") -.add_argument("output_zero_point", "Tensor", "The quantization zero_point of the output tensor.") -.set_support_level(11) -.add_type_rel("Quantize", QuantizeRel) -.set_attr("FTVMQnnCanonicalize", QuantizeQnnCanonicalize); - -TVM_REGISTER_GLOBAL("relay.qnn.op._make.quantize") -.set_body_typed(MakeQuantize); + .set_attrs_type() + .set_num_inputs(3) + .add_argument("data", "Tensor", "The tensor to quantize.") + .add_argument("output_scale", "Tensor", "The quantization scale of the output tensor.") + .add_argument("output_zero_point", "Tensor", + "The quantization zero_point of the output tensor.") + .set_support_level(11) + .add_type_rel("Quantize", QuantizeRel) + .set_attr("FTVMQnnCanonicalize", QuantizeQnnCanonicalize); + +TVM_REGISTER_GLOBAL("relay.qnn.op._make.quantize").set_body_typed(MakeQuantize); } // namespace qnn } // namespace relay diff --git a/src/relay/qnn/op/requantize.cc b/src/relay/qnn/op/requantize.cc index a2a46497e197..bdeaf05c86bd 100644 --- a/src/relay/qnn/op/requantize.cc +++ b/src/relay/qnn/op/requantize.cc @@ -25,8 +25,9 @@ #include #include #include -#include "../../transforms/pattern_util.h" + #include "../../transforms/infer_layout_util.h" +#include "../../transforms/pattern_util.h" #include "../util.h" namespace tvm { @@ -68,7 +69,7 @@ Array> RequantizeInferCorrectLayout(const Attrs& attrs, for (auto iter_var : new_in_layouts[0]->axes) { const auto& layout_axis = LayoutAxis::Get(iter_var); const std::string& layout_dim = layout_axis.name(); - if (old_dim == layout_dim) { + if (old_dim == layout_dim) { new_axis = tvm::Integer(axis_index); } // Collect only the primal axis. @@ -249,18 +250,16 @@ bool RequantizeRel(const Array& types, int num_inputs, const Attrs& attrs, const auto* data = types[0].as(); CHECK(data != nullptr); const auto in_dtype = data->dtype; - CHECK(in_dtype == DataType::Int(8) || - in_dtype == DataType::UInt(8) || + CHECK(in_dtype == DataType::Int(8) || in_dtype == DataType::UInt(8) || in_dtype == DataType::Int(32)) << "Input type should be one of [int8, uint8, int32] but was " << in_dtype; const RequantizeAttrs* requantize_attrs = attrs.as(); int axis = requantize_attrs->axis; - axis = (axis == -1) ? data->shape.size() - 1: axis; + axis = (axis == -1) ? data->shape.size() - 1 : axis; CHECK_LT(axis, static_cast(data->shape.size())) << "axis " << requantize_attrs->axis << " is out of range"; - CHECK_GE(axis, 0) - << "axis " << requantize_attrs->axis << " is out of range"; + CHECK_GE(axis, 0) << "axis " << requantize_attrs->axis << " is out of range"; // Check and assign types for scale and zero points. AssignType(types[1], DataType::Float(32), data->shape[axis], reporter); // input_scale @@ -272,8 +271,7 @@ bool RequantizeRel(const Array& types, int num_inputs, const Attrs& attrs, const Array oshape = data->shape; // assign output type auto out_dtype = requantize_attrs->out_dtype; - CHECK(out_dtype == DataType::Int(8) || - out_dtype == DataType::UInt(8) || + CHECK(out_dtype == DataType::Int(8) || out_dtype == DataType::UInt(8) || out_dtype == DataType::Int(32)) << "Output type should be one of [int8, uint8, int32] but was " << out_dtype; reporter->Assign(types[5], TensorType(oshape, out_dtype)); @@ -283,18 +281,18 @@ bool RequantizeRel(const Array& types, int num_inputs, const Attrs& attrs, // Positional relay function to create qnn requantize operator // used by frontend FFI. Expr MakeRequantize(Expr data, Expr input_scale, Expr input_zero_point, Expr output_scale, - Expr output_zero_point, int axis, std::string rounding, DataType out_dtype) { + Expr output_zero_point, int axis, String rounding, DataType out_dtype) { auto attrs = make_object(); attrs->axis = axis; attrs->rounding = std::move(rounding); attrs->out_dtype = std::move(out_dtype); static const Op& op = Op::Get("qnn.requantize"); return Call(op, {data, input_scale, input_zero_point, output_scale, output_zero_point}, - Attrs(attrs), {}); + Attrs(attrs), {}); } RELAY_REGISTER_OP("qnn.requantize") -.describe(R"code(Requantize operator. + .describe(R"code(Requantize operator. The requantize operator converts one quantized tensor to another quantized tensor. For the output tensor, we are provided with output scale and zero point. The computation looks like this @@ -302,20 +300,20 @@ point. The computation looks like this Q_output = zp_output + (scale_input)/(scale_output) * (Q_input - zp_input) )code" TVM_ADD_FILELINE) -.set_attrs_type() -.set_num_inputs(5) -.add_argument("data", "Tensor", "The quantized input tensor.") -.add_argument("input_scale", "Tensor", "The quantization scale of the input tensor.") -.add_argument("input_zero_point", "Tensor", "The quantization zero_point of the input tensor.") -.add_argument("output_scale", "Tensor", "The quantization scale of the output tensor.") -.add_argument("output_zero_point", "Tensor", "The quantization zero_point of the output tensor.") -.set_support_level(11) -.add_type_rel("Requantize", RequantizeRel) -.set_attr("FTVMQnnCanonicalize", RequantizeQnnCanonicalize) -.set_attr("FInferCorrectLayout", RequantizeInferCorrectLayout); - -TVM_REGISTER_GLOBAL("relay.qnn.op._make.requantize") -.set_body_typed(MakeRequantize); + .set_attrs_type() + .set_num_inputs(5) + .add_argument("data", "Tensor", "The quantized input tensor.") + .add_argument("input_scale", "Tensor", "The quantization scale of the input tensor.") + .add_argument("input_zero_point", "Tensor", "The quantization zero_point of the input tensor.") + .add_argument("output_scale", "Tensor", "The quantization scale of the output tensor.") + .add_argument("output_zero_point", "Tensor", + "The quantization zero_point of the output tensor.") + .set_support_level(11) + .add_type_rel("Requantize", RequantizeRel) + .set_attr("FTVMQnnCanonicalize", RequantizeQnnCanonicalize) + .set_attr("FInferCorrectLayout", RequantizeInferCorrectLayout); + +TVM_REGISTER_GLOBAL("relay.qnn.op._make.requantize").set_body_typed(MakeRequantize); } // namespace qnn } // namespace relay diff --git a/src/relay/qnn/op/subtract.cc b/src/relay/qnn/op/subtract.cc index c6ce3e33f48f..b928bd5e465c 100644 --- a/src/relay/qnn/op/subtract.cc +++ b/src/relay/qnn/op/subtract.cc @@ -23,6 +23,7 @@ */ #include #include + #include "op_common.h" namespace tvm { @@ -36,8 +37,7 @@ namespace qnn { * \param arg_types The types of input and output. * \return The sequence of Relay ops for add op. */ -Expr QnnSubtractCanonicalize(const Attrs& attrs, - const Array& new_args, +Expr QnnSubtractCanonicalize(const Attrs& attrs, const Array& new_args, const Array& arg_types) { // Get the args. QnnBinaryOpArguments args(new_args); @@ -66,17 +66,13 @@ Expr QnnSubtractCanonicalize(const Attrs& attrs, // The subtract op is done in int32 precision. // Requantize LHS if necessary. Computes Q_a' - auto requantized_lhs = RequantizeOrUpcast(args.lhs, args.lhs_scale, - args.lhs_zero_point, - args.output_scale, - args.output_zero_point, - input_type.shape); + auto requantized_lhs = + RequantizeOrUpcast(args.lhs, args.lhs_scale, args.lhs_zero_point, args.output_scale, + args.output_zero_point, input_type.shape); // Requantize RHS if necessary. Computes Q_b' - auto requantized_rhs = RequantizeOrUpcast(args.rhs, args.rhs_scale, - args.rhs_zero_point, - args.output_scale, - args.output_zero_point, - input_type.shape); + auto requantized_rhs = + RequantizeOrUpcast(args.rhs, args.rhs_scale, args.rhs_zero_point, args.output_scale, + args.output_zero_point, input_type.shape); // Computes Q_a' - Q_b' auto output = Subtract(requantized_lhs, requantized_rhs); @@ -93,10 +89,9 @@ Expr QnnSubtractCanonicalize(const Attrs& attrs, // QNN Subtraction operator. QNN_REGISTER_BINARY_OP("subtract") -.describe("Elementwise subtract with with broadcasting for quantized tensors.") -.set_support_level(11) -.set_attr("FTVMQnnCanonicalize", QnnSubtractCanonicalize); - + .describe("Elementwise subtract with with broadcasting for quantized tensors.") + .set_support_level(11) + .set_attr("FTVMQnnCanonicalize", QnnSubtractCanonicalize); } // namespace qnn } // namespace relay diff --git a/src/relay/qnn/util.cc b/src/relay/qnn/util.cc index 91fe3ca2a948..4daa5c9334de 100644 --- a/src/relay/qnn/util.cc +++ b/src/relay/qnn/util.cc @@ -23,6 +23,7 @@ */ #include "util.h" + #include "../transforms/pattern_util.h" namespace tvm { @@ -48,8 +49,7 @@ namespace qnn { * * Credit to TFLite reference implementation. */ -std::pair GetFixedPointMultiplierShift( - double double_multiplier) { +std::pair GetFixedPointMultiplierShift(double double_multiplier) { int32_t significand, exponent; if (double_multiplier == 0.) { significand = 0; @@ -84,8 +84,7 @@ Expr FixedPointMultiply(Expr tensor, double multiplier, const Array& // 1) Calculating the integer multiplier and integer shift int32_t fixed_point_multiplier, shift; - std::tie(fixed_point_multiplier, shift) = - GetFixedPointMultiplierShift(multiplier); + std::tie(fixed_point_multiplier, shift) = GetFixedPointMultiplierShift(multiplier); int left_shift = shift > 0 ? shift : 0; int right_shift = shift > 0 ? 0 : -shift; @@ -119,8 +118,7 @@ Expr FixedPointMultiply(Expr tensor, double multiplier, const Array& auto neg_rounder_t = Full(neg_rounder, input_shape, hp_dtype); auto zero_t = Zeros(input_shape, hp_dtype); - round_scalar = - Where(GreaterEqual(tensor, zero_t), pos_rounder_t, neg_rounder_t); + round_scalar = Where(GreaterEqual(tensor, zero_t), pos_rounder_t, neg_rounder_t); } else { LOG(FATAL) << "Rounding mode " << rounding << " not supported."; } @@ -128,8 +126,7 @@ Expr FixedPointMultiply(Expr tensor, double multiplier, const Array& tensor = Add(tensor, round_scalar); // 5) Simply right shift the result to get the final output. - tensor = - RightShift(tensor, MakeConstantScalar(hp_dtype, total_right_shift)); + tensor = RightShift(tensor, MakeConstantScalar(hp_dtype, total_right_shift)); // 6) The fixed point multiplication keeps the value in int32 range. Casting back to int32. return Cast(tensor, DataType::Int(32)); @@ -205,8 +202,8 @@ Expr FixedPointMultiplyPerChannel(Expr tensor, std::vector multipliers, round_scalar = exp_pos_rounding_value_expr; } else if (rounding == "TONEAREST") { // To satisfy where op shape requirements, the rounding values are broadcasted. - auto pos_rounder = MakeBroadCastTo(exp_pos_rounding_value_expr, input_shape); - auto neg_rounder = MakeBroadCastTo(exp_neg_rounding_value_expr, input_shape); + auto pos_rounder = BroadCastTo(exp_pos_rounding_value_expr, input_shape); + auto neg_rounder = BroadCastTo(exp_neg_rounding_value_expr, input_shape); auto zero_t = Zeros(input_shape, hp_dtype); round_scalar = Where(GreaterEqual(tensor, zero_t), pos_rounder, neg_rounder); diff --git a/src/relay/qnn/util.h b/src/relay/qnn/util.h index d4046ae90607..736b7361a300 100644 --- a/src/relay/qnn/util.h +++ b/src/relay/qnn/util.h @@ -25,14 +25,15 @@ #ifndef TVM_RELAY_QNN_UTIL_H_ #define TVM_RELAY_QNN_UTIL_H_ -#include -#include #include #include +#include +#include + #include #include -#include #include +#include namespace tvm { namespace relay { @@ -46,8 +47,7 @@ static inline Array get_shape(const Type& type) { } static inline int32_t GetQmin(const DataType& dtype) { - CHECK_LE(dtype.bits(), 32) - << "QNN ops support int32 or lower precision"; + CHECK_LE(dtype.bits(), 32) << "QNN ops support int32 or lower precision"; if (dtype.is_int() || dtype.is_uint()) { auto* min_value = tir::as_const_int(tvm::min_value(dtype)); CHECK(min_value != nullptr); @@ -59,8 +59,7 @@ static inline int32_t GetQmin(const DataType& dtype) { } static inline int32_t GetQmax(const DataType& dtype) { - CHECK_LE(dtype.bits(), 32) - << "QNN ops support int32 or lower precision"; + CHECK_LE(dtype.bits(), 32) << "QNN ops support int32 or lower precision"; if (dtype.is_int() || dtype.is_uint()) { auto* max_value = tir::as_const_int(tvm::max_value(dtype)); CHECK(max_value != nullptr); @@ -171,8 +170,7 @@ static inline void AssignType(const Type& expr_type, const DataType& dtype, cons const TypeReporter& reporter) { // Scale/Zero_points can be either const scalar or a vector with C axis num elems. const auto* tensor_type = expr_type.as(); - CHECK(tensor_type) << "Can assign type to Tensor type only. But got " - << AsText(expr_type, false); + CHECK(tensor_type) << "Can assign type to Tensor type only. But got " << AsText(expr_type, false); const auto tensor_dtype = tensor_type->dtype; CHECK(tensor_dtype == dtype) << "Expected type is " << dtype << " but received " << tensor_dtype; if (tensor_type->shape.size() != 0) { diff --git a/src/relay/quantize/annotate.cc b/src/relay/quantize/annotate.cc index 4492ed5bebca..8ae7df9e2941 100644 --- a/src/relay/quantize/annotate.cc +++ b/src/relay/quantize/annotate.cc @@ -24,8 +24,9 @@ * \brief Annotating the graph with simulated quantize operators. */ -#include #include +#include + #include "./quantize.h" namespace tvm { @@ -63,10 +64,7 @@ class QAnnotateExpr : public TempExpr { TVM_DEFINE_OBJECT_REF_METHODS(QAnnotateExpr, TempExpr, QAnnotateExprNode); }; - -Expr QAnnotateExprNode::Realize() const { - return expr; -} +Expr QAnnotateExprNode::Realize() const { return expr; } QAnnotateExpr::QAnnotateExpr(Expr expr, QAnnotateKind kind) { auto rnode = make_object(); @@ -75,12 +73,10 @@ QAnnotateExpr::QAnnotateExpr(Expr expr, QAnnotateKind kind) { data_ = std::move(rnode); } -TVM_REGISTER_GLOBAL("relay._quantize.make_annotate_expr") -.set_body_typed([](Expr expr, int kind) { +TVM_REGISTER_GLOBAL("relay._quantize.make_annotate_expr").set_body_typed([](Expr expr, int kind) { return QAnnotateExpr(expr, static_cast(kind)); }); - Pass QuantizeAnnotate() { // TODO(tvm-teams): since partition has added cast_hint in different // branches, try to remove this in the future. @@ -88,8 +84,7 @@ Pass QuantizeAnnotate() { if (e->IsInstance()) { const auto* n = e.as(); CHECK(n); - const PackedFunc* f = - runtime::Registry::Get("relay.quantize.attach_simulated_quantize"); + const PackedFunc* f = runtime::Registry::Get("relay.quantize.attach_simulated_quantize"); Expr ret = (*f)(n->expr, static_cast(kQInput)); return static_cast(QAnnotateExpr(ret, kQInput)); } @@ -97,23 +92,18 @@ Pass QuantizeAnnotate() { }; runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - auto func = Downcast(ForwardRewrite(f, "FQAnnotateRewrite", nullptr, fmulti_ref)); - auto new_params = func->params; - for (const auto& x : FreeVars(func)) { - new_params.push_back(x); - } - return Function(new_params, - func->body, - func->ret_type, - func->type_params, - func->attrs); - }; + [=](Function f, IRModule m, PassContext pc) { + auto func = Downcast(ForwardRewrite(f, "FQAnnotateRewrite", nullptr, fmulti_ref)); + auto new_params = func->params; + for (const auto& x : FreeVars(func)) { + new_params.push_back(x); + } + return Function(new_params, func->body, func->ret_type, func->type_params, func->attrs); + }; return CreateFunctionPass(pass_func, 1, "QuantizeAnnotate", {}); } -TVM_REGISTER_GLOBAL("relay._quantize.QuantizeAnnotate") -.set_body_typed(QuantizeAnnotate); +TVM_REGISTER_GLOBAL("relay._quantize.QuantizeAnnotate").set_body_typed(QuantizeAnnotate); TVM_REGISTER_NODE_TYPE(QAnnotateExprNode); diff --git a/src/relay/quantize/calibrate.cc b/src/relay/quantize/calibrate.cc index 7b1e909501b5..ea42a198bf84 100644 --- a/src/relay/quantize/calibrate.cc +++ b/src/relay/quantize/calibrate.cc @@ -26,7 +26,9 @@ #include #include #include + #include + #include "./quantize.h" namespace tvm { @@ -65,8 +67,8 @@ static std::vector SmoothDistribution(const std::vector& p, } static float ComputeEntropy(float* p, float* q, size_t size) { - float p_sum = std::accumulate(p, p+size, 0.f); - float q_sum = std::accumulate(q, q+size, 0.f); + float p_sum = std::accumulate(p, p + size, 0.f); + float q_sum = std::accumulate(q, q + size, 0.f); float ret = 0; for (size_t i = 0; i < size; i++) { CHECK(p[i] > 0 && q[i] > 0); @@ -77,9 +79,8 @@ static float ComputeEntropy(float* p, float* q, size_t size) { return ret; } -float MinimizeKL(const std::vector& hist, - const std::vector& hist_edges, - int num_bins, int num_quantized_bins) { +float MinimizeKL(const std::vector& hist, const std::vector& hist_edges, int num_bins, + int num_quantized_bins) { const int zero_bin_idx = num_bins / 2; const int num_half_quantized_bins = num_quantized_bins / 2; std::vector thresholds(num_bins / 2 + 1 - num_quantized_bins / 2, 0.f); @@ -137,9 +138,9 @@ float MinimizeKL(const std::vector& hist, divergence[i - num_half_quantized_bins] = ComputeEntropy(p.data(), q.data(), p.size()); } } - auto min_divergence_idx = std::distance(divergence.begin(), - std::min_element(divergence.begin(), divergence.end())); - return thresholds[min_divergence_idx];; + auto min_divergence_idx = + std::distance(divergence.begin(), std::min_element(divergence.begin(), divergence.end())); + return thresholds[min_divergence_idx]; } class StatsCollector : private ExprMutator { @@ -152,7 +153,7 @@ class StatsCollector : private ExprMutator { CHECK(func) << "Input shoule be Function"; Expr new_body = Tuple(std::move(profile_data_)); return Function(FreeVars(new_body), new_body, NullValue(), func->type_params, - func->attrs); + func->attrs); } private: @@ -167,7 +168,7 @@ class StatsCollector : private ExprMutator { auto attrs = new_call->attrs.as(); // rewrite the annotation auto new_attrs = make_object(); - const Expr& quantize_input = new_call->args[0]; // expression being quantized + const Expr& quantize_input = new_call->args[0]; // expression being quantized auto placeholder = MakeConstantScalar(DataType::Float(32), 0.); // unused argument Array new_args{quantize_input, placeholder, placeholder, placeholder}; new_attrs->kind = QAnnotateKind::kQIdentity; @@ -198,24 +199,20 @@ class StatsCollector : private ExprMutator { * \param expr The simulation graph after annotation. * \return The profile graph. */ -Expr CreateStatsCollector(const Expr& expr) { - return StatsCollector().Collect(expr); -} - -TVM_REGISTER_GLOBAL("relay._quantize.CreateStatsCollector") -.set_body_typed(CreateStatsCollector); +Expr CreateStatsCollector(const Expr& expr) { return StatsCollector().Collect(expr); } +TVM_REGISTER_GLOBAL("relay._quantize.CreateStatsCollector").set_body_typed(CreateStatsCollector); TVM_REGISTER_GLOBAL("relay._quantize.FindScaleByKLMinimization") -.set_body([](TVMArgs args, TVMRetValue *ret) { - int* hist_ptr = static_cast(static_cast(args[0])); - float* hist_edges_ptr = static_cast(static_cast(args[1])); - int num_bins = args[2]; - int num_quantized_bins = args[3]; - std::vector hist(hist_ptr, hist_ptr + num_bins); - std::vector hist_edges(hist_edges_ptr, hist_edges_ptr + num_bins + 1); - ret[0] = MinimizeKL(hist, hist_edges, num_bins, num_quantized_bins); -}); + .set_body([](TVMArgs args, TVMRetValue* ret) { + int* hist_ptr = static_cast(static_cast(args[0])); + float* hist_edges_ptr = static_cast(static_cast(args[1])); + int num_bins = args[2]; + int num_quantized_bins = args[3]; + std::vector hist(hist_ptr, hist_ptr + num_bins); + std::vector hist_edges(hist_edges_ptr, hist_edges_ptr + num_bins + 1); + ret[0] = MinimizeKL(hist, hist_edges, num_bins, num_quantized_bins); + }); } // namespace quantize } // namespace relay diff --git a/src/relay/quantize/partition.cc b/src/relay/quantize/partition.cc index 39de0bc49d4c..14b420d6034c 100644 --- a/src/relay/quantize/partition.cc +++ b/src/relay/quantize/partition.cc @@ -25,6 +25,7 @@ */ #include + #include "../transforms/pattern_util.h" #include "./quantize.h" @@ -34,16 +35,13 @@ namespace quantize { using namespace relay::transform; - class QPartitionExpr; class QPartitionExprNode : public TempExprNode { public: /*! \brief The original expression */ Expr expr; - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("expr", &expr); - } + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("expr", &expr); } Expr Realize() const final; @@ -62,7 +60,6 @@ class QPartitionExpr : public TempExpr { TVM_DEFINE_OBJECT_REF_METHODS(QPartitionExpr, TempExpr, QPartitionExprNode); }; - Expr QPartitionExprNode::Realize() const { // insert cast hint and stop fusion const QConfig& cfg = QConfig::Current(); @@ -76,23 +73,20 @@ QPartitionExpr::QPartitionExpr(Expr expr) { data_ = std::move(rnode); } -TVM_REGISTER_GLOBAL("relay._quantize.make_partition_expr") -.set_body_typed([](Expr expr) { +TVM_REGISTER_GLOBAL("relay._quantize.make_partition_expr").set_body_typed([](Expr expr) { return QPartitionExpr(expr); }); Pass QuantizePartition() { runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - auto ret = Downcast( - ForwardRewrite(f, "FQPartitionRewrite", nullptr, nullptr)); - return ret; - }; + [=](Function f, IRModule m, PassContext pc) { + auto ret = Downcast(ForwardRewrite(f, "FQPartitionRewrite", nullptr, nullptr)); + return ret; + }; return CreateFunctionPass(pass_func, 1, "QuantizePartition", {}); } -TVM_REGISTER_GLOBAL("relay._quantize.QuantizePartition") -.set_body_typed(QuantizePartition); +TVM_REGISTER_GLOBAL("relay._quantize.QuantizePartition").set_body_typed(QuantizePartition); TVM_REGISTER_NODE_TYPE(QPartitionExprNode); diff --git a/src/relay/quantize/quantize.cc b/src/relay/quantize/quantize.cc index 631d8c0fdf58..1bf858b43db0 100644 --- a/src/relay/quantize/quantize.cc +++ b/src/relay/quantize/quantize.cc @@ -23,12 +23,13 @@ * \brief transform a graph to a low-bit graph * for compression and acceleration. */ +#include "./quantize.h" + #include #include #include -#include -#include "./quantize.h" +#include namespace tvm { namespace relay { @@ -36,9 +37,7 @@ namespace quantize { TVM_REGISTER_NODE_TYPE(SimulatedQuantizeAttrs); -bool SimulatedQuantizeRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool SimulatedQuantizeRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 5); const auto param = attrs.as(); @@ -48,36 +47,34 @@ bool SimulatedQuantizeRel(const Array& types, CHECK(data != nullptr); CHECK_NE(data->shape.size(), 0) << "Input shape cannot be empty"; - reporter->Assign(types[1], TensorType({}, DataType::Float(32))); // dom_scale - reporter->Assign(types[2], TensorType({}, DataType::Float(32))); // clip_min - reporter->Assign(types[3], TensorType({}, DataType::Float(32))); // clip_max - reporter->Assign(types[4], types[0]); // output + reporter->Assign(types[1], TensorType({}, DataType::Float(32))); // dom_scale + reporter->Assign(types[2], TensorType({}, DataType::Float(32))); // clip_min + reporter->Assign(types[3], TensorType({}, DataType::Float(32))); // clip_max + reporter->Assign(types[4], types[0]); // output return true; } RELAY_REGISTER_OP("relay.op.annotation.simulated_quantize") -.describe(R"code(simulated quantize op)code" TVM_ADD_FILELINE) -.set_num_inputs(4) -.add_argument("data", "Tensor", "The input data.") -.add_argument("dom_scale", "Tensor", "The domain scale of input data. It should be a scalar") -.add_argument("clip_min", "Tensor", "lower bound. It should be a scalar") -.add_argument("clip_max", "Tensor", "upper bound. It should be a scalar") -.set_attrs_type() -.set_support_level(11) -.add_type_rel("SimulatedQuantize", SimulatedQuantizeRel); + .describe(R"code(simulated quantize op)code" TVM_ADD_FILELINE) + .set_num_inputs(4) + .add_argument("data", "Tensor", "The input data.") + .add_argument("dom_scale", "Tensor", "The domain scale of input data. It should be a scalar") + .add_argument("clip_min", "Tensor", "lower bound. It should be a scalar") + .add_argument("clip_max", "Tensor", "upper bound. It should be a scalar") + .set_attrs_type() + .set_support_level(11) + .add_type_rel("SimulatedQuantize", SimulatedQuantizeRel); TVM_REGISTER_GLOBAL("relay._quantize.simulated_quantize") -.set_body_typed( - [](Expr data, Expr dom_scale, Expr clip_min, Expr clip_max, - int kind, bool sign, std::string rounding) { - auto attrs = make_object(); - attrs->kind = kind; - attrs->sign = sign; - attrs->rounding = rounding; - static const Op& op = Op::Get("relay.op.annotation.simulated_quantize"); - return Call(op, {data, dom_scale, clip_min, clip_max}, Attrs(attrs), {}); - }); - + .set_body_typed([](Expr data, Expr dom_scale, Expr clip_min, Expr clip_max, int kind, bool sign, + String rounding) { + auto attrs = make_object(); + attrs->kind = kind; + attrs->sign = sign; + attrs->rounding = rounding; + static const Op& op = Op::Get("relay.op.annotation.simulated_quantize"); + return Call(op, {data, dom_scale, clip_min, clip_max}, Attrs(attrs), {}); + }); /*! \brief Entry to hold the BuildConfig context stack. */ struct TVMQConfigThreadLocalEntry { @@ -87,26 +84,24 @@ struct TVMQConfigThreadLocalEntry { /*! \brief The current build config context */ std::stack context_stack; - TVMQConfigThreadLocalEntry() : - default_config(make_object()) { - } + TVMQConfigThreadLocalEntry() : default_config(make_object()) {} }; /*! \brief Thread local store to hold the BuildConfig context stack. */ typedef dmlc::ThreadLocalStore TVMQConfigThreadLocalStore; void QConfig::EnterQConfigScope(const QConfig& build_config) { - TVMQConfigThreadLocalEntry *entry = TVMQConfigThreadLocalStore::Get(); + TVMQConfigThreadLocalEntry* entry = TVMQConfigThreadLocalStore::Get(); entry->context_stack.push(build_config); } void QConfig::ExitQConfigScope() { - TVMQConfigThreadLocalEntry *entry = TVMQConfigThreadLocalStore::Get(); + TVMQConfigThreadLocalEntry* entry = TVMQConfigThreadLocalStore::Get(); entry->context_stack.pop(); } QConfig& QConfig::Current() { - TVMQConfigThreadLocalEntry *entry = TVMQConfigThreadLocalStore::Get(); + TVMQConfigThreadLocalEntry* entry = TVMQConfigThreadLocalStore::Get(); if (entry->context_stack.size() > 0) { return entry->context_stack.top(); } @@ -117,31 +112,31 @@ QConfig& QConfig::Current() { TVM_REGISTER_NODE_TYPE(QConfigNode); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* op = static_cast(ref.get()); - p->stream << "qconfig("; - p->stream << "nbit_input=" << op->nbit_input << ", "; - p->stream << "nbit_weight=" << op->nbit_weight << ", "; - p->stream << "nbit_activation=" << op->nbit_activation << ", "; - p->stream << "calibrate_mode=" << op->calibrate_mode << ", "; - p->stream << "global_scale=" << op->global_scale << ", "; - p->stream << "weight_scale=" << op->weight_scale << ", "; - p->stream << "skip_conv_layers==" << op->skip_conv_layers << ", "; - p->stream << "do_simulation==" << op->do_simulation << ", "; - p->stream << "round_for_shift==" << op->round_for_shift << ", "; - p->stream << "debug_enabled_ops==" << op->debug_enabled_ops <<", "; - p->stream << "rounding==" << op->rounding; - p->stream << ")"; + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* op = static_cast(ref.get()); + p->stream << "qconfig("; + p->stream << "nbit_input=" << op->nbit_input << ", "; + p->stream << "nbit_weight=" << op->nbit_weight << ", "; + p->stream << "nbit_activation=" << op->nbit_activation << ", "; + p->stream << "calibrate_mode=" << op->calibrate_mode << ", "; + p->stream << "global_scale=" << op->global_scale << ", "; + p->stream << "weight_scale=" << op->weight_scale << ", "; + p->stream << "skip_conv_layers==" << op->skip_conv_layers << ", "; + p->stream << "do_simulation==" << op->do_simulation << ", "; + p->stream << "round_for_shift==" << op->round_for_shift << ", "; + p->stream << "debug_enabled_ops==" << op->debug_enabled_ops << ", "; + p->stream << "rounding==" << op->rounding; + p->stream << ")"; + }); + +TVM_REGISTER_GLOBAL("relay._quantize._GetCurrentQConfig").set_body_typed([]() -> QConfig { + return QConfig::Current(); }); -TVM_REGISTER_GLOBAL("relay._quantize._GetCurrentQConfig") -.set_body_typed(QConfig::Current); - TVM_REGISTER_GLOBAL("relay._quantize._EnterQConfigScope") -.set_body_typed(QConfig::EnterQConfigScope); + .set_body_typed(QConfig::EnterQConfigScope); -TVM_REGISTER_GLOBAL("relay._quantize._ExitQConfigScope") -.set_body_typed(QConfig::ExitQConfigScope); +TVM_REGISTER_GLOBAL("relay._quantize._ExitQConfigScope").set_body_typed(QConfig::ExitQConfigScope); } // namespace quantize } // namespace relay diff --git a/src/relay/quantize/quantize.h b/src/relay/quantize/quantize.h index 563f47f56933..86f8926c98ac 100644 --- a/src/relay/quantize/quantize.h +++ b/src/relay/quantize/quantize.h @@ -24,9 +24,11 @@ #ifndef TVM_RELAY_QUANTIZE_QUANTIZE_H_ #define TVM_RELAY_QUANTIZE_QUANTIZE_H_ -#include #include +#include + #include + #include "../transforms/pattern_util.h" namespace tvm { @@ -34,12 +36,7 @@ namespace relay { namespace quantize { /*! \brief Kind of annotate field */ -enum QAnnotateKind : int { - kQIdentity = 0, - kQInput = 1, - kQWeight = 2, - kQActivation = 3 -}; +enum QAnnotateKind : int { kQIdentity = 0, kQInput = 1, kQWeight = 2, kQActivation = 3 }; /*! \brief Attribute for simulated quantize operator */ struct SimulatedQuantizeAttrs : public tvm::AttrsNode { @@ -48,20 +45,17 @@ struct SimulatedQuantizeAttrs : public tvm::AttrsNode { std::string rounding; TVM_DECLARE_ATTRS(SimulatedQuantizeAttrs, "relay.attrs.SimulatedQuantizeAttrs") { - TVM_ATTR_FIELD(kind) - .describe("kind of field, hint for nbit/dtype configuration."); - TVM_ATTR_FIELD(sign).set_default(true) - .describe("whether to use signed data type."); - TVM_ATTR_FIELD(rounding).set_default("round") - .describe("rounding mode. Can be 'floor', 'ceil', 'round'"); + TVM_ATTR_FIELD(kind).describe("kind of field, hint for nbit/dtype configuration."); + TVM_ATTR_FIELD(sign).set_default(true).describe("whether to use signed data type."); + TVM_ATTR_FIELD(rounding).set_default("round").describe( + "rounding mode. Can be 'floor', 'ceil', 'round'"); } }; - class QConfig; /*! -* \brief Container for build configuration options -*/ + * \brief Container for build configuration options + */ class QConfigNode : public Object { public: int nbit_input = 8; @@ -73,6 +67,7 @@ class QConfigNode : public Object { std::string calibrate_mode = "global_scale"; double global_scale = 8.0; std::string weight_scale = "power2"; + bool skip_dense_layer = true; Array skip_conv_layers = Array(ObjectPtr(nullptr)); bool do_simulation = false; bool round_for_shift = true; @@ -90,6 +85,7 @@ class QConfigNode : public Object { v->Visit("calibrate_mode", &calibrate_mode); v->Visit("global_scale", &global_scale); v->Visit("weight_scale", &weight_scale); + v->Visit("skip_dense_layer", &skip_dense_layer); v->Visit("skip_conv_layers", &skip_conv_layers); v->Visit("do_simulation", &do_simulation); v->Visit("round_for_shift", &round_for_shift); @@ -103,20 +99,16 @@ class QConfigNode : public Object { }; /*! -* \brief Container for build configuration options -*/ + * \brief Container for build configuration options + */ class QConfig : public ObjectRef { public: QConfig() {} explicit QConfig(ObjectPtr n) : ObjectRef(n) {} - const QConfigNode* operator->() const { - return static_cast(get()); - } + const QConfigNode* operator->() const { return static_cast(get()); } - QConfigNode* operator->() { - return static_cast(get_mutable()); - } + QConfigNode* operator->() { return static_cast(get_mutable()); } /*! * \brief Push a new BuildConfig context onto the thread local stack. @@ -150,14 +142,10 @@ struct QConfigContext { * context. When the BuildConfigContext is destructed, the previous context is restored. * \param build_config The BuildConfig to set as the new current context. */ - explicit QConfigContext(const QConfig& qconfig) { - QConfig::EnterQConfigScope(qconfig); - } + explicit QConfigContext(const QConfig& qconfig) { QConfig::EnterQConfigScope(qconfig); } /*! \brief Destructor. Pops the context off the thread local stack. */ - ~QConfigContext() { - QConfig::ExitQConfigScope(); - } + ~QConfigContext() { QConfig::ExitQConfigScope(); } }; } // namespace quantize diff --git a/src/relay/quantize/realize.cc b/src/relay/quantize/realize.cc index 6d56e19d229c..49d1e522f7d7 100644 --- a/src/relay/quantize/realize.cc +++ b/src/relay/quantize/realize.cc @@ -25,12 +25,13 @@ * graph. */ -#include #include #include -#include "./quantize.h" -#include "../transforms/pattern_util.h" +#include + #include "../qnn/util.h" +#include "../transforms/pattern_util.h" +#include "./quantize.h" namespace tvm { namespace relay { @@ -53,7 +54,6 @@ class QRealizeExpr : public TempExpr { TVM_DEFINE_OBJECT_REF_METHODS(QRealizeExpr, TempExpr, QRealizeExprNode); }; - class QRealizeIntExprNode : public QRealizeExprNode { public: Expr dom_scale; @@ -67,7 +67,7 @@ class QRealizeIntExprNode : public QRealizeExprNode { Expr Realize() const final; - static constexpr const char * _type_key = "relay.quantize.QRealizeIntExpr"; + static constexpr const char* _type_key = "relay.quantize.QRealizeIntExpr"; TVM_DECLARE_FINAL_OBJECT_INFO(QRealizeIntExprNode, QRealizeExprNode); }; @@ -78,7 +78,6 @@ class QRealizeIntExpr : public QRealizeExpr { TVM_DEFINE_OBJECT_REF_METHODS(QRealizeIntExpr, QRealizeExpr, QRealizeIntExprNode); }; - Expr QRealizeIntExprNode::Realize() const { Expr data = this->data; // dequantize @@ -95,15 +94,13 @@ QRealizeIntExpr::QRealizeIntExpr(Expr data, Expr dom_scale, DataType dtype) { data_ = std::move(n); } - inline Expr ForwardOp(const Call& ref_call, const Array& args) { return Call(ref_call->op, args, ref_call->attrs, ref_call->type_args); } - /* calculate `data * s1 / s2`, use shift if possible */ inline Expr MulAndDiv(Expr data, float s1, float s2, DataType dtype, - const Array &data_shape) { + const Array& data_shape) { const QConfig& cfg = QConfig::Current(); // here we assume the dtype of data is dtype activation if (s1 == s2) return data; @@ -112,8 +109,7 @@ inline Expr MulAndDiv(Expr data, float s1, float s2, DataType dtype, float shift_factor = std::log2(factor); CHECK_GT(shift_factor, 0); if (static_cast(shift_factor) == shift_factor) { - return LeftShift(data, MakeConstantScalar(dtype, - static_cast(shift_factor))); + return LeftShift(data, MakeConstantScalar(dtype, static_cast(shift_factor))); } else if (static_cast(factor) == factor) { return Multiply(data, MakeConstantScalar(dtype, factor)); } else { @@ -122,9 +118,7 @@ inline Expr MulAndDiv(Expr data, float s1, float s2, DataType dtype, } } -Expr QuantizeRealize(const Call& ref_call, - const Array& new_args, - const ObjectRef& ctx) { +Expr QuantizeRealize(const Call& ref_call, const Array& new_args, const ObjectRef& ctx) { const QConfig& cfg = QConfig::Current(); // do not handle data type cast const auto param = ref_call->attrs.as(); @@ -158,22 +152,20 @@ Expr QuantizeRealize(const Call& ref_call, // use right shift if (cfg->round_for_shift) { float round_bias = std::pow(2.0, shift_nbit - 1); - data = Add(data, MakeConstantScalar(cfg->dtype_activation, - static_cast(round_bias))); + data = Add(data, MakeConstantScalar(cfg->dtype_activation, static_cast(round_bias))); } - data = RightShift(data, MakeConstantScalar(cfg->dtype_activation, - static_cast(shift_nbit))); + data = RightShift(data, + MakeConstantScalar(cfg->dtype_activation, static_cast(shift_nbit))); } else { - data = LeftShift(data, MakeConstantScalar(cfg->dtype_activation, - static_cast(shift_nbit))); + data = LeftShift(data, + MakeConstantScalar(cfg->dtype_activation, static_cast(shift_nbit))); } data = Clip(data, clip_min_imm, clip_max_imm); return QRealizeIntExpr(data, dom_scale, n->dtype); } else { data = Cast(data, DataType::Int(64)); data = qnn::FixedPointMultiply(data, idom_scale_imm / odom_scale_imm, - ref_call->type_as()->shape, - cfg->rounding); + ref_call->type_as()->shape, cfg->rounding); data = Cast(Clip(data, clip_min_imm, clip_max_imm), n->dtype); return QRealizeIntExpr(data, dom_scale, n->dtype); } @@ -195,12 +187,9 @@ Expr FoldConstantOpt(const Expr& expr) { } RELAY_REGISTER_OP("relay.op.annotation.simulated_quantize") -.set_attr("FQRealizeRewrite", QuantizeRealize); - + .set_attr("FQRealizeRewrite", QuantizeRealize); -Expr Conv2dRealize(const Call& ref_call, - const Array& new_args, - const ObjectRef& ctx) { +Expr Conv2dRealize(const Call& ref_call, const Array& new_args, const ObjectRef& ctx) { const QConfig& cfg = QConfig::Current(); CHECK_EQ(new_args.size(), 2); if (!new_args[0]->IsInstance() && !new_args[1]->IsInstance()) { @@ -223,20 +212,15 @@ Expr Conv2dRealize(const Call& ref_call, DataType out_dtype = cfg->dtype_activation; attrs->out_dtype = out_dtype; - Expr ret = Call(ref_call->op, - {ldata, rdata}, Attrs(attrs), ref_call->type_args); + Expr ret = Call(ref_call->op, {ldata, rdata}, Attrs(attrs), ref_call->type_args); Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale); Expr dom_scale = FoldConstantOpt(mul); return QRealizeIntExpr(ret, dom_scale, out_dtype); } -RELAY_REGISTER_OP("nn.conv2d") -.set_attr("FQRealizeRewrite", Conv2dRealize); - +RELAY_REGISTER_OP("nn.conv2d").set_attr("FQRealizeRewrite", Conv2dRealize); -Expr DenseRealize(const Call& ref_call, - const Array& new_args, - const ObjectRef& ctx) { +Expr DenseRealize(const Call& ref_call, const Array& new_args, const ObjectRef& ctx) { const QConfig& cfg = QConfig::Current(); CHECK_EQ(new_args.size(), 2); if (!new_args[0]->IsInstance() || !new_args[1]->IsInstance()) { @@ -257,20 +241,15 @@ Expr DenseRealize(const Call& ref_call, DataType out_dtype = cfg->dtype_activation; attrs->out_dtype = out_dtype; - Expr ret = Call(ref_call->op, - {ldata, rdata}, Attrs(attrs), ref_call->type_args); + Expr ret = Call(ref_call->op, {ldata, rdata}, Attrs(attrs), ref_call->type_args); Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale); Expr dom_scale = FoldConstantOpt(mul); return QRealizeIntExpr(ret, dom_scale, out_dtype); } -RELAY_REGISTER_OP("nn.dense") -.set_attr("FQRealizeRewrite", DenseRealize); +RELAY_REGISTER_OP("nn.dense").set_attr("FQRealizeRewrite", DenseRealize); - -Expr MulRealize(const Call& ref_call, - const Array& new_args, - const ObjectRef& ctx) { +Expr MulRealize(const Call& ref_call, const Array& new_args, const ObjectRef& ctx) { const QConfig& cfg = QConfig::Current(); CHECK_EQ(new_args.size(), 2); if (new_args[0].as() && new_args[1].as()) { @@ -297,9 +276,7 @@ Expr MulRealize(const Call& ref_call, return Expr(nullptr); } -RELAY_REGISTER_OP("multiply") -.set_attr("FQRealizeRewrite", MulRealize); - +RELAY_REGISTER_OP("multiply").set_attr("FQRealizeRewrite", MulRealize); float ChooseDomScale(const std::vector& nptrs) { if (nptrs.size() == 2) { @@ -316,7 +293,6 @@ float ChooseDomScale(const std::vector& nptrs) { } } - /* \brief Unify the dom scale of arguments */ Array UnifyDTypeScale(const Array& ref_args, const Array& args, DataType* dtype_ptr, Expr* scale_ptr) { @@ -366,9 +342,7 @@ Array UnifyDTypeScale(const Array& ref_args, const Array& args return ret; } -Expr AddRealize(const Call& ref_call, - const Array& new_args, - const ObjectRef& ctx) { +Expr AddRealize(const Call& ref_call, const Array& new_args, const ObjectRef& ctx) { CHECK_EQ(new_args.size(), 2); if (new_args[0].as() && new_args[1].as()) { DataType dtype; @@ -382,12 +356,9 @@ Expr AddRealize(const Call& ref_call, return Expr(nullptr); } -RELAY_REGISTER_OP("add") -.set_attr("FQRealizeRewrite", AddRealize); +RELAY_REGISTER_OP("add").set_attr("FQRealizeRewrite", AddRealize); -Expr ClipRealize(const Call& ref_call, - const Array& new_args, - const ObjectRef& ctx) { +Expr ClipRealize(const Call& ref_call, const Array& new_args, const ObjectRef& ctx) { CHECK_EQ(new_args.size(), 1); if (const auto* n = new_args[0].as()) { const auto ref_attrs = ref_call->attrs.as(); @@ -396,21 +367,16 @@ Expr ClipRealize(const Call& ref_call, attrs->a_min = ref_attrs->a_min / dom_scale; attrs->a_max = ref_attrs->a_max / dom_scale; - Expr ret = Call(ref_call->op, - {n->data}, Attrs(attrs), ref_call->type_args); + Expr ret = Call(ref_call->op, {n->data}, Attrs(attrs), ref_call->type_args); return QRealizeIntExpr(ret, n->dom_scale, n->dtype); } CHECK(!new_args[0]->IsInstance()); return Expr(nullptr); } -RELAY_REGISTER_OP("clip") -.set_attr("FQRealizeRewrite", ClipRealize); - +RELAY_REGISTER_OP("clip").set_attr("FQRealizeRewrite", ClipRealize); -Expr ConcatenateRealize(const Call& ref_call, - const Array& new_args, - const ObjectRef& ctx) { +Expr ConcatenateRealize(const Call& ref_call, const Array& new_args, const ObjectRef& ctx) { CHECK_EQ(new_args.size(), 1); CHECK_EQ(ref_call->args.size(), 1); @@ -435,14 +401,10 @@ Expr ConcatenateRealize(const Call& ref_call, } } -RELAY_REGISTER_OP("concatenate") -.set_attr("FQRealizeRewrite", ConcatenateRealize); - +RELAY_REGISTER_OP("concatenate").set_attr("FQRealizeRewrite", ConcatenateRealize); /* \brief forward the original operator */ -Expr IdentityRealize(const Call& ref_call, - const Array& new_args, - const ObjectRef& ctx) { +Expr IdentityRealize(const Call& ref_call, const Array& new_args, const ObjectRef& ctx) { CHECK_EQ(new_args.size(), 1); if (const auto* n = new_args[0].as()) { Expr ret = ForwardOp(ref_call, {n->data}); @@ -452,18 +414,15 @@ Expr IdentityRealize(const Call& ref_call, return Expr(nullptr); } -RELAY_REGISTER_OP("nn.relu") -.set_attr("FQRealizeRewrite", IdentityRealize); +RELAY_REGISTER_OP("nn.relu").set_attr("FQRealizeRewrite", IdentityRealize); -RELAY_REGISTER_OP("strided_slice") -.set_attr("FQRealizeRewrite", IdentityRealize); +RELAY_REGISTER_OP("strided_slice").set_attr("FQRealizeRewrite", IdentityRealize); RELAY_REGISTER_OP("annotation.stop_fusion") -.set_attr("FQRealizeRewrite", IdentityRealize); + .set_attr("FQRealizeRewrite", IdentityRealize); /* \brief for unary operators which requantize its input to dtype_nbit */ -Expr CastDtypeInputRealize(const Call& ref_call, - const Array& new_args, +Expr CastDtypeInputRealize(const Call& ref_call, const Array& new_args, const ObjectRef& ctx) { const QConfig& cfg = QConfig::Current(); CHECK_EQ(new_args.size(), 1); @@ -477,12 +436,9 @@ Expr CastDtypeInputRealize(const Call& ref_call, } RELAY_REGISTER_OP("nn.max_pool2d") -.set_attr("FQRealizeRewrite", CastDtypeInputRealize); - + .set_attr("FQRealizeRewrite", CastDtypeInputRealize); -Expr AvgPoolRealize(const Call& ref_call, - const Array& new_args, - const ObjectRef& ctx) { +Expr AvgPoolRealize(const Call& ref_call, const Array& new_args, const ObjectRef& ctx) { const QConfig& cfg = QConfig::Current(); CHECK_EQ(new_args.size(), 1); if (const auto* n = new_args[0].as()) { @@ -497,15 +453,12 @@ Expr AvgPoolRealize(const Call& ref_call, return Expr(nullptr); } -RELAY_REGISTER_OP("nn.avg_pool2d") -.set_attr("FQRealizeRewrite", AvgPoolRealize); +RELAY_REGISTER_OP("nn.avg_pool2d").set_attr("FQRealizeRewrite", AvgPoolRealize); RELAY_REGISTER_OP("nn.global_avg_pool2d") -.set_attr("FQRealizeRewrite", AvgPoolRealize); + .set_attr("FQRealizeRewrite", AvgPoolRealize); -Expr CastHintRealize(const Call& ref_call, - const Array& new_args, - const ObjectRef& ctx) { +Expr CastHintRealize(const Call& ref_call, const Array& new_args, const ObjectRef& ctx) { const auto param = ref_call->attrs.as(); CHECK_EQ(new_args.size(), 1); if (const auto* n = new_args[0].as()) { @@ -517,19 +470,17 @@ Expr CastHintRealize(const Call& ref_call, } RELAY_REGISTER_OP("annotation.cast_hint") -.set_attr("FQRealizeRewrite", CastHintRealize); + .set_attr("FQRealizeRewrite", CastHintRealize); Pass QuantizeRealizePass() { runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - return Downcast( - ForwardRewrite(f, "FQRealizeRewrite", nullptr, nullptr)); - }; + [=](Function f, IRModule m, PassContext pc) { + return Downcast(ForwardRewrite(f, "FQRealizeRewrite", nullptr, nullptr)); + }; return CreateFunctionPass(pass_func, 1, "QuantizeRealize", {}); } -TVM_REGISTER_GLOBAL("relay._quantize.QuantizeRealize") -.set_body_typed(QuantizeRealizePass); +TVM_REGISTER_GLOBAL("relay._quantize.QuantizeRealize").set_body_typed(QuantizeRealizePass); } // namespace quantize } // namespace relay diff --git a/src/relay/transforms/alter_op_layout.cc b/src/relay/transforms/alter_op_layout.cc index aab0b3a30a7c..3d242cd09f7d 100644 --- a/src/relay/transforms/alter_op_layout.cc +++ b/src/relay/transforms/alter_op_layout.cc @@ -24,20 +24,20 @@ custom layouts or other general weight pre-transformation. */ #include -#include -#include #include +#include #include #include -#include -#include + #include #include -#include +#include #include +#include +#include -#include "transform_layout.h" #include "pattern_util.h" +#include "transform_layout.h" namespace tvm { namespace relay { @@ -72,7 +72,7 @@ class AlterTransformMemorizer : public TransformMemorizer { * \return The new Call after calling the packed func. */ Call CallWithNewLayouts(const Call& ref_call, const std::vector& new_args) override { - static auto falter_layout = Op::GetAttr("FTVMAlterOpLayout"); + static auto falter_layout = Op::GetAttrMap("FTVMAlterOpLayout"); Op op = Downcast(ref_call->op); Expr new_e; @@ -85,8 +85,8 @@ class AlterTransformMemorizer : public TransformMemorizer { } // TODO(@kevinthesun, @icemelon9): This won't work if inputs/outputs are dynamic shapes. // Probably we need to disable the AlterOpLayout when compiling dynamic models. - Expr altered_value = falter_layout[op](ref_call->attrs, new_args, tinfos, - ref_call->checked_type()); + Expr altered_value = + falter_layout[op](ref_call->attrs, new_args, tinfos, ref_call->checked_type()); if (altered_value.defined()) { new_e = altered_value; modified = true; @@ -122,14 +122,13 @@ namespace transform { Pass AlterOpLayout() { runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - return Downcast(relay::alter_op_layout::AlterOpLayout(f)); - }; + [=](Function f, IRModule m, PassContext pc) { + return Downcast(relay::alter_op_layout::AlterOpLayout(f)); + }; return CreateFunctionPass(pass_func, 3, "AlterOpLayout", {"InferType"}); } -TVM_REGISTER_GLOBAL("relay._transform.AlterOpLayout") -.set_body_typed(AlterOpLayout); +TVM_REGISTER_GLOBAL("relay._transform.AlterOpLayout").set_body_typed(AlterOpLayout); } // namespace transform diff --git a/src/relay/transforms/annotate_target.cc b/src/relay/transforms/annotate_target.cc index bc6b4b993ae8..c307d75a9aba 100644 --- a/src/relay/transforms/annotate_target.cc +++ b/src/relay/transforms/annotate_target.cc @@ -29,22 +29,21 @@ #include #include +#include "pass_util.h" + namespace tvm { namespace relay { namespace annotate_target { -static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin"); -static const Op& compiler_end_op = Op::Get("annotation.compiler_end"); - const PackedFunc* make_begin_op = runtime::Registry::Get("relay.op.annotation._make.compiler_begin"); const PackedFunc* make_end_op = runtime::Registry::Get("relay.op.annotation._make.compiler_end"); // A helper class to insert annotation boundaries for a program region that will // be handled by a specific compiler. -class AnnotateTargetWrapper : public ExprMutator { +class AnnotateTargetRewriter : public ExprRewriter { public: - explicit AnnotateTargetWrapper(Array targets) : targets_(std::move(targets)) {} + explicit AnnotateTargetRewriter(Array targets) : targets_(std::move(targets)) {} /*! * \brief This function annotates a compiler end and a compiler begin to all arguments. @@ -66,12 +65,12 @@ class AnnotateTargetWrapper : public ExprMutator { std::string arg_target = "default"; const CallNode* call = arg.as(); - if (call && call->op == compiler_begin_op) { + if (call && call->op == CompilerBeginOp()) { // Argument is already compiler begin node meaning that this is not the first time // running this pass, so we simply remove it and will add a new one later. CHECK_EQ(call->args.size(), 1U); const CallNode* end = call->args[0].as(); - if (end->op == compiler_end_op) { + if (end->op == CompilerEndOp()) { arg_target = end->attrs.as()->compiler; } compiler_ends.push_back(call->args[0]); @@ -108,30 +107,30 @@ class AnnotateTargetWrapper : public ExprMutator { return new_op; } - Expr VisitExpr_(const CallNode* cn) final { + Expr Rewrite_(const CallNode* pre, const Expr& post) final { // Supported targets for this node. The order implies the priority. std::vector supported_targets; - auto op_node = cn->op.as(); + auto op_node = pre->op.as(); // This graph has annotations, meaning that this is not the first time running this pass. - if (op_node && cn->op == compiler_begin_op) { + if (op_node && pre->op == CompilerBeginOp()) { // Bypass compiler begin due to lack of target information. It will be processed // when the following op handling arguments. - CHECK_EQ(cn->args.size(), 1U); - return VisitExpr(cn->args[0]); - } else if (op_node && cn->op == compiler_end_op) { + CHECK_EQ(pre->args.size(), 1U); + return post.as()->args[0]; + } else if (op_node && pre->op == CompilerEndOp()) { // Override compiler end with the new target. - CHECK_EQ(cn->args.size(), 1U); - auto input_expr = VisitExpr(cn->args[0]); + CHECK_EQ(pre->args.size(), 1U); + auto input_expr = post.as()->args[0]; CHECK(op_expr_to_target_.find(input_expr) != op_expr_to_target_.end()); return InsertAnnotation(input_expr, op_expr_to_target_[input_expr], make_end_op); } // Peek the first argument. If it is compiler begin then this node had annotated by // another target before, so we also consider that target as a supported target. - const CallNode* first_arg_call = cn->args[0].as(); - if (first_arg_call && first_arg_call->op == compiler_begin_op) { + const CallNode* first_arg_call = pre->args[0].as(); + if (first_arg_call && first_arg_call->op == CompilerBeginOp()) { std::string arg_target = first_arg_call->attrs.as()->compiler; if (arg_target != "default") { supported_targets.push_back(arg_target); @@ -142,21 +141,21 @@ class AnnotateTargetWrapper : public ExprMutator { if (op_node) { // TVM operators: Check target specific op checking function and add to supported_targets // if it is supported. - Op op = Downcast(cn->op); + Op op = Downcast(pre->op); CHECK(op.defined()); for (const auto& target : this->targets_) { - if (!Op::HasAttr("target." + std::string(target))) { + if (!Op::HasAttrMap("target." + std::string(target))) { continue; } - auto fannotate = Op::GetAttr("target." + std::string(target)); - if (fannotate.count(op) && fannotate[op](cn->attrs, cn->args)) { + auto fannotate = Op::GetAttrMap("target." + std::string(target)); + if (fannotate.count(op) && fannotate[op](pre->attrs, pre->args)) { supported_targets.push_back(target); } } - } else if (cn->op->IsInstance()) { + } else if (pre->op->IsInstance()) { // Composite function: Add the target of a composite function to supported_targets // if it is in the target list. - Function func = Downcast(cn->op); + Function func = Downcast(pre->op); CHECK(func.defined()); if (auto comp_name = func->GetAttr(attr::kComposite)) { @@ -181,23 +180,22 @@ class AnnotateTargetWrapper : public ExprMutator { std::string target = supported_targets[0]; // Visit and mutate arguments after the target of this op has been determined. - auto new_call = Downcast(ExprMutator::VisitExpr_(cn)); + Call post_call = Downcast(post); // Add annotations to each arg. - auto target_n_args = AnnotateArgs(new_call->args, target); + auto target_n_args = AnnotateArgs(post_call->args, target); Array compiler_begins = std::get<1>(target_n_args); - Call call = Call(new_call->op, compiler_begins, new_call->attrs); - call->checked_type_ = cn->checked_type_; + Call new_call = Call(post_call->op, compiler_begins, post_call->attrs); + new_call->checked_type_ = pre->checked_type_; // Update the target map. - op_expr_to_target_[call] = target; + op_expr_to_target_[new_call] = target; - return std::move(call); + return std::move(new_call); } - Expr VisitExpr_(const TupleNode* op) final { - auto new_e = ExprMutator::VisitExpr_(op); - auto expr = Downcast(new_e); + Expr Rewrite_(const TupleNode* op, const Expr& post) final { + auto expr = Downcast(post); auto target_n_args = AnnotateArgs(expr->fields); auto new_expr = Tuple(std::get<1>(target_n_args)); @@ -205,9 +203,8 @@ class AnnotateTargetWrapper : public ExprMutator { return std::move(new_expr); } - Expr VisitExpr_(const TupleGetItemNode* op) final { - auto new_e = ExprMutator::VisitExpr_(op); - auto expr = Downcast(new_e); + Expr Rewrite_(const TupleGetItemNode* op, const Expr& post) final { + auto expr = Downcast(post); auto target_n_args = AnnotateArgs(Array({expr->tuple})); auto new_expr = TupleGetItem(std::get<1>(target_n_args)[0], expr->index); @@ -215,7 +212,7 @@ class AnnotateTargetWrapper : public ExprMutator { return std::move(new_expr); } - Expr VisitExpr_(const FunctionNode* fn) final { + Expr Rewrite_(const FunctionNode* fn, const Expr& post) final { Function func; Expr new_body; // don't step into composite functions @@ -223,8 +220,7 @@ class AnnotateTargetWrapper : public ExprMutator { func = GetRef(fn); new_body = func->body; } else { - auto new_e = ExprMutator::VisitExpr_(fn); - func = Downcast(new_e); + func = Downcast(post); new_body = func->body; if (op_expr_to_target_.find(func->body) != op_expr_to_target_.end()) { new_body = InsertAnnotation(func->body, op_expr_to_target_[func->body], make_end_op); @@ -234,9 +230,8 @@ class AnnotateTargetWrapper : public ExprMutator { return Function(func->params, new_body, func->ret_type, func->type_params, func->attrs); } - Expr VisitExpr_(const LetNode* op) final { - auto new_e = ExprMutator::VisitExpr_(op); - auto let = Downcast(new_e); + Expr Rewrite_(const LetNode* op, const Expr& post) final { + auto let = Downcast(post); auto target_n_args = AnnotateArgs({let->value, let->body}); auto new_expr = Let(let->var, std::get<1>(target_n_args)[0], std::get<1>(target_n_args)[1]); @@ -244,9 +239,8 @@ class AnnotateTargetWrapper : public ExprMutator { return std::move(new_expr); } - Expr VisitExpr_(const IfNode* op) final { - auto new_e = ExprMutator::VisitExpr_(op); - auto expr = Downcast(new_e); + Expr Rewrite_(const IfNode* op, const Expr& post) final { + auto expr = Downcast(post); auto target_n_args = AnnotateArgs({expr->cond, expr->true_branch, expr->false_branch}); CHECK_EQ(std::get<1>(target_n_args).size(), 3U); @@ -256,9 +250,8 @@ class AnnotateTargetWrapper : public ExprMutator { return std::move(new_expr); } - Expr VisitExpr_(const RefCreateNode* op) final { - auto new_e = ExprMutator::VisitExpr_(op); - auto expr = Downcast(new_e); + Expr Rewrite_(const RefCreateNode* op, const Expr& post) final { + auto expr = Downcast(post); auto target_n_args = AnnotateArgs(Array({expr->value})); auto new_expr = RefCreate(std::get<1>(target_n_args)[0]); @@ -266,9 +259,8 @@ class AnnotateTargetWrapper : public ExprMutator { return std::move(new_expr); } - Expr VisitExpr_(const RefReadNode* op) final { - auto new_e = ExprMutator::VisitExpr_(op); - auto expr = Downcast(new_e); + Expr Rewrite_(const RefReadNode* op, const Expr& post) final { + auto expr = Downcast(post); auto target_n_args = AnnotateArgs(Array({expr->ref})); auto new_expr = RefRead(std::get<1>(target_n_args)[0]); @@ -276,9 +268,8 @@ class AnnotateTargetWrapper : public ExprMutator { return std::move(new_expr); } - Expr VisitExpr_(const RefWriteNode* op) final { - auto new_e = ExprMutator::VisitExpr_(op); - auto expr = Downcast(new_e); + Expr Rewrite_(const RefWriteNode* op, const Expr& post) final { + auto expr = Downcast(post); auto target_n_args = AnnotateArgs(Array({expr->ref, expr->value})); auto new_expr = RefWrite(std::get<1>(target_n_args)[0], std::get<1>(target_n_args)[1]); @@ -290,11 +281,12 @@ class AnnotateTargetWrapper : public ExprMutator { /*! \brief The target backends for annotation. */ Array targets_; /*! \brief Maintain the decision of the target for each op expr. */ - std::unordered_map op_expr_to_target_; + std::unordered_map op_expr_to_target_; }; Expr AnnotateTarget(const Expr& expr, const Array& targets) { - return AnnotateTargetWrapper(targets).Mutate(expr); + auto rewriter = AnnotateTargetRewriter(targets); + return PostOrderRewrite(expr, &rewriter); } } // namespace annotate_target @@ -306,8 +298,7 @@ Pass AnnotateTarget(const Array& targets) { [=](Function f, IRModule m, PassContext pc) { return Downcast(relay::annotate_target::AnnotateTarget(f, targets)); }; - auto func_pass = CreateFunctionPass(pass_func, 0, "AnnotateTargetFunc", - {"InferType"}); + auto func_pass = CreateFunctionPass(pass_func, 0, "AnnotateTargetFunc", {"InferType"}); return transform::Sequential({func_pass, InferType()}, "AnnotateTarget"); } diff --git a/src/relay/transforms/canonicalize_cast.cc b/src/relay/transforms/canonicalize_cast.cc index ebcbd578b5f0..055ab1480a6e 100644 --- a/src/relay/transforms/canonicalize_cast.cc +++ b/src/relay/transforms/canonicalize_cast.cc @@ -22,9 +22,10 @@ * \brief Canonicalize cast expressions to make operator fusion more efficient. */ #include -#include #include +#include #include + #include "pass_util.h" #include "pattern_util.h" @@ -65,7 +66,7 @@ class CastCanonicalizer : public ExprMutator { CastCanonicalizer() : cast_op_(Op::Get("cast")) {} Expr VisitExpr_(const CallNode* call) { - static auto fpattern = Op::GetAttr("TOpPattern"); + static auto fpattern = Op::GetAttrMap("TOpPattern"); if (const OpNode* opnode = call->op.as()) { auto pattern = fpattern[GetRef(opnode)]; @@ -112,8 +113,7 @@ class CastCanonicalizer : public ExprMutator { const CallNode* new_call = new_expr.as(); CHECK(new_call); CHECK(new_call->op == cast_op_); - return Call(new_call->op, new_call->args, new_call->attrs, - new_call->type_args); + return Call(new_call->op, new_call->args, new_call->attrs, new_call->type_args); } } } @@ -122,22 +122,19 @@ class CastCanonicalizer : public ExprMutator { } }; -Expr CanonicalizeCast(const Expr& e) { - return CastCanonicalizer().Mutate(e); -} +Expr CanonicalizeCast(const Expr& e) { return CastCanonicalizer().Mutate(e); } namespace transform { Pass CanonicalizeCast() { runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - return Downcast(CanonicalizeCast(f)); - }; + [=](Function f, IRModule m, PassContext pc) { + return Downcast(CanonicalizeCast(f)); + }; return CreateFunctionPass(pass_func, 3, "CanonicalizeCast", {"InferType"}); } -TVM_REGISTER_GLOBAL("relay._transform.CanonicalizeCast") -.set_body_typed(CanonicalizeCast); +TVM_REGISTER_GLOBAL("relay._transform.CanonicalizeCast").set_body_typed(CanonicalizeCast); } // namespace transform diff --git a/src/relay/transforms/canonicalize_ops.cc b/src/relay/transforms/canonicalize_ops.cc index 1d3111b29d7d..fec757ee68d5 100644 --- a/src/relay/transforms/canonicalize_ops.cc +++ b/src/relay/transforms/canonicalize_ops.cc @@ -23,10 +23,11 @@ This can simplify latter analysis. (e.g. Expand bias_add to expand_dims and broadcast_add.) */ #include +#include #include #include -#include #include + #include "pattern_util.h" namespace tvm { @@ -71,14 +72,13 @@ namespace transform { Pass CanonicalizeOps() { runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - return Downcast(CanonicalizeOps(f)); - }; + [=](Function f, IRModule m, PassContext pc) { + return Downcast(CanonicalizeOps(f)); + }; return CreateFunctionPass(pass_func, 3, "CanonicalizeOps", {"InferType"}); } -TVM_REGISTER_GLOBAL("relay._transform.CanonicalizeOps") -.set_body_typed(CanonicalizeOps); +TVM_REGISTER_GLOBAL("relay._transform.CanonicalizeOps").set_body_typed(CanonicalizeOps); } // namespace transform diff --git a/src/relay/transforms/combine_parallel_conv2d.cc b/src/relay/transforms/combine_parallel_conv2d.cc index af6b1353f5ac..0bf9e7fd38a6 100644 --- a/src/relay/transforms/combine_parallel_conv2d.cc +++ b/src/relay/transforms/combine_parallel_conv2d.cc @@ -33,15 +33,17 @@ */ #include -#include #include #include +#include #include #include + #include #include -#include "./expr_subst.h" + #include "./combine_parallel_op.h" +#include "./expr_subst.h" #include "pattern_util.h" namespace tvm { @@ -50,13 +52,10 @@ namespace relay { class ParallelConv2DCombiner : public ParallelOpCombiner { public: explicit ParallelConv2DCombiner(uint64_t min_num_branches) - : ParallelOpCombiner("nn.conv2d", min_num_branches) { - } + : ParallelOpCombiner("nn.conv2d", min_num_branches) {} protected: - bool IsSupportedOp(const CallNode* n) { - return n->attrs.as()->groups == 1; - } + bool IsSupportedOp(const CallNode* n) { return n->attrs.as()->groups == 1; } bool CanOpsBeCombined(const CallNode* a, const CallNode* b) { StructuralEqual eq; @@ -67,10 +66,10 @@ class ParallelConv2DCombiner : public ParallelOpCombiner { CHECK(attrs_b); const auto* tweight_a = a->args[1]->type_as(); const auto* tweight_b = b->args[1]->type_as(); - const auto shape_a = tir::BijectiveLayout( - Layout(attrs_a->kernel_layout), kOIHW).ForwardShape(tweight_a->shape); - const auto shape_b = tir::BijectiveLayout( - Layout(attrs_b->kernel_layout), kOIHW).ForwardShape(tweight_b->shape); + const auto shape_a = + tir::BijectiveLayout(Layout(attrs_a->kernel_layout), kOIHW).ForwardShape(tweight_a->shape); + const auto shape_b = + tir::BijectiveLayout(Layout(attrs_b->kernel_layout), kOIHW).ForwardShape(tweight_b->shape); return eq(attrs_a->strides, attrs_b->strides) && eq(attrs_a->padding, attrs_b->padding) && eq(attrs_a->dilation, attrs_b->dilation) && eq(attrs_a->groups, attrs_b->groups) && @@ -118,8 +117,7 @@ class ParallelConv2DCombiner : public ParallelOpCombiner { auto toutput_a = a->type_as(); auto toutput_b = b->type_as(); - if (!eq(ta->dtype, tb->dtype) || ta->shape.size() != tb->shape.size()) - return false; + if (!eq(ta->dtype, tb->dtype) || ta->shape.size() != tb->shape.size()) return false; // Position of the 'C' dimension in the argument size_t arg_channel_pos = channel_pos_ - toutput_a->shape.size() + ta->shape.size(); @@ -132,15 +130,12 @@ class ParallelConv2DCombiner : public ParallelOpCombiner { for (size_t i = 0; i < ta->shape.size(); i++) { if (i == arg_channel_pos) continue; - if (!eq(ta->shape[i], tb->shape[i])) - return false; + if (!eq(ta->shape[i], tb->shape[i])) return false; } return true; } - Call MakeCombinedCallFromFollowingOps(const Expr& data, - const Group& branches, - size_t depth, + Call MakeCombinedCallFromFollowingOps(const Expr& data, const Group& branches, size_t depth, size_t parent_index) { Array new_args; const CallNode* call = branches[0][depth]; @@ -166,24 +161,31 @@ class ParallelConv2DCombiner : public ParallelOpCombiner { return Call(call->op, new_args, call->attrs, {}); } - void UpdateGroupOutput(const Expr& data, - const Group& branches, - size_t depth, + void UpdateGroupOutput(const Expr& data, const Group& branches, size_t depth, ExprSubstMap* subst_map) { int64_t index = 0; + for (const auto& branch : branches) { const CallNode* conv2d = branch[0]; int64_t channels = GetConv2DSuperChannelsDim(conv2d); - Array begin; - Array end; + std::vector begin; + std::vector end; for (size_t i = 0; i < channel_pos_; i++) { begin.push_back(0); - end.push_back(NullValue()); + end.push_back(-1); } begin.push_back(index); index += channels; end.push_back(index); - auto slice = MakeStridedSlice(data, std::move(begin), std::move(end), Array{}); + std::vector strides(begin.size(), 1); + for (size_t i = 0; i < begin.size(); ++i) { + end[i] -= begin[i]; + } + std::vector ndarray_shape = {static_cast(begin.size())}; + Constant begin_const = MakeConstantTensor(DataType::Int(64), ndarray_shape, begin); + Constant end_const = MakeConstantTensor(DataType::Int(64), ndarray_shape, end); + Constant strides_const = MakeConstantTensor(DataType::Int(64), ndarray_shape, strides); + auto slice = MakeStridedSlice(data, begin_const, end_const, strides_const, "size"); subst_map->insert({GetRef(branch[depth]), slice}); } } @@ -217,14 +219,13 @@ namespace transform { Pass CombineParallelConv2D(uint64_t min_num_branches) { runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - return Downcast(CombineParallelConv2D(f, min_num_branches)); - }; + [=](Function f, IRModule m, PassContext pc) { + return Downcast(CombineParallelConv2D(f, min_num_branches)); + }; return CreateFunctionPass(pass_func, 4, "CombineParallelConv2d", {"InferType"}); } -TVM_REGISTER_GLOBAL("relay._transform.CombineParallelConv2D") -.set_body_typed(CombineParallelConv2D); +TVM_REGISTER_GLOBAL("relay._transform.CombineParallelConv2D").set_body_typed(CombineParallelConv2D); } // namespace transform diff --git a/src/relay/transforms/combine_parallel_dense.cc b/src/relay/transforms/combine_parallel_dense.cc index 1278020ac735..8613dbe1466e 100644 --- a/src/relay/transforms/combine_parallel_dense.cc +++ b/src/relay/transforms/combine_parallel_dense.cc @@ -32,16 +32,18 @@ */ #include -#include #include #include +#include #include #include + #include #include + +#include "./combine_parallel_op_batch.h" #include "./expr_subst.h" #include "pattern_util.h" -#include "./combine_parallel_op_batch.h" namespace tvm { namespace relay { @@ -49,8 +51,7 @@ namespace relay { class ParallelDenseCombiner : public ParallelOpBatchCombiner { public: explicit ParallelDenseCombiner(uint64_t min_num_branches) - : ParallelOpBatchCombiner("nn.dense", "nn.batch_matmul", min_num_branches) { - } + : ParallelOpBatchCombiner("nn.dense", "nn.batch_matmul", min_num_branches) {} protected: virtual bool CanOpsBeCombined(const CallNode* a, const CallNode* b) { @@ -63,8 +64,7 @@ class ParallelDenseCombiner : public ParallelOpBatchCombiner { const auto* weight_b = b->args[1]->type_as(); return eq(attrs_a->out_dtype, attrs_b->out_dtype) && - eq(weight_a->shape[0], weight_b->shape[0]) && - eq(weight_a->shape[1], weight_b->shape[1]); + eq(weight_a->shape[0], weight_b->shape[0]) && eq(weight_a->shape[1], weight_b->shape[1]); } }; @@ -77,14 +77,13 @@ namespace transform { Pass CombineParallelDense(uint64_t min_num_branches) { runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - return Downcast(CombineParallelDense(f, min_num_branches)); - }; + [=](Function f, IRModule m, PassContext pc) { + return Downcast(CombineParallelDense(f, min_num_branches)); + }; return CreateFunctionPass(pass_func, 4, "CombineParallelDense", {"InferType"}); } -TVM_REGISTER_GLOBAL("relay._transform.CombineParallelDense") -.set_body_typed(CombineParallelDense); +TVM_REGISTER_GLOBAL("relay._transform.CombineParallelDense").set_body_typed(CombineParallelDense); } // namespace transform diff --git a/src/relay/transforms/combine_parallel_op.cc b/src/relay/transforms/combine_parallel_op.cc index a7f7af2b79e5..7ca2ce8b5dba 100644 --- a/src/relay/transforms/combine_parallel_op.cc +++ b/src/relay/transforms/combine_parallel_op.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -23,33 +23,33 @@ * \brief Abstract class to combine parallel ops and their successive element-wise ops. */ +#include "combine_parallel_op.h" + #include #include -#include #include #include +#include #include #include #include + #include -#include #include #include +#include + #include "expr_subst.h" #include "pattern_util.h" -#include "combine_parallel_op.h" - namespace tvm { namespace relay { -BranchGroupFinder::BranchGroupFinder(const Op& op, - FIsSupportedOp fis_supported_op, +BranchGroupFinder::BranchGroupFinder(const Op& op, FIsSupportedOp fis_supported_op, FAreCompatibleOps fare_compatible_ops) - : cached_op_(op), - fis_supported_op_(fis_supported_op), - fare_compatible_ops_(fare_compatible_ops) { -} + : cached_op_(op), + fis_supported_op_(fis_supported_op), + fare_compatible_ops_(fare_compatible_ops) {} std::vector BranchGroupFinder::Find(const Expr& expr) { this->VisitExpr(expr); @@ -81,7 +81,7 @@ std::vector BranchGroupFinder::Find(const Expr& expr) { // Create a branch starting from op. Branch BranchGroupFinder::CreateBranch(const CallNode* op) { - auto fpattern = Op::GetAttr("TOpPattern"); + auto fpattern = Op::GetAttrMap("TOpPattern"); // each branch has at least one element, the first element is always op Branch branch{op}; auto it = children_map_.find(GetRef(branch.back())); @@ -111,18 +111,13 @@ void BranchGroupFinder::VisitExpr_(const CallNode* n) { } ParallelOpCombiner::ParallelOpCombiner(const std::string& op_name, uint64_t min_num_branches) - : cached_op_(Op::Get(op_name)), - min_num_branches_(min_num_branches) { -} + : cached_op_(Op::Get(op_name)), min_num_branches_(min_num_branches) {} Expr ParallelOpCombiner::Combine(const Expr& expr) { - auto groups = BranchGroupFinder(cached_op_, - [&](const CallNode* n) { - return IsSupportedOp(n); - }, - [&](const CallNode* a, const CallNode* b) { - return CanOpsBeCombined(a, b); - }).Find(expr); + auto groups = BranchGroupFinder( + cached_op_, [&](const CallNode* n) { return IsSupportedOp(n); }, + [&](const CallNode* a, const CallNode* b) { return CanOpsBeCombined(a, b); }) + .Find(expr); for (const Group& group : groups) { if (group.size() < min_num_branches_) { continue; @@ -135,10 +130,9 @@ Expr ParallelOpCombiner::Combine(const Expr& expr) { void ParallelOpCombiner::CombineBranches(const Group& branches) { Call combined = MakeCombinedOp(branches); auto it = std::min_element(branches.begin(), branches.end(), - [](const Branch& branch_a, - const Branch& branch_b) { - return branch_a.size() < branch_b.size(); - }); + [](const Branch& branch_a, const Branch& branch_b) { + return branch_a.size() < branch_b.size(); + }); size_t depth = it->size(); size_t i; // starting from 1 to skip the op @@ -155,32 +149,30 @@ void ParallelOpCombiner::CombineBranches(const Group& branches) { } bool ParallelOpCombiner::CheckLevel(const Group& branches, size_t depth, size_t parent_index) { - const CallNode* call = branches[0][depth]; - tvm::StructuralEqual attrs_equal; - // check if all branches in current depth can be combined - for (auto it = branches.begin() + 1; it != branches.end(); it++) { - const Branch& branch = *it; - if (!branch[depth]->op.same_as(call->op) || - !attrs_equal(branch[depth]->attrs, call->attrs) || - branch[depth]->args.size() != call->args.size()) { - return false; - } + const CallNode* call = branches[0][depth]; + tvm::StructuralEqual attrs_equal; + // check if all branches in current depth can be combined + for (auto it = branches.begin() + 1; it != branches.end(); it++) { + const Branch& branch = *it; + if (!branch[depth]->op.same_as(call->op) || !attrs_equal(branch[depth]->attrs, call->attrs) || + branch[depth]->args.size() != call->args.size()) { + return false; + } - if (branch[depth]->args[parent_index].get() != branch[depth - 1]) - return false; + if (branch[depth]->args[parent_index].get() != branch[depth - 1]) return false; - // Check args - for (size_t i = 0; i < call->args.size(); i++) { - if (i == parent_index) continue; + // Check args + for (size_t i = 0; i < call->args.size(); i++) { + if (i == parent_index) continue; - if (!IsArgCompatible(call, branch[depth], i) || - !attrs_equal(call->attrs, branch[depth]->attrs)) { - return false; - } + if (!IsArgCompatible(call, branch[depth], i) || + !attrs_equal(call->attrs, branch[depth]->attrs)) { + return false; } } - return true; } + return true; +} } // namespace relay } // namespace tvm diff --git a/src/relay/transforms/combine_parallel_op.h b/src/relay/transforms/combine_parallel_op.h index 0097e29b13ea..6f53e86d534b 100644 --- a/src/relay/transforms/combine_parallel_op.h +++ b/src/relay/transforms/combine_parallel_op.h @@ -26,27 +26,28 @@ #define TVM_RELAY_TRANSFORMS_COMBINE_PARALLEL_OP_H_ #include -#include #include #include +#include #include #include + +#include #include #include #include -#include + #include "./expr_subst.h" #include "pattern_util.h" - namespace tvm { namespace relay { using Branch = std::vector; using Group = std::vector; -using FIsSupportedOp = std::function; -using FAreCompatibleOps = std::function; -using ExprSubstMap = std::unordered_map; +using FIsSupportedOp = std::function; +using FAreCompatibleOps = std::function; +using ExprSubstMap = std::unordered_map; /* * Class to find parallel branches starting with op that are @@ -74,8 +75,7 @@ class BranchGroupFinder : private ExprVisitor { * \param fare_compatible_ops function that returns true if * two ops are compatible for combining */ - BranchGroupFinder(const Op& op, - FIsSupportedOp fis_supported_op, + BranchGroupFinder(const Op& op, FIsSupportedOp fis_supported_op, FAreCompatibleOps fare_compatible_ops); /* @@ -103,10 +103,11 @@ class BranchGroupFinder : private ExprVisitor { /* \brief ops that are on the first (logically, leftmost) branch * of parallel ops and are eligible to be combined */ - std::unordered_set op_roots_; + std::unordered_set op_roots_; /* \brief map of Expr to CallNodes that follow it */ - std::unordered_map, ObjectHash, ObjectEqual> children_map_; + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> + children_map_; /* * \brief Creates new branch from op and its children that have @@ -188,10 +189,8 @@ class ParallelOpCombiner { * all combined ops * \return new combined call */ - virtual Call MakeCombinedCallFromFollowingOps(const Expr& data, - const Group& branches, - size_t depth, - size_t parent_index) = 0; + virtual Call MakeCombinedCallFromFollowingOps(const Expr& data, const Group& branches, + size_t depth, size_t parent_index) = 0; /* * \brief Updates map of expr to substitute with combined expr. This usually involves @@ -201,9 +200,7 @@ class ParallelOpCombiner { * \param depth depth at which to substitute * \param subst_map map of Expr to replace with Expr to replace it with */ - virtual void UpdateGroupOutput(const Expr& data, - const Group& branches, - size_t depth, + virtual void UpdateGroupOutput(const Expr& data, const Group& branches, size_t depth, ExprSubstMap* subst_map) = 0; private: diff --git a/src/relay/transforms/combine_parallel_op_batch.cc b/src/relay/transforms/combine_parallel_op_batch.cc index 361565ef11d7..2e9ffdb9bb3c 100644 --- a/src/relay/transforms/combine_parallel_op_batch.cc +++ b/src/relay/transforms/combine_parallel_op_batch.cc @@ -44,17 +44,20 @@ * */ +#include "./combine_parallel_op_batch.h" + #include -#include #include #include +#include #include #include + #include #include -#include "./expr_subst.h" + #include "./combine_parallel_op.h" -#include "./combine_parallel_op_batch.h" +#include "./expr_subst.h" #include "pattern_util.h" namespace tvm { @@ -63,13 +66,9 @@ namespace relay { ParallelOpBatchCombiner::ParallelOpBatchCombiner(const std::string& op_name, const std::string& batch_op_name, uint64_t min_num_branches) - : ParallelOpCombiner(op_name, min_num_branches), - batch_op_name_(batch_op_name) { -} + : ParallelOpCombiner(op_name, min_num_branches), batch_op_name_(batch_op_name) {} -bool ParallelOpBatchCombiner::IsSupportedOp(const CallNode* n) { - return true; -} +bool ParallelOpBatchCombiner::IsSupportedOp(const CallNode* n) { return true; } bool ParallelOpBatchCombiner::CanOpsBeCombined(const CallNode* a, const CallNode* b) { if (a->args.size() != b->args.size()) { @@ -116,19 +115,16 @@ bool ParallelOpBatchCombiner::IsArgCompatible(const CallNode* a, const CallNode* auto ta = a->args[index]->type_as(); auto tb = b->args[index]->type_as(); - if (!eq(ta->dtype, tb->dtype) || ta->shape.size() != tb->shape.size()) - return false; + if (!eq(ta->dtype, tb->dtype) || ta->shape.size() != tb->shape.size()) return false; for (size_t i = 0; i < ta->shape.size(); i++) { - if (!eq(ta->shape[i], tb->shape[i])) - return false; + if (!eq(ta->shape[i], tb->shape[i])) return false; } return true; } Call ParallelOpBatchCombiner::MakeCombinedCallFromFollowingOps(const Expr& data, - const Group& branches, - size_t depth, + const Group& branches, size_t depth, size_t parent_index) { Array new_args; const CallNode* call = branches[0][depth]; @@ -160,10 +156,8 @@ Call ParallelOpBatchCombiner::MakeCombinedCallFromFollowingOps(const Expr& data, return Call(call->op, new_args, call->attrs, {}); } -void ParallelOpBatchCombiner::UpdateGroupOutput(const Expr& data, - const Group& branches, - size_t depth, - ExprSubstMap* subst_map) { +void ParallelOpBatchCombiner::UpdateGroupOutput(const Expr& data, const Group& branches, + size_t depth, ExprSubstMap* subst_map) { int index = 0; auto split = MakeSplit(data, Integer(branches.size()), 0); for (const auto& branch : branches) { @@ -174,30 +168,25 @@ void ParallelOpBatchCombiner::UpdateGroupOutput(const Expr& data, } /*! \brief Combine parallel op into batched op if number of branches >= min_num_branches */ -Expr CombineParallelOpBatch(const Expr& expr, - const std::string& op_name, - const std::string& batch_op_name, - uint64_t min_num_branches) { +Expr CombineParallelOpBatch(const Expr& expr, const std::string& op_name, + const std::string& batch_op_name, uint64_t min_num_branches) { return ParallelOpBatchCombiner(op_name, batch_op_name, min_num_branches).Combine(expr); } namespace transform { -Pass CombineParallelOpBatch(const std::string& op_name, - const std::string& batch_op_name, +Pass CombineParallelOpBatch(const String& op_name, const String& batch_op_name, uint64_t min_num_branches) { runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - return Downcast(CombineParallelOpBatch(f, - op_name, - batch_op_name, - min_num_branches)); - }; + [=](Function f, IRModule m, PassContext pc) { + return Downcast( + CombineParallelOpBatch(f, op_name, batch_op_name, min_num_branches)); + }; return CreateFunctionPass(pass_func, 4, "CombineParallelOpBatch", {"InferType"}); } TVM_REGISTER_GLOBAL("relay._transform.CombineParallelOpBatch") -.set_body_typed(CombineParallelOpBatch); + .set_body_typed(CombineParallelOpBatch); } // namespace transform diff --git a/src/relay/transforms/combine_parallel_op_batch.h b/src/relay/transforms/combine_parallel_op_batch.h index 687660433946..9f87d9d2184f 100644 --- a/src/relay/transforms/combine_parallel_op_batch.h +++ b/src/relay/transforms/combine_parallel_op_batch.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -25,16 +25,18 @@ #define TVM_RELAY_TRANSFORMS_COMBINE_PARALLEL_OP_BATCH_H_ #include -#include #include #include +#include #include #include + +#include #include #include -#include -#include "./expr_subst.h" + #include "./combine_parallel_op.h" +#include "./expr_subst.h" #include "pattern_util.h" namespace tvm { @@ -68,8 +70,7 @@ class ParallelOpBatchCombiner : public ParallelOpCombiner { * \param min_num_branches min number of parallel branches beginning with op * to start combining */ - ParallelOpBatchCombiner(const std::string& op_name, - const std::string& batch_op_name, + ParallelOpBatchCombiner(const std::string& op_name, const std::string& batch_op_name, uint64_t min_num_branches); protected: @@ -116,9 +117,7 @@ class ParallelOpBatchCombiner : public ParallelOpCombiner { * all combined ops * \return new combined call as batch op by stacking args */ - Call MakeCombinedCallFromFollowingOps(const Expr& data, - const Group& branches, - size_t depth, + Call MakeCombinedCallFromFollowingOps(const Expr& data, const Group& branches, size_t depth, size_t parent_index) final; /* @@ -129,15 +128,13 @@ class ParallelOpBatchCombiner : public ParallelOpCombiner { * \param depth depth at which to substitute * \param subst_map map of Expr to replace with Expr to replace it with */ - void UpdateGroupOutput(const Expr& data, - const Group& branches, - size_t depth, + void UpdateGroupOutput(const Expr& data, const Group& branches, size_t depth, ExprSubstMap* subst_map) final; private: /* \brief name of op to replace combined ops with. for example, * for combining parallel dense, this will will be set to - * nn.batch_matmul + * nn.batch_matmul */ std::string batch_op_name_; }; diff --git a/src/relay/transforms/convert_layout.cc b/src/relay/transforms/convert_layout.cc index dbb2c38e3f27..9a71642aac13 100644 --- a/src/relay/transforms/convert_layout.cc +++ b/src/relay/transforms/convert_layout.cc @@ -24,20 +24,20 @@ custom layouts or other general weight pre-transformation. */ #include -#include -#include #include +#include #include #include -#include -#include + #include #include -#include +#include #include +#include +#include -#include "transform_layout.h" #include "pattern_util.h" +#include "transform_layout.h" namespace tvm { namespace relay { @@ -51,13 +51,15 @@ class ConvertTransformMemorizerNode : public TransformMemorizerNode { public: /*! * \brief Initializes the desired_layout. - * \param desired_layout The desired layout. + * \param desired_layouts Specify mapping of op_name to array of desired layouts for each input. + * For example: Map("nn.conv2d", Array("NHWC", "OHWI")), + * this specifies the desired layout for data then kernel for nn.conv2d. */ - explicit ConvertTransformMemorizerNode(const std::string& desired_layout) - : desired_layout_(desired_layout) {} + explicit ConvertTransformMemorizerNode(Map> desired_layouts) + : desired_layouts_(std::move(desired_layouts)) {} - /*! \brief The desired layout for the Convert Layout pass */ - std::string desired_layout_; + /*! \brief A mapping of op_name to array of desired layouts for each input. */ + Map> desired_layouts_; }; /*! @@ -80,7 +82,7 @@ class ConvertTransformMemorizer : public TransformMemorizer { * \return The new Call after calling the packed func. */ Call CallWithNewLayouts(const Call& ref_call, const std::vector& new_args) override { - static auto fconvert_layout = Op::GetAttr("FTVMConvertOpLayout"); + static auto fconvert_layout = Op::GetAttrMap("FTVMConvertOpLayout"); Op op = Downcast(ref_call->op); Expr new_e; @@ -91,8 +93,14 @@ class ConvertTransformMemorizer : public TransformMemorizer { auto ttype = expr->type_as(); tinfos.push_back(tvm::te::placeholder(ttype->shape, ttype->dtype)); } + + auto desired_layouts = operator->()->desired_layouts_; + if (desired_layouts.find(op->name) == desired_layouts.end()) { + LOG(FATAL) << "Desired layout(s) not specified for op: " << op->name; + } + Array op_desired_layouts = desired_layouts.at(op->name); Expr altered_value = - fconvert_layout[op](ref_call->attrs, new_args, tinfos, operator->()->desired_layout_); + fconvert_layout[op](ref_call->attrs, new_args, tinfos, op_desired_layouts); if (altered_value.defined()) { new_e = altered_value; modified = true; @@ -115,9 +123,9 @@ class ConvertTransformMemorizer : public TransformMemorizer { * 1. The altered op should have the same number of arguments as the previous one. * 2. Do not support nested tuple arguments. */ -Expr ConvertLayout(const Expr& expr, const std::string& desired_layout) { +Expr ConvertLayout(const Expr& expr, const Map>& desired_layouts) { ConvertTransformMemorizer transformMemorizer( - make_object(desired_layout)); + make_object(desired_layouts)); auto fcontext = [&](const Call& call) -> ObjectRef { return transformMemorizer; }; return ForwardRewrite(expr, LayoutRewriter, fcontext); @@ -127,13 +135,12 @@ Expr ConvertLayout(const Expr& expr, const std::string& desired_layout) { namespace transform { -Pass ConvertLayout(const std::string& desired_layout) { +Pass ConvertLayout(const Map>& desired_layouts) { runtime::TypedPackedFunc pass_func = [=](Function f, IRModule m, PassContext pc) { - return Downcast(relay::convert_op_layout::ConvertLayout(f, desired_layout)); + return Downcast(relay::convert_op_layout::ConvertLayout(f, desired_layouts)); }; - return CreateFunctionPass( - pass_func, 3, "ConvertLayout", {"InferType", "CanonicalizeOps"}); + return CreateFunctionPass(pass_func, 3, "ConvertLayout", {"InferType", "CanonicalizeOps"}); } TVM_REGISTER_GLOBAL("relay._transform.ConvertLayout").set_body_typed(ConvertLayout); diff --git a/src/relay/transforms/convert_sparse_dense.cc b/src/relay/transforms/convert_sparse_dense.cc new file mode 100644 index 000000000000..36aaa478eab6 --- /dev/null +++ b/src/relay/transforms/convert_sparse_dense.cc @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * + * \file convert_sparse_dense.cc + * + * \brief Mutate dense operator to sparse dense operator + */ +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace tvm { +namespace relay { + +// Search dense op weight name from Expr +class DenseOpWeightVisitor : private ExprVisitor { + public: + DenseOpWeightVisitor() : dense_op_(Op::Get("nn.dense")) {} + + Array Search(const Expr& expr) { + VisitExpr(expr); + return memo_; + } + + private: + void VisitExpr_(const CallNode* n) final { + if (n->op == dense_op_) { + const auto weight = n->args[1].as(); + if (weight) { + memo_.push_back(weight->name_hint()); + } + } + for (const auto& arg : n->args) { + VisitExpr(arg); + } + } + // Cache op + const Op& dense_op_; + + Array memo_; +}; // SearchDenseOpWeight + +Array SearchDenseOpWeight(const Expr& e) { return DenseOpWeightVisitor().Search(e); } + +TVM_REGISTER_GLOBAL("relay.analysis.search_dense_op_weight").set_body_typed(SearchDenseOpWeight); + +// Mutate ```nn.dense``` to ```nn.sparse_dense``` +class DenseToSparseDenseMutator : public ExprRewriter { + public: + DenseToSparseDenseMutator(const Array& weight_name, + const Array >& weight_shape) + : dense_op_(Op::Get("nn.dense")), sparse_dense_op_(Op::Get("nn.sparse_dense")) { + CHECK_EQ(weight_name.size(), weight_shape.size()); + for (size_t i = 0; i < weight_name.size(); ++i) { + CHECK(weight_name[i]->IsInstance()); + std::string k = weight_name[i].as()->data; + const auto& ws = weight_shape[i]; + std::vector v(ws.size()); + for (size_t j = 0; j < ws.size(); ++j) { + v[j] = ws[j].as()->value; + } + target_weights_.emplace(k, v); + } + } + + Expr Rewrite_(const CallNode* pre, const Expr& post) override { + if (pre->op == dense_op_) { + const auto weight = pre->args[1].as(); + if (weight) { + if (target_weights_.count(weight->name_hint())) { + const auto& prefix = weight->name_hint(); + const auto& ws = target_weights_.at(prefix); + const auto data = post.as()->args[0]; + auto ws_data_type = + relay::TensorType({ws.at(0), ws.at(1), ws.at(2)}, DataType::Float(32)); + auto ws_indices_type = relay::TensorType({ws.at(3)}, DataType::Int(32)); + auto ws_indptr_type = relay::TensorType({ws.at(4)}, DataType::Int(32)); + Var weight_data(prefix + ".data", ws_data_type); + Var weight_indices(prefix + ".indices", ws_indices_type); + Var weight_indptr(prefix + ".indptr", ws_indptr_type); + + return Call(sparse_dense_op_, {data, weight_data, weight_indices, weight_indptr}); + } + } + } + return post; + } + + private: + // Cached op + const Op& dense_op_; + const Op& sparse_dense_op_; + std::unordered_map > target_weights_; +}; // class DenseToSparseDenseAlter + +Expr DenseToSparse(const Expr& e, const Array& weight_name, + const Array >& weight_shape) { + auto rewriter = DenseToSparseDenseMutator(weight_name, weight_shape); + return PostOrderRewrite(e, &rewriter); +} + +namespace transform { + +Pass DenseToSparse(const Array& weight_name, + const Array >& weight_shape) { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { + // Remove FreeVar warnings + auto f0 = Downcast(DenseToSparse(f, weight_name, weight_shape)); + Array sparse_params = FreeVars(f0); + auto f1 = Function(sparse_params, f0->body, f0->ret_type, f0->type_params, f0->attrs); + Array params = FreeVars(f1); + for (const auto& var : sparse_params) { + params.push_back(var); + } + return Function(params, f1->body, f1->ret_type, f1->type_params, f1->attrs); + }; + return CreateFunctionPass(pass_func, 4, "DenseToSparse", {"DeadCodeElimination"}); +} + +TVM_REGISTER_GLOBAL("relay._transform.DenseToSparse").set_body_typed(DenseToSparse); + +} // namespace transform + +} // namespace relay +} // namespace tvm diff --git a/src/relay/transforms/de_duplicate.cc b/src/relay/transforms/de_duplicate.cc index 48b8666856a6..d90e5c584df3 100644 --- a/src/relay/transforms/de_duplicate.cc +++ b/src/relay/transforms/de_duplicate.cc @@ -23,17 +23,15 @@ * \brief Use a fresh Id for every Var to make the result well-formed. */ #include -#include #include +#include #include namespace tvm { namespace relay { Expr DeDup(const Expr& e) { - class DeDupMutator : public TypeMutator, - public ExprMutator, - public PatternMutator { + class DeDupMutator : public TypeMutator, public ExprMutator, public PatternMutator { public: TypeVar Fresh(const TypeVar& tv) { TypeVar ret = TypeVar(tv->name_hint, tv->kind); @@ -65,9 +63,7 @@ Expr DeDup(const Expr& e) { return Let(v, VisitExpr(op->value), VisitExpr(op->body)); } - Type VisitType(const Type& t) final { - return t.defined() ? TypeMutator::VisitType(t) : t; - } + Type VisitType(const Type& t) final { return t.defined() ? TypeMutator::VisitType(t) : t; } Expr VisitExpr_(const FunctionNode* op) final { tvm::Array type_params; @@ -78,33 +74,23 @@ Expr DeDup(const Expr& e) { for (const Var& param : op->params) { params.push_back(Fresh(param)); } - return Function(params, - VisitExpr(op->body), - VisitType(op->ret_type), - type_params, - op->attrs); + return Function(params, VisitExpr(op->body), VisitType(op->ret_type), type_params, op->attrs); } - Pattern VisitPattern(const Pattern& p) final { - return PatternFunctor::VisitPattern(p); - } + Pattern VisitPattern(const Pattern& p) final { return PatternFunctor::VisitPattern(p); } - Pattern VisitPattern_(const PatternVarNode* op) final { - return PatternVar(Fresh(op->var)); - } + Pattern VisitPattern_(const PatternVarNode* op) final { return PatternVar(Fresh(op->var)); } Type VisitType_(const TypeVarNode* op) final { TypeVar v = GetRef(op); return type_rename_.count(v) != 0 ? type_rename_.at(v) : v; } - Var VisitVar(const Var& v) final { - return Fresh(v); - } + Var VisitVar(const Var& v) final { return Fresh(v); } private: - std::unordered_map rename_; - std::unordered_map type_rename_; + std::unordered_map rename_; + std::unordered_map type_rename_; }; CHECK(WellFormed(e)) << AsText(e, false); Expr ret = DeDupMutator().VisitExpr(e); @@ -113,8 +99,7 @@ Expr DeDup(const Expr& e) { return ret; } -TVM_REGISTER_GLOBAL("relay._transform.dedup") -.set_body_typed(DeDup); +TVM_REGISTER_GLOBAL("relay._transform.dedup").set_body_typed(DeDup); } // namespace relay } // namespace tvm diff --git a/src/relay/transforms/dead_code.cc b/src/relay/transforms/dead_code.cc index a0d093f197d6..f6c2272a3018 100644 --- a/src/relay/transforms/dead_code.cc +++ b/src/relay/transforms/dead_code.cc @@ -30,14 +30,15 @@ #include #include #include + #include "let_list.h" namespace tvm { namespace relay { -template -using VarMap = std::unordered_map; -using VarSet = std::unordered_set; +template +using VarMap = std::unordered_map; +using VarSet = std::unordered_set; class CalcDep; class FindDef : private ExprVisitor { @@ -59,20 +60,18 @@ class Eliminator : private ExprMutator { VarMap expr_map_; VarMap use_map_; bool inline_once_; - explicit Eliminator(const VarMap& expr_map, - const VarMap& use_map, - bool inline_once) : - expr_map_(expr_map), use_map_(use_map), inline_once_(inline_once) { } + explicit Eliminator(const VarMap& expr_map, const VarMap& use_map, bool inline_once) + : expr_map_(expr_map), use_map_(use_map), inline_once_(inline_once) {} friend CalcDep; bool HasLet(const Var& v) { switch (use_map_[v]) { - case 0: - return false; - case 1: - return !inline_once_; - default: - return true; + case 0: + return false; + case 1: + return !inline_once_; + default: + return true; } } @@ -104,8 +103,7 @@ class CalcDep : protected MixedModeVisitor { } private: - explicit CalcDep(const VarMap& expr_map) - : MixedModeVisitor(2), expr_map_(expr_map) {} + explicit CalcDep(const VarMap& expr_map) : MixedModeVisitor(2), expr_map_(expr_map) {} VarMap expr_map_; VarMap use_map_; @@ -123,9 +121,7 @@ class CalcDep : protected MixedModeVisitor { } } - void VisitExpr_(const LetNode* l) final { - VisitExpr(l->body); - } + void VisitExpr_(const LetNode* l) final { VisitExpr(l->body); } void VisitExpr_(const VarNode* v) final { Var var = GetRef(v); @@ -144,14 +140,13 @@ namespace transform { Pass DeadCodeElimination(bool inline_once) { runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - return Downcast(DeadCodeElimination(f, inline_once)); - }; + [=](Function f, IRModule m, PassContext pc) { + return Downcast(DeadCodeElimination(f, inline_once)); + }; return CreateFunctionPass(pass_func, 1, "DeadCodeElimination", {}); } -TVM_REGISTER_GLOBAL("relay._transform.DeadCodeElimination") -.set_body_typed(DeadCodeElimination); +TVM_REGISTER_GLOBAL("relay._transform.DeadCodeElimination").set_body_typed(DeadCodeElimination); } // namespace transform diff --git a/src/relay/transforms/device_annotation.cc b/src/relay/transforms/device_annotation.cc index d5e1d2efd8e6..39cf563f730a 100644 --- a/src/relay/transforms/device_annotation.cc +++ b/src/relay/transforms/device_annotation.cc @@ -28,12 +28,12 @@ * 3. Collect the device allocation of each expression. */ -#include -#include #include +#include #include #include #include +#include #include #include @@ -103,8 +103,7 @@ class ValidateAnnotation : private ExprVisitor { * \return The device type. */ int GetDeviceId(const CallNode* call_node) { - CHECK(IsOnDeviceNode(call_node)) - << "The input call node must be on_device node."; + CHECK(IsOnDeviceNode(call_node)) << "The input call node must be on_device node."; const OnDeviceAttrs* on_device_attr = call_node->attrs.as(); return on_device_attr->device_type; } @@ -160,8 +159,7 @@ class RewriteAnnotation : public ExprMutator { Expr VisitExpr_(const TupleGetItemNode* op) final { Expr tuple = op->tuple; if (NeedDeviceCopy(tuple.operator->(), op)) { - Expr new_expr = - TupleGetItem(GetDeviceCopyExpr(tuple, op), op->index); + Expr new_expr = TupleGetItem(GetDeviceCopyExpr(tuple, op), op->index); UpdateAnnotationMap(op, new_expr.operator->()); return this->VisitExpr(new_expr); } else { @@ -201,8 +199,7 @@ class RewriteAnnotation : public ExprMutator { } if (annotated) { - Call new_call = Call(call_node->op, new_args, call_node->attrs, - call_node->type_args); + Call new_call = Call(call_node->op, new_args, call_node->attrs, call_node->type_args); UpdateAnnotationMap(call_node, new_call.operator->()); return this->VisitExpr(new_call); @@ -235,8 +232,7 @@ class RewriteAnnotation : public ExprMutator { return CreateDeviceCopy(src, fallback_device_, dit->second); } else { const auto dit = annotation_map_.find(dst); - int dst_dev_type = - dit == annotation_map_.end() ? fallback_device_ : dit->second; + int dst_dev_type = dit == annotation_map_.end() ? fallback_device_ : dit->second; return CreateDeviceCopy(src, sit->second, dst_dev_type); } } @@ -301,6 +297,7 @@ class AnnotatationVisitor : private ExprVisitor { visitor(expr); return visitor.annotations_; } + private: void VisitExpr_(const CallNode* call_node) { if (IsOnDeviceNode(call_node)) { @@ -414,9 +411,7 @@ class DeviceInfo { // TODO(zhiics) Skip annotation of tuple node for now. } - void VisitExpr_(const TupleGetItemNode* op) final { - ExprVisitor::VisitExpr_(op); - } + void VisitExpr_(const TupleGetItemNode* op) final { ExprVisitor::VisitExpr_(op); } void VisitExpr_(const VarNode* vn) final { post_dfs_order_.push_back(std::make_pair(vn, has_copy_)); @@ -432,7 +427,6 @@ class DeviceInfo { post_dfs_order_.push_back(std::make_pair(in, has_copy_)); } - int num_device_copy_ops_{0}; bool has_copy_ = false; std::vector> post_dfs_order_; @@ -479,25 +473,23 @@ class DeviceInfo { const auto* attrs = last_copy_node->attrs.as(); cur_dev_type = attrs->src_dev_type; if (out_dev_type == -1) out_dev_type = attrs->dst_dev_type; - if (it->second) device_map_.Set(GetRef(it->first), - attrs->dst_dev_type); + if (it->second) device_map_.Set(GetRef(it->first), attrs->dst_dev_type); } else if (last_copy_node) { Expr expr = GetRef(it->first); CHECK_EQ(device_map_.count(expr), 0U); if (it->second) device_map_.Set(expr, cur_dev_type); } } - return out_dev_type; + return out_dev_type; } void FillPropagation(int out_dev_type) { for (const auto& it : post_visitor_.post_dfs_order_) { - Expr expr = GetRef(it.first); - if (!it.second) device_map_.Set(expr, out_dev_type); + Expr expr = GetRef(it.first); + if (!it.second) device_map_.Set(expr, out_dev_type); } } - PostDfsOrderVisitor post_visitor_; Map device_map_; }; @@ -521,14 +513,12 @@ Expr RewriteAnnotatedOps(const Expr& expr, int fallback_device) { } CHECK_GT(new_body.size(), 0U); if (new_body.size() == 1) { - return Function(params, new_body[0], Type(nullptr), - fn->type_params, fn->attrs); + return Function(params, new_body[0], Type(nullptr), fn->type_params, fn->attrs); } else if (tuple->fields.size() == new_body.size()) { - return new_expr; + return new_expr; } else { Tuple tuple_body = Tuple(new_body); - return Function(params, tuple_body, Type(nullptr), - fn->type_params, fn->attrs); + return Function(params, tuple_body, Type(nullptr), fn->type_params, fn->attrs); } } else { return new_expr; @@ -544,40 +534,35 @@ Expr RewriteAnnotatedOps(const Expr& expr, int fallback_device) { if (tuple->fields.size() == new_fields.size()) { return new_fields.size() == 1 ? new_fields[0] : new_expr; } else { - return new_fields.size() == 1 ? new_fields[0] - : Tuple(new_fields); + return new_fields.size() == 1 ? new_fields[0] : Tuple(new_fields); } } else { return new_expr; } } -Map CollectDeviceInfo(const Expr& expr) { - return DeviceInfo::GetDeviceMap(expr); -} +Map CollectDeviceInfo(const Expr& expr) { return DeviceInfo::GetDeviceMap(expr); } Map CollectDeviceAnnotationOps(const Expr& expr) { return AnnotatationVisitor::GetAnnotations(expr); } -TVM_REGISTER_GLOBAL("relay.analysis.CollectDeviceInfo") -.set_body_typed(CollectDeviceInfo); +TVM_REGISTER_GLOBAL("relay.analysis.CollectDeviceInfo").set_body_typed(CollectDeviceInfo); TVM_REGISTER_GLOBAL("relay.analysis.CollectDeviceAnnotationOps") -.set_body_typed(CollectDeviceAnnotationOps); + .set_body_typed(CollectDeviceAnnotationOps); namespace transform { Pass RewriteAnnotatedOps(int fallback_device) { runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - return Downcast(relay::RewriteAnnotatedOps(f, fallback_device)); - }; + [=](Function f, IRModule m, PassContext pc) { + return Downcast(relay::RewriteAnnotatedOps(f, fallback_device)); + }; return CreateFunctionPass(pass_func, 1, "RewriteAnnotatedOps", {"InferType"}); } -TVM_REGISTER_GLOBAL("relay._transform.RewriteDeviceAnnotation") -.set_body_typed(RewriteAnnotatedOps); +TVM_REGISTER_GLOBAL("relay._transform.RewriteDeviceAnnotation").set_body_typed(RewriteAnnotatedOps); } // namespace transform diff --git a/src/relay/transforms/eliminate_common_subexpr.cc b/src/relay/transforms/eliminate_common_subexpr.cc index 68c59f5ea2ef..8f7375c9dd35 100644 --- a/src/relay/transforms/eliminate_common_subexpr.cc +++ b/src/relay/transforms/eliminate_common_subexpr.cc @@ -29,7 +29,9 @@ #include #include #include + #include + #include "pattern_util.h" namespace tvm { @@ -37,10 +39,10 @@ namespace relay { class CommonSubexprEliminator : public ExprMutator { public: - explicit CommonSubexprEliminator(runtime::TypedPackedFunc fskip): fskip_(fskip) {} + explicit CommonSubexprEliminator(runtime::TypedPackedFunc fskip) : fskip_(fskip) {} Expr VisitExpr_(const CallNode* call) final { - static auto op_stateful = Op::GetAttr("TOpIsStateful"); + static auto op_stateful = Op::GetAttrMap("TOpIsStateful"); Expr new_expr = ExprMutator::VisitExpr_(call); const CallNode* new_call = new_expr.as(); CHECK(new_call); @@ -76,7 +78,7 @@ class CommonSubexprEliminator : public ExprMutator { return new_expr; } - std::unordered_map, ObjectHash, ObjectEqual> expr_map_; + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> expr_map_; runtime::TypedPackedFunc fskip_; }; @@ -88,14 +90,14 @@ namespace transform { Pass EliminateCommonSubexpr(PackedFunc fskip) { runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - return Downcast(EliminateCommonSubexpr(f, fskip)); - }; + [=](Function f, IRModule m, PassContext pc) { + return Downcast(EliminateCommonSubexpr(f, fskip)); + }; return CreateFunctionPass(pass_func, 3, "EliminateCommonSubexpr", {"InferType"}); } TVM_REGISTER_GLOBAL("relay._transform.EliminateCommonSubexpr") -.set_body_typed(EliminateCommonSubexpr); + .set_body_typed(EliminateCommonSubexpr); } // namespace transform diff --git a/src/relay/transforms/eta_expand.cc b/src/relay/transforms/eta_expand.cc index c720bdfa14ee..42718eec9179 100644 --- a/src/relay/transforms/eta_expand.cc +++ b/src/relay/transforms/eta_expand.cc @@ -24,9 +24,9 @@ * */ #include +#include #include #include -#include namespace tvm { namespace relay { @@ -49,7 +49,7 @@ class TypeVarReplacer : public TypeMutator { private: /*! \brief variable replacement map to remap old type vars to fresh ones */ - std::unordered_map replace_map_; + std::unordered_map replace_map_; }; /*! @@ -62,16 +62,14 @@ class EtaExpander : public ExprMutator { type_var_replacer_(TypeVarReplacer()), expand_constructor_(expand_constructor), expand_global_var_(expand_global_var) { - CHECK(expand_constructor || expand_global_var) - << "must expand at least one language feature"; + CHECK(expand_constructor || expand_global_var) << "must expand at least one language feature"; } IRModule Expand() { for (GlobalVar global_var : mod_->GetGlobalVars()) { const BaseFunc base_func = mod_->Lookup(global_var); if (auto* n = base_func.as()) { - const Function new_func = Downcast( - VisitExpr(GetRef(n))); + const Function new_func = Downcast(VisitExpr(GetRef(n))); mod_->Update(global_var, new_func); } } @@ -111,11 +109,8 @@ class EtaExpander : public ExprMutator { Expr body = Call(cons, params, Attrs()); Type ret_type = TypeCall(cons->belong_to, type_params); - return Function( - Downcast>(params), - body, - ret_type, - Downcast>(type_params)); + return Function(Downcast>(params), body, ret_type, + Downcast>(type_params)); } Expr VisitExpr_(const GlobalVarNode* gvar_node) final { @@ -124,7 +119,7 @@ class EtaExpander : public ExprMutator { return std::move(gvar); } const auto base_func = mod_->Lookup(gvar); - if (auto *ptr = base_func.as()) { + if (auto* ptr = base_func.as()) { // handle relay function, skip external functions. auto func = GetRef(ptr); tvm::Array params; @@ -135,11 +130,7 @@ class EtaExpander : public ExprMutator { args.push_back(var); } - return Function( - args, - Call(gvar, params), - func->ret_type, - func->type_params); + return Function(args, Call(gvar, params), func->ret_type, func->type_params); } else { return std::move(gvar); } @@ -161,15 +152,14 @@ class EtaExpander : public ExprMutator { namespace transform { Pass EtaExpand(bool expand_constructor, bool expand_global_var) { - runtime::TypedPackedFunc pass_func = - [=](IRModule mod, PassContext pc) { + runtime::TypedPackedFunc pass_func = [=](IRModule mod, + PassContext pc) { return eta_expand::EtaExpander(mod, expand_constructor, expand_global_var).Expand(); }; return CreateModulePass(pass_func, 1, "EtaExpand", {}); } -TVM_REGISTER_GLOBAL("relay._transform.EtaExpand") -.set_body_typed(EtaExpand); +TVM_REGISTER_GLOBAL("relay._transform.EtaExpand").set_body_typed(EtaExpand); } // namespace transform diff --git a/src/relay/transforms/expr_subst.cc b/src/relay/transforms/expr_subst.cc index d3e6aa8dbfe6..96f139b7cdeb 100644 --- a/src/relay/transforms/expr_subst.cc +++ b/src/relay/transforms/expr_subst.cc @@ -22,15 +22,16 @@ * \brief Utility functions for substituting expressions. */ -#include #include "./expr_subst.h" +#include + namespace tvm { namespace relay { class ExprSubstituter : public ExprMutator { public: - explicit ExprSubstituter(std::unordered_map subst_map) + explicit ExprSubstituter(std::unordered_map subst_map) : subst_map_(subst_map) {} Expr VisitExpr(const Expr& expr) final { @@ -46,7 +47,7 @@ class ExprSubstituter : public ExprMutator { }; Expr ExprSubst(const Expr& expr, - std::unordered_map subst_map) { + std::unordered_map subst_map) { return ExprSubstituter(std::move(subst_map)).Mutate(expr); } diff --git a/src/relay/transforms/expr_subst.h b/src/relay/transforms/expr_subst.h index 849ffc2db9e2..104ce0be0106 100644 --- a/src/relay/transforms/expr_subst.h +++ b/src/relay/transforms/expr_subst.h @@ -24,13 +24,14 @@ #ifndef TVM_RELAY_TRANSFORMS_EXPR_SUBST_H_ #define TVM_RELAY_TRANSFORMS_EXPR_SUBST_H_ #include + #include namespace tvm { namespace relay { Expr ExprSubst(const Expr& expr, - std::unordered_map subst_map); + std::unordered_map subst_map); } // namespace relay } // namespace tvm diff --git a/src/relay/transforms/fast_math.cc b/src/relay/transforms/fast_math.cc index 8234dea5e075..3c8d8db637c8 100644 --- a/src/relay/transforms/fast_math.cc +++ b/src/relay/transforms/fast_math.cc @@ -22,10 +22,11 @@ * \brief Replaces non linear activation functions with their fast but approximate counterparts. */ #include -#include #include -#include +#include #include +#include + #include "pattern_util.h" namespace tvm { @@ -33,10 +34,7 @@ namespace relay { class FastMathMutator : public ExprRewriter { public: - FastMathMutator() - : exp_op_(Op::Get("exp")), - erf_op_(Op::Get("erf")), - tanh_op_(Op::Get("tanh")) {} + FastMathMutator() : exp_op_(Op::Get("exp")), erf_op_(Op::Get("erf")), tanh_op_(Op::Get("tanh")) {} Expr Rewrite_(const CallNode* pre, const Expr& post) override { if (pre->op == exp_op_) { @@ -67,14 +65,11 @@ namespace transform { Pass FastMath() { runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - return Downcast(FastMath(f)); - }; + [=](Function f, IRModule m, PassContext pc) { return Downcast(FastMath(f)); }; return CreateFunctionPass(pass_func, 4, "FastMath", {"InferType"}); } -TVM_REGISTER_GLOBAL("relay._transform.FastMath") -.set_body_typed(FastMath); +TVM_REGISTER_GLOBAL("relay._transform.FastMath").set_body_typed(FastMath); } // namespace transform diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc index a52f42054c3e..b2eab8f96987 100644 --- a/src/relay/transforms/fold_constant.cc +++ b/src/relay/transforms/fold_constant.cc @@ -21,15 +21,16 @@ * \file constant_folding.cc */ #include +#include #include +#include #include #include -#include -#include #include -#include -#include #include +#include +#include + #include "pattern_util.h" namespace tvm { @@ -48,14 +49,13 @@ class ConstantChecker : private ExprVisitor { return true; } const auto it = memo_.find(expr); - if (it != memo_.end()) - return it->second; + if (it != memo_.end()) return it->second; VisitExpr(expr); return memo_[expr]; // return memoized result or the default value false } private: - std::unordered_map memo_; + std::unordered_map memo_; void VisitExpr_(const TupleNode* n) final { bool result = true; @@ -69,12 +69,9 @@ class ConstantChecker : private ExprVisitor { } }; -bool ConstantCheck(const Expr& e) { - return ConstantChecker().Check(e); -} +bool ConstantCheck(const Expr& e) { return ConstantChecker().Check(e); } -TVM_REGISTER_GLOBAL("relay.analysis.check_constant") -.set_body_typed(ConstantCheck); +TVM_REGISTER_GLOBAL("relay.analysis.check_constant").set_body_typed(ConstantCheck); // TODO(tvm-team) consider combine dead-code with constant folder. // or make a more powerful partial evaluator. @@ -98,9 +95,7 @@ class ConstantFolder : public ExprMutator { } else { Var var = Downcast(this->Mutate(op->var)); Expr body = this->Mutate(op->body); - if (var.same_as(op->var) && - value.same_as(op->value) && - body.same_as(op->body)) { + if (var.same_as(op->var) && value.same_as(op->value) && body.same_as(op->body)) { return GetRef(op); } else { return Let(var, value, body); @@ -109,7 +104,7 @@ class ConstantFolder : public ExprMutator { } Expr VisitExpr_(const CallNode* call) final { - static auto op_stateful = Op::GetAttr("TOpIsStateful"); + static auto op_stateful = Op::GetAttrMap("TOpIsStateful"); std::unordered_set skip_list{"zeros_like", "ones_like", "full_like", "full"}; @@ -123,7 +118,7 @@ class ConstantFolder : public ExprMutator { const OpNode* op = call->op.as(); if (op == nullptr) return res; if (skip_list.count(op->name)) { - return res; + return res; } // skip stateful ops. if (op_stateful.get(GetRef(op), false)) return res; @@ -133,9 +128,7 @@ class ConstantFolder : public ExprMutator { } // We should think about potentially constant evaluation over these ops too. - if (call->op == invoke_tvm_op_ || - call->op == shape_func_op_ || - call->op == alloc_tensor_op_ || + if (call->op == invoke_tvm_op_ || call->op == shape_func_op_ || call->op == alloc_tensor_op_ || call->op == alloc_storage_op_) { return GetRef(call); } @@ -184,8 +177,7 @@ class ConstantFolder : public ExprMutator { if (value->IsInstance()) { auto nd_array = Downcast(value); for (auto dim : nd_array.Shape()) { - CHECK_GT(dim, 0) - << "invalid dimension after constant eval"; + CHECK_GT(dim, 0) << "invalid dimension after constant eval"; } return Constant(nd_array); } else if (const auto* val = value.as()) { @@ -202,7 +194,7 @@ class ConstantFolder : public ExprMutator { } // Constant evaluate a expression. Expr ConstEvaluate(Expr expr) { - std::vector passes = {transform::FuseOps(0), + std::vector passes = {transform::FuseOps(0), transform::ToANormalForm(), transform::InferType()}; Function func; if (expr.as()) { @@ -211,10 +203,7 @@ class ConstantFolder : public ExprMutator { // TODO(@jroesch): fix this func = Function(FreeVars(expr), expr, Type(), FreeTypeVars(expr, module_), {}); } - auto mod = IRModule( - {}, - module_->type_definitions, - module_->Imports()); + auto mod = IRModule({}, module_->type_definitions, module_->Imports()); auto global = GlobalVar("main"); mod->Add(global, func); auto seq = transform::Sequential(passes); @@ -250,7 +239,7 @@ class ConstantFolder : public ExprMutator { value = runtime::NDArray::Empty({}, cdtype, ctx); } else { CHECK_NE(ishape.size(), 0); - std::vector cshape = { static_cast(ishape.size()) }; + std::vector cshape = {static_cast(ishape.size())}; value = runtime::NDArray::Empty(cshape, cdtype, ctx); int32_t* dims = static_cast(value->data); using ::tvm::tir::IntImmNode; @@ -273,20 +262,20 @@ class ConstantFolder : public ExprMutator { // Cast the constant into correct dtype auto cast_attrs = make_object(); cast_attrs->dtype = param->dtype; - Expr ret = Call(cast_op_, { shape }, Attrs(cast_attrs), {}); + Expr ret = Call(cast_op_, {shape}, Attrs(cast_attrs), {}); return ConstEvaluate(ret); } }; - Expr FoldConstant(const Expr& expr, const IRModule& mod) { + using tvm::transform::PassContext; DLContext ctx; ctx.device_type = kDLCPU; ctx.device_id = 0; Target target = Target::Create("llvm"); // use a fresh build context // in case we are already in a build context. - With fresh_build_ctx(BuildConfig::Create()); + With fresh_build_ctx(PassContext::Create()); return ConstantFolder(CreateInterpreter(mod, ctx, target), mod).Mutate(expr); } @@ -295,14 +284,13 @@ namespace transform { Pass FoldConstant() { runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - return Downcast(FoldConstant(f, m)); - }; + [=](Function f, IRModule m, PassContext pc) { + return Downcast(FoldConstant(f, m)); + }; return CreateFunctionPass(pass_func, 2, "FoldConstant", {}); } -TVM_REGISTER_GLOBAL("relay._transform.FoldConstant") -.set_body_typed(FoldConstant); +TVM_REGISTER_GLOBAL("relay._transform.FoldConstant").set_body_typed(FoldConstant); } // namespace transform diff --git a/src/relay/transforms/fold_scale_axis.cc b/src/relay/transforms/fold_scale_axis.cc index cfe74bfd8ef1..a3765f3c3bef 100644 --- a/src/relay/transforms/fold_scale_axis.cc +++ b/src/relay/transforms/fold_scale_axis.cc @@ -23,14 +23,15 @@ * \brief Fold axis scaling into weights of * conv/dense operators. */ -#include #include #include #include #include -#include "pattern_util.h" -#include "pass_util.h" +#include +#include "../op/tensor/transform.h" +#include "pass_util.h" +#include "pattern_util.h" namespace tvm { namespace relay { @@ -39,11 +40,11 @@ namespace relay { * * Use namespace to reduce potential naming conflict. */ + namespace fold_scale_axis { using runtime::TypedPackedFunc; - // FoldScaleAxis algorithm: // // The general idea is to transform Expr to tuple of @@ -109,7 +110,7 @@ class Message : public ObjectRef { TVM_DEFINE_OBJECT_REF_METHODS(Message, ObjectRef, MessageNode); }; -Message::Message(const AxesSet& axes, bool require_positive) { +Message::Message(const AxesSet& axes, bool require_positive) { auto n = make_object(); n->axes = axes; n->require_positive = require_positive; @@ -139,7 +140,8 @@ AxesSet Intersect(const AxesSet& lhs, const AxesSet& rhs) { ++j; } else { ret.push_back(lhs[i]); - ++i; ++j; + ++i; + ++j; } } return ret; @@ -166,8 +168,8 @@ Message Intersect(const Message& lhs, const Message& rhs) { * positive scale is required. * \return The message containing the result scaling on axes of the input. */ -using FForwardPrep = runtime::TypedPackedFunc< - Array (const Call& call, const Message& out_message)>; +using FForwardPrep = + runtime::TypedPackedFunc(const Call& call, const Message& out_message)>; /*! \brief Axis scale tuple. */ class ScaledExprNode : public TempExprNode { @@ -180,8 +182,7 @@ class ScaledExprNode : public TempExprNode { Expr scale = NullValue(); Expr Realize() const final { - CHECK(!axes.defined()) - << "outstanding scale"; + CHECK(!axes.defined()) << "outstanding scale"; return value; } @@ -195,18 +196,15 @@ class ScaledExprNode : public TempExprNode { TVM_DECLARE_FINAL_OBJECT_INFO(ScaledExprNode, TempExprNode); }; -using FForwardRewrite = TypedPackedFunc< - Expr(const Call& ref_call, - const Array& new_args, - const Message& message)>; +using FForwardRewrite = TypedPackedFunc& new_args, + const Message& message)>; //---------------------------------------------- // Generic Visitors for FScaleAxisForward //---------------------------------------------- class ForwardPrep : private ExprVisitor { public: - std::unordered_map - Prepare(const Expr& body) { + std::unordered_map Prepare(const Expr& body) { this->Update(body, NullValue()); this->VisitExpr(body); // flist is added in the Post-DFS order @@ -222,7 +220,7 @@ class ForwardPrep : private ExprVisitor { private: // The invoke list - std::vector > flist_; + std::vector> flist_; // The message on each node. std::unordered_map message_; // Update the message stored at node. @@ -245,15 +243,11 @@ class ForwardPrep : private ExprVisitor { } } // Visitor pattern override. - void VisitExpr_(const LetNode* call) { - LOG(FATAL) << "FoldScaleAxis only accept dataflow-form"; - } + void VisitExpr_(const LetNode* call) { LOG(FATAL) << "FoldScaleAxis only accept dataflow-form"; } void VisitExpr_(const FunctionNode* op) { ExprVisitor::VisitExpr_(op); - auto flazy = [this, op] { - this->Update(op->body, NullValue()); - }; + auto flazy = [this, op] { this->Update(op->body, NullValue()); }; flist_.push_back(flazy); } @@ -261,8 +255,7 @@ class ForwardPrep : private ExprVisitor { ExprVisitor::VisitExpr_(call); // function to be lazily invoked auto flazy = [this, call]() { - static const auto& fprep = - Op::GetAttr("FScaleAxisForwardPrep"); + static const auto& fprep = Op::GetAttrMap("FScaleAxisForwardPrep"); // find the message send to this node. auto it = message_.find(call); Message out_message; @@ -314,6 +307,42 @@ class ForwardPrep : private ExprVisitor { } }; +static bool IsIntInArray(const Array& axis, int v) { + for (size_t i = 0; i < axis.size(); i++) { + if (axis[i] == v) return true; + } + return false; +} + +static Expr ReshapeToMatchAxis(Expr scale, const Array& shape, + const Array& axis) { + Array arr; + for (size_t i = 0; i < shape.size(); i++) { + if (IsIntInArray(axis, i)) { + auto node = shape[i].as(); + if (!node) { + // if the shape is not a constant, use normal transform + return Expr(); + } + arr.push_back(node->value); + } else { + arr.push_back(1); + } + } + return MakeReshape( + scale, MakeConstantTensor(DataType::Int(32), {static_cast(arr.size())}, arr)); +} + +// if only one axis, use expand dim. Else, use reshape +static Expr ReshapeOrExpandToMatchAxis(Expr scale, const Array& shape, + const Array& axis) { + if (axis.size() > 1) { + return ReshapeToMatchAxis(scale, shape, axis); + } else { + return ExpandBiasToMatchAxis(scale, shape.size(), axis); + } +} + //---------------------------------------------- // Per operator defs for FScaleAxisForward //---------------------------------------------- @@ -326,31 +355,26 @@ Array ReluForwardPrep(const Call& call, const Message& out_message) { return {out_message}; } -Expr ReluForwardRewrite(const Call& ref_call, - const Array& new_args, - const Message& message) { +Expr ReluForwardRewrite(const Call& ref_call, const Array& new_args, const Message& message) { const auto* input = new_args[0].as(); if (input == nullptr) return Expr(nullptr); // return transformed conv2d auto rnode = make_object(); - rnode->value = Call( - ref_call->op, {input->value}, ref_call->attrs, ref_call->type_args); + rnode->value = Call(ref_call->op, {input->value}, ref_call->attrs, ref_call->type_args); rnode->scale = input->scale; rnode->axes = input->axes; return Expr(rnode); } -RELAY_REGISTER_OP("nn.relu") -.set_attr("FScaleAxisForwardPrep", ReluForwardPrep); +RELAY_REGISTER_OP("nn.relu").set_attr("FScaleAxisForwardPrep", ReluForwardPrep); -RELAY_REGISTER_OP("nn.relu") -.set_attr("FScaleAxisForwardRewrite", ReluForwardRewrite); +RELAY_REGISTER_OP("nn.relu").set_attr("FScaleAxisForwardRewrite", + ReluForwardRewrite); -RELAY_REGISTER_OP("nn.leaky_relu") -.set_attr("FScaleAxisForwardPrep", ReluForwardPrep); +RELAY_REGISTER_OP("nn.leaky_relu").set_attr("FScaleAxisForwardPrep", ReluForwardPrep); RELAY_REGISTER_OP("nn.leaky_relu") -.set_attr("FScaleAxisForwardRewrite", ReluForwardRewrite); + .set_attr("FScaleAxisForwardRewrite", ReluForwardRewrite); // AddSub Array AddSubForwardPrep(const Call& call, const Message& out_message) { @@ -367,8 +391,7 @@ Array AddSubForwardPrep(const Call& call, const Message& out_message) { return {none, none}; } -Expr AddSubForwardRewrite(const Call& ref_call, - const Array& new_args, +Expr AddSubForwardRewrite(const Call& ref_call, const Array& new_args, const Message& message) { const auto* slhs = new_args[0].as(); const auto* srhs = new_args[1].as(); @@ -380,43 +403,42 @@ Expr AddSubForwardRewrite(const Call& ref_call, if (slhs != nullptr) { CHECK(srhs == nullptr); CHECK(MatchBroadcastToLeftAxes(tlhs, trhs, slhs->axes)); - Expr scale = ExpandBiasToMatchAxis( - slhs->scale, tlhs->shape.size(), slhs->axes); + Expr scale = ReshapeOrExpandToMatchAxis(slhs->scale, tlhs->shape, slhs->axes); + if (!scale.defined()) { + return Expr(); + } Expr rhs = Divide(new_args[1], scale); - rnode->value = Call(ref_call->op, {slhs->value, rhs}, - ref_call->attrs, ref_call->type_args); + rnode->value = Call(ref_call->op, {slhs->value, rhs}, ref_call->attrs, ref_call->type_args); rnode->scale = slhs->scale; rnode->axes = slhs->axes; } else { CHECK(srhs != nullptr); CHECK(MatchBroadcastToLeftAxes(trhs, tlhs, srhs->axes)); - Expr scale = ExpandBiasToMatchAxis( - srhs->scale, trhs->shape.size(), srhs->axes); + Expr scale = ReshapeOrExpandToMatchAxis(srhs->scale, trhs->shape, srhs->axes); + if (!scale.defined()) { + return Expr(); + } Expr lhs = Divide(new_args[0], scale); - rnode->value = Call(ref_call->op, {lhs, srhs->value}, - ref_call->attrs, ref_call->type_args); + rnode->value = Call(ref_call->op, {lhs, srhs->value}, ref_call->attrs, ref_call->type_args); rnode->scale = srhs->scale; rnode->axes = srhs->axes; } return Expr(rnode); } -RELAY_REGISTER_OP("add") -.set_attr("FScaleAxisForwardPrep", AddSubForwardPrep); +RELAY_REGISTER_OP("add").set_attr("FScaleAxisForwardPrep", AddSubForwardPrep); -RELAY_REGISTER_OP("add") -.set_attr("FScaleAxisForwardRewrite", AddSubForwardRewrite); +RELAY_REGISTER_OP("add").set_attr("FScaleAxisForwardRewrite", + AddSubForwardRewrite); -RELAY_REGISTER_OP("subtract") -.set_attr("FScaleAxisForwardPrep", AddSubForwardPrep); +RELAY_REGISTER_OP("subtract").set_attr("FScaleAxisForwardPrep", AddSubForwardPrep); RELAY_REGISTER_OP("subtract") -.set_attr("FScaleAxisForwardRewrite", AddSubForwardRewrite); + .set_attr("FScaleAxisForwardRewrite", AddSubForwardRewrite); // Producer operators // Multiply produces the scale-axis pair. -Expr MultiplyForwardRewrite(const Call& ref_call, - const Array& new_args, +Expr MultiplyForwardRewrite(const Call& ref_call, const Array& new_args, const Message& message) { if (!message.defined()) return Expr(); const auto& expected_out_axes = message->axes; @@ -451,7 +473,7 @@ Expr MultiplyForwardRewrite(const Call& ref_call, } RELAY_REGISTER_OP("multiply") -.set_attr("FScaleAxisForwardRewrite", MultiplyForwardRewrite); + .set_attr("FScaleAxisForwardRewrite", MultiplyForwardRewrite); // Consumer operators // Conv2D send out requirement of axis folding. @@ -467,7 +489,6 @@ Array Conv2DForwardPrep(const Call& call, const Message& out_message) { CHECK_GE(c_big_axis, 0); Message none = NullValue(); - AxesSet data_axes = NullValue(); // For now, we only support simple pattern (no folded weight/data) // More general layout can be supported under the current framework. // By using a unified layout transformation. @@ -476,20 +497,23 @@ Array Conv2DForwardPrep(const Call& call, const Message& out_message) { // only handle depthwise or full conv2d. // TODO(tvm-team) handle grouped conv by reshape + bcast bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout); - if (kernel_layout.IndexOf(LayoutAxis::Get('i')) < 0 && - c_small_axis < 0 && - (param->groups == 1 || is_depthwise_conv2d)) { - data_axes = {c_big_axis}; - } - if (data_axes.defined()) { - return {Message(data_axes, false), none}; + if (param->groups == 1 || is_depthwise_conv2d) { + auto ko_small_axis = kernel_layout.IndexOf(LayoutAxis::Get('o')); + auto ki_small_axis = kernel_layout.IndexOf(LayoutAxis::Get('i')); + if ((ko_small_axis < 0 && ki_small_axis < 0 && c_small_axis < 0) || // simple layout + (ko_small_axis >= 0 && ki_small_axis >= 0 && c_small_axis >= 0)) { // blocked layout + Array arr{c_big_axis}; + if (c_small_axis >= 0) { + arr.push_back(c_small_axis); + } + return {Message(arr, false), none}; + } } return {none, none}; } // Conv2D consumes the scale axis during transformation. -Expr Conv2DForwardRewrite(const Call& ref_call, - const Array& new_args, +Expr Conv2DForwardRewrite(const Call& ref_call, const Array& new_args, const Message& message) { // if data do not have scale, normal transform path. const auto* sdata = new_args[0].as(); @@ -502,13 +526,14 @@ Expr Conv2DForwardRewrite(const Call& ref_call, Layout kernel_layout(param->kernel_layout); int c_big_axis = data_layout.IndexOf(LayoutAxis::Get('C')); CHECK_GE(c_big_axis, 0); - // For now, we only support simple pattern (no folded weight/data) - // TODO(tvm-team) support general data layout - CHECK_EQ(kernel_layout.IndexOf(LayoutAxis::Get('i')), -1); - CHECK(sdata->axes.size() == 1 && - c_big_axis == sdata->axes[0]->value); - int big_oc_axis = kernel_layout.IndexOf(LayoutAxis::Get('O')); - int big_ic_axis = kernel_layout.IndexOf(LayoutAxis::Get('I')); + int small_ko_axis = kernel_layout.IndexOf(LayoutAxis::Get('o')); + int small_ki_axis = kernel_layout.IndexOf(LayoutAxis::Get('i')); + int big_ki_axis = kernel_layout.IndexOf(LayoutAxis::Get('I')); + int big_ko_axis = kernel_layout.IndexOf(LayoutAxis::Get('O')); + + bool is_simple = (small_ko_axis < 0 && small_ki_axis < 0 && big_ki_axis >= 0); + bool is_blocking = (small_ko_axis >= 0 && small_ki_axis >= 0 && big_ki_axis >= 0); + CHECK(is_simple || is_blocking); // Check it must be depthwise or full conv2d. bool is_depthwise_conv2d = IsDepthwiseConv2D(ref_call, param, kernel_layout); @@ -518,29 +543,39 @@ Expr Conv2DForwardRewrite(const Call& ref_call, // match the ic_axis if (is_depthwise_conv2d) { - Expr scale = ExpandBiasToMatchAxis( - sdata->scale, kernel_layout.ndim(), {big_oc_axis}); - weight = Multiply(weight, scale); + if (is_simple) { + Expr scale = ExpandBiasToMatchAxis(sdata->scale, kernel_layout.ndim(), {big_ko_axis}); + weight = Multiply(weight, scale); + } else { + weight = Multiply(weight, + ReshapeToMatchAxis(sdata->scale, weight->type_as()->shape, + {big_ko_axis, small_ko_axis})); + if (!weight.defined()) return Expr(); + } + } else { - Expr scale = ExpandBiasToMatchAxis( - sdata->scale, kernel_layout.ndim(), {big_ic_axis}); - weight = Multiply(weight, scale); + if (is_simple) { + Expr scale = ExpandBiasToMatchAxis(sdata->scale, kernel_layout.ndim(), {big_ki_axis}); + weight = Multiply(weight, scale); + } else { + weight = Multiply(weight, + ReshapeToMatchAxis(sdata->scale, weight->type_as()->shape, + {big_ki_axis, small_ki_axis})); + if (!weight.defined()) return Expr(); + } } // return transformed conv2d - return Call( - ref_call->op, {sdata->value, weight}, ref_call->attrs, ref_call->type_args); + return Call(ref_call->op, {sdata->value, weight}, ref_call->attrs, ref_call->type_args); } -RELAY_REGISTER_OP("nn.conv2d") -.set_attr("FScaleAxisForwardPrep", Conv2DForwardPrep); +RELAY_REGISTER_OP("nn.conv2d").set_attr("FScaleAxisForwardPrep", Conv2DForwardPrep); RELAY_REGISTER_OP("nn.conv2d") -.set_attr("FScaleAxisForwardRewrite", Conv2DForwardRewrite); - + .set_attr("FScaleAxisForwardRewrite", Conv2DForwardRewrite); Expr ForwardFoldScaleAxis(const Expr& data) { auto message = ForwardPrep().Prepare(data); - auto fcontext = [&](const Call& call) -> ObjectRef{ + auto fcontext = [&](const Call& call) -> ObjectRef { auto it = message.find(call.get()); if (it != message.end()) { return it->second; @@ -548,8 +583,7 @@ Expr ForwardFoldScaleAxis(const Expr& data) { return ObjectRef(nullptr); } }; - return ForwardRewrite( - data, "FScaleAxisForwardRewrite", fcontext); + return ForwardRewrite(data, "FScaleAxisForwardRewrite", fcontext); } //---------------------------------------- @@ -564,14 +598,11 @@ class BackwardTransformer; * positive scale is required. * \return Message containing the result scaling on axes of the input. */ -using FBackwardPrep = TypedPackedFunc< - Message(const Call& call, const Array& in_messages)>; +using FBackwardPrep = TypedPackedFunc& in_messages)>; -using FBackwardTransform = TypedPackedFunc< - Expr(const Call& call, - const Message& message, - const Expr& scale, - const BackwardTransformer& transformer)>; +using FBackwardTransform = + TypedPackedFunc; //---------------------------------------------- // Generic Visitors for FScaleAxisBackward @@ -580,8 +611,7 @@ using FBackwardTransform = TypedPackedFunc< class BackwardPrep : private ExprVisitor { public: // The message on each node. - std::unordered_map - Prepare(const Expr& body) { + std::unordered_map Prepare(const Expr& body) { ref_counter_ = GetExprRefCount(body); this->VisitExpr(body); return std::move(message_); @@ -595,8 +625,7 @@ class BackwardPrep : private ExprVisitor { // Visit the expression. void VisitExpr_(const CallNode* call) { ExprVisitor::VisitExpr_(call); - static const auto& fprep = - Op::GetAttr("FScaleAxisBackwardPrep"); + static const auto& fprep = Op::GetAttrMap("FScaleAxisBackwardPrep"); auto f = fprep.get(call->op, nullptr); if (f == nullptr) return; auto rit = ref_counter_.find(call); @@ -620,9 +649,7 @@ class BackwardPrep : private ExprVisitor { } }; -class BackwardTransformerNode : - public Object, - private ExprMutator { +class BackwardTransformerNode : public Object, private ExprMutator { public: // Run forward transform. Expr Fold(Expr expr) { @@ -692,19 +719,15 @@ class BackwardTransformerNode : class BackwardTransformer : public ObjectRef { public: BackwardTransformer() {} - explicit BackwardTransformer( - ::tvm::ObjectPtr<::tvm::Object> n) : ObjectRef(n) { - } + explicit BackwardTransformer(::tvm::ObjectPtr<::tvm::Object> n) : ObjectRef(n) {} BackwardTransformerNode* operator->() const { return static_cast(get_mutable()); } using ContainerType = BackwardTransformerNode; }; -Expr BackwardTransformerNode::Transform( - const CallNode* call_node, Message message, Expr scale) { - static const auto& ftransform = - Op::GetAttr("FScaleAxisBackwardTransform"); +Expr BackwardTransformerNode::Transform(const CallNode* call_node, Message message, Expr scale) { + static const auto& ftransform = Op::GetAttrMap("FScaleAxisBackwardTransform"); auto f = ftransform.get(call_node->op, nullptr); if (f != nullptr) { const Call call = GetRef(call_node); @@ -712,10 +735,7 @@ Expr BackwardTransformerNode::Transform( if (it != memo_.end()) { return it->second; } - Expr new_expr = f(GetRef(call_node), - message, - scale, - GetRef(this)); + Expr new_expr = f(GetRef(call_node), message, scale, GetRef(this)); memo_[call] = new_expr; return new_expr; } else { @@ -724,7 +744,6 @@ Expr BackwardTransformerNode::Transform( } } - //---------------------------------------------- // Per operator defs for FScaleAxisForward //---------------------------------------------- @@ -737,45 +756,38 @@ Message ReluBackwardPrep(const Call& call, const Array& in_messages) { return in_messages[0]; } -Expr ReluBackwardTransform(const Call& call, - const Message& message, - const Expr& scale, +Expr ReluBackwardTransform(const Call& call, const Message& message, const Expr& scale, const BackwardTransformer& transformer) { if (!message.defined()) { return transformer->NormalCallTransform(call.operator->()); } - Expr input = transformer->Transform( - call->args[0], message, scale); + Expr input = transformer->Transform(call->args[0], message, scale); return Call(call->op, {input}, call->attrs, call->type_args); } -RELAY_REGISTER_OP("nn.relu") -.set_attr("FScaleAxisBackwardPrep", ReluBackwardPrep); +RELAY_REGISTER_OP("nn.relu").set_attr("FScaleAxisBackwardPrep", ReluBackwardPrep); -RELAY_REGISTER_OP("nn.relu") -.set_attr("FScaleAxisBackwardTransform", ReluBackwardTransform); +RELAY_REGISTER_OP("nn.relu").set_attr("FScaleAxisBackwardTransform", + ReluBackwardTransform); RELAY_REGISTER_OP("nn.leaky_relu") -.set_attr("FScaleAxisBackwardPrep", ReluBackwardPrep); + .set_attr("FScaleAxisBackwardPrep", ReluBackwardPrep); RELAY_REGISTER_OP("nn.leaky_relu") -.set_attr("FScaleAxisBackwardTransform", ReluBackwardTransform); + .set_attr("FScaleAxisBackwardTransform", ReluBackwardTransform); // AddSub Message AddSubBackwardPrep(const Call& call, const Array& in_messages) { const auto* tlhs = call->args[0]->type_as(); const auto* trhs = call->args[1]->type_as(); StructuralEqual equal; - if (in_messages[0].defined() && - MatchBroadcastToLeftAxes(tlhs, trhs, in_messages[0]->axes)) { + if (in_messages[0].defined() && MatchBroadcastToLeftAxes(tlhs, trhs, in_messages[0]->axes)) { return in_messages[0]; } else if (in_messages[1].defined() && MatchBroadcastToLeftAxes(trhs, tlhs, in_messages[1]->axes)) { return in_messages[1]; - } else if (in_messages[0].defined() && - in_messages[1].defined() && - equal(in_messages[0]->axes, in_messages[1]->axes) && - equal(tlhs->shape, trhs->shape)) { + } else if (in_messages[0].defined() && in_messages[1].defined() && + equal(in_messages[0]->axes, in_messages[1]->axes) && equal(tlhs->shape, trhs->shape)) { // add of two elements. return in_messages[0]; } else { @@ -784,9 +796,7 @@ Message AddSubBackwardPrep(const Call& call, const Array& in_messages) } } -Expr AddSubBackwardTransform(const Call& call, - const Message& message, - const Expr& scale, +Expr AddSubBackwardTransform(const Call& call, const Message& message, const Expr& scale, const BackwardTransformer& transformer) { const auto* tlhs = call->args[0]->type_as(); const auto* trhs = call->args[1]->type_as(); @@ -806,19 +816,21 @@ Expr AddSubBackwardTransform(const Call& call, } else if (lhs_message.defined()) { CHECK(equal(message->axes, lhs_message->axes)); Expr lhs = transformer->Transform(call->args[0], message, scale); - Expr rhs = transformer->Transform( - call->args[1], NullValue(), NullValue()); - Expr rhs_scale = ExpandBiasToMatchAxis( - scale, tlhs->shape.size(), message->axes); + Expr rhs = transformer->Transform(call->args[1], NullValue(), NullValue()); + Expr rhs_scale = ReshapeOrExpandToMatchAxis(scale, tlhs->shape, message->axes); + if (!rhs_scale.defined()) { + return transformer->NormalCallTransform(call.operator->()); + } rhs = Multiply(rhs, rhs_scale); return Call(call->op, {lhs, rhs}, call->attrs, call->type_args); } else if (rhs_message.defined()) { CHECK(equal(message->axes, rhs_message->axes)); - Expr lhs = transformer->Transform( - call->args[0], NullValue(), NullValue()); + Expr lhs = transformer->Transform(call->args[0], NullValue(), NullValue()); Expr rhs = transformer->Transform(call->args[1], message, scale); - Expr lhs_scale = ExpandBiasToMatchAxis( - scale, trhs->shape.size(), message->axes); + Expr lhs_scale = ReshapeOrExpandToMatchAxis(scale, trhs->shape, message->axes); + if (!lhs_scale.defined()) { + return transformer->NormalCallTransform(call.operator->()); + } lhs = Multiply(lhs, lhs_scale); return Call(call->op, {lhs, rhs}, call->attrs, call->type_args); } else { @@ -827,23 +839,19 @@ Expr AddSubBackwardTransform(const Call& call, } } -RELAY_REGISTER_OP("add") -.set_attr("FScaleAxisBackwardPrep", AddSubBackwardPrep); +RELAY_REGISTER_OP("add").set_attr("FScaleAxisBackwardPrep", AddSubBackwardPrep); -RELAY_REGISTER_OP("add") -.set_attr("FScaleAxisBackwardTransform", AddSubBackwardTransform); +RELAY_REGISTER_OP("add").set_attr("FScaleAxisBackwardTransform", + AddSubBackwardTransform); -RELAY_REGISTER_OP("subtract") -.set_attr("FScaleAxisBackwardPrep", AddSubBackwardPrep); +RELAY_REGISTER_OP("subtract").set_attr("FScaleAxisBackwardPrep", AddSubBackwardPrep); RELAY_REGISTER_OP("subtract") -.set_attr("FScaleAxisBackwardTransform", AddSubBackwardTransform); + .set_attr("FScaleAxisBackwardTransform", AddSubBackwardTransform); // Producer operators // Multiply produces the scale-axis pair. -Expr MultiplyBackwardTransform(const Call& call, - const Message& message, - const Expr& scale, +Expr MultiplyBackwardTransform(const Call& call, const Message& message, const Expr& scale, const BackwardTransformer& transformer) { CHECK(!message.defined()) << "outstanding scale"; const auto* tlhs = call->args[0]->type_as(); @@ -871,7 +879,7 @@ Expr MultiplyBackwardTransform(const Call& call, } RELAY_REGISTER_OP("multiply") -.set_attr("FScaleAxisBackwardTransform", MultiplyBackwardTransform); + .set_attr("FScaleAxisBackwardTransform", MultiplyBackwardTransform); // Consumer operators // Conv2D send out requirement of axis folding. @@ -892,20 +900,23 @@ Message Conv2DBackwardPrep(const Call& call, const Array& in_messages) // only handle depthwise or full conv2d. // TODO(tvm-team) handle grouped conv by reshape + bcast bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout); - if (kernel_layout.IndexOf(LayoutAxis::Get('o')) < 0 && - kernel_layout.IndexOf(LayoutAxis::Get('i')) < 0 && - c_small_axis < 0 && - (param->groups == 1 || is_depthwise_conv2d)) { - return Message({c_big_axis}, false); - } else { - return NullValue(); + if (param->groups == 1 || is_depthwise_conv2d) { + auto ko_small_axis = kernel_layout.IndexOf(LayoutAxis::Get('o')); + auto ki_small_axis = kernel_layout.IndexOf(LayoutAxis::Get('i')); + if ((ko_small_axis < 0 && ki_small_axis < 0 && c_small_axis < 0) || // simple layout + (ko_small_axis >= 0 && ki_small_axis >= 0 && c_small_axis >= 0)) { // blocked layout + Array arr{c_big_axis}; + if (c_small_axis >= 0) { + arr.push_back(c_small_axis); + } + return Message(arr, false); + } } + return NullValue(); } // Conv2D consumes the scale axis during transformation. -Expr Conv2DBackwardTransform(const Call& call, - const Message& message, - const Expr& scale, +Expr Conv2DBackwardTransform(const Call& call, const Message& message, const Expr& scale, const BackwardTransformer& transformer) { if (!message.defined()) { return transformer->NormalCallTransform(call.operator->()); @@ -918,33 +929,37 @@ Expr Conv2DBackwardTransform(const Call& call, CHECK_GE(c_big_axis, 0); // For now, we only support simple pattern (no folded weight/data) // TODO(tvm-team) support general data layout - CHECK_EQ(kernel_layout.IndexOf(LayoutAxis::Get('o')), -1); - CHECK_EQ(kernel_layout.IndexOf(LayoutAxis::Get('i')), -1); - CHECK(message->axes.size() == 1 && - c_big_axis == message->axes[0]->value); - - int big_oc_axis = kernel_layout.IndexOf(LayoutAxis::Get('O')); + int small_ko_axis = kernel_layout.IndexOf(LayoutAxis::Get('o')); + int small_ki_axis = kernel_layout.IndexOf(LayoutAxis::Get('i')); + int big_ki_axis = kernel_layout.IndexOf(LayoutAxis::Get('I')); + int big_ko_axis = kernel_layout.IndexOf(LayoutAxis::Get('O')); // Check it must be depthwise or full conv2d. bool is_depthwise_conv2d = IsDepthwiseConv2D(call, param, kernel_layout); CHECK(param->groups == 1 || is_depthwise_conv2d); + bool is_simple = (small_ko_axis < 0 && small_ki_axis < 0 && big_ki_axis >= 0); + bool is_blocking = (small_ko_axis >= 0 && small_ki_axis >= 0 && big_ki_axis >= 0); + CHECK(is_simple || is_blocking); - Expr data = transformer->Transform( - call->args[0], NullValue(), NullValue()); - Expr weight = transformer->Transform( - call->args[1], NullValue(), NullValue()); + Expr data = transformer->Transform(call->args[0], NullValue(), NullValue()); + Expr weight = transformer->Transform(call->args[1], NullValue(), NullValue()); // scale on input for deptwise. - Expr wscale = ExpandBiasToMatchAxis( - scale, kernel_layout.ndim(), {big_oc_axis}); + Expr wscale; + if (is_simple) { + wscale = ExpandBiasToMatchAxis(scale, kernel_layout.ndim(), {big_ko_axis}); + } else { + wscale = ReshapeToMatchAxis(scale, weight->type_as()->shape, + {big_ko_axis, small_ko_axis}); + if (!wscale.defined()) return transformer->NormalCallTransform(call.operator->()); + } weight = Multiply(weight, wscale); - return Call( - call->op, {data, weight}, call->attrs, call->type_args); + return Call(call->op, {data, weight}, call->attrs, call->type_args); } RELAY_REGISTER_OP("nn.conv2d") -.set_attr("FScaleAxisBackwardPrep", Conv2DBackwardPrep); + .set_attr("FScaleAxisBackwardPrep", Conv2DBackwardPrep); RELAY_REGISTER_OP("nn.conv2d") -.set_attr("FScaleAxisBackwardTransform", Conv2DBackwardTransform); + .set_attr("FScaleAxisBackwardTransform", Conv2DBackwardTransform); Expr BackwardFoldScaleAxis(const Expr& data) { return make_object()->Fold(data); @@ -956,39 +971,33 @@ namespace transform { Pass ForwardFoldScaleAxis() { runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - return Downcast( - relay::fold_scale_axis::ForwardFoldScaleAxis(f)); - }; + [=](Function f, IRModule m, PassContext pc) { + return Downcast(relay::fold_scale_axis::ForwardFoldScaleAxis(f)); + }; return CreateFunctionPass(pass_func, 3, "ForwardFoldScaleAxis", {"InferType"}); } -TVM_REGISTER_GLOBAL("relay._transform.ForwardFoldScaleAxis") -.set_body_typed(ForwardFoldScaleAxis); +TVM_REGISTER_GLOBAL("relay._transform.ForwardFoldScaleAxis").set_body_typed(ForwardFoldScaleAxis); Pass BackwardFoldScaleAxis() { runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - return Downcast( - relay::fold_scale_axis::BackwardFoldScaleAxis(f)); - }; + [=](Function f, IRModule m, PassContext pc) { + return Downcast(relay::fold_scale_axis::BackwardFoldScaleAxis(f)); + }; return CreateFunctionPass(pass_func, 3, "BackwardFoldScaleAxis", {"InferType"}); } -TVM_REGISTER_GLOBAL("relay._transform.BackwardFoldScaleAxis") -.set_body_typed(BackwardFoldScaleAxis); +TVM_REGISTER_GLOBAL("relay._transform.BackwardFoldScaleAxis").set_body_typed(BackwardFoldScaleAxis); Pass FoldScaleAxis() { // FoldScaleAxis pass contains the following three passes. Therefore, we can // register it as a sequential pass. - Pass pass = Sequential( - {BackwardFoldScaleAxis(), ForwardFoldScaleAxis(), FoldConstant()}, - "FoldScaleAxis"); + Pass pass = Sequential({BackwardFoldScaleAxis(), ForwardFoldScaleAxis(), FoldConstant()}, + "FoldScaleAxis"); return pass; } -TVM_REGISTER_GLOBAL("relay._transform.FoldScaleAxis") -.set_body_typed(FoldScaleAxis); +TVM_REGISTER_GLOBAL("relay._transform.FoldScaleAxis").set_body_typed(FoldScaleAxis); } // namespace transform diff --git a/src/relay/transforms/forward_rewrite.cc b/src/relay/transforms/forward_rewrite.cc index f01c4faeff3e..f093f5425d94 100644 --- a/src/relay/transforms/forward_rewrite.cc +++ b/src/relay/transforms/forward_rewrite.cc @@ -26,6 +26,7 @@ #include #include #include + #include "pass_util.h" namespace tvm { @@ -36,9 +37,7 @@ namespace relay { // so that calling realize repeatively won't hurt perf. class TempRealizer : private MixedModeMutator { public: - Expr Realize(Expr expr) { - return Mutate(expr); - } + Expr Realize(Expr expr) { return Mutate(expr); } private: Expr DispatchVisitExpr(const Expr& expr) final { @@ -54,20 +53,15 @@ class TempRealizer : private MixedModeMutator { class ForwardRewriter : private MixedModeMutator { public: - ForwardRewriter(const OpMap* rewrite_map, + ForwardRewriter(const OpAttrMap* rewrite_map, std::function fcontext, std::function fmulti_ref_trigger) - : rewrite_map_(rewrite_map), - fcontext_(fcontext), - fmulti_ref_trigger_(fmulti_ref_trigger) {} + : rewrite_map_(rewrite_map), fcontext_(fcontext), fmulti_ref_trigger_(fmulti_ref_trigger) {} ForwardRewriter(const FForwardRewrite* rewrite_func, std::function fcontext, std::function fmulti_ref_trigger) - : rewrite_func_(rewrite_func), - fcontext_(fcontext), - fmulti_ref_trigger_(fmulti_ref_trigger) {} - + : rewrite_func_(rewrite_func), fcontext_(fcontext), fmulti_ref_trigger_(fmulti_ref_trigger) {} // Transform expression. Expr Rewrite(const Expr& expr) { @@ -79,7 +73,7 @@ class ForwardRewriter : private MixedModeMutator { private: // The rewrite rule. - const OpMap* rewrite_map_{nullptr}; + const OpAttrMap* rewrite_map_{nullptr}; const FForwardRewrite* rewrite_func_{nullptr}; // The context.const std::function fcontext_{nullptr}; @@ -91,7 +85,7 @@ class ForwardRewriter : private MixedModeMutator { TempRealizer realizer_; // Visit and allow non-realized version. - Expr GetTempExpr(const Expr& expr, const Expr& post) { + Expr GetTempExpr(const Expr& expr, const Expr& post) { if (fmulti_ref_trigger_ != nullptr) { Expr ret = post; auto it = ref_counter_.find(expr.get()); @@ -160,9 +154,8 @@ class ForwardRewriter : private MixedModeMutator { } // try to rewrite. if (frewrite != nullptr) { - Expr res = frewrite( - ref_call, call_args, - fcontext_ != nullptr ? fcontext_(ref_call) : ObjectRef(nullptr)); + Expr res = frewrite(ref_call, call_args, + fcontext_ != nullptr ? fcontext_(ref_call) : ObjectRef(nullptr)); if (res.defined()) return res; // abort, use old rule for (size_t i = 0; i < call_args.size(); ++i) { @@ -175,21 +168,18 @@ class ForwardRewriter : private MixedModeMutator { } } if (unchanged) return ref_call; - return Call( - new_op, call_args, call_node->attrs, call_node->type_args); + return Call(new_op, call_args, call_node->attrs, call_node->type_args); } }; -Expr ForwardRewrite(const Expr& expr, - const std::string& rewrite_map_name, +Expr ForwardRewrite(const Expr& expr, const String& rewrite_map_name, std::function fcontext, std::function fmulti_ref_trigger) { - auto rewrite_map = Op::GetAttr(rewrite_map_name); + auto rewrite_map = Op::GetAttrMap(rewrite_map_name); return ForwardRewriter(&rewrite_map, fcontext, fmulti_ref_trigger).Rewrite(expr); } -Expr ForwardRewrite(const Expr& expr, - const FForwardRewrite& rewrite_func, +Expr ForwardRewrite(const Expr& expr, const FForwardRewrite& rewrite_func, std::function fcontext, std::function fmulti_ref_trigger) { return ForwardRewriter(&rewrite_func, fcontext, fmulti_ref_trigger).Rewrite(expr); diff --git a/src/relay/transforms/fuse_ops.cc b/src/relay/transforms/fuse_ops.cc index f646042962f0..01f1eeea30b3 100644 --- a/src/relay/transforms/fuse_ops.cc +++ b/src/relay/transforms/fuse_ops.cc @@ -24,13 +24,15 @@ * \brief This is a backend-aware optimization pass. * Fuse necessary ops into a single one. */ -#include #include #include #include #include -#include "pattern_util.h" +#include + #include "../../support/arena.h" +#include "pass_util.h" +#include "pattern_util.h" namespace tvm { namespace relay { @@ -53,8 +55,9 @@ namespace relay { However, at the point of conv2d we do not necessarily know that all the future paths will merge at the elemwise add. The fusion algorithm applies post-dominator analysis. - The immediate post-dominator of a node defined by the closest node where all the future path goes into. - In the above case, the elemwise add is the post-dominator of conv2d. The general algorithm is as follows: + The immediate post-dominator of a node defined by the closest node where all the future path goes + into. In the above case, the elemwise add is the post-dominator of conv2d. The general algorithm + is as follows: - Construct a DAG of dataflow graph for dominator analysis - Construct a post-dominator tree which gives immediate post dominator of each node. @@ -73,8 +76,8 @@ namespace relay { - CommitFuse: mark all the nodes between source and post-dominator as the same group. - We use an Union-Find data structure to manage the groups. */ -using support::LinkNode; using support::LinkedList; +using support::LinkNode; constexpr uint32_t kMaxFusedOps = 256; @@ -123,9 +126,7 @@ class IndexedForwardGraph { std::ostringstream os; for (size_t i = 0; i < post_dfs_order.size(); ++i) { Node* node = post_dfs_order[i]; - os << "node[" << i << "], " - << GetRef(node->ref) - << " outputs=["; + os << "node[" << i << "], " << GetRef(node->ref) << " outputs=["; for (auto* link = node->outputs.head; link != nullptr; link = link->next) { os << link->value.node->index << ", "; } @@ -147,8 +148,7 @@ class IndexedForwardGraph { // Creator of post dominator tree of the dataflow class IndexedForwardGraph::Creator : private ExprVisitor { public: - explicit Creator(support::Arena* arena) - : arena_(arena) {} + explicit Creator(support::Arena* arena) : arena_(arena) {} IndexedForwardGraph Prepare(const Expr& body) { this->Update(body, nullptr, kOpaque); @@ -164,9 +164,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor { // attribute equal comparator StructuralEqual attr_equal_; // Update the message stored at the node. - void Update(const Expr& node, - IndexedForwardGraph::Node* parent, - OpPatternKind pattern) { + void Update(const Expr& node, IndexedForwardGraph::Node* parent, OpPatternKind pattern) { const tvm::Object* key = node.get(); IndexedForwardGraph::Node* current; auto it = graph_.node_map.find(key); @@ -188,8 +186,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor { void AddNode(const tvm::Object* key) { auto it = graph_.node_map.find(key); - CHECK(it != graph_.node_map.end()) - << "Cannot find node " << GetRef(key); + CHECK(it != graph_.node_map.end()) << "Cannot find node " << GetRef(key); IndexedForwardGraph::Node* node = it->second; CHECK(node->ref == nullptr); node->ref = key; @@ -199,6 +196,9 @@ class IndexedForwardGraph::Creator : private ExprVisitor { // Post order tree void VisitExpr_(const FunctionNode* op) final { + // Skip the function that should be handled by external codegen. + if (op->GetAttr(attr::kCompiler).defined()) return; + for (auto param : op->params) { this->Update(param, nullptr, kOpaque); } @@ -211,12 +211,9 @@ class IndexedForwardGraph::Creator : private ExprVisitor { Node* node = graph_.node_map.at(op); DataType dtype = DataType(op->data->dtype); // This rule must be consistent with code generator. - bool is_simple_const = ( - dtype == DataType::Int(32) || - dtype == DataType::Int(64) || - dtype == DataType::Float(32) || - dtype == DataType::Float(64) || - dtype == DataType::Bool()); + bool is_simple_const = + (dtype == DataType::Int(32) || dtype == DataType::Int(64) || dtype == DataType::Float(32) || + dtype == DataType::Float(64) || dtype == DataType::Bool()); if (op->is_scalar() && is_simple_const) { node->pattern = kElemWise; } else { @@ -229,8 +226,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor { void VisitExpr_(const CallNode* call) final { CHECK(graph_.node_map.count(call)); Node* node = graph_.node_map.at(call); - static auto fpattern = - Op::GetAttr("TOpPattern"); + static auto fpattern = Op::GetAttrMap("TOpPattern"); // Now we set the pattern of this call. // // If we see a call mentioning an operator we should mark it with its @@ -242,7 +238,13 @@ class IndexedForwardGraph::Creator : private ExprVisitor { // need to call Update, as it may be an arbitrary expression. OpPatternKind op_pattern = kOpaque; if (const OpNode* opnode = call->op.as()) { - op_pattern = static_cast(fpattern[GetRef(opnode)]); + auto op = GetRef(opnode); + if (IsDynamic(call->checked_type()) && IsDataDependant(call)) { + // output of a shape func can't be fed to a data-dependent shape func + op_pattern = kOpaque; + } else { + op_pattern = static_cast(fpattern[op]); + } } else { this->Update(call->op, node, kOpaque); } @@ -252,13 +254,10 @@ class IndexedForwardGraph::Creator : private ExprVisitor { const auto* rtype = call->checked_type().as(); // pass the analysis back to all the children it references. for (size_t i = 0; i < call->args.size(); ++i) { - const auto* arg_type = - call->args[i]->checked_type().as(); + const auto* arg_type = call->args[i]->checked_type().as(); // specifically check if result type is the same as arguments type OpPatternKind edge_pattern = op_pattern; - if (edge_pattern == kBroadcast && - arg_type != nullptr && - rtype != nullptr && + if (edge_pattern == kBroadcast && arg_type != nullptr && rtype != nullptr && attr_equal_(rtype->shape, arg_type->shape)) { edge_pattern = kElemWise; } @@ -310,9 +309,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor { this->AddNode(op); } - void VisitExpr_(const VarNode* op) final { - this->AddNode(op); - } + void VisitExpr_(const VarNode* op) final { this->AddNode(op); } void VisitExpr_(const LetNode* op) final { // do not fuse through let. @@ -361,8 +358,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor { } }; -IndexedForwardGraph IndexedForwardGraph::Create( - support::Arena* arena, const Expr& body) { +IndexedForwardGraph IndexedForwardGraph::Create(support::Arena* arena, const Expr& body) { return Creator(arena).Prepare(body); } @@ -395,13 +391,11 @@ class DominatorTree { * \note This algorithm makes use of the fact that graph is DAG, * and runs a single pass algorithm via LCA (Least Common Ancestor) */ - static DominatorTree PostDom(support::Arena* arena, - const IndexedForwardGraph& graph); + static DominatorTree PostDom(support::Arena* arena, const IndexedForwardGraph& graph); private: // Combine pattern together. - static OpPatternKind CombinePattern( - OpPatternKind lhs, OpPatternKind rhs) { + static OpPatternKind CombinePattern(OpPatternKind lhs, OpPatternKind rhs) { if (lhs > rhs) return lhs; return rhs; } @@ -413,26 +407,19 @@ class DominatorTree { * The combined edge pattern across all the parents. * \return The least common ancestor of the two. */ - static Node* LeastCommonAncestor( - Node* lhs, - Node* rhs, - OpPatternKind* edge_pattern) { + static Node* LeastCommonAncestor(Node* lhs, Node* rhs, OpPatternKind* edge_pattern) { while (lhs != rhs) { if (lhs == nullptr) return nullptr; if (rhs == nullptr) return nullptr; if (lhs->depth < rhs->depth) { - edge_pattern[0] = CombinePattern( - edge_pattern[0], rhs->pattern); + edge_pattern[0] = CombinePattern(edge_pattern[0], rhs->pattern); rhs = rhs->parent; } else if (rhs->depth < lhs->depth) { - edge_pattern[0] = CombinePattern( - edge_pattern[0], lhs->pattern); + edge_pattern[0] = CombinePattern(edge_pattern[0], lhs->pattern); lhs = lhs->parent; } else { - edge_pattern[0] = CombinePattern( - edge_pattern[0], lhs->pattern); - edge_pattern[0] = CombinePattern( - edge_pattern[0], rhs->pattern); + edge_pattern[0] = CombinePattern(edge_pattern[0], lhs->pattern); + edge_pattern[0] = CombinePattern(edge_pattern[0], rhs->pattern); lhs = lhs->parent; rhs = rhs->parent; } @@ -493,9 +480,7 @@ class DominatorTree { } }; - -DominatorTree DominatorTree::PostDom(support::Arena* arena, - const IndexedForwardGraph& graph) { +DominatorTree DominatorTree::PostDom(support::Arena* arena, const IndexedForwardGraph& graph) { DominatorTree tree; tree.nodes.resize(graph.post_dfs_order.size(), nullptr); // reverse topo order @@ -569,13 +554,11 @@ class GraphPartitioner { /*! \brief internal field used for deduplication */ std::unordered_set visited_; // Internal implelementation of CheckPath - template - bool CheckPath_(IndexedForwardGraph::Node* src, - IndexedForwardGraph::Node* sink, - F fcond) { + template + bool CheckPath_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink, F fcond) { if (visited_.count(src)) return true; visited_.insert(src); - Group* gnode = groups_[src->index]; + Group* gnode = groups_[src->index]; CHECK(gnode != nullptr); gnode = gnode->FindRoot(); if (!fcond(gnode->pattern, src == sink)) return false; @@ -597,10 +580,8 @@ class GraphPartitioner { * \tparam F the condition function, with signature * \note sink must be a post-dominator of src. */ - template - bool CheckPath(IndexedForwardGraph::Node* src, - IndexedForwardGraph::Node* sink, - F fcond) { + template + bool CheckPath(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink, F fcond) { CHECK(!src->extern_ref); visited_.clear(); CHECK(src != sink); @@ -610,8 +591,7 @@ class GraphPartitioner { return true; } // Combine two patterns together. - static OpPatternKind CombinePattern( - OpPatternKind lhs, OpPatternKind rhs) { + static OpPatternKind CombinePattern(OpPatternKind lhs, OpPatternKind rhs) { if (lhs > kBroadcast && rhs > kBroadcast) { LOG(FATAL) << "Cannot merge two complex group together"; } @@ -634,14 +614,11 @@ class GraphPartitioner { if (child->master_ref != nullptr) { CHECK(parent->master_ref == nullptr); parent->master_ref = child->master_ref; - parent->pattern = CombinePattern( - child->pattern, parent->pattern); + parent->pattern = CombinePattern(child->pattern, parent->pattern); } } // Internal implelementation of CommitFuse - void CommitFuse_(IndexedForwardGraph::Node* src, - IndexedForwardGraph::Node* sink, - Group* target) { + void CommitFuse_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink, Group* target) { if (src == sink) return; if (visited_.count(src)) return; visited_.insert(src); @@ -659,8 +636,7 @@ class GraphPartitioner { * \param sink The termination node. * \note sink must be a post-dominator of src. */ - void CommitFuse(IndexedForwardGraph::Node* src, - IndexedForwardGraph::Node* sink) { + void CommitFuse(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink) { Group* target = groups_[sink->index]; visited_.clear(); CHECK(src != sink); @@ -684,9 +660,7 @@ class GraphPartitioner { } // execute the fusion algorithm. - void RunFuse(const IndexedForwardGraph& graph, - const DominatorTree& post_dom_tree, - int phase) { + void RunFuse(const IndexedForwardGraph& graph, const DominatorTree& post_dom_tree, int phase) { for (size_t nid = 0; nid < groups_.size(); ++nid) { // the group of current node has been specified already. auto* graph_node = graph.post_dfs_order[nid]; @@ -701,8 +675,7 @@ class GraphPartitioner { size_t dom_parent_gindex = dom_node->parent->gnode->index; // refuse the fusion if too many ops are going to be fused together - if (groups_[dom_parent_gindex]->num_nodes + group_node->num_nodes > kMaxFusedOps) - continue; + if (groups_[dom_parent_gindex]->num_nodes + group_node->num_nodes > kMaxFusedOps) continue; if (phase == 2) { // Fuse injective ops into intermediate tuples, if any @@ -713,9 +686,7 @@ class GraphPartitioner { if (dom_root_group->pattern == kTuple) continue; if (dom_parent_group->pattern == kTuple && dom_root_group->pattern <= kInjective) { // Now we know the tuple has been fused into subsequent injective ops - auto fcond = [](OpPatternKind kind, bool is_sink) { - return kind <= kInjective; - }; + auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kInjective; }; // dom_root_group can also be tuple, as in inception layers // CheckPath is needed to avoid fusing two intermediate tuples if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) { @@ -740,9 +711,7 @@ class GraphPartitioner { if (dom_node->parent != nullptr && dom_node->pattern == kElemWise) { CHECK(dom_node->parent->gnode != nullptr); // The fuse can be executed if all the intermediate ops are still broadcast. - auto fcond = [](OpPatternKind kind, bool is_sink) { - return kind <= kBroadcast; - }; + auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kBroadcast; }; if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) { CommitFuse(graph_node, dom_node->parent->gnode); } @@ -750,8 +719,7 @@ class GraphPartitioner { } else if (group_node->pattern <= kBroadcast) { // Pre-condition: can only be fused to parent which is injective or reduction. if (dom_node->parent != nullptr && - (dom_node->pattern <= kInjective || - dom_node->pattern == kCommReduce)) { + (dom_node->pattern <= kInjective || dom_node->pattern == kCommReduce)) { // Check if all the intermediate ops are still broadcast. // The final terminal node can already be fused to a OutEWiseFusable group. auto fcond = [](OpPatternKind kind, bool is_sink) { @@ -760,9 +728,7 @@ class GraphPartitioner { // are allowed be fused to the elemwise/broadcast master. return kind <= kInjective; } else { - return (kind <= kBroadcast || - kind == kCommReduce || - kind == kInjective || + return (kind <= kBroadcast || kind == kCommReduce || kind == kInjective || kind == kOutEWiseFusable); } }; @@ -775,9 +741,7 @@ class GraphPartitioner { // so conv2d always finishes fusing. if (phase != 1) continue; // Check if all path are injective. - auto fcond = [](OpPatternKind kind, bool is_sink) { - return kind <= kInjective; - }; + auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kInjective; }; if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) { CommitFuse(graph_node, dom_node->parent->gnode); } @@ -789,8 +753,8 @@ class GraphPartitioner { } }; -std::vector -GraphPartitioner::Partition(const IndexedForwardGraph& graph) { +std::vector GraphPartitioner::Partition( + const IndexedForwardGraph& graph) { this->InitGroups(graph); if (opt_level_ == 0) return std::move(groups_); // get post dominator tree @@ -808,8 +772,7 @@ class FuseMutator : private ExprMutator { Expr Transform(const Expr& body, int fuse_opt_level) { // setup the group map. auto graph = IndexedForwardGraph::Create(&arena_, body); - auto groups = GraphPartitioner(&arena_, fuse_opt_level).Partition( - graph); + auto groups = GraphPartitioner(&arena_, fuse_opt_level).Partition(graph); for (size_t nid = 0; nid < graph.post_dfs_order.size(); ++nid) { CHECK(graph.post_dfs_order[nid]->ref != nullptr); gmap_[graph.post_dfs_order[nid]->ref] = groups[nid]; @@ -819,7 +782,6 @@ class FuseMutator : private ExprMutator { return this->Mutate(body); } - private: /*! \brief Temporary information from each group. */ struct GroupInfo { @@ -862,8 +824,7 @@ class FuseMutator : private ExprMutator { // Transform calls. Expr VisitExpr_(const CallNode* call) { if (call->op.as()) { - static auto fnoncomputational = - Op::GetAttr("TNonComputational"); + static auto fnoncomputational = Op::GetAttrMap("TNonComputational"); if (fnoncomputational.get(Downcast(call->op), false)) { return ExprMutator::VisitExpr_(call); @@ -878,8 +839,7 @@ class FuseMutator : private ExprMutator { auto* ret_group = gmap_.at(call)->FindRoot(); Array new_args = GetNewArguments(call->args, ret_group); - auto new_call = Call( - call->op, new_args, call->attrs, call->type_args); + auto new_call = Call(call->op, new_args, call->attrs, call->type_args); if (ret_group->root_ref == call) { // This is the root of the group @@ -926,9 +886,7 @@ class FuseMutator : private ExprMutator { // If the function has no call, it is not a primitive function. struct HasCallVisitor : ExprVisitor { bool has_call = false; - void VisitExpr_(const CallNode* op) final { - has_call = true; - } + void VisitExpr_(const CallNode* op) final { has_call = true; } } visitor; visitor(body); const GroupInfo& ginfo = ginfo_[group]; @@ -957,13 +915,13 @@ class FuseMutator : private ExprMutator { // Debug function, dump the group assignment in text. void DebugDumpGroup(const Expr& body) { std::string text = AsText(body, false, [this](const ObjectRef& expr) -> std::string { - auto it = gmap_.find(expr.get()); - if (it == gmap_.end()) return ""; - std::ostringstream os; - auto *group = it->second->FindRoot(); - os << " /* group=" << group << " */"; - return os.str(); - }); + auto it = gmap_.find(expr.get()); + if (it == gmap_.end()) return ""; + std::ostringstream os; + auto* group = it->second->FindRoot(); + os << " /* group=" << group << " */"; + return os.str(); + }); LOG(INFO) << "Dump of group info:\n" << text; } }; @@ -976,15 +934,14 @@ namespace transform { Pass FuseOps(int fuse_opt_level) { runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - int opt_level = fuse_opt_level == -1 ? pc->opt_level : fuse_opt_level; - return Downcast(FuseOps(f, opt_level, m)); - }; + [=](Function f, IRModule m, PassContext pc) { + int opt_level = fuse_opt_level == -1 ? pc->opt_level : fuse_opt_level; + return Downcast(FuseOps(f, opt_level, m)); + }; return CreateFunctionPass(pass_func, 1, "FuseOps", {"InferType"}); } -TVM_REGISTER_GLOBAL("relay._transform.FuseOps") -.set_body_typed(FuseOps); +TVM_REGISTER_GLOBAL("relay._transform.FuseOps").set_body_typed(FuseOps); } // namespace transform diff --git a/src/relay/transforms/gradient.cc b/src/relay/transforms/gradient.cc index d0ff169445fb..4bc643935dc9 100644 --- a/src/relay/transforms/gradient.cc +++ b/src/relay/transforms/gradient.cc @@ -22,13 +22,14 @@ * \brief API for Automatic Differentiation for the Relay IR. */ #include -#include -#include #include +#include #include -#include "pattern_util.h" -#include "pass_util.h" +#include + #include "let_list.h" +#include "pass_util.h" +#include "pattern_util.h" namespace tvm { namespace relay { @@ -42,12 +43,14 @@ using namespace tvm::runtime; * Formally speaking, such requirement mean that the input function is a closed expression - * that is, it only refer to local variable that is it's parameter, or defined inside it. * Every top level definition satisfy this criteria. - * AD can also be run-time, which mean it is merely a function term of AD : (Float[] -> Float[]) -> (Float[] -> Float[]). - * In relay we currently only support compile-time AD, but it should be enough for a lot of use case. + * AD can also be run-time, which mean it is merely a function term of AD : (Float[] -> Float[]) -> + * (Float[] -> Float[]). In relay we currently only support compile-time AD, but it should be enough + * for a lot of use case. * - * In deep learning, the most common way to train a deep neural network is by gradient descent or some of it's variant. - * Such optimization method require us to input the gradient of neural network, which can be obtained easily using AD. - * In fact, back propagation is essentially reverse-mode automatic differentiation, a kind of AD! + * In deep learning, the most common way to train a deep neural network is by gradient descent or + * some of it's variant. Such optimization method require us to input the gradient of neural + * network, which can be obtained easily using AD. In fact, back propagation is essentially + * reverse-mode automatic differentiation, a kind of AD! */ /*! In relay, automatic differentiation(AD) is a macro, @@ -55,32 +58,31 @@ using namespace tvm::runtime; * (x0, x1, x2, ...) -> Float[] to * (x0, x1, x2, ...) -> (Float[], (x0, x1, x2, ...)), * When x0, x1, x2... are Float of different shape. - * the return value is a pair, with left hand side as the original value, and right hand side as gradient of the input. - * WithGradientType will take the type of input, and produce the type of output. - * There are multiple implementation of AD in relay, with different characteristic. - * However, they all transform the input expr according to WithGradientType. + * the return value is a pair, with left hand side as the original value, and right hand side as + * gradient of the input. WithGradientType will take the type of input, and produce the type of + * output. There are multiple implementation of AD in relay, with different characteristic. However, + * they all transform the input expr according to WithGradientType. */ Type WithGradientType(const Type&); /*! return an expression that represent differentiation of e (according to WithGradientType). * This version only work on first order code without control flow. */ -Expr FirstOrderGradient(const Expr& e, const IRModule& mod); +Expr FirstOrderGradient(const Expr& e, const Optional& mod); Type WithGradientType(const Type& t) { // TODO(M.K.): stricter checking auto ty = t.as(); CHECK(ty) << "input should be a function"; - return FuncType(ty->arg_types, - TupleType({ - ty->ret_type, - TupleType(ty->arg_types)}), {}, {}); + return FuncType(ty->arg_types, TupleType({ty->ret_type, TupleType(ty->arg_types)}), {}, {}); } //! \brief if the expression is a GlobalVar, transform to it's expression. -Expr DeGlobal(const IRModule& mod, const Expr& e) { - if (const auto* x = e.as()) { - BaseFunc base_func = mod->Lookup(GetRef(x)); +Expr DeGlobal(const Optional& mod, const Expr& e) { + const auto* x = e.as(); + + if (mod.defined() && (x)) { + BaseFunc base_func = mod.value()->Lookup(GetRef(x)); if (auto* n = base_func.as()) { return n->body; } else { @@ -95,7 +97,7 @@ Expr DeGlobal(const IRModule& mod, const Expr& e) { * pass. */ struct ADValueNode { - virtual ~ADValueNode() { } + virtual ~ADValueNode() {} template T& get() { auto ret = dynamic_cast(this); @@ -110,8 +112,8 @@ using ADValue = std::shared_ptr; struct ADTensor : ADValueNode { Expr forward; mutable Expr reverse; // must be a variable to avoid duplication - ADTensor(LetList* ll, const Expr& forward) : - forward(ll->Push(forward)), reverse(ll->Push(ZerosLike(this->forward))) { + ADTensor(LetList* ll, const Expr& forward) + : forward(ll->Push(forward)), reverse(ll->Push(ZerosLike(this->forward))) { this->forward->checked_type_ = forward->checked_type(); } }; @@ -121,51 +123,46 @@ struct ADTensor : ADValueNode { * can compute away this function to obtain a reverse mode program. */ struct ADFunction : ADValueNode { - std::function&, - const Attrs&, - const tvm::Array&)> func; - explicit ADFunction(const std::function&, - const Attrs&, - const tvm::Array&)>& func) : - func(func) { } + std::function&, const Attrs&, + const tvm::Array&)> + func; + explicit ADFunction(const std::function&, + const Attrs&, const tvm::Array&)>& func) + : func(func) {} }; -struct FirstOrderReverseAD : ExprFunctor { - const OpMap rev_map = Op::GetAttr("FPrimalGradient"); +struct FirstOrderReverseAD : ExprFunctor { + const OpAttrMap rev_map = Op::GetAttrMap("FPrimalGradient"); std::vector> backprop_actions; // we assume no closure so no need for lexical scoping - std::unordered_map env; + std::unordered_map env; LetList* ll; - FirstOrderReverseAD(LetList* ll) : ll(ll) { } + FirstOrderReverseAD(LetList* ll) : ll(ll) {} ADValue VisitExpr_(const OpNode* op) final { Op op_ref = GetRef(op); - CHECK(rev_map.count(op_ref)) - << op->name << " does not have reverse mode defined"; - return std::make_shared([this, op_ref](const Type& orig_type, - const std::vector& args, - const Attrs& attrs, - const tvm::Array& type_args) { - std::vector call_args; - for (const ADValue& adval : args) { - call_args.push_back(adval->get().forward); - } - auto orig = Call(op_ref, call_args, attrs, type_args); - orig->checked_type_ = orig_type; - auto ret = std::make_shared(ll, orig); - backprop_actions.push_back([this, args, orig, ret, op_ref](LetList* ll) { - tvm::Array rev = rev_map[op_ref](orig, ret->reverse); - CHECK(args.size() == rev.size()); - for (size_t i = 0; i < args.size(); ++i) { - args[i]->get().reverse = - ll->Push(Add(args[i]->get().reverse, rev[i])); - } - }); - return ret; - }); + CHECK(rev_map.count(op_ref)) << op->name << " does not have reverse mode defined"; + return std::make_shared( + [this, op_ref](const Type& orig_type, const std::vector& args, const Attrs& attrs, + const tvm::Array& type_args) { + std::vector call_args; + for (const ADValue& adval : args) { + call_args.push_back(adval->get().forward); + } + auto orig = Call(op_ref, call_args, attrs, type_args); + orig->checked_type_ = orig_type; + auto ret = std::make_shared(ll, orig); + backprop_actions.push_back([this, args, orig, ret, op_ref](LetList* ll) { + tvm::Array rev = rev_map[op_ref](orig, ret->reverse); + CHECK(args.size() == rev.size()); + for (size_t i = 0; i < args.size(); ++i) { + args[i]->get().reverse = + ll->Push(Add(args[i]->get().reverse, rev[i])); + } + }); + return ret; + }); } ADValue VisitExpr_(const ConstantNode* op) final { @@ -185,16 +182,15 @@ struct FirstOrderReverseAD : ExprFunctor { ADValue VisitExpr_(const FunctionNode* op) final { Function f = GetRef(op); // todo: assert no closure - return std::make_shared([this, f](const Type& orig_type, - const std::vector& args, - const Attrs& attrs, - const tvm::Array& type_args) { - CHECK_EQ(f->params.size(), args.size()); - for (size_t i = 0; i < f->params.size(); ++i) { - env[f->params[i]] = args[i]; - } - return VisitExpr(f->body); - }); + return std::make_shared( + [this, f](const Type& orig_type, const std::vector& args, const Attrs& attrs, + const tvm::Array& type_args) { + CHECK_EQ(f->params.size(), args.size()); + for (size_t i = 0; i < f->params.size(); ++i) { + env[f->params[i]] = args[i]; + } + return VisitExpr(f->body); + }); } ADValue VisitExpr_(const VarNode* op) final { @@ -220,7 +216,7 @@ Type GradRetType(const Function& f) { return TupleType({f->ret_type, TupleType(vt)}); } -Expr FirstOrderGradient(const Expr& re, const IRModule& mod) { +Expr FirstOrderGradient(const Expr& re, const Optional& mod) { // Currently we first remove any global functions for the first // order case. auto e = DeGlobal(mod, re); @@ -240,8 +236,7 @@ Expr FirstOrderGradient(const Expr& re, const IRModule& mod) { const auto& res = c->get(); Expr grad = LetList::With([&](LetList* ll) { res.reverse = OnesLike(res.forward); - for (auto it = reverse_ad.backprop_actions.rbegin(); - it != reverse_ad.backprop_actions.rend(); + for (auto it = reverse_ad.backprop_actions.rbegin(); it != reverse_ad.backprop_actions.rend(); ++it) { (*it)(ll); } @@ -257,8 +252,7 @@ Expr FirstOrderGradient(const Expr& re, const IRModule& mod) { return Function(f->params, body, GradRetType(GetRef(f)), {}); } -TVM_REGISTER_GLOBAL("relay._transform.first_order_gradient") -.set_body_typed(FirstOrderGradient); +TVM_REGISTER_GLOBAL("relay._transform.first_order_gradient").set_body_typed(FirstOrderGradient); struct ReverseADType : TypeMutator { Type VisitType_(const TensorTypeNode* ttn) final { @@ -267,17 +261,13 @@ struct ReverseADType : TypeMutator { } }; -Type ReverseType(const Type& t) { - return ReverseADType()(t); -} +Type ReverseType(const Type& t) { return ReverseADType()(t); } /*! \brief Lift a function that transform Tensor to a function that also transform more type * by doing a structure preserving map. */ Expr LiftTensor(const std::function& f, - const std::function& tf, - const Type& forward_type, - const Expr& e, + const std::function& tf, const Type& forward_type, const Expr& e, LetList* ll) { CHECK(IsAtomic(e)) << e; if (forward_type.as()) { @@ -288,11 +278,7 @@ Expr LiftTensor(const std::function& f, tvm::Array fields; tvm::Array types; for (size_t i = 0; i < tt->fields.size(); ++i) { - auto field = LiftTensor(f, - tf, - tt->fields[i], - ll->Push(GetField(e, i)), - ll); + auto field = LiftTensor(f, tf, tt->fields[i], ll->Push(GetField(e, i)), ll); fields.push_back(field); types.push_back(field->checked_type_); } @@ -308,10 +294,7 @@ Expr LiftTensor(const std::function& f, /*! \brief Transfers the gradients from an Expr to a deep duplication of the Expr, * by stitching the references in the AD values. */ -void TransferGrads(const Type& forward_type, - const Expr& from, - const Expr& to, - LetList* ll) { +void TransferGrads(const Type& forward_type, const Expr& from, const Expr& to, LetList* ll) { CHECK(IsAtomic(from)) << from; CHECK(IsAtomic(to)) << to; if (forward_type.as()) { @@ -320,9 +303,7 @@ void TransferGrads(const Type& forward_type, ll->Push(RefWrite(to_ref, RefRead(from_ref))); } else if (auto* tt = forward_type.as()) { for (size_t i = 0; i < tt->fields.size(); ++i) { - TransferGrads(tt->fields[i], - ll->Push(TupleGetItem(from, i)), - ll->Push(TupleGetItem(to, i)), + TransferGrads(tt->fields[i], ll->Push(TupleGetItem(from, i)), ll->Push(TupleGetItem(to, i)), ll); } } else { @@ -333,48 +314,31 @@ void TransferGrads(const Type& forward_type, /*! \brief t -> ReverseType(t). Transform to Reverse Mode Value. */ Expr GetRev(const Type& forward_type, const Expr& e, LetList* ll) { - auto rev = [&](const Expr& e) { - return Pair(e, ll->Push(RefCreate(ZerosLike(e)))); - }; - auto rev_type = [&](const Type& forward_type) { - return ReverseType(forward_type); - }; + auto rev = [&](const Expr& e) { return Pair(e, ll->Push(RefCreate(ZerosLike(e)))); }; + auto rev_type = [&](const Type& forward_type) { return ReverseType(forward_type); }; return LiftTensor(rev, rev_type, forward_type, e, ll); } /*! \brief ReverseType(t) -> t. Get the original value. */ Expr GetValue(const Type& forward_type, const Expr& e, LetList* ll) { - auto val = [&](const Expr& e) { - return GetField(e, 0); - }; - auto val_type = [&](const Type& forward_type) { - return forward_type; - }; + auto val = [&](const Expr& e) { return GetField(e, 0); }; + auto val_type = [&](const Type& forward_type) { return forward_type; }; return LiftTensor(val, val_type, forward_type, e, ll); } /*! \brief ReverseType(t) -> t. Get the gradient. */ Expr GetGrad(const Type& forward_type, const Expr& e, LetList* ll) { - auto grad = [&](const Expr& e) { - return ll->Push(RefRead(GetField(e, 1))); - }; - auto grad_type = [&](const Type& forward_type) { - return forward_type; - }; + auto grad = [&](const Expr& e) { return ll->Push(RefRead(GetField(e, 1))); }; + auto grad_type = [&](const Type& forward_type) { return forward_type; }; return LiftTensor(grad, grad_type, forward_type, e, ll); } void UpdateGrad(const Type& t, const Expr& arg, const Expr& grad, LetList* ll) { if (t.as()) { - ll->Push(RefWrite(GetField(arg, 1), - Add(ll->Push(RefRead(GetField(arg, 1))), - grad))); + ll->Push(RefWrite(GetField(arg, 1), Add(ll->Push(RefRead(GetField(arg, 1))), grad))); } else if (auto* tt = t.as()) { for (size_t i = 0; i < tt->fields.size(); ++i) { - UpdateGrad(tt->fields[i], - ll->Push(GetField(arg, i)), - ll->Push(GetField(grad, i)), - ll); + UpdateGrad(tt->fields[i], ll->Push(GetField(arg, i)), ll->Push(GetField(grad, i)), ll); } } else { LOG(FATAL) << "unsupported arg type of operator: " << t; @@ -388,21 +352,20 @@ Expr BPEmpty() { } struct ReverseAD : ExprMutator { - using ADVarMap = std::unordered_map; + using ADVarMap = std::unordered_map; Var bp; std::shared_ptr ad_vars; - const OpMap rev_map = Op::GetAttr("FPrimalGradient"); + const OpAttrMap rev_map = Op::GetAttrMap("FPrimalGradient"); - explicit ReverseAD(const Var& bp, std::shared_ptr ad_vars) - : bp(bp), ad_vars(ad_vars) { } + explicit ReverseAD(const Var& bp, std::shared_ptr ad_vars) : bp(bp), ad_vars(ad_vars) {} Expr VisitExpr_(const OpNode* op) final { LOG(FATAL) << "op should only be inside call"; throw; } - Expr VisitCheckpoint(const CallNode *call) { + Expr VisitCheckpoint(const CallNode* call) { const OpNode* op_node = call->op.as(); CHECK(op_node) << "expected op in call"; Op op_ref = GetRef(op_node); @@ -412,20 +375,17 @@ struct ReverseAD : ExprMutator { auto x_var = ll->Push(x); auto ret = ll->Push(GetRev(call->checked_type(), x_var, ll)); auto bpv = ll->Push(RefRead(bp)); - Expr nbp = Function( - {}, - LetList::With([&](LetList* ll) { - // we need a new ReverseAD visitor to avoid clobbering the bp local var - auto dup_bp = ll->Push(BPEmpty()); - ReverseAD dup_diff(dup_bp, ad_vars); - auto dup_ad = ll->Push(dup_diff.VisitExpr(DeDup(x))); - - TransferGrads(call->checked_type(), ret, dup_ad, ll); - ll->Push(Call(RefRead(dup_bp), {})); - return Call(bpv, {}); - }), - TupleType::Empty(), - {}); + Expr nbp = Function({}, LetList::With([&](LetList* ll) { + // we need a new ReverseAD visitor to avoid clobbering the bp local var + auto dup_bp = ll->Push(BPEmpty()); + ReverseAD dup_diff(dup_bp, ad_vars); + auto dup_ad = ll->Push(dup_diff.VisitExpr(DeDup(x))); + + TransferGrads(call->checked_type(), ret, dup_ad, ll); + ll->Push(Call(RefRead(dup_bp), {})); + return Call(bpv, {}); + }), + TupleType::Empty(), {}); ll->Push(RefWrite(bp, nbp)); return ret; }); @@ -439,8 +399,7 @@ struct ReverseAD : ExprMutator { return VisitCheckpoint(call); } - CHECK(rev_map.count(op_ref)) - << op_node->name << " does not have reverse mode defined"; + CHECK(rev_map.count(op_ref)) << op_node->name << " does not have reverse mode defined"; return LetList::With([&](LetList* ll) { std::vector args; for (const auto& arg : call->args) { @@ -456,18 +415,16 @@ struct ReverseAD : ExprMutator { orig_var->checked_type_ = call->checked_type(); auto ret = ll->Push(GetRev(call->checked_type(), orig_var, ll)); auto bpv = ll->Push(RefRead(bp)); - Expr nbp = Function( - {}, - LetList::With([&](LetList* ll) { - tvm::Array rev = rev_map[op_ref](orig, GetGrad(call->checked_type(), ret, ll)); - CHECK(args.size() == rev.size()); - for (size_t i = 0; i < args.size(); ++i) { - UpdateGrad(call->args[i]->checked_type(), args[i], rev[i], ll); - } - return Call(bpv, {}); - }), - TupleType::Empty(), - {}); + Expr nbp = Function({}, LetList::With([&](LetList* ll) { + tvm::Array rev = + rev_map[op_ref](orig, GetGrad(call->checked_type(), ret, ll)); + CHECK(args.size() == rev.size()); + for (size_t i = 0; i < args.size(); ++i) { + UpdateGrad(call->args[i]->checked_type(), args[i], rev[i], ll); + } + return Call(bpv, {}); + }), + TupleType::Empty(), {}); ll->Push(RefWrite(bp, nbp)); return ret; }); @@ -481,9 +438,8 @@ struct ReverseAD : ExprMutator { } Expr VisitExpr_(const IfNode* op) final { - return If(TupleGetItem(VisitExpr(op->cond), 0), - VisitExpr(op->true_branch), - VisitExpr(op->false_branch)); + return If(TupleGetItem(VisitExpr(op->cond), 0), VisitExpr(op->true_branch), + VisitExpr(op->false_branch)); } Expr VisitExpr_(const VarNode* var) final { @@ -497,14 +453,12 @@ struct ReverseAD : ExprMutator { return ad_vars->at(var_ref); } - Type VisitType(const Type& t) final { - return t.defined() ? ReverseType(t) : t; - } + Type VisitType(const Type& t) final { return t.defined() ? ReverseType(t) : t; } }; bool MissingGrad(const Expr& e) { struct MGVisitor : ExprVisitor { - const OpMap rev_map = Op::GetAttr("FPrimalGradient"); + const OpAttrMap rev_map = Op::GetAttrMap("FPrimalGradient"); std::unordered_set op_names; void VisitExpr_(const OpNode* op) final { @@ -530,7 +484,7 @@ bool MissingGrad(const Expr& e) { return false; } -Expr Gradient(const Expr& re, const IRModule& mod) { +Expr Gradient(const Expr& re, const Optional& mod) { auto e = DeGlobal(mod, re); auto f = e.as(); CHECK(f) << "input need to be a function"; @@ -585,8 +539,7 @@ Expr Gradient(const Expr& re, const IRModule& mod) { return Function(f->params, body, GradRetType(GetRef(f)), {}); } -TVM_REGISTER_GLOBAL("relay._transform.gradient") -.set_body_typed(Gradient); +TVM_REGISTER_GLOBAL("relay._transform.gradient").set_body_typed(Gradient); } // namespace relay } // namespace tvm diff --git a/src/relay/transforms/infer_layout_util.h b/src/relay/transforms/infer_layout_util.h index ca730034327a..9868ee5d03db 100644 --- a/src/relay/transforms/infer_layout_util.h +++ b/src/relay/transforms/infer_layout_util.h @@ -27,11 +27,13 @@ #ifndef TVM_RELAY_TRANSFORMS_INFER_LAYOUT_UTIL_H_ #define TVM_RELAY_TRANSFORMS_INFER_LAYOUT_UTIL_H_ -#include #include #include +#include + #include #include + #include "pattern_util.h" namespace tvm { @@ -94,17 +96,15 @@ inline Layout AdjustSubordinateFactors(const Layout& src_layout, const Layout& o * \return infered_layout An array of two elements that are inferred input layouts and * inferred output layouts. */ -using FInferCorrectLayout = runtime::TypedPackedFunc< - Array>(const Attrs& attrs, - const Array& new_in_layouts, - const Array& old_in_layouts, - const Array &old_in_types)>; +using FInferCorrectLayout = runtime::TypedPackedFunc>( + const Attrs& attrs, const Array& new_in_layouts, const Array& old_in_layouts, + const Array& old_in_types)>; /*! \brief take arbitrary input layout and copy to output */ -inline Array > ElemwiseArbitraryLayout(const Attrs& attrs, - const Array& new_in_layouts, - const Array& old_in_layouts, - const Array &old_in_types) { +inline Array> ElemwiseArbitraryLayout(const Attrs& attrs, + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array& old_in_types) { Layout ret; if (new_in_layouts.defined()) { @@ -119,14 +119,14 @@ inline Array > ElemwiseArbitraryLayout(const Attrs& attrs, } } - return Array >{Array(old_in_layouts.size(), ret), {ret}}; + return Array>{Array(old_in_layouts.size(), ret), {ret}}; } /*! \brief Infer layout for binary broadcast operators */ -inline Array > BinaryBroadcastLayout(const Attrs& attrs, - const Array& new_in_layouts, - const Array& old_in_layouts, - const Array &old_in_types) { +inline Array> BinaryBroadcastLayout(const Attrs& attrs, + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array& old_in_types) { Array layouts; Array> old_in_shapes; for (auto old_in_t : old_in_types) { @@ -135,35 +135,34 @@ inline Array > BinaryBroadcastLayout(const Attrs& attrs, } if (new_in_layouts.defined()) { - layouts.assign(new_in_layouts.begin(), new_in_layouts.end()); + layouts.Assign(new_in_layouts.begin(), new_in_layouts.end()); } else { - layouts.assign(old_in_layouts.begin(), old_in_layouts.end()); + layouts.Assign(old_in_layouts.begin(), old_in_layouts.end()); } if (!layouts[0].defined() && !layouts[1].defined()) { // both undefined, infer fails - return Array > {{Layout::Undef()}, {Layout::Undef()}}; + return Array>{{Layout::Undef()}, {Layout::Undef()}}; } else if (!layouts[0].defined() || !layouts[1].defined()) { // only one is defined, use shape information to help infer int defined_idx = layouts[0].defined() ? 0 : 1; int undef_idx = 1 - defined_idx; if (old_in_shapes[defined_idx].size() >= old_in_shapes[undef_idx].size()) { - layouts.Set(undef_idx, - layouts[defined_idx].SubLayout( - old_in_shapes[defined_idx].size() - old_in_shapes[undef_idx].size(), - old_in_shapes[undef_idx].size())); - return Array >{layouts, {layouts[defined_idx]}}; + layouts.Set(undef_idx, layouts[defined_idx].SubLayout(old_in_shapes[defined_idx].size() - + old_in_shapes[undef_idx].size(), + old_in_shapes[undef_idx].size())); + return Array>{layouts, {layouts[defined_idx]}}; } else { // only know the tensor with smaller dimensions, // so we cannot infer the final broadcasted output. // fails in this case. - return Array >{{Layout::Undef()}, {Layout::Undef()}}; + return Array>{{Layout::Undef()}, {Layout::Undef()}}; } } else if (layouts[0].defined() && layouts[1].defined() && - (layouts[0].ndim() == 0 || layouts[1].ndim() == 0)) { + (layouts[0].ndim() == 0 || layouts[1].ndim() == 0)) { int scalar = layouts[0].ndim() == 0 ? 0 : 1; - return Array >{layouts, {layouts[1-scalar]}}; + return Array>{layouts, {layouts[1 - scalar]}}; } else { // Set the layout of the larger dimension. If one dimension size is lower, we call expand dims // while transforming layout. @@ -209,7 +208,7 @@ inline Array > BinaryBroadcastLayout(const Attrs& attrs, static inline std::tuple, Array, bool> InferCorrectLayouts( const Call& call, const Array& new_in_layouts, const Array& old_in_layouts, const Array& old_in_types) { - static auto finfer_layout = Op::GetAttr("FInferCorrectLayout"); + static auto finfer_layout = Op::GetAttrMap("FInferCorrectLayout"); if (!call->op.as()) { return std::make_tuple<>(Array(nullptr), Array(nullptr), false); } @@ -217,8 +216,7 @@ static inline std::tuple, Array, bool> InferCorrectLayouts Op op = Downcast(call->op); if (finfer_layout.count(op)) { Array> inferred_layouts; - inferred_layouts = - finfer_layout[op](call->attrs, new_in_layouts, old_in_layouts, old_in_types); + inferred_layouts = finfer_layout[op](call->attrs, new_in_layouts, old_in_layouts, old_in_types); CHECK_EQ(inferred_layouts.size(), 2) << "FInferCorrectLayout should return an array with size of 2"; for (auto x : inferred_layouts) { diff --git a/src/relay/transforms/inline.cc b/src/relay/transforms/inline.cc index ba0f5688ea9d..c9a0de44e2d4 100644 --- a/src/relay/transforms/inline.cc +++ b/src/relay/transforms/inline.cc @@ -35,8 +35,9 @@ #include #include -#include #include +#include + #include #include @@ -83,11 +84,8 @@ class Inliner : ExprMutator { } Function Inline(const Function& func) { - return Function(func->params, - VisitExpr(func->body), - func->ret_type, - func->type_params, - func->attrs); + return Function(func->params, VisitExpr(func->body), func->ret_type, func->type_params, + func->attrs); } private: @@ -115,20 +113,13 @@ class Inliner : ExprMutator { } // Make a new Relay expression to replace the callee. - Expr MakeNewExpr(const GlobalVar& global, - const Array& args, - const Expr& callee) { - CHECK(callee->IsInstance() || - callee->IsInstance()); + Expr MakeNewExpr(const GlobalVar& global, const Array& args, const Expr& callee) { + CHECK(callee->IsInstance() || callee->IsInstance()); auto base_func = call_graph_->GetGlobalFunction(global); const auto* fn = base_func.as(); CHECK(fn) << "Expected to work on a Relay function."; - auto func = Function(fn->params, - fn->body, - fn->ret_type, - fn->type_params, - fn->attrs); + auto func = Function(fn->params, fn->body, fn->ret_type, fn->type_params, fn->attrs); // Inline the function body to the caller if this function uses default // compiler, i.e. no external codegen is needed. if (!func->GetAttr(attr::kCompiler).defined()) { @@ -144,14 +135,13 @@ class Inliner : ExprMutator { // Cannot replace TensorType/TensorTupleType with FuncType. Therefore, // we simply inline the function as a closure instead of directly using // its body when the global var returns FuncType. - return ret_type->IsInstance() ? std::move(func) - : func->body; + return ret_type->IsInstance() ? std::move(func) : func->body; } else { CHECK(callee->IsInstance()); return Bind(func->body, bind_map); } } else if (const auto* call_node = callee.as()) { - return Call(func, args, call_node->attrs, call_node->type_args); + return Call(func, args, call_node->attrs, call_node->type_args); } else { return std::move(func); } @@ -214,14 +204,11 @@ namespace transform { Pass Inline() { runtime::TypedPackedFunc pass_func = - [=](IRModule m, PassContext pc) { - return relay::Inline(m); - }; + [=](IRModule m, PassContext pc) { return relay::Inline(m); }; return CreateModulePass(pass_func, 1, "InlineGlobals", {}); } -TVM_REGISTER_GLOBAL("relay._transform.Inline") -.set_body_typed(Inline); +TVM_REGISTER_GLOBAL("relay._transform.Inline").set_body_typed(Inline); } // namespace transform diff --git a/src/relay/transforms/lazy_gradient_init.cc b/src/relay/transforms/lazy_gradient_init.cc index e6248f11a00e..f06246667a8b 100644 --- a/src/relay/transforms/lazy_gradient_init.cc +++ b/src/relay/transforms/lazy_gradient_init.cc @@ -24,21 +24,21 @@ * \brief Lazily instantiate 0-filled or 1-filled tensors. * This pass should be used after reverse-mode ad so that gradient tensors * are not instantiated until after the forward pass. - * - * This pass delays or removes memory allocation by converting tensors into + * + * This pass delays or removes memory allocation by converting tensors into * GradCell, an algebraic data type defined in gradient.rly. - * + * * This will delay or decrease memory usage. All calls to * ones, ones_like, zeros, zeros_like will call the One or Zero constructor * of GradCell, which will not instantiate in memory until needed. All other cases result * in using the Raw constructor which means the tensor is instantiated in memory. - * + * * It also overloads + and * operation which can increase performance when doing * operations involving tensors with values of only 0 or 1. - * + * * Note: this pass can only be used with functions where the input/output types are * a combination of TupleTypes and TensorTypes - * + * * This pass optimizes 6 ops: * - add * - multiply @@ -46,39 +46,40 @@ * - ones_like * - zeros * - zeros_like - * + * * This pass makes use of three visitor. The most important one visits the entire function, * one is used for wrap inputs and one to unwrap outputs. - * + * * For example: * fn: TensorType[(10,10), float32] -> TensorType[(10,10), float32] - * + * * After this pass * fn: GradCell[TensorType[(10,10), float32]] -> GradCell[TensorType[(10,10), float32]] - * + * * Thus, it is necessary to wrap this outer function so that the input/output types remain the same */ +#include #include #include #include -#include #include + #include "let_list.h" namespace tvm { namespace relay { /*! -* \brief Visitor appropriately wraps tensors with Raw constructor -* -* Recursively looks at the type of the expression (TensorType or TupleType are only supported for now) -* and either call the GradCell constructor if TensorType -* or unfold and recursively visit if TupleType -*/ -class InputVisitor: public ExprFunctor { + * \brief Visitor appropriately wraps tensors with Raw constructor + * + * Recursively looks at the type of the expression (TensorType or TupleType are only supported for + * now) and either call the GradCell constructor if TensorType or unfold and recursively visit if + * TupleType + */ +class InputVisitor : public ExprFunctor { public: - explicit InputVisitor(IRModule module): module_(module) {} + explicit InputVisitor(IRModule module) : module_(module) {} Expr VisitExpr_(const VarNode* op, const Type& t) final { std::cout << op->type_annotation << std::endl; @@ -88,13 +89,13 @@ class InputVisitor: public ExprFunctor { Expr VisitExpr_(const TupleGetItemNode* op, const Type& t) final { return WrapExpr(GetRef(op), t); } + private: IRModule module_; Expr WrapExpr(const Expr expr, const Type& type) { if (type.as()) { - return Call(module_->GetConstructor("GradCell", "Raw"), - {expr}, Attrs(), {type}); + return Call(module_->GetConstructor("GradCell", "Raw"), {expr}, Attrs(), {type}); } else if (auto* type_anno = type.as()) { tvm::Array fields; for (size_t i = 0; i < type_anno->fields.size(); i++) { @@ -110,15 +111,15 @@ class InputVisitor: public ExprFunctor { }; /*! -* \brief Visitor appropriately unwraps expressions with GradCell type into Tensors -* -* Recursively looks at the type of the expression -* and either use the FromGradCell function if TypeCall to GradCell -* or unfold and recursively visit if TupleType -*/ -class OutputVisitor: public ExprFunctor { + * \brief Visitor appropriately unwraps expressions with GradCell type into Tensors + * + * Recursively looks at the type of the expression + * and either use the FromGradCell function if TypeCall to GradCell + * or unfold and recursively visit if TupleType + */ +class OutputVisitor : public ExprFunctor { public: - explicit OutputVisitor(IRModule module): module_(module) {} + explicit OutputVisitor(IRModule module) : module_(module) {} Expr VisitExpr_(const CallNode* op, const Type& t) final { return UnwrapExpr(GetRef(op), t); @@ -127,6 +128,7 @@ class OutputVisitor: public ExprFunctor { Expr VisitExpr_(const TupleGetItemNode* op, const Type& t) final { return UnwrapExpr(GetRef(op), t); } + private: IRModule module_; @@ -150,19 +152,18 @@ class OutputVisitor: public ExprFunctor { } }; -class LazyGradientInitializer: public ExprMutator, public TypeMutator { +class LazyGradientInitializer : public ExprMutator, public TypeMutator { public: - explicit LazyGradientInitializer(IRModule module): - module_(module) { - module_->ImportFromStd("gradient.rly"); - } + explicit LazyGradientInitializer(IRModule module) : module_(module) { + module_->ImportFromStd("gradient.rly"); + } /*! - * \brief apply LazyGradientInit transformation and wrap function - * so that function type stays the same - * - * input/output types should only be a combination of TupleTypes and TensorTypes - */ + * \brief apply LazyGradientInit transformation and wrap function + * so that function type stays the same + * + * input/output types should only be a combination of TupleTypes and TensorTypes + */ Expr Transform(const Expr& e) { auto* f = (e).as(); auto* transformed = this->Mutate(e).as(); @@ -185,8 +186,8 @@ class LazyGradientInitializer: public ExprMutator, public TypeMutator { } Expr VisitExpr_(const ConstantNode* op) final { - return Call(module_->GetConstructor("GradCell", "Raw"), - {GetRef(op)}, Attrs(), {op->checked_type()}); + return Call(module_->GetConstructor("GradCell", "Raw"), {GetRef(op)}, Attrs(), + {op->checked_type()}); } Expr VisitExpr_(const CallNode* call_node) final { @@ -202,13 +203,13 @@ class LazyGradientInitializer: public ExprMutator, public TypeMutator { } if (op_expr == Op::Get("ones") || op_expr == Op::Get("zeros")) { - // fn() -> T, function returns result of the operation - Expr func = Function({}, {ExprMutator::VisitExpr_(call_node)}, - {call_node->checked_type()}, {}); + // ones and zeros need TensorType input + Expr result = CallPrimitiveOp(call_node); + Expr func = Function({}, result, {call_node->checked_type()}, Array()); // call appropriate GradCell constructor std::string constructor_name = op_expr == Op::Get("ones") ? "One" : "Zero"; - return Call(module_->GetConstructor("GradCell", constructor_name), - {func}, Attrs(), {call_node->checked_type()}); + return Call(module_->GetConstructor("GradCell", constructor_name), {func}, Attrs(), + {call_node->checked_type()}); } if (op_expr == Op::Get("ones_like") || op_expr == Op::Get("zeros_like")) { @@ -218,23 +219,21 @@ class LazyGradientInitializer: public ExprMutator, public TypeMutator { Expr func = Function({}, result, {call_node->checked_type()}, Array()); // call appropriate GradCell constructor std::string constructor_name = op_expr == Op::Get("ones_like") ? "One" : "Zero"; - return Call(module_->GetConstructor("GradCell", "One"), - {func}, Attrs(), {call_node->checked_type()}); + return Call(module_->GetConstructor("GradCell", "One"), {func}, Attrs(), + {call_node->checked_type()}); } // handle all other ops Expr result = CallPrimitiveOp(call_node); // wrap result with Raw constructor - return Call(module_->GetConstructor("GradCell", "Raw"), {result}, - Attrs(), {call_node->checked_type()}); + return Call(module_->GetConstructor("GradCell", "Raw"), {result}, Attrs(), + {call_node->checked_type()}); } // not an op return ExprMutator::VisitExpr_(call_node); } - Type VisitType(const Type& t) final { - return TypeMutator::VisitType(t); - } + Type VisitType(const Type& t) final { return TypeMutator::VisitType(t); } Type VisitType_(const TensorTypeNode* op) { GlobalTypeVar gradCell = module_->GetGlobalTypeVar("GradCell"); @@ -248,23 +247,22 @@ class LazyGradientInitializer: public ExprMutator, public TypeMutator { IRModule module_; /*! - * \brief Convert call_node to add/multiply op to use overloaded functions for GradCell type - */ + * \brief Convert call_node to add/multiply op to use overloaded functions for GradCell type + */ Expr CallGradCellFunction(const CallNode* call_node, GlobalVar overloaded_op) { // can only use overloaded functions if 2 arguments of same type if (call_node->args.size() != 2 || !tvm::StructuralEqual()(call_node->args[0]->checked_type(), call_node->args[1]->checked_type())) { Expr result = CallPrimitiveOp(call_node); - return Call(module_->GetConstructor("GradCell", "Raw"), {result}, - Attrs(), {call_node->checked_type()}); + return Call(module_->GetConstructor("GradCell", "Raw"), {result}, Attrs(), + {call_node->checked_type()}); } tvm::Array args; // create "fallback" function for overloaded function Type paramType = call_node->args[0]->checked_type(); - tvm::Array params = {Var("lhs", paramType), - Var("rhs", paramType)}; + tvm::Array params = {Var("lhs", paramType), Var("rhs", paramType)}; // use primitive op in this case Expr callOp = Call(call_node->op, {params[0], params[1]}); Expr func = Function(params, callOp, paramType, Array()); @@ -279,19 +277,18 @@ class LazyGradientInitializer: public ExprMutator, public TypeMutator { } /*! - * \brief Convert calls to other ops by converting args into TensorType - * \return call expr returning result of op - */ + * \brief Convert calls to other ops by converting args into TensorType + * \return call expr returning result of op + */ Expr CallPrimitiveOp(const CallNode* call_node) { const auto fromFunc = module_->GetGlobalVar("FromGradCell"); tvm::Array args; // use FromGradCell to convert args to Tensor for (Expr expr : call_node->args) { - args.push_back(Call(fromFunc, - {VisitExpr(expr)}, Attrs(), {expr->checked_type()})); + args.push_back(Call(fromFunc, {VisitExpr(expr)}, Attrs(), {expr->checked_type()})); } // result of operation - return Call(call_node->op, args); + return Call(call_node->op, args, call_node->attrs); } }; @@ -302,14 +299,13 @@ Expr LazyGradientInit(const Expr& e, IRModule mod) { namespace transform { Pass LazyGradientInit() { runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - return Downcast(LazyGradientInit(f, m)); - }; - return CreateFunctionPass(pass_func, 2, "LazyGradientInit", {}); + [=](Function f, IRModule m, PassContext pc) { + return Downcast(LazyGradientInit(f, m)); + }; + return CreateFunctionPass(pass_func, 2, "LazyGradientInit", {}); } -TVM_REGISTER_GLOBAL("relay._transform.LazyGradientInit") -.set_body_typed(LazyGradientInit); +TVM_REGISTER_GLOBAL("relay._transform.LazyGradientInit").set_body_typed(LazyGradientInit); } // namespace transform diff --git a/src/relay/transforms/legalize.cc b/src/relay/transforms/legalize.cc index 0b5c671ab7f6..89f59f625a8d 100644 --- a/src/relay/transforms/legalize.cc +++ b/src/relay/transforms/legalize.cc @@ -23,10 +23,10 @@ * shape, dtype or layout to another op or a sequence of ops. */ -#include #include #include #include +#include namespace tvm { namespace relay { @@ -44,13 +44,13 @@ class Legalizer : public ExprRewriter { // Get the new_call node without any changes to current call node. Call new_call = Downcast(post); - // Check if the string is registered in the OpRegistry. - if (!Op::HasAttr(legalize_map_attr_name_)) { + // Check if the string is registered. + if (!Op::HasAttrMap(legalize_map_attr_name_)) { return post; } // Collect the registered legalize function. - auto fop_legalize = Op::GetAttr(legalize_map_attr_name_); + auto fop_legalize = Op::GetAttrMap(legalize_map_attr_name_); auto call_op = call_node->op; if (call_op.as()) { Op op = Downcast(call_node->op); @@ -96,7 +96,7 @@ Expr Legalize(const Expr& expr, const std::string& legalize_map_attr_name) { namespace transform { -Pass Legalize(const std::string& legalize_map_attr_name) { +Pass Legalize(const String& legalize_map_attr_name) { runtime::TypedPackedFunc pass_func = [=](Function f, IRModule m, PassContext pc) { return Downcast(relay::legalize::Legalize(f, legalize_map_attr_name)); diff --git a/src/relay/transforms/let_list.h b/src/relay/transforms/let_list.h index f195c3060e2f..c0e0b3a23864 100644 --- a/src/relay/transforms/let_list.h +++ b/src/relay/transforms/let_list.h @@ -29,12 +29,14 @@ #ifndef TVM_RELAY_TRANSFORMS_LET_LIST_H_ #define TVM_RELAY_TRANSFORMS_LET_LIST_H_ -#include #include +#include + +#include +#include #include #include -#include -#include + #include "tvm/relay/type.h" namespace tvm { @@ -77,9 +79,7 @@ class LetList { * * \return a Var that hold the inserted expr. */ - Var Push(Expr expr, Type ty) { - return Push(Var("x", ty), expr); - } + Var Push(Expr expr, Type ty) { return Push(Var("x", ty), expr); } /*! * \brief insert a binding. @@ -88,9 +88,7 @@ class LetList { * * \return a Var that hold the inserted expr. */ - Var Push(Expr expr) { - return Push(expr, Type()); - } + Var Push(Expr expr) { return Push(expr, Type()); } /*! * \brief wrap an expr around the LetList. @@ -130,16 +128,14 @@ class LetList { * * \return the wrapped Expr. */ - template + template static Expr With(F&& f) { LetList ll; return ll.Get(f(&ll)); } static Expr LetBind(const Expr& e, const std::function& f) { - return With([&](LetList* ll) { - return f(ll->Push(e)); - }); + return With([&](LetList* ll) { return f(ll->Push(e)); }); } private: diff --git a/src/relay/transforms/merge_compiler_regions.cc b/src/relay/transforms/merge_compiler_regions.cc index 601be0f96bc4..5e615e4316bd 100644 --- a/src/relay/transforms/merge_compiler_regions.cc +++ b/src/relay/transforms/merge_compiler_regions.cc @@ -43,22 +43,18 @@ #include #include "../analysis/annotated_region_set.h" +#include "pass_util.h" namespace tvm { namespace relay { namespace merge_compiler_region { -// Cache compiler_begin and compiler_end annotation ops for equivalence check to -// reduce registry lookup overhead. -static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin"); -static const Op& compiler_end_op = Op::Get("annotation.compiler_end"); - -class RegionMerger : public ExprVisitor { +class RegionMerger : public MixedModeVisitor { public: explicit RegionMerger(AnnotatedRegionSet regions) : regions_(regions) {} void VisitExpr_(const CallNode* call) final { - if (call->op == compiler_end_op) { + if (call->op == CompilerEndOp()) { auto region = regions_->GetRegion(GetRef(call)); // Skip this region if it has been merged to the other region. @@ -75,7 +71,7 @@ class RegionMerger : public ExprVisitor { // Region inputs must be begin annotation, and the region of // the begin annotation's argument is the parent region. auto begin = Downcast(arg); - CHECK_EQ(begin->op, compiler_begin_op); + CHECK_EQ(begin->op, CompilerBeginOp()); auto parent_region = regions_->GetRegion(begin->args[0]); // Skip this region if it has been merged. @@ -87,10 +83,10 @@ class RegionMerger : public ExprVisitor { } // Collect unmerged parent regions. - std::unordered_set mergeable_regions; + std::unordered_set mergeable_regions; for (const auto& arg : region->GetInputs()) { auto begin = Downcast(arg); - CHECK_EQ(begin->op, compiler_begin_op); + CHECK_EQ(begin->op, CompilerBeginOp()); auto parent_region = regions_->GetRegion(begin->args[0]); if (parent_region.defined()) { mergeable_regions.insert(parent_region); @@ -131,7 +127,6 @@ class RegionMerger : public ExprVisitor { } merged_regions_.insert(region->GetID()); } - ExprVisitor::VisitExpr_(call); } private: @@ -140,25 +135,26 @@ class RegionMerger : public ExprVisitor { std::unordered_map> region_restrictions_; }; -class MergeAnnotations : public ExprMutator { +class MergeAnnotations : public ExprRewriter { public: explicit MergeAnnotations(AnnotatedRegionSet regions) : regions_(regions) {} - Expr VisitExpr_(const CallNode* call) final { + Expr Rewrite_(const CallNode* call, const Expr& post) final { // Merge annotations which are now internal to a region. // This happens if we see a compiler begin next to a // compiler end and they're both in the same region. - if (call->op == compiler_begin_op && call->args[0]->IsInstance()) { + if (call->op == CompilerBeginOp() && call->args[0]->IsInstance()) { auto arg = Downcast(call->args[0]); - if (arg->op == compiler_end_op) { + if (arg->op == CompilerEndOp()) { auto region1 = regions_->GetRegion(GetRef(call)); auto region2 = regions_->GetRegion(arg); if (region1 == region2) { - return VisitExpr(arg->args[0]); + auto post_arg = post.as()->args[0]; + return post_arg.as()->args[0]; } } } - return ExprMutator::VisitExpr_(call); + return post; } private: @@ -167,7 +163,7 @@ class MergeAnnotations : public ExprMutator { Expr MergeCompilerRegions(const Expr& expr) { // Create regions using the annotations. - AnnotatedRegionSet regions = AnnotatedRegionSet::Create(expr, compiler_begin_op, compiler_end_op); + AnnotatedRegionSet regions = AnnotatedRegionSet::Create(expr, CompilerBeginOp(), CompilerEndOp()); // Analyze the graph to explore the opportunities of merging regions. RegionMerger merger(regions); @@ -175,7 +171,7 @@ Expr MergeCompilerRegions(const Expr& expr) { // Remove annotations that are not in the region boundaries. MergeAnnotations merge_anno(regions); - return merge_anno.Mutate(expr); + return PostOrderRewrite(expr, &merge_anno); } } // namespace merge_compiler_region diff --git a/src/relay/transforms/merge_composite.cc b/src/relay/transforms/merge_composite.cc index 75d95f0378f1..324b2cb3a1c4 100644 --- a/src/relay/transforms/merge_composite.cc +++ b/src/relay/transforms/merge_composite.cc @@ -26,6 +26,7 @@ */ #include +#include #include #include #include @@ -35,188 +36,24 @@ namespace tvm { namespace relay { namespace merge_composite { -class MergeCompositeWrapper : public ExprMutator { - public: - explicit MergeCompositeWrapper(const std::string& pattern_name, const Expr& pattern, - const PackedFunc& check) - : pattern_name_(pattern_name), pattern_(pattern), check_(check) {} - - Expr ExtractPattern(const Var& pattern, const Expr& root, - Map>* var_map) { - if (var_map->find(pattern->name_hint()) == var_map->end()) { - // if we haven't encountered this var yet, make a new free var and associate - // it with the value at 'root' - auto free_var = Var(pattern->name_hint(), Type()); - var_map->Set(pattern->name_hint(), Array({free_var, root})); - return std::move(free_var); - } else { - // if we have encountered this var already, return the free var that was created - auto vars = (*var_map)[pattern->name_hint()]; - auto free_var = vars[0]; - auto graph_expr = vars[1]; - // make sure to first check they both map to the same node in the graph - if (graph_expr != root) { - return Expr(); - } - return (*var_map)[pattern->name_hint()][0]; - } - } - - Expr ExtractPattern(const Constant& pattern, const Expr& root, - Map>* var_map) { - return root; - } - - Expr ExtractPattern(const TupleGetItem& pattern, const Expr& root, - Map>* var_map, Map* call_map) { - if (!root->IsInstance()) { - return Expr(); - } - auto root_node = Downcast(root); - if (pattern->index != root_node->index) { - return Expr(); - } - if (pattern->tuple->IsInstance() && root_node->tuple->IsInstance()) { - Expr new_arg; - if (call_map->find(pattern->tuple) != call_map->end()) { - new_arg = (*call_map)[pattern->tuple]; - } else { - new_arg = ExtractPattern(Downcast(pattern->tuple), Downcast(root_node->tuple), - var_map, call_map); - call_map->Set(pattern->tuple, new_arg); - } - return TupleGetItem(new_arg, root_node->index); - } - return Expr(); - } - - /*! - * \brief Try and extract a given pattern from a graph as a subgraph. - * \param pattern The pattern to extract. - * \param root The graph to extract from. - * \param var_map A map between free vars in the subgraph and nodes in the graph. - * \return The extracted subgraph. - * - * \note How does this work? - * - * A pattern consists of Relay expression containing only operator call nodes, constants - * and free variables. The free variables indicate where the pattern can 'attach' in your - * graph. This function takes the final call node of the pattern and the call node currently - * being traversed in the Relay graph. It traverses through the pattern in lockstep with call node - * from the graph (referred to as the 'root' node here) to check they're identical. If at any - * point they differ, an empty expression is returned to signify the extract failed. If a free var - * is reached in the pattern, the corresponding value in the root is associated with the name of - * the free var (via the var_map) so that when we construct the composite function, the inputs - * match up correctly with the rest of the graph. The return value of this function when - * successful is a new Relay expression ready to be wrapped into a composite function. - */ - Expr ExtractPattern(const Call& pattern, const Call& root, Map>* var_map, - Map* call_map) { - // check to make sure both calls are to operators (not functions) - if (!pattern->op->IsInstance() || !root->op->IsInstance()) return Expr(); - if (pattern->op.as()->name != root->op.as()->name) return Expr(); - - unsigned int i = 0; - Array new_args; - for (const auto& arg : pattern->args) { - Expr new_arg; - if (arg->IsInstance()) { - // if we've already processed this call node, return the previous result - if (call_map->find(arg) != call_map->end()) { - new_arg = (*call_map)[arg]; - } else { - // fail if the root argument is not also a call node - if (!root->args[i]->IsInstance()) { - return Expr(); - } - // if it's a call node, recursively call this function - new_arg = - ExtractPattern(Downcast(arg), Downcast(root->args[i]), var_map, call_map); - call_map->Set(arg, new_arg); - } - } else if (arg->IsInstance()) { - // if there's a var in the pattern, it must be a free var - // so call the function to update the var_map - new_arg = ExtractPattern(Downcast(arg), root->args[i], var_map); - } else if (arg->IsInstance()) { - // if there's a constant, simply get the corresponding - // value of the constant from the root - new_arg = ExtractPattern(Downcast(arg), root->args[i], var_map); - } else if (arg->IsInstance()) { - new_arg = ExtractPattern(Downcast(arg), root->args[i], var_map, call_map); - } - if (!new_arg.defined()) { - return Expr(); - } - new_args.push_back(new_arg); - i++; - } - return Call(root->op, new_args, root->attrs); - } - - Expr VisitExpr_(const CallNode* cn) { - Call call = GetRef(cn); - if (call->op->IsInstance()) { - Function func = Downcast(call->op); - CHECK(func.defined()); - auto name_node = func->GetAttr(attr::kComposite); - // don't step into existing composite functions - if (name_node.defined() && name_node != "") { - tvm::Array new_args; - for (const auto& arg : call->args) { - auto new_e = this->Mutate(arg); - new_args.push_back(new_e); - } - return Call(call->op, new_args, call->attrs); - } - } - - Expr expr = ExprMutator::VisitExpr_(cn); - call = Downcast(expr); - if (!call->op->IsInstance()) return std::move(call); - - // only call patterns are supported - Call pattern = Downcast(pattern_); - CHECK(pattern.defined()); - Map> args_map; - Map call_map; - auto extract = ExtractPattern(pattern, call, &args_map, &call_map); - if (extract.defined() && static_cast(check_(extract))) { - auto free_vars = FreeVars(extract); - // make the composite function - auto f = Function(free_vars, extract, call->checked_type_, {}, DictAttrs()); - f = WithAttr(std::move(f), attr::kComposite, runtime::String(pattern_name_)); - // find the expressions associated with the free vars using the args_map - // this tells us which expressions should be given as inputs to the composite function - Array args; - for (const auto& free_var : free_vars) { - args.push_back(args_map[free_var->name_hint()][1]); - } - auto new_call = Call(f, args); - return std::move(new_call); - } - return std::move(call); - } - - private: - /*! \brief The name of the pattern to match */ - std::string pattern_name_; - /*! \brief The pattern to match */ - Expr pattern_; - /*! \brief The function to check whether an extract is supported */ - PackedFunc check_; -}; +Function InferType(const Function& expr) { + auto mod = IRModule::FromExpr(expr); + mod = transform::InferType()(mod); + return Downcast(mod->Lookup("main")); +} -Expr MergeComposite(const Expr& expr, const Array& pattern_names, - const Array& patterns, const std::vector& checks) { +Expr MergeComposite(const Function& func, const Array& pattern_names, + const Array& patterns, const std::vector& checks) { CHECK_EQ(pattern_names.size(), patterns.size()); - Expr merged_expr = expr; + Function merged_func = func; // merge the patterns one-by-one in order for (size_t i = 0; i < patterns.size(); i++) { - merged_expr = - MergeCompositeWrapper(pattern_names[i], patterns[i], checks[i]).Mutate(merged_expr); + Map attrs; + attrs.Set("Composite", pattern_names[i]); + merged_func = Downcast(PartitionPattern(patterns[i], merged_func, attrs, checks[i])); + merged_func = InferType(merged_func); } - return merged_expr; + return std::move(merged_func); } } // namespace merge_composite @@ -224,7 +61,7 @@ Expr MergeComposite(const Expr& expr, const Array& pattern_name namespace transform { Pass MergeComposite(const tvm::Array& pattern_names, - const tvm::Array& patterns, const std::vector& checks) { + const tvm::Array& patterns, const std::vector& checks) { runtime::TypedPackedFunc pass_func = [=](Function f, IRModule m, PassContext pc) { return Downcast( @@ -234,10 +71,9 @@ Pass MergeComposite(const tvm::Array& pattern_names, return func_pass; } -TVM_REGISTER_GLOBAL("relay._transform.MergeComposite") -.set_body([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("relay._transform.MergeComposite").set_body([](TVMArgs args, TVMRetValue* rv) { tvm::Array pattern_names = args[0]; - tvm::Array patterns = args[1]; + tvm::Array patterns = args[1]; std::vector checks; for (int i = 2; i < args.size(); i++) { checks.push_back(args[i]); diff --git a/src/relay/transforms/partial_eval.cc b/src/relay/transforms/partial_eval.cc index cd1f40c28767..371142ad76a2 100644 --- a/src/relay/transforms/partial_eval.cc +++ b/src/relay/transforms/partial_eval.cc @@ -63,7 +63,7 @@ * so we have to deduplicate them. * * 4: In the generated code, as it call TypeSubst, multiple VarNode might have same Id. - * While it is permitted, most pass use ObjectHash for Var, + * While it is permitted, most pass use ObjectPtrHash for Var, * and having multiple VarNode for same Id break them. * Thus we remap them to a single Id for now. * @@ -91,12 +91,13 @@ */ #include #include -#include #include -#include #include -#include "pass_util.h" +#include +#include + #include "let_list.h" +#include "pass_util.h" namespace tvm { namespace relay { @@ -109,9 +110,7 @@ using namespace runtime; * Use VarHash to hash Var by id. */ struct VarHash { - size_t operator()(const Var& v) const { - return ObjectHash()(v->vid); - } + size_t operator()(const Var& v) const { return ObjectPtrHash()(v->vid); } }; /*! \brief Compare Var by it's id. @@ -119,9 +118,7 @@ struct VarHash { * Use VarEqual to compare Var by id. */ struct VarEqual { - bool operator()(const Var& l, const Var& r) const { - return l->vid.get() == r->vid.get(); - } + bool operator()(const Var& l, const Var& r) const { return l->vid.get() == r->vid.get(); } }; Expr PostProcess(const Expr&); @@ -137,9 +134,7 @@ class Static : public ObjectRef { public: Static() {} explicit Static(ObjectPtr n) : ObjectRef(n) {} - const StaticNode* operator->() const { - return static_cast(get()); - } + const StaticNode* operator->() const { return static_cast(get()); } using ContainerType = StaticNode; }; @@ -156,9 +151,9 @@ struct PStaticNode : Object { Static pstatic; // may be null Expr dynamic; Time created_time; - PStaticNode(const Static& pstatic, const Expr& dynamic) : - pstatic(pstatic), dynamic(dynamic), created_time(time()) { } - explicit PStaticNode(const Expr& dynamic) : PStaticNode(Static(), dynamic) { } + PStaticNode(const Static& pstatic, const Expr& dynamic) + : pstatic(pstatic), dynamic(dynamic), created_time(time()) {} + explicit PStaticNode(const Expr& dynamic) : PStaticNode(Static(), dynamic) {} static constexpr const char* _type_key = "relay.PStatic"; TVM_DECLARE_FINAL_OBJECT_INFO(PStaticNode, Object); }; @@ -170,7 +165,7 @@ class PStatic : public ObjectRef { struct STupleNode : StaticNode { std::vector fields; - explicit STupleNode(const std::vector& fields) : fields(fields) { } + explicit STupleNode(const std::vector& fields) : fields(fields) {} static constexpr const char* _type_key = "relay.STuple"; TVM_DECLARE_FINAL_OBJECT_INFO(STupleNode, StaticNode); }; @@ -186,7 +181,7 @@ Static MkSTuple(const std::vector& fields) { struct STensorNode : StaticNode { runtime::NDArray data; - explicit STensorNode(const NDArray& data) : data(data) { } + explicit STensorNode(const NDArray& data) : data(data) {} static constexpr const char* _type_key = "relay.STensor"; TVM_DECLARE_FINAL_OBJECT_INFO(STensorNode, StaticNode); }; @@ -196,15 +191,13 @@ class STensor : public Static { TVM_DEFINE_OBJECT_REF_METHODS(STensor, Static, STensorNode); }; -Static MkSTensor(const NDArray& data) { - return Static(make_object(data)); -} +Static MkSTensor(const NDArray& data) { return Static(make_object(data)); } struct SConstructorNode : StaticNode { Constructor constructor; std::vector fields; - SConstructorNode(const Constructor& constructor, const std::vector& fields) : - constructor(constructor), fields(fields) { } + SConstructorNode(const Constructor& constructor, const std::vector& fields) + : constructor(constructor), fields(fields) {} static constexpr const char* _type_key = "relay.SConstructor"; TVM_DECLARE_FINAL_OBJECT_INFO(SConstructorNode, StaticNode); }; @@ -229,19 +222,14 @@ class SRef : public Static { TVM_DEFINE_OBJECT_REF_METHODS(SRef, Static, SRefNode); }; -Static MkSRef() { - return Static(make_object()); -} +Static MkSRef() { return Static(make_object()); } -using Func = std::function&, - const Attrs&, - const Array&, - LetList*)>; +using Func = std::function&, const Attrs&, + const Array&, LetList*)>; struct SFuncNode : StaticNode { Func func; - explicit SFuncNode(const Func& func) : func(func) { } + explicit SFuncNode(const Func& func) : func(func) {} static constexpr const char* _type_key = "relay.SFunc"; TVM_DECLARE_FINAL_OBJECT_INFO(SFuncNode, StaticNode); }; @@ -251,15 +239,13 @@ class SFunc : public Static { TVM_DEFINE_OBJECT_REF_METHODS(SFunc, Static, SFuncNode); }; -Static MkSFunc(const Func& func) { - return Static(make_object(func)); -} - +Static MkSFunc(const Func& func) { return Static(make_object(func)); } class FuelNode; /*! \brief A meet-semilattice with finite descending chain. * It means that we can meet two element to get an element, - * and for every element, there is only a finite amount of meet before getting back the same element. + * and for every element, there is only a finite amount of meet before getting back the same + * element. * * Every time we recurse, we do a meet and require that progress must be made. * This ensures we do not recurse infinitely in the Partial Evaluator. @@ -301,9 +287,7 @@ class FuelNode : public RelayNode { TVM_DECLARE_BASE_OBJECT_INFO(FuelNode, RelayNode); }; -const FuelNode* Fuel::operator->() const { - return static_cast(get()); -} +const FuelNode* Fuel::operator->() const { return static_cast(get()); } Fuel MkFSeq(const std::vector& fuels); struct FSeqNode : FuelNode { @@ -318,7 +302,7 @@ struct FSeqNode : FuelNode { } return MkFSeq(new_fuels); } - explicit FSeqNode(const std::vector& fuels) : fuels(fuels) { } + explicit FSeqNode(const std::vector& fuels) : fuels(fuels) {} static constexpr const char* _type_key = "relay.FSeq"; TVM_DECLARE_FINAL_OBJECT_INFO(FSeqNode, FuelNode); }; @@ -328,9 +312,7 @@ class FSeq : public Fuel { TVM_DEFINE_OBJECT_REF_METHODS(FSeq, Fuel, FSeqNode); }; -Fuel MkFSeq(const std::vector& fuels) { - return Fuel(make_object(fuels)); -} +Fuel MkFSeq(const std::vector& fuels) { return Fuel(make_object(fuels)); } Fuel MkFTime(Time time); struct FTimeNode : FuelNode { @@ -341,7 +323,7 @@ struct FTimeNode : FuelNode { Time new_time = std::min(time, x->time); return std::make_tuple(MkFTime(new_time), new_time < time); } - explicit FTimeNode(Time time) : time(time) { } + explicit FTimeNode(Time time) : time(time) {} static constexpr const char* _type_key = "relay.FTime"; TVM_DECLARE_FINAL_OBJECT_INFO(FTimeNode, FuelNode); }; @@ -351,9 +333,7 @@ class FTime : public Fuel { TVM_DEFINE_OBJECT_REF_METHODS(FTime, Fuel, FTimeNode); }; -Fuel MkFTime(Time time) { - return Fuel(make_object(time)); -} +Fuel MkFTime(Time time) { return Fuel(make_object(time)); } Fuel MkFTValue(size_t tvalue); /*! \brief If the pstatic is hold a positive integer scalar, that number, else 0. */ @@ -365,7 +345,7 @@ struct FTValueNode : FuelNode { size_t new_tvalue = std::min(tvalue, x->tvalue); return std::make_tuple(MkFTValue(new_tvalue), new_tvalue < tvalue); } - explicit FTValueNode(size_t tvalue) : tvalue(tvalue) { } + explicit FTValueNode(size_t tvalue) : tvalue(tvalue) {} static constexpr const char* _type_key = "relay.FTValue"; TVM_DECLARE_FINAL_OBJECT_INFO(FTValueNode, FuelNode); }; @@ -375,9 +355,7 @@ class FTValue : public Fuel { TVM_DEFINE_OBJECT_REF_METHODS(FTValue, Fuel, FTValueNode); }; -Fuel MkFTValue(size_t tvalue) { - return Fuel(make_object(tvalue)); -} +Fuel MkFTValue(size_t tvalue) { return Fuel(make_object(tvalue)); } /*! \brief Initially every element has Fuel of FTop. It is the largest element. * @@ -397,9 +375,7 @@ class FTop : public Fuel { TVM_DEFINE_OBJECT_REF_METHODS(FTop, Fuel, FTopNode); }; -Fuel MkFTop() { - return Fuel(make_object()); -} +Fuel MkFTop() { return Fuel(make_object()); } /*! * \brief A stack frame in the Relay interpreter. @@ -414,10 +390,10 @@ struct Frame { class Environment { public: - Environment() : env_({Frame()}) { } + Environment() : env_({Frame()}) {} Environment(const Environment&) = delete; - template + template T Extend(const std::function& body) { FrameContext fc(this); return body(); @@ -447,12 +423,8 @@ class Environment { struct FrameContext { Environment* env_; - explicit FrameContext(Environment* env) : env_(env) { - env_->env_.push_back(Frame()); - } - ~FrameContext() { - env_->env_.pop_back(); - } + explicit FrameContext(Environment* env) : env_(env) { env_->env_.push_back(Frame()); } + ~FrameContext() { env_->env_.pop_back(); } }; }; @@ -470,16 +442,16 @@ struct StoreFrame { * It only outdate the frame above it, but not the current frame. */ bool history_valid = true; - explicit StoreFrame(const std::unordered_map& store) : store(store) { } + explicit StoreFrame(const std::unordered_map& store) : store(store) {} StoreFrame() = default; }; class Store { public: - Store() : store_({StoreFrame()}) { } + Store() : store_({StoreFrame()}) {} Store(const Store&) = delete; - template + template T Extend(const std::function& body) { StoreFrameContext sfc(this); return body(); @@ -534,16 +506,12 @@ PStatic HasStatic(const Static& stat, const Expr& dynamic) { return PStatic(make_object(stat, dynamic)); } -PStatic NoStatic(const Expr& dynamic) { - return PStatic(make_object(dynamic)); -} +PStatic NoStatic(const Expr& dynamic) { return PStatic(make_object(dynamic)); } -enum struct MatchStatus { - Match, NoMatch, Unknown -}; +enum struct MatchStatus { Match, NoMatch, Unknown }; bool StatefulOp(const Expr& e) { - static auto op_stateful = Op::GetAttr("TOpIsStateful"); + static auto op_stateful = Op::GetAttrMap("TOpIsStateful"); struct StatefulOpVisitor : ExprVisitor { bool stateful = false; void VisitExpr_(const OpNode* op) { @@ -565,10 +533,12 @@ DLContext CPUContext() { } FInterpreter CPUInterpreter() { + using tvm::transform::PassContext; + Target target = Target::Create("llvm"); // use a fresh build context // in case we are already in a build context. - With fresh_build_ctx(BuildConfig::Create()); + With fresh_build_ctx(PassContext::Create()); return CreateInterpreter(IRModule(nullptr), CPUContext(), target); } @@ -582,20 +552,16 @@ struct WithFuncIdAttrs : public tvm::AttrsNode { FuncId fid; TVM_DECLARE_ATTRS(WithFuncIdAttrs, "relay.attrs.WithFuncIdAttrs") { - TVM_ATTR_FIELD(fid) - .describe("The FuncId that an function is annotated with.") - .set_default(-1); + TVM_ATTR_FIELD(fid).describe("The FuncId that an function is annotated with.").set_default(-1); } }; TVM_REGISTER_NODE_TYPE(WithFuncIdAttrs); - RELAY_REGISTER_OP("annotation.with_funcid") -.describe(R"code(Annotate a function with a funcid.)code" -TVM_ADD_FILELINE) -.set_num_inputs(1) -.add_argument("func", "Function", "The input data."); + .describe(R"code(Annotate a function with a funcid.)code" TVM_ADD_FILELINE) + .set_num_inputs(1) + .add_argument("func", "Function", "The input data."); // Cache with_funcid op to reduce lookup overhead during traversal. static const Op& with_funcid_op = Op::Get("annotation.with_funcid"); @@ -624,7 +590,7 @@ Function AsFunc(const Expr& e) { class PartialEvaluator : public ExprFunctor, public PatternFunctor { public: - PartialEvaluator(const IRModule& mod) : mod_(mod) { } + PartialEvaluator(const IRModule& mod) : mod_(mod) {} PStatic VisitExpr(const Expr& e, LetList* ll) final { PStatic ret = ExprFunctor::VisitExpr(e, ll); @@ -639,9 +605,8 @@ class PartialEvaluator : public ExprFunctor return VisitExpr(c->args[0], ll, name); } } - PStatic ret = e.as() ? - VisitFunc(Downcast(e), ll, name) : - VisitExpr(e, ll); + PStatic ret = + e.as() ? VisitFunc(Downcast(e), ll, name) : VisitExpr(e, ll); CHECK(IsAtomic(ret->dynamic)) << ret->dynamic; return ret; } @@ -670,9 +635,7 @@ class PartialEvaluator : public ExprFunctor } } - PStatic VisitExpr_(const VarNode* op, LetList* ll) final { - return env_.Lookup(GetRef(op)); - } + PStatic VisitExpr_(const VarNode* op, LetList* ll) final { return env_.Lookup(GetRef(op)); } PStatic VisitGlobalVar(const GlobalVar& gv) { CHECK(mod_.defined()); @@ -714,15 +677,11 @@ class PartialEvaluator : public ExprFunctor } } else { Expr t = store_.Extend([&]() { - return LetList::With([&](LetList* ll) { - return VisitExpr(op->true_branch, ll)->dynamic; - }); - }); + return LetList::With([&](LetList* ll) { return VisitExpr(op->true_branch, ll)->dynamic; }); + }); Expr f = store_.Extend([&]() { - return LetList::With([&](LetList* ll) { - return VisitExpr(op->false_branch, ll)->dynamic; - }); - }); + return LetList::With([&](LetList* ll) { return VisitExpr(op->false_branch, ll)->dynamic; }); + }); store_.Invalidate(); return NoStatic(ll->Push(If(c->dynamic, t, f))); } @@ -782,16 +741,12 @@ class PartialEvaluator : public ExprFunctor PartialEvaluator* pe_; FuncId fid_; Fuel old_fuel; - FuelFrame(PartialEvaluator* pe, - FuncId fid, - const Fuel& new_fuel) : pe_(pe), fid_(fid) { + FuelFrame(PartialEvaluator* pe, FuncId fid, const Fuel& new_fuel) : pe_(pe), fid_(fid) { CHECK_GT(pe_->fuel_map_.count(fid_), 0); old_fuel = pe_->fuel_map_[fid_]; pe_->fuel_map_[fid_] = new_fuel; } - ~FuelFrame() { - pe_->fuel_map_[fid_] = old_fuel; - } + ~FuelFrame() { pe_->fuel_map_[fid_] = old_fuel; } }; size_t GetFTValue(const PStatic& ps) { @@ -829,82 +784,76 @@ class PartialEvaluator : public ExprFunctor free_vars.push_back(std::pair(v, env_.Lookup(v))); } } - return [=](const PStatic& self, - const std::vector& pv, - const Attrs& attrs, - const tvm::Array& type_args, - LetList* ll) { + return [=](const PStatic& self, const std::vector& pv, const Attrs& attrs, + const tvm::Array& type_args, LetList* ll) { return env_.Extend([&]() { - CHECK_EQ(pv.size(), func->params.size()); - CHECK_GT(func_map_.count(func), 0); - FuncId fid = func_map_.at(func); - if (fuel_map_.count(fid) == 0) { - fuel_map_.insert({fid, MkFTop()}); + CHECK_EQ(pv.size(), func->params.size()); + CHECK_GT(func_map_.count(func), 0); + FuncId fid = func_map_.at(func); + if (fuel_map_.count(fid) == 0) { + fuel_map_.insert({fid, MkFTop()}); + } + std::vector args_fuel; + for (const auto& v : pv) { + args_fuel.push_back(GetFuel(v)); + } + auto meet_res = fuel_map_[fid]->Meet(MkFSeq(args_fuel)); + if (std::get<1>(meet_res)) { + FuelFrame tf(this, fid, std::get<0>(meet_res)); + Expr dedup_func = RegisterFuncId(DeDup(AnnotateFuncId(func))); + Function func = AsFunc(dedup_func); + if (var.as()) { + env_.Insert(Downcast(var), self); } - std::vector args_fuel; - for (const auto& v : pv) { - args_fuel.push_back(GetFuel(v)); + for (size_t i = 0; i < pv.size(); ++i) { + env_.Insert(func->params[i], pv[i]); + } + for (const auto& p : free_vars) { + env_.Insert(p.first, p.second); + } + tvm::Map subst; + for (size_t i = 0; i < type_args.size(); ++i) { + subst.Set(func->type_params[i], type_args[i]); } - auto meet_res = fuel_map_[fid]->Meet(MkFSeq(args_fuel)); - if (std::get<1>(meet_res)) { - FuelFrame tf(this, fid, std::get<0>(meet_res)); - Expr dedup_func = RegisterFuncId(DeDup(AnnotateFuncId(func))); - Function func = AsFunc(dedup_func); - if (var.as()) { - env_.Insert(Downcast(var), self); - } - for (size_t i = 0; i < pv.size(); ++i) { - env_.Insert(func->params[i], pv[i]); - } - for (const auto& p : free_vars) { - env_.Insert(p.first, p.second); - } - tvm::Map subst; - for (size_t i = 0; i < type_args.size(); ++i) { - subst.Set(func->type_params[i], type_args[i]); - } - for (size_t i = type_args.size(); i < func->type_params.size(); ++i) { - subst.Set(func->type_params[i], IncompleteType(kType)); - } - return VisitExpr(RegisterFuncId(TypeSubst(AnnotateFuncId(func->body), subst)), ll); - } else { - std::vector dyn; - for (const auto& v : pv) { - dyn.push_back(v->dynamic); - } - return NoStatic(ll->Push(Call(var, dyn, attrs, type_args))); + for (size_t i = type_args.size(); i < func->type_params.size(); ++i) { + subst.Set(func->type_params[i], IncompleteType(kType)); } - }); + return VisitExpr(RegisterFuncId(TypeSubst(AnnotateFuncId(func->body), subst)), ll); + } else { + std::vector dyn; + for (const auto& v : pv) { + dyn.push_back(v->dynamic); + } + return NoStatic(ll->Push(Call(var, dyn, attrs, type_args))); + } + }); }; } Expr VisitFuncDynamic(const Function& func, const Func& f, const Expr& self) { return store_.Extend([&]() { store_.Invalidate(); - return Function(func->params, - LetList::With([&](LetList* ll) { - std::vector pv; - for (const auto& v : func->params) { - pv.push_back(NoStatic(v)); - } - tvm::Array type_args; - for (const auto& tp : func->type_params) { - type_args.push_back(tp); - } - return f(HasStatic(MkSFunc(f), self), pv, Attrs(), type_args, ll)->dynamic; - }), func->ret_type, func->type_params, func->attrs); + return Function(func->params, LetList::With([&](LetList* ll) { + std::vector pv; + for (const auto& v : func->params) { + pv.push_back(NoStatic(v)); + } + tvm::Array type_args; + for (const auto& tp : func->type_params) { + type_args.push_back(tp); + } + return f(HasStatic(MkSFunc(f), self), pv, Attrs(), type_args, ll)->dynamic; + }), + func->ret_type, func->type_params, func->attrs); }); } - PStatic VisitFunc(const Function& func, - LetList* ll, - const Var& name = Var("x", Type())) { + PStatic VisitFunc(const Function& func, LetList* ll, const Var& name = Var("x", Type())) { Func f = VisitFuncStatic(func, name); Function u_func = AsFunc(RegisterFuncId(DeDup(AnnotateFuncId(func)))); // TODO(@M.K.): we seems to reduce landin knot into letrec. // restore letrec support across whole relay. - return HasStatic(MkSFunc(f), - ll->Push(name, VisitFuncDynamic(u_func, f, name))); + return HasStatic(MkSFunc(f), ll->Push(name, VisitFuncDynamic(u_func, f, name))); } PStatic VisitExpr_(const FunctionNode* op, LetList* ll) final { @@ -912,7 +861,7 @@ class PartialEvaluator : public ExprFunctor } struct ReflectError : dmlc::Error { - ReflectError() : dmlc::Error("static value not found") { } + ReflectError() : dmlc::Error("static value not found") {} }; Expr Reflect(const PStatic& st) { @@ -954,31 +903,24 @@ class PartialEvaluator : public ExprFunctor // Constant evaluate a expression. PStatic ConstEvaluate(const Expr& expr, LetList* ll) { - std::vector passes = {transform::FuseOps(0), - transform::InferType()}; + std::vector passes = {transform::FuseOps(0), transform::InferType()}; auto mod = IRModule::FromExpr(expr); auto seq = transform::Sequential(passes); mod = seq(mod); auto entry_func = Downcast(mod->Lookup("main")); - auto fused_infered = - expr.as() == nullptr ? entry_func->body : entry_func; + auto fused_infered = expr.as() == nullptr ? entry_func->body : entry_func; return Reify(executor_(fused_infered), ll); } Func ConstEvaluateFunc(const Expr& expr) { CHECK_EQ(FreeVars(expr).size(), 0); - return [=](const PStatic& self, - const std::vector& pv, - const Attrs& attrs, - const tvm::Array& type_args, - LetList* ll) { + return [=](const PStatic& self, const std::vector& pv, const Attrs& attrs, + const tvm::Array& type_args, LetList* ll) { tvm::Array ns_args; for (const PStatic& ps : pv) { ns_args.push_back(ps->dynamic); } - auto ns = [&]() { - return NoStatic(ll->Push(Call(expr, ns_args, attrs, type_args))); - }; + auto ns = [&]() { return NoStatic(ll->Push(Call(expr, ns_args, attrs, type_args))); }; if (StatefulOp(expr)) { return ns(); } @@ -988,8 +930,7 @@ class PartialEvaluator : public ExprFunctor args.push_back(Reflect(ps)); } return ConstEvaluate(Call(expr, args, attrs, type_args), ll); - } - catch (const ReflectError&) { + } catch (const ReflectError&) { return ns(); } }; @@ -1001,11 +942,8 @@ class PartialEvaluator : public ExprFunctor PStatic VisitExpr_(const ConstructorNode* op, LetList* ll) final { Constructor c = GetRef(op); - Func f = [=](const PStatic& self, - const std::vector& pv, - const Attrs& attrs, - const tvm::Array& type_args, - LetList* ll) { + Func f = [=](const PStatic& self, const std::vector& pv, const Attrs& attrs, + const tvm::Array& type_args, LetList* ll) { tvm::Array dyn; for (const PStatic& ps : pv) { dyn.push_back(ps->dynamic); @@ -1020,30 +958,30 @@ class PartialEvaluator : public ExprFunctor return env_.Extend([&]() { for (const Clause& c : op->clauses) { switch (VisitPattern(c->lhs, ps)) { - case MatchStatus::Match: - return VisitExpr(c->rhs, ll); - case MatchStatus::NoMatch: - continue; - case MatchStatus::Unknown: - return [&]() { - tvm::Array clauses; - for (const Clause& c : op->clauses) { - Expr expr = store_.Extend([&]() { - return LetList::With([&](LetList* ll) { - for (const Var& v : BoundVars(c->lhs)) { - env_.Insert(v, NoStatic(v)); - } - return VisitExpr(c->rhs, ll)->dynamic; + case MatchStatus::Match: + return VisitExpr(c->rhs, ll); + case MatchStatus::NoMatch: + continue; + case MatchStatus::Unknown: + return [&]() { + tvm::Array clauses; + for (const Clause& c : op->clauses) { + Expr expr = store_.Extend([&]() { + return LetList::With([&](LetList* ll) { + for (const Var& v : BoundVars(c->lhs)) { + env_.Insert(v, NoStatic(v)); + } + return VisitExpr(c->rhs, ll)->dynamic; + }); }); - }); - clauses.push_back(Clause(c->lhs, expr)); - } - store_.Invalidate(); - return NoStatic(ll->Push(Match(ps->dynamic, clauses, op->complete))); - }(); - default: - LOG(FATAL) << "Unknown MatchStatus"; - throw; + clauses.push_back(Clause(c->lhs, expr)); + } + store_.Invalidate(); + return NoStatic(ll->Push(Match(ps->dynamic, clauses, op->complete))); + }(); + default: + LOG(FATAL) << "Unknown MatchStatus"; + throw; } } LOG(FATAL) << "No case Match"; @@ -1071,12 +1009,12 @@ class PartialEvaluator : public ExprFunctor for (size_t i = 0; i < op->patterns.size(); ++i) { MatchStatus ms = VisitPattern(op->patterns[i], scn->fields[i]); switch (ms) { - case MatchStatus::Match: - continue; - case MatchStatus::NoMatch: - return MatchStatus::NoMatch; - case MatchStatus::Unknown: - current_match_status = MatchStatus::Unknown; + case MatchStatus::Match: + continue; + case MatchStatus::NoMatch: + return MatchStatus::NoMatch; + case MatchStatus::Unknown: + current_match_status = MatchStatus::Unknown; } } return current_match_status; @@ -1095,12 +1033,12 @@ class PartialEvaluator : public ExprFunctor for (size_t i = 0; i < op->patterns.size(); ++i) { MatchStatus ms = VisitPattern(op->patterns[i], stn->fields[i]); switch (ms) { - case MatchStatus::Match: - continue; - case MatchStatus::NoMatch: - return MatchStatus::NoMatch; - case MatchStatus::Unknown: - current_match_status = MatchStatus::Unknown; + case MatchStatus::Match: + continue; + case MatchStatus::NoMatch: + return MatchStatus::NoMatch; + case MatchStatus::Unknown: + current_match_status = MatchStatus::Unknown; } } return current_match_status; @@ -1112,7 +1050,7 @@ class PartialEvaluator : public ExprFunctor void InitializeFuncId(const Expr& e) { struct InitializeFuncIdVisitor : ExprVisitor, PatternVisitor { PartialEvaluator* pe; - explicit InitializeFuncIdVisitor(PartialEvaluator* pe) : pe(pe) { } + explicit InitializeFuncIdVisitor(PartialEvaluator* pe) : pe(pe) {} void VisitExpr_(const FunctionNode* op) final { Function f = GetRef(op); @@ -1121,9 +1059,7 @@ class PartialEvaluator : public ExprFunctor VisitExpr(f->body); } - void VisitPattern(const Pattern& p) final { - PatternVisitor::VisitPattern(p); - } + void VisitPattern(const Pattern& p) final { PatternVisitor::VisitPattern(p); } }; InitializeFuncIdVisitor(this).VisitExpr(e); } @@ -1131,7 +1067,7 @@ class PartialEvaluator : public ExprFunctor Expr RegisterFuncId(const Expr& e) { struct RegisterFuncIdVisitor : ExprVisitor, PatternVisitor { PartialEvaluator* pe; - explicit RegisterFuncIdVisitor(PartialEvaluator* pe) : pe(pe) { } + explicit RegisterFuncIdVisitor(PartialEvaluator* pe) : pe(pe) {} void VisitExpr_(const CallNode* op) final { if (op->op == with_funcid_op) { @@ -1154,9 +1090,7 @@ class PartialEvaluator : public ExprFunctor ExprVisitor::VisitExpr_(op); } - void VisitPattern(const Pattern& p) final { - PatternVisitor::VisitPattern(p); - } + void VisitPattern(const Pattern& p) final { PatternVisitor::VisitPattern(p); } }; RegisterFuncIdVisitor(this).VisitExpr(e); return e; @@ -1165,7 +1099,7 @@ class PartialEvaluator : public ExprFunctor Expr AnnotateFuncId(const Expr& e) { struct AnnotateFuncIdMutator : ExprMutator, PatternMutator { PartialEvaluator* pe; - explicit AnnotateFuncIdMutator(PartialEvaluator* pe) : pe(pe) { } + explicit AnnotateFuncIdMutator(PartialEvaluator* pe) : pe(pe) {} Expr VisitExpr_(const FunctionNode* op) final { Function f = GetRef(op); @@ -1173,13 +1107,9 @@ class PartialEvaluator : public ExprFunctor return MkWithFuncId(ExprMutator::VisitExpr_(op), pe->func_map_.at(f)); } - Pattern VisitPattern(const Pattern& p) final { - return PatternMutator::VisitPattern(p); - } + Pattern VisitPattern(const Pattern& p) final { return PatternMutator::VisitPattern(p); } - Var VisitVar(const Var& v) final { - return v; - } + Var VisitVar(const Var& v) final { return v; } }; return AnnotateFuncIdMutator(this).VisitExpr(e); } @@ -1187,7 +1117,7 @@ class PartialEvaluator : public ExprFunctor private: Environment env_; IRModule mod_; - std::unordered_map gv_map_; + std::unordered_map gv_map_; /*! Termination checking is done as follows: * We have finitely many FunctionIds. * Each FunctionId maps to a class of semantically equivalent function (ignoring type), @@ -1199,9 +1129,10 @@ class PartialEvaluator : public ExprFunctor * If no progress is made, we do not inline. * In both case, we remap the mapping to the new Fuel * when we PE inside the Function body. - * Termination is guaranteed because Fuel is finitely descending - there can only be so many meet. + * Termination is guaranteed because Fuel is finitely descending - there can only be so many + * meet. */ - std::unordered_map func_map_; + std::unordered_map func_map_; std::unordered_map fuel_map_; Store store_; DLContext context_ = CPUContext(); @@ -1219,9 +1150,7 @@ Expr Remap(const Expr& e) { return remap_.at(v); } - Var VisitVar(const Var& v) final { - return Downcast(VisitExpr(v)); - } + Var VisitVar(const Var& v) final { return Downcast(VisitExpr(v)); } private: std::unordered_map remap_; @@ -1240,20 +1169,14 @@ Expr StripWithFuncId(const Expr& e) { } } - Pattern VisitPattern(const Pattern& p) final { - return PatternMutator::VisitPattern(p); - } + Pattern VisitPattern(const Pattern& p) final { return PatternMutator::VisitPattern(p); } - Var VisitVar(const Var& v) final { - return v; - } + Var VisitVar(const Var& v) final { return v; } }; return StripWithFuncIdMutator().VisitExpr(e); } -Expr PostProcess(const Expr& e) { - return StripWithFuncId(DeDup(Remap(e))); -} +Expr PostProcess(const Expr& e) { return StripWithFuncId(DeDup(Remap(e))); } } // namespace partial_eval @@ -1273,14 +1196,11 @@ namespace transform { Pass PartialEval() { runtime::TypedPackedFunc pass_func = - [=](IRModule m, PassContext pc) { - return relay::PartialEval(m); - }; + [=](IRModule m, PassContext pc) { return relay::PartialEval(m); }; return CreateModulePass(pass_func, 1, "PartialEvaluate", {}); } -TVM_REGISTER_GLOBAL("relay._transform.PartialEvaluate") -.set_body_typed(PartialEval); +TVM_REGISTER_GLOBAL("relay._transform.PartialEvaluate").set_body_typed(PartialEval); } // namespace transform diff --git a/src/relay/transforms/partition_graph.cc b/src/relay/transforms/partition_graph.cc index 15ad60be3a95..e173bc32734f 100644 --- a/src/relay/transforms/partition_graph.cc +++ b/src/relay/transforms/partition_graph.cc @@ -44,49 +44,36 @@ #include "../analysis/annotated_region_set.h" #include "../backend/utils.h" +#include "pass_util.h" namespace tvm { namespace relay { namespace partitioning { -// Cache compiler_begin and compiler_end annotation ops for equivalence check to -// reduce registry lookup overhead. -static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin"); -static const Op& compiler_end_op = Op::Get("annotation.compiler_end"); - -/*! - * \brief The checker that verifies if a Relay program is annotated correctly - * for partitioning. +/*! \brief This struct maintains the required metadata for a region to generate a corresponding + * global function and function call. Global function will be passed to the target specific codegen + * and function call will be used in the transform Relay graph to invoke the function in runtime. */ -class AnnotationChecker : public ExprVisitor { - public: - bool Check() { - if (!found_start_ && !found_end_) { - LOG(WARNING) << "No compiler annotation found"; - } else if (!found_start_) { - LOG(ERROR) << "compiler_begin annotation is missing"; - return false; - } else if (!found_end_) { - LOG(ERROR) << "compiler_end annotation is missing"; - return false; - } - return true; - } +struct RegionFuncMetadata { + /*! \brief The call node of the generated global function for this region. */ + Call func_call; - void VisitExpr_(const CallNode* call) final { - auto op_node = call->op.as(); - if (op_node == nullptr || call->attrs.as() == nullptr) { - return; - } else if (call->op == compiler_begin_op) { - found_start_ = true; - } else if (call->op == compiler_end_op) { - found_end_ = true; - } - } + /*! \brief A list of argument pairs. Each pair includes (var, expr). var is used + * as a function node argument; input expression is used as a function call parameter. + */ + std::vector> args; - private: - bool found_start_{false}; - bool found_end_{false}; + /*! \brief Map from each region output expr (compiler end) node to + * the corresponding function output expr. + */ + std::unordered_map region_func_out; + + /*! \brief Map from each region input expression (compiler begin) to + * the corresponding function input variable. This cache is used to make sure + * a region function will not have duplicated inputs even if it refers to + * the same expr multiple times. + */ + std::unordered_map region_func_in; }; /*! \brief This class partitions the expr labeled with begin and end annotations @@ -124,37 +111,34 @@ class AnnotationChecker : public ExprVisitor { * the compiler name. */ -class Partitioner : public ExprMutator { +class Partitioner : public MixedModeMutator { public: explicit Partitioner(const IRModule& module) : module_(module) { for (auto f : module->functions) { GlobalVar f_var = f.first; BaseFunc f_func = f.second; - // Creating regionset per function in the module - auto region_set = AnnotatedRegionSet::Create(f_func, partitioning::compiler_begin_op, - partitioning::compiler_end_op); + // Creating regionset per function in the module. + auto region_set = AnnotatedRegionSet::Create(f_func, CompilerBeginOp(), CompilerEndOp()); regions_sets_[region_set] = f_func; } } - Expr VisitExpr_(const CallNode* call) final { + Expr Rewrite_(const CallNode* call, const Expr& post) final { auto op_node = call->op.as(); if (op_node == nullptr || call->attrs.as() == nullptr) { - return ExprMutator::VisitExpr_(call); - } else if (call->op == compiler_begin_op) { - // The annotation node is inserted on edge so it must have only one - // argument. + return post; + } else if (call->op == CompilerBeginOp()) { + // The annotation node is inserted on edge so it must have only one argument. CHECK_EQ(call->args.size(), 1U); // Traverse the rest graph. Expr parent = call->args[0]; - auto input_expr = VisitExpr(parent); + auto input_expr = Downcast(post)->args[0]; // Backtrace the parent to find the first ancestor node that is not a begin or end op while (const auto* parent_call = parent.as()) { - if (parent_call->op == compiler_begin_op || - parent_call->op == compiler_end_op) { + if (parent_call->op == CompilerBeginOp() || parent_call->op == CompilerEndOp()) { parent = parent_call->args[0]; } else { break; @@ -165,8 +149,8 @@ class Partitioner : public ExprMutator { int index = GetArgIdx(sg, GetRef(call)); CHECK_NE(index, -1); - if (shared_output_.count(parent) && shared_output_[parent].count(sg)) { - return shared_output_[parent][sg]; + if (region_func_meta_[sg].region_func_in.count(parent)) { + return region_func_meta_[sg].region_func_in[parent]; } else { // The type of the created variable is the same as the compiler_begin // node. @@ -177,15 +161,15 @@ class Partitioner : public ExprMutator { std::pair cand = std::make_pair(var, input_expr); - if (std::find(region_args[sg].begin(), region_args[sg].end(), cand) == - region_args[sg].end()) { - region_args[sg].push_back(cand); + if (std::find(region_func_meta_[sg].args.begin(), region_func_meta_[sg].args.end(), cand) == + region_func_meta_[sg].args.end()) { + region_func_meta_[sg].args.push_back(cand); } - shared_output_[parent][sg] = var; + region_func_meta_[sg].region_func_in[parent] = var; return std::move(var); } } else { - CHECK_EQ(call->op, compiler_end_op); + CHECK_EQ(call->op, CompilerEndOp()); // The annotation node is inserted on edge so it must have only one // argument. CHECK_EQ(call->args.size(), 1U); @@ -197,114 +181,21 @@ class Partitioner : public ExprMutator { BaseFunc f = GetFunc(GetRef(call)); // Traverse subgraph inputs. - auto input = VisitExpr(call->args[0]); + auto input = Downcast(post)->args[0]; CHECK(region.defined()) << "Region not defined for " << GetRef(call); // functions are created for each annotated regions, // when their first output is encountered. // If multiple outputs are there, a tuple node is inserted at the end. - // region_function_calls is map that maintains - // (each annotated regions) --> created function - if (region_function_calls.find(region) == region_function_calls.end()) { - // First time this region is encountered in the traversal. - // Creating the function. + if (!region_func_meta_[region].func_call.defined()) { + // First time this region is encountered in the traversal. Creating the function. CreateFunction(region, call); } - // Retrieve this particular output of function. - return GetFunctionOutput(region, GetRef(call)); - } - } - - Expr VisitExpr_(const TupleNode* op) final { - auto region = GetRegion(GetRef(op)); - if (!region.defined()) { - return ExprMutator::VisitExpr_(op); - } else { - Array fields; - for (auto field : op->fields) { - fields.push_back(VisitExpr(field)); - } - return Tuple(fields); - } - } - - Expr VisitExpr_(const TupleGetItemNode* g) final { - auto region = GetRegion(GetRef(g)); - if (!region.defined()) { - return ExprMutator::VisitExpr_(g); - } else { - auto t = VisitExpr(g->tuple); - return TupleGetItem(t, g->index); - } - } - - Expr VisitExpr_(const FunctionNode* op) final { - auto region = GetRegion(GetRef(op)); - if (!region.defined()) { - return ExprMutator::VisitExpr_(op); - } else { - Array params; - for (auto param : op->params) { - Var new_param = Downcast(VisitExpr(param)); - params.push_back(new_param); - } - auto body = VisitExpr(op->body); - return Function(params, body, op->ret_type, op->type_params, op->attrs); - } - } - - Expr VisitExpr_(const LetNode* op) final { - auto region = GetRegion(GetRef(op)); - if (!region.defined()) { - return ExprMutator::VisitExpr_(op); - } else { - Var var = Downcast(VisitExpr(op->var)); - auto value = VisitExpr(op->value); - auto body = VisitExpr(op->body); - return Let(var, value, body); - } - } - - Expr VisitExpr_(const IfNode* op) final { - auto region = GetRegion(GetRef(op)); - if (!region.defined()) { - return ExprMutator::VisitExpr_(op); - } else { - auto guard = VisitExpr(op->cond); - auto true_b = VisitExpr(op->true_branch); - auto false_b = VisitExpr(op->false_branch); - return If(guard, true_b, false_b); - } - } - Expr VisitExpr_(const RefCreateNode* op) final { - auto region = GetRegion(GetRef(op)); - if (!region.defined()) { - return ExprMutator::VisitExpr_(op); - } else { - Expr value = VisitExpr(op->value); - return RefCreate(value); - } - } - - Expr VisitExpr_(const RefReadNode* op) final { - auto region = GetRegion(GetRef(op)); - if (!region.defined()) { - return ExprMutator::VisitExpr_(op); - } else { - Expr ref = VisitExpr(op->ref); - return RefRead(ref); - } - } - - Expr VisitExpr_(const RefWriteNode* op) final { - auto region = GetRegion(GetRef(op)); - if (!region.defined()) { - return ExprMutator::VisitExpr_(op); - } else { - Expr ref = VisitExpr(op->ref); - Expr value = VisitExpr(op->value); - return RefWrite(ref, value); + // Retrieve this particular output of function. + Expr region_out_expr = Downcast(GetRef(call))->args[0]; + CHECK(region_func_meta_[region].region_func_out.count(region_out_expr)); + return region_func_meta_[region].region_func_out[region_out_expr]; } } @@ -370,35 +261,41 @@ class Partitioner : public ExprMutator { } /*! - * \brief This function is called first time that we encounter a compiler_end - * node to create the function for the subgraph. + * \brief Create a function and its function call for the given region. If the function has + * multiple outputs, a Tuple will be formed to aggregate all outputs, and TupleGetItem nodes + * will be created to serve output consumers. */ - void CreateFunction(AnnotatedRegion region, const CallNode* call) { - // Create fields which is a unique list of outputs. Also populate - // region_return_indices_ map which maps parent of compiler_end node to - // corresponding index in fields. + void CreateFunction(AnnotatedRegion region, const CallNode* end_node) { + // Create fields which is a unique list of outputs. Array fields; - int i = 0; - for (auto ret : region->GetOutputs()) { - auto ret_node = Downcast(ret)->args[0]; + std::unordered_map out_expr_to_idx; + int out_idx = 0; + for (auto region_end_node : region->GetOutputs()) { + auto ret_node = Downcast(region_end_node)->args[0]; // Don't duplicate outputs. - if (!region_return_indices_.count(region) || - !region_return_indices_[region].count(ret_node)) { - auto ret_expr = VisitExpr(ret_node); + if (!out_expr_to_idx.count(ret_node)) { + auto ret_expr = MixedModeMutator::VisitExpr(ret_node); fields.push_back(ret_expr); - region_return_indices_[region][ret_node] = i; - i++; + out_expr_to_idx[ret_node] = out_idx++; } } Array params; Array param_expr; - std::unordered_map params_bind; + Map params_bind; + + auto IsConstant = [](const Expr& expr) { + if (expr->IsInstance()) return true; + if (!expr->IsInstance()) return false; + const auto* tn = expr.as(); + return std::all_of(tn->fields.begin(), tn->fields.end(), + [](const Expr& e) { return e->IsInstance(); }); + }; - for (auto pair : region_args[region]) { + for (auto pair : region_func_meta_[region].args) { params.push_back(pair.first); - if (const auto* cn = pair.second.as()) { - params_bind[pair.first->name_hint()] = cn->data; + if (IsConstant(pair.second)) { + params_bind.Set(pair.first, pair.second); } else { param_expr.push_back(pair.second); } @@ -408,32 +305,29 @@ class Partitioner : public ExprMutator { if (fields.size() == 1) { // If there are only a single output; no need to add a tuple global_region_func = - Function(params, fields[0], call->args[0]->checked_type_, {}, DictAttrs()); + Function(params, fields[0], end_node->args[0]->checked_type_, {}, DictAttrs()); } else { auto tuple = Tuple(fields); global_region_func = Function(params, tuple, tuple->checked_type_, {}, DictAttrs()); } - std::string target = call->attrs.as()->compiler; + std::string target = end_node->attrs.as()->compiler; std::string name = target + "_" + std::to_string(region->GetID()); - global_region_func = WithAttr(std::move(global_region_func), tvm::attr::kGlobalSymbol, - runtime::String(name)); global_region_func = - WithAttr(std::move(global_region_func), attr::kPrimitive, tvm::Integer(1)); - global_region_func = WithAttr(std::move(global_region_func), attr::kCompiler, - tvm::runtime::String(target)); + WithAttr(std::move(global_region_func), tvm::attr::kGlobalSymbol, runtime::String(name)); + global_region_func = WithAttr(std::move(global_region_func), attr::kPrimitive, tvm::Integer(1)); global_region_func = - WithAttr(std::move(global_region_func), attr::kInline, tvm::Integer(1)); + WithAttr(std::move(global_region_func), attr::kCompiler, tvm::runtime::String(target)); + global_region_func = WithAttr(std::move(global_region_func), attr::kInline, tvm::Integer(1)); // Constant propagation if (!params_bind.empty()) { - global_region_func = backend::BindParamsByName(global_region_func, params_bind); + global_region_func = Downcast(relay::Bind(global_region_func, params_bind)); } std::string fname = name; - CHECK(!module_->ContainGlobalVar(fname)) - << "Global function " << fname << " already exists"; + CHECK(!module_->ContainGlobalVar(fname)) << "Global function " << fname << " already exists"; // Create a global function and add it to the IRModule for the region. // This way we lift the functions that should be handled by external // codegen to the module scope and rely on the pass manager to prevent @@ -442,132 +336,153 @@ class Partitioner : public ExprMutator { GlobalVar glob_func(fname); module_->Add(glob_func, global_region_func); - // The return type of callnode is the same as the type of the - // compiler_end node. - auto ret = Call(glob_func, param_expr); - region_function_calls[region] = ret; - } + // Create a call node for the function. + auto call = Call(glob_func, param_expr); + region_func_meta_[region].func_call = call; - /*! - * \brief Get the return(output) of the function for compiler end node "end_arg". - * This will return either a Call (for a function with a single output) or a - * TupleGetItem (for a function with multiple outputs). - */ - Expr GetFunctionOutput(AnnotatedRegion region, const Expr& end_arg) { - Expr arg = Downcast(end_arg)->args[0]; - // Function has one output. - if (region_return_indices_[region].size() == 1) { - return region_function_calls[region]; - } - // Function has multiple outputs. - // Use already made TupleGetItem. - if (region_return_tuplegetitem_.count(region) && - region_return_tuplegetitem_[region].count(arg)) { - return region_return_tuplegetitem_[region][arg]; + // Create output expr(s) for the function call. + if (out_expr_to_idx.size() == 1) { + // Single output direcly uses the call node as the output expr. + region_func_meta_[region].region_func_out[out_expr_to_idx.begin()->first] = call; + } else { + // Multiple outptus need to create TupleGetItem nodes as output exprs. + for (auto pair : out_expr_to_idx) { + Expr region_out_expr = pair.first; // The arg of a compiler end node of this region. + int idx = pair.second; // Corresponding function output tuple index. + auto tuple_get_item = TupleGetItem(call, idx); + tuple_get_item->checked_type_ = region_out_expr->checked_type_; + region_func_meta_[region].region_func_out[region_out_expr] = tuple_get_item; + } } - // Create new TupleGetItem. - CHECK(region_return_indices_.count(region) && - region_return_indices_[region].count(arg)); - int index = region_return_indices_[region][arg]; - - auto func_call = region_function_calls[region]; - auto tuple_get_item_ = TupleGetItem(func_call, index); - tuple_get_item_->checked_type_ = arg->checked_type_; - region_return_tuplegetitem_[region][arg] = tuple_get_item_; - return std::move(tuple_get_item_); } - /*! - * \brief This map maintains the already created function calls. - * This is required in the multi-output scenario, to link rest of the outputs - * to call - */ - std::unordered_map region_function_calls; - - /*! - * \brief This map maintains arguments (of region) visits through visitor - * patterns. Those arguement var and expression will be used to when creating - * the function. - */ - std::unordered_map>, ObjectHash, ObjectEqual> - region_args; - - /*! - * \brief This map maintains the index of an output in the subgraph function - * for a given region. If there are multiple entries for a region, then the - * function has a tuple of multiple outputs for its return. - */ - using RegionRetIndexMap = std::unordered_map; - std::unordered_map - region_return_indices_; - - /*! - * \brief This map holds already created TupleGetItem nodes for accessing - * outputs of a function. - */ - using RegionRetTupleGetItemMap = std::unordered_map; - std::unordered_map - region_return_tuplegetitem_; + /*! \brief Map from each region to its metadata of the generated function. */ + std::unordered_map + region_func_meta_; - /*! - * \brief Each region set is associated with a function in the module. + /*! \brief Each region set is associated with a function in the module. * This map maintains the mapping between regionsets and the function it * belongs to */ - std::unordered_map regions_sets_; - - /*!\brief Cache the output that is shared by different nodes. */ - using RegionOutputMap = std::unordered_map; - std::unordered_map shared_output_; + std::unordered_map regions_sets_; /*!\brief The IRModule used for partitioning. */ IRModule module_; }; -class DefaultRemover : public ExprMutator { - public: - explicit DefaultRemover(const IRModule& module) : module_(module) {} +IRModule RemoveDefaultAnnotations(IRModule module) { + class DefaultRemover : public ExprRewriter { + public: + DefaultRemover() = default; - IRModule Remove() { - auto glob_funcs = module_->functions; - for (const auto& pair : glob_funcs) { - if (auto* fn = pair.second.as()) { - auto func = GetRef(fn); - func = Function(func->params, VisitExpr(func->body), func->ret_type, func->type_params, - func->attrs); - module_->Update(pair.first, func); + Expr Rewrite_(const CallNode* call, const Expr& post) final { + auto attrs = call->attrs.as(); + if (attrs != nullptr && attrs->compiler == "default") { + return Downcast(post)->args[0]; } + return post; } - return module_; - } + }; - Expr VisitExpr_(const CallNode* call) final { - auto attrs = call->attrs.as(); - if (attrs != nullptr && attrs->compiler == "default") { - return VisitExpr(call->args[0]); + auto glob_funcs = module->functions; + // module is mutable, hence, we make a copy of it. + module.CopyOnWrite(); + for (const auto& pair : glob_funcs) { + if (auto* fn = pair.second.as()) { + auto func = GetRef(fn); + DefaultRemover remover; + auto removed = PostOrderRewrite(func->body, &remover); + func = Function(func->params, removed, func->ret_type, func->type_params, func->attrs); + module->Update(pair.first, func); } - return ExprMutator::VisitExpr_(call); } + return module; +} - private: - IRModule module_; -}; +/*! \brief There can be regions with multiple outputs where each output + * could be a tuple output. Such tuple outputs needs to be flattened + * otherwise the function would create tuples of tuples. Moreover, tuple + * of tuples are valid relay, however they are not currently supported by + * graph runtime or relay VM. + */ + +// New annotations would be required to be added for each flattened output +const PackedFunc* make_end_op = runtime::Registry::Get("relay.op.annotation._make.compiler_end"); + +IRModule FlattenTupleOutputs(IRModule module) { + class TupleOutFlattener : public ExprRewriter { + public: + TupleOutFlattener() = default; + + Expr Rewrite_(const CallNode* call, const Expr& post) final { + if (call->op == CompilerEndOp()) { + std::string target = call->attrs.as()->compiler; + // Arguments of annotation ops should be 1 + CHECK_EQ(call->args.size(), 1U); + auto annotated_op = Downcast(post)->args[0]; + if (const auto* tn = annotated_op.as()) { + Array new_fields; + + // Here each input of the tuple will be annotated with compiler_ends + for (auto& tn_arg : tn->fields) { + new_fields.push_back((*make_end_op)(tn_arg, target)); + } + + // Return a tuple of compiler_ends in the place of the tuple that was + // annotated with a compiler_end. + auto out = Tuple(new_fields); + return std::move(out); + } + } + return post; + } + }; + + auto glob_funcs = module->functions; + // module is mutable, hence, we make a copy of it. + module.CopyOnWrite(); + for (const auto& pair : glob_funcs) { + if (auto* fn = pair.second.as()) { + auto func = GetRef(fn); + TupleOutFlattener to_flattener; + auto removed = PostOrderRewrite(func->body, &to_flattener); + func = Function(func->params, removed, func->ret_type, func->type_params, func->attrs); + module->Update(pair.first, func); + } + } + return module; +} } // namespace partitioning namespace transform { Pass PartitionGraph() { - runtime::TypedPackedFunc part_func = - [=](IRModule m, PassContext pc) { - // TODO(@comaniac, @zhiics): We should also handle the annotation with "default" attribute - // by treating them as un-annotated, but we don't have it yet. This workaround pass removes - // all "default" annotations and should be deleted in the future. - auto new_m = partitioning::DefaultRemover(m).Remove(); - return partitioning::Partitioner(new_m).Partition(); + runtime::TypedPackedFunc flatten_tuples = [=](IRModule m, + PassContext pc) { + // There could be compiler_end annotations on tuples + // If the corresponding region is having multiple compiler_ends, + // this would lead to creation of tuples of tuples. + // Thus, we flatten the tuples by transfering the compiler_end to + // the tuple inputs. + return partitioning::FlattenTupleOutputs(m); + }; + + runtime::TypedPackedFunc remove_defaults = [=](IRModule m, + PassContext pc) { + // TODO(@comaniac, @zhiics): We should also handle the annotation with "default" attribute + // by treating them as un-annotated, but we don't have it yet. This workaround pass removes + // all "default" annotations and should be deleted in the future. + return partitioning::RemoveDefaultAnnotations(m); }; - auto partitioned = CreateModulePass(part_func, 0, "PartitionGraph", {}); - return Sequential({partitioned, InferType()}); + + runtime::TypedPackedFunc part_func = + [=](IRModule m, PassContext pc) { return partitioning::Partitioner(m).Partition(); }; + + auto flatten_tuples_pass = CreateModulePass(flatten_tuples, 0, "FlattenNestedTuples", {}); + auto remove_default_pass = CreateModulePass(remove_defaults, 0, "RemoveDefaultAnnotations", {}); + auto partition_pass = CreateModulePass(part_func, 0, "PartitionGraph", {}); + return Sequential({flatten_tuples_pass, remove_default_pass, partition_pass, InferType()}); } TVM_REGISTER_GLOBAL("relay._transform.PartitionGraph").set_body_typed(transform::PartitionGraph); diff --git a/src/relay/transforms/pass_util.h b/src/relay/transforms/pass_util.h index 56b064573841..35bbb234dbc1 100644 --- a/src/relay/transforms/pass_util.h +++ b/src/relay/transforms/pass_util.h @@ -25,9 +25,10 @@ #ifndef TVM_RELAY_TRANSFORMS_PASS_UTIL_H_ #define TVM_RELAY_TRANSFORMS_PASS_UTIL_H_ -#include -#include #include +#include +#include + #include #include @@ -75,6 +76,20 @@ Type TypeSubst(const Type& type, const tvm::Map& subst_map); */ Expr TypeSubst(const Expr& expr, const tvm::Map& subst_map); +/*! + * \brief Check if type is dynamic. + * \param ty The type to be checked. + * \return Whether the type is dynamic. + */ +bool IsDynamic(const Type& ty); + +/*! + * \brief Check if call is data dependant. + * \param call The call to be checked. + * \return Whether the call is data dependant. + */ +bool IsDataDependant(const CallNode* call); + /*! * \brief Make arbitrary transformation preserve the out most function. * \param func The transformation. @@ -100,41 +115,57 @@ inline bool IsAtomic(const Expr& e) { return e.as() || e.as() || e.as() || e.as(); } -template +/*! + * \brief Cache the compiler_begin annotation op to reduce registry lookup overhead + * \param void + * \return compiler_begin op + */ +inline const Op& CompilerBeginOp() { + static Op op = Op::Get("annotation.compiler_begin"); + return op; +} + +/*! + * \brief Cache the compiler_end annotation op to reduce registry lookup overhead + * \param void + * \return compiler_end op + */ +inline const Op& CompilerEndOp() { + static Op op = Op::Get("annotation.compiler_end"); + return op; +} + +template struct TreeNode { typedef std::shared_ptr> pointer; virtual ~TreeNode() {} }; -template +template struct TreeLeafNode : TreeNode { using TreeObjectPtr = typename TreeNode::pointer; Expr body; - explicit TreeLeafNode(Expr body): body(body) {} + explicit TreeLeafNode(Expr body) : body(body) {} - static TreeObjectPtr Make(Expr body) { - return std::make_shared(body); - } + static TreeObjectPtr Make(Expr body) { return std::make_shared(body); } ~TreeLeafNode() {} }; -template +template struct TreeLeafFatalNode : TreeNode { using TreeObjectPtr = typename TreeNode::pointer; TreeLeafFatalNode() = default; - static TreeObjectPtr Make() { - return std::make_shared(); - } + static TreeObjectPtr Make() { return std::make_shared(); } ~TreeLeafFatalNode() {} }; -template +template struct TreeBranchNode : TreeNode { using TreeObjectPtr = typename TreeNode::pointer; @@ -142,15 +173,11 @@ struct TreeBranchNode : TreeNode { TreeObjectPtr then_branch; TreeObjectPtr else_branch; - TreeBranchNode(ConditionObjectPtr cond, - TreeObjectPtr then_branch, - TreeObjectPtr else_branch) - : cond(cond), then_branch(then_branch), else_branch(else_branch) {} - + TreeBranchNode(ConditionObjectPtr cond, TreeObjectPtr then_branch, TreeObjectPtr else_branch) + : cond(cond), then_branch(then_branch), else_branch(else_branch) {} - static TreeObjectPtr Make(ConditionObjectPtr cond, - TreeObjectPtr then_branch, - TreeObjectPtr else_branch) { + static TreeObjectPtr Make(ConditionObjectPtr cond, TreeObjectPtr then_branch, + TreeObjectPtr else_branch) { return std::make_shared(cond, then_branch, else_branch); } diff --git a/src/relay/transforms/pattern_util.h b/src/relay/transforms/pattern_util.h index cd2af9f2ac2e..7518eb9ac81a 100644 --- a/src/relay/transforms/pattern_util.h +++ b/src/relay/transforms/pattern_util.h @@ -28,19 +28,19 @@ #include #include -#include -#include -#include #include #include -#include #include +#include +#include +#include #include +#include +#include #include -#include #include - +#include namespace tvm { namespace relay { @@ -49,42 +49,42 @@ namespace relay { * \brief Dispatch DataType to the C++ data type * during runtime. */ -#define TVM_DTYPE_DISPATCH(type, DType, ...) \ - if (type == DataType::Float(64)) { \ - typedef double DType; \ - {__VA_ARGS__} \ - } else if (type == DataType::Float(32)) { \ - typedef float DType; \ - {__VA_ARGS__} \ - } else if (type == DataType::Float(16)) { \ - typedef uint16_t DType; \ - {__VA_ARGS__} \ - } else if (type == DataType::Int(64)) { \ - typedef int64_t DType; \ - {__VA_ARGS__} \ - } else if (type == DataType::Int(32)) { \ - typedef int32_t DType; \ - {__VA_ARGS__} \ - } else if (type == DataType::Int(16)) { \ - typedef int16_t DType; \ - {__VA_ARGS__} \ - } else if (type == DataType::Int(8)) { \ - typedef int8_t DType; \ - {__VA_ARGS__} \ - } else if (type == DataType::UInt(64)) { \ - typedef uint64_t DType; \ - {__VA_ARGS__} \ - } else if (type == DataType::UInt(32)) { \ - typedef uint32_t DType; \ - {__VA_ARGS__} \ - } else if (type == DataType::UInt(16)) { \ - typedef uint16_t DType; \ - {__VA_ARGS__} \ - } else if (type == DataType::UInt(8)) { \ - typedef uint8_t DType; \ - {__VA_ARGS__} \ - } else { \ - LOG(FATAL) << "unknown data type " << type; \ +#define TVM_DTYPE_DISPATCH(type, DType, ...) \ + if (type == DataType::Float(64)) { \ + typedef double DType; \ + { __VA_ARGS__ } \ + } else if (type == DataType::Float(32)) { \ + typedef float DType; \ + { __VA_ARGS__ } \ + } else if (type == DataType::Float(16)) { \ + typedef uint16_t DType; \ + { __VA_ARGS__ } \ + } else if (type == DataType::Int(64)) { \ + typedef int64_t DType; \ + { __VA_ARGS__ } \ + } else if (type == DataType::Int(32)) { \ + typedef int32_t DType; \ + { __VA_ARGS__ } \ + } else if (type == DataType::Int(16)) { \ + typedef int16_t DType; \ + { __VA_ARGS__ } \ + } else if (type == DataType::Int(8)) { \ + typedef int8_t DType; \ + { __VA_ARGS__ } \ + } else if (type == DataType::UInt(64)) { \ + typedef uint64_t DType; \ + { __VA_ARGS__ } \ + } else if (type == DataType::UInt(32)) { \ + typedef uint32_t DType; \ + { __VA_ARGS__ } \ + } else if (type == DataType::UInt(16)) { \ + typedef uint16_t DType; \ + { __VA_ARGS__ } \ + } else if (type == DataType::UInt(8)) { \ + typedef uint8_t DType; \ + { __VA_ARGS__ } \ + } else { \ + LOG(FATAL) << "unknown data type " << type; \ } /*! @@ -99,10 +99,8 @@ namespace relay { * \param rhs_value A squeezed version of rhs which only contains matched dimension. * \return Whether match is successful. */ -inline bool MatchBroadcastToLeftAxes(const TensorTypeNode* tlhs, - const TensorTypeNode* trhs, - const Array& lhs_axes, - Expr* rhs_value = nullptr) { +inline bool MatchBroadcastToLeftAxes(const TensorTypeNode* tlhs, const TensorTypeNode* trhs, + const Array& lhs_axes, Expr* rhs_value = nullptr) { if (tlhs->shape.size() < trhs->shape.size()) return false; StructuralEqual equal; size_t base = tlhs->shape.size() - trhs->shape.size(); @@ -145,9 +143,7 @@ inline bool MatchBroadcastToLeftAxes(const TensorTypeNode* tlhs, * \param target_ndim Target dimension. * \param axes The axis on the output we want to match on. */ -inline Expr ExpandBiasToMatchAxis(Expr bias, - int target_ndim, - const Array& axes) { +inline Expr ExpandBiasToMatchAxis(Expr bias, int target_ndim, const Array& axes) { static const Op& expand_dims = Op::Get("expand_dims"); for (size_t i = axes.size(); i != 0; --i) { if (i == axes.size()) { @@ -179,14 +175,12 @@ inline Expr ExpandBiasToMatchAxis(Expr bias, * \param param The conv2d attributes. * \return Whether it is depthwise_conv2d. */ -inline bool IsDepthwiseConv2D(const Call& call, - const Conv2DAttrs* param, +inline bool IsDepthwiseConv2D(const Call& call, const Conv2DAttrs* param, const Layout& kernel_layout) { static const Layout kOIHW("OIHW"); const auto bilayout = tir::BijectiveLayout(kernel_layout, kOIHW); auto wshape = bilayout.ForwardShape(call->args[1]->type_as()->shape); - return tir::is_const_int(wshape[0], param->groups) && - tir::is_const_int(wshape[1], 1); + return tir::is_const_int(wshape[0], param->groups) && tir::is_const_int(wshape[1], 1); } /*! @@ -195,12 +189,12 @@ inline bool IsDepthwiseConv2D(const Call& call, * \return Super-dimension size of output channels of conv2d. */ inline int64_t GetConv2DSuperChannelsDim(const CallNode* call) { - auto param = call->attrs.as(); - auto tweight = call->args[1]->type_as(); - auto index = param->kernel_layout.find('O'); - CHECK_NE(index, std::string::npos); - auto channels = tir::as_const_int(tweight->shape[index]); - return *channels; + auto param = call->attrs.as(); + auto tweight = call->args[1]->type_as(); + auto index = param->kernel_layout.find('O'); + CHECK_NE(index, std::string::npos); + auto channels = tir::as_const_int(tweight->shape[index]); + return *channels; } /*! @@ -289,6 +283,53 @@ static inline Constant MakeConstantTensor(DataType dtype, std::vector s return Constant(arr); } +/*! + * \brief Create a Constant with a tensor. + * + * \param dtype The data type. + * \param value The array of the tensor values. + * \return A Constant. + */ +template +static inline Constant MakeConstantTensor(DataType dtype, std::vector shape, + Array value) { + runtime::NDArray arr = runtime::NDArray::Empty(shape, dtype, {kDLCPU, 0}); + TVM_DTYPE_DISPATCH(dtype, DType, { + for (size_t i = 0; i < value.size(); i++) { + if (dtype == DataType::Float(16)) { + // convert to float16 + // storage is uint16_t + // Similar handling as that in MakeConstantScalar + *(static_cast(arr->data) + i) = + __truncXfYf2__( + static_cast(value[i])); + } else { + *(static_cast(arr->data) + i) = value[i]; + } + } + }) + return Constant(arr); +} + +/*! + * \brief Check whether a shape is static and create corresponding Constant. + * + * \param shape The Array of the shape values. + * \return A Constant. + */ +static inline Constant CheckConstantShape(const Array& shape) { + auto shape_array = + runtime::NDArray::Empty({int64_t(shape.size())}, DataType::Int(64), {kDLCPU, 0}); + auto* shape_data = static_cast(shape_array->data); + for (size_t i = 0; i < shape.size(); ++i) { + const auto& dim_val = shape[i].as(); + CHECK(dim_val) << "Do not support symbolic shape for " + "Array format. Pass shape as Expr instead."; + shape_data[i] = dim_val->value; + } + return Constant(shape_array); +} + /*! * \brief Check if two expressions are equal scalars. * \param a The expression to be checked. @@ -304,14 +345,71 @@ inline bool IsEqualScalar(const Expr& a, const Expr& b) { return tvm::StructuralEqual()(a, b); } -inline Expr GetField(Expr t, size_t i) { - return TupleGetItem(t, i); +/*! + * \brief Convert an element of a NDArray with type int or float to scalar. + * \param array Input NDArray + * \param i element index + * \return Converted scalar value. + */ +static inline double ToScalar(const runtime::NDArray& array, size_t i = 0) { + if (array->dtype.code == kDLInt) { + if (array->dtype.bits == 8) { + return reinterpret_cast(array->data)[i]; + } else if (array->dtype.bits == 16) { + return reinterpret_cast(array->data)[i]; + } else if (array->dtype.bits == 32) { + return reinterpret_cast(array->data)[i]; + } else if (array->dtype.bits == 64) { + return reinterpret_cast(array->data)[i]; + } + } else if (array->dtype.code == kDLUInt) { + if (array->dtype.bits == 8) { + return reinterpret_cast(array->data)[i]; + } else if (array->dtype.bits == 16) { + return reinterpret_cast(array->data)[i]; + } else if (array->dtype.bits == 32) { + return reinterpret_cast(array->data)[i]; + } else if (array->dtype.bits == 64) { + return reinterpret_cast(array->data)[i]; + } + } else if (array->dtype.code == kDLFloat) { +#if (__ARM_FP16_FORMAT_IEEE == 1) + if (array->dtype.bits == 16) { + return reinterpret_cast<__fp16*>(array->data)[i]; + } +#endif + if (array->dtype.bits == 32) { + return reinterpret_cast(array->data)[i]; + } else if (array->dtype.bits == 64) { + return reinterpret_cast(array->data)[i]; + } + } + LOG(FATAL) << "Unknown data type: " << tvm::runtime::DLDataType2String(array->dtype); + // make compiler happy + return -std::numeric_limits::infinity(); } -inline Expr Pair(Expr l, Expr r) { - return Tuple({l, r}); +/*! + * \brief Convert a NDArray with type int or float to Array. + * \param array Input NDArray + * \return Converted Array. + */ +static inline Array ToVector(const runtime::NDArray& array) { + size_t ndim = array.Shape().size(); + CHECK_EQ(ndim, 1) << "This function should only used for shape tensor."; + size_t len = array.Shape().front(); + Array out; + for (size_t i = 0; i < len; ++i) { + double elem_val = ToScalar(array, i); + out.push_back(Integer(static_cast(elem_val))); + } + return out; } +inline Expr GetField(Expr t, size_t i) { return TupleGetItem(t, i); } + +inline Expr Pair(Expr l, Expr r) { return Tuple({l, r}); } + inline Expr Exp(Expr e) { static const Op& op = Op::Get("exp"); return Call(op, {e}); @@ -362,25 +460,21 @@ inline Expr Negative(Expr x) { return Call(op, {x}, Attrs(), {}); } - inline Expr Sqrt(Expr x) { static const Op& op = Op::Get("sqrt"); return Call(op, {x}, Attrs(), {}); } - inline Expr Relu(Expr x) { static const Op& op = Op::Get("nn.relu"); return Call(op, {x}, Attrs(), {}); } - inline Expr Round(Expr x) { static const Op& op = Op::Get("round"); return Call(op, {x}, Attrs(), {}); } - inline Expr Clip(Expr x, double a_min, double a_max) { static const Op& op = Op::Get("clip"); auto attrs = make_object(); @@ -389,25 +483,21 @@ inline Expr Clip(Expr x, double a_min, double a_max) { return Call(op, {x}, Attrs(attrs), {}); } - inline Expr Add(Expr lhs, Expr rhs) { static const Op& op = Op::Get("add"); return Call(op, {lhs, rhs}, Attrs(), {}); } - inline Expr Subtract(Expr lhs, Expr rhs) { static const Op& op = Op::Get("subtract"); return Call(op, {lhs, rhs}, Attrs(), {}); } - inline Expr Multiply(Expr lhs, Expr rhs) { static const Op& op = Op::Get("multiply"); return Call(op, {lhs, rhs}, Attrs(), {}); } - inline Expr Divide(Expr lhs, Expr rhs) { static const Op& op = Op::Get("divide"); return Call(op, {lhs, rhs}, Attrs(), {}); @@ -423,12 +513,10 @@ inline Expr ZerosLike(Expr e) { return Call(op, {e}); } +Expr MakeZeros(Expr shape, DataType dtype); + inline Expr Zeros(Array shape, DataType dtype) { - auto attrs = make_object(); - attrs->shape = std::move(shape); - attrs->dtype = std::move(dtype); - static const Op& op = Op::Get("zeros"); - return Call(op, {}, Attrs(attrs), {}); + return MakeZeros(CheckConstantShape(shape), dtype); } inline Expr OnesLike(Expr e) { @@ -446,31 +534,26 @@ inline Expr Power(Expr lhs, Expr rhs) { return Call(op, {lhs, rhs}, Attrs(), {}); } - inline Expr RightShift(Expr x, Expr nbit) { static const Op& op = Op::Get("right_shift"); return Call(op, {x, nbit}, Attrs(), {}); } - inline Expr LeftShift(Expr x, Expr nbit) { static const Op& op = Op::Get("left_shift"); return Call(op, {x, nbit}, Attrs(), {}); } - inline Expr ReshapeLike(Expr lhs, Expr rhs) { static const Op& op = Op::Get("reshape_like"); return Call(op, {lhs, rhs}, Attrs(), {}); } - inline Expr Copy(Expr data) { static const Op& op = Op::Get("copy"); return Call(op, {data}, Attrs(), {}); } - inline Expr Mean(Expr data, Array axis, bool keepdims, bool exclude) { auto attrs = make_object(); attrs->axis = std::move(axis); @@ -489,7 +572,6 @@ inline Expr Variance(Expr data, Expr mean, Array axis, bool keepdims, b return Call(op, {data, mean}, Attrs(attrs), {}); } - static inline Expr Where(const Expr& condition, const Expr& x, const Expr& y) { static const Op& op = Op::Get("where"); return Call(op, {condition, x, y}); @@ -500,14 +582,10 @@ static inline Expr GreaterEqual(const Expr& lhs, const Expr& rhs) { return Call(op, {lhs, rhs}, Attrs(), {}); } -static inline Expr Full(Expr fill_value, - Array shape, - DataType dtype) { - auto attrs = make_object(); - attrs->shape = std::move(shape); - attrs->dtype = std::move(dtype); - static const Op& op = Op::Get("full"); - return Call(op, {fill_value}, Attrs(attrs), {}); +Expr MakeFull(Expr fill_value, Expr shape, DataType dtype); + +static inline Expr Full(Expr fill_value, Array shape, DataType dtype) { + return MakeFull(fill_value, CheckConstantShape(shape), dtype); } static inline Expr Conv2D(Expr data, Expr weight, Array strides, @@ -529,10 +607,7 @@ static inline Expr Conv2D(Expr data, Expr weight, Array strides, return Call(op, {data, weight}, Attrs(attrs), {}); } -static inline Expr Dense(Expr data, - Expr weight, - IndexExpr units, - DataType out_dtype) { +static inline Expr Dense(Expr data, Expr weight, IndexExpr units, DataType out_dtype) { auto attrs = make_object(); attrs->units = units; attrs->out_dtype = out_dtype; @@ -549,12 +624,12 @@ static inline Expr Sum(Expr data, Array axis, bool keepdims, bool exclu return Call(op, {data}, Attrs(attrs), {}); } +Expr MakeReshape(Expr data, Expr newshape); + static inline Expr Reshape(Expr data, Array newshape) { - auto attrs = make_object(); - attrs->newshape = std::move(newshape); - attrs->reverse = false; - static const Op& op = Op::Get("reshape"); - return Call(op, {data}, Attrs(attrs), {}); + auto newshape_tensor = + MakeConstantTensor(DataType::Int(32), {static_cast(newshape.size())}, newshape); + return MakeReshape(data, newshape_tensor); } static inline Expr AvgPool2D(Expr data, Array pool_size, Array strides, @@ -588,13 +663,17 @@ static inline Expr Tile(Expr data, Array reps) { return Call(op, {data}, Attrs(attrs), {}); } -Expr MakeBroadCastTo(Expr data, Array shape); +Expr MakeBroadCastTo(Expr data, Expr shape); + +static inline Expr BroadCastTo(Expr data, Array shape) { + return MakeBroadCastTo(data, CheckConstantShape(shape)); +} Expr MakeConcatenate(Expr data, int axis); Expr MakeRepeat(Expr data, int repeats, int axis); -Expr MakeStridedSlice(Expr data, Array begin, Array end, Array strides); +Expr MakeStridedSlice(Expr data, Expr begin, Expr end, Expr strides, String slice_mode); Expr MakeStack(Expr data, int axis); @@ -604,7 +683,7 @@ Expr MakeSqueeze(Expr data, Array axis); Expr MakeExpandDims(Expr data, int axis, int num_newaxis); -Expr MakeLayoutTransform(Expr data, std::string src_layout, std::string dst_layout); +Expr MakeLayoutTransform(Expr data, String src_layout, String dst_layout); Expr StopFusion(Expr data); diff --git a/src/relay/transforms/simplify_fc_transpose.cc b/src/relay/transforms/simplify_fc_transpose.cc new file mode 100644 index 000000000000..99ded0ba591d --- /dev/null +++ b/src/relay/transforms/simplify_fc_transpose.cc @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * + * \file simplify_fc_transpose.cc + * + * \brief Mutate ```y = nn.dense(x, tranpose(w, [1, 0]))``` to + * ```y = nn.dense(x, wt)``` + */ +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace tvm { +namespace relay { + +// Find name of weight in ```y = nn.dense(x, tranpose(w, [1, 0]))``` +class FCTransposeVisitor : private ExprVisitor { + public: + FCTransposeVisitor() : dense_op_(Op::Get("nn.dense")), transpose_op_(Op::Get("transpose")) {} + + Array Search(const Expr& expr) { + VisitExpr(expr); + return memo_; + } + + private: + void VisitExpr_(const CallNode* n) final { + if (n->op == dense_op_) { + const auto weight = n->args[1].as(); + if (weight) { + if (weight->op == transpose_op_) { + if (weight->args[0].as()) { + const auto arg = weight->args[0].as(); + memo_.push_back(arg->name_hint()); + } + } + } + } + for (const auto& arg : n->args) { + VisitExpr(arg); + } + } + + const Op& dense_op_; + const Op& transpose_op_; + Array memo_; +}; // SearchDenseOpWeight + +Array SearchFCTranspose(const Expr& e) { return FCTransposeVisitor().Search(e); } + +TVM_REGISTER_GLOBAL("relay.analysis.search_fc_transpose").set_body_typed(SearchFCTranspose); + +// Mutate ```y = nn.dense(x, tranpose(w, [1, 0]))``` to ```y = nn.dense(x, wt)``` +class FCTransposeMutator : public ExprRewriter { + public: + explicit FCTransposeMutator(const Array& target_weights) + : dense_op_(Op::Get("nn.dense")), transpose_op_(Op::Get("transpose")) { + for (size_t i = 0; i < target_weights.size(); ++i) { + CHECK(target_weights[i]->IsInstance()); + std::string k = target_weights[i].as()->data; + target_weights_.emplace(k); + } + } + + Expr Rewrite_(const CallNode* pre, const Expr& post) override { + if (pre->op == dense_op_) { + const auto data = post.as()->args[0]; + const auto weight = pre->args[1].as(); + if (weight) { + if (weight->op == transpose_op_) { + const auto arg = weight->args[0]; + if (arg.as()) { + const auto& arg_node = arg.as(); + CHECK_GT(target_weights_.count(arg_node->name_hint()), 0); + const auto& tt = arg_node->type_annotation.as(); + auto wt_type = TensorType({tt->shape[1], tt->shape[0]}, tt->dtype); + Var wt(arg_node->name_hint() + ".T", wt_type); + return Call(dense_op_, {data, wt}, pre->attrs, pre->type_args); + } + } + } + } + return post; + } + + private: + // Cached op + const Op& dense_op_; + const Op& transpose_op_; + std::unordered_set target_weights_; +}; // class DenseToSparseDenseAlter + +Expr SimplifyFCTranspose(const Expr& e, const Array& target_weights) { + auto rewriter = FCTransposeMutator(target_weights); + return PostOrderRewrite(e, &rewriter); +} + +namespace transform { + +Pass SimplifyFCTranspose(const Array& target_weights) { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { + // Remove FreeVar warning + auto f0 = Downcast(SimplifyFCTranspose(f, target_weights)); + Array wt_params = FreeVars(f0); + auto f1 = Function(wt_params, f0->body, f0->ret_type, f0->type_params, f0->attrs); + Array params = FreeVars(f1); + for (const auto& var : wt_params) { + params.push_back(var); + } + return Function(params, f1->body, f1->ret_type, f1->type_params, f1->attrs); + }; + return CreateFunctionPass(pass_func, 4, "SimplifyFCTranspose", {"DeadCodeElimination"}); +} + +TVM_REGISTER_GLOBAL("relay._transform.SimplifyFCTranspose").set_body_typed(SimplifyFCTranspose); + +} // namespace transform + +} // namespace relay +} // namespace tvm diff --git a/src/relay/transforms/simplify_inference.cc b/src/relay/transforms/simplify_inference.cc index d349fdddeeea..8728e90f55a3 100644 --- a/src/relay/transforms/simplify_inference.cc +++ b/src/relay/transforms/simplify_inference.cc @@ -21,22 +21,18 @@ * \file simplify_inference.cc */ #include -#include #include -#include +#include #include +#include + #include "pattern_util.h" namespace tvm { namespace relay { -Expr BatchNormToInferUnpack(const Attrs attrs, - Expr data, - Expr gamma, - Expr beta, - Expr moving_mean, - Expr moving_var, - Type tdata) { +Expr BatchNormToInferUnpack(const Attrs attrs, Expr data, Expr gamma, Expr beta, Expr moving_mean, + Expr moving_var, Type tdata) { auto ttype = tdata.as(); CHECK(ttype); const auto param = attrs.as(); @@ -64,11 +60,62 @@ Expr BatchNormToInferUnpack(const Attrs attrs, return out; } -Expr LayerNormToInferUnpack(const Attrs attrs, - Expr data, - Expr gamma, - Expr beta, - Type tdata) { +Expr GroupNormToInferUnpack(const Attrs attrs, Expr data, Expr gamma, Expr beta, Type tdata) { + auto ttype = tdata.as(); + CHECK(ttype); + const auto param = attrs.as(); + CHECK(param); + + int ndim = ttype->shape.size(); + int axis = (param->axis < 0) ? param->axis + ndim : param->axis; + Array reduced_axes; + Array new_shape; + Array old_shape; + + int num_groups = param->num_groups; + int channel = ttype->shape[axis].as()->value; + + // old_shape = N, C, H, W + // new shape = N, num_groups, C/num_groups, H, W + // reduce_axes = axis of (C/num_groups, H, W) + for (int i = 0; i < ndim; ++i) { + auto val = ttype->shape[i].as()->value; + + // Save the old shape to reshape later + old_shape.push_back(val); + if (i == axis) { + new_shape.push_back(num_groups); + new_shape.push_back(channel / num_groups); + reduced_axes.push_back(i + 1); + continue; + } + if (i >= axis) { + reduced_axes.push_back(i + 1); + } + new_shape.push_back(val); + } + + data = Reshape(data, new_shape); + + Expr epsilon = MakeConstantScalar(ttype->dtype, static_cast(param->epsilon)); + Expr mean = Mean(data, {reduced_axes}, true, false); + Expr var = Variance(data, mean, {reduced_axes}, true, false); + Expr denom = Sqrt(Add(var, epsilon)); + Expr out = Divide(Subtract(data, mean), denom); + + out = Reshape(out, old_shape); + + if (param->scale) { + out = Multiply(out, ExpandBiasToMatchAxis(gamma, ndim, {axis})); + } + if (param->center) { + out = Add(out, ExpandBiasToMatchAxis(beta, ndim, {axis})); + } + + return out; +} + +Expr LayerNormToInferUnpack(const Attrs attrs, Expr data, Expr gamma, Expr beta, Type tdata) { auto ttype = tdata.as(); CHECK(ttype); const auto param = attrs.as(); @@ -91,11 +138,7 @@ Expr LayerNormToInferUnpack(const Attrs attrs, return out; } -Expr InstanceNormToInferUnpack(const Attrs attrs, - Expr data, - Expr gamma, - Expr beta, - Type tdata) { +Expr InstanceNormToInferUnpack(const Attrs attrs, Expr data, Expr gamma, Expr beta, Type tdata) { auto ttype = tdata.as(); CHECK(ttype); const auto param = attrs.as(); @@ -105,8 +148,7 @@ Expr InstanceNormToInferUnpack(const Attrs attrs, int axis = (param->axis < 0) ? param->axis + ndim : param->axis; Array reduced_axes; for (int i = 1; i < ndim; ++i) { - if (i != axis) - reduced_axes.push_back(i); + if (i != axis) reduced_axes.push_back(i); } Expr epsilon = MakeConstantScalar(DataType::Float(32), static_cast(param->epsilon)); @@ -143,6 +185,7 @@ class InferenceSimplifier : public ExprMutator { dropout_op_(Op::Get("nn.dropout")), instance_norm_op_(Op::Get("nn.instance_norm")), layer_norm_op_(Op::Get("nn.layer_norm")), + group_norm_op_(Op::Get("nn.group_norm")), l2_norm_op_(Op::Get("nn.l2_normalize")) {} Expr VisitExpr_(const TupleGetItemNode* n) final { @@ -170,6 +213,10 @@ class InferenceSimplifier : public ExprMutator { const auto* call = new_n.as(); return LayerNormToInferUnpack(call->attrs, call->args[0], call->args[1], call->args[2], n->args[0]->checked_type()); + } else if (n->op == group_norm_op_) { + const auto* call = new_n.as(); + return GroupNormToInferUnpack(call->attrs, call->args[0], call->args[1], call->args[2], + n->args[0]->checked_type()); } else if (n->op == instance_norm_op_) { const auto* call = new_n.as(); return InstanceNormToInferUnpack(call->attrs, call->args[0], call->args[1], call->args[2], @@ -189,26 +236,24 @@ class InferenceSimplifier : public ExprMutator { const Op& dropout_op_; const Op& instance_norm_op_; const Op& layer_norm_op_; + const Op& group_norm_op_; const Op& l2_norm_op_; - std::unordered_map ty_map_; + std::unordered_map ty_map_; }; -Expr SimplifyInference(const Expr& e) { - return InferenceSimplifier().Mutate(e); -} +Expr SimplifyInference(const Expr& e) { return InferenceSimplifier().Mutate(e); } namespace transform { Pass SimplifyInference() { runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - return Downcast(SimplifyInference(f)); - }; + [=](Function f, IRModule m, PassContext pc) { + return Downcast(SimplifyInference(f)); + }; return CreateFunctionPass(pass_func, 0, "SimplifyInference", {"InferType"}); } -TVM_REGISTER_GLOBAL("relay._transform.SimplifyInference") -.set_body_typed(SimplifyInference); +TVM_REGISTER_GLOBAL("relay._transform.SimplifyInference").set_body_typed(SimplifyInference); } // namespace transform diff --git a/src/relay/transforms/to_a_normal_form.cc b/src/relay/transforms/to_a_normal_form.cc index 21c516201dd7..8d1024217a1e 100644 --- a/src/relay/transforms/to_a_normal_form.cc +++ b/src/relay/transforms/to_a_normal_form.cc @@ -26,12 +26,12 @@ #include #include #include -#include #include -#include "let_list.h" -#include "pass_util.h" + #include "../../support/arena.h" #include "../analysis/dependency_graph.h" +#include "let_list.h" +#include "pass_util.h" namespace tvm { namespace relay { @@ -47,13 +47,11 @@ struct ScopeNode { size_t level; Scope parent; std::shared_ptr ll = std::make_shared(); - explicit ScopeNode(const Scope& parent) : level(1 + parent->level), parent(parent) { } - ScopeNode() : level(0) { } + explicit ScopeNode(const Scope& parent) : level(1 + parent->level), parent(parent) {} + ScopeNode() : level(0) {} }; -Scope ChildScope(const Scope& s) { - return std::make_shared(s); -} +Scope ChildScope(const Scope& s) { return std::make_shared(s); } Scope LCA(Scope lhs, Scope rhs) { while (lhs != rhs) { @@ -100,8 +98,7 @@ std::unordered_map CalcScope(const DependencyGrap */ class Fill : ExprFunctor { public: - static Expr ToANormalForm(const Expr& e, - const DependencyGraph& dg, + static Expr ToANormalForm(const Expr& e, const DependencyGraph& dg, std::unordered_map* node_scope) { Fill fi(dg, node_scope); return fi.GetScope(e)->ll->Get(fi.VisitExpr(e)); @@ -110,16 +107,12 @@ class Fill : ExprFunctor { private: const DependencyGraph& dg_; std::unordered_map* node_scope_; - std::unordered_map memo; + std::unordered_map memo; - Fill(const DependencyGraph& dg, - std::unordered_map* node_scope) : - dg_(dg), - node_scope_(node_scope) { } + Fill(const DependencyGraph& dg, std::unordered_map* node_scope) + : dg_(dg), node_scope_(node_scope) {} - Scope GetScope(const Expr& e) { - return node_scope_->at(dg_.expr_node.at(e)); - } + Scope GetScope(const Expr& e) { return node_scope_->at(dg_.expr_node.at(e)); } Scope GetSubScope(const Expr& e, size_t i) { DependencyGraph::Node* n = dg_.expr_node.at(e); @@ -144,18 +137,12 @@ class Fill : ExprFunctor { return ret; } - Expr VisitExpr(const Expr& e) { - return this->VisitExpr(e, Var()); - } + Expr VisitExpr(const Expr& e) { return this->VisitExpr(e, Var()); } - Expr Atomic(const Expr& e, const Var& v) { - return v.defined() ? GetScope(e)->ll->Push(v, e) : e; - } + Expr Atomic(const Expr& e, const Var& v) { return v.defined() ? GetScope(e)->ll->Push(v, e) : e; } Expr Compound(const Expr& orig, const Expr& now, const Var& v) { - Var var = v.defined() ? - v : - Var(std::string("x"), Type()); + Var var = v.defined() ? v : Var(String("x"), Type()); return GetScope(orig)->ll->Push(var, now); } @@ -199,9 +186,8 @@ class Fill : ExprFunctor { Expr VisitExpr_(const IfNode* i, const Var& v) final { Expr e = GetRef(i); - Expr ret = If(VisitExpr(i->cond), - GetSubScope(e, 1)->ll->Get(VisitExpr(i->true_branch)), - GetSubScope(e, 2)->ll->Get(VisitExpr(i->false_branch))); + Expr ret = If(VisitExpr(i->cond), GetSubScope(e, 1)->ll->Get(VisitExpr(i->true_branch)), + GetSubScope(e, 2)->ll->Get(VisitExpr(i->false_branch))); return Compound(e, ret, v); } @@ -211,11 +197,8 @@ class Fill : ExprFunctor { if (f->HasNonzeroAttr(attr::kPrimitive)) { ret = e; } else { - ret = Function(f->params, - GetSubScope(e, 0)->ll->Get(VisitExpr(f->body)), - f->ret_type, - f->type_params, - f->attrs); + ret = Function(f->params, GetSubScope(e, 0)->ll->Get(VisitExpr(f->body)), f->ret_type, + f->type_params, f->attrs); } return Compound(e, ret, v); } @@ -257,9 +240,8 @@ class Fill : ExprFunctor { Expr data = VisitExpr(m->data); std::vector clauses; for (const Clause& c : m->clauses) { - clauses.push_back(Clause( - c->lhs, - GetSubScope(e, 1 + clauses.size())->ll->Get(VisitExpr(c->rhs)))); + clauses.push_back( + Clause(c->lhs, GetSubScope(e, 1 + clauses.size())->ll->Get(VisitExpr(c->rhs)))); } return Compound(e, Match(data, clauses, m->complete), v); } @@ -301,14 +283,9 @@ IRModule ToANormalForm(const IRModule& m) { if (const auto* n = it.second.as()) { if (n->GetAttr(attr::kCompiler).defined()) continue; } - Expr ret = - TransformF([&](const Expr& e) { - return ToANormalFormAux(e); - }, it.second); + Expr ret = TransformF([&](const Expr& e) { return ToANormalFormAux(e); }, it.second); CHECK_EQ(FreeVars(ret).size(), 0) - << AsText(ret) - << "should not has free vars: " - << FreeVars(ret); + << AsText(ret) << "should not has free vars: " << FreeVars(ret); updates.Set(it.first, Downcast(ret)); } @@ -325,14 +302,11 @@ namespace transform { Pass ToANormalForm() { runtime::TypedPackedFunc pass_func = - [=](IRModule m, PassContext pc) { - return relay::ToANormalForm(m); - }; + [=](IRModule m, PassContext pc) { return relay::ToANormalForm(m); }; return CreateModulePass(pass_func, 1, "ToANormalForm", {}); } -TVM_REGISTER_GLOBAL("relay._transform.ToANormalForm") -.set_body_typed(ToANormalForm); +TVM_REGISTER_GLOBAL("relay._transform.ToANormalForm").set_body_typed(ToANormalForm); } // namespace transform diff --git a/src/relay/transforms/to_cps.cc b/src/relay/transforms/to_cps.cc index e6c83928b098..6972d5a76b77 100644 --- a/src/relay/transforms/to_cps.cc +++ b/src/relay/transforms/to_cps.cc @@ -51,9 +51,10 @@ * wheter directly invoking it, or indirectly by recursion. */ #include -#include #include #include +#include + #include "let_list.h" #include "pass_util.h" @@ -62,9 +63,7 @@ namespace relay { // we assume the data type has no closure - no idea how to look into datatype right now. -Type Arrow(const Type& l, const Type& r) { - return FuncType({l}, r, {}, {}); -} +Type Arrow(const Type& l, const Type& r) { return FuncType({l}, r, {}, {}); } Type CPSType(const Type& t, const TypeVar& answer); @@ -79,7 +78,7 @@ FuncType CPSFuncType(const FuncType& f, const TypeVar& answer) { Type CPSType(const Type& t, const TypeVar& answer) { struct CPSTypeMutator : TypeMutator { - explicit CPSTypeMutator(const TypeVar& answer) : answer(answer) { } + explicit CPSTypeMutator(const TypeVar& answer) : answer(answer) {} TypeVar answer; Type VisitType_(const FuncTypeNode* t) final { return CPSFuncType(GetRef(t), answer); @@ -89,10 +88,10 @@ Type CPSType(const Type& t, const TypeVar& answer) { } // transform global functions into cps form. -using CPSMap = std::unordered_map; +using CPSMap = std::unordered_map; // transform vars from the original program into new vars, so their type will be correct. -using VarMap = std::unordered_map; +using VarMap = std::unordered_map; /* * The meta continuation. @@ -113,22 +112,15 @@ using MCont = std::function; Function ToCPS(const Function& f, const IRModule& m, CPSMap* cm); -Function ToCPS(const Function& f, - const IRModule& m, - CPSMap* cm, - VarMap* vm, +Function ToCPS(const Function& f, const IRModule& m, CPSMap* cm, VarMap* vm, const TypeVar& answer) { - std::function remap = [&](const Var& v) { - return vm->count(v) == 0 ? v : vm->at(v); - }; + std::function remap = [&](const Var& v) { return vm->count(v) == 0 ? v : vm->at(v); }; auto function_type = Downcast(f->checked_type()); // Each MCont can be used at most once. struct CPSFunctor : ExprFunctor, PatternMutator { - CPSFunctor(const std::function& remap, - const TypeVar& answer, - const IRModule& m, - VarMap* vm, - CPSMap* cm) : remap(remap), answer(answer), m(m), vm(vm), cm(cm) { } + CPSFunctor(const std::function& remap, const TypeVar& answer, const IRModule& m, + VarMap* vm, CPSMap* cm) + : remap(remap), answer(answer), m(m), vm(vm), cm(cm) {} const std::function& remap; TypeVar answer; IRModule m; @@ -136,9 +128,8 @@ Function ToCPS(const Function& f, CPSMap* cm; Expr VisitExpr_(const LetNode* op, const MCont& k) final { - return VisitExpr(op->value, [&](const Expr& v) { - return Let(remap(op->var), v, VisitExpr(op->body, k)); - }); + return VisitExpr( + op->value, [&](const Expr& v) { return Let(remap(op->var), v, VisitExpr(op->body, k)); }); } Expr VisitExpr_(const FunctionNode* op, const MCont& k) final { @@ -150,13 +141,9 @@ Function ToCPS(const Function& f, return k(GetRef(op)); } - Expr VisitExpr_(const VarNode* op, const MCont& k) final { - return k(remap(GetRef(op))); - } + Expr VisitExpr_(const VarNode* op, const MCont& k) final { return k(remap(GetRef(op))); } - Pattern VisitPattern_(const PatternVarNode* op) final { - return PatternVar(remap(op->var)); - } + Pattern VisitPattern_(const PatternVarNode* op) final { return PatternVar(remap(op->var)); } Expr VisitExpr_(const GlobalVarNode* op, const MCont& k) final { auto gv = GetRef(op); @@ -164,7 +151,7 @@ Function ToCPS(const Function& f, // only look unfold non-external calls. BaseFunc base_func = m->Lookup(gv); if (auto* n = base_func.as()) { - auto cps_gv = GlobalVar(gv->name_hint + "_cps"); + auto cps_gv = GlobalVar(std::string(gv->name_hint) + "_cps"); cm->insert({gv, cps_gv}); m->Add(cps_gv, ToCPS(GetRef(n), m, cm)); } else { @@ -186,16 +173,14 @@ Function ToCPS(const Function& f, } Expr reify(const MCont& k, const std::function& cont) { - return LetList::LetBind(reify(k), - [&](const Var& f) { + return LetList::LetBind(reify(k), [&](const Var& f) { return cont([&](const Expr& e) { return Call(f, {e}); }); }); } Expr VisitExpr_(const IfNode* op, const MCont& k) final { return reify(k, [&](const MCont& kf) { - return VisitExpr(op->cond, - [&](const Expr& v) { + return VisitExpr(op->cond, [&](const Expr& v) { return If(v, VisitExpr(op->true_branch, kf), VisitExpr(op->false_branch, kf)); }); }); @@ -214,19 +199,13 @@ Function ToCPS(const Function& f, } Expr VisitExpr_(const RefReadNode* op, const MCont& k) final { - return VisitExpr(op->ref, - [&](const Expr& r) { - return LetList::LetBind(RefRead(r), k); - }); + return VisitExpr(op->ref, [&](const Expr& r) { return LetList::LetBind(RefRead(r), k); }); } Expr VisitExpr_(const RefWriteNode* op, const MCont& k) final { - return VisitExpr(op->ref, - [&](const Expr& r) { + return VisitExpr(op->ref, [&](const Expr& r) { return VisitExpr(op->value, - [&](const Expr& v) { - return LetList::LetBind(RefWrite(r, v), k); - }); + [&](const Expr& v) { return LetList::LetBind(RefWrite(r, v), k); }); }); } @@ -234,20 +213,18 @@ Function ToCPS(const Function& f, tvm::Array fields; std::function next; next = [&]() { - return (fields.size() == op->fields.size()) ? - k(Tuple(fields)) : - VisitExpr(op->fields[fields.size()], [&](const Expr& v) { - fields.push_back(v); - return next(); - }); + return (fields.size() == op->fields.size()) + ? k(Tuple(fields)) + : VisitExpr(op->fields[fields.size()], [&](const Expr& v) { + fields.push_back(v); + return next(); + }); }; return next(); } Expr VisitExpr_(const TupleGetItemNode* op, const MCont& k) final { - return VisitExpr(op->tuple, [&](const Expr& v) { - return k(TupleGetItem(v, op->index)); - }); + return VisitExpr(op->tuple, [&](const Expr& v) { return k(TupleGetItem(v, op->index)); }); } Expr VisitExpr_(const CallNode* op, const MCont& k) final { @@ -259,9 +236,9 @@ Function ToCPS(const Function& f, return LetList::LetBind(Call(op->op, args, op->attrs, op->type_args), k); } else { return VisitExpr(op->args[args.size()], [&](const Expr& v) { - args.push_back(v); - return next(); - }); + args.push_back(v); + return next(); + }); } }; return next(); @@ -279,7 +256,7 @@ Function ToCPS(const Function& f, return next(); }); } - }; + }; return VisitExpr(op->op, [&](const Expr& v) { f = v; return next(); @@ -293,19 +270,15 @@ Function ToCPS(const Function& f, new_params.push_back(remap(v)); } new_params.push_back(k); - return Function(new_params, - mut.VisitExpr(f->body, - [&](const Expr& e) { return Call(k, {e}); }), - answer, - f->type_params, - f->attrs); + return Function(new_params, mut.VisitExpr(f->body, [&](const Expr& e) { return Call(k, {e}); }), + answer, f->type_params, f->attrs); } Function ToCPS(const Function& f, const IRModule& m, CPSMap* cm) { TypeVar answer = TypeVar("answer", kType); VarMap var; struct Remapper : ExprVisitor, PatternVisitor { - Remapper(const TypeVar& answer, VarMap* vm) : answer(answer), vm(vm) { } + Remapper(const TypeVar& answer, VarMap* vm) : answer(answer), vm(vm) {} TypeVar answer; VarMap* vm; void VisitExpr_(const VarNode* vn) final { @@ -316,13 +289,9 @@ Function ToCPS(const Function& f, const IRModule& m, CPSMap* cm) { } } - void VisitPattern(const Pattern& p) final { - PatternVisitor::VisitPattern(p); - } + void VisitPattern(const Pattern& p) final { PatternVisitor::VisitPattern(p); } - void VisitPattern_(const PatternVarNode* op) final { - VisitExpr(op->var); - } + void VisitPattern_(const PatternVarNode* op) final { VisitExpr(op->var); } } remap(answer, &var); remap.VisitExpr(f); Function ret = ToCPS(f, m, cm, &var, answer); @@ -366,43 +335,32 @@ Function UnCPS(const Function& f) { type_args.push_back(tp); } type_args.push_back(new_ret_type); - return Function(new_params, - Call(f, args, {}, type_args), - new_ret_type, - new_type_params, - f->attrs); + return Function(new_params, Call(f, args, {}, type_args), new_ret_type, new_type_params, + f->attrs); } TVM_REGISTER_GLOBAL("relay._transform.to_cps") -.set_body_typed(static_cast(ToCPS)); + .set_body_typed(static_cast(ToCPS)); -TVM_REGISTER_GLOBAL("relay._transform.un_cps") -.set_body_typed(UnCPS); +TVM_REGISTER_GLOBAL("relay._transform.un_cps").set_body_typed(UnCPS); namespace transform { Pass ToCPS() { runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - return Function(ToCPS(f, m)); - }; + [=](Function f, IRModule m, PassContext pc) { return Function(ToCPS(f, m)); }; return CreateFunctionPass(pass_func, 1, "ToCPS", {}); } -TVM_REGISTER_GLOBAL("relay._transform.ToCPS") -.set_body_typed(ToCPS); - +TVM_REGISTER_GLOBAL("relay._transform.ToCPS").set_body_typed(ToCPS); Pass UnCPS() { runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - return Function(UnCPS(f)); - }; + [=](Function f, IRModule m, PassContext pc) { return Function(UnCPS(f)); }; return CreateFunctionPass(pass_func, 1, "UnCPS", {}); } -TVM_REGISTER_GLOBAL("relay._transform.UnCPS") -.set_body_typed(UnCPS); +TVM_REGISTER_GLOBAL("relay._transform.UnCPS").set_body_typed(UnCPS); } // namespace transform diff --git a/src/relay/transforms/to_graph_normal_form.cc b/src/relay/transforms/to_graph_normal_form.cc index 8bf41a4610c0..ff5ff568b048 100644 --- a/src/relay/transforms/to_graph_normal_form.cc +++ b/src/relay/transforms/to_graph_normal_form.cc @@ -26,6 +26,7 @@ #include #include #include + #include "let_list.h" namespace tvm { @@ -33,7 +34,7 @@ namespace relay { class UseVarVisitor : public ExprVisitor { public: - explicit UseVarVisitor(const Var& v) : v(v) { } + explicit UseVarVisitor(const Var& v) : v(v) {} static bool UseVar(const Var& v, const Expr& e) { UseVarVisitor uv(v); @@ -45,22 +46,18 @@ class UseVarVisitor : public ExprVisitor { bool use_var = false; Var v; - void VisitExpr_(const VarNode* vn) override { - use_var = use_var || (v == GetRef(vn)); - } + void VisitExpr_(const VarNode* vn) override { use_var = use_var || (v == GetRef(vn)); } }; class GNF : public ExprMutator { private: - std::unordered_map var_map_; + std::unordered_map var_map_; Expr VisitExpr_(const VarNode* vn) override { Var v = GetRef(vn); return var_map_.count(v) == 0 ? v : var_map_.at(v); } - static bool UseVar(const Var& v, const Expr& e) { - return UseVarVisitor::UseVar(v, e); - } + static bool UseVar(const Var& v, const Expr& e) { return UseVarVisitor::UseVar(v, e); } static Expr WrapRec(const Var& var, const Expr& val) { return UseVar(var, val) ? Let(var, val, var) : val; @@ -72,22 +69,19 @@ class GNF : public ExprMutator { } }; -Expr ToGraphNormalForm(const Expr& e) { - return GNF()(e); -} +Expr ToGraphNormalForm(const Expr& e) { return GNF()(e); } namespace transform { Pass ToGraphNormalForm() { runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - return Downcast(ToGraphNormalForm(f)); - }; + [=](Function f, IRModule m, PassContext pc) { + return Downcast(ToGraphNormalForm(f)); + }; return CreateFunctionPass(pass_func, 1, "ToGraphNormalForm", {}); } -TVM_REGISTER_GLOBAL("relay._transform.ToGraphNormalForm") -.set_body_typed(ToGraphNormalForm); +TVM_REGISTER_GLOBAL("relay._transform.ToGraphNormalForm").set_body_typed(ToGraphNormalForm); } // namespace transform diff --git a/src/relay/transforms/transform_layout.h b/src/relay/transforms/transform_layout.h index b6e75ae4f585..19632defc826 100644 --- a/src/relay/transforms/transform_layout.h +++ b/src/relay/transforms/transform_layout.h @@ -26,14 +26,16 @@ #ifndef TVM_RELAY_TRANSFORMS_TRANSFORM_LAYOUT_H_ #define TVM_RELAY_TRANSFORMS_TRANSFORM_LAYOUT_H_ -#include #include +#include + #include -#include #include +#include #include -#include "pattern_util.h" + #include "infer_layout_util.h" +#include "pattern_util.h" namespace tvm { namespace relay { @@ -49,8 +51,8 @@ class TransformMemorizerNode : public Object { struct key_hash : public std::function { std::size_t operator()(const TransformKey& k) const { return dmlc::HashCombine( - dmlc::HashCombine( - std::hash()(std::get<0>(k)), std::get<1>(k)), + dmlc::HashCombine(std::hash()(std::get<0>(k)), + std::get<1>(k)), (std::get<2>(k))); } }; @@ -300,8 +302,7 @@ Expr LayoutRewriter(const Call& ref_call, const Array& new_args, const Obj // new_in2, new_out = op.infer(new_in) if (new_call->op->IsInstance()) { success = false; - std::tie(new_in2, new_out, success) = - InferCorrectLayouts(new_call, new_in, old_in, types); + std::tie(new_in2, new_out, success) = InferCorrectLayouts(new_call, new_in, old_in, types); if (!success) { return Expr(nullptr); } diff --git a/src/relay/transforms/type_infer.cc b/src/relay/transforms/type_infer.cc index 3a16d8ff793b..45e1af1c960f 100644 --- a/src/relay/transforms/type_infer.cc +++ b/src/relay/transforms/type_infer.cc @@ -37,14 +37,15 @@ * If we can not infer a type or there are conflicting typing * constraints we will trigger an error. */ -#include #include +#include +#include #include #include -#include #include -#include "pass_util.h" + #include "../analysis/type_solver.h" +#include "pass_util.h" namespace tvm { namespace relay { @@ -53,21 +54,16 @@ namespace relay { struct TupleGetItemAttrs : public tvm::AttrsNode { int index; - TVM_DECLARE_ATTRS(TupleGetItemAttrs, "relay.attrs.TupleGetItemAttrs") { - TVM_ATTR_FIELD(index); - } + TVM_DECLARE_ATTRS(TupleGetItemAttrs, "relay.attrs.TupleGetItemAttrs") { TVM_ATTR_FIELD(index); } }; -bool TupleGetItemRel(const Array& types, - int num_inputs, - const Attrs& attrs, +bool TupleGetItemRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(types.size(), 2); if (types[0].as()) return false; const auto* data = types[0].as(); - CHECK(data != nullptr) - << "TupleGetItem expect input type to be TupleType " - << " get " << types[0] << " instead"; + CHECK(data != nullptr) << "TupleGetItem expect input type to be TupleType " + << " get " << types[0] << " instead"; const auto* param = attrs.as(); CHECK(param != nullptr); CHECK_GE(param->index, 0); @@ -77,9 +73,7 @@ bool TupleGetItemRel(const Array& types, } TVM_REGISTER_NODE_TYPE(TupleGetItemAttrs); -TVM_REGISTER_GLOBAL("tvm.relay.type_relation.TupleGetItem") -.set_body_typed( - TupleGetItemRel); +TVM_REGISTER_GLOBAL("tvm.relay.type_relation.TupleGetItem").set_body_typed(TupleGetItemRel); struct ResolvedTypeInfo { explicit ResolvedTypeInfo(Type checked_type, Array type_args) @@ -105,8 +99,10 @@ class TypeInferencer : private ExprFunctor, // constructors explicit TypeInferencer(IRModule mod, GlobalVar current_func) - : mod_(mod), current_func_(current_func), - err_reporter(), solver_(current_func, mod, &this->err_reporter) { + : mod_(mod), + current_func_(current_func), + err_reporter(), + solver_(current_func, mod, &this->err_reporter) { CHECK(mod.defined()) << "internal error: Module must be set in the type inferencer"; } @@ -127,7 +123,7 @@ class TypeInferencer : private ExprFunctor, // map from expression to checked type // type inferencer will populate it up - std::unordered_map type_map_; + std::unordered_map type_map_; // The solver used by the inferencer. TypeSolver solver_; @@ -140,22 +136,16 @@ class TypeInferencer : private ExprFunctor, Type Unify(const Type& t1, const Type& t2, const ObjectRef& expr) { try { return solver_.Unify(t1, t2, expr); - } catch (const dmlc::Error &e) { + } catch (const dmlc::Error& e) { this->ReportFatalError( - expr, - ErrorBuilder() - << "Error unifying `" - << t1 - << "` and `" - << t2 - << "`: " << e.what()); + expr, ErrorBuilder() << "Error unifying `" << t1 << "` and `" << t2 << "`: " << e.what()); return Type(); } } // Lazily get type for expr // expression, we will populate it now, and return the result. - Type GetType(const Expr &expr) { + Type GetType(const Expr& expr) { auto it = type_map_.find(expr); if (it != type_map_.end() && it->second.checked_type.defined()) { return it->second.checked_type; @@ -186,19 +176,15 @@ class TypeInferencer : private ExprFunctor, Type VisitExpr_(const GlobalVarNode* op) final { GlobalVar var = GetRef(op); if (!mod_.defined()) { - this->ReportFatalError( - GetRef(op), - ErrorBuilder() << - "Cannot do type inference on global variables " \ - "without a module"); + this->ReportFatalError(GetRef(op), + ErrorBuilder() << "Cannot do type inference on global variables " + "without a module"); } Expr e = mod_->Lookup(var); return e->checked_type(); } - Type VisitExpr_(const ConstantNode* op) final { - return op->tensor_type(); - } + Type VisitExpr_(const ConstantNode* op) final { return op->tensor_type(); } Type VisitExpr_(const TupleNode* op) final { Array types; @@ -209,23 +195,22 @@ class TypeInferencer : private ExprFunctor, } Type VisitExpr_(const TupleGetItemNode* op) final { - if (!tuple_getitem_rel_.defined()) { - tuple_getitem_rel_ = Downcast( - EnvFunc::Get("tvm.relay.type_relation.TupleGetItem")); + if (!tuple_getitem_rel_.defined()) { + tuple_getitem_rel_ = + Downcast(EnvFunc::Get("tvm.relay.type_relation.TupleGetItem")); } Type tuple_type = GetType(op->tuple); Type rtype = IncompleteType(Kind::kType); auto attrs = make_object(); attrs->index = op->index; - solver_.AddConstraint(TypeRelation( - tuple_getitem_rel_, {tuple_type, rtype}, 1, Attrs(attrs)), GetRef(op)); + solver_.AddConstraint(TypeRelation(tuple_getitem_rel_, {tuple_type, rtype}, 1, Attrs(attrs)), + GetRef(op)); return rtype; } void VisitPattern_(const PatternConstructorNode* con, const Type& t) { - CHECK(mod_.defined()) - << "Cannot do type inference without a environment:" - << con->constructor->name_hint; + CHECK(mod_.defined()) << "Cannot do type inference without a environment:" + << con->constructor->name_hint; TypeData td = mod_->type_definitions.at(con->constructor->belong_to); auto pc = GetRef(con); @@ -242,26 +227,24 @@ class TypeInferencer : private ExprFunctor, this->ReportFatalError(pc, ErrorBuilder() << "Expected a type call, got " << unified); } if (td->header != tc->func) { - this->ReportFatalError(pc, - ErrorBuilder() << "ADT headers must match, but we have " - << td->header << " and " << tc->func); + this->ReportFatalError(pc, ErrorBuilder() << "ADT headers must match, but we have " + << td->header << " and " << tc->func); } if (td->type_vars.size() != tc->args.size()) { - this->ReportFatalError(pc, - ErrorBuilder() << "The number of type args must match" - << "the number of type vars in the type data: " - << td->type_vars.size() << " != " << tc->args.size()); + this->ReportFatalError( + pc, ErrorBuilder() << "The number of type args must match" + << "the number of type vars in the type data: " << td->type_vars.size() + << " != " << tc->args.size()); } - std::unordered_map type_var_map_; + std::unordered_map type_var_map_; for (size_t i = 0; i < td->type_vars.size(); ++i) { type_var_map_[td->type_vars[i]] = tc->args[i]; } CHECK(con->constructor->inputs.size() == con->patterns.size()) << "not enough pattern"; if (con->constructor->inputs.size() != con->patterns.size()) { - this->ReportFatalError(pc, - ErrorBuilder() << "Not enough inputs for the constructor; " - << "expected " << con->constructor->inputs.size() - << ", got " << con->patterns.size()); + this->ReportFatalError(pc, ErrorBuilder() << "Not enough inputs for the constructor; " + << "expected " << con->constructor->inputs.size() + << ", got " << con->patterns.size()); } for (size_t i = 0; i < con->constructor->inputs.size(); ++i) { VisitPattern(con->patterns[i], Bind(con->constructor->inputs[i], type_var_map_)); @@ -294,7 +277,7 @@ class TypeInferencer : private ExprFunctor, Unify(vt, t, pv->span); } - void VisitPattern_(const PatternWildcardNode* wc, const Type& t) { } + void VisitPattern_(const PatternWildcardNode* wc, const Type& t) {} Type VisitExpr_(const MatchNode* op) final { Type dtype = GetType(op->data); @@ -303,9 +286,7 @@ class TypeInferencer : private ExprFunctor, } Type rtype = IncompleteType(Kind::kType); for (const auto& c : op->clauses) { - rtype = this->Unify(rtype, - GetType(c->rhs), - op->span); + rtype = this->Unify(rtype, GetType(c->rhs), op->span); } if (op->complete) { @@ -319,18 +300,14 @@ class TypeInferencer : private ExprFunctor, for (auto cs : unmatched_cases) { ss << "case " << i++ << ": \n" << PrettyPrint(cs); } - this->ReportFatalError( - match, - ss); + this->ReportFatalError(match, ss); } } return rtype; } - Type VisitExpr_(const OpNode* op) final { - return op->op_type; - } + Type VisitExpr_(const OpNode* op) final { return op->op_type; } Type VisitExpr_(const LetNode* let) final { // if the definition is a function literal, permit recursion @@ -342,7 +319,6 @@ class TypeInferencer : private ExprFunctor, type_map_[let->var].checked_type = let_type; } - if (let->var->type_annotation.defined()) { let_type = Unify(let_type, let->var->type_annotation, GetRef(let)); } @@ -360,9 +336,7 @@ class TypeInferencer : private ExprFunctor, // Ensure the type of the guard is of Tensor[Bool, ()], // that is a rank-0 boolean tensor. Type cond_type = this->GetType(ite->cond); - this->Unify(cond_type, - TensorType::Scalar(tvm::DataType::Bool()), - ite->cond); + this->Unify(cond_type, TensorType::Scalar(tvm::DataType::Bool()), ite->cond); Type checked_true = this->GetType(ite->true_branch); Type checked_false = this->GetType(ite->false_branch); return this->Unify(checked_true, checked_false, GetRef(ite)); @@ -372,9 +346,7 @@ class TypeInferencer : private ExprFunctor, // which are registered in the style defined in src/relay/op/*. // // The result will be the return type of the operator. - Type PrimitiveCall(const FuncTypeNode* op, - Array arg_types, - const Attrs& attrs, + Type PrimitiveCall(const FuncTypeNode* op, Array arg_types, const Attrs& attrs, const ObjectRef& loc) { if (op->type_params.size() != arg_types.size() + 1) return Type(); if (op->type_constraints.size() != 1) return Type(); @@ -387,8 +359,7 @@ class TypeInferencer : private ExprFunctor, Type rtype = IncompleteType(Kind::kType); arg_types.push_back(rtype); // we can do simple replacement here - solver_.AddConstraint(TypeRelation( - rel->func, arg_types, arg_types.size() - 1, attrs), loc); + solver_.AddConstraint(TypeRelation(rel->func, arg_types, arg_types.size() - 1, attrs), loc); return rtype; } @@ -417,9 +388,7 @@ class TypeInferencer : private ExprFunctor, ret_type = IncompleteType(Kind::kType); } - Type inst_ty = FuncType(fn_ty->arg_types, - ret_type, {}, - fn_ty->type_constraints); + Type inst_ty = FuncType(fn_ty->arg_types, ret_type, {}, fn_ty->type_constraints); inst_ty = Bind(inst_ty, subst_map); return Downcast(inst_ty); } @@ -437,7 +406,6 @@ class TypeInferencer : private ExprFunctor, return InstantiateFuncType(fn_ty, type_args); } - void AddTypeArgs(const Expr& expr, Array type_args) { auto type_info = type_map_.find(expr); if (type_info == type_map_.end()) { @@ -456,10 +424,8 @@ class TypeInferencer : private ExprFunctor, if (fn_ty_node == nullptr && inc_ty_node == nullptr) { this->ReportFatalError( - GetRef(call), - ErrorBuilder() - << "only expressions with function types can be called, found " - << ftype); + GetRef(call), + ErrorBuilder() << "only expressions with function types can be called, found " << ftype); } // incomplete type => it must be a function taking the arg types @@ -474,12 +440,10 @@ class TypeInferencer : private ExprFunctor, Array type_args = call->type_args; if (type_args.size() > fn_ty_node->type_params.size()) { this->ReportFatalError(GetRef(call), - ErrorBuilder() - << "Incorrect number of type args in " - << call->span << ": " - << "Expected " - << fn_ty_node->type_params.size() - << "but got " << type_args.size()); + ErrorBuilder() + << "Incorrect number of type args in " << call->span << ": " + << "Expected " << fn_ty_node->type_params.size() << "but got " + << type_args.size()); } FuncType fn_ty = InstantiateFuncType(fn_ty_node, type_args); @@ -491,17 +455,15 @@ class TypeInferencer : private ExprFunctor, if (type_arity != number_of_args) { if (type_arity < number_of_args) { - this->ReportFatalError( - GetRef(call), - ErrorBuilder() - << "the function is provided too many arguments " - << "expected " << type_arity << ", found " << number_of_args); + this->ReportFatalError(GetRef(call), + ErrorBuilder() + << "the function is provided too many arguments " + << "expected " << type_arity << ", found " << number_of_args); } else { - this->ReportFatalError( - GetRef(call), - ErrorBuilder() - << "the function is provided too few arguments " - << "expected " << type_arity << ", found " << number_of_args); + this->ReportFatalError(GetRef(call), + ErrorBuilder() + << "the function is provided too few arguments " + << "expected " << type_arity << ", found " << number_of_args); } } @@ -511,9 +473,8 @@ class TypeInferencer : private ExprFunctor, for (auto cs : fn_ty->type_constraints) { if (const auto* tr = cs.as()) { - solver_.AddConstraint( - TypeRelation(tr->func, tr->args, tr->num_inputs, call->attrs), - GetRef(call)); + solver_.AddConstraint(TypeRelation(tr->func, tr->args, tr->num_inputs, call->attrs), + GetRef(call)); } else { solver_.AddConstraint(cs, GetRef(call)); } @@ -529,9 +490,7 @@ class TypeInferencer : private ExprFunctor, } if (const OpNode* opnode = call->op.as()) { - Type rtype = PrimitiveCall(opnode->op_type.as(), - arg_types, - call->attrs, + Type rtype = PrimitiveCall(opnode->op_type.as(), arg_types, call->attrs, GetRef(call)); if (rtype.defined()) { AddTypeArgs(GetRef(call), arg_types); @@ -560,9 +519,7 @@ class TypeInferencer : private ExprFunctor, return solver_.Resolve(ret); } - Type VisitExpr_(const RefCreateNode* op) final { - return RelayRefType(GetType(op->value)); - } + Type VisitExpr_(const RefCreateNode* op) final { return RelayRefType(GetType(op->value)); } Type VisitExpr_(const RefReadNode* op) final { Type it = IncompleteType(Kind::kType); @@ -578,16 +535,13 @@ class TypeInferencer : private ExprFunctor, } Type VisitExpr_(const ConstructorNode* c) final { - CHECK(mod_.defined()) - << "Cannot do type inference without a environment:" - << c->name_hint; + CHECK(mod_.defined()) << "Cannot do type inference without a environment:" << c->name_hint; TypeData td = mod_->LookupTypeDef(c->belong_to); std::vector types; - for (const auto & t : td->type_vars) { + for (const auto& t : td->type_vars) { types.push_back(t); } - return FuncType(c->inputs, TypeCall(c->belong_to, types), - td->type_vars, {}); + return FuncType(c->inputs, TypeCall(c->belong_to, types), td->type_vars, {}); } void Solve() { @@ -601,74 +555,41 @@ class TypeInferencer : private ExprFunctor, class TypeInferencer::Resolver : public ExprMutator, PatternMutator { public: - Resolver(const std::unordered_map& tmap, + Resolver(const std::unordered_map& tmap, TypeSolver* solver) - : tmap_(tmap), solver_(solver) { - } + : tmap_(tmap), solver_(solver) {} - Expr VisitExpr_(const VarNode* op) final { - return VisitVar(GetRef(op)); - } + Expr VisitExpr_(const VarNode* op) final { return VisitVar(GetRef(op)); } - Expr VisitExpr_(const ConstantNode* op) final { - return AttachCheckedType(op); - } + Expr VisitExpr_(const ConstantNode* op) final { return AttachCheckedType(op); } - Expr VisitExpr_(const GlobalVarNode* op) final { - return GetRef(op); - } + Expr VisitExpr_(const GlobalVarNode* op) final { return GetRef(op); } - Expr VisitExpr_(const OpNode* op) final { - return ExprMutator::VisitExpr_(op); - } + Expr VisitExpr_(const OpNode* op) final { return ExprMutator::VisitExpr_(op); } - Expr VisitExpr_(const TupleNode* op) final { - return AttachCheckedType(op); - } + Expr VisitExpr_(const TupleNode* op) final { return AttachCheckedType(op); } - Expr VisitExpr_(const TupleGetItemNode* op) final { - return AttachCheckedType(op); - } + Expr VisitExpr_(const TupleGetItemNode* op) final { return AttachCheckedType(op); } - Expr VisitExpr_(const FunctionNode* op) final { - return AttachCheckedType(op); - } + Expr VisitExpr_(const FunctionNode* op) final { return AttachCheckedType(op); } - Expr VisitExpr_(const CallNode* op) final { - return AttachCheckedType(op); - } + Expr VisitExpr_(const CallNode* op) final { return AttachCheckedType(op); } - Expr VisitExpr_(const LetNode* op) final { - return AttachCheckedType(op); - } + Expr VisitExpr_(const LetNode* op) final { return AttachCheckedType(op); } - Expr VisitExpr_(const IfNode* op) final { - return AttachCheckedType(op); - } + Expr VisitExpr_(const IfNode* op) final { return AttachCheckedType(op); } - Expr VisitExpr_(const RefCreateNode* op) final { - return AttachCheckedType(op); - } + Expr VisitExpr_(const RefCreateNode* op) final { return AttachCheckedType(op); } - Expr VisitExpr_(const RefReadNode* op) final { - return AttachCheckedType(op); - } + Expr VisitExpr_(const RefReadNode* op) final { return AttachCheckedType(op); } - Expr VisitExpr_(const RefWriteNode* op) final { - return AttachCheckedType(op); - } + Expr VisitExpr_(const RefWriteNode* op) final { return AttachCheckedType(op); } - Expr VisitExpr_(const ConstructorNode* op) final { - return AttachCheckedType(op); - } + Expr VisitExpr_(const ConstructorNode* op) final { return AttachCheckedType(op); } - Expr VisitExpr_(const MatchNode* op) final { - return AttachCheckedType(op); - } + Expr VisitExpr_(const MatchNode* op) final { return AttachCheckedType(op); } - Pattern VisitPattern(const Pattern& p) final { - return PatternMutator::VisitPattern(p); - } + Pattern VisitPattern(const Pattern& p) final { return PatternMutator::VisitPattern(p); } Var VisitVar(const Var& v) final { if (vmap_.count(v) == 0) { @@ -678,7 +599,7 @@ class TypeInferencer::Resolver : public ExprMutator, PatternMutator { } // attach checked type to the mutated node. - template + template Expr AttachCheckedType(const T* op) { auto it = tmap_.find(GetRef(op)); CHECK(it != tmap_.end()); @@ -687,42 +608,34 @@ class TypeInferencer::Resolver : public ExprMutator, PatternMutator { // TODO(@jroesch): it would be nice if we would report resolution // errors directly on the program. CHECK(checked_type.as() == nullptr) - << "Cannot resolve type of " << GetRef(op) - << " at " << op->span; + << "Cannot resolve type of " << GetRef(op) << " at " << op->span; Expr new_e = ExprMutator::VisitExpr_(op); // new_call and new_var's code is only going to be valid for VarNode/CallNode. // Compiler optimization will likely fold these away for other nodes. - CallNode* new_call =( - std::is_base_of::value ? - const_cast(static_cast(new_e.get())) : nullptr); - VarNode* new_var =( - std::is_base_of::value ? - const_cast(static_cast(new_e.get())) : nullptr); - FunctionNode* new_fn =( - std::is_base_of::value ? - const_cast(static_cast(new_e.get())) : nullptr); + CallNode* new_call = (std::is_base_of::value + ? const_cast(static_cast(new_e.get())) + : nullptr); + VarNode* new_var = (std::is_base_of::value + ? const_cast(static_cast(new_e.get())) + : nullptr); + FunctionNode* new_fn = + (std::is_base_of::value + ? const_cast(static_cast(new_e.get())) + : nullptr); // check if we need update the new_e bool need_update_type = !checked_type.same_as(new_e->checked_type_); - bool need_update_call = ( - std::is_base_of::value && - it->second.type_args.defined() && - !it->second.type_args.same_as(new_call->type_args)); - bool need_update_var = ( - std::is_base_of::value && - update_missing_type_annotation_ && - !new_var->type_annotation.defined()); - - bool need_update_fn =( - std::is_base_of::value && - update_missing_type_annotation_ && - !new_fn->ret_type.defined()); - - if (!need_update_type && - !need_update_var && - !need_update_call && - !need_update_fn) { + bool need_update_call = + (std::is_base_of::value && it->second.type_args.defined() && + !it->second.type_args.same_as(new_call->type_args)); + bool need_update_var = (std::is_base_of::value && update_missing_type_annotation_ && + !new_var->type_annotation.defined()); + + bool need_update_fn = (std::is_base_of::value && + update_missing_type_annotation_ && !new_fn->ret_type.defined()); + + if (!need_update_type && !need_update_var && !need_update_call && !need_update_fn) { return new_e; } @@ -732,15 +645,11 @@ class TypeInferencer::Resolver : public ExprMutator, PatternMutator { // we make a copy mutating an existing reference. ObjectPtr ptr = make_object(*new_e.as()); new_e = Expr(ptr); - new_call = ( - std::is_base_of::value ? - static_cast(ptr.get()) : nullptr); - new_var = ( - std::is_base_of::value ? - static_cast(ptr.get()) : nullptr); - new_fn = ( - std::is_base_of::value ? - static_cast(ptr.get()) : nullptr); + new_call = + (std::is_base_of::value ? static_cast(ptr.get()) : nullptr); + new_var = (std::is_base_of::value ? static_cast(ptr.get()) : nullptr); + new_fn = (std::is_base_of::value ? static_cast(ptr.get()) + : nullptr); } // attach the information. @@ -765,13 +674,11 @@ class TypeInferencer::Resolver : public ExprMutator, PatternMutator { return new_e; } - Type VisitType(const Type &t) final { - return solver_->Resolve(t); - } + Type VisitType(const Type& t) final { return solver_->Resolve(t); } private: - std::unordered_map vmap_; - const std::unordered_map& tmap_; + std::unordered_map vmap_; + const std::unordered_map& tmap_; TypeSolver* solver_; // whether attach the checked type as type_annotation // if original type anntation is missing. @@ -793,17 +700,21 @@ Expr TypeInferencer::Infer(Expr expr) { struct AllCheckTypePopulated : ExprVisitor { void VisitExpr(const Expr& e) { - if (e.as()) { return; } - if (e.as()) { return; } - if (e.as()) { return; } + if (e.as()) { + return; + } + if (e.as()) { + return; + } + if (e.as()) { + return; + } CHECK(e->checked_type_.defined()) << "Expression: " << e; return ExprVisitor::VisitExpr(e); } }; -void EnsureCheckedType(const Expr& e) { - AllCheckTypePopulated().VisitExpr(e); -} +void EnsureCheckedType(const Expr& e) { AllCheckTypePopulated().VisitExpr(e); } Expr InferType(const Expr& expr, const IRModule& mod) { auto main = mod->GetGlobalVar("main"); @@ -811,15 +722,12 @@ Expr InferType(const Expr& expr, const IRModule& mod) { auto e = inferencer.Infer(expr); CHECK(WellFormed(e)); auto free_tvars = FreeTypeVars(e, mod); - CHECK(free_tvars.size() == 0) - << "Found unbound type variables in " << e << ": " << free_tvars; + CHECK(free_tvars.size() == 0) << "Found unbound type variables in " << e << ": " << free_tvars; EnsureCheckedType(e); return e; } -Function InferType(const Function& func, - const IRModule& mod, - const GlobalVar& var) { +Function InferType(const Function& func, const IRModule& mod, const GlobalVar& var) { CHECK(mod.defined()) << "internal error: module must be set for type inference"; Function func_copy = Function(make_object(*func.operator->())); func_copy->checked_type_ = func_copy->func_type_annotation(); @@ -828,11 +736,9 @@ Function InferType(const Function& func, mod->Remove(var); CHECK(WellFormed(func_ret)); auto free_tvars = FreeTypeVars(func_ret, mod); - CHECK(free_tvars.size() == 0) - << "Found unbound type variables in: " - << std::endl - << AsText(func, true) - << std::endl << free_tvars; + CHECK(free_tvars.size() == 0) << "Found unbound type variables in: " << std::endl + << AsText(func, true) << std::endl + << free_tvars; return Downcast(func_ret); } @@ -840,16 +746,11 @@ namespace transform { Pass InferType() { runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - return Downcast(InferType(f, m)); - }; + [=](Function f, IRModule m, PassContext pc) { return Downcast(InferType(f, m)); }; return CreateFunctionPass(pass_func, 0, "InferType", {}); } -TVM_REGISTER_GLOBAL("relay._transform.InferType") -.set_body_typed([]() { - return InferType(); -}); +TVM_REGISTER_GLOBAL("relay._transform.InferType").set_body_typed([]() { return InferType(); }); } // namespace transform diff --git a/src/runtime/builtin_fp16.cc b/src/runtime/builtin_fp16.cc index 60dc55d8c24a..d229491a4c7b 100644 --- a/src/runtime/builtin_fp16.cc +++ b/src/runtime/builtin_fp16.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -20,7 +20,7 @@ /*! * \file builtin_fp16.cc * \brief Functions for conversion between fp32 and fp16 -*/ + */ #include #include diff --git a/src/runtime/c_runtime_api.cc b/src/runtime/c_runtime_api.cc index fb1f74da2103..0164b1bc4d39 100644 --- a/src/runtime/c_runtime_api.cc +++ b/src/runtime/c_runtime_api.cc @@ -22,20 +22,22 @@ * \brief Device specific implementations */ #include -#include #include -#include +#include +#include #include +#include #include -#include -#include -#include + #include -#include -#include +#include #include -#include "runtime_base.h" +#include +#include +#include + #include "object_internal.h" +#include "runtime_base.h" namespace tvm { namespace runtime { @@ -90,9 +92,7 @@ class DeviceAPIManager { public: static const int kMaxDeviceAPI = 32; // Get API - static DeviceAPI* Get(const TVMContext& ctx) { - return Get(ctx.device_type); - } + static DeviceAPI* Get(const TVMContext& ctx) { return Get(ctx.device_type); } static DeviceAPI* Get(int dev_type, bool allow_missing = false) { return Global()->GetAPI(dev_type, allow_missing); } @@ -102,9 +102,7 @@ class DeviceAPIManager { DeviceAPI* rpc_api_{nullptr}; std::mutex mutex_; // constructor - DeviceAPIManager() { - std::fill(api_.begin(), api_.end(), nullptr); - } + DeviceAPIManager() { std::fill(api_.begin(), api_.end(), nullptr); } // Global static variable. static DeviceAPIManager* Global() { static DeviceAPIManager inst; @@ -130,8 +128,7 @@ class DeviceAPIManager { std::string factory = "device_api." + name; auto* f = Registry::Get(factory); if (f == nullptr) { - CHECK(allow_missing) - << "Device API " << name << " is not enabled."; + CHECK(allow_missing) << "Device API " << name << " is not enabled."; return nullptr; } void* ptr = (*f)(); @@ -140,19 +137,14 @@ class DeviceAPIManager { }; DeviceAPI* DeviceAPI::Get(TVMContext ctx, bool allow_missing) { - return DeviceAPIManager::Get( - static_cast(ctx.device_type), allow_missing); + return DeviceAPIManager::Get(static_cast(ctx.device_type), allow_missing); } -void* DeviceAPI::AllocWorkspace(TVMContext ctx, - size_t size, - DLDataType type_hint) { +void* DeviceAPI::AllocWorkspace(TVMContext ctx, size_t size, DLDataType type_hint) { return AllocDataSpace(ctx, size, kTempAllocaAlignment, type_hint); } -void DeviceAPI::FreeWorkspace(TVMContext ctx, void* ptr) { - FreeDataSpace(ctx, ptr); -} +void DeviceAPI::FreeWorkspace(TVMContext ctx, void* ptr) { FreeDataSpace(ctx, ptr); } TVMStreamHandle DeviceAPI::CreateStream(TVMContext ctx) { LOG(FATAL) << "Device does not support stream api."; @@ -163,8 +155,7 @@ void DeviceAPI::FreeStream(TVMContext ctx, TVMStreamHandle stream) { LOG(FATAL) << "Device does not support stream api."; } -void DeviceAPI::SyncStreamFromTo(TVMContext ctx, - TVMStreamHandle event_src, +void DeviceAPI::SyncStreamFromTo(TVMContext ctx, TVMStreamHandle event_src, TVMStreamHandle event_dst) { LOG(FATAL) << "Device does not support stream api."; } @@ -256,7 +247,8 @@ std::string NormalizeError(std::string err_msg) { // Parse error type. { size_t start_pos = 0, end_pos; - for (; start_pos < line.length() && line[start_pos] == ' '; ++start_pos) {} + for (; start_pos < line.length() && line[start_pos] == ' '; ++start_pos) { + } for (end_pos = start_pos; end_pos < line.length(); ++end_pos) { char ch = line[end_pos]; if (ch == ':') { @@ -268,8 +260,9 @@ std::string NormalizeError(std::string err_msg) { } if (error_type.length() != 0) { // if we successfully detected error_type: trim the following space. - for (start_pos = end_pos + 1; - start_pos < line.length() && line[start_pos] == ' '; ++start_pos) {} + for (start_pos = end_pos + 1; start_pos < line.length() && line[start_pos] == ' '; + ++start_pos) { + } line = line.substr(start_pos); } else { // did not detect error_type, use default value. @@ -345,22 +338,16 @@ struct TVMRuntimeEntry { typedef dmlc::ThreadLocalStore TVMAPIRuntimeStore; -const char *TVMGetLastError() { - return TVMAPIRuntimeStore::Get()->last_error.c_str(); -} +const char* TVMGetLastError() { return TVMAPIRuntimeStore::Get()->last_error.c_str(); } -int TVMAPIHandleException(const std::runtime_error &e) { +int TVMAPIHandleException(const std::runtime_error& e) { TVMAPISetLastError(NormalizeError(e.what()).c_str()); return -1; } -void TVMAPISetLastError(const char* msg) { - TVMAPIRuntimeStore::Get()->last_error = msg; -} +void TVMAPISetLastError(const char* msg) { TVMAPIRuntimeStore::Get()->last_error = msg; } -int TVMModLoadFromFile(const char* file_name, - const char* format, - TVMModuleHandle* out) { +int TVMModLoadFromFile(const char* file_name, const char* format, TVMModuleHandle* out) { API_BEGIN(); TVMRetValue ret; ret = Module::LoadFromFile(file_name, format); @@ -371,21 +358,16 @@ int TVMModLoadFromFile(const char* file_name, API_END(); } -int TVMModImport(TVMModuleHandle mod, - TVMModuleHandle dep) { +int TVMModImport(TVMModuleHandle mod, TVMModuleHandle dep) { API_BEGIN(); - ObjectInternal::GetModuleNode(mod)->Import( - GetRef(ObjectInternal::GetModuleNode(dep))); + ObjectInternal::GetModuleNode(mod)->Import(GetRef(ObjectInternal::GetModuleNode(dep))); API_END(); } -int TVMModGetFunction(TVMModuleHandle mod, - const char* func_name, - int query_imports, - TVMFunctionHandle *func) { +int TVMModGetFunction(TVMModuleHandle mod, const char* func_name, int query_imports, + TVMFunctionHandle* func) { API_BEGIN(); - PackedFunc pf = ObjectInternal::GetModuleNode(mod)->GetFunction( - func_name, query_imports != 0); + PackedFunc pf = ObjectInternal::GetModuleNode(mod)->GetFunction(func_name, query_imports != 0); if (pf != nullptr) { *func = new PackedFunc(pf); } else { @@ -394,23 +376,15 @@ int TVMModGetFunction(TVMModuleHandle mod, API_END(); } -int TVMModFree(TVMModuleHandle mod) { - return TVMObjectFree(mod); -} +int TVMModFree(TVMModuleHandle mod) { return TVMObjectFree(mod); } -int TVMBackendGetFuncFromEnv(void* mod_node, - const char* func_name, - TVMFunctionHandle *func) { +int TVMBackendGetFuncFromEnv(void* mod_node, const char* func_name, TVMFunctionHandle* func) { API_BEGIN(); - *func = (TVMFunctionHandle)( - static_cast(mod_node)->GetFuncFromEnv(func_name)); + *func = (TVMFunctionHandle)(static_cast(mod_node)->GetFuncFromEnv(func_name)); API_END(); } -void* TVMBackendAllocWorkspace(int device_type, - int device_id, - uint64_t size, - int dtype_code_hint, +void* TVMBackendAllocWorkspace(int device_type, int device_id, uint64_t size, int dtype_code_hint, int dtype_bits_hint) { TVMContext ctx; ctx.device_type = static_cast(device_type); @@ -421,14 +395,10 @@ void* TVMBackendAllocWorkspace(int device_type, type_hint.bits = static_cast(dtype_bits_hint); type_hint.lanes = 1; - return DeviceAPIManager::Get(ctx)->AllocWorkspace(ctx, - static_cast(size), - type_hint); + return DeviceAPIManager::Get(ctx)->AllocWorkspace(ctx, static_cast(size), type_hint); } -int TVMBackendFreeWorkspace(int device_type, - int device_id, - void* ptr) { +int TVMBackendFreeWorkspace(int device_type, int device_id, void* ptr) { TVMContext ctx; ctx.device_type = static_cast(device_type); ctx.device_id = device_id; @@ -436,10 +406,7 @@ int TVMBackendFreeWorkspace(int device_type, return 0; } -int TVMBackendRunOnce(void** handle, - int (*f)(void*), - void* cdata, - int nbytes) { +int TVMBackendRunOnce(void** handle, int (*f)(void*), void* cdata, int nbytes) { if (*handle == nullptr) { *handle = reinterpret_cast(1); return (*f)(cdata); @@ -453,20 +420,14 @@ int TVMFuncFree(TVMFunctionHandle func) { API_END(); } -int TVMFuncCall(TVMFunctionHandle func, - TVMValue* args, - int* arg_type_codes, - int num_args, - TVMValue* ret_val, - int* ret_type_code) { +int TVMFuncCall(TVMFunctionHandle func, TVMValue* args, int* arg_type_codes, int num_args, + TVMValue* ret_val, int* ret_type_code) { API_BEGIN(); + TVMRetValue rv; - (*static_cast(func)).CallPacked( - TVMArgs(args, arg_type_codes, num_args), &rv); + (*static_cast(func)).CallPacked(TVMArgs(args, arg_type_codes, num_args), &rv); // handle return string. - if (rv.type_code() == kTVMStr || - rv.type_code() == kTVMDataType || - rv.type_code() == kTVMBytes) { + if (rv.type_code() == kTVMStr || rv.type_code() == kTVMDataType || rv.type_code() == kTVMBytes) { TVMRuntimeEntry* e = TVMAPIRuntimeStore::Get(); if (rv.type_code() != kTVMDataType) { e->ret_str = *rv.ptr(); @@ -488,10 +449,7 @@ int TVMFuncCall(TVMFunctionHandle func, API_END(); } -int TVMCFuncSetReturn(TVMRetValueHandle ret, - TVMValue* value, - int* type_code, - int num_ret) { +int TVMCFuncSetReturn(TVMRetValueHandle ret, TVMValue* value, int* type_code, int num_ret) { API_BEGIN(); CHECK_EQ(num_ret, 1); TVMRetValue* rv = static_cast(ret); @@ -499,32 +457,28 @@ int TVMCFuncSetReturn(TVMRetValueHandle ret, API_END(); } -int TVMFuncCreateFromCFunc(TVMPackedCFunc func, - void* resource_handle, - TVMPackedCFuncFinalizer fin, - TVMFunctionHandle *out) { +int TVMFuncCreateFromCFunc(TVMPackedCFunc func, void* resource_handle, TVMPackedCFuncFinalizer fin, + TVMFunctionHandle* out) { API_BEGIN(); if (fin == nullptr) { - *out = new PackedFunc( - [func, resource_handle](TVMArgs args, TVMRetValue* rv) { - int ret = func((TVMValue*)args.values, (int*)args.type_codes, // NOLINT(*) - args.num_args, rv, resource_handle); - if (ret != 0) { - throw dmlc::Error(TVMGetLastError() + ::dmlc::StackTrace()); - } - }); + *out = new PackedFunc([func, resource_handle](TVMArgs args, TVMRetValue* rv) { + int ret = func((TVMValue*)args.values, (int*)args.type_codes, // NOLINT(*) + args.num_args, rv, resource_handle); + if (ret != 0) { + throw dmlc::Error(TVMGetLastError() + ::dmlc::StackTrace()); + } + }); } else { // wrap it in a shared_ptr, with fin as deleter. // so fin will be called when the lambda went out of scope. std::shared_ptr rpack(resource_handle, fin); - *out = new PackedFunc( - [func, rpack](TVMArgs args, TVMRetValue* rv) { - int ret = func((TVMValue*)args.values, (int*)args.type_codes, // NOLINT(*) - args.num_args, rv, rpack.get()); - if (ret != 0) { - throw dmlc::Error(TVMGetLastError() + ::dmlc::StackTrace()); - } - }); + *out = new PackedFunc([func, rpack](TVMArgs args, TVMRetValue* rv) { + int ret = func((TVMValue*)args.values, (int*)args.type_codes, // NOLINT(*) + args.num_args, rv, rpack.get()); + if (ret != 0) { + throw dmlc::Error(TVMGetLastError() + ::dmlc::StackTrace()); + } + }); } API_END(); } @@ -565,9 +519,7 @@ int TVMSynchronize(int device_type, int device_id, TVMStreamHandle stream) { API_END(); } -int TVMStreamStreamSynchronize(int device_type, - int device_id, - TVMStreamHandle src, +int TVMStreamStreamSynchronize(int device_type, int device_id, TVMStreamHandle src, TVMStreamHandle dst) { API_BEGIN(); TVMContext ctx; @@ -585,35 +537,55 @@ int TVMCbArgToReturn(TVMValue* value, int* code) { API_END(); } +int TVMDeviceAllocDataSpace(DLContext ctx, size_t nbytes, size_t alignment, DLDataType type_hint, + void** out_data) { + API_BEGIN(); + out_data[0] = DeviceAPIManager::Get(ctx)->AllocDataSpace(ctx, nbytes, alignment, type_hint); + API_END(); +} + +int TVMDeviceFreeDataSpace(DLContext ctx, void* ptr) { + API_BEGIN(); + DeviceAPIManager::Get(ctx)->FreeDataSpace(ctx, ptr); + API_END(); +} + +int TVMDeviceCopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, + size_t num_bytes, TVMContext ctx_from, TVMContext ctx_to, + DLDataType type_hint, TVMStreamHandle stream) { + API_BEGIN(); + TVMContext ctx = ctx_from.device_type != kDLCPU ? ctx_from : ctx_to; + DeviceAPIManager::Get(ctx)->CopyDataFromTo(from, from_offset, to, to_offset, num_bytes, ctx_from, + ctx_to, type_hint, stream); + API_END(); +} + // set device api TVM_REGISTER_GLOBAL(tvm::runtime::symbol::tvm_set_device) -.set_body([](TVMArgs args, TVMRetValue *ret) { - TVMContext ctx; - ctx.device_type = static_cast(args[0].operator int()); - ctx.device_id = args[1]; - DeviceAPIManager::Get(ctx)->SetDevice(ctx); - }); + .set_body([](TVMArgs args, TVMRetValue* ret) { + TVMContext ctx; + ctx.device_type = static_cast(args[0].operator int()); + ctx.device_id = args[1]; + DeviceAPIManager::Get(ctx)->SetDevice(ctx); + }); // set device api -TVM_REGISTER_GLOBAL("runtime.GetDeviceAttr") -.set_body([](TVMArgs args, TVMRetValue *ret) { - TVMContext ctx; - ctx.device_type = static_cast(args[0].operator int()); - ctx.device_id = args[1]; - - DeviceAttrKind kind = static_cast(args[2].operator int()); - if (kind == kExist) { - DeviceAPI* api = DeviceAPIManager::Get(ctx.device_type, true); - if (api != nullptr) { - api->GetAttr(ctx, kind, ret); - } else { - *ret = 0; - } +TVM_REGISTER_GLOBAL("runtime.GetDeviceAttr").set_body([](TVMArgs args, TVMRetValue* ret) { + TVMContext ctx; + ctx.device_type = static_cast(args[0].operator int()); + ctx.device_id = args[1]; + + DeviceAttrKind kind = static_cast(args[2].operator int()); + if (kind == kExist) { + DeviceAPI* api = DeviceAPIManager::Get(ctx.device_type, true); + if (api != nullptr) { + api->GetAttr(ctx, kind, ret); } else { - DeviceAPIManager::Get(ctx)->GetAttr(ctx, kind, ret); + *ret = 0; } - }); - + } else { + DeviceAPIManager::Get(ctx)->GetAttr(ctx, kind, ret); + } +}); -TVM_REGISTER_GLOBAL("runtime.TVMSetStream") -.set_body_typed(TVMSetStream); +TVM_REGISTER_GLOBAL("runtime.TVMSetStream").set_body_typed(TVMSetStream); diff --git a/src/runtime/container.cc b/src/runtime/container.cc index 81dfd3d4e252..62220a885208 100644 --- a/src/runtime/container.cc +++ b/src/runtime/container.cc @@ -19,36 +19,32 @@ /*! * \file src/runtime/container.cc - * \brief Implementations of common plain old data (POD) containers. + * \brief Implementations of common containers. */ #include #include #include -#include #include +#include namespace tvm { namespace runtime { using namespace vm; -TVM_REGISTER_GLOBAL("runtime.GetADTTag") -.set_body([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("runtime.GetADTTag").set_body([](TVMArgs args, TVMRetValue* rv) { ObjectRef obj = args[0]; const auto& adt = Downcast(obj); *rv = static_cast(adt.tag()); }); -TVM_REGISTER_GLOBAL("runtime.GetADTSize") -.set_body([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("runtime.GetADTSize").set_body([](TVMArgs args, TVMRetValue* rv) { ObjectRef obj = args[0]; const auto& adt = Downcast(obj); *rv = static_cast(adt.size()); }); - -TVM_REGISTER_GLOBAL("runtime.GetADTFields") -.set_body([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("runtime.GetADTFields").set_body([](TVMArgs args, TVMRetValue* rv) { ObjectRef obj = args[0]; int idx = args[1]; const auto& adt = Downcast(obj); @@ -56,8 +52,7 @@ TVM_REGISTER_GLOBAL("runtime.GetADTFields") *rv = adt[idx]; }); -TVM_REGISTER_GLOBAL("runtime.Tuple") -.set_body([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("runtime.Tuple").set_body([](TVMArgs args, TVMRetValue* rv) { std::vector fields; for (auto i = 0; i < args.size(); ++i) { fields.push_back(args[i]); @@ -65,8 +60,7 @@ TVM_REGISTER_GLOBAL("runtime.Tuple") *rv = ADT::Tuple(fields); }); -TVM_REGISTER_GLOBAL("runtime.ADT") -.set_body([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("runtime.ADT").set_body([](TVMArgs args, TVMRetValue* rv) { int itag = args[0]; size_t tag = static_cast(itag); std::vector fields; @@ -76,31 +70,14 @@ TVM_REGISTER_GLOBAL("runtime.ADT") *rv = ADT(tag, fields); }); -TVM_REGISTER_GLOBAL("runtime.String") -.set_body_typed([](std::string str) { +TVM_REGISTER_GLOBAL("runtime.String").set_body_typed([](std::string str) { return String(std::move(str)); }); -TVM_REGISTER_GLOBAL("runtime.GetStringSize") -.set_body_typed([](String str) { - return static_cast(str.size()); -}); - -TVM_REGISTER_GLOBAL("runtime.GetStdString") -.set_body_typed([](String str) { +TVM_REGISTER_GLOBAL("runtime.GetFFIString").set_body_typed([](String str) { return std::string(str); }); -TVM_REGISTER_GLOBAL("runtime.CompareString") -.set_body_typed([](String lhs, String rhs) { - return lhs.compare(rhs); -}); - -TVM_REGISTER_GLOBAL("runtime.StringHash") -.set_body_typed([](String str) { - return static_cast(std::hash()(str)); -}); - TVM_REGISTER_OBJECT_TYPE(ADTObj); TVM_REGISTER_OBJECT_TYPE(StringObj); TVM_REGISTER_OBJECT_TYPE(ClosureObj); diff --git a/src/runtime/contrib/cblas/cblas.cc b/src/runtime/contrib/cblas/cblas.cc index d4959be64cf1..0cf4c69cdf1e 100644 --- a/src/runtime/contrib/cblas/cblas.cc +++ b/src/runtime/contrib/cblas/cblas.cc @@ -21,8 +21,9 @@ * \file Use external cblas library call. */ #include -#include #include +#include + #include "gemm_common.h" extern "C" { @@ -50,8 +51,8 @@ struct CblasSgemmOp { void operator()(bool ta, bool tb, int M, int N, int K, float alpha, float* A, int lda, float* B, int ldb, float beta, float* C, int ldc) { #if USE_DNNL == 1 - dnnl_sgemm(BooleanToTransposeChar(tb), BooleanToTransposeChar(ta), N, M, K, alpha, B, - ldb, A, lda, beta, C, ldc); + dnnl_sgemm(BooleanToTransposeChar(tb), BooleanToTransposeChar(ta), N, M, K, alpha, B, ldb, A, + lda, beta, C, ldc); #else cblas_sgemm(CblasColMajor, BooleanToTranspose(ta), BooleanToTranspose(tb), M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); @@ -159,8 +160,7 @@ struct CblasDgemmBatchIterativeOp { }; // matrix multiplication for row major -TVM_REGISTER_GLOBAL("tvm.contrib.cblas.matmul") -.set_body([](TVMArgs args, TVMRetValue* ret) { +TVM_REGISTER_GLOBAL("tvm.contrib.cblas.matmul").set_body([](TVMArgs args, TVMRetValue* ret) { DLTensor* A = args[0]; CHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); @@ -170,8 +170,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cblas.matmul") CallGemm(args, ret, CblasDgemmOp()); }); -TVM_REGISTER_GLOBAL("tvm.contrib.cblas.batch_matmul") -.set_body([](TVMArgs args, TVMRetValue* ret) { +TVM_REGISTER_GLOBAL("tvm.contrib.cblas.batch_matmul").set_body([](TVMArgs args, TVMRetValue* ret) { DLTensor* A = args[0]; CHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); if (TypeMatch(A->dtype, kDLFloat, 32)) { @@ -182,14 +181,14 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cblas.batch_matmul") }); TVM_REGISTER_GLOBAL("tvm.contrib.cblas.batch_matmul_iterative") -.set_body([](TVMArgs args, TVMRetValue* ret) { - DLTensor* A = args[0]; - CHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); - if (TypeMatch(A->dtype, kDLFloat, 32)) { - CallBatchGemm(args, ret, CblasSgemmBatchIterativeOp()); - } else { - CallBatchGemm(args, ret, CblasDgemmBatchIterativeOp()); - } -}); + .set_body([](TVMArgs args, TVMRetValue* ret) { + DLTensor* A = args[0]; + CHECK(TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); + if (TypeMatch(A->dtype, kDLFloat, 32)) { + CallBatchGemm(args, ret, CblasSgemmBatchIterativeOp()); + } else { + CallBatchGemm(args, ret, CblasDgemmBatchIterativeOp()); + } + }); } // namespace contrib } // namespace tvm diff --git a/src/runtime/contrib/cblas/gemm_common.h b/src/runtime/contrib/cblas/gemm_common.h index b73ababbbade..96d6322cc592 100644 --- a/src/runtime/contrib/cblas/gemm_common.h +++ b/src/runtime/contrib/cblas/gemm_common.h @@ -23,15 +23,16 @@ */ #pragma once -#include #include +#include + #include namespace tvm { namespace contrib { using namespace runtime; -inline int ColumnStride(DLTensor *tensor) { +inline int ColumnStride(DLTensor* tensor) { // If the tensor itself is transposed then it will have strides // backward from what we expect. Regardless, the max of the strides // (the other stride is 1) is the column stride. @@ -42,7 +43,7 @@ inline int ColumnStride(DLTensor *tensor) { } } -inline int ElementStride(DLTensor *tensor) { +inline int ElementStride(DLTensor* tensor) { if (tensor->strides) { return std::min(tensor->strides[0], tensor->strides[1]); } else { @@ -51,25 +52,21 @@ inline int ElementStride(DLTensor *tensor) { } // Reversed strides indicates an in-place transpose operation. -inline bool IsInPlaceTransposed(DLTensor *tensor) { +inline bool IsInPlaceTransposed(DLTensor* tensor) { return tensor->strides && (tensor->strides[1] > tensor->strides[0]); } -inline int RowCount(DLTensor *tensor, bool trans) { - return tensor->shape[trans ? 1 : 0]; -} +inline int RowCount(DLTensor* tensor, bool trans) { return tensor->shape[trans ? 1 : 0]; } -inline int ColumnCount(DLTensor *tensor, bool trans) { - return tensor->shape[trans ? 0 : 1]; -} +inline int ColumnCount(DLTensor* tensor, bool trans) { return tensor->shape[trans ? 0 : 1]; } // Call a column major blas. Note that data is stored in tvm as row // major, so this we switch the arguments. template -inline void CallGemm(TVMArgs args, TVMRetValue *ret, TGemmOp op) { - DLTensor *A = args[0]; - DLTensor *B = args[1]; - DLTensor *C = args[2]; +inline void CallGemm(TVMArgs args, TVMRetValue* ret, TGemmOp op) { + DLTensor* A = args[0]; + DLTensor* B = args[1]; + DLTensor* C = args[2]; bool transa = args[3]; bool transb = args[4]; int bit_depth = sizeof(typename TGemmOp::TDatatype) * 8; @@ -92,20 +89,17 @@ inline void CallGemm(TVMArgs args, TVMRetValue *ret, TGemmOp op) { CHECK(TypeMatch(C->dtype, kDLFloat, bit_depth)); double alpha = args.size() > 5 ? args[5] : 1.0; double beta = args.size() > 6 ? args[6] : 0.0; - op(transb, transa, ColumnCount(B, transb), RowCount(A, transa), - ColumnCount(A, transa), static_cast(alpha), - reinterpret_cast( - static_cast(B->data) + B->byte_offset), + op(transb, transa, ColumnCount(B, transb), RowCount(A, transa), ColumnCount(A, transa), + static_cast(alpha), + reinterpret_cast(static_cast(B->data) + B->byte_offset), ColumnStride(B), - reinterpret_cast( - static_cast(A->data) + A->byte_offset), + reinterpret_cast(static_cast(A->data) + A->byte_offset), ColumnStride(A), static_cast(beta), - reinterpret_cast( - static_cast(C->data) + C->byte_offset), + reinterpret_cast(static_cast(C->data) + C->byte_offset), ColumnStride(C)); } -inline int ColumnStride3D(DLTensor *tensor) { +inline int ColumnStride3D(DLTensor* tensor) { // If the tensor itself is transposed then it will have strides // backward from what we expect. Regardless, the max of the strides // (the other stride is 1) is the column stride. @@ -115,7 +109,7 @@ inline int ColumnStride3D(DLTensor *tensor) { return tensor->shape[2]; } } -inline int ElementStride3D(DLTensor *tensor) { +inline int ElementStride3D(DLTensor* tensor) { if (tensor->strides) { return std::min(tensor->strides[1], tensor->strides[2]); } else { @@ -123,22 +117,18 @@ inline int ElementStride3D(DLTensor *tensor) { } } // Reversed strides indicates an in-place transpose operation. -inline bool IsInPlaceTransposed3D(DLTensor *tensor) { +inline bool IsInPlaceTransposed3D(DLTensor* tensor) { return tensor->strides && (tensor->strides[2] > tensor->strides[1]); } -inline int BatchCount3D(DLTensor *tensor) { return tensor->shape[0]; } -inline int RowCount3D(DLTensor *tensor, bool trans) { - return tensor->shape[trans ? 2 : 1]; -} -inline int ColumnCount3D(DLTensor *tensor, bool trans) { - return tensor->shape[trans ? 1 : 2]; -} +inline int BatchCount3D(DLTensor* tensor) { return tensor->shape[0]; } +inline int RowCount3D(DLTensor* tensor, bool trans) { return tensor->shape[trans ? 2 : 1]; } +inline int ColumnCount3D(DLTensor* tensor, bool trans) { return tensor->shape[trans ? 1 : 2]; } template -inline void CallBatchGemm(TVMArgs args, TVMRetValue *ret, TBatchGemmOp op) { +inline void CallBatchGemm(TVMArgs args, TVMRetValue* ret, TBatchGemmOp op) { using DType = typename TBatchGemmOp::TDatatype; - DLTensor *A = args[0]; - DLTensor *B = args[1]; - DLTensor *C = args[2]; + DLTensor* A = args[0]; + DLTensor* B = args[1]; + DLTensor* C = args[2]; bool transa = args[3]; bool transb = args[4]; int bit_depth = sizeof(DType) * 8; @@ -163,16 +153,15 @@ inline void CallBatchGemm(TVMArgs args, TVMRetValue *ret, TBatchGemmOp op) { const int A_size = A->shape[1] * A->shape[2]; const int B_size = B->shape[1] * B->shape[2]; const int C_size = C->shape[1] * C->shape[2]; - DType *A_data = reinterpret_cast( - static_cast(A->data) + A->byte_offset); - DType *B_data = reinterpret_cast( - static_cast(B->data) + B->byte_offset); - DType *C_data = reinterpret_cast( - static_cast(C->data) + C->byte_offset); - op(batch_size, transb, transa, ColumnCount3D(B, transb), - RowCount3D(A, transa), ColumnCount3D(A, transa), - static_cast(alpha), - B_data, B_size, ColumnStride3D(B), A_data, A_size, ColumnStride3D(A), + DType* A_data = reinterpret_cast(static_cast(A->data) + + A->byte_offset); + DType* B_data = reinterpret_cast(static_cast(B->data) + + B->byte_offset); + DType* C_data = reinterpret_cast(static_cast(C->data) + + C->byte_offset); + op(batch_size, transb, transa, ColumnCount3D(B, transb), RowCount3D(A, transa), + ColumnCount3D(A, transa), static_cast(alpha), B_data, B_size, + ColumnStride3D(B), A_data, A_size, ColumnStride3D(A), static_cast(beta), C_data, C_size, ColumnStride3D(C)); } diff --git a/src/runtime/contrib/coreml/coreml_runtime.h b/src/runtime/contrib/coreml/coreml_runtime.h new file mode 100644 index 000000000000..05c9ac38fe2d --- /dev/null +++ b/src/runtime/contrib/coreml/coreml_runtime.h @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \brief CoreML runtime that can run coreml model + * containing only tvm PackedFunc. + * \file coreml_runtime.h + */ +#ifndef TVM_RUNTIME_CONTRIB_COREML_COREML_RUNTIME_H_ +#define TVM_RUNTIME_CONTRIB_COREML_COREML_RUNTIME_H_ + +#import +#import + +#include +#include +#include + +#include +#include +#include +#include + +namespace tvm { +namespace runtime { + +/*! + * \brief CoreML model. + */ +class CoreMLModel { + public: + /*! + * \brief constructor + * \param url The directory where compiled models are located. + */ + explicit CoreMLModel(NSURL* url) { + url_ = url; + model_ = [MLModel modelWithContentsOfURL:url error:nil]; + input_dict_ = [NSMutableDictionary dictionary]; + output_ = nil; + } + /*! + * \brief Invoke the coreml prediction. + */ + void Invoke(); + /*! + * \brief set input to the model. + * \param key The input name. + * \param data_in The input data. + */ + void SetInput(const std::string& key, DLTensor* data_in); + /*! + * \brief Return NDArray for given output index. + * \param index The output index. + * + * \return NDArray corresponding to given output node index. + */ + NDArray GetOutput(int index) const; + /*! + * \brief Return the number of outputs + * + * \return The number of outputs + */ + int GetNumOutputs() const; + + // CoreML model url + NSURL* url_; + // CoreML model + MLModel* model_; + // CoreML model input dictionary + NSMutableDictionary* input_dict_; + // CoreML model output + id output_; +}; + +/*! + * \brief CoreML runtime. + * + * This runtime can be accessed in various language via + * TVM runtime PackedFunc API. + */ +class CoreMLRuntime : public ModuleNode { + public: + /*! + * \brief Get member function to front-end. + * \param name The name of the function. + * \param sptr_to_self The pointer to the module node. + * \return The corresponding member function. + */ + virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self); + + /*! + * \brief Serialize the content of the mlmodelc directory and save it to + * binary stream. + * \param stream The binary stream to save to. + */ + void SaveToBinary(dmlc::Stream* stream) final; + + /*! + * \return The type key of the executor. + */ + const char* type_key() const { return "coreml"; } + + /*! + * \brief Initialize the coreml runtime with coreml model and context. + * \param model_dir The directory where compiled models are located. + */ + void Init(const std::string& model_dir); + + /*! + * \brief Get coreml model. + * \param model_name The name of the model. + */ + CoreMLModel& GetModel(const std::string& model_name); + + // Map of the avaiable CoreML models + std::unordered_map> model_map_; +}; + +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_CONTRIB_COREML_COREML_RUNTIME_H_ diff --git a/src/runtime/contrib/coreml/coreml_runtime.mm b/src/runtime/contrib/coreml/coreml_runtime.mm new file mode 100644 index 000000000000..e6d22517d20f --- /dev/null +++ b/src/runtime/contrib/coreml/coreml_runtime.mm @@ -0,0 +1,269 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file coreml_runtime.cc + */ +#include + +#include "coreml_runtime.h" + +namespace tvm { +namespace runtime { + +void CoreMLModel::Invoke() { + id input = [[MLDictionaryFeatureProvider alloc] initWithDictionary:input_dict_ + error:nil]; + output_ = [model_ predictionFromFeatures:input error:nil]; +} + +void CoreMLModel::SetInput(const std::string& key, DLTensor* data_in) { + int64_t size = 1; + NSMutableArray* shape = [[NSMutableArray alloc] init]; + for (int64_t i = 0; i < data_in->ndim; ++i) { + size *= data_in->shape[i]; + [shape addObject:[NSNumber numberWithInteger:data_in->shape[i]]]; + } + + DataType dtype(data_in->dtype); + MLMultiArrayDataType dataType; + if (dtype == DataType::Float(64)) { + dataType = MLMultiArrayDataTypeDouble; + size *= sizeof(double); + } else if (dtype == DataType::Float(32)) { + dataType = MLMultiArrayDataTypeFloat32; + size *= sizeof(float); + } else { + LOG(FATAL) << "unsupported data type " << dtype; + return; + } + + MLMultiArray* dest = [[MLMultiArray alloc] initWithShape:shape dataType:dataType error:nil]; + + CHECK(data_in->strides == NULL); + memcpy(dest.dataPointer, data_in->data, size); + + NSString* nsKey = [NSString stringWithUTF8String:key.c_str()]; + [input_dict_ setObject:dest forKey:nsKey]; +} + +NDArray CoreMLModel::GetOutput(int index) const { + MLModelDescription* model_desc = model_.modelDescription; + NSString* metadata = [model_desc metadata][MLModelDescriptionKey]; + NSData* data = [metadata dataUsingEncoding:NSUTF8StringEncoding]; + NSDictionary* json = [NSJSONSerialization JSONObjectWithData:data + options:NSJSONReadingAllowFragments + error:nil]; + NSString* name = json[@"outputs"][index]; + MLFeatureDescription* output_desc = model_desc.outputDescriptionsByName[name]; + MLMultiArrayConstraint* data_desc = output_desc.multiArrayConstraint; + std::vector shape; + int64_t size = 1; + for (int64_t i = 0; i < data_desc.shape.count; ++i) { + int n = data_desc.shape[i].intValue; + size *= n; + shape.push_back(n); + } + + DataType dtype; + if (data_desc.dataType == MLMultiArrayDataTypeDouble) { + dtype = DataType::Float(64); + size *= sizeof(double); + } else if (data_desc.dataType == MLMultiArrayDataTypeFloat32) { + dtype = DataType::Float(32); + size *= sizeof(float); + } else { + LOG(FATAL) << "unexpected data type " << data_desc.dataType; + } + MLMultiArray* src = [output_ featureValueForName:name].multiArrayValue; + TVMContext cpu_ctx = { + .device_type = kDLCPU, + .device_id = 0, + }; + NDArray ret = NDArray::Empty(shape, dtype, cpu_ctx); + ret.CopyFromBytes(src.dataPointer, size); + + return ret; +} + +int CoreMLModel::GetNumOutputs() const { + MLModelDescription* model_desc = model_.modelDescription; + return [[model_desc outputDescriptionsByName] count]; +} + +void CoreMLRuntime::Init(const std::string& _model_dir) { + NSString* model_dir = [NSString stringWithUTF8String:(_model_dir).c_str()]; + if (![model_dir hasPrefix:@"/"]) { + // find models in the bundle's framework + NSBundle* bundle = [NSBundle mainBundle]; + NSString* base = [bundle privateFrameworksPath]; + model_dir = [base stringByAppendingPathComponent:model_dir]; + } + NSFileManager* fileMamager = [NSFileManager defaultManager]; + NSArray* files = [fileMamager contentsOfDirectoryAtPath:model_dir error:nil]; + for (NSString* file in files) { + if ([[file pathExtension] isEqualToString:@"mlmodelc"]) { + NSString* model_path = [model_dir stringByAppendingPathComponent:file]; + NSURL* url = [NSURL fileURLWithPath:model_path]; + const std::string& model_name = [[file stringByDeletingPathExtension] UTF8String]; + model_map_[model_name] = std::unique_ptr(new CoreMLModel(url)); + } + } +} + +CoreMLModel& CoreMLRuntime::GetModel(const std::string& model_name) { + CHECK(model_map_.count(model_name) > 0) << "No such model in this module: " << model_name; + return *model_map_[model_name]; +} + +PackedFunc CoreMLRuntime::GetFunction(const std::string& name, + const ObjectPtr& sptr_to_self) { + // Return member functions during query. + if (name == "invoke") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { GetModel("main").Invoke(); }); + } else if (name == "set_input") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + const auto& input_name = args[0].operator std::string(); + GetModel("main").SetInput(input_name, args[1]); + }); + } else if (name == "get_output") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + *rv = GetModel("main").GetOutput(args[0]); + }); + } else if (name == "get_num_outputs") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + *rv = GetModel("main").GetNumOutputs(); + }); + } else { + // Return the packedfunc which executes the subgraph. + return PackedFunc([sptr_to_self, name, this](TVMArgs args, TVMRetValue* rv) { + CoreMLModel& model = GetModel(name); + MLModelDescription* model_desc = [model.model_ modelDescription]; + NSString* metadata = [model_desc metadata][MLModelDescriptionKey]; + NSData* data = [metadata dataUsingEncoding:NSUTF8StringEncoding]; + NSDictionary* json = [NSJSONSerialization JSONObjectWithData:data + options:NSJSONReadingAllowFragments + error:nil]; + NSArray* input_names = json[@"inputs"]; + + // Copy input tensors to corresponding data entries. + for (auto i = 0; i < args.size() - 1; ++i) { + CHECK(args[i].type_code() == kTVMDLTensorHandle || args[i].type_code() == kTVMNDArrayHandle) + << "Expect NDArray or DLTensor as inputs\n"; + if (args[i].type_code() == kTVMDLTensorHandle) { + model.SetInput([input_names[i] UTF8String], args[i]); + } else { + LOG(FATAL) << "Not implemented"; + } + } + + // Execute the subgraph. + model.Invoke(); + + // TODO: Support multiple outputs. + NDArray out = model.GetOutput(0); + if (args[args.size() - 1].type_code() == kTVMDLTensorHandle) { + DLTensor* arg = args[args.size() - 1]; + out.CopyTo(arg); + } else { + NDArray arg = args[args.size() - 1]; + out.CopyTo(arg); + } + *rv = out; + }); + } +} + +Module CoreMLRuntimeCreate(const std::string& model_dir) { + auto exec = make_object(); + exec->Init(model_dir); + return Module(exec); +} + +TVM_REGISTER_GLOBAL("tvm.coreml_runtime.create").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = CoreMLRuntimeCreate(args[0]); +}); + +void CoreMLRuntime::SaveToBinary(dmlc::Stream* stream) { + stream->Write((uint32_t)model_map_.size()); + for (const auto& kv : model_map_) { + const std::string& model_name = kv.first; + NSURL* url = kv.second->url_; + NSFileWrapper* dirWrapper = [[[NSFileWrapper alloc] initWithURL:url options:0 + error:nil] autorelease]; + NSData* dirData = [dirWrapper serializedRepresentation]; + stream->Write(model_name); + stream->Write((uint64_t)[dirData length]); + stream->Write([dirData bytes], [dirData length]); + LOG(INFO) << "Save " << model_name << " (" << [dirData length] << " bytes)"; + } +} + +/*! + * \brief Load a CoreML module from stream. + * + * \param strm The binary stream to load json. + * + * \return The created CoreML module. + */ +Module CoreMLRuntimeLoadFromBinary(void* strm) { + dmlc::Stream* stream = static_cast(strm); + + uint32_t nr_models; + stream->Read(&nr_models); + + NSString* tempBaseDir = NSTemporaryDirectory(); + if (tempBaseDir == nil) tempBaseDir = @"/tmp"; + + NSString* templateStr = [tempBaseDir stringByAppendingPathComponent:@"tvm.XXXXXX"]; + const char* fsTemplate = [templateStr fileSystemRepresentation]; + NSMutableData* bufferData = [NSMutableData dataWithBytes:fsTemplate + length:strlen(fsTemplate) + 1]; + char* buffer = (char*)[bufferData mutableBytes]; + char* result = mkdtemp(buffer); + NSString* tempDir = [NSString stringWithUTF8String:result]; + + for (int i = 0; i < nr_models; i++) { + std::string model_name; + stream->Read(&model_name); + uint64_t length; + stream->Read(&length); + void* ptr = new char[length]; + stream->Read(ptr, length); + NSData* data = [[NSData alloc] initWithBytesNoCopy:ptr length:length]; + NSFileWrapper* dirWrapper = + [[[NSFileWrapper alloc] initWithSerializedRepresentation:data] autorelease]; + NSString* model_dir = [tempDir + stringByAppendingPathComponent:[NSString stringWithUTF8String:(model_name + ".mlmodelc") + .c_str()]]; + NSURL* url = [NSURL fileURLWithPath:model_dir]; + BOOL res = [dirWrapper writeToURL:url options:0 originalContentsURL:nil error:nil]; + CHECK(res) << "Failed to create model directory " << [model_dir UTF8String]; + } + + auto exec = make_object(); + exec->Init([tempDir UTF8String]); + return Module(exec); +} + +TVM_REGISTER_GLOBAL("runtime.module.loadbinary_coreml").set_body_typed(CoreMLRuntimeLoadFromBinary); + +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/contrib/cublas/cublas.cc b/src/runtime/contrib/cublas/cublas.cc index 5424f4cdcddf..ff204457d1c4 100644 --- a/src/runtime/contrib/cublas/cublas.cc +++ b/src/runtime/contrib/cublas/cublas.cc @@ -20,152 +20,98 @@ /*! * \file Use external cblas library call. */ -#include -#include #include +#include +#include + #include "../cblas/gemm_common.h" #include "cublas_utils.h" - namespace tvm { namespace contrib { using namespace runtime; -inline cublasOperation_t BooleanToTranspose(bool item) { - return item ? CUBLAS_OP_T : CUBLAS_OP_N; -} +inline cublasOperation_t BooleanToTranspose(bool item) { return item ? CUBLAS_OP_T : CUBLAS_OP_N; } inline void TryEnableTensorCore(cublasHandle_t hdl) { // TensorCores are only supported in cublas 9.0 or higher int version; CHECK_CUBLAS_ERROR(cublasGetVersion(hdl, &version)); - if (version >= 9000) - CHECK_CUBLAS_ERROR(cublasSetMathMode(hdl, CUBLAS_TENSOR_OP_MATH)); + if (version >= 9000) CHECK_CUBLAS_ERROR(cublasSetMathMode(hdl, CUBLAS_TENSOR_OP_MATH)); } struct CublasHgemmOp { typedef half TDatatype; cublasHandle_t handle; - explicit CublasHgemmOp(cublasHandle_t hdl) - : handle(hdl) {} - - void operator()(bool ta, bool tb, - int M, int N, int K, - half alpha, half* A, int lda, - half* B, int ldb, - half beta, half* C, int ldc) { - CHECK_CUBLAS_ERROR(cublasHgemm(handle, - BooleanToTranspose(ta), - BooleanToTranspose(tb), - M, N, K, - &alpha, A, lda, - B, ldb, - &beta, C, ldc)); + explicit CublasHgemmOp(cublasHandle_t hdl) : handle(hdl) {} + + void operator()(bool ta, bool tb, int M, int N, int K, half alpha, half* A, int lda, half* B, + int ldb, half beta, half* C, int ldc) { + CHECK_CUBLAS_ERROR(cublasHgemm(handle, BooleanToTranspose(ta), BooleanToTranspose(tb), M, N, K, + &alpha, A, lda, B, ldb, &beta, C, ldc)); } }; struct CublasSgemmOp { typedef float TDatatype; cublasHandle_t handle; - explicit CublasSgemmOp(cublasHandle_t hdl) - : handle(hdl) {} - - void operator()(bool ta, bool tb, - int M, int N, int K, - float alpha, float* A, int lda, - float* B, int ldb, - float beta, float* C, int ldc) { - CHECK_CUBLAS_ERROR(cublasSgemm(handle, - BooleanToTranspose(ta), - BooleanToTranspose(tb), - M, N, K, - &alpha, A, lda, - B, ldb, - &beta, C, ldc)); + explicit CublasSgemmOp(cublasHandle_t hdl) : handle(hdl) {} + + void operator()(bool ta, bool tb, int M, int N, int K, float alpha, float* A, int lda, float* B, + int ldb, float beta, float* C, int ldc) { + CHECK_CUBLAS_ERROR(cublasSgemm(handle, BooleanToTranspose(ta), BooleanToTranspose(tb), M, N, K, + &alpha, A, lda, B, ldb, &beta, C, ldc)); } }; struct CublasDgemmOp { typedef double TDatatype; cublasHandle_t handle; - explicit CublasDgemmOp(cublasHandle_t hdl) - : handle(hdl) {} - void operator()(bool ta, bool tb, - int M, int N, int K, - double alpha, double* A, int lda, - double* B, int ldb, - double beta, double* C, int ldc) { - CHECK_CUBLAS_ERROR(cublasDgemm(handle, - BooleanToTranspose(ta), - BooleanToTranspose(tb), - M, N, K, - &alpha, A, lda, - B, ldb, - &beta, C, ldc)); + explicit CublasDgemmOp(cublasHandle_t hdl) : handle(hdl) {} + void operator()(bool ta, bool tb, int M, int N, int K, double alpha, double* A, int lda, + double* B, int ldb, double beta, double* C, int ldc) { + CHECK_CUBLAS_ERROR(cublasDgemm(handle, BooleanToTranspose(ta), BooleanToTranspose(tb), M, N, K, + &alpha, A, lda, B, ldb, &beta, C, ldc)); } }; struct CublasHgemmBatchOp { typedef half TDatatype; cublasHandle_t handle; - explicit CublasHgemmBatchOp(cublasHandle_t hdl) - : handle(hdl) {} + explicit CublasHgemmBatchOp(cublasHandle_t hdl) : handle(hdl) {} void operator()(int batch_size, bool ta, bool tb, int M, int N, int K, half alpha, half* A, int a_stride, int lda, half* B, int b_stride, int ldb, half beta, half* C, int c_stride, int ldc) { - CHECK_CUBLAS_ERROR(cublasHgemmStridedBatched(handle, - BooleanToTranspose(ta), - BooleanToTranspose(tb), - M, N, K, - &alpha, - A, lda, a_stride, - B, ldb, b_stride, - &beta, - C, ldc, c_stride, - batch_size)); + CHECK_CUBLAS_ERROR(cublasHgemmStridedBatched( + handle, BooleanToTranspose(ta), BooleanToTranspose(tb), M, N, K, &alpha, A, lda, a_stride, + B, ldb, b_stride, &beta, C, ldc, c_stride, batch_size)); } }; struct CublasSgemmBatchOp { typedef float TDatatype; cublasHandle_t handle; - explicit CublasSgemmBatchOp(cublasHandle_t hdl) - : handle(hdl) {} + explicit CublasSgemmBatchOp(cublasHandle_t hdl) : handle(hdl) {} void operator()(int batch_size, bool ta, bool tb, int M, int N, int K, float alpha, float* A, int a_stride, int lda, float* B, int b_stride, int ldb, float beta, float* C, int c_stride, int ldc) { - CHECK_CUBLAS_ERROR(cublasSgemmStridedBatched(handle, - BooleanToTranspose(ta), - BooleanToTranspose(tb), - M, N, K, - &alpha, - A, lda, a_stride, - B, ldb, b_stride, - &beta, - C, ldc, c_stride, - batch_size)); + CHECK_CUBLAS_ERROR(cublasSgemmStridedBatched( + handle, BooleanToTranspose(ta), BooleanToTranspose(tb), M, N, K, &alpha, A, lda, a_stride, + B, ldb, b_stride, &beta, C, ldc, c_stride, batch_size)); } }; struct CublasDgemmBatchOp { typedef double TDatatype; cublasHandle_t handle; - explicit CublasDgemmBatchOp(cublasHandle_t hdl) - : handle(hdl) {} + explicit CublasDgemmBatchOp(cublasHandle_t hdl) : handle(hdl) {} void operator()(int batch_size, bool ta, bool tb, int M, int N, int K, double alpha, double* A, int a_stride, int lda, double* B, int b_stride, int ldb, double beta, double* C, int c_stride, int ldc) { - CHECK_CUBLAS_ERROR(cublasDgemmStridedBatched(handle, - BooleanToTranspose(ta), - BooleanToTranspose(tb), - M, N, K, - &alpha, - A, lda, a_stride, - B, ldb, b_stride, - &beta, - C, ldc, c_stride, - batch_size)); + CHECK_CUBLAS_ERROR(cublasDgemmStridedBatched( + handle, BooleanToTranspose(ta), BooleanToTranspose(tb), M, N, K, &alpha, A, lda, a_stride, + B, ldb, b_stride, &beta, C, ldc, c_stride, batch_size)); } }; @@ -174,22 +120,19 @@ bool CheckMixPrecisionType(DLDataType in_dtype, DLDataType out_dtype, bool int_s if (int_support && TypeMatch(out_dtype, kDLInt, 32)) { return TypeMatch(in_dtype, kDLInt, 8); } else if (TypeMatch(out_dtype, kDLFloat, 32)) { - return TypeMatch(in_dtype, kDLInt, 8) || - TypeMatch(in_dtype, kDLFloat, 16); + return TypeMatch(in_dtype, kDLInt, 8) || TypeMatch(in_dtype, kDLFloat, 16); } else { return false; } } -int roundoff(int v, int d) { - return (v + d - 1) / d * d; -} +int roundoff(int v, int d) { return (v + d - 1) / d * d; } #if CUDART_VERSION >= 10010 -inline void CallLtIgemm(TVMArgs args, TVMRetValue *ret, cublasLtHandle_t hdl) { - DLTensor *A = args[0]; - DLTensor *B = args[1]; - DLTensor *C = args[2]; +inline void CallLtIgemm(TVMArgs args, TVMRetValue* ret, cublasLtHandle_t hdl) { + DLTensor* A = args[0]; + DLTensor* B = args[1]; + DLTensor* C = args[2]; bool transa = args[3]; bool transb = args[4]; // Reversed strides indicates an in-place transpose operation. @@ -230,53 +173,37 @@ inline void CallLtIgemm(TVMArgs args, TVMRetValue *ret, cublasLtHandle_t hdl) { cublasLtOrder_t order_COL4_4R2_8C = CUBLASLT_ORDER_COL4_4R2_8C; cublasLtMatmulDesc_t operationDesc = nullptr; CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate(&operationDesc, CUDA_R_32I)); - CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( - operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opTranspose, sizeof(opTranspose))); + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, + &opTranspose, sizeof(opTranspose))); cublasOperation_t opTransA = BooleanToTranspose(transa); cublasOperation_t opTransB = BooleanToTranspose(transb); - CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( - operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &opTransA, sizeof(opTransA))); - CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute( - operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opTransB, sizeof(opTransB))); + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, + &opTransA, sizeof(opTransA))); + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, + &opTransB, sizeof(opTransB))); // Create descriptors for the original matrices - CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate( - &Adesc, CUDA_R_8I, opTransA == CUBLAS_OP_N ? m : k , - opTransA == CUBLAS_OP_N ? k : m, lda)); - CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate( - &Bdesc, CUDA_R_8I, opTransB == CUBLAS_OP_N ? k : n , - opTransB == CUBLAS_OP_N ? n : k, ldb)); + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&Adesc, CUDA_R_8I, opTransA == CUBLAS_OP_N ? m : k, + opTransA == CUBLAS_OP_N ? k : m, lda)); + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&Bdesc, CUDA_R_8I, opTransB == CUBLAS_OP_N ? k : n, + opTransB == CUBLAS_OP_N ? n : k, ldb)); CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&Cdesc, CUDA_R_32I, m, n, ldc)); + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(Adesc, CUBLASLT_MATRIX_LAYOUT_ORDER, + &order_COL32, sizeof(order_COL32))); CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( - Adesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_COL32, sizeof(order_COL32))); - CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( - Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_COL4_4R2_8C, sizeof(order_COL4_4R2_8C))); - CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute( - Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_COL32, sizeof(order_COL32))); - - CHECK_CUBLAS_ERROR(cublasLtMatmul(hdl, - operationDesc, - &alpha, - B_data, - Adesc, - A_data, - Bdesc, - &beta, - C_data, - Cdesc, - C_data, - Cdesc, - NULL, - NULL, - 0, - 0)); + Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_COL4_4R2_8C, sizeof(order_COL4_4R2_8C))); + CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, + &order_COL32, sizeof(order_COL32))); + + CHECK_CUBLAS_ERROR(cublasLtMatmul(hdl, operationDesc, &alpha, B_data, Adesc, A_data, Bdesc, &beta, + C_data, Cdesc, C_data, Cdesc, NULL, NULL, 0, 0)); } #endif -inline void CallGemmEx(TVMArgs args, TVMRetValue *ret, cublasHandle_t hdl) { - DLTensor *A = args[0]; - DLTensor *B = args[1]; - DLTensor *C = args[2]; +inline void CallGemmEx(TVMArgs args, TVMRetValue* ret, cublasHandle_t hdl) { + DLTensor* A = args[0]; + DLTensor* B = args[1]; + DLTensor* C = args[2]; bool transa = args[3]; bool transb = args[4]; CHECK_EQ(A->ndim, 2); @@ -297,10 +224,10 @@ inline void CallGemmEx(TVMArgs args, TVMRetValue *ret, cublasHandle_t hdl) { transb = IsInPlaceTransposed(B) ? !transb : transb; CHECK(CheckMixPrecisionType(A->dtype, C->dtype)) << "Unsupported data type"; - CHECK(!TypeMatch(A->dtype, kDLInt, 8) || - ColumnStride(A) % 4 == 0) << "leading dimension must divide 4 for int8 gemm"; - CHECK(!TypeMatch(B->dtype, kDLInt, 8) || - ColumnStride(B) % 4 == 0) << "leading dimension must divide 4 for int8 gemm"; + CHECK(!TypeMatch(A->dtype, kDLInt, 8) || ColumnStride(A) % 4 == 0) + << "leading dimension must divide 4 for int8 gemm"; + CHECK(!TypeMatch(B->dtype, kDLInt, 8) || ColumnStride(B) % 4 == 0) + << "leading dimension must divide 4 for int8 gemm"; double alpha = args.size() > 5 ? args[5] : 1.0; double beta = args.size() > 6 ? args[6] : 0.0; @@ -320,28 +247,21 @@ inline void CallGemmEx(TVMArgs args, TVMRetValue *ret, cublasHandle_t hdl) { beta_ptr = &beta_float; } - auto A_data = reinterpret_cast(static_cast(A->data) + A->byte_offset); - auto B_data = reinterpret_cast(static_cast(B->data) + B->byte_offset); - auto C_data = reinterpret_cast(static_cast(C->data) + C->byte_offset); - - CHECK_CUBLAS_ERROR(cublasGemmEx(hdl, - BooleanToTranspose(transb), - BooleanToTranspose(transa), - ColumnCount(B, transb), - RowCount(A, transa), - ColumnCount(A, transa), - alpha_ptr, - B_data, cuda_in_type, ColumnStride(B), - A_data, cuda_in_type, ColumnStride(A), - beta_ptr, - C_data, cuda_out_type, ColumnStride(C), - cuda_out_type, algo)); + auto A_data = reinterpret_cast(static_cast(A->data) + A->byte_offset); + auto B_data = reinterpret_cast(static_cast(B->data) + B->byte_offset); + auto C_data = reinterpret_cast(static_cast(C->data) + C->byte_offset); + + CHECK_CUBLAS_ERROR(cublasGemmEx(hdl, BooleanToTranspose(transb), BooleanToTranspose(transa), + ColumnCount(B, transb), RowCount(A, transa), + ColumnCount(A, transa), alpha_ptr, B_data, cuda_in_type, + ColumnStride(B), A_data, cuda_in_type, ColumnStride(A), beta_ptr, + C_data, cuda_out_type, ColumnStride(C), cuda_out_type, algo)); } -inline void CallBatchGemmEx(TVMArgs args, TVMRetValue *ret, cublasHandle_t hdl) { - DLTensor *A = args[0]; - DLTensor *B = args[1]; - DLTensor *C = args[2]; +inline void CallBatchGemmEx(TVMArgs args, TVMRetValue* ret, cublasHandle_t hdl) { + DLTensor* A = args[0]; + DLTensor* B = args[1]; + DLTensor* C = args[2]; bool transa = args[3]; bool transb = args[4]; CHECK_EQ(A->ndim, 3); @@ -364,10 +284,10 @@ inline void CallBatchGemmEx(TVMArgs args, TVMRetValue *ret, cublasHandle_t hdl) transb = IsInPlaceTransposed(B) ? !transb : transb; CHECK(CheckMixPrecisionType(A->dtype, C->dtype, false)) << "Unsupported data type"; - CHECK(!TypeMatch(A->dtype, kDLInt, 8) || - ColumnStride(A) % 4 == 0) << "leading dimension must divide 4 for int8 gemm"; - CHECK(!TypeMatch(B->dtype, kDLInt, 8) || - ColumnStride(B) % 4 == 0) << "leading dimension must divide 4 for int8 gemm"; + CHECK(!TypeMatch(A->dtype, kDLInt, 8) || ColumnStride(A) % 4 == 0) + << "leading dimension must divide 4 for int8 gemm"; + CHECK(!TypeMatch(B->dtype, kDLInt, 8) || ColumnStride(B) % 4 == 0) + << "leading dimension must divide 4 for int8 gemm"; double alpha = args.size() > 5 ? args[5] : 1.0; double beta = args.size() > 6 ? args[6] : 0.0; @@ -391,88 +311,76 @@ inline void CallBatchGemmEx(TVMArgs args, TVMRetValue *ret, cublasHandle_t hdl) beta_ptr = &beta_float; } - auto A_data = reinterpret_cast(static_cast(A->data) + A->byte_offset); - auto B_data = reinterpret_cast(static_cast(B->data) + B->byte_offset); - auto C_data = reinterpret_cast(static_cast(C->data) + C->byte_offset); - CHECK_CUBLAS_ERROR(cublasGemmStridedBatchedEx(hdl, - BooleanToTranspose(transb), - BooleanToTranspose(transa), - ColumnCount3D(B, transb), - RowCount3D(A, transa), - ColumnCount3D(A, transa), - alpha_ptr, - B_data, cuda_in_type, ColumnStride3D(B), B_size, - A_data, cuda_in_type, ColumnStride3D(A), A_size, - beta_ptr, - C_data, cuda_out_type, ColumnStride3D(C), C_size, - batch_size, cuda_out_type, algo)); + auto A_data = reinterpret_cast(static_cast(A->data) + A->byte_offset); + auto B_data = reinterpret_cast(static_cast(B->data) + B->byte_offset); + auto C_data = reinterpret_cast(static_cast(C->data) + C->byte_offset); + CHECK_CUBLAS_ERROR(cublasGemmStridedBatchedEx( + hdl, BooleanToTranspose(transb), BooleanToTranspose(transa), ColumnCount3D(B, transb), + RowCount3D(A, transa), ColumnCount3D(A, transa), alpha_ptr, B_data, cuda_in_type, + ColumnStride3D(B), B_size, A_data, cuda_in_type, ColumnStride3D(A), A_size, beta_ptr, C_data, + cuda_out_type, ColumnStride3D(C), C_size, batch_size, cuda_out_type, algo)); } // matrix multiplication for row major -TVM_REGISTER_GLOBAL("tvm.contrib.cublas.matmul") -.set_body([](TVMArgs args, TVMRetValue *ret) { - DLTensor* A = args[0]; - DLTensor* C = args[2]; +TVM_REGISTER_GLOBAL("tvm.contrib.cublas.matmul").set_body([](TVMArgs args, TVMRetValue* ret) { + DLTensor* A = args[0]; + DLTensor* C = args[2]; - CuBlasThreadEntry* entry_ptr = CuBlasThreadEntry::ThreadLocal(); + CuBlasThreadEntry* entry_ptr = CuBlasThreadEntry::ThreadLocal(); - TryEnableTensorCore(entry_ptr->handle); + TryEnableTensorCore(entry_ptr->handle); - if (TypeEqual(A->dtype, C->dtype)) { - CHECK(TypeMatch(A->dtype, kDLFloat, 16) || - TypeMatch(A->dtype, kDLFloat, 32) || + if (TypeEqual(A->dtype, C->dtype)) { + CHECK(TypeMatch(A->dtype, kDLFloat, 16) || TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); - if (TypeMatch(A->dtype, kDLFloat, 16)) - CallGemm(args, ret, CublasHgemmOp(entry_ptr->handle)); - else if (TypeMatch(A->dtype, kDLFloat, 32)) - CallGemm(args, ret, CublasSgemmOp(entry_ptr->handle)); - else - CallGemm(args, ret, CublasDgemmOp(entry_ptr->handle)); - } else { - CallGemmEx(args, ret, entry_ptr->handle); - } + if (TypeMatch(A->dtype, kDLFloat, 16)) + CallGemm(args, ret, CublasHgemmOp(entry_ptr->handle)); + else if (TypeMatch(A->dtype, kDLFloat, 32)) + CallGemm(args, ret, CublasSgemmOp(entry_ptr->handle)); + else + CallGemm(args, ret, CublasDgemmOp(entry_ptr->handle)); + } else { + CallGemmEx(args, ret, entry_ptr->handle); + } }); #if CUDART_VERSION >= 10010 -TVM_REGISTER_GLOBAL("tvm.contrib.cublaslt.matmul") -.set_body([](TVMArgs args, TVMRetValue* ret) { - DLTensor* A = args[0]; +TVM_REGISTER_GLOBAL("tvm.contrib.cublaslt.matmul").set_body([](TVMArgs args, TVMRetValue* ret) { + DLTensor* A = args[0]; - CuBlasThreadEntry* entry_ptr = CuBlasThreadEntry::ThreadLocal(); + CuBlasThreadEntry* entry_ptr = CuBlasThreadEntry::ThreadLocal(); - TryEnableTensorCore(entry_ptr->handle); + TryEnableTensorCore(entry_ptr->handle); - CHECK(TypeMatch(A->dtype, kDLInt, 8)) << "Expects dtype to be int8\n"; - cublasLtHandle_t ltHandle; - CHECK_CUBLAS_ERROR(cublasLtCreate(<Handle)); - CallLtIgemm(args, ret, ltHandle); - CHECK_CUBLAS_ERROR(cublasLtDestroy(ltHandle)); + CHECK(TypeMatch(A->dtype, kDLInt, 8)) << "Expects dtype to be int8\n"; + cublasLtHandle_t ltHandle; + CHECK_CUBLAS_ERROR(cublasLtCreate(<Handle)); + CallLtIgemm(args, ret, ltHandle); + CHECK_CUBLAS_ERROR(cublasLtDestroy(ltHandle)); }); #endif // CUDART_VERSION >= 10010 -TVM_REGISTER_GLOBAL("tvm.contrib.cublas.batch_matmul") -.set_body([](TVMArgs args, TVMRetValue* ret) { - DLTensor* A = args[0]; - DLTensor* C = args[2]; +TVM_REGISTER_GLOBAL("tvm.contrib.cublas.batch_matmul").set_body([](TVMArgs args, TVMRetValue* ret) { + DLTensor* A = args[0]; + DLTensor* C = args[2]; - CuBlasThreadEntry* entry_ptr = CuBlasThreadEntry::ThreadLocal(); + CuBlasThreadEntry* entry_ptr = CuBlasThreadEntry::ThreadLocal(); - TryEnableTensorCore(entry_ptr->handle); - if (TypeEqual(A->dtype, C->dtype)) { - CHECK(TypeMatch(A->dtype, kDLFloat, 16) || - TypeMatch(A->dtype, kDLFloat, 32) || + TryEnableTensorCore(entry_ptr->handle); + if (TypeEqual(A->dtype, C->dtype)) { + CHECK(TypeMatch(A->dtype, kDLFloat, 16) || TypeMatch(A->dtype, kDLFloat, 32) || TypeMatch(A->dtype, kDLFloat, 64)); - if (TypeMatch(A->dtype, kDLFloat, 16)) - CallBatchGemm(args, ret, CublasHgemmBatchOp(entry_ptr->handle)); - else if (TypeMatch(A->dtype, kDLFloat, 32)) - CallBatchGemm(args, ret, CublasSgemmBatchOp(entry_ptr->handle)); - else - CallBatchGemm(args, ret, CublasDgemmBatchOp(entry_ptr->handle)); - } else { - CallBatchGemmEx(args, ret, entry_ptr->handle); - } + if (TypeMatch(A->dtype, kDLFloat, 16)) + CallBatchGemm(args, ret, CublasHgemmBatchOp(entry_ptr->handle)); + else if (TypeMatch(A->dtype, kDLFloat, 32)) + CallBatchGemm(args, ret, CublasSgemmBatchOp(entry_ptr->handle)); + else + CallBatchGemm(args, ret, CublasDgemmBatchOp(entry_ptr->handle)); + } else { + CallBatchGemmEx(args, ret, entry_ptr->handle); + } }); } // namespace contrib diff --git a/src/runtime/contrib/cublas/cublas_utils.cc b/src/runtime/contrib/cublas/cublas_utils.cc index 9953cda32379..d4ec08770723 100644 --- a/src/runtime/contrib/cublas/cublas_utils.cc +++ b/src/runtime/contrib/cublas/cublas_utils.cc @@ -21,18 +21,16 @@ * \file Use external cudnn utils function */ #include "cublas_utils.h" + #include #include + #include "../../cuda/cuda_common.h" namespace tvm { namespace contrib { - -CuBlasThreadEntry::CuBlasThreadEntry() { - CHECK_CUBLAS_ERROR(cublasCreate(&handle)); -} - +CuBlasThreadEntry::CuBlasThreadEntry() { CHECK_CUBLAS_ERROR(cublasCreate(&handle)); } CuBlasThreadEntry::~CuBlasThreadEntry() { if (handle) { @@ -41,10 +39,8 @@ CuBlasThreadEntry::~CuBlasThreadEntry() { } } - typedef dmlc::ThreadLocalStore CuBlasThreadStore; - CuBlasThreadEntry* CuBlasThreadEntry::ThreadLocal() { auto stream = runtime::CUDAThreadEntry::ThreadLocal()->stream; CuBlasThreadEntry* retval = CuBlasThreadStore::Get(); @@ -52,6 +48,5 @@ CuBlasThreadEntry* CuBlasThreadEntry::ThreadLocal() { return retval; } - } // namespace contrib } // namespace tvm diff --git a/src/runtime/contrib/cublas/cublas_utils.h b/src/runtime/contrib/cublas/cublas_utils.h index 2e553e28493b..5189c4f483a8 100644 --- a/src/runtime/contrib/cublas/cublas_utils.h +++ b/src/runtime/contrib/cublas/cublas_utils.h @@ -24,11 +24,12 @@ #ifndef TVM_RUNTIME_CONTRIB_CUBLAS_CUBLAS_UTILS_H_ #define TVM_RUNTIME_CONTRIB_CUBLAS_CUBLAS_UTILS_H_ -#include -#include #include #include #include +#include +#include + #include #if CUDART_VERSION >= 10010 #include @@ -39,27 +40,35 @@ namespace contrib { inline const char* GetCublasErrorString(int error) { switch (error) { - case CUBLAS_STATUS_NOT_INITIALIZED: return "CUBLAS_STATUS_NOT_INITIALIZED"; - case CUBLAS_STATUS_ALLOC_FAILED: return "CUBLAS_STATUS_ALLOC_FAILED"; - case CUBLAS_STATUS_INVALID_VALUE: return "CUBLAS_STATUS_INVALID_VALUE"; - case CUBLAS_STATUS_ARCH_MISMATCH: return "CUBLAS_STATUS_ARCH_MISMATCH"; - case CUBLAS_STATUS_MAPPING_ERROR: return "CUBLAS_STATUS_MAPPING_ERROR"; - case CUBLAS_STATUS_EXECUTION_FAILED: return "CUBLAS_STATUS_EXECUTION_FAILED"; - case CUBLAS_STATUS_INTERNAL_ERROR: return "CUBLAS_STATUS_INTERNAL_ERROR"; - case CUBLAS_STATUS_NOT_SUPPORTED: return "CUBLAS_STATUS_NOT_SUPPORTED"; - case CUBLAS_STATUS_LICENSE_ERROR: return "CUBLAS_STATUS_LICENSE_ERROR"; + case CUBLAS_STATUS_NOT_INITIALIZED: + return "CUBLAS_STATUS_NOT_INITIALIZED"; + case CUBLAS_STATUS_ALLOC_FAILED: + return "CUBLAS_STATUS_ALLOC_FAILED"; + case CUBLAS_STATUS_INVALID_VALUE: + return "CUBLAS_STATUS_INVALID_VALUE"; + case CUBLAS_STATUS_ARCH_MISMATCH: + return "CUBLAS_STATUS_ARCH_MISMATCH"; + case CUBLAS_STATUS_MAPPING_ERROR: + return "CUBLAS_STATUS_MAPPING_ERROR"; + case CUBLAS_STATUS_EXECUTION_FAILED: + return "CUBLAS_STATUS_EXECUTION_FAILED"; + case CUBLAS_STATUS_INTERNAL_ERROR: + return "CUBLAS_STATUS_INTERNAL_ERROR"; + case CUBLAS_STATUS_NOT_SUPPORTED: + return "CUBLAS_STATUS_NOT_SUPPORTED"; + case CUBLAS_STATUS_LICENSE_ERROR: + return "CUBLAS_STATUS_LICENSE_ERROR"; } return "Unrecognized error"; } #ifndef CHECK_CUBLAS_ERROR -#define CHECK_CUBLAS_ERROR(fn) \ - do { \ - int error = static_cast(fn); \ +#define CHECK_CUBLAS_ERROR(fn) \ + do { \ + int error = static_cast(fn); \ CHECK_EQ(error, CUBLAS_STATUS_SUCCESS) << "CUBLAS: " << GetCublasErrorString(error); \ } while (0) // ; intentionally left off. -#endif // CHECK_CUBLAS_ERROR - +#endif // CHECK_CUBLAS_ERROR struct CuBlasThreadEntry { CuBlasThreadEntry(); @@ -71,19 +80,26 @@ struct CuBlasThreadEntry { inline cudaDataType_t GetCudaDataType(DLDataType type) { if (type.code == kDLInt) { switch (type.bits) { - case 8: return CUDA_R_8I; - case 32: return CUDA_R_32I; + case 8: + return CUDA_R_8I; + case 32: + return CUDA_R_32I; } } else if (type.code == kDLUInt) { switch (type.bits) { - case 8: return CUDA_R_8U; - case 32: return CUDA_R_32U; + case 8: + return CUDA_R_8U; + case 32: + return CUDA_R_32U; } } else if (type.code == kDLFloat) { switch (type.bits) { - case 16: return CUDA_R_16F; - case 32: return CUDA_R_32F; - case 64: return CUDA_R_64F; + case 16: + return CUDA_R_16F; + case 32: + return CUDA_R_32F; + case 64: + return CUDA_R_64F; } } LOG(FATAL) << "Unsupported cuda type"; diff --git a/src/runtime/contrib/cudnn/conv_forward.cc b/src/runtime/contrib/cudnn/conv_forward.cc index 95811332bbfa..223a5b4fe435 100644 --- a/src/runtime/contrib/cudnn/conv_forward.cc +++ b/src/runtime/contrib/cudnn/conv_forward.cc @@ -20,9 +20,10 @@ /*! * \file Use external cudnn utils function */ -#include #include #include +#include + #include "cudnn_utils.h" namespace tvm { @@ -30,18 +31,9 @@ namespace contrib { using namespace runtime; -void ConvolutionForward( - int mode, - int format, - int algo, - int dims, - const int pad[], - const int stride[], - const int dilation[], - DLTensor* x, - DLTensor* w, - DLTensor* y, - const std::string& conv_dtype) { +void ConvolutionForward(int mode, int format, int algo, int dims, int groups, const int pad[], + const int stride[], const int dilation[], DLTensor* x, DLTensor* w, + DLTensor* y, const std::string& conv_dtype) { CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(); // Set Mode entry_ptr->conv_entry.mode = static_cast(mode); @@ -62,19 +54,15 @@ void ConvolutionForward( // Note: For 2D tenor, using ND setters causes CUDNN_STATUS_NOT_SUPPORTED error // in following cudnnGetConvolutionForwardWorkspaceSize() when data type is fp16, int + + CUDNN_CALL(cudnnSetConvolutionGroupCount(entry_ptr->conv_entry.conv_desc, groups)); if (dims == 2) { - // Set Desc - CUDNN_CALL(cudnnSetConvolution2dDescriptor(entry_ptr->conv_entry.conv_desc, - pad[0], - pad[1], - stride[0], - stride[1], - dilation[0], - dilation[1], - entry_ptr->conv_entry.mode, - entry_ptr->conv_entry.data_type)); + // Set Desc + CUDNN_CALL(cudnnSetConvolution2dDescriptor( + entry_ptr->conv_entry.conv_desc, pad[0], pad[1], stride[0], stride[1], dilation[0], + dilation[1], entry_ptr->conv_entry.mode, entry_ptr->conv_entry.data_type)); int ni, ci, hi, wi; - if (entry_ptr->conv_entry.tensor_format == CUDNN_TENSOR_NHWC) { + if (entry_ptr->conv_entry.tensor_format == CUDNN_TENSOR_NHWC) { ni = 0; ci = 3; hi = 1; @@ -87,67 +75,46 @@ void ConvolutionForward( } // Set Filter - CUDNN_CALL(cudnnSetFilter4dDescriptor(entry_ptr->conv_entry.filter_desc, - data_type, - entry_ptr->conv_entry.tensor_format, - static_cast(w->shape[ni]), - static_cast(w->shape[ci]), - static_cast(w->shape[hi]), - static_cast(w->shape[wi]))); + CUDNN_CALL(cudnnSetFilter4dDescriptor( + entry_ptr->conv_entry.filter_desc, data_type, entry_ptr->conv_entry.tensor_format, + static_cast(w->shape[ni]), static_cast(w->shape[ci]), + static_cast(w->shape[hi]), static_cast(w->shape[wi]))); // Set Input - CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->conv_entry.input_desc, - entry_ptr->conv_entry.tensor_format, - data_type, - static_cast(x->shape[ni]), - static_cast(x->shape[ci]), - static_cast(x->shape[hi]), - static_cast(x->shape[wi]))); + CUDNN_CALL(cudnnSetTensor4dDescriptor( + entry_ptr->conv_entry.input_desc, entry_ptr->conv_entry.tensor_format, data_type, + static_cast(x->shape[ni]), static_cast(x->shape[ci]), + static_cast(x->shape[hi]), static_cast(x->shape[wi]))); // Set Output - CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->conv_entry.output_desc, - entry_ptr->conv_entry.tensor_format, - data_type, - static_cast(y->shape[ni]), - static_cast(y->shape[ci]), - static_cast(y->shape[hi]), - static_cast(y->shape[wi]))); + CUDNN_CALL(cudnnSetTensor4dDescriptor( + entry_ptr->conv_entry.output_desc, entry_ptr->conv_entry.tensor_format, data_type, + static_cast(y->shape[ni]), static_cast(y->shape[ci]), + static_cast(y->shape[hi]), static_cast(y->shape[wi]))); } else { - CUDNN_CALL(cudnnSetConvolutionNdDescriptor(entry_ptr->conv_entry.conv_desc, - dims, - pad, - stride, - dilation, - entry_ptr->conv_entry.mode, + CUDNN_CALL(cudnnSetConvolutionNdDescriptor(entry_ptr->conv_entry.conv_desc, dims, pad, stride, + dilation, entry_ptr->conv_entry.mode, entry_ptr->conv_entry.data_type)); // Set Filter for (int i = 0; i < full_dims; i++) { dim[i] = static_cast(w->shape[i]); } - CUDNN_CALL(cudnnSetFilterNdDescriptor(entry_ptr->conv_entry.filter_desc, - data_type, - entry_ptr->conv_entry.tensor_format, - full_dims, + CUDNN_CALL(cudnnSetFilterNdDescriptor(entry_ptr->conv_entry.filter_desc, data_type, + entry_ptr->conv_entry.tensor_format, full_dims, dim.data())); // Set Input for (int i = 0; i < full_dims; i++) { dim[i] = static_cast(x->shape[i]); } GetCudnnStride(full_dims, dim.data(), tensor_stride.data()); - CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.input_desc, - data_type, - full_dims, - dim.data(), - tensor_stride.data())); + CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.input_desc, data_type, full_dims, + dim.data(), tensor_stride.data())); // Set Output for (int i = 0; i < full_dims; i++) { dim[i] = static_cast(y->shape[i]); } GetCudnnStride(full_dims, dim.data(), tensor_stride.data()); - CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.output_desc, - data_type, - full_dims, - dim.data(), - tensor_stride.data())); + CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.output_desc, data_type, full_dims, + dim.data(), tensor_stride.data())); } if (cudnnGetVersion() > 7000) { @@ -156,41 +123,23 @@ void ConvolutionForward( // Set workspace size_t workspace_size = 0; - CUDNN_CALL(cudnnGetConvolutionForwardWorkspaceSize(entry_ptr->handle, - entry_ptr->conv_entry.input_desc, - entry_ptr->conv_entry.filter_desc, - entry_ptr->conv_entry.conv_desc, - entry_ptr->conv_entry.output_desc, - entry_ptr->conv_entry.fwd_algo, - &workspace_size)); + CUDNN_CALL(cudnnGetConvolutionForwardWorkspaceSize( + entry_ptr->handle, entry_ptr->conv_entry.input_desc, entry_ptr->conv_entry.filter_desc, + entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.output_desc, + entry_ptr->conv_entry.fwd_algo, &workspace_size)); entry_ptr->conv_entry.UpdateWorkspace(workspace_size); - CUDNN_CALL(cudnnConvolutionForward(entry_ptr->handle, - CuDNNDataType::GetConst<1>(entry_ptr->conv_entry.data_type), - entry_ptr->conv_entry.input_desc, - x->data, - entry_ptr->conv_entry.filter_desc, - w->data, - entry_ptr->conv_entry.conv_desc, - entry_ptr->conv_entry.fwd_algo, - entry_ptr->conv_entry.workspace, - workspace_size, - CuDNNDataType::GetConst<0>(entry_ptr->conv_entry.data_type), - entry_ptr->conv_entry.output_desc, - y->data)); + CUDNN_CALL(cudnnConvolutionForward( + entry_ptr->handle, CuDNNDataType::GetConst<1>(entry_ptr->conv_entry.data_type), + entry_ptr->conv_entry.input_desc, x->data, entry_ptr->conv_entry.filter_desc, w->data, + entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.fwd_algo, + entry_ptr->conv_entry.workspace, workspace_size, + CuDNNDataType::GetConst<0>(entry_ptr->conv_entry.data_type), + entry_ptr->conv_entry.output_desc, y->data)); } - -void OutputShape( - int format, - int dims, - const int pad[], - const int stride[], - const int dilation[], - const int x_dim[], - const int w_dim[], - void *out_shape, - const std::string& data_dtype, - const std::string& conv_dtype) { +void OutputShape(int format, int dims, int groups, const int pad[], const int stride[], + const int dilation[], const int x_dim[], const int w_dim[], void* out_shape, + const std::string& data_dtype, const std::string& conv_dtype) { CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(); // Set Data Type @@ -202,77 +151,47 @@ void OutputShape( int full_dims = dims + 2; // conv desc - CUDNN_CALL(cudnnSetConvolutionNdDescriptor(entry_ptr->conv_entry.conv_desc, - dims, - pad, - stride, - dilation, - CUDNN_CROSS_CORRELATION, + CUDNN_CALL(cudnnSetConvolutionGroupCount(entry_ptr->conv_entry.conv_desc, groups)); + CUDNN_CALL(cudnnSetConvolutionNdDescriptor(entry_ptr->conv_entry.conv_desc, dims, pad, stride, + dilation, CUDNN_CROSS_CORRELATION, entry_ptr->conv_entry.data_type)); - if (dims == 2 && entry_ptr->conv_entry.tensor_format == CUDNN_TENSOR_NHWC) { + if (dims == 2 && entry_ptr->conv_entry.tensor_format == CUDNN_TENSOR_NHWC) { // Set Input CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->conv_entry.input_desc, - entry_ptr->conv_entry.tensor_format, - data_type, - x_dim[0], - x_dim[3], - x_dim[1], - x_dim[2])); + entry_ptr->conv_entry.tensor_format, data_type, x_dim[0], + x_dim[3], x_dim[1], x_dim[2])); // filter desc - CUDNN_CALL(cudnnSetFilter4dDescriptor(entry_ptr->conv_entry.filter_desc, - data_type, - entry_ptr->conv_entry.tensor_format, - w_dim[0], - w_dim[3], - w_dim[1], - w_dim[2])); - - CUDNN_CALL(cudnnGetConvolution2dForwardOutputDim(entry_ptr->conv_entry.conv_desc, - entry_ptr->conv_entry.input_desc, - entry_ptr->conv_entry.filter_desc, - static_cast(out_shape), - static_cast(out_shape) + 3, - static_cast(out_shape) + 1, - static_cast(out_shape) + 2)); + CUDNN_CALL(cudnnSetFilter4dDescriptor(entry_ptr->conv_entry.filter_desc, data_type, + entry_ptr->conv_entry.tensor_format, w_dim[0], w_dim[3], + w_dim[1], w_dim[2])); + + CUDNN_CALL(cudnnGetConvolution2dForwardOutputDim( + entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.input_desc, + entry_ptr->conv_entry.filter_desc, static_cast(out_shape), + static_cast(out_shape) + 3, static_cast(out_shape) + 1, + static_cast(out_shape) + 2)); } else { // Set Input std::vector tensor_stride(full_dims); GetCudnnStride(full_dims, x_dim, tensor_stride.data()); - CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.input_desc, - data_type, - full_dims, - x_dim, - tensor_stride.data())); + + CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.input_desc, data_type, full_dims, + x_dim, tensor_stride.data())); // filter desc - CUDNN_CALL(cudnnSetFilterNdDescriptor(entry_ptr->conv_entry.filter_desc, - data_type, - entry_ptr->conv_entry.tensor_format, - full_dims, - w_dim)); - - CUDNN_CALL(cudnnGetConvolutionNdForwardOutputDim(entry_ptr->conv_entry.conv_desc, - entry_ptr->conv_entry.input_desc, - entry_ptr->conv_entry.filter_desc, - full_dims, - static_cast(out_shape))); + CUDNN_CALL(cudnnSetFilterNdDescriptor(entry_ptr->conv_entry.filter_desc, data_type, + entry_ptr->conv_entry.tensor_format, full_dims, w_dim)); + + CUDNN_CALL(cudnnGetConvolutionNdForwardOutputDim( + entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.input_desc, + entry_ptr->conv_entry.filter_desc, full_dims, static_cast(out_shape))); } } - -void FindAlgo( - int format, - int dims, - const int pad[], - const int stride[], - const int dilation[], - const int x_dim[], - const int w_dim[], - const int y_dim[], - const std::string& data_dtype, - const std::string& conv_dtype, - TVMRetValue *ret) { +void FindAlgo(int format, int dims, int groups, const int pad[], const int stride[], + const int dilation[], const int x_dim[], const int w_dim[], const int y_dim[], + const std::string& data_dtype, const std::string& conv_dtype, TVMRetValue* ret) { CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(); // Set Data Type @@ -284,65 +203,47 @@ void FindAlgo( int full_dims = dims + 2; // conv desc - CUDNN_CALL(cudnnSetConvolutionNdDescriptor(entry_ptr->conv_entry.conv_desc, - dims, - pad, - stride, - dilation, - CUDNN_CROSS_CORRELATION, + CUDNN_CALL(cudnnSetConvolutionGroupCount(entry_ptr->conv_entry.conv_desc, groups)); + CUDNN_CALL(cudnnSetConvolutionNdDescriptor(entry_ptr->conv_entry.conv_desc, dims, pad, stride, + dilation, CUDNN_CROSS_CORRELATION, entry_ptr->conv_entry.data_type)); std::vector tensor_stride(full_dims); // input desc GetCudnnStride(full_dims, x_dim, tensor_stride.data()); - CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.input_desc, - data_type, - full_dims, - x_dim, - tensor_stride.data())); + CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.input_desc, data_type, full_dims, + x_dim, tensor_stride.data())); // filter desc - CUDNN_CALL(cudnnSetFilterNdDescriptor(entry_ptr->conv_entry.filter_desc, - data_type, - entry_ptr->conv_entry.tensor_format, - full_dims, - w_dim)); + CUDNN_CALL(cudnnSetFilterNdDescriptor(entry_ptr->conv_entry.filter_desc, data_type, + entry_ptr->conv_entry.tensor_format, full_dims, w_dim)); // output desc GetCudnnStride(full_dims, y_dim, tensor_stride.data()); - CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.output_desc, - data_type, - full_dims, - y_dim, - tensor_stride.data())); + CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.output_desc, data_type, full_dims, + y_dim, tensor_stride.data())); if (cudnnGetVersion() > 7000) { CUDNN_CALL(cudnnSetConvolutionMathType(entry_ptr->conv_entry.conv_desc, CUDNN_TENSOR_OP_MATH)) } int returned_algo_count = 0; cudnnConvolutionFwdAlgoPerf_t perf_results[CUDNN_CONVOLUTION_FWD_ALGO_COUNT]; - CUDNN_CALL(cudnnFindConvolutionForwardAlgorithm(entry_ptr->handle, - entry_ptr->conv_entry.input_desc, - entry_ptr->conv_entry.filter_desc, - entry_ptr->conv_entry.conv_desc, - entry_ptr->conv_entry.output_desc, - CUDNN_CONVOLUTION_FWD_ALGO_COUNT, - &returned_algo_count, - perf_results)); - - const std::vector fwd_algo_names{ - "CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM", - "CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM", - "CUDNN_CONVOLUTION_FWD_ALGO_GEMM", - "CUDNN_CONVOLUTION_FWD_ALGO_DIRECT", - "CUDNN_CONVOLUTION_FWD_ALGO_FFT", - "CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING", - "CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD", - "CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED" - }; + CUDNN_CALL(cudnnFindConvolutionForwardAlgorithm( + entry_ptr->handle, entry_ptr->conv_entry.input_desc, entry_ptr->conv_entry.filter_desc, + entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.output_desc, + CUDNN_CONVOLUTION_FWD_ALGO_COUNT, &returned_algo_count, perf_results)); + + const std::vector fwd_algo_names{"CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM", + "CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM", + "CUDNN_CONVOLUTION_FWD_ALGO_GEMM", + "CUDNN_CONVOLUTION_FWD_ALGO_DIRECT", + "CUDNN_CONVOLUTION_FWD_ALGO_FFT", + "CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING", + "CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD", + "CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED"}; auto best_algo = perf_results[0].algo; - LOG(INFO) << "\tCUDNN Found " << returned_algo_count - << " fwd algorithms, choosing " << fwd_algo_names[best_algo]; + LOG(INFO) << "\tCUDNN Found " << returned_algo_count << " fwd algorithms, choosing " + << fwd_algo_names[best_algo]; for (int i = 0; i < returned_algo_count; ++i) { LOG(INFO) << "\t\t" << i << ") " << fwd_algo_names[perf_results[i].algo] << " - time: " << perf_results[i].time << " ms" @@ -352,82 +253,83 @@ void FindAlgo( ret[0] = best_algo; } - TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.forward") -.set_body([](TVMArgs args, TVMRetValue *ret) { - int mode = args[0]; - int format = args[1]; - int algo = args[2]; - int pad_v[2], stride_v[2], dilation_v[2]; - for (int i = 0; i < 2; i++) { - pad_v[i] = args[3 + i]; - stride_v[i] = args[5 + i]; - dilation_v[i] = args[7 + i]; - } - DLTensor* x = args[9]; - DLTensor* w = args[10]; - DLTensor* y = args[11]; - std::string conv_dtype = args[12]; - - ConvolutionForward(mode, format, algo, 2, pad_v, stride_v, dilation_v, x, w, y, conv_dtype); -}); - + .set_body([](TVMArgs args, TVMRetValue* ret) { + int mode = args[0]; + int format = args[1]; + int algo = args[2]; + int pad_v[2], stride_v[2], dilation_v[2]; + for (int i = 0; i < 2; i++) { + pad_v[i] = args[3 + i]; + stride_v[i] = args[5 + i]; + dilation_v[i] = args[7 + i]; + } + DLTensor* x = args[9]; + DLTensor* w = args[10]; + DLTensor* y = args[11]; + std::string conv_dtype = args[12]; + int groups = args[13]; + + ConvolutionForward(mode, format, algo, 2, groups, pad_v, stride_v, dilation_v, x, w, y, + conv_dtype); + }); TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv3d.forward") -.set_body([](TVMArgs args, TVMRetValue *ret) { - int mode = args[0]; - int format = args[1]; - int algo = args[2]; - int pad_v[3], stride_v[3], dilation_v[3]; - for (int i = 0; i < 3; i++) { - pad_v[i] = args[3 + i]; - stride_v[i] = args[6 + i]; - dilation_v[i] = args[9 + i]; - } - DLTensor *x = args[12]; - DLTensor *w = args[13]; - DLTensor *y = args[14]; - std::string conv_dtype = args[15]; - - ConvolutionForward(mode, format, algo, 3, pad_v, stride_v, dilation_v, x, w, y, - conv_dtype); -}); - + .set_body([](TVMArgs args, TVMRetValue* ret) { + int mode = args[0]; + int format = args[1]; + int algo = args[2]; + int pad_v[3], stride_v[3], dilation_v[3]; + for (int i = 0; i < 3; i++) { + pad_v[i] = args[3 + i]; + stride_v[i] = args[6 + i]; + dilation_v[i] = args[9 + i]; + } + DLTensor* x = args[12]; + DLTensor* w = args[13]; + DLTensor* y = args[14]; + std::string conv_dtype = args[15]; + int groups = args[16]; + + ConvolutionForward(mode, format, algo, 3, groups, pad_v, stride_v, dilation_v, x, w, y, + conv_dtype); + }); TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.output_shape") -.set_body([](TVMArgs args, TVMRetValue *ret) { - int format = args[0]; - int dims = args[1]; - int* pad = static_cast(static_cast(args[2])); - int* stride = static_cast(static_cast(args[3])); - int* dilation = static_cast(static_cast(args[4])); - int* x_dim = static_cast(static_cast(args[5])); - int* w_dim = static_cast(static_cast(args[6])); - void* out_shape = args[7]; - std::string data_dtype = args[8]; - std::string conv_dtype = args[9]; - - OutputShape(format, dims, pad, stride, dilation, x_dim, - w_dim, out_shape, data_dtype, conv_dtype); -}); - + .set_body([](TVMArgs args, TVMRetValue* ret) { + int format = args[0]; + int dims = args[1]; + int* pad = static_cast(static_cast(args[2])); + int* stride = static_cast(static_cast(args[3])); + int* dilation = static_cast(static_cast(args[4])); + int* x_dim = static_cast(static_cast(args[5])); + int* w_dim = static_cast(static_cast(args[6])); + void* out_shape = args[7]; + std::string data_dtype = args[8]; + std::string conv_dtype = args[9]; + int groups = args[10]; + + OutputShape(format, dims, groups, pad, stride, dilation, x_dim, w_dim, out_shape, data_dtype, + conv_dtype); + }); TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.find_algo") -.set_body([](TVMArgs args, TVMRetValue *ret) { - int format = args[0]; - int dims = args[1]; - int* pad = static_cast(static_cast(args[2])); - int* stride = static_cast(static_cast(args[3])); - int* dilation = static_cast(static_cast(args[4])); - int* x_dim = static_cast(static_cast(args[5])); - int* w_dim = static_cast(static_cast(args[6])); - int* y_dim = static_cast(static_cast(args[7])); - std::string data_dtype = args[8]; - std::string conv_dtype = args[9]; - - FindAlgo(format, dims, pad, stride, dilation, x_dim, - w_dim, y_dim, data_dtype, conv_dtype, ret); -}); + .set_body([](TVMArgs args, TVMRetValue* ret) { + int format = args[0]; + int dims = args[1]; + int* pad = static_cast(static_cast(args[2])); + int* stride = static_cast(static_cast(args[3])); + int* dilation = static_cast(static_cast(args[4])); + int* x_dim = static_cast(static_cast(args[5])); + int* w_dim = static_cast(static_cast(args[6])); + int* y_dim = static_cast(static_cast(args[7])); + std::string data_dtype = args[8]; + std::string conv_dtype = args[9]; + int groups = args[10]; + + FindAlgo(format, dims, groups, pad, stride, dilation, x_dim, w_dim, y_dim, data_dtype, + conv_dtype, ret); + }); } // namespace contrib } // namespace tvm diff --git a/src/runtime/contrib/cudnn/cudnn_utils.cc b/src/runtime/contrib/cudnn/cudnn_utils.cc index 9c895c5b7e06..cd934bcb7081 100644 --- a/src/runtime/contrib/cudnn/cudnn_utils.cc +++ b/src/runtime/contrib/cudnn/cudnn_utils.cc @@ -21,38 +21,44 @@ * \file Use external cudnn utils function */ #include "cudnn_utils.h" + #include #include - namespace tvm { namespace contrib { // CuDNN Data Type -cudnnDataType_t CuDNNDataType::DLTypeToCuDNNType(const DLDataType &dtype) { +cudnnDataType_t CuDNNDataType::DLTypeToCuDNNType(const DLDataType& dtype) { switch (dtype.code) { - case kDLInt: - if (dtype.bits == 8 && dtype.lanes == 1) return CUDNN_DATA_INT8; - else if (dtype.bits == 32 && dtype.lanes == 1) return CUDNN_DATA_INT32; - else if (dtype.bits == 8 && dtype.lanes == 4) return CUDNN_DATA_INT8x4; - else - LOG(FATAL) << "Unsupported type"; - break; - case kDLUInt: + case kDLInt: + if (dtype.bits == 8 && dtype.lanes == 1) + return CUDNN_DATA_INT8; + else if (dtype.bits == 32 && dtype.lanes == 1) + return CUDNN_DATA_INT32; + else if (dtype.bits == 8 && dtype.lanes == 4) + return CUDNN_DATA_INT8x4; + else LOG(FATAL) << "Unsupported type"; - break; - case kDLFloat: - if (dtype.bits == 32 && dtype.lanes == 1) return CUDNN_DATA_FLOAT; - else if (dtype.bits == 64 && dtype.lanes == 1) return CUDNN_DATA_DOUBLE; - else if (dtype.bits == 16 && dtype.lanes == 1) return CUDNN_DATA_HALF; - else - LOG(FATAL) << "Unsupported type"; - break; - } - return CUDNN_DATA_FLOAT; + break; + case kDLUInt: + LOG(FATAL) << "Unsupported type"; + break; + case kDLFloat: + if (dtype.bits == 32 && dtype.lanes == 1) + return CUDNN_DATA_FLOAT; + else if (dtype.bits == 64 && dtype.lanes == 1) + return CUDNN_DATA_DOUBLE; + else if (dtype.bits == 16 && dtype.lanes == 1) + return CUDNN_DATA_HALF; + else + LOG(FATAL) << "Unsupported type"; + break; + } + return CUDNN_DATA_FLOAT; } -template<> +template <> const void* CuDNNDataType::GetConst<0>(cudnnDataType_t type) { static const int int_v = 0; static const float float_v = 0; @@ -69,7 +75,7 @@ const void* CuDNNDataType::GetConst<0>(cudnnDataType_t type) { return nullptr; } -template<> +template <> const void* CuDNNDataType::GetConst<1>(cudnnDataType_t type) { static const int int_v = 1; static const float float_v = 1.f; @@ -91,22 +97,18 @@ const void* CuDNNDataType::GetConst<1>(cudnnDataType_t type) { CuDNNThreadEntry::CuDNNThreadEntry() { auto stream = runtime::CUDAThreadEntry::ThreadLocal()->stream; auto func = runtime::Registry::Get("device_api.gpu"); - void *ret = (*func)(); + void* ret = (*func)(); cuda_api = static_cast(ret); CUDNN_CALL(cudnnCreate(&handle)); CUDNN_CALL(cudnnSetStream(handle, stream)); conv_entry.cuda_api = cuda_api; } -CuDNNThreadEntry::~CuDNNThreadEntry() { - CUDNN_CALL(cudnnDestroy(handle)); -} +CuDNNThreadEntry::~CuDNNThreadEntry() { CUDNN_CALL(cudnnDestroy(handle)); } typedef dmlc::ThreadLocalStore CuDNNThreadStore; -CuDNNThreadEntry* CuDNNThreadEntry::ThreadLocal() { - return CuDNNThreadStore::Get(); -} +CuDNNThreadEntry* CuDNNThreadEntry::ThreadLocal() { return CuDNNThreadStore::Get(); } // ConvEntry @@ -142,13 +144,9 @@ void ConvEntry::CleanWorkspace() { // SoftmaxEntry -SoftmaxEntry::SoftmaxEntry() { - CUDNN_CALL(cudnnCreateTensorDescriptor(&shape_desc)); -} +SoftmaxEntry::SoftmaxEntry() { CUDNN_CALL(cudnnCreateTensorDescriptor(&shape_desc)); } -SoftmaxEntry::~SoftmaxEntry() { - CUDNN_CALL(cudnnDestroyTensorDescriptor(shape_desc)); -} +SoftmaxEntry::~SoftmaxEntry() { CUDNN_CALL(cudnnDestroyTensorDescriptor(shape_desc)); } } // namespace contrib } // namespace tvm diff --git a/src/runtime/contrib/cudnn/cudnn_utils.h b/src/runtime/contrib/cudnn/cudnn_utils.h index ee6bb5089e38..1b4eb40f193f 100644 --- a/src/runtime/contrib/cudnn/cudnn_utils.h +++ b/src/runtime/contrib/cudnn/cudnn_utils.h @@ -24,11 +24,11 @@ #ifndef TVM_RUNTIME_CONTRIB_CUDNN_CUDNN_UTILS_H_ #define TVM_RUNTIME_CONTRIB_CUDNN_CUDNN_UTILS_H_ -#include #include +#include #include -#include "../../cuda/cuda_common.h" +#include "../../cuda/cuda_common.h" namespace tvm { namespace contrib { @@ -41,24 +41,22 @@ namespace contrib { /*! breif Convert DLTensor type to CuDNN type */ struct CuDNNDataType { - static cudnnDataType_t DLTypeToCuDNNType(const DLDataType &dtype); - template + static cudnnDataType_t DLTypeToCuDNNType(const DLDataType& dtype); + template static const void* GetConst(cudnnDataType_t type); }; // struct CuDNNDataType -inline void GetStride(int nbdim, const int *dims, int *strides) { +inline void GetStride(int nbdim, const int* dims, int* strides) { int mul = 1; - for (int i = nbdim - 1; i >=0; --i) { + for (int i = nbdim - 1; i >= 0; --i) { mul *= dims[i]; strides[i] = mul; } } -inline void GetCudnnStride(int nbdim, - const int* dims, - int* strides) { +inline void GetCudnnStride(int nbdim, const int* dims, int* strides) { int mul = 1; - for (int i = nbdim - 1; i >=0; --i) { + for (int i = nbdim - 1; i >= 0; --i) { strides[i] = mul; mul *= dims[i]; } @@ -75,10 +73,9 @@ struct ConvEntry { cudnnConvolutionFwdAlgo_t fwd_algo; // cudnnMathType_t math_type; TVMContext ctx; - runtime::DeviceAPI *cuda_api; - void *workspace{nullptr}; + runtime::DeviceAPI* cuda_api; + void* workspace{nullptr}; size_t workspace_size{0}; - int group_count {0}; ConvEntry(); ~ConvEntry(); void UpdateWorkspace(const size_t wsize); @@ -99,7 +96,7 @@ struct CuDNNThreadEntry { cudnnHandle_t handle{nullptr}; ConvEntry conv_entry; SoftmaxEntry softmax_entry; - runtime::DeviceAPI *cuda_api{nullptr}; + runtime::DeviceAPI* cuda_api{nullptr}; static CuDNNThreadEntry* ThreadLocal(); }; // CuDNNThreadEntry diff --git a/src/runtime/contrib/cudnn/softmax.cc b/src/runtime/contrib/cudnn/softmax.cc index fb6d8a6fdc56..ff6d6a1dbd81 100644 --- a/src/runtime/contrib/cudnn/softmax.cc +++ b/src/runtime/contrib/cudnn/softmax.cc @@ -21,8 +21,9 @@ * \file src/runtime/contrib/cudnn/softmax.cc * \brief Use external cudnn softmax function */ -#include #include +#include + #include "cudnn_utils.h" namespace tvm { @@ -31,64 +32,53 @@ namespace contrib { using namespace runtime; TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.softmax.forward") -.set_body([](TVMArgs args, TVMRetValue *ret) { - DLTensor* x = args[0]; - DLTensor* y = args[1]; - int axis = args[2]; - int ndim = x->ndim; - int64_t* shape = x->shape; - if (axis < 0) axis += ndim; - CHECK(axis >= 0 && axis < ndim); + .set_body([](TVMArgs args, TVMRetValue* ret) { + DLTensor* x = args[0]; + DLTensor* y = args[1]; + int axis = args[2]; + int ndim = x->ndim; + int64_t* shape = x->shape; + if (axis < 0) axis += ndim; + CHECK(axis >= 0 && axis < ndim); - CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(); - entry_ptr->softmax_entry.data_type = CuDNNDataType::DLTypeToCuDNNType(x->dtype); + CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(); + entry_ptr->softmax_entry.data_type = CuDNNDataType::DLTypeToCuDNNType(x->dtype); - // Set mode and shape descriptor - if (axis == ndim - 1) { - int64_t N = 1; - for (int i = 0; i < ndim - 1; ++i) { - N *= shape[i]; - } - entry_ptr->softmax_entry.mode = CUDNN_SOFTMAX_MODE_INSTANCE; - CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->softmax_entry.shape_desc, - CUDNN_TENSOR_NCHW, - entry_ptr->softmax_entry.data_type, - static_cast(N), - static_cast(shape[ndim - 1]), - 1, - 1)); - } else { - int64_t pre_axis_dim = 1; - int64_t post_axis_dim = 1; - for (int i = 0; i < ndim; ++i) { - if (i < axis) { - pre_axis_dim *= shape[i]; - } else if (i > axis) { - post_axis_dim *= shape[i]; + // Set mode and shape descriptor + if (axis == ndim - 1) { + int64_t N = 1; + for (int i = 0; i < ndim - 1; ++i) { + N *= shape[i]; + } + entry_ptr->softmax_entry.mode = CUDNN_SOFTMAX_MODE_INSTANCE; + CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->softmax_entry.shape_desc, + CUDNN_TENSOR_NCHW, entry_ptr->softmax_entry.data_type, + static_cast(N), + static_cast(shape[ndim - 1]), 1, 1)); + } else { + int64_t pre_axis_dim = 1; + int64_t post_axis_dim = 1; + for (int i = 0; i < ndim; ++i) { + if (i < axis) { + pre_axis_dim *= shape[i]; + } else if (i > axis) { + post_axis_dim *= shape[i]; + } + } + entry_ptr->softmax_entry.mode = CUDNN_SOFTMAX_MODE_CHANNEL; + CUDNN_CALL(cudnnSetTensor4dDescriptor( + entry_ptr->softmax_entry.shape_desc, CUDNN_TENSOR_NCHW, + entry_ptr->softmax_entry.data_type, static_cast(pre_axis_dim), + static_cast(shape[axis]), static_cast(post_axis_dim), 1)); } - } - entry_ptr->softmax_entry.mode = CUDNN_SOFTMAX_MODE_CHANNEL; - CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->softmax_entry.shape_desc, - CUDNN_TENSOR_NCHW, - entry_ptr->softmax_entry.data_type, - static_cast(pre_axis_dim), - static_cast(shape[axis]), - static_cast(post_axis_dim), - 1)); - } - auto alpha = CuDNNDataType::GetConst<1>(entry_ptr->softmax_entry.data_type); - auto beta = CuDNNDataType::GetConst<0>(entry_ptr->softmax_entry.data_type); - CUDNN_CALL(cudnnSoftmaxForward(entry_ptr->handle, - CUDNN_SOFTMAX_ACCURATE, - entry_ptr->softmax_entry.mode, - alpha, - entry_ptr->softmax_entry.shape_desc, - x->data, - beta, - entry_ptr->softmax_entry.shape_desc, - y->data)); -}); + auto alpha = CuDNNDataType::GetConst<1>(entry_ptr->softmax_entry.data_type); + auto beta = CuDNNDataType::GetConst<0>(entry_ptr->softmax_entry.data_type); + CUDNN_CALL(cudnnSoftmaxForward(entry_ptr->handle, CUDNN_SOFTMAX_ACCURATE, + entry_ptr->softmax_entry.mode, alpha, + entry_ptr->softmax_entry.shape_desc, x->data, beta, + entry_ptr->softmax_entry.shape_desc, y->data)); + }); } // namespace contrib } // namespace tvm diff --git a/src/runtime/contrib/dnnl/dnnl.cc b/src/runtime/contrib/dnnl/dnnl.cc index 0922ac1a65df..5b9f5e17232c 100644 --- a/src/runtime/contrib/dnnl/dnnl.cc +++ b/src/runtime/contrib/dnnl/dnnl.cc @@ -22,8 +22,6 @@ * \brief TVM compatible wrappers for dnnl kernels. */ -#include "dnnl_kernel.h" - #include #include #include @@ -34,6 +32,8 @@ #include #include +#include "dnnl_kernel.h" + namespace tvm { namespace runtime { namespace contrib { @@ -133,8 +133,7 @@ extern "C" void dnnl_fused_conv2d_bias_relu(float* data, float* weights, float* p_Pw_, p_Kh_, p_Kw_, p_Sh_, p_Sw_, create_attr_with_relu_post_op()); } -extern "C" void dnnl_dense(float* data, float* weight, float* out, int p_B_, - int p_I_, int p_O_) { +extern "C" void dnnl_dense(float* data, float* weight, float* out, int p_B_, int p_I_, int p_O_) { using tag = memory::format_tag; using dt = memory::data_type; @@ -157,8 +156,8 @@ extern "C" void dnnl_dense(float* data, float* weight, float* out, int p_B_, auto bias_memory = memory(bias_md, eng, bias.data()); auto dst_memory = memory(dst_md, eng); - auto dense_desc = inner_product_forward::desc( - prop_kind::forward_inference, data_md, weight_md, bias_md, dst_md); + auto dense_desc = inner_product_forward::desc(prop_kind::forward_inference, data_md, weight_md, + bias_md, dst_md); auto dense_prim_desc = inner_product_forward::primitive_desc(dense_desc, eng); assert(dst_md == dense_prim_desc.dst_desc()); @@ -171,8 +170,7 @@ extern "C" void dnnl_dense(float* data, float* weight, float* out, int p_B_, read_from_dnnl_memory(out, dst_memory); } -extern "C" void dnnl_relu(float* data, float* out, int p_N_, int p_C_, int p_H_, - int p_W_) { +extern "C" void dnnl_relu(float* data, float* out, int p_N_, int p_C_, int p_H_, int p_W_) { using tag = memory::format_tag; using dt = memory::data_type; @@ -186,8 +184,8 @@ extern "C" void dnnl_relu(float* data, float* out, int p_N_, int p_C_, int p_H_, auto data_memory = memory(data_md, eng, data); auto dst_memory = memory(data_md, eng); - auto relu_desc = eltwise_forward::desc(prop_kind::forward_inference, - algorithm::eltwise_relu, data_md, 0); + auto relu_desc = + eltwise_forward::desc(prop_kind::forward_inference, algorithm::eltwise_relu, data_md, 0); auto relu_prim_desc = eltwise_forward::primitive_desc(relu_desc, eng); assert(data_md == relu_prim_desc.dst_desc()); @@ -215,8 +213,7 @@ extern "C" void dnnl_bn(float* data, float* gamma, float* beta, float* mean, flo auto bn_desc = batch_normalization_forward::desc( prop_kind::forward_inference, data_md, p_E_, - normalization_flags::use_global_stats | - normalization_flags::use_scale_shift); + normalization_flags::use_global_stats | normalization_flags::use_scale_shift); auto bn_prim_desc = batch_normalization_forward::primitive_desc(bn_desc, eng); assert(data_md == bn_prim_desc.dst_desc()); @@ -239,8 +236,8 @@ extern "C" void dnnl_bn(float* data, float* gamma, float* beta, float* mean, flo free(weight); } -extern "C" void dnnl_add(float* data, float* weight, float* out, int p_N_, - int p_C_, int p_H_, int p_W_) { +extern "C" void dnnl_add(float* data, float* weight, float* out, int p_N_, int p_C_, int p_H_, + int p_W_) { using tag = memory::format_tag; using dt = memory::data_type; @@ -257,15 +254,14 @@ extern "C" void dnnl_add(float* data, float* weight, float* out, int p_N_, auto weight_memory = memory(weight_md, eng, weight); auto dst_memory = memory(dst_md, eng); - auto add_desc = - binary::desc(algorithm::binary_add, data_md, weight_md, dst_md); + auto add_desc = binary::desc(algorithm::binary_add, data_md, weight_md, dst_md); auto add_prim_desc = binary::primitive_desc(add_desc, eng); assert(dst_md == add_prim_desc.dst_desc()); auto add = binary(add_prim_desc); - add.execute(s, {{DNNL_ARG_SRC_0, data_memory}, - {DNNL_ARG_SRC_1, weight_memory}, - {DNNL_ARG_DST, dst_memory}}); + add.execute( + s, + {{DNNL_ARG_SRC_0, data_memory}, {DNNL_ARG_SRC_1, weight_memory}, {DNNL_ARG_DST, dst_memory}}); s.wait(); read_from_dnnl_memory(out, dst_memory); } diff --git a/src/runtime/contrib/dnnl/dnnl_kernel.h b/src/runtime/contrib/dnnl/dnnl_kernel.h index f92d7679aeee..dbc064a6bc99 100644 --- a/src/runtime/contrib/dnnl/dnnl_kernel.h +++ b/src/runtime/contrib/dnnl/dnnl_kernel.h @@ -26,6 +26,7 @@ #define TVM_RUNTIME_CONTRIB_DNNL_DNNL_KERNEL_H_ #include + #include "dnnl.hpp" namespace tvm { diff --git a/src/runtime/contrib/edgetpu/edgetpu_runtime.cc b/src/runtime/contrib/edgetpu/edgetpu_runtime.cc index 4823ef7de959..13b3c34a6b17 100644 --- a/src/runtime/contrib/edgetpu/edgetpu_runtime.cc +++ b/src/runtime/contrib/edgetpu/edgetpu_runtime.cc @@ -20,25 +20,23 @@ /*! * \file edgetpu_runtime.cc */ -#include +#include "edgetpu_runtime.h" + +#include #include #include #include -#include - - -#include "edgetpu_runtime.h" +#include namespace tvm { namespace runtime { -void EdgeTPURuntime::Init(const std::string& tflite_model_bytes, - TVMContext ctx) { +void EdgeTPURuntime::Init(const std::string& tflite_model_bytes, TVMContext ctx) { const char* buffer = tflite_model_bytes.c_str(); size_t buffer_size = tflite_model_bytes.size(); // Load compiled model as a FlatBufferModel std::unique_ptr model = - tflite::FlatBufferModel::BuildFromBuffer(buffer, buffer_size); + tflite::FlatBufferModel::BuildFromBuffer(buffer, buffer_size); // Build resolver tflite::ops::builtin::BuiltinOpResolver resolver; // Init EdgeTPUContext object @@ -58,16 +56,14 @@ void EdgeTPURuntime::Init(const std::string& tflite_model_bytes, ctx_ = ctx; } -Module EdgeTPURuntimeCreate(const std::string& tflite_model_bytes, - TVMContext ctx) { +Module EdgeTPURuntimeCreate(const std::string& tflite_model_bytes, TVMContext ctx) { auto exec = make_object(); exec->Init(tflite_model_bytes, ctx); return Module(exec); } -TVM_REGISTER_GLOBAL("tvm.edgetpu_runtime.create") - .set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = EdgeTPURuntimeCreate(args[0], args[1]); - }); +TVM_REGISTER_GLOBAL("tvm.edgetpu_runtime.create").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = EdgeTPURuntimeCreate(args[0], args[1]); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/contrib/edgetpu/edgetpu_runtime.h b/src/runtime/contrib/edgetpu/edgetpu_runtime.h index 78730d530018..af3517ba76f3 100644 --- a/src/runtime/contrib/edgetpu/edgetpu_runtime.h +++ b/src/runtime/contrib/edgetpu/edgetpu_runtime.h @@ -25,8 +25,8 @@ #ifndef TVM_RUNTIME_CONTRIB_EDGETPU_EDGETPU_RUNTIME_H_ #define TVM_RUNTIME_CONTRIB_EDGETPU_EDGETPU_RUNTIME_H_ -#include #include +#include #include "../tflite/tflite_runtime.h" @@ -44,17 +44,14 @@ class EdgeTPURuntime : public TFLiteRuntime { /*! * \return The type key of the executor. */ - const char* type_key() const final { - return "EdgeTPURuntime"; - } + const char* type_key() const final { return "EdgeTPURuntime"; } /*! * \brief Initialize the edge TPU tflite runtime with tflite model and context. * \param tflite_model_bytes The tflite model. * \param ctx The context where the tflite model will be executed on. */ - void Init(const std::string& tflite_model_bytes, - TVMContext ctx); + void Init(const std::string& tflite_model_bytes, TVMContext ctx); private: std::shared_ptr edgetpu_context_; diff --git a/src/runtime/contrib/example_ext_runtime/example_ext_runtime.cc b/src/runtime/contrib/example_ext_runtime/example_ext_runtime.cc index 98078b68c23a..1a63eded5adf 100644 --- a/src/runtime/contrib/example_ext_runtime/example_ext_runtime.cc +++ b/src/runtime/contrib/example_ext_runtime/example_ext_runtime.cc @@ -42,8 +42,8 @@ #include #include -#include #include +#include #include #include #include @@ -76,9 +76,8 @@ int Add(TVMValue* value, int* type_code, int nargs) { DLTensor* arg0 = static_cast(value[0].v_handle); DLTensor* arg1 = static_cast(value[1].v_handle); DLTensor* out = static_cast(value[2].v_handle); - Add_(static_cast(arg0->data), arg0->shape[0], - static_cast(arg1->data), arg1->shape[0], - static_cast(out->data)); + Add_(static_cast(arg0->data), arg0->shape[0], static_cast(arg1->data), + arg1->shape[0], static_cast(out->data)); return 0; } @@ -93,9 +92,8 @@ int Sub(TVMValue* value, int* type_code, int nargs) { DLTensor* arg0 = static_cast(value[0].v_handle); DLTensor* arg1 = static_cast(value[1].v_handle); DLTensor* out = static_cast(value[2].v_handle); - Sub_(static_cast(arg0->data), arg0->shape[0], - static_cast(arg1->data), arg1->shape[0], - static_cast(out->data)); + Sub_(static_cast(arg0->data), arg0->shape[0], static_cast(arg1->data), + arg1->shape[0], static_cast(out->data)); return 0; } @@ -110,9 +108,8 @@ int Mul(TVMValue* value, int* type_code, int nargs) { DLTensor* arg0 = static_cast(value[0].v_handle); DLTensor* arg1 = static_cast(value[1].v_handle); DLTensor* out = static_cast(value[2].v_handle); - Mul_(static_cast(arg0->data), arg0->shape[0], - static_cast(arg1->data), arg1->shape[0], - static_cast(out->data)); + Mul_(static_cast(arg0->data), arg0->shape[0], static_cast(arg1->data), + arg1->shape[0], static_cast(out->data)); return 0; } @@ -136,8 +133,7 @@ class ExampleJsonModule : public ModuleNode { * * \return The function pointer when it is found, otherwise, PackedFunc(nullptr). */ - PackedFunc GetFunction(const std::string& name, - const ObjectPtr& sptr_to_self) final { + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { if (this->graph_.find(name) != this->graph_.end()) { this->curr_subgraph_ = name; return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { @@ -215,9 +211,7 @@ class ExampleJsonModule : public ModuleNode { * * \param stream. The stream to save the binary. */ - void SaveToBinary(dmlc::Stream* stream) final { - stream->Write(this->graph_json_); - } + void SaveToBinary(dmlc::Stream* stream) final { stream->Write(this->graph_json_); } /*! * \brief Parse the example json string. @@ -333,12 +327,10 @@ class ExampleJsonModule : public ModuleNode { }; TVM_REGISTER_GLOBAL("runtime.module.loadfile_examplejson") -.set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = ExampleJsonModule::Create(args[0]); -}); + .set_body([](TVMArgs args, TVMRetValue* rv) { *rv = ExampleJsonModule::Create(args[0]); }); TVM_REGISTER_GLOBAL("runtime.module.loadbinary_examplejson") -.set_body_typed(ExampleJsonModule::LoadFromBinary); + .set_body_typed(ExampleJsonModule::LoadFromBinary); } // namespace runtime } // namespace tvm diff --git a/src/runtime/contrib/miopen/conv_forward.cc b/src/runtime/contrib/miopen/conv_forward.cc index d4575484320b..1353e2f996bb 100644 --- a/src/runtime/contrib/miopen/conv_forward.cc +++ b/src/runtime/contrib/miopen/conv_forward.cc @@ -20,9 +20,10 @@ /*! * \file Use external miopen utils function */ -#include #include #include +#include + #include "miopen_utils.h" namespace tvm { @@ -31,8 +32,7 @@ namespace miopen { using namespace runtime; -TVM_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.setup") -.set_body([](TVMArgs args, TVMRetValue *ret) { +TVM_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.setup").set_body([](TVMArgs args, TVMRetValue* ret) { const int mode = args[0]; const int dtype = args[1]; const int pad_h = args[2]; @@ -50,72 +50,52 @@ TVM_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.setup") const int w_dim2 = args[14]; const int w_dim3 = args[15]; const int n_group = args[16]; - void *out_shape = args[17]; + void* out_shape = args[17]; MIOpenThreadEntry* entry_ptr = MIOpenThreadEntry::ThreadLocal(); assert(n_group > 0 && "Group Size > 0 is expected"); - if (n_group > 1) - assert(mode > 1 && "Group /Depthwise Conv mode when num of groups > 1"); + if (n_group > 1) assert(mode > 1 && "Group /Depthwise Conv mode when num of groups > 1"); // Set Mode entry_ptr->conv_entry.mode = static_cast(mode); // Set Ctx entry_ptr->conv_entry.ctx = TVMContext{kDLROCM, 0}; // Set Data Type - entry_ptr->conv_entry.data_type = static_cast( - dtype); // MIOpen supports fp32(miopenFloat), fp16(miopenHalf), int32, int8 at - // this moment. + entry_ptr->conv_entry.data_type = + static_cast(dtype); // MIOpen supports fp32(miopenFloat), fp16(miopenHalf), + // int32, int8 at this moment. // Set Desc MIOPEN_CALL(miopenInitConvolutionDescriptor(entry_ptr->conv_entry.conv_desc, - entry_ptr->conv_entry.mode, - pad_h, - pad_w, - stride_h, - stride_w, - dilation_h, - dilation_w)); + entry_ptr->conv_entry.mode, pad_h, pad_w, stride_h, + stride_w, dilation_h, dilation_w)); if (n_group > 1) MIOPEN_CALL(miopenSetConvolutionGroupCount(entry_ptr->conv_entry.conv_desc, n_group)); // Set Filter MIOPEN_CALL(miopenSet4dTensorDescriptor(entry_ptr->conv_entry.filter_desc, - entry_ptr->conv_entry.data_type, - w_dim0, - w_dim1/n_group, - w_dim2, - w_dim3)); + entry_ptr->conv_entry.data_type, w_dim0, w_dim1 / n_group, + w_dim2, w_dim3)); // Set Input MIOPEN_CALL(miopenSet4dTensorDescriptor(entry_ptr->conv_entry.input_desc, - entry_ptr->conv_entry.data_type, - x_dim0, - x_dim1, - x_dim2, + entry_ptr->conv_entry.data_type, x_dim0, x_dim1, x_dim2, x_dim3)); // Set Output shape - MIOPEN_CALL(miopenGetConvolutionForwardOutputDim(entry_ptr->conv_entry.conv_desc, - entry_ptr->conv_entry.input_desc, - entry_ptr->conv_entry.filter_desc, - static_cast(out_shape), - static_cast(out_shape) + 1, - static_cast(out_shape) + 2, - static_cast(out_shape) + 3)); - - const int *oshape = static_cast(out_shape); + MIOPEN_CALL(miopenGetConvolutionForwardOutputDim( + entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.input_desc, + entry_ptr->conv_entry.filter_desc, static_cast(out_shape), + static_cast(out_shape) + 1, static_cast(out_shape) + 2, + static_cast(out_shape) + 3)); + + const int* oshape = static_cast(out_shape); // Set Output MIOPEN_CALL(miopenSet4dTensorDescriptor(entry_ptr->conv_entry.output_desc, - entry_ptr->conv_entry.data_type, - oshape[0], - oshape[1], - oshape[2], - oshape[3])); + entry_ptr->conv_entry.data_type, oshape[0], oshape[1], + oshape[2], oshape[3])); // Set workspace size_t workspace_size = 0; - MIOPEN_CALL(miopenConvolutionForwardGetWorkSpaceSize(entry_ptr->handle, - entry_ptr->conv_entry.filter_desc, - entry_ptr->conv_entry.input_desc, - entry_ptr->conv_entry.conv_desc, - entry_ptr->conv_entry.output_desc, - &workspace_size)); + MIOPEN_CALL(miopenConvolutionForwardGetWorkSpaceSize( + entry_ptr->handle, entry_ptr->conv_entry.filter_desc, entry_ptr->conv_entry.input_desc, + entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.output_desc, &workspace_size)); entry_ptr->conv_entry.UpdateWorkspace(workspace_size); const size_t input_size = x_dim0 * x_dim1 * x_dim2 * x_dim3; @@ -123,12 +103,12 @@ TVM_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.setup") const size_t output_size = oshape[0] * oshape[1] * oshape[2] * oshape[3]; runtime::DeviceAPI* rocm_api = entry_ptr->conv_entry.rocm_api; - float* input_buf = static_cast(rocm_api->AllocWorkspace(entry_ptr->conv_entry.ctx, - input_size * sizeof(float))); - float* filter_buf = static_cast(rocm_api->AllocWorkspace(entry_ptr->conv_entry.ctx, - filter_size * sizeof(float))); - float* output_buf = static_cast(rocm_api->AllocWorkspace(entry_ptr->conv_entry.ctx, - output_size * sizeof(float))); + float* input_buf = static_cast( + rocm_api->AllocWorkspace(entry_ptr->conv_entry.ctx, input_size * sizeof(float))); + float* filter_buf = static_cast( + rocm_api->AllocWorkspace(entry_ptr->conv_entry.ctx, filter_size * sizeof(float))); + float* output_buf = static_cast( + rocm_api->AllocWorkspace(entry_ptr->conv_entry.ctx, output_size * sizeof(float))); const int request_algo_count = 4; const bool exhaustive_search = false; @@ -137,20 +117,11 @@ TVM_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.setup") int returned_algo_count = 0; miopenConvAlgoPerf_t perfs[4]; - MIOPEN_CALL(miopenFindConvolutionForwardAlgorithm(entry_ptr->handle, - entry_ptr->conv_entry.input_desc, - input_buf, - entry_ptr->conv_entry.filter_desc, - filter_buf, - entry_ptr->conv_entry.conv_desc, - entry_ptr->conv_entry.output_desc, - output_buf, - request_algo_count, - &returned_algo_count, - perfs, - workspace, - workspace_size, - exhaustive_search)); + MIOPEN_CALL(miopenFindConvolutionForwardAlgorithm( + entry_ptr->handle, entry_ptr->conv_entry.input_desc, input_buf, + entry_ptr->conv_entry.filter_desc, filter_buf, entry_ptr->conv_entry.conv_desc, + entry_ptr->conv_entry.output_desc, output_buf, request_algo_count, &returned_algo_count, + perfs, workspace, workspace_size, exhaustive_search)); rocm_api->FreeWorkspace(entry_ptr->conv_entry.ctx, input_buf); rocm_api->FreeWorkspace(entry_ptr->conv_entry.ctx, filter_buf); @@ -163,8 +134,8 @@ TVM_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.setup") "miopenConvolutionFwdAlgoWinograd", }; const auto best_algo = perfs[0].fwd_algo; - LOG(INFO) << "\tMIOpen Found " << returned_algo_count - << " fwd algorithms, choosing " << fwd_algo_names[best_algo]; + LOG(INFO) << "\tMIOpen Found " << returned_algo_count << " fwd algorithms, choosing " + << fwd_algo_names[best_algo]; for (int i = 0; i < returned_algo_count; ++i) { LOG(INFO) << "\t\t" << i << ") " << fwd_algo_names[perfs[i].fwd_algo] << " - time: " << perfs[i].time << " ms" @@ -174,79 +145,56 @@ TVM_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.setup") ret[0] = static_cast(best_algo); }); - TVM_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.forward") -.set_body([](TVMArgs args, TVMRetValue *ret) { - const int mode = args[0]; - const int dtype = args[1]; - const int pad_h = args[2]; - const int pad_w = args[3]; - const int stride_h = args[4]; - const int stride_w = args[5]; - const int dilation_h = args[6]; - const int dilation_w = args[7]; - const int algo = args[8]; - const DLTensor *x = args[9]; - const DLTensor *w = args[10]; - const DLTensor *y = args[11]; - - MIOpenThreadEntry* entry_ptr = MIOpenThreadEntry::ThreadLocal(); - entry_ptr->conv_entry.fwd_algo = static_cast(algo); - // Set Mode - entry_ptr->conv_entry.mode = static_cast(mode); - // Set Ctx - entry_ptr->conv_entry.ctx = x->ctx; - // Set Data Type - entry_ptr->conv_entry.data_type = static_cast( - dtype); // MIOpen supports fp32(miopenFloat), fp16(miopenHalf) at - // this moment. - // Set Desc - MIOPEN_CALL(miopenInitConvolutionDescriptor(entry_ptr->conv_entry.conv_desc, - entry_ptr->conv_entry.mode, - pad_h, - pad_w, - stride_h, - stride_w, - dilation_h, - dilation_w)); - // Set Filter - MIOPEN_CALL(miopenSet4dTensorDescriptor(entry_ptr->conv_entry.filter_desc, - entry_ptr->conv_entry.data_type, - w->shape[0], - w->shape[1], - w->shape[2], - w->shape[3])); - // Set Input - MIOPEN_CALL(miopenSet4dTensorDescriptor(entry_ptr->conv_entry.input_desc, - entry_ptr->conv_entry.data_type, - x->shape[0], - x->shape[1], - x->shape[2], - x->shape[3])); - // Set Output - MIOPEN_CALL(miopenSet4dTensorDescriptor(entry_ptr->conv_entry.output_desc, - entry_ptr->conv_entry.data_type, - y->shape[0], - y->shape[1], - y->shape[2], - y->shape[3])); - - const float alpha = 1.f; - const float beta = 0.f; - MIOPEN_CALL(miopenConvolutionForward(entry_ptr->handle, - &alpha, - entry_ptr->conv_entry.input_desc, - x->data, - entry_ptr->conv_entry.filter_desc, - w->data, - entry_ptr->conv_entry.conv_desc, - entry_ptr->conv_entry.fwd_algo, - &beta, - entry_ptr->conv_entry.output_desc, - y->data, - entry_ptr->conv_entry.workspace, - entry_ptr->conv_entry.workspace_size)); -}); + .set_body([](TVMArgs args, TVMRetValue* ret) { + const int mode = args[0]; + const int dtype = args[1]; + const int pad_h = args[2]; + const int pad_w = args[3]; + const int stride_h = args[4]; + const int stride_w = args[5]; + const int dilation_h = args[6]; + const int dilation_w = args[7]; + const int algo = args[8]; + const DLTensor* x = args[9]; + const DLTensor* w = args[10]; + const DLTensor* y = args[11]; + + MIOpenThreadEntry* entry_ptr = MIOpenThreadEntry::ThreadLocal(); + entry_ptr->conv_entry.fwd_algo = static_cast(algo); + // Set Mode + entry_ptr->conv_entry.mode = static_cast(mode); + // Set Ctx + entry_ptr->conv_entry.ctx = x->ctx; + // Set Data Type + entry_ptr->conv_entry.data_type = + static_cast(dtype); // MIOpen supports fp32(miopenFloat), + // fp16(miopenHalf) at this moment. + // Set Desc + MIOPEN_CALL(miopenInitConvolutionDescriptor(entry_ptr->conv_entry.conv_desc, + entry_ptr->conv_entry.mode, pad_h, pad_w, + stride_h, stride_w, dilation_h, dilation_w)); + // Set Filter + MIOPEN_CALL(miopenSet4dTensorDescriptor(entry_ptr->conv_entry.filter_desc, + entry_ptr->conv_entry.data_type, w->shape[0], + w->shape[1], w->shape[2], w->shape[3])); + // Set Input + MIOPEN_CALL(miopenSet4dTensorDescriptor(entry_ptr->conv_entry.input_desc, + entry_ptr->conv_entry.data_type, x->shape[0], + x->shape[1], x->shape[2], x->shape[3])); + // Set Output + MIOPEN_CALL(miopenSet4dTensorDescriptor(entry_ptr->conv_entry.output_desc, + entry_ptr->conv_entry.data_type, y->shape[0], + y->shape[1], y->shape[2], y->shape[3])); + + const float alpha = 1.f; + const float beta = 0.f; + MIOPEN_CALL(miopenConvolutionForward( + entry_ptr->handle, &alpha, entry_ptr->conv_entry.input_desc, x->data, + entry_ptr->conv_entry.filter_desc, w->data, entry_ptr->conv_entry.conv_desc, + entry_ptr->conv_entry.fwd_algo, &beta, entry_ptr->conv_entry.output_desc, y->data, + entry_ptr->conv_entry.workspace, entry_ptr->conv_entry.workspace_size)); + }); } // namespace miopen } // namespace contrib diff --git a/src/runtime/contrib/miopen/miopen_utils.cc b/src/runtime/contrib/miopen/miopen_utils.cc index 330ccdd043d0..a57918045d87 100644 --- a/src/runtime/contrib/miopen/miopen_utils.cc +++ b/src/runtime/contrib/miopen/miopen_utils.cc @@ -21,20 +21,22 @@ * \file Use external miopen utils function */ #include "miopen_utils.h" + #include #include -#include + #include +#include namespace tvm { namespace contrib { namespace miopen { std::string miopenGetErrorString(int error_code) { - const std::vector mio_err{ - "StatusSuccess ", "StatusNotInitialized ", "StatusInvalidValue ", - "StatusBadParm ", "StatusAllocFailed ", "StatusInternalError ", - "StatusNotImplemented ", "StatusUnknownError "}; + const std::vector mio_err{"StatusSuccess ", "StatusNotInitialized ", + "StatusInvalidValue ", "StatusBadParm ", + "StatusAllocFailed ", "StatusInternalError ", + "StatusNotImplemented ", "StatusUnknownError "}; return mio_err[error_code]; } @@ -42,22 +44,18 @@ std::string miopenGetErrorString(int error_code) { MIOpenThreadEntry::MIOpenThreadEntry() { auto stream = runtime::ROCMThreadEntry::ThreadLocal()->stream; auto func = runtime::Registry::Get("device_api.rocm"); - void *ret = (*func)(); + void* ret = (*func)(); rocm_api = static_cast(ret); MIOPEN_CALL(miopenCreate(&handle)); MIOPEN_CALL(miopenSetStream(handle, stream)); conv_entry.rocm_api = rocm_api; } -MIOpenThreadEntry::~MIOpenThreadEntry() { - MIOPEN_CALL(miopenDestroy(handle)); -} +MIOpenThreadEntry::~MIOpenThreadEntry() { MIOPEN_CALL(miopenDestroy(handle)); } typedef dmlc::ThreadLocalStore MIOpenThreadStore; -MIOpenThreadEntry* MIOpenThreadEntry::ThreadLocal() { - return MIOpenThreadStore::Get(); -} +MIOpenThreadEntry* MIOpenThreadEntry::ThreadLocal() { return MIOpenThreadStore::Get(); } // ConvEntry diff --git a/src/runtime/contrib/miopen/miopen_utils.h b/src/runtime/contrib/miopen/miopen_utils.h index 8831e4fac95c..4dec2ad710ba 100644 --- a/src/runtime/contrib/miopen/miopen_utils.h +++ b/src/runtime/contrib/miopen/miopen_utils.h @@ -27,7 +27,9 @@ #include #include #include + #include + #include "../../rocm/rocm_common.h" namespace tvm { @@ -36,11 +38,10 @@ namespace miopen { std::string miopenGetErrorString(int error_code); -#define MIOPEN_CALL(func) \ - { \ - miopenStatus_t e = (func); \ - CHECK_EQ(e, miopenStatusSuccess) \ - << "miopen error: " << miopenGetErrorString(e); \ +#define MIOPEN_CALL(func) \ + { \ + miopenStatus_t e = (func); \ + CHECK_EQ(e, miopenStatusSuccess) << "miopen error: " << miopenGetErrorString(e); \ } struct ConvEntry { @@ -52,8 +53,8 @@ struct ConvEntry { miopenTensorDescriptor_t output_desc; miopenConvFwdAlgorithm_t fwd_algo; TVMContext ctx; - runtime::DeviceAPI *rocm_api; - void *workspace{nullptr}; + runtime::DeviceAPI* rocm_api; + void* workspace{nullptr}; size_t workspace_size{0}; ConvEntry(); ~ConvEntry(); @@ -66,8 +67,8 @@ struct MIOpenThreadEntry { ~MIOpenThreadEntry(); miopenHandle_t handle{nullptr}; ConvEntry conv_entry; - runtime::DeviceAPI *rocm_api{nullptr}; - static MIOpenThreadEntry *ThreadLocal(); + runtime::DeviceAPI* rocm_api{nullptr}; + static MIOpenThreadEntry* ThreadLocal(); }; // MIOpenThreadEntry } // namespace miopen diff --git a/src/runtime/contrib/mps/conv.mm b/src/runtime/contrib/mps/conv.mm index 064e6d53cfb8..b598014f0267 100644 --- a/src/runtime/contrib/mps/conv.mm +++ b/src/runtime/contrib/mps/conv.mm @@ -24,69 +24,59 @@ using namespace runtime; -TVM_REGISTER_GLOBAL("tvm.contrib.mps.buffer2img") -.set_body([](TVMArgs args, TVMRetValue *ret) { - DLTensor *buf = args[0]; - DLTensor *img = args[1]; +TVM_REGISTER_GLOBAL("tvm.contrib.mps.buffer2img").set_body([](TVMArgs args, TVMRetValue* ret) { + DLTensor* buf = args[0]; + DLTensor* img = args[1]; // copy to temp id mtlbuf = (__bridge id)(buf->data); - MetalThreadEntry *entry_ptr = MetalThreadEntry::ThreadLocal(); - runtime::metal::MetalThreadEntry *rt = - runtime::metal::MetalThreadEntry::ThreadLocal(); + MetalThreadEntry* entry_ptr = MetalThreadEntry::ThreadLocal(); + runtime::metal::MetalThreadEntry* rt = runtime::metal::MetalThreadEntry::ThreadLocal(); id dev = entry_ptr->metal_api->GetDevice(buf->ctx); id temp = rt->GetTempBuffer(buf->ctx, [mtlbuf length]); - entry_ptr->metal_api->CopyDataFromTo( - (__bridge void *)mtlbuf, 0, (__bridge void *)temp, 0, [mtlbuf length], - buf->ctx, buf->ctx, nullptr - ); + entry_ptr->metal_api->CopyDataFromTo((__bridge void*)mtlbuf, 0, (__bridge void*)temp, 0, + [mtlbuf length], buf -> ctx, buf -> ctx, nullptr); - MPSImageDescriptor *desc = [MPSImageDescriptor - imageDescriptorWithChannelFormat:MPSImageFeatureChannelFormatFloat32 - width:buf->shape[2] - height:buf->shape[1] - featureChannels:buf->shape[3]]; + MPSImageDescriptor* desc = + [MPSImageDescriptor imageDescriptorWithChannelFormat:MPSImageFeatureChannelFormatFloat32 + width:buf->shape[2] + height:buf->shape[1] + featureChannels:buf->shape[3]]; - MPSImage *mpsimg = entry_ptr->AllocMPSImage(dev, desc); + MPSImage* mpsimg = entry_ptr->AllocMPSImage(dev, desc); [mpsimg writeBytes:[temp contents] dataLayout:MPSDataLayoutHeightxWidthxFeatureChannels imageIndex:0]; - img->data = (__bridge void *)mpsimg; + img->data = (__bridge void*)mpsimg; [mpsimg readBytes:[temp contents] dataLayout:MPSDataLayoutHeightxWidthxFeatureChannels imageIndex:0]; +}); - }); - -TVM_REGISTER_GLOBAL("tvm.contrib.mps.img2buffer") -.set_body([](TVMArgs args, TVMRetValue *ret) { - DLTensor *img = args[0]; - DLTensor *buf = args[1]; +TVM_REGISTER_GLOBAL("tvm.contrib.mps.img2buffer").set_body([](TVMArgs args, TVMRetValue* ret) { + DLTensor* img = args[0]; + DLTensor* buf = args[1]; id mtlbuf = (__bridge id)(buf->data); - MPSImage *mpsimg = (__bridge MPSImage *)(img->data); - MetalThreadEntry *entry_ptr = MetalThreadEntry::ThreadLocal(); - runtime::metal::MetalThreadEntry *rt = - runtime::metal::MetalThreadEntry::ThreadLocal(); + MPSImage* mpsimg = (__bridge MPSImage*)(img->data); + MetalThreadEntry* entry_ptr = MetalThreadEntry::ThreadLocal(); + runtime::metal::MetalThreadEntry* rt = runtime::metal::MetalThreadEntry::ThreadLocal(); id temp = rt->GetTempBuffer(buf->ctx, [mtlbuf length]); [mpsimg readBytes:[temp contents] dataLayout:MPSDataLayoutHeightxWidthxFeatureChannels imageIndex:0]; - entry_ptr->metal_api->CopyDataFromTo( - (__bridge void *)temp, 0, (__bridge void *)mtlbuf, 0, [mtlbuf length], - buf->ctx, buf->ctx, nullptr); - - }); + entry_ptr->metal_api->CopyDataFromTo((__bridge void*)temp, 0, (__bridge void*)mtlbuf, 0, + [mtlbuf length], buf -> ctx, buf -> ctx, nullptr); +}); -TVM_REGISTER_GLOBAL("tvm.contrib.mps.conv2d") -.set_body([](TVMArgs args, TVMRetValue *ret) { +TVM_REGISTER_GLOBAL("tvm.contrib.mps.conv2d").set_body([](TVMArgs args, TVMRetValue* ret) { // MPS-NHWC - DLTensor *data = args[0]; - DLTensor *weight = args[1]; - DLTensor *output = args[2]; + DLTensor* data = args[0]; + DLTensor* weight = args[1]; + DLTensor* output = args[2]; int pad = args[3]; int stride = args[4]; @@ -108,54 +98,48 @@ auto f_buf2img = runtime::Registry::Get("tvm.contrib.mps.buffer2img"); auto f_img2buf = runtime::Registry::Get("tvm.contrib.mps.img2buffer"); // Get Metal device API - MetalThreadEntry *entry_ptr = MetalThreadEntry::ThreadLocal(); - runtime::metal::MetalThreadEntry *rt = - runtime::metal::MetalThreadEntry::ThreadLocal(); + MetalThreadEntry* entry_ptr = MetalThreadEntry::ThreadLocal(); + runtime::metal::MetalThreadEntry* rt = runtime::metal::MetalThreadEntry::ThreadLocal(); id dev = entry_ptr->metal_api->GetDevice(data->ctx); - id queue = - entry_ptr->metal_api->GetCommandQueue(data->ctx); + id queue = entry_ptr->metal_api->GetCommandQueue(data->ctx); id cb = [queue commandBuffer]; // data to MPSImage DLTensor tmp_in; (*f_buf2img)(data, &tmp_in); - MPSImage *tempA = (__bridge MPSImage *)tmp_in.data; + MPSImage* tempA = (__bridge MPSImage*)tmp_in.data; // weight to temp memory id bufB = (__bridge id)(weight->data); id tempB = rt->GetTempBuffer(weight->ctx, [bufB length]); - entry_ptr->metal_api->CopyDataFromTo( - (__bridge void *)bufB, 0, (__bridge void *)tempB, 0, [bufB length], - weight->ctx, weight->ctx, nullptr); - float *ptr_w = (float *)[tempB contents]; + entry_ptr->metal_api->CopyDataFromTo((__bridge void*)bufB, 0, (__bridge void*)tempB, 0, + [bufB length], weight -> ctx, weight -> ctx, nullptr); + float* ptr_w = (float*)[tempB contents]; // output to MPSImage DLTensor tmp_out; (*f_buf2img)(output, &tmp_out); - MPSImage *tempC = (__bridge MPSImage *)tmp_out.data; + MPSImage* tempC = (__bridge MPSImage*)tmp_out.data; // conv desc - MPSCNNConvolutionDescriptor *conv_desc = [MPSCNNConvolutionDescriptor - cnnConvolutionDescriptorWithKernelWidth:kW - kernelHeight:kH - inputFeatureChannels:iCh - outputFeatureChannels:oCh]; + MPSCNNConvolutionDescriptor* conv_desc = + [MPSCNNConvolutionDescriptor cnnConvolutionDescriptorWithKernelWidth:kW + kernelHeight:kH + inputFeatureChannels:iCh + outputFeatureChannels:oCh]; [conv_desc setStrideInPixelsX:stride]; [conv_desc setStrideInPixelsY:stride]; - MPSCNNConvolution *conv = - [[MPSCNNConvolution alloc] initWithDevice:dev - convolutionDescriptor:conv_desc - kernelWeights:ptr_w - biasTerms:nil - flags:MPSCNNConvolutionFlagsNone]; + MPSCNNConvolution* conv = [[MPSCNNConvolution alloc] initWithDevice:dev + convolutionDescriptor:conv_desc + kernelWeights:ptr_w + biasTerms:nil + flags:MPSCNNConvolutionFlagsNone]; if (pad == 0) { - conv.padding = [MPSNNDefaultPadding - paddingWithMethod:MPSNNPaddingMethodAddRemainderToTopLeft | - MPSNNPaddingMethodAlignCentered | - MPSNNPaddingMethodSizeSame]; + conv.padding = [MPSNNDefaultPadding paddingWithMethod:MPSNNPaddingMethodAddRemainderToTopLeft | + MPSNNPaddingMethodAlignCentered | + MPSNNPaddingMethodSizeSame]; } else if (pad == 1) { - conv.padding = [MPSNNDefaultPadding - paddingWithMethod:MPSNNPaddingMethodAddRemainderToTopLeft | - MPSNNPaddingMethodAlignCentered | - MPSNNPaddingMethodSizeValidOnly]; + conv.padding = [MPSNNDefaultPadding paddingWithMethod:MPSNNPaddingMethodAddRemainderToTopLeft | + MPSNNPaddingMethodAlignCentered | + MPSNNPaddingMethodSizeValidOnly]; } [conv encodeToCommandBuffer:cb sourceImage:tempA destinationImage:tempC]; @@ -166,8 +150,7 @@ [cb waitUntilCompleted]; (*f_img2buf)(&tmp_out, output); +}); - }); - -} // namespace contrib -} // namespace tvm +} // namespace contrib +} // namespace tvm diff --git a/src/runtime/contrib/mps/gemm.mm b/src/runtime/contrib/mps/gemm.mm index bc1216704cc4..109c952ff0c4 100644 --- a/src/runtime/contrib/mps/gemm.mm +++ b/src/runtime/contrib/mps/gemm.mm @@ -24,11 +24,10 @@ using namespace runtime; -TVM_REGISTER_GLOBAL("tvm.contrib.mps.matmul") -.set_body([](TVMArgs args, TVMRetValue *ret) { - DLTensor *A = args[0]; - DLTensor *B = args[1]; - DLTensor *C = args[2]; +TVM_REGISTER_GLOBAL("tvm.contrib.mps.matmul").set_body([](TVMArgs args, TVMRetValue* ret) { + DLTensor* A = args[0]; + DLTensor* B = args[1]; + DLTensor* C = args[2]; bool transa = args[3]; bool transb = args[4]; // call gemm for simple compact code. @@ -42,7 +41,7 @@ CHECK(TypeMatch(B->dtype, kDLFloat, 32)); CHECK(TypeMatch(C->dtype, kDLFloat, 32)); // Get Metal device API - MetalThreadEntry *entry_ptr = MetalThreadEntry::ThreadLocal(); + MetalThreadEntry* entry_ptr = MetalThreadEntry::ThreadLocal(); // CHECK_EQ(A->ctx, B->ctx); // CHECK_EQ(A->ctx, C->ctx); id dev = entry_ptr->metal_api->GetDevice(A->ctx); @@ -55,36 +54,31 @@ CHECK_EQ(A->shape[1 - (transa ? 1 : 0)], K); // mps a MPSDataType dtype = MPSType::DLTypeToMPSType(A->dtype); - MPSMatrixDescriptor *descA = [MPSMatrixDescriptor - matrixDescriptorWithDimensions:M - columns:K - rowBytes:K * sizeof(MPSDataTypeFloat32) - dataType:MPSDataTypeFloat32]; + MPSMatrixDescriptor* descA = + [MPSMatrixDescriptor matrixDescriptorWithDimensions:M + columns:K + rowBytes:K * sizeof(MPSDataTypeFloat32) + dataType:MPSDataTypeFloat32]; id bufA = (__bridge id)(A->data); - MPSMatrix *matrixA = - [[MPSMatrix alloc] initWithBuffer:bufA descriptor:descA]; + MPSMatrix* matrixA = [[MPSMatrix alloc] initWithBuffer:bufA descriptor:descA]; // mps b - MPSMatrixDescriptor *descB = - [MPSMatrixDescriptor matrixDescriptorWithDimensions:K - columns:N - rowBytes:N * sizeof(dtype) - dataType:dtype]; + MPSMatrixDescriptor* descB = [MPSMatrixDescriptor matrixDescriptorWithDimensions:K + columns:N + rowBytes:N * sizeof(dtype) + dataType:dtype]; id bufB = (__bridge id)(B->data); - MPSMatrix *matrixB = - [[MPSMatrix alloc] initWithBuffer:bufB descriptor:descB]; + MPSMatrix* matrixB = [[MPSMatrix alloc] initWithBuffer:bufB descriptor:descB]; // mps c - MPSMatrixDescriptor *descC = - [MPSMatrixDescriptor matrixDescriptorWithDimensions:M - columns:N - rowBytes:N * sizeof(dtype) - dataType:dtype]; + MPSMatrixDescriptor* descC = [MPSMatrixDescriptor matrixDescriptorWithDimensions:M + columns:N + rowBytes:N * sizeof(dtype) + dataType:dtype]; id bufC = (__bridge id)(C->data); - MPSMatrix *matrixC = - [[MPSMatrix alloc] initWithBuffer:bufC descriptor:descC]; + MPSMatrix* matrixC = [[MPSMatrix alloc] initWithBuffer:bufC descriptor:descC]; // kernel - MPSMatrixMultiplication *mul_obj = [[MPSMatrixMultiplication alloc] init]; - MPSMatrixMultiplication *sgemm = [mul_obj initWithDevice:dev + MPSMatrixMultiplication* mul_obj = [[MPSMatrixMultiplication alloc] init]; + MPSMatrixMultiplication* sgemm = [mul_obj initWithDevice:dev transposeLeft:transa transposeRight:transb resultRows:M @@ -93,13 +87,9 @@ alpha:1.0f beta:0.0f]; CHECK(sgemm != nil); - [sgemm encodeToCommandBuffer:cb - leftMatrix:matrixA - rightMatrix:matrixB - resultMatrix:matrixC]; + [sgemm encodeToCommandBuffer:cb leftMatrix:matrixA rightMatrix:matrixB resultMatrix:matrixC]; [cb commit]; +}); - }); - -} // namespace contrib -} // namespace tvm +} // namespace contrib +} // namespace tvm diff --git a/src/runtime/contrib/mps/mps_utils.h b/src/runtime/contrib/mps/mps_utils.h index f1fff95c1df3..170451ea385b 100644 --- a/src/runtime/contrib/mps/mps_utils.h +++ b/src/runtime/contrib/mps/mps_utils.h @@ -27,10 +27,12 @@ #import #include #include +#include #include #include -#include + #include + #include "../../metal/metal_common.h" namespace tvm { @@ -38,18 +40,17 @@ namespace contrib { /*! breif Convert DLTensor type to MPS type */ struct MPSType { - static MPSDataType DLTypeToMPSType(const DLDataType &dtype); + static MPSDataType DLTypeToMPSType(const DLDataType& dtype); }; // struct MPSType struct MetalThreadEntry { MetalThreadEntry(); ~MetalThreadEntry(); - MPSImage *AllocMPSImage(id dev, MPSImageDescriptor *desc); - MPSTemporaryImage *AllocTempImage(id cb, - MPSImageDescriptor *desc); - runtime::metal::MetalWorkspace *metal_api{nullptr}; - static MetalThreadEntry *ThreadLocal(); - std::vector img_table; + MPSImage* AllocMPSImage(id dev, MPSImageDescriptor* desc); + MPSTemporaryImage* AllocTempImage(id cb, MPSImageDescriptor* desc); + runtime::metal::MetalWorkspace* metal_api{nullptr}; + static MetalThreadEntry* ThreadLocal(); + std::vector img_table; }; // MetalThreadEntry } // namespace contrib diff --git a/src/runtime/contrib/mps/mps_utils.mm b/src/runtime/contrib/mps/mps_utils.mm index b3d4070ca6b7..f9f80431165e 100644 --- a/src/runtime/contrib/mps/mps_utils.mm +++ b/src/runtime/contrib/mps/mps_utils.mm @@ -23,60 +23,58 @@ namespace contrib { // MPS Data Type -MPSDataType MPSType::DLTypeToMPSType(const DLDataType &dtype) { +MPSDataType MPSType::DLTypeToMPSType(const DLDataType& dtype) { switch (dtype.code) { - case kDLInt: - if (dtype.bits == 8 && dtype.lanes == 1) - return MPSDataTypeInt8; - else if (dtype.bits == 16 && dtype.lanes == 1) - return MPSDataTypeInt16; - else + case kDLInt: + if (dtype.bits == 8 && dtype.lanes == 1) + return MPSDataTypeInt8; + else if (dtype.bits == 16 && dtype.lanes == 1) + return MPSDataTypeInt16; + else + LOG(FATAL) << "Unsupported type"; + break; + case kDLUInt: + if (dtype.bits == 8 && dtype.lanes == 1) + return MPSDataTypeUInt8; + else if (dtype.bits == 16 && dtype.lanes == 1) + return MPSDataTypeUInt16; + else if (dtype.bits == 32 && dtype.lanes == 1) + return MPSDataTypeUInt32; LOG(FATAL) << "Unsupported type"; - break; - case kDLUInt: - if (dtype.bits == 8 && dtype.lanes == 1) - return MPSDataTypeUInt8; - else if (dtype.bits == 16 && dtype.lanes == 1) - return MPSDataTypeUInt16; - else if (dtype.bits == 32 && dtype.lanes == 1) - return MPSDataTypeUInt32; - LOG(FATAL) << "Unsupported type"; - break; - case kDLFloat: - if (dtype.bits == 16 && dtype.lanes == 1) - return MPSDataTypeFloat16; - else if (dtype.bits == 32 && dtype.lanes == 1) - return MPSDataTypeFloat32; - else + break; + case kDLFloat: + if (dtype.bits == 16 && dtype.lanes == 1) + return MPSDataTypeFloat16; + else if (dtype.bits == 32 && dtype.lanes == 1) + return MPSDataTypeFloat32; + else + LOG(FATAL) << "Unsupported type"; + break; + default: LOG(FATAL) << "Unsupported type"; - break; - default: - LOG(FATAL) << "Unsupported type"; } return MPSDataTypeFloat32; } // MetalThreadEntry -MPSImage *MetalThreadEntry::AllocMPSImage(id dev, - MPSImageDescriptor *desc) { - MPSImage *mpsimg = [[MPSImage alloc] initWithDevice:dev imageDescriptor:desc]; +MPSImage* MetalThreadEntry::AllocMPSImage(id dev, MPSImageDescriptor* desc) { + MPSImage* mpsimg = [[MPSImage alloc] initWithDevice:dev imageDescriptor:desc]; img_table.push_back(mpsimg); return mpsimg; } -MPSTemporaryImage *MetalThreadEntry::AllocTempImage(id cb, - MPSImageDescriptor *desc) { - MPSTemporaryImage *mpsimg = - [MPSTemporaryImage temporaryImageWithCommandBuffer:cb - imageDescriptor:desc]; +MPSTemporaryImage* MetalThreadEntry::AllocTempImage(id cb, + MPSImageDescriptor* desc) { + MPSTemporaryImage* mpsimg = [MPSTemporaryImage temporaryImageWithCommandBuffer:cb + imageDescriptor:desc]; return mpsimg; } MetalThreadEntry::MetalThreadEntry() { auto func = runtime::Registry::Get("device_api.metal"); - void *ret = (*func)(); - metal_api = static_cast(ret); + void* ret = (*func)(); + metal_api = static_cast(ret); } MetalThreadEntry::~MetalThreadEntry() { @@ -87,9 +85,7 @@ typedef dmlc::ThreadLocalStore MetalThreadStore; -MetalThreadEntry *MetalThreadEntry::ThreadLocal() { - return MetalThreadStore::Get(); -} +MetalThreadEntry* MetalThreadEntry::ThreadLocal() { return MetalThreadStore::Get(); } -} // namespace contrib -} // namespace tvm +} // namespace contrib +} // namespace tvm diff --git a/src/runtime/contrib/nnpack/convolution.cc b/src/runtime/contrib/nnpack/convolution.cc index 79ea19175d65..54c9ea4f969b 100644 --- a/src/runtime/contrib/nnpack/convolution.cc +++ b/src/runtime/contrib/nnpack/convolution.cc @@ -20,11 +20,12 @@ /*! * \file Use external nnpack library call. */ -#include -#include -#include #include #include +#include +#include +#include + #include "nnpack_utils.h" namespace tvm { @@ -32,28 +33,25 @@ namespace contrib { using namespace runtime; TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference") - .set_body([](TVMArgs args, TVMRetValue *ret) { - NNPackThreadLocalEntry *entry = NNPackThreadLocalEntry::ThreadLocal(); + .set_body([](TVMArgs args, TVMRetValue* ret) { + NNPackThreadLocalEntry* entry = NNPackThreadLocalEntry::ThreadLocal(); static std::once_flag flag; - std::call_once(flag, - []() { CHECK_EQ(nnp_initialize(), nnp_status_success); }); - DLTensor *input = args[0]; - DLTensor *kernel = args[1]; - DLTensor *bias = nullptr; + std::call_once(flag, []() { CHECK_EQ(nnp_initialize(), nnp_status_success); }); + DLTensor* input = args[0]; + DLTensor* kernel = args[1]; + DLTensor* bias = nullptr; if (args[2].type_code() == kTVMDLTensorHandle) { bias = args[2]; } - DLTensor *output = args[3]; - uint64_t pad_top = args[4], pad_right = args[5], pad_bottom = args[6], - pad_left = args[7]; + DLTensor* output = args[3]; + uint64_t pad_top = args[4], pad_right = args[5], pad_bottom = args[6], pad_left = args[7]; nnp_padding input_padding{pad_top, pad_right, pad_bottom, pad_left}; uint64_t stride_width = args[8], stride_height = args[9]; nnp_size stride_size{stride_width, stride_height}; NNPackConfig(args[10]); uint64_t algo_ = args[11]; - nnp_convolution_algorithm algo = - static_cast(algo_); + nnp_convolution_algorithm algo = static_cast(algo_); CHECK_EQ(input->ndim, 4); CHECK_EQ(kernel->ndim, 4); if (bias) { @@ -93,10 +91,9 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference") size_t workspace_size = 0; nnp_status status = nnp_convolution_inference( - algo, nnp_convolution_transform_strategy_compute, input_channels, - output_channels, input_size, input_padding, kernel_size, stride_size, - nullptr, nullptr, nullptr, nullptr, nullptr, &workspace_size, - nnp_activation_identity, nullptr, entry->threadpool, nullptr); + algo, nnp_convolution_transform_strategy_compute, input_channels, output_channels, + input_size, input_padding, kernel_size, stride_size, nullptr, nullptr, nullptr, nullptr, + nullptr, &workspace_size, nnp_activation_identity, nullptr, entry->threadpool, nullptr); CHECK_EQ(status, nnp_status_success); // Division with rounding up, in case size is not multiple of sizeof(float) @@ -107,24 +104,21 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference") DeviceAPI* cpu_api = DeviceAPI::Get(ctx); void* workspace_buffer = - cpu_api->AllocWorkspace(ctx, workspace_elements * sizeof(float), type_hint); + cpu_api->AllocWorkspace(ctx, workspace_elements * sizeof(float), type_hint); CHECK(workspace_buffer != nullptr); for (auto n = 0; n < input->shape[0]; ++n) { nnp_status status = nnp_convolution_inference( - algo, nnp_convolution_transform_strategy_compute, input_channels, - output_channels, input_size, input_padding, kernel_size, - stride_size, - static_cast(input->data) + n * input->shape[1] * - input->shape[2] * - input->shape[3], - static_cast(kernel->data), - bias ? static_cast(bias->data) : zero_bias->data(), - static_cast(output->data) + n * output->shape[1] * - output->shape[2] * - output->shape[3], - workspace_buffer, &workspace_size, - nnp_activation_identity, nullptr, entry->threadpool, nullptr); + algo, nnp_convolution_transform_strategy_compute, input_channels, output_channels, + input_size, input_padding, kernel_size, stride_size, + static_cast(input->data) + + n * input->shape[1] * input->shape[2] * input->shape[3], + static_cast(kernel->data), + bias ? static_cast(bias->data) : zero_bias->data(), + static_cast(output->data) + + n * output->shape[1] * output->shape[2] * output->shape[3], + workspace_buffer, &workspace_size, nnp_activation_identity, nullptr, entry->threadpool, + nullptr); CHECK_EQ(status, nnp_status_success); } @@ -132,28 +126,25 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference") }); TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference_without_weight_transform") - .set_body([](TVMArgs args, TVMRetValue *ret) { - NNPackThreadLocalEntry *entry = NNPackThreadLocalEntry::ThreadLocal(); + .set_body([](TVMArgs args, TVMRetValue* ret) { + NNPackThreadLocalEntry* entry = NNPackThreadLocalEntry::ThreadLocal(); static std::once_flag flag; - std::call_once(flag, - []() { CHECK_EQ(nnp_initialize(), nnp_status_success); }); - DLTensor *input = args[0]; - DLTensor *transformed_kernel = args[1]; - DLTensor *bias = nullptr; + std::call_once(flag, []() { CHECK_EQ(nnp_initialize(), nnp_status_success); }); + DLTensor* input = args[0]; + DLTensor* transformed_kernel = args[1]; + DLTensor* bias = nullptr; if (args[2].type_code() == kTVMDLTensorHandle) { bias = args[2]; } - DLTensor *output = args[3]; - uint64_t pad_top = args[4], pad_right = args[5], pad_bottom = args[6], - pad_left = args[7]; + DLTensor* output = args[3]; + uint64_t pad_top = args[4], pad_right = args[5], pad_bottom = args[6], pad_left = args[7]; nnp_padding input_padding{pad_top, pad_right, pad_bottom, pad_left}; uint64_t stride_width = args[8], stride_height = args[9]; nnp_size stride_size{stride_width, stride_height}; NNPackConfig(args[10]); uint64_t algo_ = args[11]; - nnp_convolution_algorithm algo = - static_cast(algo_); + nnp_convolution_algorithm algo = static_cast(algo_); CHECK_EQ(input->ndim, 4); if (bias) { CHECK_EQ(bias->ndim, 1); @@ -189,10 +180,9 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference_without_weight_tra size_t workspace_size = 0; nnp_status status = nnp_convolution_inference( - algo, nnp_convolution_transform_strategy_reuse, input_channels, - output_channels, input_size, input_padding, kernel_size, stride_size, - nullptr, nullptr, nullptr, nullptr, nullptr, &workspace_size, - nnp_activation_identity, nullptr, entry->threadpool, nullptr); + algo, nnp_convolution_transform_strategy_reuse, input_channels, output_channels, + input_size, input_padding, kernel_size, stride_size, nullptr, nullptr, nullptr, nullptr, + nullptr, &workspace_size, nnp_activation_identity, nullptr, entry->threadpool, nullptr); CHECK_EQ(status, nnp_status_success); // Division with rounding up, in case size is not multiple of sizeof(float) @@ -203,38 +193,34 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference_without_weight_tra DeviceAPI* cpu_api = DeviceAPI::Get(ctx); void* workspace_buffer = - cpu_api->AllocWorkspace(ctx, workspace_elements * sizeof(float), type_hint); + cpu_api->AllocWorkspace(ctx, workspace_elements * sizeof(float), type_hint); CHECK(workspace_buffer != nullptr); for (auto n = 0; n < input->shape[0]; ++n) { nnp_status status = nnp_convolution_inference( algo, nnp_convolution_transform_strategy_reuse, input_channels, output_channels, input_size, input_padding, kernel_size, stride_size, - static_cast(input->data) + n * input->shape[1] * - input->shape[2] * - input->shape[3], - static_cast(transformed_kernel->data), - bias ? static_cast(bias->data) : zero_bias->data(), - static_cast(output->data) + n * output->shape[1] * - output->shape[2] * - output->shape[3], - workspace_buffer, &workspace_size, - nnp_activation_identity, nullptr, entry->threadpool, nullptr); + static_cast(input->data) + + n * input->shape[1] * input->shape[2] * input->shape[3], + static_cast(transformed_kernel->data), + bias ? static_cast(bias->data) : zero_bias->data(), + static_cast(output->data) + + n * output->shape[1] * output->shape[2] * output->shape[3], + workspace_buffer, &workspace_size, nnp_activation_identity, nullptr, entry->threadpool, + nullptr); CHECK_EQ(status, nnp_status_success); } cpu_api->FreeWorkspace(ctx, workspace_buffer); }); -TVM_REGISTER_GLOBAL( - "tvm.contrib.nnpack.convolution_inference_weight_transform") - .set_body([](TVMArgs args, TVMRetValue *ret) { - NNPackThreadLocalEntry *entry = NNPackThreadLocalEntry::ThreadLocal(); +TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference_weight_transform") + .set_body([](TVMArgs args, TVMRetValue* ret) { + NNPackThreadLocalEntry* entry = NNPackThreadLocalEntry::ThreadLocal(); static std::once_flag flag; - std::call_once(flag, - []() { CHECK_EQ(nnp_initialize(), nnp_status_success); }); - DLTensor *kernel = args[0]; - DLTensor *transformed_kernel = args[1]; + std::call_once(flag, []() { CHECK_EQ(nnp_initialize(), nnp_status_success); }); + DLTensor* kernel = args[0]; + DLTensor* transformed_kernel = args[1]; // Dummy sizes nnp_padding input_padding{1, 1, 1, 1}; nnp_size stride_size{1, 1}; @@ -244,8 +230,7 @@ TVM_REGISTER_GLOBAL( NNPackConfig(args[2]); uint64_t algo_ = args[3]; - nnp_convolution_algorithm algo = - static_cast(algo_); + nnp_convolution_algorithm algo = static_cast(algo_); CHECK_EQ(kernel->ndim, 4); size_t input_channels = kernel->shape[1]; size_t output_channels = kernel->shape[0]; @@ -259,21 +244,20 @@ TVM_REGISTER_GLOBAL( size_t transformed_kernel_size = 0; nnp_status status; status = nnp_convolution_inference( - algo, nnp_convolution_transform_strategy_precompute, input_channels, - output_channels, input_size, input_padding, kernel_size, stride_size, - nullptr, nullptr, nullptr, nullptr, nullptr, &transformed_kernel_size, - nnp_activation_identity, nullptr, entry->threadpool, nullptr); + algo, nnp_convolution_transform_strategy_precompute, input_channels, output_channels, + input_size, input_padding, kernel_size, stride_size, nullptr, nullptr, nullptr, nullptr, + nullptr, &transformed_kernel_size, nnp_activation_identity, nullptr, entry->threadpool, + nullptr); CHECK_EQ(status, nnp_status_success); CHECK_LE(transformed_kernel_size, GetDataSize(*transformed_kernel)); status = nnp_convolution_inference( - algo, nnp_convolution_transform_strategy_precompute, input_channels, - output_channels, input_size, input_padding, kernel_size, stride_size, - nullptr, static_cast(kernel->data), nullptr, nullptr, - static_cast(transformed_kernel->data), - &transformed_kernel_size, nnp_activation_identity, nullptr, - entry->threadpool, nullptr); + algo, nnp_convolution_transform_strategy_precompute, input_channels, output_channels, + input_size, input_padding, kernel_size, stride_size, nullptr, + static_cast(kernel->data), nullptr, nullptr, + static_cast(transformed_kernel->data), &transformed_kernel_size, + nnp_activation_identity, nullptr, entry->threadpool, nullptr); CHECK_EQ(status, nnp_status_success); }); } // namespace contrib diff --git a/src/runtime/contrib/nnpack/fully_connected.cc b/src/runtime/contrib/nnpack/fully_connected.cc index 5f111efac4df..543d23958633 100644 --- a/src/runtime/contrib/nnpack/fully_connected.cc +++ b/src/runtime/contrib/nnpack/fully_connected.cc @@ -20,10 +20,11 @@ /*! * \file Use external nnpack library call. */ -#include -#include #include #include +#include +#include + #include "nnpack_utils.h" namespace tvm { @@ -33,33 +34,30 @@ using namespace runtime; // matrix multiplication for row major TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.fully_connected_inference") -.set_body([](TVMArgs args, TVMRetValue *ret) { - NNPackThreadLocalEntry *entry = NNPackThreadLocalEntry::ThreadLocal(); - nnp_initialize(); - DLTensor* A = args[0]; - DLTensor* B = args[1]; - DLTensor* C = args[2]; - NNPackConfig(args[3]); + .set_body([](TVMArgs args, TVMRetValue* ret) { + NNPackThreadLocalEntry* entry = NNPackThreadLocalEntry::ThreadLocal(); + nnp_initialize(); + DLTensor* A = args[0]; + DLTensor* B = args[1]; + DLTensor* C = args[2]; + NNPackConfig(args[3]); - CHECK_EQ(A->ndim, 1); - CHECK_EQ(B->ndim, 2); - CHECK_EQ(C->ndim, 1); - CHECK_EQ(B->shape[0], C->shape[0]); - CHECK_EQ(B->shape[1], A->shape[0]); - CHECK(C->strides == nullptr); - CHECK(B->strides == nullptr); - CHECK(A->strides == nullptr); - CHECK(TypeMatch(A->dtype, kDLFloat, 32)); - CHECK(TypeMatch(B->dtype, kDLFloat, 32)); - CHECK(TypeMatch(C->dtype, kDLFloat, 32)); + CHECK_EQ(A->ndim, 1); + CHECK_EQ(B->ndim, 2); + CHECK_EQ(C->ndim, 1); + CHECK_EQ(B->shape[0], C->shape[0]); + CHECK_EQ(B->shape[1], A->shape[0]); + CHECK(C->strides == nullptr); + CHECK(B->strides == nullptr); + CHECK(A->strides == nullptr); + CHECK(TypeMatch(A->dtype, kDLFloat, 32)); + CHECK(TypeMatch(B->dtype, kDLFloat, 32)); + CHECK(TypeMatch(C->dtype, kDLFloat, 32)); - nnp_fully_connected_inference(B->shape[1], - B->shape[0], - static_cast(A->data), - static_cast(B->data), - static_cast(C->data), - entry->threadpool); - }); + nnp_fully_connected_inference(B->shape[1], B->shape[0], static_cast(A->data), + static_cast(B->data), static_cast(C->data), + entry->threadpool); + }); } // namespace contrib } // namespace tvm diff --git a/src/runtime/contrib/nnpack/nnpack_utils.cc b/src/runtime/contrib/nnpack/nnpack_utils.cc index f01ad8557fee..91cf865128e9 100644 --- a/src/runtime/contrib/nnpack/nnpack_utils.cc +++ b/src/runtime/contrib/nnpack/nnpack_utils.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -28,13 +28,12 @@ using namespace runtime; typedef dmlc::ThreadLocalStore NNPackThreadLocalStore; - NNPackThreadLocalEntry* NNPackThreadLocalEntry::ThreadLocal() { return NNPackThreadLocalStore::Get(); } bool NNPackConfig(uint64_t nthreads) { - NNPackThreadLocalEntry *entry = NNPackThreadLocalEntry::ThreadLocal(); + NNPackThreadLocalEntry* entry = NNPackThreadLocalEntry::ThreadLocal(); if (entry->threadpool && pthreadpool_get_threads_count(entry->threadpool) == nthreads) { CHECK_NE(nthreads, 1); return true; @@ -55,11 +54,9 @@ bool NNPackConfig(uint64_t nthreads) { return true; } - -TVM_REGISTER_GLOBAL("contrib.nnpack._initialize") -.set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = nnp_initialize(); - }); +TVM_REGISTER_GLOBAL("contrib.nnpack._initialize").set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = nnp_initialize(); +}); } // namespace contrib } // namespace tvm diff --git a/src/runtime/contrib/nnpack/nnpack_utils.h b/src/runtime/contrib/nnpack/nnpack_utils.h index 4ba586fe08ac..bbb0d16bc868 100644 --- a/src/runtime/contrib/nnpack/nnpack_utils.h +++ b/src/runtime/contrib/nnpack/nnpack_utils.h @@ -22,11 +22,11 @@ */ #ifndef TVM_RUNTIME_CONTRIB_NNPACK_NNPACK_UTILS_H_ #define TVM_RUNTIME_CONTRIB_NNPACK_NNPACK_UTILS_H_ -#include -#include -#include #include +#include #include +#include +#include namespace tvm { namespace contrib { diff --git a/src/runtime/contrib/random/mt_random_engine.cc b/src/runtime/contrib/random/mt_random_engine.cc index 37166e2c8d0f..c628e327643e 100644 --- a/src/runtime/contrib/random/mt_random_engine.cc +++ b/src/runtime/contrib/random/mt_random_engine.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -22,6 +22,7 @@ * \brief mt19937 random engine */ #include + #include #include #include @@ -34,45 +35,37 @@ namespace contrib { */ class RandomEngine { public: - /*! - * \brief Creates a RandomEngine using a default seed. - */ - RandomEngine() { - this->Seed(time(0)); - } - - /*! - * \brief Creates a RandomEngine, suggesting the use of a provided seed. - */ - explicit RandomEngine(unsigned seed) { - this->Seed(seed); - } - - /*! - * \brief Seeds the underlying RNG, if possible. - */ + /*! + * \brief Creates a RandomEngine using a default seed. + */ + RandomEngine() { this->Seed(time(0)); } + + /*! + * \brief Creates a RandomEngine, suggesting the use of a provided seed. + */ + explicit RandomEngine(unsigned seed) { this->Seed(seed); } + + /*! + * \brief Seeds the underlying RNG, if possible. + */ inline void Seed(unsigned seed) { rnd_engine_.seed(seed); this->rseed_ = static_cast(seed); } - /*! - * \return the seed associated with the underlying RNG. - */ - inline unsigned GetSeed() const { - return rseed_; - } + /*! + * \return the seed associated with the underlying RNG. + */ + inline unsigned GetSeed() const { return rseed_; } - /*! - * \return a random integer sampled from the RNG. - */ - inline unsigned GetRandInt() { - return rnd_engine_(); - } + /*! + * \return a random integer sampled from the RNG. + */ + inline unsigned GetRandInt() { return rnd_engine_(); } - /*! - * \brief Fills a tensor with values drawn from Unif(low, high) - */ + /*! + * \brief Fills a tensor with values drawn from Unif(low, high) + */ void SampleUniform(DLTensor* data, float low, float high) { CHECK_GT(high, low) << "high must be bigger than low"; CHECK(data->strides == nullptr); @@ -87,17 +80,16 @@ class RandomEngine { if (data->ctx.device_type == kDLCPU) { std::uniform_real_distribution uniform_dist(low, high); - std::generate_n(static_cast(data->data), size, [&] () { - return uniform_dist(rnd_engine_); - }); + std::generate_n(static_cast(data->data), size, + [&]() { return uniform_dist(rnd_engine_); }); } else { LOG(FATAL) << "Do not support random.uniform on this device yet"; } } - /*! - * \brief Fills a tensor with values drawn from Normal(loc, scale**2) - */ + /*! + * \brief Fills a tensor with values drawn from Normal(loc, scale**2) + */ void SampleNormal(DLTensor* data, float loc, float scale) { CHECK_GT(scale, 0) << "standard deviation must be positive"; CHECK(data->strides == nullptr); @@ -112,9 +104,8 @@ class RandomEngine { if (data->ctx.device_type == kDLCPU) { std::normal_distribution normal_dist(loc, scale); - std::generate_n(static_cast(data->data), size, [&] () { - return normal_dist(rnd_engine_); - }); + std::generate_n(static_cast(data->data), size, + [&]() { return normal_dist(rnd_engine_); }); } else { LOG(FATAL) << "Do not support random.normal on this device yet"; } diff --git a/src/runtime/contrib/random/random.cc b/src/runtime/contrib/random/random.cc index 8ae1f8668c87..acba193c1230 100644 --- a/src/runtime/contrib/random/random.cc +++ b/src/runtime/contrib/random/random.cc @@ -20,32 +20,34 @@ /*! * \file External random functions for tensor. */ -#include -#include #include #include +#include +#include + #include + #include "mt_random_engine.cc" #define DLPACK_INTEGER_TYPE_SWITCH(type, DType, ...) \ if (type.code == kDLInt && type.bits == 32) { \ typedef int32_t DType; \ - {__VA_ARGS__} \ + { __VA_ARGS__ } \ } else if (type.code == kDLInt && type.bits == 16) { \ typedef int16_t DType; \ - {__VA_ARGS__} \ + { __VA_ARGS__ } \ } else if (type.code == kDLInt && type.bits == 8) { \ typedef int8_t DType; \ - {__VA_ARGS__} \ + { __VA_ARGS__ } \ } else if (type.code == kDLUInt && type.bits == 32) { \ typedef uint32_t DType; \ - {__VA_ARGS__} \ + { __VA_ARGS__ } \ } else if (type.code == kDLUInt && type.bits == 16) { \ typedef uint16_t DType; \ - {__VA_ARGS__} \ + { __VA_ARGS__ } \ } else if (type.code == kDLUInt && type.bits == 8) { \ typedef uint8_t DType; \ - {__VA_ARGS__} \ + { __VA_ARGS__ } \ } else { \ LOG(FATAL) << "unknown data type"; \ } @@ -66,61 +68,54 @@ RandomThreadLocalEntry* RandomThreadLocalEntry::ThreadLocal() { return RandomThreadLocalStore::Get(); } +TVM_REGISTER_GLOBAL("tvm.contrib.random.randint").set_body([](TVMArgs args, TVMRetValue* ret) { + RandomThreadLocalEntry* entry = RandomThreadLocalEntry::ThreadLocal(); + int64_t low = args[0]; + int64_t high = args[1]; + DLTensor* out = args[2]; + CHECK_GT(high, low) << "high must be bigger than low"; + CHECK(out->strides == nullptr); + + DLDataType dtype = out->dtype; + int64_t size = 1; + for (int i = 0; i < out->ndim; ++i) { + size *= out->shape[i]; + } -TVM_REGISTER_GLOBAL("tvm.contrib.random.randint") -.set_body([](TVMArgs args, TVMRetValue *ret) { - RandomThreadLocalEntry *entry = RandomThreadLocalEntry::ThreadLocal(); - int64_t low = args[0]; - int64_t high = args[1]; - DLTensor* out = args[2]; - CHECK_GT(high, low) << "high must be bigger than low"; - CHECK(out->strides == nullptr); - - DLDataType dtype = out->dtype; - int64_t size = 1; - for (int i = 0; i < out->ndim; ++i) { - size *= out->shape[i]; + DLPACK_INTEGER_TYPE_SWITCH(dtype, DType, { + int64_t numeric_low = std::numeric_limits::min(); + int64_t numeric_high = std::numeric_limits::max(); + numeric_high += 1; // exclusive upper bound + low = std::max(low, numeric_low); + high = std::min(high, numeric_high); + + if (out->ctx.device_type == kDLCPU) { + // file the data with random byte + std::generate_n(static_cast(out->data), size, [&]() { + unsigned rint = entry->random_engine.GetRandInt(); + return low + rint % (high - low); + }); + } else { + LOG(FATAL) << "Do not support random.randint on this device yet"; } - - DLPACK_INTEGER_TYPE_SWITCH(dtype, DType, { - int64_t numeric_low = std::numeric_limits::min(); - int64_t numeric_high = std::numeric_limits::max(); - numeric_high += 1; // exclusive upper bound - low = std::max(low, numeric_low); - high = std::min(high, numeric_high); - - if (out->ctx.device_type == kDLCPU) { - // file the data with random byte - std::generate_n(static_cast(out->data), size, [&] () { - unsigned rint = entry->random_engine.GetRandInt(); - return low + rint % (high - low); - }); - } else { - LOG(FATAL) << "Do not support random.randint on this device yet"; - } - }) - }); - - -TVM_REGISTER_GLOBAL("tvm.contrib.random.uniform") -.set_body([](TVMArgs args, TVMRetValue *ret) { - RandomThreadLocalEntry *entry = RandomThreadLocalEntry::ThreadLocal(); - double low = args[0]; - double high = args[1]; - DLTensor* out = args[2]; - entry->random_engine.SampleUniform(out, low, high); - }); - - -TVM_REGISTER_GLOBAL("tvm.contrib.random.normal") -.set_body([](TVMArgs args, TVMRetValue *ret) { - RandomThreadLocalEntry *entry = RandomThreadLocalEntry::ThreadLocal(); - double loc = args[0]; - double scale = args[1]; - DLTensor* out = args[2]; - entry->random_engine.SampleNormal(out, loc, scale); - }); - + }) +}); + +TVM_REGISTER_GLOBAL("tvm.contrib.random.uniform").set_body([](TVMArgs args, TVMRetValue* ret) { + RandomThreadLocalEntry* entry = RandomThreadLocalEntry::ThreadLocal(); + double low = args[0]; + double high = args[1]; + DLTensor* out = args[2]; + entry->random_engine.SampleUniform(out, low, high); +}); + +TVM_REGISTER_GLOBAL("tvm.contrib.random.normal").set_body([](TVMArgs args, TVMRetValue* ret) { + RandomThreadLocalEntry* entry = RandomThreadLocalEntry::ThreadLocal(); + double loc = args[0]; + double scale = args[1]; + DLTensor* out = args[2]; + entry->random_engine.SampleNormal(out, loc, scale); +}); } // namespace contrib } // namespace tvm diff --git a/src/runtime/contrib/rocblas/rocblas.cc b/src/runtime/contrib/rocblas/rocblas.cc index dda4ee30fde5..0e6f4bd69686 100644 --- a/src/runtime/contrib/rocblas/rocblas.cc +++ b/src/runtime/contrib/rocblas/rocblas.cc @@ -20,75 +20,68 @@ /*! * \file Use external rocblas library call. */ -#include -#include -#include #include "rocblas.h" +#include +#include +#include + namespace tvm { namespace contrib { using namespace runtime; #ifndef CHECK_ROCBLAS_ERROR -#define CHECK_ROCBLAS_ERROR(error) \ -if (error != rocblas_status_success) { \ - fprintf(stderr, "rocBLAS error: "); \ - if (error == rocblas_status_invalid_handle) fprintf(stderr, "rocblas_status_invalid_handle"); \ - if (error == rocblas_status_not_implemented) fprintf(stderr, " rocblas_status_not_implemented"); \ - if (error == rocblas_status_invalid_pointer) fprintf(stderr, "rocblas_status_invalid_pointer"); \ - if (error == rocblas_status_invalid_size) fprintf(stderr, "rocblas_status_invalid_size"); \ - if (error == rocblas_status_memory_error) fprintf(stderr, "rocblas_status_memory_error"); \ - if (error == rocblas_status_internal_error) fprintf(stderr, "rocblas_status_internal_error"); \ - fprintf(stderr, "\n"); \ - exit(EXIT_FAILURE); \ -} +#define CHECK_ROCBLAS_ERROR(error) \ + if (error != rocblas_status_success) { \ + fprintf(stderr, "rocBLAS error: "); \ + if (error == rocblas_status_invalid_handle) fprintf(stderr, "rocblas_status_invalid_handle"); \ + if (error == rocblas_status_not_implemented) \ + fprintf(stderr, " rocblas_status_not_implemented"); \ + if (error == rocblas_status_invalid_pointer) \ + fprintf(stderr, "rocblas_status_invalid_pointer"); \ + if (error == rocblas_status_invalid_size) fprintf(stderr, "rocblas_status_invalid_size"); \ + if (error == rocblas_status_memory_error) fprintf(stderr, "rocblas_status_memory_error"); \ + if (error == rocblas_status_internal_error) fprintf(stderr, "rocblas_status_internal_error"); \ + fprintf(stderr, "\n"); \ + exit(EXIT_FAILURE); \ + } #endif - // matrix multiplication for row major -TVM_REGISTER_GLOBAL("tvm.contrib.rocblas.matmul") -.set_body([](TVMArgs args, TVMRetValue *ret) { - DLTensor* A = args[0]; - DLTensor* B = args[1]; - DLTensor* C = args[2]; - bool transa = args[3]; - bool transb = args[4]; - // call gemm for simple compact code. - CHECK_EQ(A->ndim, 2); - CHECK_EQ(B->ndim, 2); - CHECK_EQ(C->ndim, 2); - CHECK(C->strides == nullptr); - CHECK(B->strides == nullptr); - CHECK(A->strides == nullptr); - CHECK(TypeMatch(A->dtype, kDLFloat, 32)); - CHECK(TypeMatch(B->dtype, kDLFloat, 32)); - CHECK(TypeMatch(C->dtype, kDLFloat, 32)); +TVM_REGISTER_GLOBAL("tvm.contrib.rocblas.matmul").set_body([](TVMArgs args, TVMRetValue* ret) { + DLTensor* A = args[0]; + DLTensor* B = args[1]; + DLTensor* C = args[2]; + bool transa = args[3]; + bool transb = args[4]; + // call gemm for simple compact code. + CHECK_EQ(A->ndim, 2); + CHECK_EQ(B->ndim, 2); + CHECK_EQ(C->ndim, 2); + CHECK(C->strides == nullptr); + CHECK(B->strides == nullptr); + CHECK(A->strides == nullptr); + CHECK(TypeMatch(A->dtype, kDLFloat, 32)); + CHECK(TypeMatch(B->dtype, kDLFloat, 32)); + CHECK(TypeMatch(C->dtype, kDLFloat, 32)); - rocblas_handle handle; - CHECK_ROCBLAS_ERROR(rocblas_create_handle(&handle)); - float alpha = 1.0; - float beta = 0.0; - float *A_ptr = reinterpret_cast(static_cast(B->data) + B->byte_offset); - float *B_ptr = reinterpret_cast(static_cast(A->data) + A->byte_offset); - float *C_ptr = reinterpret_cast(static_cast(C->data) + C->byte_offset); + rocblas_handle handle; + CHECK_ROCBLAS_ERROR(rocblas_create_handle(&handle)); + float alpha = 1.0; + float beta = 0.0; + float* A_ptr = reinterpret_cast(static_cast(B->data) + B->byte_offset); + float* B_ptr = reinterpret_cast(static_cast(A->data) + A->byte_offset); + float* C_ptr = reinterpret_cast(static_cast(C->data) + C->byte_offset); - CHECK_ROCBLAS_ERROR(rocblas_sgemm(handle, - transb ? rocblas_operation_transpose : rocblas_operation_none, - transa ? rocblas_operation_transpose : rocblas_operation_none, - transb ? B->shape[0] : B->shape[1], - transa ? A->shape[1] : A->shape[0], - transb ? B->shape[1] : B->shape[0], - &alpha, - A_ptr, - B->shape[1], - B_ptr, - A->shape[1], - &beta, - C_ptr, - C->shape[1])); + CHECK_ROCBLAS_ERROR( + rocblas_sgemm(handle, transb ? rocblas_operation_transpose : rocblas_operation_none, + transa ? rocblas_operation_transpose : rocblas_operation_none, + transb ? B->shape[0] : B->shape[1], transa ? A->shape[1] : A->shape[0], + transb ? B->shape[1] : B->shape[0], &alpha, A_ptr, B->shape[1], B_ptr, + A->shape[1], &beta, C_ptr, C->shape[1])); - CHECK_ROCBLAS_ERROR(rocblas_destroy_handle(handle)); + CHECK_ROCBLAS_ERROR(rocblas_destroy_handle(handle)); }); } // namespace contrib } // namespace tvm diff --git a/src/runtime/contrib/sort/sort.cc b/src/runtime/contrib/sort/sort.cc index 0c9c57533dbe..9543e4b4c64e 100644 --- a/src/runtime/contrib/sort/sort.cc +++ b/src/runtime/contrib/sort/sort.cc @@ -21,8 +21,9 @@ * \file Use standard C library call. */ -#include #include +#include + #include #include @@ -31,19 +32,16 @@ namespace contrib { using namespace runtime; -template -bool CompareAscend(const std::pair& lhs, - const std::pair& rhs) { +template +bool CompareAscend(const std::pair& lhs, const std::pair& rhs) { return lhs.second < rhs.second; } -template -bool CompareDescend(const std::pair& lhs, - const std::pair& rhs) { +template +bool CompareDescend(const std::pair& lhs, const std::pair& rhs) { return lhs.second > rhs.second; } - // Argsort implemented C library sort for nms. // Return indices of sorted tensor. // By default, the last axis will be used to sort. @@ -51,17 +49,16 @@ bool CompareDescend(const std::pair& lhs, // If input tensor has dimension (d0, d1, ..., d(k-1), dk, d(k+1), ..., d(n-1)) // and sort axis is dk. sort_num should have dimension of // (d1, d2, ..., d(k-1), d(k+1), ..., dn). -TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort_nms") -.set_body([](TVMArgs args, TVMRetValue *ret) { - DLTensor *input = args[0]; - DLTensor *sort_num = args[1]; - DLTensor *output = args[2]; +TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort_nms").set_body([](TVMArgs args, TVMRetValue* ret) { + DLTensor* input = args[0]; + DLTensor* sort_num = args[1]; + DLTensor* output = args[2]; int32_t axis = args[3]; bool is_ascend = args[4]; auto dtype = input->dtype; - auto data_ptr = static_cast(input->data); - auto sort_num_ptr = static_cast(sort_num->data); + auto data_ptr = static_cast(input->data); + auto sort_num_ptr = static_cast(sort_num->data); std::vector> sorter; int64_t axis_mul_before = 1; int64_t axis_mul_after = 1; @@ -72,13 +69,14 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort_nms") // Currently only supports input dtype to be float32. CHECK_EQ(dtype.code, 2) << "Currently only supports input dtype " - "to be float."; + "to be float."; #if (__ARM_FEATURE_FP16_SCALAR_ARITHMETIC != 1) CHECK_EQ(dtype.bits, 32) << "Currently only supports input dtype " - "to be float32."; + "to be float32."; #endif CHECK_LT(axis, input->ndim) << "Axis out of boundary for " - "input ndim " << input->ndim; + "input ndim " + << input->ndim; for (int i = 0; i < input->ndim; ++i) { if (i < axis) { @@ -88,8 +86,8 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort_nms") } } - for (int64_t i = 0 ; i < axis_mul_before; ++i) { - for (int64_t j = 0 ; j < axis_mul_after; ++j) { + for (int64_t i = 0; i < axis_mul_before; ++i) { + for (int64_t j = 0; j < axis_mul_after; ++j) { sorter.clear(); int32_t current_sort_num = *(sort_num_ptr + i * axis_mul_after + j); int64_t base_idx = i * input->shape[axis] * axis_mul_after + j; @@ -103,7 +101,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort_nms") std::stable_sort(sorter.begin(), sorter.end(), CompareAscend<__fp16>); } else { #endif - std::stable_sort(sorter.begin(), sorter.end(), CompareAscend); + std::stable_sort(sorter.begin(), sorter.end(), CompareAscend); #if (__ARM_FEATURE_FP16_SCALAR_ARITHMETIC == 1) } #endif @@ -113,24 +111,24 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort_nms") std::stable_sort(sorter.begin(), sorter.end(), CompareDescend<__fp16>); } else { #endif - std::stable_sort(sorter.begin(), sorter.end(), CompareDescend); + std::stable_sort(sorter.begin(), sorter.end(), CompareDescend); #if (__ARM_FEATURE_FP16_SCALAR_ARITHMETIC == 1) } #endif } for (int32_t k = 0; k < input->shape[axis]; ++k) { - *(static_cast(output->data) + base_idx + k * axis_mul_after) - = k < static_cast(sorter.size()) ? sorter[k].first : k; + *(static_cast(output->data) + base_idx + k * axis_mul_after) = + k < static_cast(sorter.size()) ? sorter[k].first : k; } } } }); -template +template void argsort(DLTensor* input, DLTensor* output, int32_t axis, bool is_ascend) { - auto data_ptr = static_cast(input->data); - auto out_ptr = static_cast(output->data); - std::vector > sorter; + auto data_ptr = static_cast(input->data); + auto out_ptr = static_cast(output->data); + std::vector> sorter; int axis_mul_before = 1; int axis_mul_after = 1; @@ -142,8 +140,8 @@ void argsort(DLTensor* input, DLTensor* output, int32_t axis, bool is_ascend) { } } - for (int i = 0 ; i < axis_mul_before; ++i) { - for (int j = 0 ; j < axis_mul_after; ++j) { + for (int i = 0; i < axis_mul_before; ++i) { + for (int j = 0; j < axis_mul_after; ++j) { sorter.clear(); int64_t base_idx = i * input->shape[axis] * axis_mul_after + j; for (int64_t k = 0; k < input->shape[axis]; ++k) { @@ -169,17 +167,17 @@ void argsort(DLTensor* input, DLTensor* output, int32_t axis, bool is_ascend) { // If input tensor has dimension (d0, d1, ..., d(k-1), dk, d(k+1), ..., d(n-1)) // and sort axis is dk. sort_num should have dimension of // (d1, d2, ..., d(k-1), d(k+1), ..., dn). -TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort") -.set_body([](TVMArgs args, TVMRetValue *ret) { - DLTensor *input = args[0]; - DLTensor *output = args[1]; +TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort").set_body([](TVMArgs args, TVMRetValue* ret) { + DLTensor* input = args[0]; + DLTensor* output = args[1]; int32_t axis = args[2]; bool is_ascend = args[3]; if (axis < 0) { axis = input->ndim + axis; } CHECK_LT(axis, input->ndim) << "Axis out of boundary for " - "input ndim " << input->ndim; + "input ndim " + << input->ndim; auto data_dtype = DLDataType2String(input->dtype); auto out_dtype = DLDataType2String(output->dtype); @@ -228,7 +226,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort") } else { LOG(FATAL) << "Unsupported output dtype: " << out_dtype; } - } else if (data_dtype == "int64") { + } else if (data_dtype == "int64") { if (out_dtype == "int32") { argsort(input, output, axis, is_ascend); } else if (out_dtype == "int64") { @@ -245,19 +243,15 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort") } }); -template -void topk(DLTensor* input, - DLTensor* out_values, - DLTensor* out_indices, - int k, - int axis, +template +void topk(DLTensor* input, DLTensor* out_values, DLTensor* out_indices, int k, int axis, bool is_ascend) { - DataType* data_ptr = static_cast(input->data); - DataType* values_ptr = (out_values == nullptr) ? nullptr : - static_cast(out_values->data); - IndicesType* indices_ptr = (out_indices == nullptr) ? nullptr : - static_cast(out_indices->data); - std::vector > sorter; + DataType* data_ptr = static_cast(input->data); + DataType* values_ptr = + (out_values == nullptr) ? nullptr : static_cast(out_values->data); + IndicesType* indices_ptr = + (out_indices == nullptr) ? nullptr : static_cast(out_indices->data); + std::vector> sorter; int axis_mul_before = 1; int axis_mul_after = 1; @@ -272,8 +266,8 @@ void topk(DLTensor* input, k = input->shape[axis]; } - for (int i = 0 ; i < axis_mul_before; ++i) { - for (int j = 0 ; j < axis_mul_after; ++j) { + for (int i = 0; i < axis_mul_before; ++i) { + for (int j = 0; j < axis_mul_after; ++j) { sorter.clear(); int64_t src_base_idx = i * input->shape[axis] * axis_mul_after + j; int64_t dst_base_idx = i * k * axis_mul_after + j; @@ -290,11 +284,10 @@ void topk(DLTensor* input, for (int64_t kk = 0; kk < cnt; ++kk) { if (indices_ptr != nullptr) { indices_ptr[dst_base_idx + kk * axis_mul_after] = - static_cast(sorter[kk].first); + static_cast(sorter[kk].first); } if (values_ptr != nullptr) { - values_ptr[dst_base_idx + kk * axis_mul_after] = - static_cast(sorter[kk].second); + values_ptr[dst_base_idx + kk * axis_mul_after] = static_cast(sorter[kk].second); } } } @@ -308,8 +301,7 @@ void topk(DLTensor* input, // If input tensor has dimension (d0, d1, ..., d(k-1), dk, d(k+1), ..., d(n-1)) // and sort axis is dk. sort_num should have dimension of // (d1, d2, ..., d(k-1), d(k+1), ..., dn). -TVM_REGISTER_GLOBAL("tvm.contrib.sort.topk") -.set_body([](TVMArgs args, TVMRetValue* ret) { +TVM_REGISTER_GLOBAL("tvm.contrib.sort.topk").set_body([](TVMArgs args, TVMRetValue* ret) { DLTensor* input = args[0]; DLTensor* values_out = nullptr; DLTensor* indices_out = nullptr; @@ -371,7 +363,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.topk") } else { LOG(FATAL) << "Unsupported output dtype: " << out_dtype; } - } else if (data_dtype == "int64") { + } else if (data_dtype == "int64") { if (out_dtype == "int32") { topk(input, values_out, indices_out, k, axis, is_ascend); } else if (out_dtype == "int64") { diff --git a/src/runtime/contrib/tensorrt/tensorrt_ops.h b/src/runtime/contrib/tensorrt/tensorrt_ops.h index beaede9dc67a..9aa3f13af4d3 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_ops.h +++ b/src/runtime/contrib/tensorrt/tensorrt_ops.h @@ -35,8 +35,8 @@ #include #include #include + #include "NvInfer.h" -// #include "NvInferPlugin.h" #include "utils.h" #if TRT_VERSION_GE(6, 0, 1) @@ -102,8 +102,8 @@ class TrtOpConverter { * true. input_types vector will be ignored and any number of input tensors * can be used for this op. All inputs will be tensors and not weights. */ - TrtOpConverter(const std::vector& input_types, - bool variable_input_count = false) + explicit TrtOpConverter(const std::vector& input_types, + bool variable_input_count = false) : input_types(input_types), variable_input_count(variable_input_count) {} /*! @@ -121,8 +121,7 @@ class TrtOpConverter { * \param new_shape New shape, does not include batch dim. * \return Reshaped tensor */ - nvinfer1::ITensor* Reshape(AddTrtLayerParams* params, - nvinfer1::ITensor* input, + nvinfer1::ITensor* Reshape(AddTrtLayerParams* params, nvinfer1::ITensor* input, const std::vector& new_shape) const { auto layer = params->network->addShuffle(*input); CHECK(layer != nullptr); @@ -137,8 +136,7 @@ class TrtOpConverter { * \param order New order of axes, does include batch dim. * \return Transposed tensor */ - nvinfer1::ITensor* Transpose(AddTrtLayerParams* params, - nvinfer1::ITensor* input, + nvinfer1::ITensor* Transpose(AddTrtLayerParams* params, nvinfer1::ITensor* input, const std::vector& order) const { auto layer = params->network->addShuffle(*input); CHECK(layer != nullptr); @@ -189,16 +187,14 @@ class TrtOpConverter { * \param broadcast_to_dims Dims that scalar should be broadcastable against. * \return Constant tensor. */ - nvinfer1::ITensor* CreateScalar( - AddTrtLayerParams* params, float value, - const nvinfer1::Dims& broadcast_to_dims) const { + nvinfer1::ITensor* CreateScalar(AddTrtLayerParams* params, float value, + const nvinfer1::Dims& broadcast_to_dims) const { nvinfer1::Dims dims; dims.nbDims = broadcast_to_dims.nbDims; std::fill_n(dims.d, dims.nbDims, 1); float* values = new float[1]; values[0] = value; - nvinfer1::Weights weights{nvinfer1::DataType::kFLOAT, - static_cast(values), 1}; + nvinfer1::Weights weights{nvinfer1::DataType::kFLOAT, static_cast(values), 1}; params->trt_weights->push_back(weights); return params->network->addConstant(dims, weights)->getOutput(0); } @@ -215,24 +211,24 @@ class TrtOpConverter { CHECK(padding.size() == 1 || padding.size() == 2 || padding.size() == 4); if (padding.size() == 4) { // four int : padding width in the order of (top, left, bottom, right). - *prepadding = nvinfer1::DimsHW(padding[0].as()->value, - padding[1].as()->value); - *postpadding = nvinfer1::DimsHW(padding[2].as()->value, - padding[3].as()->value); + *prepadding = + nvinfer1::DimsHW(padding[0].as()->value, padding[1].as()->value); + *postpadding = + nvinfer1::DimsHW(padding[2].as()->value, padding[3].as()->value); *use_asymmetric_padding = true; } else if (padding.size() == 2) { // two int : bottom, right will use same padding as top, left - *prepadding = nvinfer1::DimsHW(padding[0].as()->value, - padding[1].as()->value); - *postpadding = nvinfer1::DimsHW(padding[0].as()->value, - padding[1].as()->value); + *prepadding = + nvinfer1::DimsHW(padding[0].as()->value, padding[1].as()->value); + *postpadding = + nvinfer1::DimsHW(padding[0].as()->value, padding[1].as()->value); *use_asymmetric_padding = false; } else { // one int : same padding used on all sides - *prepadding = nvinfer1::DimsHW(padding[0].as()->value, - padding[0].as()->value); - *postpadding = nvinfer1::DimsHW(padding[0].as()->value, - padding[0].as()->value); + *prepadding = + nvinfer1::DimsHW(padding[0].as()->value, padding[0].as()->value); + *postpadding = + nvinfer1::DimsHW(padding[0].as()->value, padding[0].as()->value); *use_asymmetric_padding = false; } } @@ -244,21 +240,19 @@ class ActivationOpConverter : public TrtOpConverter { void Convert(AddTrtLayerParams* params) const { CHECK_EQ(params->inputs.size(), 1) << "Activation op expects 1 input."; - static const std::unordered_map - op_map = { - {"nn.relu", nvinfer1::ActivationType::kRELU}, - {"sigmoid", nvinfer1::ActivationType::kSIGMOID}, - {"tanh", nvinfer1::ActivationType::kTANH}, + static const std::unordered_map op_map = { + {"nn.relu", nvinfer1::ActivationType::kRELU}, + {"sigmoid", nvinfer1::ActivationType::kSIGMOID}, + {"tanh", nvinfer1::ActivationType::kTANH}, #if TRT_VERSION_GE(5, 1, 5) - {"clip", nvinfer1::ActivationType::kCLIP}, - {"nn.leaky_relu", nvinfer1::ActivationType::kLEAKY_RELU}, + {"clip", nvinfer1::ActivationType::kCLIP}, + {"nn.leaky_relu", nvinfer1::ActivationType::kLEAKY_RELU}, #endif - }; + }; auto it = op_map.find(params->op_name); - CHECK(it != op_map.end()) << "Unsupported activation type " - << params->op_name; - nvinfer1::IActivationLayer* act_layer = params->network->addActivation( - *params->inputs.at(0).tensor, it->second); + CHECK(it != op_map.end()) << "Unsupported activation type " << params->op_name; + nvinfer1::IActivationLayer* act_layer = + params->network->addActivation(*params->inputs.at(0).tensor, it->second); #if TRT_VERSION_GE(5, 1, 5) if (params->op_name == "clip") { const auto* clip_attr = params->call->attrs.as(); @@ -286,24 +280,22 @@ class ClipLegacyOpConverter : public TrtOpConverter { nvinfer1::ITensor* output = nullptr; if (attrs->a_min == 0.0f) { // Use relu instead of max(x, 0) because relu can be fused. - nvinfer1::IActivationLayer* relu_layer = params->network->addActivation( - *input, nvinfer1::ActivationType::kRELU); + nvinfer1::IActivationLayer* relu_layer = + params->network->addActivation(*input, nvinfer1::ActivationType::kRELU); CHECK(relu_layer != nullptr); output = relu_layer->getOutput(0); } else { // max(x, a_min) - nvinfer1::ITensor* a_min = - CreateScalar(params, attrs->a_min, input->getDimensions()); - nvinfer1::IElementWiseLayer* max_layer = params->network->addElementWise( - *input, *a_min, nvinfer1::ElementWiseOperation::kMAX); + nvinfer1::ITensor* a_min = CreateScalar(params, attrs->a_min, input->getDimensions()); + nvinfer1::IElementWiseLayer* max_layer = + params->network->addElementWise(*input, *a_min, nvinfer1::ElementWiseOperation::kMAX); CHECK(max_layer != nullptr); output = max_layer->getOutput(0); } // min(relu(x), a_max) - nvinfer1::ITensor* a_max = - CreateScalar(params, attrs->a_max, input->getDimensions()); - nvinfer1::IElementWiseLayer* min_layer = params->network->addElementWise( - *output, *a_max, nvinfer1::ElementWiseOperation::kMIN); + nvinfer1::ITensor* a_max = CreateScalar(params, attrs->a_max, input->getDimensions()); + nvinfer1::IElementWiseLayer* min_layer = + params->network->addElementWise(*output, *a_max, nvinfer1::ElementWiseOperation::kMIN); params->outputs.push_back(min_layer->getOutput(0)); } }; @@ -313,17 +305,16 @@ class ElementWiseBinaryOpConverter : public TrtOpConverter { ElementWiseBinaryOpConverter() : TrtOpConverter({kTensor, kTensor}) {} void Convert(AddTrtLayerParams* params) const { - static const std::unordered_map - op_map = {{"add", nvinfer1::ElementWiseOperation::kSUM}, - {"subtract", nvinfer1::ElementWiseOperation::kSUB}, - {"multiply", nvinfer1::ElementWiseOperation::kPROD}, - {"divide", nvinfer1::ElementWiseOperation::kDIV}, - {"power", nvinfer1::ElementWiseOperation::kPOW}, - {"maximum", nvinfer1::ElementWiseOperation::kMAX}, - {"minimum", nvinfer1::ElementWiseOperation::kMIN}}; + static const std::unordered_map op_map = { + {"add", nvinfer1::ElementWiseOperation::kSUM}, + {"subtract", nvinfer1::ElementWiseOperation::kSUB}, + {"multiply", nvinfer1::ElementWiseOperation::kPROD}, + {"divide", nvinfer1::ElementWiseOperation::kDIV}, + {"power", nvinfer1::ElementWiseOperation::kPOW}, + {"maximum", nvinfer1::ElementWiseOperation::kMAX}, + {"minimum", nvinfer1::ElementWiseOperation::kMIN}}; auto it = op_map.find(params->op_name); - CHECK(it != op_map.end()) << "Unsupported elementwise type " - << params->op_name; + CHECK(it != op_map.end()) << "Unsupported elementwise type " << params->op_name; // Broadcast auto input0 = params->inputs.at(0).tensor; auto input0_dims = TrtDimsToVector(input0->getDimensions()); @@ -333,13 +324,11 @@ class ElementWiseBinaryOpConverter : public TrtOpConverter { if (need_broadcast) { if (input0_dims.size() < input1_dims.size()) { std::vector new_shape(input0_dims); - while (new_shape.size() < input1_dims.size()) - new_shape.insert(new_shape.begin(), 1); + while (new_shape.size() < input1_dims.size()) new_shape.insert(new_shape.begin(), 1); input0 = Reshape(params, input0, new_shape); } else if (input1_dims.size() < input0_dims.size()) { std::vector new_shape(input1_dims); - while (new_shape.size() < input0_dims.size()) - new_shape.insert(new_shape.begin(), 1); + while (new_shape.size() < input0_dims.size()) new_shape.insert(new_shape.begin(), 1); input1 = Reshape(params, input1, new_shape); } } @@ -371,8 +360,7 @@ class Conv2DOpConverter : public TrtOpConverter { GetPadding(conv2d_attr->padding, &use_asymmetric_padding, &prepadding, &postpadding); #if !TRT_VERSION_GE(5, 1, 5) if (use_asymmetric_padding) { - auto pad_layer = - params->network->addPadding(*input_tensor, prepadding, postpadding); + auto pad_layer = params->network->addPadding(*input_tensor, prepadding, postpadding); CHECK(pad_layer != nullptr); input_tensor = pad_layer->getOutput(0); // No need for conv op to do any padding. @@ -385,9 +373,8 @@ class Conv2DOpConverter : public TrtOpConverter { const int num_outputs = weight_shape[0]; const auto kernel_size = nvinfer1::DimsHW(weight_shape[2], weight_shape[3]); nvinfer1::Weights bias{nvinfer1::DataType::kFLOAT, nullptr, 0}; - auto conv_layer = - params->network->addConvolution(*input_tensor, num_outputs, kernel_size, - params->inputs.at(1).weight, bias); + auto conv_layer = params->network->addConvolution(*input_tensor, num_outputs, kernel_size, + params->inputs.at(1).weight, bias); CHECK(conv_layer != nullptr); if (use_asymmetric_padding) { #if TRT_VERSION_GE(5, 1, 5) @@ -398,14 +385,12 @@ class Conv2DOpConverter : public TrtOpConverter { conv_layer->setPadding(prepadding); } CHECK_EQ(conv2d_attr->strides.size(), 2); - const auto strides = - nvinfer1::DimsHW(conv2d_attr->strides[0].as()->value, - conv2d_attr->strides[1].as()->value); + const auto strides = nvinfer1::DimsHW(conv2d_attr->strides[0].as()->value, + conv2d_attr->strides[1].as()->value); conv_layer->setStride(strides); CHECK_EQ(conv2d_attr->dilation.size(), 2); - const auto dilation = - nvinfer1::DimsHW(conv2d_attr->dilation[0].as()->value, - conv2d_attr->dilation[1].as()->value); + const auto dilation = nvinfer1::DimsHW(conv2d_attr->dilation[0].as()->value, + conv2d_attr->dilation[1].as()->value); conv_layer->setDilation(dilation); conv_layer->setNbGroups(conv2d_attr->groups); params->outputs.push_back(conv_layer->getOutput(0)); @@ -426,17 +411,15 @@ class DenseOpConverter : public TrtOpConverter { if (need_reshape_on_input) { // Add dims of size 1 until rank is required_rank. std::vector new_shape(input_dims); - while (new_shape.size() < required_rank) - new_shape.insert(new_shape.end(), 1); + while (new_shape.size() < required_rank) new_shape.insert(new_shape.end(), 1); input_tensor = Reshape(params, input_tensor, new_shape); } // Weights are in KC format. CHECK_EQ(params->inputs.at(1).weight_shape.size(), 2); const int num_units = params->inputs.at(1).weight_shape[0]; nvinfer1::Weights bias{nvinfer1::DataType::kFLOAT, nullptr, 0}; - nvinfer1::IFullyConnectedLayer* fc_layer = - params->network->addFullyConnected(*input_tensor, num_units, - params->inputs.at(1).weight, bias); + nvinfer1::IFullyConnectedLayer* fc_layer = params->network->addFullyConnected( + *input_tensor, num_units, params->inputs.at(1).weight, bias); CHECK(fc_layer != nullptr); auto output_tensor = fc_layer->getOutput(0); if (need_reshape_on_input) { @@ -450,8 +433,7 @@ class DenseOpConverter : public TrtOpConverter { class BatchNormOpConverter : public TrtOpConverter { public: - BatchNormOpConverter() - : TrtOpConverter({kTensor, kWeight, kWeight, kWeight, kWeight}) {} + BatchNormOpConverter() : TrtOpConverter({kTensor, kWeight, kWeight, kWeight, kWeight}) {} void Convert(AddTrtLayerParams* params) const { auto input = params->inputs.at(0).tensor; @@ -467,12 +449,10 @@ class BatchNormOpConverter : public TrtOpConverter { const bool need_transpose = bn_attr->axis == 3; void* weight_scale_ptr = new float[gamma.count]; - nvinfer1::Weights weight_scale{nvinfer1::DataType::kFLOAT, weight_scale_ptr, - gamma.count}; + nvinfer1::Weights weight_scale{nvinfer1::DataType::kFLOAT, weight_scale_ptr, gamma.count}; params->trt_weights->push_back(weight_scale); void* weight_shift_ptr = new float[gamma.count]; - nvinfer1::Weights weight_shift{nvinfer1::DataType::kFLOAT, weight_shift_ptr, - gamma.count}; + nvinfer1::Weights weight_shift{nvinfer1::DataType::kFLOAT, weight_shift_ptr, gamma.count}; params->trt_weights->push_back(weight_shift); nvinfer1::Weights power{nvinfer1::DataType::kFLOAT, nullptr, 0}; @@ -496,9 +476,8 @@ class BatchNormOpConverter : public TrtOpConverter { if (need_transpose) { input = Transpose(params, input, {0, 3, 1, 2}); } - nvinfer1::IScaleLayer* scale_layer = - params->network->addScale(*input, nvinfer1::ScaleMode::kCHANNEL, - weight_shift, weight_scale, power); + nvinfer1::IScaleLayer* scale_layer = params->network->addScale( + *input, nvinfer1::ScaleMode::kCHANNEL, weight_shift, weight_scale, power); CHECK(scale_layer != nullptr); auto output = scale_layer->getOutput(0); if (need_transpose) { @@ -515,11 +494,9 @@ class BatchFlattenOpConverter : public TrtOpConverter { void Convert(AddTrtLayerParams* params) const { std::vector new_shape{-1}; if (!TRT_HAS_IMPLICIT_BATCH(params)) { - new_shape.insert(new_shape.begin(), - params->inputs.at(0).tensor->getDimensions().d[0]); + new_shape.insert(new_shape.begin(), params->inputs.at(0).tensor->getDimensions().d[0]); } - params->outputs.push_back( - Reshape(params, params->inputs.at(0).tensor, new_shape)); + params->outputs.push_back(Reshape(params, params->inputs.at(0).tensor, new_shape)); } }; @@ -532,8 +509,7 @@ class SoftmaxOpConverter : public TrtOpConverter { const int input_rank = input->getDimensions().nbDims; const auto* softmax_attr = params->call->attrs.as(); const int axis = ConvertAxis(params, softmax_attr->axis, input_rank); - nvinfer1::ISoftMaxLayer* softmax_layer = - params->network->addSoftMax(*input); + nvinfer1::ISoftMaxLayer* softmax_layer = params->network->addSoftMax(*input); softmax_layer->setAxes(1 << axis); CHECK(softmax_layer != nullptr); params->outputs.push_back(softmax_layer->getOutput(0)); @@ -549,51 +525,44 @@ class PoolingOpConverter : public TrtOpConverter { // in prepadding only. template void GetPoolAttrs(const PoolAttrs* attrs, nvinfer1::DimsHW* prepadding, - nvinfer1::DimsHW* postpadding, - nvinfer1::DimsHW* window_size, nvinfer1::DimsHW* strides, - bool* ceil_mode, bool* use_asymmetric_padding) const { + nvinfer1::DimsHW* postpadding, nvinfer1::DimsHW* window_size, + nvinfer1::DimsHW* strides, bool* ceil_mode, + bool* use_asymmetric_padding) const { CHECK_EQ(attrs->layout, "NCHW"); GetPadding(attrs->padding, use_asymmetric_padding, prepadding, postpadding); - *window_size = - nvinfer1::DimsHW(attrs->pool_size[0].template as()->value, - attrs->pool_size[1].template as()->value); - *strides = - nvinfer1::DimsHW(attrs->strides[0].template as()->value, - attrs->strides[1].template as()->value); + *window_size = nvinfer1::DimsHW(attrs->pool_size[0].template as()->value, + attrs->pool_size[1].template as()->value); + *strides = nvinfer1::DimsHW(attrs->strides[0].template as()->value, + attrs->strides[1].template as()->value); *ceil_mode = attrs->ceil_mode; } void Convert(AddTrtLayerParams* params) const { auto input = params->inputs.at(0).tensor; - static const std::unordered_map op_map = - {{"nn.max_pool2d", nvinfer1::PoolingType::kMAX}, - {"nn.avg_pool2d", nvinfer1::PoolingType::kAVERAGE}}; + static const std::unordered_map op_map = { + {"nn.max_pool2d", nvinfer1::PoolingType::kMAX}, + {"nn.avg_pool2d", nvinfer1::PoolingType::kAVERAGE}}; auto it = op_map.find(params->op_name); - CHECK(it != op_map.end()) << "Unsupported pooling type " << params->op_name - << " in TensorRT"; + CHECK(it != op_map.end()) << "Unsupported pooling type " << params->op_name << " in TensorRT"; nvinfer1::DimsHW prepadding, postpadding, window_size, strides; - bool use_asymmetric_padding = false, ceil_mode = false, - count_include_pad = true; + bool use_asymmetric_padding = false, ceil_mode = false, count_include_pad = true; if (params->op_name == "nn.max_pool2d") { const auto* attrs = params->call->attrs.as(); - GetPoolAttrs(attrs, &prepadding, &postpadding, - &window_size, &strides, &ceil_mode, - &use_asymmetric_padding); + GetPoolAttrs(attrs, &prepadding, &postpadding, &window_size, &strides, + &ceil_mode, &use_asymmetric_padding); } else if (params->op_name == "nn.avg_pool2d") { const auto* attrs = params->call->attrs.as(); count_include_pad = attrs->count_include_pad; - GetPoolAttrs(attrs, &prepadding, &postpadding, - &window_size, &strides, &ceil_mode, - &use_asymmetric_padding); + GetPoolAttrs(attrs, &prepadding, &postpadding, &window_size, &strides, + &ceil_mode, &use_asymmetric_padding); } // TRT pooling op doesn't support asymmetric padding before 5.1, so we // workaround by adding a padding layer before the pooling op. #if !TRT_VERSION_GE(5, 1, 5) if (use_asymmetric_padding) { - auto pad_layer = - params->network->addPadding(*input, prepadding, postpadding); + auto pad_layer = params->network->addPadding(*input, prepadding, postpadding); CHECK(pad_layer != nullptr); input = pad_layer->getOutput(0); // No need for pooling op to do any padding. @@ -602,8 +571,7 @@ class PoolingOpConverter : public TrtOpConverter { } #endif - auto pool_layer = - params->network->addPooling(*input, it->second, window_size); + auto pool_layer = params->network->addPooling(*input, it->second, window_size); CHECK(pool_layer != nullptr); pool_layer->setStride(strides); if (use_asymmetric_padding) { @@ -643,18 +611,17 @@ class GlobalPoolingOpConverter : public TrtOpConverter { void Convert(AddTrtLayerParams* params) const { auto input_tensor = params->inputs.at(0).tensor; auto input_dims = TrtDimsToVector(input_tensor->getDimensions()); - static const std::unordered_map op_map = - {{"nn.global_max_pool2d", nvinfer1::PoolingType::kMAX}, - {"nn.global_avg_pool2d", nvinfer1::PoolingType::kAVERAGE}}; + static const std::unordered_map op_map = { + {"nn.global_max_pool2d", nvinfer1::PoolingType::kMAX}, + {"nn.global_avg_pool2d", nvinfer1::PoolingType::kAVERAGE}}; auto it = op_map.find(params->op_name); - CHECK(it != op_map.end()) << "Unsupported pooling type " << params->op_name - << " in TensorRT"; + CHECK(it != op_map.end()) << "Unsupported pooling type " << params->op_name << " in TensorRT"; const auto* pool_attr = params->call->attrs.as(); CHECK_EQ(pool_attr->layout, "NCHW"); const int h = TRT_HAS_IMPLICIT_BATCH(params) ? input_dims[1] : input_dims[2]; const int w = TRT_HAS_IMPLICIT_BATCH(params) ? input_dims[2] : input_dims[3]; - auto pool_layer = params->network->addPooling(*input_tensor, it->second, - nvinfer1::DimsHW(h, w)); + auto pool_layer = + params->network->addPooling(*input_tensor, it->second, nvinfer1::DimsHW(h, w)); CHECK(pool_layer != nullptr); params->outputs.push_back(pool_layer->getOutput(0)); } @@ -672,8 +639,7 @@ class ExpandDimsOpConverter : public TrtOpConverter { for (int i = 0; i < attrs->num_newaxis; ++i) { input_dims.insert(input_dims.begin() + axis, 1); } - params->outputs.push_back( - Reshape(params, params->inputs.at(0).tensor, input_dims)); + params->outputs.push_back(Reshape(params, params->inputs.at(0).tensor, input_dims)); } }; @@ -688,14 +654,12 @@ class SqueezeOpConverter : public TrtOpConverter { // TODO(tmorris): if axis not defined, squeeze all dimensions with size 1. CHECK(attrs->axis.defined()); for (size_t i = 0; i < attrs->axis.size(); ++i) { - const int axis = ConvertAxis( - params, attrs->axis[i].as()->value, input_dims.size()); + const int axis = + ConvertAxis(params, attrs->axis[i].as()->value, input_dims.size()); input_dims[axis] = 0; } - input_dims.erase(std::remove(input_dims.begin(), input_dims.end(), 0), - input_dims.end()); - params->outputs.push_back( - Reshape(params, params->inputs.at(0).tensor, input_dims)); + input_dims.erase(std::remove(input_dims.begin(), input_dims.end(), 0), input_dims.end()); + params->outputs.push_back(Reshape(params, params->inputs.at(0).tensor, input_dims)); } }; @@ -706,21 +670,20 @@ class UnaryOpConverter : public TrtOpConverter { void Convert(AddTrtLayerParams* params) const { // The following ops are supported by TRT but don't exist in relay yet: // recip, tan, sinh, cosh, asin, acos, asinh, acosh, atanh - static const std::unordered_map - op_map = { - {"exp", nvinfer1::UnaryOperation::kEXP}, - {"log", nvinfer1::UnaryOperation::kLOG}, - {"sqrt", nvinfer1::UnaryOperation::kSQRT}, - {"abs", nvinfer1::UnaryOperation::kABS}, - {"negative", nvinfer1::UnaryOperation::kNEG}, + static const std::unordered_map op_map = { + {"exp", nvinfer1::UnaryOperation::kEXP}, + {"log", nvinfer1::UnaryOperation::kLOG}, + {"sqrt", nvinfer1::UnaryOperation::kSQRT}, + {"abs", nvinfer1::UnaryOperation::kABS}, + {"negative", nvinfer1::UnaryOperation::kNEG}, #if TRT_VERSION_GE(5, 1, 5) - {"sin", nvinfer1::UnaryOperation::kSIN}, - {"cos", nvinfer1::UnaryOperation::kCOS}, - {"atan", nvinfer1::UnaryOperation::kATAN}, - {"ceil", nvinfer1::UnaryOperation::kCEIL}, - {"floor", nvinfer1::UnaryOperation::kFLOOR}, + {"sin", nvinfer1::UnaryOperation::kSIN}, + {"cos", nvinfer1::UnaryOperation::kCOS}, + {"atan", nvinfer1::UnaryOperation::kATAN}, + {"ceil", nvinfer1::UnaryOperation::kCEIL}, + {"floor", nvinfer1::UnaryOperation::kFLOOR}, #endif - }; + }; auto it = op_map.find(params->op_name); CHECK(it != op_map.end()) << "Unsupported unary type " << params->op_name; nvinfer1::IUnaryLayer* unary_layer = @@ -749,8 +712,7 @@ class ConcatOpConverter : public TrtOpConverter { const int axis = ConvertAxis(params, concat_attr->axis, input_rank); nvinfer1::IConcatenationLayer* concat_layer = - params->network->addConcatenation(input_tensors.data(), - input_tensors.size()); + params->network->addConcatenation(input_tensors.data(), input_tensors.size()); CHECK(concat_layer != nullptr); concat_layer->setAxis(axis); params->outputs.push_back(concat_layer->getOutput(0)); @@ -770,16 +732,14 @@ class BiasAddOpConverter : public TrtOpConverter { if (need_reshape_on_input) { // Add dims of size 1 until rank is required_rank. std::vector new_shape(input_dims); - while (new_shape.size() < required_rank) - new_shape.insert(new_shape.end(), 1); + while (new_shape.size() < required_rank) new_shape.insert(new_shape.end(), 1); input_tensor = Reshape(params, input_tensor, new_shape); } nvinfer1::Weights shift{nvinfer1::DataType::kFLOAT, nullptr, 0}; nvinfer1::Weights power{nvinfer1::DataType::kFLOAT, nullptr, 0}; - nvinfer1::IScaleLayer* scale_layer = - params->network->addScale(*input_tensor, nvinfer1::ScaleMode::kCHANNEL, - params->inputs.at(1).weight, shift, power); + nvinfer1::IScaleLayer* scale_layer = params->network->addScale( + *input_tensor, nvinfer1::ScaleMode::kCHANNEL, params->inputs.at(1).weight, shift, power); CHECK(scale_layer != nullptr); auto output_tensor = scale_layer->getOutput(0); if (need_reshape_on_input) { @@ -811,8 +771,7 @@ class Conv2DTransposeOpConverter : public TrtOpConverter { GetPadding(conv2d_attr->padding, &use_asymmetric_padding, &prepadding, &postpadding); #if !TRT_VERSION_GE(5, 1, 5) if (use_asymmetric_padding) { - auto pad_layer = - params->network->addPadding(*input_tensor, prepadding, postpadding); + auto pad_layer = params->network->addPadding(*input_tensor, prepadding, postpadding); CHECK(pad_layer != nullptr); input_tensor = pad_layer->getOutput(0); // No need for conv op to do any padding. @@ -825,9 +784,8 @@ class Conv2DTransposeOpConverter : public TrtOpConverter { const int num_outputs = weight_shape[1]; const auto kernel_size = nvinfer1::DimsHW(weight_shape[2], weight_shape[3]); nvinfer1::Weights bias{nvinfer1::DataType::kFLOAT, nullptr, 0}; - auto deconv_layer = params->network->addDeconvolution( - *input_tensor, num_outputs, kernel_size, params->inputs.at(1).weight, - bias); + auto deconv_layer = params->network->addDeconvolution(*input_tensor, num_outputs, kernel_size, + params->inputs.at(1).weight, bias); CHECK(deconv_layer != nullptr); if (use_asymmetric_padding) { #if TRT_VERSION_GE(5, 1, 5) @@ -837,9 +795,8 @@ class Conv2DTransposeOpConverter : public TrtOpConverter { } else { deconv_layer->setPadding(prepadding); } - const auto strides = - nvinfer1::DimsHW(conv2d_attr->strides[0].as()->value, - conv2d_attr->strides[1].as()->value); + const auto strides = nvinfer1::DimsHW(conv2d_attr->strides[0].as()->value, + conv2d_attr->strides[1].as()->value); deconv_layer->setStride(strides); deconv_layer->setNbGroups(conv2d_attr->groups); nvinfer1::ITensor* output = deconv_layer->getOutput(0); @@ -880,12 +837,13 @@ class ReshapeOpConverter : public TrtOpConverter { void Convert(AddTrtLayerParams* params) const { auto input = params->inputs.at(0).tensor; const auto* attrs = params->call->attrs.as(); + CHECK(attrs->newshape); CHECK_EQ(attrs->reverse, false); std::vector new_shape; - const int start_index = - TRT_HAS_IMPLICIT_BATCH(params) ? 1 : 0; - for (size_t i = start_index; i < attrs->newshape.size(); ++i) { - const int value = attrs->newshape[i].as()->value; + const int start_index = TRT_HAS_IMPLICIT_BATCH(params) ? 1 : 0; + for (size_t i = start_index; i < attrs->newshape.value().size(); ++i) { + CHECK(attrs->newshape.value()[i].defined()); + const int value = attrs->newshape.value()[i].as()->value; CHECK_GE(value, -1); new_shape.push_back(value); } @@ -903,34 +861,27 @@ class PadOpConverter : public TrtOpConverter { const int input_rank_with_batch = input->getDimensions().nbDims + (TRT_HAS_IMPLICIT_BATCH(params) ? 1 : 0); CHECK_EQ(input_rank_with_batch, attrs->pad_width.size()); - CHECK(!TRT_HAS_IMPLICIT_BATCH(params) || - (attrs->pad_width[0][0].as()->value == 0 && - attrs->pad_width[0][1].as()->value == 0)) + CHECK(!TRT_HAS_IMPLICIT_BATCH(params) || (attrs->pad_width[0][0].as()->value == 0 && + attrs->pad_width[0][1].as()->value == 0)) << "Cannot pad on batch dimension."; nvinfer1::DimsHW prepadding, postpadding; // Check if we need to transpose from NHWC -> NCHW. - const bool need_transpose = - attrs->pad_width[1][0].as()->value != 0 || - attrs->pad_width[1][1].as()->value != 0; + const bool need_transpose = attrs->pad_width[1][0].as()->value != 0 || + attrs->pad_width[1][1].as()->value != 0; if (need_transpose) { input = Transpose(params, input, {0, 3, 1, 2}); - prepadding = - nvinfer1::DimsHW(attrs->pad_width[1][0].as()->value, - attrs->pad_width[2][0].as()->value); - postpadding = - nvinfer1::DimsHW(attrs->pad_width[1][1].as()->value, - attrs->pad_width[2][1].as()->value); + prepadding = nvinfer1::DimsHW(attrs->pad_width[1][0].as()->value, + attrs->pad_width[2][0].as()->value); + postpadding = nvinfer1::DimsHW(attrs->pad_width[1][1].as()->value, + attrs->pad_width[2][1].as()->value); } else { - prepadding = - nvinfer1::DimsHW(attrs->pad_width[2][0].as()->value, - attrs->pad_width[3][0].as()->value); - postpadding = - nvinfer1::DimsHW(attrs->pad_width[2][1].as()->value, - attrs->pad_width[3][1].as()->value); + prepadding = nvinfer1::DimsHW(attrs->pad_width[2][0].as()->value, + attrs->pad_width[3][0].as()->value); + postpadding = nvinfer1::DimsHW(attrs->pad_width[2][1].as()->value, + attrs->pad_width[3][1].as()->value); } - auto pad_layer = - params->network->addPadding(*input, prepadding, postpadding); + auto pad_layer = params->network->addPadding(*input, prepadding, postpadding); CHECK(pad_layer != nullptr); auto output = pad_layer->getOutput(0); if (need_transpose) { @@ -946,12 +897,12 @@ class ReduceOpConverter : public TrtOpConverter { ReduceOpConverter() : TrtOpConverter({kTensor}) {} void Convert(AddTrtLayerParams* params) const { - static const std::unordered_map - op_map = {{"sum", nvinfer1::ReduceOperation::kSUM}, - {"prod", nvinfer1::ReduceOperation::kPROD}, - {"max", nvinfer1::ReduceOperation::kMAX}, - {"min", nvinfer1::ReduceOperation::kMIN}, - {"mean", nvinfer1::ReduceOperation::kAVG}}; + static const std::unordered_map op_map = { + {"sum", nvinfer1::ReduceOperation::kSUM}, + {"prod", nvinfer1::ReduceOperation::kPROD}, + {"max", nvinfer1::ReduceOperation::kMAX}, + {"min", nvinfer1::ReduceOperation::kMIN}, + {"mean", nvinfer1::ReduceOperation::kAVG}}; auto it = op_map.find(params->op_name); CHECK(it != op_map.end()) << "Unsupported reduce type " << params->op_name; @@ -962,13 +913,12 @@ class ReduceOpConverter : public TrtOpConverter { CHECK(attrs->axis.defined() && attrs->axis.size() > 0); uint32_t reduce_axes = 0; for (size_t i = 0; i < attrs->axis.size(); ++i) { - const int axis = - ConvertAxis(params, attrs->axis[i].as()->value, - input->getDimensions().nbDims); + const int axis = ConvertAxis(params, attrs->axis[i].as()->value, + input->getDimensions().nbDims); reduce_axes |= 1 << axis; } - auto reduce_layer = params->network->addReduce( - *input, it->second, reduce_axes, attrs->keepdims); + auto reduce_layer = + params->network->addReduce(*input, it->second, reduce_axes, attrs->keepdims); params->outputs.push_back(reduce_layer->getOutput(0)); } }; @@ -982,15 +932,17 @@ class StridedSliceOpConverter : public TrtOpConverter { auto input = params->inputs.at(0).tensor; auto input_dims = TrtDimsToVector(input->getDimensions()); const auto* attrs = params->call->attrs.as(); + // Dynamic shapes not supported. + CHECK(attrs->begin && attrs->end && attrs->strides); const int input_rank_with_batch = input->getDimensions().nbDims + (TRT_HAS_IMPLICIT_BATCH(params) ? 1 : 0); - CHECK_EQ(input_rank_with_batch, attrs->begin.size()); - CHECK_EQ(input_rank_with_batch, attrs->end.size()); + CHECK_EQ(input_rank_with_batch, attrs->begin.value().size()); + CHECK_EQ(input_rank_with_batch, attrs->end.value().size()); const bool default_strides = - !attrs->strides.defined() || attrs->strides.size() == 0; + !attrs->strides.value().defined() || attrs->strides.value().size() == 0; if (TRT_HAS_IMPLICIT_BATCH(params)) { - CHECK(default_strides || !attrs->strides[0].defined() || - attrs->strides[0].as()->value == 1); + CHECK(default_strides || !attrs->strides.value()[0].defined() || + attrs->strides.value()[0].as()->value == 1); } auto process_slice_index = [](Integer x, int default_value, int dim_value) { @@ -1002,19 +954,17 @@ class StridedSliceOpConverter : public TrtOpConverter { const int start_index = TRT_HAS_IMPLICIT_BATCH(params) ? 1 : 0; std::vector start, size, strides; - for (size_t i = start_index; i < attrs->begin.size(); ++i) { + for (size_t i = start_index; i < attrs->begin.value().size(); ++i) { const int begin_value = - process_slice_index(attrs->begin[i], 0, input_dims[i - start_index]); - const int end_value = - process_slice_index(attrs->end[i], input_dims[i - start_index], - input_dims[i - start_index]); - const int stride_value = (default_strides || i >= attrs->strides.size() || - !attrs->strides[i].defined()) + process_slice_index(attrs->begin.value()[i], 0, input_dims[i - start_index]); + const int end_value = process_slice_index(attrs->end.value()[i], input_dims[i - start_index], + input_dims[i - start_index]); + const int stride_value = (default_strides || i >= attrs->strides.value().size() || + !attrs->strides.value()[i].defined()) ? 1 - : attrs->strides[i].as()->value; + : attrs->strides.value()[i].as()->value; CHECK_GT(stride_value, 0); - const int size_value = - (end_value - begin_value + stride_value - 1) / stride_value; + const int size_value = (end_value - begin_value + stride_value - 1) / stride_value; CHECK_GE(begin_value, 0); CHECK_GT(size_value, 0); start.push_back(begin_value); @@ -1023,8 +973,7 @@ class StridedSliceOpConverter : public TrtOpConverter { } auto slice_layer = params->network->addSlice(*input, VectorToTrtDims(start), - VectorToTrtDims(size), - VectorToTrtDims(strides)); + VectorToTrtDims(size), VectorToTrtDims(strides)); params->outputs.push_back(slice_layer->getOutput(0)); } }; @@ -1037,12 +986,11 @@ class AdaptivePoolingOpConverter : public TrtOpConverter { void Convert(AddTrtLayerParams* params) const { auto input_tensor = params->inputs.at(0).tensor; auto input_dims = TrtDimsToVector(input_tensor->getDimensions()); - static const std::unordered_map op_map = - {{"nn.adaptive_max_pool2d", nvinfer1::PoolingType::kMAX}, - {"nn.adaptive_avg_pool2d", nvinfer1::PoolingType::kAVERAGE}}; + static const std::unordered_map op_map = { + {"nn.adaptive_max_pool2d", nvinfer1::PoolingType::kMAX}, + {"nn.adaptive_avg_pool2d", nvinfer1::PoolingType::kAVERAGE}}; auto it = op_map.find(params->op_name); - CHECK(it != op_map.end()) << "Unsupported pooling type " << params->op_name - << " in TensorRT"; + CHECK(it != op_map.end()) << "Unsupported pooling type " << params->op_name << " in TensorRT"; const auto* attrs = params->call->attrs.as(); CHECK_EQ(attrs->layout, "NCHW"); @@ -1052,13 +1000,10 @@ class AdaptivePoolingOpConverter : public TrtOpConverter { auto output_size = nvinfer1::DimsHW(1, 1); const int h = TRT_HAS_IMPLICIT_BATCH(params) ? input_dims[1] : input_dims[2]; const int w = TRT_HAS_IMPLICIT_BATCH(params) ? input_dims[2] : input_dims[3]; - const auto stride = - nvinfer1::DimsHW(h / output_size.h(), w / output_size.w()); - const auto window_size = - nvinfer1::DimsHW(h - (output_size.h() - 1) * stride.h(), - w - (output_size.w() - 1) * stride.w()); - auto pool_layer = - params->network->addPooling(*input_tensor, it->second, window_size); + const auto stride = nvinfer1::DimsHW(h / output_size.h(), w / output_size.w()); + const auto window_size = nvinfer1::DimsHW(h - (output_size.h() - 1) * stride.h(), + w - (output_size.w() - 1) * stride.w()); + auto pool_layer = params->network->addPooling(*input_tensor, it->second, window_size); CHECK(pool_layer != nullptr); pool_layer->setStride(stride); params->outputs.push_back(pool_layer->getOutput(0)); @@ -1073,9 +1018,9 @@ class ResizeOpConverter : public TrtOpConverter { void Convert(AddTrtLayerParams* params) const { auto input = params->inputs.at(0).tensor; const auto* attrs = params->call->attrs.as(); - static const std::unordered_map op_map = - {{"nearest_neighbor", nvinfer1::ResizeMode::kNEAREST}, - {"bilinear", nvinfer1::ResizeMode::kLINEAR}}; + static const std::unordered_map op_map = { + {"nearest_neighbor", nvinfer1::ResizeMode::kNEAREST}, + {"bilinear", nvinfer1::ResizeMode::kLINEAR}}; auto it = op_map.find(attrs->method); CHECK(it != op_map.end()) << "Unsupported resize type " << attrs->method; CHECK_EQ(attrs->size.size(), 2); @@ -1096,8 +1041,7 @@ class ResizeOpConverter : public TrtOpConverter { CHECK(resize_layer != nullptr); resize_layer->setResizeMode(it->second); resize_layer->setOutputDimensions(VectorToTrtDims(output_dims)); - resize_layer->setAlignCorners(attrs->coordinate_transformation_mode == - "align_corners"); + resize_layer->setAlignCorners(attrs->coordinate_transformation_mode == "align_corners"); params->outputs.push_back(resize_layer->getOutput(0)); } }; @@ -1123,8 +1067,7 @@ class SplitOpConverter : public TrtOpConverter { for (int i = 0; i < sections; ++i) { start[axis] = i * size[axis]; auto slice_layer = params->network->addSlice(*input, VectorToTrtDims(start), - VectorToTrtDims(size), - VectorToTrtDims(strides)); + VectorToTrtDims(size), VectorToTrtDims(strides)); params->outputs.push_back(slice_layer->getOutput(0)); } @@ -1148,8 +1091,8 @@ class SliceLikeOpConverter : public TrtOpConverter { const auto* attrs = params->call->attrs.as(); if (attrs->axes.defined()) { for (int i = 0; i < attrs->axes.size(); i++) { - const int axis = ConvertAxis( - params, attrs->axes[i].as()->value, input_dims.size()); + const int axis = + ConvertAxis(params, attrs->axes[i].as()->value, input_dims.size()); input_dims[axis] = new_dims[axis]; } } else { @@ -1161,9 +1104,8 @@ class SliceLikeOpConverter : public TrtOpConverter { // slice_like always begins at 0. std::vector start(input_dims.size(), 0); std::vector strides(input_dims.size(), 1); - auto slice_layer = params->network->addSlice(*input, VectorToTrtDims(start), - VectorToTrtDims(input_dims), - VectorToTrtDims(strides)); + auto slice_layer = params->network->addSlice( + *input, VectorToTrtDims(start), VectorToTrtDims(input_dims), VectorToTrtDims(strides)); params->outputs.push_back(slice_layer->getOutput(0)); } @@ -1178,9 +1120,9 @@ class UpsamplingOpConverter : public TrtOpConverter { void Convert(AddTrtLayerParams* params) const { auto input = params->inputs.at(0).tensor; const auto* attrs = params->call->attrs.as(); - static const std::unordered_map op_map = - {{"nearest_neighbor", nvinfer1::ResizeMode::kNEAREST}, - {"bilinear", nvinfer1::ResizeMode::kLINEAR}}; + static const std::unordered_map op_map = { + {"nearest_neighbor", nvinfer1::ResizeMode::kNEAREST}, + {"bilinear", nvinfer1::ResizeMode::kLINEAR}}; auto it = op_map.find(attrs->method); CHECK(it != op_map.end()) << "Unsupported resize type " << attrs->method; auto output_dims = TrtDimsToVector(input->getDimensions()); diff --git a/src/runtime/contrib/tflite/tflite_runtime.cc b/src/runtime/contrib/tflite/tflite_runtime.cc index 56d3ce93433e..8b34e90312b0 100644 --- a/src/runtime/contrib/tflite/tflite_runtime.cc +++ b/src/runtime/contrib/tflite/tflite_runtime.cc @@ -20,53 +20,52 @@ /*! * \file tflite_runtime.cc */ -#include +#include "tflite_runtime.h" + #include #include #include - - -#include "tflite_runtime.h" +#include namespace tvm { namespace runtime { -#define TVM_DTYPE_DISPATCH(type, DType, ...) \ - if (type == DataType::Float(64)) { \ - typedef double DType; \ - {__VA_ARGS__} \ - } else if (type == DataType::Float(32)) { \ - typedef float DType; \ - {__VA_ARGS__} \ - } else if (type == DataType::Float(16)) { \ - typedef uint16_t DType; \ - {__VA_ARGS__} \ - } else if (type == DataType::Int(64)) { \ - typedef int64_t DType; \ - {__VA_ARGS__} \ - } else if (type == DataType::Int(32)) { \ - typedef int32_t DType; \ - {__VA_ARGS__} \ - } else if (type == DataType::Int(16)) { \ - typedef int16_t DType; \ - {__VA_ARGS__} \ - } else if (type == DataType::Int(8)) { \ - typedef int8_t DType; \ - {__VA_ARGS__} \ - } else if (type == DataType::UInt(64)) { \ - typedef uint64_t DType; \ - {__VA_ARGS__} \ - } else if (type == DataType::UInt(32)) { \ - typedef uint32_t DType; \ - {__VA_ARGS__} \ - } else if (type == DataType::UInt(16)) { \ - typedef uint16_t DType; \ - {__VA_ARGS__} \ - } else if (type == DataType::UInt(8)) { \ - typedef uint8_t DType; \ - {__VA_ARGS__} \ - } else { \ - LOG(FATAL) << "unknown data type " << type; \ +#define TVM_DTYPE_DISPATCH(type, DType, ...) \ + if (type == DataType::Float(64)) { \ + typedef double DType; \ + { __VA_ARGS__ } \ + } else if (type == DataType::Float(32)) { \ + typedef float DType; \ + { __VA_ARGS__ } \ + } else if (type == DataType::Float(16)) { \ + typedef uint16_t DType; \ + { __VA_ARGS__ } \ + } else if (type == DataType::Int(64)) { \ + typedef int64_t DType; \ + { __VA_ARGS__ } \ + } else if (type == DataType::Int(32)) { \ + typedef int32_t DType; \ + { __VA_ARGS__ } \ + } else if (type == DataType::Int(16)) { \ + typedef int16_t DType; \ + { __VA_ARGS__ } \ + } else if (type == DataType::Int(8)) { \ + typedef int8_t DType; \ + { __VA_ARGS__ } \ + } else if (type == DataType::UInt(64)) { \ + typedef uint64_t DType; \ + { __VA_ARGS__ } \ + } else if (type == DataType::UInt(32)) { \ + typedef uint32_t DType; \ + { __VA_ARGS__ } \ + } else if (type == DataType::UInt(16)) { \ + typedef uint16_t DType; \ + { __VA_ARGS__ } \ + } else if (type == DataType::UInt(8)) { \ + typedef uint8_t DType; \ + { __VA_ARGS__ } \ + } else { \ + LOG(FATAL) << "unknown data type " << type; \ } DataType TfLiteDType2TVMDType(TfLiteType dtype) { @@ -91,12 +90,15 @@ DataType TfLiteDType2TVMDType(TfLiteType dtype) { } } -void TFLiteRuntime::Init(const std::string& tflite_model_bytes, - TVMContext ctx) { +void TFLiteRuntime::Init(const std::string& tflite_model_bytes, TVMContext ctx) { const char* buffer = tflite_model_bytes.c_str(); size_t buffer_size = tflite_model_bytes.size(); + // The buffer used to construct the model must be kept alive for + // dependent interpreters to be used. + flatBuffersBuffer_ = std::unique_ptr(new char[buffer_size]); + std::memcpy(flatBuffersBuffer_.get(), buffer, buffer_size); std::unique_ptr model = - tflite::FlatBufferModel::BuildFromBuffer(buffer, buffer_size); + tflite::FlatBufferModel::BuildFromBuffer(flatBuffersBuffer_.get(), buffer_size); tflite::ops::builtin::BuiltinOpResolver resolver; // Build interpreter TfLiteStatus status = tflite::InterpreterBuilder(*model, resolver)(&interpreter_); @@ -108,24 +110,22 @@ void TFLiteRuntime::Init(const std::string& tflite_model_bytes, ctx_ = ctx; } -void TFLiteRuntime::Invoke() { - interpreter_->Invoke(); -} +void TFLiteRuntime::Invoke() { interpreter_->Invoke(); } void TFLiteRuntime::SetInput(int index, DLTensor* data_in) { DataType dtype(data_in->dtype); TVM_DTYPE_DISPATCH(dtype, DType, { - DType* dest = interpreter_->typed_input_tensor(index); - DType* src = static_cast(data_in->data); - CHECK(data_in->strides == NULL); - int64_t size = 1; - for (int64_t i = 0; i < data_in->ndim; ++i) { - size *= data_in->shape[i]; - } - for (int64_t i = 0; i < size; ++i) { - dest[i] = src[i]; - } - }); + DType* dest = interpreter_->typed_input_tensor(index); + DType* src = static_cast(data_in->data); + CHECK(data_in->strides == NULL); + int64_t size = 1; + for (int64_t i = 0; i < data_in->ndim; ++i) { + size *= data_in->shape[i]; + } + for (int64_t i = 0; i < size; ++i) { + dest[i] = src[i]; + } + }); } NDArray TFLiteRuntime::GetOutput(int index) const { @@ -140,48 +140,44 @@ NDArray TFLiteRuntime::GetOutput(int index) const { } NDArray ret = NDArray::Empty(shape, dtype, ctx_); TVM_DTYPE_DISPATCH(dtype, DType, { - DType* dest = static_cast(ret->data); - DType* src = interpreter_->typed_output_tensor(index); - for (int64_t i = 0; i < size; ++i) { - dest[i] = src[i]; - } - }); + DType* dest = static_cast(ret->data); + DType* src = interpreter_->typed_output_tensor(index); + for (int64_t i = 0; i < size; ++i) { + dest[i] = src[i]; + } + }); return ret; } -PackedFunc TFLiteRuntime::GetFunction( - const std::string& name, - const ObjectPtr& sptr_to_self) { +PackedFunc TFLiteRuntime::GetFunction(const std::string& name, + const ObjectPtr& sptr_to_self) { // Return member functions during query. if (name == "set_input") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - int in_idx = args[0]; - CHECK_GE(in_idx, 0); - this->SetInput(in_idx, args[1]); - }); + int in_idx = args[0]; + CHECK_GE(in_idx, 0); + this->SetInput(in_idx, args[1]); + }); } else if (name == "get_output") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->GetOutput(args[0]); - }); + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetOutput(args[0]); }); } else if (name == "invoke") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - this->Invoke(); - }); + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { this->Invoke(); }); } else { return PackedFunc(); } } -Module TFLiteRuntimeCreate(const std::string& tflite_model_bytes, - TVMContext ctx) { +Module TFLiteRuntimeCreate(const std::string& tflite_model_bytes, TVMContext ctx) { auto exec = make_object(); exec->Init(tflite_model_bytes, ctx); return Module(exec); } -TVM_REGISTER_GLOBAL("tvm.tflite_runtime.create") - .set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = TFLiteRuntimeCreate(args[0], args[1]); - }); +TVM_REGISTER_GLOBAL("tvm.tflite_runtime.create").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = TFLiteRuntimeCreate(args[0], args[1]); +}); + +TVM_REGISTER_GLOBAL("target.runtime.tflite").set_body_typed(TFLiteRuntimeCreate); } // namespace runtime } // namespace tvm diff --git a/src/runtime/contrib/tflite/tflite_runtime.h b/src/runtime/contrib/tflite/tflite_runtime.h index d823690126b1..f3e3bd90bba4 100644 --- a/src/runtime/contrib/tflite/tflite_runtime.h +++ b/src/runtime/contrib/tflite/tflite_runtime.h @@ -26,12 +26,13 @@ #define TVM_RUNTIME_CONTRIB_TFLITE_TFLITE_RUNTIME_H_ #include +#include #include #include -#include -#include #include +#include +#include namespace tvm { namespace runtime { @@ -52,18 +53,15 @@ class TFLiteRuntime : public ModuleNode { * \param sptr_to_self The pointer to the module node. * \return The corresponding member function. */ - virtual PackedFunc GetFunction(const std::string& name, - const ObjectPtr& sptr_to_self); + virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self); /*! * \return The type key of the executor. */ - const char* type_key() const { - return "TFLiteRuntime"; - } + const char* type_key() const { return "TFLiteRuntime"; } /*! - * \brief Invoke the internal tflite interpreter and run the whole model in + * \brief Invoke the internal tflite interpreter and run the whole model in * dependency order. */ void Invoke(); @@ -73,8 +71,7 @@ class TFLiteRuntime : public ModuleNode { * \param tflite_model_bytes The tflite model. * \param ctx The context where the tflite model will be executed on. */ - void Init(const std::string& tflite_model_bytes, - TVMContext ctx); + void Init(const std::string& tflite_model_bytes, TVMContext ctx); /*! * \brief set index-th input to the model. @@ -97,6 +94,8 @@ class TFLiteRuntime : public ModuleNode { */ NDArray GetOutput(int index) const; + // Buffer backing the interpreter's model + std::unique_ptr flatBuffersBuffer_; // TFLite interpreter std::unique_ptr interpreter_; // TVM context diff --git a/src/runtime/cpu_device_api.cc b/src/runtime/cpu_device_api.cc index 920bdae5b964..c70a4f29ccbe 100644 --- a/src/runtime/cpu_device_api.cc +++ b/src/runtime/cpu_device_api.cc @@ -22,10 +22,12 @@ */ #include #include -#include #include +#include + #include #include + #include "workspace_pool.h" #ifdef __ANDROID__ @@ -42,9 +44,7 @@ class CPUDeviceAPI final : public DeviceAPI { *rv = 1; } } - void* AllocDataSpace(TVMContext ctx, - size_t nbytes, - size_t alignment, + void* AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment, DLDataType type_hint) final { void* ptr; #if _MSC_VER @@ -69,53 +69,38 @@ class CPUDeviceAPI final : public DeviceAPI { #endif } - void CopyDataFromTo(const void* from, - size_t from_offset, - void* to, - size_t to_offset, - size_t size, - TVMContext ctx_from, - TVMContext ctx_to, - DLDataType type_hint, + void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size, + TVMContext ctx_from, TVMContext ctx_to, DLDataType type_hint, TVMStreamHandle stream) final { - memcpy(static_cast(to) + to_offset, - static_cast(from) + from_offset, - size); + memcpy(static_cast(to) + to_offset, static_cast(from) + from_offset, size); } - void StreamSync(TVMContext ctx, TVMStreamHandle stream) final { - } + void StreamSync(TVMContext ctx, TVMStreamHandle stream) final {} void* AllocWorkspace(TVMContext ctx, size_t size, DLDataType type_hint) final; void FreeWorkspace(TVMContext ctx, void* data) final; static const std::shared_ptr& Global() { - static std::shared_ptr inst = - std::make_shared(); + static std::shared_ptr inst = std::make_shared(); return inst; } }; struct CPUWorkspacePool : public WorkspacePool { - CPUWorkspacePool() : - WorkspacePool(kDLCPU, CPUDeviceAPI::Global()) {} + CPUWorkspacePool() : WorkspacePool(kDLCPU, CPUDeviceAPI::Global()) {} }; -void* CPUDeviceAPI::AllocWorkspace(TVMContext ctx, - size_t size, - DLDataType type_hint) { - return dmlc::ThreadLocalStore::Get() - ->AllocWorkspace(ctx, size); +void* CPUDeviceAPI::AllocWorkspace(TVMContext ctx, size_t size, DLDataType type_hint) { + return dmlc::ThreadLocalStore::Get()->AllocWorkspace(ctx, size); } void CPUDeviceAPI::FreeWorkspace(TVMContext ctx, void* data) { dmlc::ThreadLocalStore::Get()->FreeWorkspace(ctx, data); } -TVM_REGISTER_GLOBAL("device_api.cpu") -.set_body([](TVMArgs args, TVMRetValue* rv) { - DeviceAPI* ptr = CPUDeviceAPI::Global().get(); - *rv = static_cast(ptr); - }); +TVM_REGISTER_GLOBAL("device_api.cpu").set_body([](TVMArgs args, TVMRetValue* rv) { + DeviceAPI* ptr = CPUDeviceAPI::Global().get(); + *rv = static_cast(ptr); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/crt/crt_backend_api.c b/src/runtime/crt/crt_backend_api.c index 52cefafe3980..7589ce479014 100644 --- a/src/runtime/crt/crt_backend_api.c +++ b/src/runtime/crt/crt_backend_api.c @@ -17,13 +17,14 @@ * under the License. */ -#include -#include - +#include #include #include -#include #include +#include +#include + +#include "packed_func.h" void* TVMBackendAllocWorkspace(int device_type, int device_id, uint64_t nbytes, int dtype_code_hint, int dtype_bits_hint) { @@ -48,7 +49,7 @@ int TVMBackendParallelLaunch(FTVMParallelLambda flambda, void* cdata, int num_ta int TVMBackendRegisterSystemLibSymbol(const char* name, void* ptr) { g_fexecs = vrealloc(g_fexecs, sizeof(TVMPackedFunc) * (g_fexecs_count + 1)); - snprintf(g_fexecs[g_fexecs_count].name, sizeof(g_fexecs[g_fexecs_count].name), name); + snprintf(g_fexecs[g_fexecs_count].name, sizeof(g_fexecs[g_fexecs_count].name), "%s", name); g_fexecs[g_fexecs_count].fexec = ptr; g_fexecs_count++; return 0; diff --git a/src/runtime/crt/crt_runtime_api.c b/src/runtime/crt/crt_runtime_api.c index 6d7c010e3757..bd7d35e119bc 100644 --- a/src/runtime/crt/crt_runtime_api.c +++ b/src/runtime/crt/crt_runtime_api.c @@ -17,15 +17,14 @@ * under the License. */ -#include - +#include #include #include -#include #include +#include -#include "ndarray.h" #include "graph_runtime.h" +#include "ndarray.h" #include "packed_func.h" // Handle internal errors @@ -41,14 +40,8 @@ const char* TVMGetLastError(void) { return g_last_error; } // Manipulate NDArray on target device -int TVMArrayAlloc(const tvm_index_t* shape, - int ndim, - int dtype_code, - int dtype_bits, - int dtype_lanes, - int device_type, - int device_id, - TVMArrayHandle* out) { +int TVMArrayAlloc(const tvm_index_t* shape, int ndim, int dtype_code, int dtype_bits, + int dtype_lanes, int device_type, int device_id, TVMArrayHandle* out) { DLDataType dtype; dtype.code = dtype_code; dtype.bits = dtype_bits; @@ -67,14 +60,10 @@ int TVMArrayFree(TVMArrayHandle handle) { return TVMNDArray_Release(&arr); } -void * SystemLibraryCreate() { - return 0; -} +void* SystemLibraryCreate() { return 0; } -int TVMModGetFunction(TVMModuleHandle mod, - const char* func_name, - int query_imports, - TVMFunctionHandle *out) { +int TVMModGetFunction(TVMModuleHandle mod, const char* func_name, int query_imports, + TVMFunctionHandle* out) { int status = 0; if (!strcmp(func_name, "load_params")) { *out = &TVMGraphRuntime_LoadParams; diff --git a/src/runtime/crt/graph_runtime.c b/src/runtime/crt/graph_runtime.c index b5ed3b70281b..0ddbb41ae730 100644 --- a/src/runtime/crt/graph_runtime.c +++ b/src/runtime/crt/graph_runtime.c @@ -22,26 +22,29 @@ * \brief implement graph runtime in pure C */ +#include "graph_runtime.h" + #include #include "logging.h" -#include "graph_runtime.h" #ifndef MAX #define MAX(a, b) (((a) > (b)) ? (a) : (b)) #endif // MAX -uint32_t Shape_Accumulate(int64_t * shape, uint32_t ndim) { +uint32_t Shape_Accumulate(int64_t* shape, uint32_t ndim) { int64_t accum = 1; uint32_t idx; for (idx = 0; idx < ndim; idx++) { - if (shape[idx] == 0) { break; } + if (shape[idx] == 0) { + break; + } accum *= shape[idx]; } return accum; } -int NodeEntry_Load(TVMGraphRuntimeNodeEntry * entry, JSONReader * reader) { +int NodeEntry_Load(TVMGraphRuntimeNodeEntry* entry, JSONReader* reader) { int status = 0; reader->BeginArray(reader); if (!(reader->NextArrayItem(reader))) { @@ -66,8 +69,8 @@ int NodeEntry_Load(TVMGraphRuntimeNodeEntry * entry, JSONReader * reader) { return status; } -void TVMGraphRuntimeNode_LoadAttrs(TVMGraphRuntimeNode * node, JSONReader *reader, - TVMOpParam* param) { +void TVMGraphRuntimeNode_LoadAttrs(TVMGraphRuntimeNode* node, JSONReader* reader, + TVMOpParam* param) { int bitmask = 0; char key[20], value[120]; memset(param, 0, sizeof(TVMOpParam)); @@ -92,10 +95,12 @@ void TVMGraphRuntimeNode_LoadAttrs(TVMGraphRuntimeNode * node, JSONReader *reade fprintf(stderr, "do not support key %s", key); } } - if (bitmask != (1|2|4|8)) { fprintf(stderr, "invalid format\n"); } + if (bitmask != (1 | 2 | 4 | 8)) { + fprintf(stderr, "invalid format\n"); + } } -int TVMGraphRuntimeNode_Load(TVMGraphRuntimeNode * node, JSONReader *reader) { +int TVMGraphRuntimeNode_Load(TVMGraphRuntimeNode* node, JSONReader* reader) { int status = 0; reader->BeginObject(reader); int bitmask = 0; @@ -111,8 +116,8 @@ int TVMGraphRuntimeNode_Load(TVMGraphRuntimeNode * node, JSONReader *reader) { size_t count = node->inputs_count; reader->BeginArray(reader); while (reader->NextArrayItem(reader)) { - node->inputs = vrealloc(node->inputs, sizeof(TVMGraphRuntimeNodeEntry)*(count+1)); - TVMGraphRuntimeNodeEntry * inputs = node->inputs + count; + node->inputs = vrealloc(node->inputs, sizeof(TVMGraphRuntimeNodeEntry) * (count + 1)); + TVMGraphRuntimeNodeEntry* inputs = node->inputs + count; reader->BeginArray(reader); if (!reader->NextArrayItem(reader)) { fprintf(stderr, "invalid json format\n"); @@ -152,9 +157,11 @@ int TVMGraphRuntimeNode_Load(TVMGraphRuntimeNode * node, JSONReader *reader) { fprintf(stderr, "do not support key %s", key); status = -1; } - if (status != 0) { break; } + if (status != 0) { + break; + } } - if (bitmask != (1|2|4)) { + if (bitmask != (1 | 2 | 4)) { fprintf(stderr, "invalid format\n"); status = -1; } @@ -169,15 +176,17 @@ TVMGraphRuntimeNode TVMGraphRuntimeNodeCreate() { return node; } -void TVMGraphRuntimeNodeRelease(TVMGraphRuntimeNode * node) { - if (!node) { return; } +void TVMGraphRuntimeNodeRelease(TVMGraphRuntimeNode* node) { + if (!node) { + return; + } if (node->inputs) { vfree(node->inputs); node->inputs = 0; } } -int TVMGraphRuntimeGraphAttr_Load(TVMGraphRuntimeGraphAttr * attr, JSONReader *reader) { +int TVMGraphRuntimeGraphAttr_Load(TVMGraphRuntimeGraphAttr* attr, JSONReader* reader) { int status = 0; int bitmask = 0; char key[16], type[16]; @@ -211,7 +220,8 @@ int TVMGraphRuntimeGraphAttr_Load(TVMGraphRuntimeGraphAttr * attr, JSONReader *r reader->ReadString(reader, attr->dltype + dltype_count * TVM_CRT_STRLEN_DLTYPE); dltype_count++; } - attr->dltype_count = dltype_count;; + attr->dltype_count = dltype_count; + if (reader->NextArrayItem(reader)) { fprintf(stderr, "Invalid json format\n"); status = -1; @@ -238,7 +248,7 @@ int TVMGraphRuntimeGraphAttr_Load(TVMGraphRuntimeGraphAttr * attr, JSONReader *r } reader->BeginArray(reader); while (reader->NextArrayItem(reader)) { - attr->storage_id = vrealloc(attr->storage_id, sizeof(uint32_t)*(storage_id_count+1)); + attr->storage_id = vrealloc(attr->storage_id, sizeof(uint32_t) * (storage_id_count + 1)); reader->ReadUnsignedInteger(reader, &(attr->storage_id[storage_id_count])); storage_id_count++; } @@ -269,10 +279,10 @@ int TVMGraphRuntimeGraphAttr_Load(TVMGraphRuntimeGraphAttr * attr, JSONReader *r reader->BeginArray(reader); while (reader->NextArrayItem(reader)) { attr->shape = - vrealloc(attr->shape, sizeof(attr->shape[0])*(shape_count+1)*TVM_CRT_MAX_NDIM); - attr->ndim = vrealloc(attr->ndim, sizeof(attr->ndim[0])*(shape_count+1)); + vrealloc(attr->shape, sizeof(attr->shape[0]) * (shape_count + 1) * TVM_CRT_MAX_NDIM); + attr->ndim = vrealloc(attr->ndim, sizeof(attr->ndim[0]) * (shape_count + 1)); reader->BeginArray(reader); - int64_t * attr_shape_ptr = attr->shape + shape_count*TVM_CRT_MAX_NDIM; + int64_t* attr_shape_ptr = attr->shape + shape_count * TVM_CRT_MAX_NDIM; reader->ReadInteger(reader, attr_shape_ptr + 0); uint32_t ndim = 1; if (reader->NextArrayItem(reader)) { @@ -316,7 +326,8 @@ int TVMGraphRuntimeGraphAttr_Load(TVMGraphRuntimeGraphAttr * attr, JSONReader *r break; } while (reader->NextArrayItem(reader)) { - attr->device_index = vrealloc(attr->device_index, sizeof(uint32_t)*(device_index_count+1)); + attr->device_index = + vrealloc(attr->device_index, sizeof(uint32_t) * (device_index_count + 1)); reader->ReadUnsignedInteger(reader, &(attr->device_index[device_index_count])); device_index_count++; } @@ -339,7 +350,7 @@ int TVMGraphRuntimeGraphAttr_Load(TVMGraphRuntimeGraphAttr * attr, JSONReader *r status = -1; break; } - uint32_t * temp = 0; + uint32_t* temp = 0; uint32_t temp_count = 0; reader->BeginArray(reader); while (reader->NextArrayItem(reader)) { @@ -371,15 +382,17 @@ int TVMGraphRuntimeGraphAttr_Load(TVMGraphRuntimeGraphAttr * attr, JSONReader *r } } } - if (bitmask != (1|2|4)) { + if (bitmask != (1 | 2 | 4)) { fprintf(stderr, "invalid format\n"); status = -1; } return status; } -void TVMGraphRuntimeGraphAttr_Release(TVMGraphRuntimeGraphAttr * attr) { - if (!attr) { return; } +void TVMGraphRuntimeGraphAttr_Release(TVMGraphRuntimeGraphAttr* attr) { + if (!attr) { + return; + } if (attr->storage_id) { vfree(attr->storage_id); attr->storage_id = 0; @@ -402,90 +415,90 @@ void TVMGraphRuntimeGraphAttr_Release(TVMGraphRuntimeGraphAttr * attr) { } } -int TVMGraphRuntime_Load(TVMGraphRuntime * runtime, JSONReader *reader) { - int status = 0; - reader->BeginObject(reader); - int bitmask = 0; - char key[20]; - while (reader->NextObjectItem(reader, key)) { - if (!strcmp(key, "nodes")) { - reader->BeginArray(reader); - while (reader->NextArrayItem(reader)) { - runtime->nodes = +int TVMGraphRuntime_Load(TVMGraphRuntime* runtime, JSONReader* reader) { + int status = 0; + reader->BeginObject(reader); + int bitmask = 0; + char key[20]; + while (reader->NextObjectItem(reader, key)) { + if (!strcmp(key, "nodes")) { + reader->BeginArray(reader); + while (reader->NextArrayItem(reader)) { + runtime->nodes = vrealloc(runtime->nodes, sizeof(TVMGraphRuntimeNode) * (runtime->nodes_count + 1)); - TVMGraphRuntimeNode * node = runtime->nodes + runtime->nodes_count; - status = TVMGraphRuntimeNode_Load(node, reader); - if (status != 0) { - fprintf(stderr, "failed to load an element in `nodes` field in graph runtime node.\n"); - break; + TVMGraphRuntimeNode* node = runtime->nodes + runtime->nodes_count; + status = TVMGraphRuntimeNode_Load(node, reader); + if (status != 0) { + fprintf(stderr, "failed to load an element in `nodes` field in graph runtime node.\n"); + break; #if TVM_CRT_DEBUG - } else { - printf("loading: node (%u) %s loaded.\n", runtime->nodes_count, node->name); + } else { + printf("loading: node (%u) %s loaded.\n", runtime->nodes_count, node->name); #endif // TVM_CRT_DEBUG - } - runtime->nodes_count++; } - bitmask |= 1; - } else if (!strcmp(key, "arg_nodes")) { - reader->BeginArray(reader); - while (reader->NextArrayItem(reader)) { - runtime->input_nodes = + runtime->nodes_count++; + } + bitmask |= 1; + } else if (!strcmp(key, "arg_nodes")) { + reader->BeginArray(reader); + while (reader->NextArrayItem(reader)) { + runtime->input_nodes = vrealloc(runtime->input_nodes, sizeof(uint32_t) * (runtime->input_nodes_count + 1)); - uint32_t * node = runtime->input_nodes + runtime->input_nodes_count; - reader->ReadUnsignedInteger(reader, node); - runtime->input_nodes_count++; - } - bitmask |= 2; - } else if (!strcmp(key, "node_row_ptr")) { - reader->BeginArray(reader); - while (reader->NextArrayItem(reader)) { - runtime->node_row_ptr = + uint32_t* node = runtime->input_nodes + runtime->input_nodes_count; + reader->ReadUnsignedInteger(reader, node); + runtime->input_nodes_count++; + } + bitmask |= 2; + } else if (!strcmp(key, "node_row_ptr")) { + reader->BeginArray(reader); + while (reader->NextArrayItem(reader)) { + runtime->node_row_ptr = vrealloc(runtime->node_row_ptr, sizeof(uint32_t) * (runtime->node_row_ptr_count + 1)); - uint32_t count = runtime->node_row_ptr_count; - uint32_t * node = runtime->node_row_ptr + count; - reader->ReadUnsignedInteger(reader, node); - runtime->node_row_ptr_count++; - } - bitmask |= 4; - } else if (!strcmp(key, "heads")) { - reader->BeginArray(reader); - while (reader->NextArrayItem(reader)) { - runtime->outputs = - vrealloc(runtime->outputs, - sizeof(TVMGraphRuntimeNodeEntry) * (runtime->outputs_count + 1)); - TVMGraphRuntimeNodeEntry * entry = runtime->outputs + runtime->outputs_count; - status = NodeEntry_Load(entry, reader); - if (status != 0) { - fprintf(stderr, "Fail to load an element in `heads` field in graph runtime node.\n"); - break; - } - runtime->outputs_count++; - } - bitmask |= 8; - } else if (!strcmp(key, "attrs")) { - status = TVMGraphRuntimeGraphAttr_Load(&(runtime->attrs), reader); + uint32_t count = runtime->node_row_ptr_count; + uint32_t* node = runtime->node_row_ptr + count; + reader->ReadUnsignedInteger(reader, node); + runtime->node_row_ptr_count++; + } + bitmask |= 4; + } else if (!strcmp(key, "heads")) { + reader->BeginArray(reader); + while (reader->NextArrayItem(reader)) { + runtime->outputs = vrealloc( + runtime->outputs, sizeof(TVMGraphRuntimeNodeEntry) * (runtime->outputs_count + 1)); + TVMGraphRuntimeNodeEntry* entry = runtime->outputs + runtime->outputs_count; + status = NodeEntry_Load(entry, reader); if (status != 0) { fprintf(stderr, "Fail to load an element in `heads` field in graph runtime node.\n"); break; } - bitmask |= 16; - } else if (!strcmp(key, "metadata")) { + runtime->outputs_count++; + } + bitmask |= 8; + } else if (!strcmp(key, "attrs")) { + status = TVMGraphRuntimeGraphAttr_Load(&(runtime->attrs), reader); + if (status != 0) { + fprintf(stderr, "Fail to load an element in `heads` field in graph runtime node.\n"); break; - } else { - fprintf(stderr, "key %s is not supported\n", key); - status = -1; } - if (status != 0) { break; } + bitmask |= 16; + } else if (!strcmp(key, "metadata")) { + break; + } else { + fprintf(stderr, "key %s is not supported\n", key); + status = -1; } - if (!(bitmask == (1|2|4|8|16))) { - fprintf(stderr, "invalid format\n"); - status = -1; + if (status != 0) { + break; } - return status; + } + if (!(bitmask == (1 | 2 | 4 | 8 | 16))) { + fprintf(stderr, "invalid format\n"); + status = -1; + } + return status; } -uint32_t TVMGraphRuntime_GetEntryId(TVMGraphRuntime * runtime, - uint32_t nid, uint32_t index) { +uint32_t TVMGraphRuntime_GetEntryId(TVMGraphRuntime* runtime, uint32_t nid, uint32_t index) { return runtime->node_row_ptr[nid] + index; } @@ -495,10 +508,10 @@ uint32_t TVMGraphRuntime_GetEntryId(TVMGraphRuntime * runtime, * \param name The name of the input. * \return The index of input. */ -int TVMGraphRuntime_GetInputIndex(TVMGraphRuntime * runtime, const char * name) { +int TVMGraphRuntime_GetInputIndex(TVMGraphRuntime* runtime, const char* name) { uint32_t i; int32_t rv = -1; - for (i = 0; i< runtime->input_nodes_count; ++i) { + for (i = 0; i < runtime->input_nodes_count; ++i) { uint32_t nid = runtime->input_nodes[i]; if (!strcmp(runtime->nodes[nid].name, name)) { rv = i; @@ -515,7 +528,7 @@ int TVMGraphRuntime_GetInputIndex(TVMGraphRuntime * runtime, const char * name) * \param name The name of the input. * \param data_in The input data. */ -void TVMGraphRuntime_SetInput(TVMGraphRuntime * runtime, const char * name, DLTensor* data_in) { +void TVMGraphRuntime_SetInput(TVMGraphRuntime* runtime, const char* name, DLTensor* data_in) { uint32_t index = runtime->GetInputIndex(runtime, name); if (index >= runtime->input_nodes_count) { fprintf(stderr, "given index is greater than num of input nodes.\n"); @@ -531,10 +544,10 @@ void TVMGraphRuntime_SetInput(TVMGraphRuntime * runtime, const char * name, DLTe * \param param_size The parameter size. * \return The result of this function execution. */ -int TVMGraphRuntime_LoadParams(TVMGraphRuntime * runtime, const char * param_blob, +int TVMGraphRuntime_LoadParams(TVMGraphRuntime* runtime, const char* param_blob, const uint32_t param_size) { int status = 0; - const char * bptr = param_blob; + const char* bptr = param_blob; uint64_t header, reserved; header = ((uint64_t*)bptr)[0]; // NOLINT(*) bptr += sizeof(header); @@ -546,7 +559,7 @@ int TVMGraphRuntime_LoadParams(TVMGraphRuntime * runtime, const char * param_blo bptr += sizeof(reserved); // read names - char * names = vmalloc(TVM_CRT_STRLEN_NAME * runtime->nodes_count); + char* names = vmalloc(TVM_CRT_STRLEN_NAME * runtime->nodes_count); memset(names, 0, TVM_CRT_STRLEN_NAME * runtime->nodes_count); uint64_t names_count; int idx; @@ -576,12 +589,12 @@ int TVMGraphRuntime_LoadParams(TVMGraphRuntime * runtime, const char * param_blo for (idx = 0; idx < size; idx++) { int32_t in_idx = runtime->GetInputIndex(runtime, names + TVM_CRT_STRLEN_NAME * idx); - CHECK_GT(in_idx, 0, - "Found param for non-existent input: %s\n", names + TVM_CRT_STRLEN_NAME * idx); + CHECK_GT(in_idx, 0, "Found param for non-existent input: %s\n", + names + TVM_CRT_STRLEN_NAME * idx); uint32_t eid = runtime->GetEntryId(runtime, runtime->input_nodes[in_idx], 0); if (!(eid < runtime->data_entry_count)) { - fprintf(stderr, "`entry_id`=%d is greater than expected(%d).\n", - eid, runtime->data_entry_count); + fprintf(stderr, "`entry_id`=%d is greater than expected(%d).\n", eid, + runtime->data_entry_count); status = -1; } @@ -595,11 +608,11 @@ int TVMGraphRuntime_LoadParams(TVMGraphRuntime * runtime, const char * param_blo } status |= TVMNDArray_Load(&(runtime->data_entry[eid]), &bptr); #if TVM_CRT_DEBUG - TVMNDArray * entry = &(runtime->data_entry[eid]); + TVMNDArray* entry = &(runtime->data_entry[eid]); printf("loading: param %s loaded, in_idx=%d, eid=%d, ndim=%d, data[0]=%f\n", names + TVM_CRT_STRLEN_NAME * idx, in_idx, eid, entry->dl_tensor.ndim, ((float*)entry->dl_tensor.data)[0]); // NOLINT(*) -#endif // TVM_CRT_DEBUG +#endif // TVM_CRT_DEBUG } // Release memory @@ -612,7 +625,7 @@ int TVMGraphRuntime_LoadParams(TVMGraphRuntime * runtime, const char * param_blo * \brief Run all the operations one by one. * \param runtime The graph runtime. */ -void TVMGraphRuntime_Run(TVMGraphRuntime * runtime) { +void TVMGraphRuntime_Run(TVMGraphRuntime* runtime) { // setup the array and requirements. uint32_t idx; for (idx = 0; idx < runtime->op_execs_count; ++idx) { @@ -625,7 +638,7 @@ void TVMGraphRuntime_Run(TVMGraphRuntime * runtime) { } } -int TVMGraphRuntime_GetOutput(TVMGraphRuntime * runtime, const int32_t idx, DLTensor * out) { +int TVMGraphRuntime_GetOutput(TVMGraphRuntime* runtime, const int32_t idx, DLTensor* out) { int status = 0; uint32_t nid = runtime->outputs[idx].node_id; uint32_t index = runtime->outputs[idx].index; @@ -634,7 +647,7 @@ int TVMGraphRuntime_GetOutput(TVMGraphRuntime * runtime, const int32_t idx, DLTe // copy data section to allocated output tensor int32_t elem_bytes = out->dtype.bits / 8; int64_t size = Shape_Accumulate(out->shape, out->ndim); - DLTensor * tensor = &(runtime->data_entry[eid].dl_tensor); + DLTensor* tensor = &(runtime->data_entry[eid].dl_tensor); CHECK(out->ndim == tensor->ndim); CHECK(out->dtype.bits == tensor->dtype.bits); CHECK(Shape_Accumulate(out->shape, out->ndim) == Shape_Accumulate(tensor->shape, tensor->ndim)); @@ -642,27 +655,27 @@ int TVMGraphRuntime_GetOutput(TVMGraphRuntime * runtime, const int32_t idx, DLTe return status; } -void TVMGraphRuntime_SetupStorage(TVMGraphRuntime * runtime) { +void TVMGraphRuntime_SetupStorage(TVMGraphRuntime* runtime) { uint32_t idx; // Grab saved optimization plan from graph. - TVMGraphRuntimeGraphAttr * attrs = &(runtime->attrs); - DLDataType * vtype = vmalloc(sizeof(DLDataType) * attrs->dltype_count); + TVMGraphRuntimeGraphAttr* attrs = &(runtime->attrs); + DLDataType* vtype = vmalloc(sizeof(DLDataType) * attrs->dltype_count); for (idx = 0; idx < attrs->dltype_count; idx++) { vtype[idx] = String2DLDataType(attrs->dltype + idx * TVM_CRT_STRLEN_DLTYPE); } // Size and device type of each storage pool entry. - TVMGraphRuntimePoolEntry * pool_entry = - vmalloc(sizeof(TVMGraphRuntimePoolEntry) * runtime->nodes_count); + TVMGraphRuntimePoolEntry* pool_entry = + vmalloc(sizeof(TVMGraphRuntimePoolEntry) * runtime->nodes_count); memset(pool_entry, 0, sizeof(TVMGraphRuntimePoolEntry) * runtime->nodes_count); - uint32_t pool_entry_count = 0; + uint32_t pool_entry_count = 0; // Find the maximum space size. for (idx = 0; idx < attrs->shape_count; idx++) { int storage_id = attrs->storage_id[idx]; // Use the fallback device if no device index is available. int device_type = runtime->ctxs[0].device_type; - uint32_t size = Shape_Accumulate(attrs->shape+idx*TVM_CRT_MAX_NDIM, attrs->ndim[idx]); + uint32_t size = Shape_Accumulate(attrs->shape + idx * TVM_CRT_MAX_NDIM, attrs->ndim[idx]); DLDataType t = vtype[idx]; uint32_t bits = t.bits * t.lanes; size_t bytes = ((bits + 7U) / 8U) * size; @@ -678,9 +691,11 @@ void TVMGraphRuntime_SetupStorage(TVMGraphRuntime * runtime) { // Allocate the space. for (idx = 0; idx < pool_entry_count; idx++) { runtime->storage_pool = - vrealloc(runtime->storage_pool, sizeof(TVMNDArray) * (runtime->storage_pool_count + 1)); + vrealloc(runtime->storage_pool, sizeof(TVMNDArray) * (runtime->storage_pool_count + 1)); TVMGraphRuntimePoolEntry pit = pool_entry[idx]; - int64_t shape[TVM_CRT_MAX_NDIM] = {0, }; + int64_t shape[TVM_CRT_MAX_NDIM] = { + 0, + }; TVMContext ctx = runtime->ctxs[0]; DLDataType dtype = {kDLFloat, 32, 1}; shape[0] = (pit.size + 3) / 4; @@ -696,13 +711,13 @@ void TVMGraphRuntime_SetupStorage(TVMGraphRuntime * runtime) { runtime->data_entry_count = runtime->node_row_ptr[runtime->node_row_ptr_count - 1]; runtime->data_entry = vmalloc(sizeof(TVMNDArray) * runtime->data_entry_count); for (idx = 0; idx < runtime->data_entry_count; ++idx) { - size_t storage_id = attrs->storage_id[idx]; + uint32_t storage_id = attrs->storage_id[idx]; CHECK(storage_id < runtime->storage_pool_count); runtime->data_entry[idx] = - TVMNDArray_CreateView(&(runtime->storage_pool[storage_id]), - attrs->shape+idx*TVM_CRT_MAX_NDIM, attrs->ndim[idx], vtype[idx]); + TVMNDArray_CreateView(&(runtime->storage_pool[storage_id]), + attrs->shape + idx * TVM_CRT_MAX_NDIM, attrs->ndim[idx], vtype[idx]); CHECK_NE(runtime->data_entry[idx].dl_tensor.data, 0, - "fail to create for node with idx=%d, storage_id=%lu\n", idx, storage_id); + "fail to create for node with idx=%d, storage_id=%u\n", idx, storage_id); } // Release memory @@ -710,18 +725,18 @@ void TVMGraphRuntime_SetupStorage(TVMGraphRuntime * runtime) { vfree(pool_entry); } -int TVMGraphRuntime_SetupOpExecs(TVMGraphRuntime * runtime) { +int TVMGraphRuntime_SetupOpExecs(TVMGraphRuntime* runtime) { int status = 0; uint32_t nid, idx; runtime->op_execs_count = runtime->nodes_count; runtime->op_execs = vmalloc(sizeof(TVMPackedFunc) * runtime->op_execs_count); for (nid = 0; nid < runtime->nodes_count; nid++) { - const TVMGraphRuntimeNode * inode = runtime->nodes + nid; + const TVMGraphRuntimeNode* inode = runtime->nodes + nid; if (strcmp(inode->op_type, "null")) { DLTensorPtr args[TVM_CRT_MAX_ARGS]; uint32_t args_count = 0; for (idx = 0; idx < inode->inputs_count; idx++) { - const TVMGraphRuntimeNodeEntry * entry = inode->inputs + idx; + const TVMGraphRuntimeNodeEntry* entry = inode->inputs + idx; uint32_t eid = runtime->GetEntryId(runtime, entry->node_id, entry->index); args[idx] = &(runtime->data_entry[eid].dl_tensor); args_count++; @@ -760,13 +775,13 @@ typedef struct TVMOpArgs { uint32_t arg_values_count; uint32_t arg_tcodes[TVM_CRT_MAX_ARGS]; uint32_t arg_tcodes_count; - int64_t shape_data[TVM_CRT_MAX_ARGS]; + int64_t shape_data[TVM_CRT_MAX_ARGS]; uint32_t shape_data_count; } TVMOpArgs; -int32_t TVMGraphRuntime_CreateTVMOp(TVMGraphRuntime * runtime, const TVMOpParam * param, - DLTensorPtr * args, const uint32_t args_count, - uint32_t num_inputs, TVMPackedFunc * pf) { +int32_t TVMGraphRuntime_CreateTVMOp(TVMGraphRuntime* runtime, const TVMOpParam* param, + DLTensorPtr* args, const uint32_t args_count, + uint32_t num_inputs, TVMPackedFunc* pf) { int status = 0; uint32_t idx; TVMOpArgs arg_ptr; @@ -778,7 +793,7 @@ int32_t TVMGraphRuntime_CreateTVMOp(TVMGraphRuntime * runtime, const TVMOpParam for (idx = 0; idx < arg_ptr.args_count; ++idx) { TVMValue v; memset(&v, 0, sizeof(v)); - DLTensor * t = &(arg_ptr.args[idx]); + DLTensor* t = &(arg_ptr.args[idx]); /* v.v_handle = &((*args)[idx]); */ v.v_handle = args[idx]; arg_ptr.arg_values[idx] = v; @@ -811,19 +826,19 @@ int32_t TVMGraphRuntime_CreateTVMOp(TVMGraphRuntime * runtime, const TVMOpParam * \param ctxs The context of the host and devices where graph nodes will be * executed on. */ -void TVMGraphRuntime_Init(TVMGraphRuntime * runtime, const char * graph_json, - const TVMModule * module, const TVMContext * ctxs) { +void TVMGraphRuntime_Init(TVMGraphRuntime* runtime, const char* graph_json, const TVMModule* module, + const TVMContext* ctxs) { JSONReader reader = JSONReader_Create(graph_json); runtime->Load(runtime, &reader); + JSONReader_Release(&reader); runtime->ctxs[0] = ctxs[0]; runtime->SetupStorage(runtime); runtime->SetupOpExecs(runtime); - JSONReader_Release(&reader); } -TVMGraphRuntime * TVMGraphRuntimeCreate(const char * sym_json, - const TVMModule * m, const TVMContext * ctxs) { - TVMGraphRuntime * runtime = (TVMGraphRuntime*)vmalloc(sizeof(TVMGraphRuntime)); // NOLINT(*) +TVMGraphRuntime* TVMGraphRuntimeCreate(const char* sym_json, const TVMModule* m, + const TVMContext* ctxs) { + TVMGraphRuntime* runtime = (TVMGraphRuntime*)vmalloc(sizeof(TVMGraphRuntime)); // NOLINT(*) memset(runtime, 0, sizeof(TVMGraphRuntime)); runtime->GetEntryId = TVMGraphRuntime_GetEntryId; runtime->GetInputIndex = TVMGraphRuntime_GetInputIndex; @@ -842,9 +857,9 @@ TVMGraphRuntime * TVMGraphRuntimeCreate(const char * sym_json, return runtime; } -void TVMGraphRuntimeRelease(TVMGraphRuntime ** pptr) { +void TVMGraphRuntimeRelease(TVMGraphRuntime** pptr) { int32_t idx; - TVMGraphRuntime * runtime = *pptr; + TVMGraphRuntime* runtime = *pptr; for (idx = 0; idx < runtime->nodes_count; ++idx) { TVMGraphRuntimeNodeRelease(&(runtime->nodes[idx])); } diff --git a/src/runtime/crt/graph_runtime.h b/src/runtime/crt/graph_runtime.h index 3cb8ba95e0fa..fd3b14633222 100644 --- a/src/runtime/crt/graph_runtime.h +++ b/src/runtime/crt/graph_runtime.h @@ -27,9 +27,9 @@ #include #include "load_json.h" +#include "module.h" #include "ndarray.h" #include "packed_func.h" -#include "module.h" /*! \brief operator attributes about tvm op */ typedef struct TVMOpParam { @@ -51,7 +51,7 @@ typedef struct TVMGraphRuntimeNodeEntry { uint32_t index; uint32_t version; // JSON Loader - void (*Load)(JSONReader *reader); + void (*Load)(JSONReader* reader); } TVMGraphRuntimeNodeEntry; // Node @@ -63,26 +63,26 @@ typedef struct TVMGraphRuntimeNode { // parameters TVMOpParam param; // inputs - TVMGraphRuntimeNodeEntry * inputs; + TVMGraphRuntimeNodeEntry* inputs; // number of inputs size_t inputs_count; // control deps uint32_t control_deps[20]; // JSON Loader - void (*LoadAttrs)(struct TVMGraphRuntimeNode * node, JSONReader *reader, TVMOpParam* param); + void (*LoadAttrs)(struct TVMGraphRuntimeNode* node, JSONReader* reader, TVMOpParam* param); // JSON Loader - int (*Load)(struct TVMGraphRuntimeNode * node, JSONReader *reader); + int (*Load)(struct TVMGraphRuntimeNode* node, JSONReader* reader); } TVMGraphRuntimeNode; // Graph attribute typedef struct TVMGraphRuntimeGraphAttr { uint32_t storage_num_not_alloctaed; - uint32_t * storage_id; - uint32_t * device_index; - char * dltype; // "int8", "int16", "float32" + uint32_t* storage_id; + uint32_t* device_index; + char* dltype; // "int8", "int16", "float32" uint32_t dltype_count; - int64_t * shape; - uint32_t * ndim; + int64_t* shape; + uint32_t* ndim; uint32_t shape_count; } TVMGraphRuntimeGraphAttr; @@ -96,7 +96,7 @@ typedef DLTensor* DLTensorPtr; */ /* class GraphRuntime : public ModuleNode { */ typedef struct TVMGraphRuntime { - void (*Run)(struct TVMGraphRuntime * runtime); + void (*Run)(struct TVMGraphRuntime* runtime); /*! * \brief Initialize the graph executor with graph and context. @@ -107,10 +107,8 @@ typedef struct TVMGraphRuntime { * \param ctxs The context of the host and devices where graph nodes will be * executed on. */ - void (*Init)(struct TVMGraphRuntime * runtime, - const char * graph_json, - const TVMModule * module, - const TVMContext * ctxs); + void (*Init)(struct TVMGraphRuntime* runtime, const char* graph_json, const TVMModule* module, + const TVMContext* ctxs); /*! * \brief Get the input index given the name of input. @@ -118,7 +116,7 @@ typedef struct TVMGraphRuntime { * \param name The name of the input. * \return The index of input. */ - int (*GetInputIndex)(struct TVMGraphRuntime * runtime, const char * name); + int (*GetInputIndex)(struct TVMGraphRuntime* runtime, const char* name); /*! * \brief set input to the graph based on name. @@ -126,7 +124,7 @@ typedef struct TVMGraphRuntime { * \param name The name of the input. * \param data_in The input data. */ - void (*SetInput)(struct TVMGraphRuntime * runtime, const char * name, DLTensor* data_in); + void (*SetInput)(struct TVMGraphRuntime* runtime, const char* name, DLTensor* data_in); /*! * \brief Return NDArray for given output index. @@ -135,7 +133,7 @@ typedef struct TVMGraphRuntime { * \param out The DLTensor corresponding to given output node index. * \return The result of this function execution. */ - int (*GetOutput)(struct TVMGraphRuntime * runtime, const int32_t index, DLTensor * out); + int (*GetOutput)(struct TVMGraphRuntime* runtime, const int32_t index, DLTensor* out); /*! * \brief Load parameters from parameter blob. * \param runtime The graph runtime. @@ -143,15 +141,15 @@ typedef struct TVMGraphRuntime { * \param param_size The parameter size. * \return The result of this function execution. */ - int (*LoadParams)(struct TVMGraphRuntime * runtime, const char * param_blob, + int (*LoadParams)(struct TVMGraphRuntime* runtime, const char* param_blob, const uint32_t param_size); // The graph attribute fields. - int (*Load)(struct TVMGraphRuntime * runtime, JSONReader *reader); + int (*Load)(struct TVMGraphRuntime* runtime, JSONReader* reader); /*! \brief Setup the temporal storage */ - void (*SetupStorage)(struct TVMGraphRuntime * runtime); + void (*SetupStorage)(struct TVMGraphRuntime* runtime); /*! \brief Setup the executors. */ - int (*SetupOpExecs)(struct TVMGraphRuntime * runtime); + int (*SetupOpExecs)(struct TVMGraphRuntime* runtime); /*! * \brief Create an execution function given input. @@ -163,25 +161,25 @@ typedef struct TVMGraphRuntime { * \param pf The created executor. * \return The result of this function execution. */ - int32_t (*CreateTVMOp)(struct TVMGraphRuntime * runtime, const TVMOpParam * attrs, - DLTensorPtr * args, const uint32_t args_count, - uint32_t num_inputs, TVMPackedFunc * pf); + int32_t (*CreateTVMOp)(struct TVMGraphRuntime* runtime, const TVMOpParam* attrs, + DLTensorPtr* args, const uint32_t args_count, uint32_t num_inputs, + TVMPackedFunc* pf); // Get node entry index. - uint32_t (*GetEntryId)(struct TVMGraphRuntime * runtime, uint32_t nid, uint32_t index); + uint32_t (*GetEntryId)(struct TVMGraphRuntime* runtime, uint32_t nid, uint32_t index); /*! \brief The graph nodes. */ - TVMGraphRuntimeNode * nodes; + TVMGraphRuntimeNode* nodes; /*! \brief The graph nodes counter. */ uint32_t nodes_count; /*! \brief The argument nodes. */ - uint32_t * input_nodes; + uint32_t* input_nodes; uint32_t input_nodes_count; /*! \brief Used for quick entry indexing. */ - uint32_t * node_row_ptr; + uint32_t* node_row_ptr; uint32_t node_row_ptr_count; /*! \brief Output entries. */ - TVMGraphRuntimeNodeEntry * outputs; + TVMGraphRuntimeNodeEntry* outputs; /*! \brief Output entries counter. */ uint32_t outputs_count; /*! \brief Additional graph attributes. */ @@ -190,28 +188,28 @@ typedef struct TVMGraphRuntime { TVMModule module; /*! \brief Execution context of all devices including the host. */ TVMContext ctxs[1]; - uint32_t ctxs_count; + uint32_t ctxs_count; /*! \brief Common storage pool for all devices. */ - TVMNDArray * storage_pool; + TVMNDArray* storage_pool; uint32_t storage_pool_count; /*! \brief Data entry of each node. */ - TVMNDArray * data_entry; + TVMNDArray* data_entry; uint32_t data_entry_count; /*! \brief Operator on each node. */ - TVMPackedFunc * op_execs; + TVMPackedFunc* op_execs; uint32_t op_execs_count; } TVMGraphRuntime; // public functions -TVMGraphRuntime * TVMGraphRuntimeCreate(const char * sym_json, const TVMModule * m, - const TVMContext * ctxs); -void TVMGraphRuntimeRelease(TVMGraphRuntime ** runtime); +TVMGraphRuntime* TVMGraphRuntimeCreate(const char* sym_json, const TVMModule* m, + const TVMContext* ctxs); +void TVMGraphRuntimeRelease(TVMGraphRuntime** runtime); // private functions -void TVMGraphRuntime_SetInput(TVMGraphRuntime * runtime, const char * name, DLTensor* data_in); -int TVMGraphRuntime_LoadParams(TVMGraphRuntime * runtime, const char * param_blob, +void TVMGraphRuntime_SetInput(TVMGraphRuntime* runtime, const char* name, DLTensor* data_in); +int TVMGraphRuntime_LoadParams(TVMGraphRuntime* runtime, const char* param_blob, const uint32_t param_size); -void TVMGraphRuntime_Run(TVMGraphRuntime * runtime); -int TVMGraphRuntime_GetOutput(TVMGraphRuntime * runtime, const int32_t idx, DLTensor * out); +void TVMGraphRuntime_Run(TVMGraphRuntime* runtime); +int TVMGraphRuntime_GetOutput(TVMGraphRuntime* runtime, const int32_t idx, DLTensor* out); #endif // TVM_RUNTIME_CRT_GRAPH_RUNTIME_H_ diff --git a/src/runtime/crt/load_json.c b/src/runtime/crt/load_json.c index cf9492b8e2fa..5ae60cca86b6 100644 --- a/src/runtime/crt/load_json.c +++ b/src/runtime/crt/load_json.c @@ -21,36 +21,41 @@ * \file load_json.c * \brief Load graph from JSON file. */ -#include - #include "load_json.h" +#include + // the node entry structure in serialized format typedef struct JSONNodeEntry { uint32_t node_id; uint32_t index; uint32_t version; - void (*Load)(struct JSONNodeEntry * entry, JSONReader *reader); + void (*Load)(struct JSONNodeEntry* entry, JSONReader* reader); } JSONNodeEntry; -void JSONNodeEntryLoad(JSONNodeEntry * entry, JSONReader *reader) { +void JSONNodeEntryLoad(JSONNodeEntry* entry, JSONReader* reader) { reader->BeginArray(reader); - if (reader->NextArrayItem(reader)) { fprintf(stderr, "invalid json format\n"); } + if (reader->NextArrayItem(reader)) { + fprintf(stderr, "invalid json format\n"); + } reader->ReadUnsignedInteger(reader, &(entry->node_id)); - if (reader->NextArrayItem(reader)) { fprintf(stderr, "invalid json format\n"); } + if (reader->NextArrayItem(reader)) { + fprintf(stderr, "invalid json format\n"); + } reader->ReadUnsignedInteger(reader, &(entry->index)); if (reader->NextArrayItem(reader)) { reader->ReadUnsignedInteger(reader, &(entry->version)); - if (!reader->NextArrayItem(reader)) { fprintf(stderr, "invalid json format\n"); } + if (!reader->NextArrayItem(reader)) { + fprintf(stderr, "invalid json format\n"); + } } else { entry->version = 0; } } - // implementation of Seq class -void SeqPush(Seq * seq, uint32_t src) { +void SeqPush(Seq* seq, uint32_t src) { if (seq->size >= seq->allocated) { printf("seq too large.\n"); } @@ -58,14 +63,14 @@ void SeqPush(Seq * seq, uint32_t src) { seq->size += 1; } -uint32_t * SeqBack(Seq * seq) { +uint32_t* SeqBack(Seq* seq) { if (seq->size >= seq->allocated) { printf("seq too large.\n"); } - return seq->data + (seq->size-1); + return seq->data + (seq->size - 1); } -void SeqPop(Seq * seq) { +void SeqPop(Seq* seq) { if (seq->size >= seq->allocated) { printf("seq size is too large.\n"); } @@ -75,30 +80,29 @@ void SeqPop(Seq * seq) { seq->size -= 1; } -Seq * SeqCreate(uint64_t len) { - Seq * seq = (Seq*)vmalloc(sizeof(Seq)); // NOLINT(*) +Seq* SeqCreate(uint64_t len) { + Seq* seq = (Seq*)vmalloc(sizeof(Seq)); // NOLINT(*) memset(seq, 0, sizeof(Seq)); seq->allocated = len; - seq->data = (uint32_t*)vmalloc(sizeof(uint32_t)*len); // NOLINT(*) + seq->data = (uint32_t*)vmalloc(sizeof(uint32_t) * len); // NOLINT(*) seq->push_back = SeqPush; seq->back = SeqBack; seq->pop_back = SeqPop; return seq; } -void SeqRelease(Seq ** seq) { +void SeqRelease(Seq** seq) { vfree((*seq)->data); vfree(*seq); } - // implementations of JSONReader /*! * \brief Takes the next char from the input source. * \return the next character. */ -char JSONReader_NextChar(JSONReader * reader) { +char JSONReader_NextChar(JSONReader* reader) { char ch = reader->isptr[0]; reader->isptr += 1; return ch; @@ -108,20 +112,22 @@ char JSONReader_NextChar(JSONReader * reader) { * \brief Returns the next char from the input source. * \return the next character. */ -char JSONReader_PeekNextChar(JSONReader * reader) { - return reader->isptr[0]; -} +char JSONReader_PeekNextChar(JSONReader* reader) { return reader->isptr[0]; } /*! * \brief Read next nonspace character. * \return the next nonspace character. */ -char JSONReader_NextNonSpace(JSONReader * reader) { +char JSONReader_NextNonSpace(JSONReader* reader) { int ch; do { ch = reader->NextChar(reader); - if (ch == '\n') { ++(reader->line_count_n_); } - if (ch == '\r') { ++(reader->line_count_r_); } + if (ch == '\n') { + ++(reader->line_count_n_); + } + if (ch == '\r') { + ++(reader->line_count_r_); + } } while (isspace(ch)); return ch; } @@ -130,12 +136,16 @@ char JSONReader_NextNonSpace(JSONReader * reader) { * \brief Read just before next nonspace but not read that. * \return the next nonspace character. */ -char JSONReader_PeekNextNonSpace(JSONReader * reader) { +char JSONReader_PeekNextNonSpace(JSONReader* reader) { int ch; while (1) { ch = reader->PeekNextChar(reader); - if (ch == '\n') { ++(reader->line_count_n_); } - if (ch == '\r') { ++(reader->line_count_r_); } + if (ch == '\n') { + ++(reader->line_count_n_); + } + if (ch == '\r') { + ++(reader->line_count_r_); + } if (!isspace(ch)) break; reader->NextChar(reader); } @@ -147,7 +157,7 @@ char JSONReader_PeekNextNonSpace(JSONReader * reader) { * \param out_str the output string. * \throw dmlc::Error when next token is not string */ -int JSONReader_ReadString(JSONReader * reader, char * out_str) { +int JSONReader_ReadString(JSONReader* reader, char* out_str) { int status = 0; char ch = reader->NextNonSpace(reader); char output[128]; @@ -158,15 +168,28 @@ int JSONReader_ReadString(JSONReader * reader, char * out_str) { if (ch == '\\') { char sch = reader->NextChar(reader); switch (sch) { - case 'r': snprintf(output + strlen(output), sizeof(output), "\r"); break; - case 'n': snprintf(output + strlen(output), sizeof(output), "\n"); break; - case '\\': snprintf(output + strlen(output), sizeof(output), "\\"); break; - case 't': snprintf(output + strlen(output), sizeof(output), "\t"); break; - case '\"': snprintf(output + strlen(output), sizeof(output), "\""); break; - default: fprintf(stderr, "unknown string escape %c\n", sch); + case 'r': + snprintf(output + strlen(output), sizeof(output), "\r"); + break; + case 'n': + snprintf(output + strlen(output), sizeof(output), "\n"); + break; + case '\\': + snprintf(output + strlen(output), sizeof(output), "\\"); + break; + case 't': + snprintf(output + strlen(output), sizeof(output), "\t"); + break; + case '\"': + snprintf(output + strlen(output), sizeof(output), "\""); + break; + default: + fprintf(stderr, "unknown string escape %c\n", sch); } } else { - if (ch == '\"') { break; } + if (ch == '\"') { + break; + } if (strlen(output) >= 127) { fprintf(stderr, "Error: detected buffer overflow.\n"); status = -1; @@ -182,13 +205,14 @@ int JSONReader_ReadString(JSONReader * reader, char * out_str) { } if (ch == EOF || ch == '\r' || ch == '\n') { fprintf(stderr, "Error at line X, Expect \'\"\' but reach end of line\n"); + status = -1; } } snprintf(out_str, sizeof(output), "%s", output); return status; } -int JSONReader_ReadUnsignedInteger(JSONReader * reader, unsigned int * out_value) { +int JSONReader_ReadUnsignedInteger(JSONReader* reader, unsigned int* out_value) { int status = 0; char* endptr; const char* icstr = reader->isptr; @@ -198,8 +222,7 @@ int JSONReader_ReadUnsignedInteger(JSONReader * reader, unsigned int * out_value return status; } - -int JSONReader_ReadInteger(JSONReader * reader, int64_t * out_value) { +int JSONReader_ReadInteger(JSONReader* reader, int64_t* out_value) { int status = 0; char* endptr; const char* icstr = reader->isptr; @@ -222,12 +245,12 @@ int JSONReader_ReadInteger(JSONReader * reader, int64_t * out_value) { * } * \endcode */ -void JSONReader_BeginObject(JSONReader * reader) { +void JSONReader_BeginObject(JSONReader* reader) { int ch = reader->NextNonSpace(reader); if (!(ch == '{')) { fprintf(stderr, "Error at line X, Expect \'{\' but got \'%c\'\n", ch); } - Seq * scope_counter_ = reader->scope_counter_; + Seq* scope_counter_ = reader->scope_counter_; scope_counter_->push_back(scope_counter_, 0); } @@ -238,9 +261,9 @@ void JSONReader_BeginObject(JSONReader * reader) { * \param out_key the key to the next object. * \return true if the read is successful, false if we are at end of the object. */ -uint8_t JSONReader_NextObjectItem(JSONReader * reader, char * out_key) { +uint8_t JSONReader_NextObjectItem(JSONReader* reader, char* out_key) { uint8_t next = 1; - Seq * scope_counter_ = reader->scope_counter_; + Seq* scope_counter_ = reader->scope_counter_; if (scope_counter_->back(scope_counter_)[0] != 0) { int ch = reader->NextNonSpace(reader); if (ch == EOF) { @@ -284,12 +307,12 @@ uint8_t JSONReader_NextObjectItem(JSONReader * reader, char * out_key) { * } * \endcode */ -void JSONReader_BeginArray(JSONReader * reader) { +void JSONReader_BeginArray(JSONReader* reader) { int ch = reader->NextNonSpace(reader); if (ch != '[') { fprintf(stderr, "Error at line X, Expect \'[\' but get \'%c\'\n", ch); } - Seq * scope_counter_ = reader->scope_counter_; + Seq* scope_counter_ = reader->scope_counter_; scope_counter_->push_back(scope_counter_, 0); } @@ -299,9 +322,9 @@ void JSONReader_BeginArray(JSONReader * reader) { * reader->Read to read in the value. * \return true if the read is successful, false if we are at end of the array. */ -uint8_t JSONReader_NextArrayItem(JSONReader * reader) { +uint8_t JSONReader_NextArrayItem(JSONReader* reader) { uint8_t next = 1; - Seq * scope_counter_ = reader->scope_counter_; + Seq* scope_counter_ = reader->scope_counter_; if (scope_counter_->back(scope_counter_)[0] != 0) { int ch = reader->NextNonSpace(reader); if (ch == EOF) { @@ -333,7 +356,7 @@ uint8_t JSONReader_NextArrayItem(JSONReader * reader) { * \brief Constructor. * \param is the input source. */ -JSONReader JSONReader_Create(const char * is) { +JSONReader JSONReader_Create(const char* is) { JSONReader reader; memset(&reader, 0, sizeof(JSONReader)); reader.scope_counter_ = SeqCreate(200); @@ -348,14 +371,14 @@ JSONReader JSONReader_Create(const char * is) { reader.BeginObject = JSONReader_BeginObject; reader.NextArrayItem = JSONReader_NextArrayItem; reader.NextObjectItem = JSONReader_NextObjectItem; - reader.is_ = (char*)vmalloc(strlen(is)+1); // NOLINT(*) - memset(reader.is_, 0, strlen(is)+1); - snprintf(reader.is_, strlen(is)+1, "%s", is); + reader.is_ = (char*)vmalloc(strlen(is) + 1); // NOLINT(*) + memset(reader.is_, 0, strlen(is) + 1); + snprintf(reader.is_, strlen(is) + 1, "%s", is); reader.isptr = reader.is_; return reader; } -void JSONReader_Release(JSONReader * reader) { +void JSONReader_Release(JSONReader* reader) { SeqRelease(&(reader->scope_counter_)); vfree(reader->is_); } diff --git a/src/runtime/crt/load_json.h b/src/runtime/crt/load_json.h index a5df7a055af0..0c9324777c1d 100644 --- a/src/runtime/crt/load_json.h +++ b/src/runtime/crt/load_json.h @@ -24,8 +24,8 @@ #ifndef TVM_RUNTIME_CRT_LOAD_JSON_H_ #define TVM_RUNTIME_CRT_LOAD_JSON_H_ -#include #include +#include enum { JSON_READ_TYPE_U8 = 1, @@ -42,12 +42,12 @@ enum { }; typedef struct Seq { - uint32_t * data; + uint32_t* data; uint64_t allocated; uint32_t size; - void (*push_back)(struct Seq * seq, uint32_t src); - uint32_t * (*back)(struct Seq * seq); - void (*pop_back)(struct Seq * seq); + void (*push_back)(struct Seq* seq, uint32_t src); + uint32_t* (*back)(struct Seq* seq); + void (*pop_back)(struct Seq* seq); } Seq; /*! @@ -56,8 +56,8 @@ typedef struct Seq { */ typedef struct JSONReader { /*! \brief internal reader string */ - char * is_; - char * isptr; + char* is_; + char* isptr; /*! \brief "\\r" counter */ size_t line_count_r_; /*! \brief "\\n" counter */ @@ -66,27 +66,27 @@ typedef struct JSONReader { * \brief record how many element processed in * current array/object scope. */ - Seq * scope_counter_; + Seq* scope_counter_; - char (*NextChar)(struct JSONReader * reader); - char (*NextNonSpace)(struct JSONReader * reader); - char (*PeekNextChar)(struct JSONReader * reader); - char (*PeekNextNonSpace)(struct JSONReader * reader); - int (*ReadUnsignedInteger)(struct JSONReader * reader, unsigned int * out_value); - int (*ReadInteger)(struct JSONReader * reader, int64_t * out_value); - int (*ReadString)(struct JSONReader * reader, char * out_value); - void (*BeginArray)(struct JSONReader * reader); - void (*BeginObject)(struct JSONReader * reader); - uint8_t (*NextObjectItem)(struct JSONReader * reader, char * out_key); - uint8_t (*NextArrayItem)(struct JSONReader * reader); + char (*NextChar)(struct JSONReader* reader); + char (*NextNonSpace)(struct JSONReader* reader); + char (*PeekNextChar)(struct JSONReader* reader); + char (*PeekNextNonSpace)(struct JSONReader* reader); + int (*ReadUnsignedInteger)(struct JSONReader* reader, unsigned int* out_value); + int (*ReadInteger)(struct JSONReader* reader, int64_t* out_value); + int (*ReadString)(struct JSONReader* reader, char* out_value); + void (*BeginArray)(struct JSONReader* reader); + void (*BeginObject)(struct JSONReader* reader); + uint8_t (*NextObjectItem)(struct JSONReader* reader, char* out_key); + uint8_t (*NextArrayItem)(struct JSONReader* reader); } JSONReader; /*! * \brief Constructor of JSONReader class * \param is the input source. */ -JSONReader JSONReader_Create(const char * is); +JSONReader JSONReader_Create(const char* is); -void JSONReader_Release(JSONReader * reader); +void JSONReader_Release(JSONReader* reader); #endif // TVM_RUNTIME_CRT_LOAD_JSON_H_ diff --git a/src/runtime/crt/logging.h b/src/runtime/crt/logging.h index 2c58834ca6a9..c711b3aa3bb9 100644 --- a/src/runtime/crt/logging.h +++ b/src/runtime/crt/logging.h @@ -27,31 +27,31 @@ #define TVM_RUNTIME_CRT_LOGGING_H_ #ifndef CHECK -#define CHECK(x) \ - do { \ - if (!(x)) { \ - fprintf(stderr, "Check failed: %s\n", #x); \ - exit(-1); \ - } \ - }while(0) +#define CHECK(x) \ + do { \ + if (!(x)) { \ + fprintf(stderr, "Check failed: %s\n", #x); \ + exit(-1); \ + } \ + } while (0) #endif #ifndef CHECK_BINARY_OP -#define CHECK_BINARY_OP(op, x, y, fmt, ...) \ - do { \ - if (!(x op y)) { \ +#define CHECK_BINARY_OP(op, x, y, fmt, ...) \ + do { \ + if (!(x op y)) { \ fprintf(stderr, "Check failed: %s %s %s: " fmt "\n", #x, #op, #y, ##__VA_ARGS__); \ - exit(-1); \ - } \ - }while(0) + exit(-1); \ + } \ + } while (0) #endif #ifndef CHECK_LT -#define CHECK_LT(x, y, fmt, ...) CHECK_BINARY_OP(<, x, y, fmt, ##__VA_ARGS__) +#define CHECK_LT(x, y, fmt, ...) CHECK_BINARY_OP(<, x, y, fmt, ##__VA_ARGS__) #endif #ifndef CHECK_GT -#define CHECK_GT(x, y, fmt, ...) CHECK_BINARY_OP(>, x, y, fmt, ##__VA_ARGS__) +#define CHECK_GT(x, y, fmt, ...) CHECK_BINARY_OP(>, x, y, fmt, ##__VA_ARGS__) #endif #ifndef CHECK_LE diff --git a/src/runtime/crt/memory.c b/src/runtime/crt/memory.c index 24175f6d6e55..c25749e44493 100644 --- a/src/runtime/crt/memory.c +++ b/src/runtime/crt/memory.c @@ -24,11 +24,10 @@ * To maximize portability, thread-safe feature has been dropped for now. */ +#include #include #include -#include - #include "logging.h" /*! Number of bits in a page */ @@ -61,7 +60,7 @@ typedef struct Page { /*! \brief The total number of pages */ tvm_index_t num_pages; /*! \brief Data */ - char * data; + char* data; } Page; // construct a new page @@ -76,12 +75,12 @@ Page PageCreate(tvm_index_t ptable_begin, tvm_index_t num_pages) { typedef struct PageTable { Page page[TVM_CRT_MAX_PAGES]; uint32_t count; - void (*resize)(struct PageTable * ptable, uint32_t size, Page * page); + void (*resize)(struct PageTable* ptable, uint32_t size, Page* page); } PageTable; -void PageTable_Resize(struct PageTable * ptable, uint32_t new_size, Page * page) { - CHECK_LE(ptable->count, new_size, - "size value (%d) is smaller than expected (%d).", new_size, ptable->count); +void PageTable_Resize(struct PageTable* ptable, uint32_t new_size, Page* page) { + CHECK_LE(ptable->count, new_size, "size value (%d) is smaller than expected (%d).", new_size, + ptable->count); for (uint32_t idx = ptable->count; idx < new_size; idx++) { ptable->page[idx] = *page; } @@ -89,19 +88,19 @@ void PageTable_Resize(struct PageTable * ptable, uint32_t new_size, Page * page) } typedef struct PageEntry { - char * addr; + char* addr; Page page; } PageEntry; typedef struct TLB { PageEntry entries[TVM_CRT_MAX_PAGES]; uint32_t count; - void (*set)(struct TLB * tlb, char * data, Page * page); - PageEntry * (*find)(struct TLB * tlb, char * data); + void (*set)(struct TLB* tlb, char* data, Page* page); + PageEntry* (*find)(struct TLB* tlb, char* data); } TLB; -void TLB_Set(TLB * tlb, char * data, Page * page) { - PageEntry * entry = tlb->find(tlb, data); +void TLB_Set(TLB* tlb, char* data, Page* page) { + PageEntry* entry = tlb->find(tlb, data); if (entry == 0) { tlb->entries[tlb->count].addr = data; tlb->entries[tlb->count].page = *page; @@ -112,8 +111,8 @@ void TLB_Set(TLB * tlb, char * data, Page * page) { } } -PageEntry * TLB_Find(TLB * tlb, char * data) { - PageEntry * entry = 0; +PageEntry* TLB_Find(TLB* tlb, char* data) { + PageEntry* entry = 0; for (uint32_t idx = 0; idx < tlb->count; idx++) { if (tlb->entries[idx].addr == data) { entry = tlb->entries + idx; @@ -131,14 +130,14 @@ typedef struct IndexedEntry { typedef struct MultiMap { IndexedEntry entries[TVM_CRT_MAX_PAGES]; uint32_t count; - IndexedEntry * (*lower_bound)(struct MultiMap * map, uint32_t npage); - IndexedEntry * (*end)(struct MultiMap * map); - void (*erase)(struct MultiMap * map, IndexedEntry * entry); - void (*insert)(struct MultiMap * map, uint32_t npage, Page * p); + IndexedEntry* (*lower_bound)(struct MultiMap* map, uint32_t npage); + IndexedEntry* (*end)(struct MultiMap* map); + void (*erase)(struct MultiMap* map, IndexedEntry* entry); + void (*insert)(struct MultiMap* map, uint32_t npage, Page* p); } MultiMap; -IndexedEntry * MultiMap_LowerBound(struct MultiMap * map, uint32_t npage) { - IndexedEntry * entry = 0; +IndexedEntry* MultiMap_LowerBound(struct MultiMap* map, uint32_t npage) { + IndexedEntry* entry = 0; for (uint32_t idx = 0; idx < map->count; idx++) { if (map->entries[idx].index >= npage) { entry = map->entries + idx; @@ -148,12 +147,12 @@ IndexedEntry * MultiMap_LowerBound(struct MultiMap * map, uint32_t npage) { return entry; } -IndexedEntry * MultiMap_End(struct MultiMap * map) { - IndexedEntry * entry = 0; +IndexedEntry* MultiMap_End(struct MultiMap* map) { + IndexedEntry* entry = 0; return entry; } -void MultiMap_Erase(struct MultiMap * map, IndexedEntry * entry) { +void MultiMap_Erase(struct MultiMap* map, IndexedEntry* entry) { for (uint32_t idx = 0; idx < map->count; idx++) { if ((map->entries + idx) == entry) { memcpy(map->entries + idx, map->entries + (idx + 1), @@ -164,7 +163,7 @@ void MultiMap_Erase(struct MultiMap * map, IndexedEntry * entry) { } } -void MultiMap_Insert(struct MultiMap * map, uint32_t npage, Page * p) { +void MultiMap_Insert(struct MultiMap* map, uint32_t npage, Page* p) { CHECK_LE(map->count + 1, TVM_CRT_MAX_PAGES, "invalid number of free pages."); for (uint32_t idx = map->count; idx < (map->count + npage); idx++) { map->entries[map->count].index = npage; @@ -183,20 +182,20 @@ typedef struct MemoryManager { * \param size The size of memory * \return The virtual address */ - void* (*Alloc)(struct MemoryManager * mgr, tvm_index_t size); + void* (*Alloc)(struct MemoryManager* mgr, tvm_index_t size); /*! * \brief Allocate memory from manager * \param ptr The pointer to the memory area to be reallocated * \param size The size of memory * \return The virtual address */ - void* (*Realloc)(struct MemoryManager * mgr, void * ptr, tvm_index_t size); + void* (*Realloc)(struct MemoryManager* mgr, void* ptr, tvm_index_t size); /*! * \brief Free the memory. * \param ptr The pointer to the memory to deallocate * \return The virtual address */ - void (*Free)(struct MemoryManager * mgr, void* data); + void (*Free)(struct MemoryManager* mgr, void* data); // Physical address -> page PageTable ptable; @@ -211,11 +210,11 @@ typedef struct MemoryManager { * \param size The size of memory * \return The virtual address */ -void* MemoryManager_Alloc(MemoryManager * mgr, tvm_index_t size) { - char * data = 0; +void* MemoryManager_Alloc(MemoryManager* mgr, tvm_index_t size) { + char* data = 0; tvm_index_t npage = (size + kPageSize - 1) / kPageSize; - MultiMap * free_map = &(mgr->free_map); - IndexedEntry * it = free_map->lower_bound(free_map, npage); + MultiMap* free_map = &(mgr->free_map); + IndexedEntry* it = free_map->lower_bound(free_map, npage); tvm_index_t start = 0; if (it != free_map->end(free_map)) { Page p = it->page; @@ -224,22 +223,22 @@ void* MemoryManager_Alloc(MemoryManager * mgr, tvm_index_t size) { start = p.ptable_begin; npage = p.num_pages; } else { - PageTable * ptable = &(mgr->ptable); + PageTable* ptable = &(mgr->ptable); start = ptable->count; CHECK_LE((unsigned)(start + npage), (sizeof(g_memory_pool) / kPageSize), - "insufficient memory, start=%" PRId64 ", npage=%" PRId64 ", total=%" PRId64 "", - start, npage, start + npage); + "insufficient memory, start=%" PRId64 ", npage=%" PRId64 ", total=%" PRId64 "", start, + npage, start + npage); /* insert page entry */ Page p = PageCreate(start, npage); ptable->resize(ptable, start + npage, &p); data = p.data; - TLB * pmap = &(mgr->pmap); + TLB* pmap = &(mgr->pmap); pmap->set(pmap, data, &p); } vleak_size++; #if TVM_CRT_DEBUG > 1 - printf("allocate: addr=%p, start=%d/%d, npage=%d, vleak=%d\n", - data, start, TVM_CRT_MAX_PAGES, npage, vleak_size); + printf("allocate: addr=%p, start=%d/%d, npage=%d, vleak=%d\n", data, start, TVM_CRT_MAX_PAGES, + npage, vleak_size); #endif // TVM_CRT_DEBUG return data; } @@ -250,26 +249,26 @@ void* MemoryManager_Alloc(MemoryManager * mgr, tvm_index_t size) { * \param size The size of memory * \return The virtual address */ -void* MemoryManager_Realloc(MemoryManager * mgr, void * ptr, tvm_index_t size) { - char * data = (char*)ptr; // NOLINT(*) - PageTable * ptable = &(mgr->ptable); - TLB * pmap = &(mgr->pmap); - MultiMap * free_map = &(mgr->free_map); +void* MemoryManager_Realloc(MemoryManager* mgr, void* ptr, tvm_index_t size) { + char* data = (char*)ptr; // NOLINT(*) + PageTable* ptable = &(mgr->ptable); + TLB* pmap = &(mgr->pmap); + MultiMap* free_map = &(mgr->free_map); tvm_index_t start = 0; tvm_index_t npage = (size + kPageSize - 1) / kPageSize; if (ptr) { // get page size for given pointer CHECK_NE(pmap->count, 0, "invalid translation look-aside buffer."); - PageEntry * entry = pmap->find(pmap, (char*)ptr); // NOLINT(*) + PageEntry* entry = pmap->find(pmap, (char*)ptr); // NOLINT(*) CHECK_NE(entry, 0, "no valid page entry found."); - Page * pptr = &(entry->page); + Page* pptr = &(entry->page); // if the page size is smaller than target page size, // try allocate new space if (pptr->num_pages < npage) { // TODO(liangfu): found out whether we can extend current entry // // insert new page entry - IndexedEntry * it = free_map->lower_bound(free_map, npage); + IndexedEntry* it = free_map->lower_bound(free_map, npage); if (it != free_map->end(free_map)) { data = it->page.data; start = it->page.ptable_begin; @@ -293,7 +292,7 @@ void* MemoryManager_Realloc(MemoryManager * mgr, void * ptr, tvm_index_t size) { start = pptr->ptable_begin; } } else { - IndexedEntry * it = free_map->lower_bound(free_map, npage); + IndexedEntry* it = free_map->lower_bound(free_map, npage); if (it != free_map->end(free_map)) { Page p = it->page; free_map->erase(free_map, it); @@ -301,7 +300,7 @@ void* MemoryManager_Realloc(MemoryManager * mgr, void * ptr, tvm_index_t size) { start = p.ptable_begin; npage = p.num_pages; } else { - PageTable * ptable = &(mgr->ptable); + PageTable* ptable = &(mgr->ptable); start = ptable->count; CHECK_LE((unsigned)(start + npage), (sizeof(g_memory_pool) / kPageSize), "insufficient memory, start=%" PRId64 ", npage=%" PRId64 ", total=%" PRId64 "", @@ -310,14 +309,14 @@ void* MemoryManager_Realloc(MemoryManager * mgr, void * ptr, tvm_index_t size) { Page p = PageCreate(start, npage); ptable->resize(ptable, start + npage, &p); data = p.data; - TLB * pmap = &(mgr->pmap); + TLB* pmap = &(mgr->pmap); pmap->set(pmap, data, &p); } vleak_size++; } #if TVM_CRT_DEBUG > 1 - printf("reallocate: addr=%p, start=%d/%d, npage=%d, vleak=%d, size=%d\n", - data, start, TVM_CRT_MAX_PAGES, npage, vleak_size, size); + printf("reallocate: addr=%p, start=%d/%d, npage=%d, vleak=%d, size=%d\n", data, start, + TVM_CRT_MAX_PAGES, npage, vleak_size, size); #endif // TVM_CRT_DEBUG return data; } @@ -327,22 +326,22 @@ void* MemoryManager_Realloc(MemoryManager * mgr, void * ptr, tvm_index_t size) { * \param ptr The pointer to the memory to deallocate * \return The virtual address */ -void MemoryManager_Free(MemoryManager * mgr, void* ptr) { - TLB * pmap = &(mgr->pmap); +void MemoryManager_Free(MemoryManager* mgr, void* ptr) { + TLB* pmap = &(mgr->pmap); CHECK_NE(pmap->count, 0, "invalid translation look-aside buffer."); - PageEntry * entry = pmap->find(pmap, (char*)ptr); // NOLINT(*) + PageEntry* entry = pmap->find(pmap, (char*)ptr); // NOLINT(*) CHECK_NE(entry, 0, "no valid page entry found."); - Page * p = &(entry->page); - MultiMap * free_map = &(mgr->free_map); + Page* p = &(entry->page); + MultiMap* free_map = &(mgr->free_map); free_map->insert(free_map, p->num_pages, p); vleak_size--; #if TVM_CRT_DEBUG > 1 - printf("release: addr=%p, start=%d/%d, npage=%d, vleak=%d\n", - ptr, entry->page.ptable_begin, TVM_CRT_MAX_PAGES, entry->page.num_pages, vleak_size); + printf("release: addr=%p, start=%d/%d, npage=%d, vleak=%d\n", ptr, entry->page.ptable_begin, + TVM_CRT_MAX_PAGES, entry->page.num_pages, vleak_size); #endif // TVM_CRT_DEBUG } -MemoryManager * MemoryManagerCreate() { +MemoryManager* MemoryManagerCreate() { static MemoryManager mgr; memset(&mgr, 0, sizeof(MemoryManager)); /* handle MemoryManager member functions */ @@ -362,10 +361,10 @@ MemoryManager * MemoryManagerCreate() { return &mgr; } -MemoryManager * TVMGetGlobalMemoryManager() { +MemoryManager* TVMGetGlobalMemoryManager() { /* initialize once */ static uint32_t initialized = 0; - static MemoryManager * mgr; + static MemoryManager* mgr; if (!initialized) { mgr = MemoryManagerCreate(); memset(g_memory_pool, 0, sizeof(g_memory_pool)); @@ -375,19 +374,19 @@ MemoryManager * TVMGetGlobalMemoryManager() { } /** \brief Allocate memory from manager */ -void * vmalloc(size_t size) { - MemoryManager * mgr = TVMGetGlobalMemoryManager(); +void* vmalloc(size_t size) { + MemoryManager* mgr = TVMGetGlobalMemoryManager(); return mgr->Alloc(mgr, size); } /** \brief Reallocate memory from manager */ -void * vrealloc(void * ptr, size_t size) { - MemoryManager * mgr = TVMGetGlobalMemoryManager(); +void* vrealloc(void* ptr, size_t size) { + MemoryManager* mgr = TVMGetGlobalMemoryManager(); return mgr->Realloc(mgr, ptr, size); } /** \brief Release memory from manager */ -void vfree(void * ptr) { - MemoryManager * mgr = TVMGetGlobalMemoryManager(); +void vfree(void* ptr) { + MemoryManager* mgr = TVMGetGlobalMemoryManager(); mgr->Free(mgr, ptr); } diff --git a/src/runtime/crt/module.h b/src/runtime/crt/module.h index 9ef287d650d8..57f8dd708f88 100644 --- a/src/runtime/crt/module.h +++ b/src/runtime/crt/module.h @@ -24,8 +24,8 @@ #ifndef TVM_RUNTIME_CRT_MODULE_H_ #define TVM_RUNTIME_CRT_MODULE_H_ -#include #include +#include struct TVMPackedFunc; @@ -41,7 +41,7 @@ typedef struct TVMModule { * * This function will return PackedFunc(nullptr) if function do not exist. */ - void (*GetFunction)(struct TVMModule * mod, const char * name, struct TVMPackedFunc * pf); + void (*GetFunction)(struct TVMModule* mod, const char* name, struct TVMPackedFunc* pf); } TVMModule; #endif // TVM_RUNTIME_CRT_MODULE_H_ diff --git a/src/runtime/crt/ndarray.c b/src/runtime/crt/ndarray.c index ed623fbf3de8..17e210785aa1 100644 --- a/src/runtime/crt/ndarray.c +++ b/src/runtime/crt/ndarray.c @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -22,25 +22,25 @@ * \brief NDArray container infratructure. */ -#include - #include "ndarray.h" -TVMNDArray TVMNDArray_Create(uint32_t ndim, const tvm_index_t * shape, - DLDataType dtype, DLContext ctx) { +#include + +TVMNDArray TVMNDArray_Create(uint32_t ndim, const tvm_index_t* shape, DLDataType dtype, + DLContext ctx) { TVMNDArray ret; memset(&ret, 0, sizeof(TVMNDArray)); ret.dl_tensor.ndim = ndim; - ret.dl_tensor.shape = (int64_t*)vmalloc(sizeof(int64_t)*ndim); // NOLINT(*) - memcpy(ret.dl_tensor.shape, shape, sizeof(int64_t)*ndim); + ret.dl_tensor.shape = (int64_t*)vmalloc(sizeof(int64_t) * ndim); // NOLINT(*) + memcpy(ret.dl_tensor.shape, shape, sizeof(int64_t) * ndim); ret.dl_tensor.dtype = dtype; ret.dl_tensor.ctx = ctx; ret.dl_tensor.data = 0; return ret; } -TVMNDArray TVMNDArray_Empty(uint32_t ndim, const tvm_index_t * shape, - DLDataType dtype, DLContext ctx) { +TVMNDArray TVMNDArray_Empty(uint32_t ndim, const tvm_index_t* shape, DLDataType dtype, + DLContext ctx) { TVMNDArray ret = TVMNDArray_Create(ndim, shape, dtype, ctx); int64_t num_elems = 1; int elem_bytes = (dtype.bits + 7) / 8; @@ -53,21 +53,26 @@ TVMNDArray TVMNDArray_Empty(uint32_t ndim, const tvm_index_t * shape, return ret; } -int TVMNDArray_Load(TVMNDArray * ret, const char ** strm) { +int TVMNDArray_Load(TVMNDArray* ret, const char** strm) { int32_t status = 0; uint64_t header, reserved; - header = ((uint64_t*)*strm)[0]; *strm += sizeof(header); // NOLINT(*) + header = ((uint64_t*)*strm)[0]; // NOLINT(*) + *strm += sizeof(header); if (header != kTVMNDArrayMagic) { fprintf(stderr, "Invalid DLTensor file format\n"); status = -1; } - reserved = ((uint64_t*)*strm)[0]; *strm += sizeof(reserved); // NOLINT(*) + reserved = ((uint64_t*)*strm)[0]; // NOLINT(*) + *strm += sizeof(reserved); DLContext ctx; uint32_t ndim; DLDataType dtype; - ctx = ((DLContext*)*strm)[0]; *strm += sizeof(ctx); // NOLINT(*) - ndim = ((uint32_t*)*strm)[0]; *strm += sizeof(ndim); // NOLINT(*) - dtype = ((DLDataType*)*strm)[0]; *strm += sizeof(dtype); // NOLINT(*) + ctx = ((DLContext*)*strm)[0]; // NOLINT(*) + *strm += sizeof(ctx); + ndim = ((uint32_t*)*strm)[0]; // NOLINT(*) + *strm += sizeof(ndim); + dtype = ((DLDataType*)*strm)[0]; // NOLINT(*) + *strm += sizeof(dtype); if ((ndim < 0) || (ndim > TVM_CRT_MAX_NDIM)) { fprintf(stderr, "Invalid ndim=%d: expected to be 0 ~ %d.\n", ndim, TVM_CRT_MAX_NDIM); status = -1; @@ -80,7 +85,8 @@ int TVMNDArray_Load(TVMNDArray * ret, const char ** strm) { uint32_t idx; if (ndim != 0) { for (idx = 0; idx < ndim; idx++) { - shape[idx] = ((int64_t*)*strm)[0]; *strm += sizeof(shape[idx]); // NOLINT(*) + shape[idx] = ((int64_t*)*strm)[0]; // NOLINT(*) + *strm += sizeof(shape[idx]); } } *ret = TVMNDArray_Empty(ndim, shape, dtype, ctx); @@ -90,11 +96,13 @@ int TVMNDArray_Load(TVMNDArray * ret, const char ** strm) { num_elems *= ret->dl_tensor.shape[idx]; } int64_t data_byte_size; - data_byte_size = ((int64_t*)*strm)[0]; *strm += sizeof(data_byte_size); // NOLINT(*) + data_byte_size = ((int64_t*)*strm)[0]; // NOLINT(*) + *strm += sizeof(data_byte_size); if (!(data_byte_size == num_elems * elem_bytes)) { - fprintf(stderr, "invalid DLTensor file format: data_byte_size=%ld, " - "while num_elems*elem_bytes=%ld\n", - data_byte_size, (num_elems * elem_bytes)); + fprintf(stderr, + "invalid DLTensor file format: data_byte_size=%d, " + "while num_elems*elem_bytes=%d\n", + (int)data_byte_size, (int)(num_elems * elem_bytes)); // NOLINT(*) status = -1; } memcpy(ret->dl_tensor.data, *strm, data_byte_size); @@ -103,14 +111,14 @@ int TVMNDArray_Load(TVMNDArray * ret, const char ** strm) { return status; } -TVMNDArray TVMNDArray_CreateView(TVMNDArray * arr, const tvm_index_t * shape, - uint32_t ndim, DLDataType dtype) { +TVMNDArray TVMNDArray_CreateView(TVMNDArray* arr, const tvm_index_t* shape, uint32_t ndim, + DLDataType dtype) { TVMNDArray ret = TVMNDArray_Create(ndim, shape, dtype, arr->dl_tensor.ctx); ret.dl_tensor.data = arr->dl_tensor.data; return ret; } -int TVMNDArray_Release(TVMNDArray * arr) { +int TVMNDArray_Release(TVMNDArray* arr) { vfree(arr->dl_tensor.data); arr->dl_tensor.data = 0; vfree(arr->dl_tensor.shape); diff --git a/src/runtime/crt/ndarray.h b/src/runtime/crt/ndarray.h index dde23ca6cd41..ae76726ae0b9 100644 --- a/src/runtime/crt/ndarray.h +++ b/src/runtime/crt/ndarray.h @@ -24,13 +24,12 @@ #ifndef TVM_RUNTIME_CRT_NDARRAY_H_ #define TVM_RUNTIME_CRT_NDARRAY_H_ -#include -#include #include - -#include #include #include +#include +#include +#include /*! \brief Magic number for NDArray file */ static const uint64_t kTVMNDArrayMagic = 0xDD5E40F096B4A13F; @@ -42,17 +41,17 @@ typedef struct TVMNDArray { DLTensor dl_tensor; } TVMNDArray; -TVMNDArray TVMNDArray_Create(uint32_t ndim, const tvm_index_t * shape, - DLDataType dtype, DLContext ctx); +TVMNDArray TVMNDArray_Create(uint32_t ndim, const tvm_index_t* shape, DLDataType dtype, + DLContext ctx); -TVMNDArray TVMNDArray_Empty(uint32_t ndim, const tvm_index_t * shape, - DLDataType dtype, DLContext ctx); +TVMNDArray TVMNDArray_Empty(uint32_t ndim, const tvm_index_t* shape, DLDataType dtype, + DLContext ctx); -int TVMNDArray_Load(TVMNDArray * ret, const char ** strm); +int TVMNDArray_Load(TVMNDArray* ret, const char** strm); -TVMNDArray TVMNDArray_CreateView(TVMNDArray * arr, const tvm_index_t * shape, - uint32_t ndim, DLDataType dtype); +TVMNDArray TVMNDArray_CreateView(TVMNDArray* arr, const tvm_index_t* shape, uint32_t ndim, + DLDataType dtype); -int TVMNDArray_Release(TVMNDArray * arr); +int TVMNDArray_Release(TVMNDArray* arr); #endif // TVM_RUNTIME_CRT_NDARRAY_H_ diff --git a/src/runtime/crt/packed_func.h b/src/runtime/crt/packed_func.h index 93898a436c88..d4597e62fd0f 100644 --- a/src/runtime/crt/packed_func.h +++ b/src/runtime/crt/packed_func.h @@ -24,29 +24,34 @@ #ifndef TVM_RUNTIME_CRT_PACKED_FUNC_H_ #define TVM_RUNTIME_CRT_PACKED_FUNC_H_ -#include - +#include #include #include -#include +#include #include "module.h" -static inline DLDataType String2DLDataType(const char * s) { +static inline DLDataType String2DLDataType(const char* s) { DLDataType t; // handle None type if (strlen(s) == 0) { - t.bits = 0; t.lanes = 0; t.code = kTVMOpaqueHandle; + t.bits = 0; + t.lanes = 0; + t.code = kTVMOpaqueHandle; return t; } - t.bits = 32; t.lanes = 1; + t.bits = 32; + t.lanes = 1; const char* scan; if (!strncmp(s, "int", 3)) { - t.code = kDLInt; scan = s + 3; + t.code = kDLInt; + scan = s + 3; } else if (!strncmp(s, "uint", 4)) { - t.code = kDLUInt; scan = s + 4; + t.code = kDLUInt; + scan = s + 4; } else if (!strncmp(s, "float", 5)) { - t.code = kDLFloat; scan = s + 5; + t.code = kDLFloat; + scan = s + 5; } else if (!strncmp(s, "handle", 6)) { t.code = kTVMOpaqueHandle; t.bits = 64; // handle uses 64 bit by default. @@ -75,11 +80,11 @@ static inline DLDataType String2DLDataType(const char * s) { typedef struct TVMArgs { TVMValue values[TVM_CRT_MAX_ARGS]; - int tcodes[TVM_CRT_MAX_ARGS]; /* Data type should be identical to type_codes in TVMPackedCFunc */ + int tcodes[TVM_CRT_MAX_ARGS]; /* Data type should be identical to type_codes in TVMPackedCFunc */ uint32_t values_count; } TVMArgs; -static inline TVMArgs TVMArgs_Create(TVMValue * values, uint32_t * tcodes, uint32_t values_count) { +static inline TVMArgs TVMArgs_Create(TVMValue* values, uint32_t* tcodes, uint32_t values_count) { uint32_t idx; TVMArgs args; memset(&args, 0, sizeof(args)); @@ -91,8 +96,8 @@ static inline TVMArgs TVMArgs_Create(TVMValue * values, uint32_t * tcodes, uint3 return args; } -static inline int TVMNoOperation(TVMValue * args, int * type_codes, int num_args, - TVMRetValueHandle ret, void * res) { +static inline int TVMNoOperation(TVMValue* args, int* type_codes, int num_args, + TVMRetValueHandle ret, void* res) { return 0; } @@ -100,24 +105,24 @@ typedef struct TVMPackedFunc { char name[200]; TVMPackedCFunc fexec; TVMArgs args; - void (*Call)(struct TVMPackedFunc * pf); - void (*SetArgs)(struct TVMPackedFunc * pf, const struct TVMArgs * args); + void (*Call)(struct TVMPackedFunc* pf); + void (*SetArgs)(struct TVMPackedFunc* pf, const struct TVMArgs* args); } TVMPackedFunc; -static inline void TVMPackedFunc_Call(TVMPackedFunc * pf) { +static inline void TVMPackedFunc_Call(TVMPackedFunc* pf) { pf->fexec(pf->args.values, pf->args.tcodes, pf->args.values_count, 0, 0); } -static inline void TVMPackedFunc_SetArgs(TVMPackedFunc * pf, const TVMArgs * args) { +static inline void TVMPackedFunc_SetArgs(TVMPackedFunc* pf, const TVMArgs* args) { memcpy(&(pf->args), args, sizeof(TVMArgs)); } -TVMPackedFunc * g_fexecs = 0; +TVMPackedFunc* g_fexecs = 0; uint32_t g_fexecs_count = 0; // Implement TVMModule::GetFunction // Put implementation in this file so we have seen the TVMPackedFunc -static inline void TVMModule_GetFunction(TVMModule * mod, const char * name, TVMPackedFunc * pf) { +static inline void TVMModule_GetFunction(TVMModule* mod, const char* name, TVMPackedFunc* pf) { int idx; memset(pf, 0, sizeof(TVMPackedFunc)); assert(strlen(name) <= sizeof(pf->name)); diff --git a/src/runtime/cuda/cuda_common.h b/src/runtime/cuda/cuda_common.h index 87cf3be5491d..25ff28a91a6c 100644 --- a/src/runtime/cuda/cuda_common.h +++ b/src/runtime/cuda/cuda_common.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -26,7 +26,9 @@ #include #include + #include + #include "../workspace_pool.h" namespace tvm { @@ -36,18 +38,16 @@ namespace runtime { { \ CUresult result = x; \ if (result != CUDA_SUCCESS && result != CUDA_ERROR_DEINITIALIZED) { \ - const char *msg; \ + const char* msg; \ cuGetErrorName(result, &msg); \ - LOG(FATAL) \ - << "CUDAError: " #x " failed with error: " << msg; \ + LOG(FATAL) << "CUDAError: " #x " failed with error: " << msg; \ } \ } -#define CUDA_CALL(func) \ - { \ - cudaError_t e = (func); \ - CHECK(e == cudaSuccess || e == cudaErrorCudartUnloading) \ - << "CUDA: " << cudaGetErrorString(e); \ +#define CUDA_CALL(func) \ + { \ + cudaError_t e = (func); \ + CHECK(e == cudaSuccess || e == cudaErrorCudartUnloading) << "CUDA: " << cudaGetErrorString(e); \ } /*! \brief Thread local workspace */ diff --git a/src/runtime/cuda/cuda_device_api.cc b/src/runtime/cuda/cuda_device_api.cc index d9f03e773bc9..a6d4a5499469 100644 --- a/src/runtime/cuda/cuda_device_api.cc +++ b/src/runtime/cuda/cuda_device_api.cc @@ -21,13 +21,14 @@ * \file cuda_device_api.cc * \brief GPU specific API */ -#include - -#include -#include #include #include +#include +#include +#include + #include + #include "cuda_common.h" namespace tvm { @@ -35,40 +36,32 @@ namespace runtime { class CUDADeviceAPI final : public DeviceAPI { public: - void SetDevice(TVMContext ctx) final { - CUDA_CALL(cudaSetDevice(ctx.device_id)); - } + void SetDevice(TVMContext ctx) final { CUDA_CALL(cudaSetDevice(ctx.device_id)); } void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final { int value = 0; switch (kind) { case kExist: - value = ( - cudaDeviceGetAttribute( - &value, cudaDevAttrMaxThreadsPerBlock, ctx.device_id) - == cudaSuccess); + value = (cudaDeviceGetAttribute(&value, cudaDevAttrMaxThreadsPerBlock, ctx.device_id) == + cudaSuccess); break; case kMaxThreadsPerBlock: { - CUDA_CALL(cudaDeviceGetAttribute( - &value, cudaDevAttrMaxThreadsPerBlock, ctx.device_id)); + CUDA_CALL(cudaDeviceGetAttribute(&value, cudaDevAttrMaxThreadsPerBlock, ctx.device_id)); break; } case kWarpSize: { - CUDA_CALL(cudaDeviceGetAttribute( - &value, cudaDevAttrWarpSize, ctx.device_id)); + CUDA_CALL(cudaDeviceGetAttribute(&value, cudaDevAttrWarpSize, ctx.device_id)); break; } case kMaxSharedMemoryPerBlock: { - CUDA_CALL(cudaDeviceGetAttribute( - &value, cudaDevAttrMaxSharedMemoryPerBlock, ctx.device_id)); + CUDA_CALL( + cudaDeviceGetAttribute(&value, cudaDevAttrMaxSharedMemoryPerBlock, ctx.device_id)); break; } case kComputeVersion: { std::ostringstream os; - CUDA_CALL(cudaDeviceGetAttribute( - &value, cudaDevAttrComputeCapabilityMajor, ctx.device_id)); + CUDA_CALL(cudaDeviceGetAttribute(&value, cudaDevAttrComputeCapabilityMajor, ctx.device_id)); os << value << "."; - CUDA_CALL(cudaDeviceGetAttribute( - &value, cudaDevAttrComputeCapabilityMinor, ctx.device_id)); + CUDA_CALL(cudaDeviceGetAttribute(&value, cudaDevAttrComputeCapabilityMinor, ctx.device_id)); os << value; *rv = os.str(); return; @@ -81,40 +74,33 @@ class CUDADeviceAPI final : public DeviceAPI { return; } case kMaxClockRate: { - CUDA_CALL(cudaDeviceGetAttribute( - &value, cudaDevAttrClockRate, ctx.device_id)); + CUDA_CALL(cudaDeviceGetAttribute(&value, cudaDevAttrClockRate, ctx.device_id)); break; } case kMultiProcessorCount: { - CUDA_CALL(cudaDeviceGetAttribute( - &value, cudaDevAttrMultiProcessorCount, ctx.device_id)); + CUDA_CALL(cudaDeviceGetAttribute(&value, cudaDevAttrMultiProcessorCount, ctx.device_id)); break; } case kMaxThreadDimensions: { int dims[3]; - CUDA_CALL(cudaDeviceGetAttribute( - &dims[0], cudaDevAttrMaxBlockDimX, ctx.device_id)); - CUDA_CALL(cudaDeviceGetAttribute( - &dims[1], cudaDevAttrMaxBlockDimY, ctx.device_id)); - CUDA_CALL(cudaDeviceGetAttribute( - &dims[2], cudaDevAttrMaxBlockDimZ, ctx.device_id)); + CUDA_CALL(cudaDeviceGetAttribute(&dims[0], cudaDevAttrMaxBlockDimX, ctx.device_id)); + CUDA_CALL(cudaDeviceGetAttribute(&dims[1], cudaDevAttrMaxBlockDimY, ctx.device_id)); + CUDA_CALL(cudaDeviceGetAttribute(&dims[2], cudaDevAttrMaxBlockDimZ, ctx.device_id)); std::stringstream ss; // use json string to return multiple int values; - ss << "[" << dims[0] <<", " << dims[1] << ", " << dims[2] << "]"; + ss << "[" << dims[0] << ", " << dims[1] << ", " << dims[2] << "]"; *rv = ss.str(); return; } - case kGcnArch: return; + case kGcnArch: + return; } *rv = value; } - void* AllocDataSpace(TVMContext ctx, - size_t nbytes, - size_t alignment, + void* AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment, DLDataType type_hint) final { - CHECK_EQ(256 % alignment, 0U) - << "CUDA space is aligned at 256 bytes"; - void *ret; + CHECK_EQ(256 % alignment, 0U) << "CUDA space is aligned at 256 bytes"; + void* ret; if (ctx.device_type == kDLCPUPinned) { CUDA_CALL(cudaMallocHost(&ret, nbytes)); } else { @@ -133,14 +119,8 @@ class CUDADeviceAPI final : public DeviceAPI { } } - void CopyDataFromTo(const void* from, - size_t from_offset, - void* to, - size_t to_offset, - size_t size, - TVMContext ctx_from, - TVMContext ctx_to, - DLDataType type_hint, + void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size, + TVMContext ctx_from, TVMContext ctx_to, DLDataType type_hint, TVMStreamHandle stream) final { cudaStream_t cu_stream = static_cast(stream); from = static_cast(from) + from_offset; @@ -156,8 +136,8 @@ class CUDADeviceAPI final : public DeviceAPI { // In case there is a copy from host mem to host mem */ if (ctx_to.device_type == kDLCPU && ctx_from.device_type == kDLCPU) { - memcpy(to, from, size); - return; + memcpy(to, from, size); + return; } if (ctx_from.device_type == kDLGPU && ctx_to.device_type == kDLGPU) { @@ -165,9 +145,7 @@ class CUDADeviceAPI final : public DeviceAPI { if (ctx_from.device_id == ctx_to.device_id) { GPUCopy(from, to, size, cudaMemcpyDeviceToDevice, cu_stream); } else { - cudaMemcpyPeerAsync(to, ctx_to.device_id, - from, ctx_from.device_id, - size, cu_stream); + cudaMemcpyPeerAsync(to, ctx_to.device_id, from, ctx_from.device_id, size, cu_stream); } } else if (ctx_from.device_type == kDLGPU && ctx_to.device_type == kDLCPU) { CUDA_CALL(cudaSetDevice(ctx_from.device_id)); @@ -210,8 +188,7 @@ class CUDADeviceAPI final : public DeviceAPI { } void SetStream(TVMContext ctx, TVMStreamHandle stream) final { - CUDAThreadEntry::ThreadLocal() - ->stream = static_cast(stream); + CUDAThreadEntry::ThreadLocal()->stream = static_cast(stream); } void* AllocWorkspace(TVMContext ctx, size_t size, DLDataType type_hint) final { @@ -223,16 +200,12 @@ class CUDADeviceAPI final : public DeviceAPI { } static const std::shared_ptr& Global() { - static std::shared_ptr inst = - std::make_shared(); + static std::shared_ptr inst = std::make_shared(); return inst; } private: - static void GPUCopy(const void* from, - void* to, - size_t size, - cudaMemcpyKind kind, + static void GPUCopy(const void* from, void* to, size_t size, cudaMemcpyKind kind, cudaStream_t stream) { if (stream != 0) { CUDA_CALL(cudaMemcpyAsync(to, from, size, kind, stream)); @@ -244,25 +217,19 @@ class CUDADeviceAPI final : public DeviceAPI { typedef dmlc::ThreadLocalStore CUDAThreadStore; -CUDAThreadEntry::CUDAThreadEntry() - : pool(kDLGPU, CUDADeviceAPI::Global()) { -} +CUDAThreadEntry::CUDAThreadEntry() : pool(kDLGPU, CUDADeviceAPI::Global()) {} -CUDAThreadEntry* CUDAThreadEntry::ThreadLocal() { - return CUDAThreadStore::Get(); -} +CUDAThreadEntry* CUDAThreadEntry::ThreadLocal() { return CUDAThreadStore::Get(); } -TVM_REGISTER_GLOBAL("device_api.gpu") -.set_body([](TVMArgs args, TVMRetValue* rv) { - DeviceAPI* ptr = CUDADeviceAPI::Global().get(); - *rv = static_cast(ptr); - }); +TVM_REGISTER_GLOBAL("device_api.gpu").set_body([](TVMArgs args, TVMRetValue* rv) { + DeviceAPI* ptr = CUDADeviceAPI::Global().get(); + *rv = static_cast(ptr); +}); -TVM_REGISTER_GLOBAL("device_api.cpu_pinned") -.set_body([](TVMArgs args, TVMRetValue* rv) { - DeviceAPI* ptr = CUDADeviceAPI::Global().get(); - *rv = static_cast(ptr); - }); +TVM_REGISTER_GLOBAL("device_api.cpu_pinned").set_body([](TVMArgs args, TVMRetValue* rv) { + DeviceAPI* ptr = CUDADeviceAPI::Global().get(); + *rv = static_cast(ptr); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/cuda/cuda_module.cc b/src/runtime/cuda/cuda_module.cc index 0550712de9ab..498a9b703a7b 100644 --- a/src/runtime/cuda/cuda_module.cc +++ b/src/runtime/cuda/cuda_module.cc @@ -22,19 +22,21 @@ */ #include "cuda_module.h" -#include #include #include -#include +#include + #include -#include #include +#include #include -#include "cuda_common.h" +#include + +#include "../file_util.h" +#include "../meta_data.h" #include "../pack_args.h" #include "../thread_storage_scope.h" -#include "../meta_data.h" -#include "../file_util.h" +#include "cuda_common.h" namespace tvm { namespace runtime { @@ -45,8 +47,7 @@ namespace runtime { // The modules will be lazily loaded class CUDAModuleNode : public runtime::ModuleNode { public: - explicit CUDAModuleNode(std::string data, - std::string fmt, + explicit CUDAModuleNode(std::string data, std::string fmt, std::unordered_map fmap, std::string cuda_source) : data_(data), fmt_(fmt), fmap_(fmap), cuda_source_(cuda_source) { @@ -62,16 +63,11 @@ class CUDAModuleNode : public runtime::ModuleNode { } } - const char* type_key() const final { - return "cuda"; - } + const char* type_key() const final { return "cuda"; } - PackedFunc GetFunction( - const std::string& name, - const ObjectPtr& sptr_to_self) final; + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final; - void SaveToFile(const std::string& file_name, - const std::string& format) final { + void SaveToFile(const std::string& file_name, const std::string& format) final { std::string fmt = GetFileFormat(file_name, format); std::string meta_file = GetMetaFilePath(file_name); if (fmt == "cu") { @@ -79,8 +75,7 @@ class CUDAModuleNode : public runtime::ModuleNode { SaveMetaDataToFile(meta_file, fmap_); SaveBinaryToFile(file_name, cuda_source_); } else { - CHECK_EQ(fmt, fmt_) - << "Can only save to format=" << fmt_; + CHECK_EQ(fmt, fmt_) << "Can only save to format=" << fmt_; SaveMetaDataToFile(meta_file, fmap_); SaveBinaryToFile(file_name, data_); } @@ -112,18 +107,14 @@ class CUDAModuleNode : public runtime::ModuleNode { CUfunction func; CUresult result = cuModuleGetFunction(&func, module_[device_id], func_name.c_str()); if (result != CUDA_SUCCESS) { - const char *msg; + const char* msg; cuGetErrorName(result, &msg); - LOG(FATAL) - << "CUDAError: cuModuleGetFunction " << func_name - << " failed with error: " << msg; + LOG(FATAL) << "CUDAError: cuModuleGetFunction " << func_name << " failed with error: " << msg; } return func; } // get a global var from primary context in device_id - CUdeviceptr GetGlobal(int device_id, - const std::string& global_name, - size_t expect_nbytes) { + CUdeviceptr GetGlobal(int device_id, const std::string& global_name, size_t expect_nbytes) { std::lock_guard lock(mutex_); // must recheck under the lock scope if (module_[device_id] == nullptr) { @@ -132,15 +123,12 @@ class CUDAModuleNode : public runtime::ModuleNode { CUdeviceptr global; size_t nbytes; - CUresult result = cuModuleGetGlobal(&global, &nbytes, - module_[device_id], global_name.c_str()); + CUresult result = cuModuleGetGlobal(&global, &nbytes, module_[device_id], global_name.c_str()); CHECK_EQ(nbytes, expect_nbytes); if (result != CUDA_SUCCESS) { - const char *msg; + const char* msg; cuGetErrorName(result, &msg); - LOG(FATAL) - << "CUDAError: cuModuleGetGlobal " << global_name - << " failed with error: " << msg; + LOG(FATAL) << "CUDAError: cuModuleGetGlobal " << global_name << " failed with error: " << msg; } return global; } @@ -164,11 +152,8 @@ class CUDAModuleNode : public runtime::ModuleNode { class CUDAWrappedFunc { public: // initialize the CUDA function. - void Init(CUDAModuleNode* m, - ObjectPtr sptr, - const std::string& func_name, - size_t num_void_args, - const std::vector& thread_axis_tags) { + void Init(CUDAModuleNode* m, ObjectPtr sptr, const std::string& func_name, + size_t num_void_args, const std::vector& thread_axis_tags) { m_ = m; sptr_ = sptr; func_name_ = func_name; @@ -176,9 +161,7 @@ class CUDAWrappedFunc { thread_axis_cfg_.Init(num_void_args, thread_axis_tags); } // invoke the function with void arguments - void operator()(TVMArgs args, - TVMRetValue* rv, - void** void_args) const { + void operator()(TVMArgs args, TVMRetValue* rv, void** void_args) const { int device_id; CUDA_CALL(cudaGetDevice(&device_id)); if (fcache_[device_id] == nullptr) { @@ -186,24 +169,17 @@ class CUDAWrappedFunc { } CUstream strm = static_cast(CUDAThreadEntry::ThreadLocal()->stream); ThreadWorkLoad wl = thread_axis_cfg_.Extract(args); - CUresult result = cuLaunchKernel( - fcache_[device_id], - wl.grid_dim(0), - wl.grid_dim(1), - wl.grid_dim(2), - wl.block_dim(0), - wl.block_dim(1), - wl.block_dim(2), - 0, strm, void_args, 0); + CUresult result = + cuLaunchKernel(fcache_[device_id], wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2), + wl.block_dim(0), wl.block_dim(1), wl.block_dim(2), 0, strm, void_args, 0); if (result != CUDA_SUCCESS && result != CUDA_ERROR_DEINITIALIZED) { - const char *msg; + const char* msg; cuGetErrorName(result, &msg); std::ostringstream os; os << "CUDALaunch Error: " << msg << "\n" - << " grid=(" << wl.grid_dim(0) << "," - << wl.grid_dim(1) << "," << wl.grid_dim(2) << "), " - << " block=(" << wl.block_dim(0) << "," - << wl.block_dim(1) << "," << wl.block_dim(2) << ")\n"; + << " grid=(" << wl.grid_dim(0) << "," << wl.grid_dim(1) << "," << wl.grid_dim(2) << "), " + << " block=(" << wl.block_dim(0) << "," << wl.block_dim(1) << "," << wl.block_dim(2) + << ")\n"; std::string cuda = m_->GetSource(""); if (cuda.length() != 0) { os << "// func_name=" << func_name_ << "\n" @@ -231,9 +207,7 @@ class CUDAWrappedFunc { class CUDAPrepGlobalBarrier { public: - CUDAPrepGlobalBarrier(CUDAModuleNode* m, - ObjectPtr sptr) - : m_(m), sptr_(sptr) { + CUDAPrepGlobalBarrier(CUDAModuleNode* m, ObjectPtr sptr) : m_(m), sptr_(sptr) { std::fill(pcache_.begin(), pcache_.end(), 0); } @@ -241,8 +215,8 @@ class CUDAPrepGlobalBarrier { int device_id; CUDA_CALL(cudaGetDevice(&device_id)); if (pcache_[device_id] == 0) { - pcache_[device_id] = m_->GetGlobal( - device_id, runtime::symbol::tvm_global_barrier_state, sizeof(unsigned)); + pcache_[device_id] = + m_->GetGlobal(device_id, runtime::symbol::tvm_global_barrier_state, sizeof(unsigned)); } CUDA_DRIVER_CALL(cuMemsetD32(pcache_[device_id], 0, 1)); } @@ -256,12 +230,10 @@ class CUDAPrepGlobalBarrier { mutable std::array pcache_; }; -PackedFunc CUDAModuleNode::GetFunction( - const std::string& name, - const ObjectPtr& sptr_to_self) { +PackedFunc CUDAModuleNode::GetFunction(const std::string& name, + const ObjectPtr& sptr_to_self) { CHECK_EQ(sptr_to_self.get(), this); - CHECK_NE(name, symbol::tvm_module_main) - << "Device function do not have main"; + CHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main"; if (name == symbol::tvm_prepare_global_barrier) { return PackedFunc(CUDAPrepGlobalBarrier(this, sptr_to_self)); } @@ -273,18 +245,15 @@ PackedFunc CUDAModuleNode::GetFunction( return PackFuncVoidAddr(f, info.arg_types); } -Module CUDAModuleCreate( - std::string data, - std::string fmt, - std::unordered_map fmap, - std::string cuda_source) { +Module CUDAModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, + std::string cuda_source) { auto n = make_object(data, fmt, fmap, cuda_source); return Module(n); } // Load module from module. -Module CUDAModuleLoadFile(const std::string& file_name, - const std::string& format) { +Module CUDAModuleLoadFile(const std::string& file_name, const std::string& format) { std::string data; std::unordered_map fmap; std::string fmt = GetFileFormat(file_name, format); @@ -305,13 +274,10 @@ Module CUDAModuleLoadBinary(void* strm) { return CUDAModuleCreate(data, fmt, fmap, std::string()); } -TVM_REGISTER_GLOBAL("runtime.module.loadfile_cubin") -.set_body_typed(CUDAModuleLoadFile); +TVM_REGISTER_GLOBAL("runtime.module.loadfile_cubin").set_body_typed(CUDAModuleLoadFile); -TVM_REGISTER_GLOBAL("runtime.module.loadfile_ptx") -.set_body_typed(CUDAModuleLoadFile); +TVM_REGISTER_GLOBAL("runtime.module.loadfile_ptx").set_body_typed(CUDAModuleLoadFile); -TVM_REGISTER_GLOBAL("runtime.module.loadbinary_cuda") -.set_body_typed(CUDAModuleLoadBinary); +TVM_REGISTER_GLOBAL("runtime.module.loadbinary_cuda").set_body_typed(CUDAModuleLoadBinary); } // namespace runtime } // namespace tvm diff --git a/src/runtime/cuda/cuda_module.h b/src/runtime/cuda/cuda_module.h index bce0d63e98a1..e65c5fe60811 100644 --- a/src/runtime/cuda/cuda_module.h +++ b/src/runtime/cuda/cuda_module.h @@ -25,10 +25,12 @@ #define TVM_RUNTIME_CUDA_CUDA_MODULE_H_ #include + #include -#include #include #include +#include + #include "../meta_data.h" namespace tvm { @@ -45,11 +47,9 @@ static constexpr const int kMaxNumGPUs = 32; * \param fmap The map function information map of each function. * \param cuda_source Optional, cuda source file */ -Module CUDAModuleCreate( - std::string data, - std::string fmt, - std::unordered_map fmap, - std::string cuda_source); +Module CUDAModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, + std::string cuda_source); } // namespace runtime } // namespace tvm #endif // TVM_RUNTIME_CUDA_CUDA_MODULE_H_ diff --git a/src/runtime/dso_library.cc b/src/runtime/dso_library.cc index 378f976dead1..6d3eec402306 100644 --- a/src/runtime/dso_library.cc +++ b/src/runtime/dso_library.cc @@ -21,10 +21,11 @@ * \file dso_libary.cc * \brief Create library module to load from dynamic shared library. */ -#include #include -#include +#include #include +#include + #include "library_module.h" #if defined(_WIN32) @@ -43,13 +44,9 @@ class DSOLibrary final : public Library { ~DSOLibrary() { if (lib_handle_) Unload(); } - void Init(const std::string& name) { - Load(name); - } + void Init(const std::string& name) { Load(name); } - void* GetSymbol(const char* name) final { - return GetSymbol_(name); - } + void* GetSymbol(const char* name) final { return GetSymbol_(name); } private: // Platform dependent handling. @@ -58,8 +55,7 @@ class DSOLibrary final : public Library { HMODULE lib_handle_{nullptr}; void* GetSymbol_(const char* name) { - return reinterpret_cast( - GetProcAddress(lib_handle_, (LPCSTR)name)); // NOLINT(*) + return reinterpret_cast(GetProcAddress(lib_handle_, (LPCSTR)name)); // NOLINT(*) } // Load the library @@ -67,8 +63,7 @@ class DSOLibrary final : public Library { // use wstring version that is needed by LLVM. std::wstring wname(name.begin(), name.end()); lib_handle_ = LoadLibraryW(wname.c_str()); - CHECK(lib_handle_ != nullptr) - << "Failed to load dynamic shared library " << name; + CHECK(lib_handle_ != nullptr) << "Failed to load dynamic shared library " << name; } void Unload() { @@ -81,14 +76,11 @@ class DSOLibrary final : public Library { // load the library void Load(const std::string& name) { lib_handle_ = dlopen(name.c_str(), RTLD_LAZY | RTLD_LOCAL); - CHECK(lib_handle_ != nullptr) - << "Failed to load dynamic shared library " << name - << " " << dlerror(); + CHECK(lib_handle_ != nullptr) << "Failed to load dynamic shared library " << name << " " + << dlerror(); } - void* GetSymbol_(const char* name) { - return dlsym(lib_handle_, name); - } + void* GetSymbol_(const char* name) { return dlsym(lib_handle_, name); } void Unload() { dlclose(lib_handle_); @@ -97,11 +89,10 @@ class DSOLibrary final : public Library { #endif }; -TVM_REGISTER_GLOBAL("runtime.module.loadfile_so") -.set_body([](TVMArgs args, TVMRetValue* rv) { - auto n = make_object(); - n->Init(args[0]); - *rv = CreateModuleFromLibrary(n); - }); +TVM_REGISTER_GLOBAL("runtime.module.loadfile_so").set_body([](TVMArgs args, TVMRetValue* rv) { + auto n = make_object(); + n->Init(args[0]); + *rv = CreateModuleFromLibrary(n); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/file_util.cc b/src/runtime/file_util.cc index f94b2d37b72b..68d174e470a2 100644 --- a/src/runtime/file_util.cc +++ b/src/runtime/file_util.cc @@ -20,13 +20,15 @@ /*! * \file file_util.cc */ +#include "file_util.h" + #include #include #include + #include -#include #include -#include "file_util.h" +#include namespace tvm { namespace runtime { @@ -69,8 +71,7 @@ bool FunctionInfo::Load(dmlc::Stream* reader) { return true; } -std::string GetFileFormat(const std::string& file_name, - const std::string& format) { +std::string GetFileFormat(const std::string& file_name, const std::string& format) { std::string fmt = format; if (fmt.length() == 0) { size_t pos = file_name.find_last_of("."); @@ -103,7 +104,7 @@ std::string GetFileBasename(const std::string& file_name) { } std::string GetMetaFilePath(const std::string& file_name) { - size_t pos = file_name.find_last_of("."); + size_t pos = file_name.find_last_of("."); if (pos != std::string::npos) { return file_name.substr(0, pos) + ".tvm_meta.json"; } else { @@ -111,8 +112,7 @@ std::string GetMetaFilePath(const std::string& file_name) { } } -void LoadBinaryFromFile(const std::string& file_name, - std::string* data) { +void LoadBinaryFromFile(const std::string& file_name, std::string* data) { std::ifstream fs(file_name, std::ios::in | std::ios::binary); CHECK(!fs.fail()) << "Cannot open " << file_name; // get its size: @@ -123,17 +123,14 @@ void LoadBinaryFromFile(const std::string& file_name, fs.read(&(*data)[0], size); } -void SaveBinaryToFile( - const std::string& file_name, - const std::string& data) { +void SaveBinaryToFile(const std::string& file_name, const std::string& data) { std::ofstream fs(file_name, std::ios::out | std::ios::binary); CHECK(!fs.fail()) << "Cannot open " << file_name; fs.write(&data[0], data.length()); } -void SaveMetaDataToFile( - const std::string& file_name, - const std::unordered_map& fmap) { +void SaveMetaDataToFile(const std::string& file_name, + const std::unordered_map& fmap) { std::string version = "0.1.0"; std::ofstream fs(file_name.c_str()); CHECK(!fs.fail()) << "Cannot open file " << file_name; @@ -145,9 +142,8 @@ void SaveMetaDataToFile( fs.close(); } -void LoadMetaDataFromFile( - const std::string& file_name, - std::unordered_map* fmap) { +void LoadMetaDataFromFile(const std::string& file_name, + std::unordered_map* fmap) { std::ifstream fs(file_name.c_str()); CHECK(!fs.fail()) << "Cannot open file " << file_name; std::string version; @@ -159,9 +155,7 @@ void LoadMetaDataFromFile( fs.close(); } -void RemoveFile(const std::string& file_name) { - std::remove(file_name.c_str()); -} +void RemoveFile(const std::string& file_name) { std::remove(file_name.c_str()); } } // namespace runtime } // namespace tvm diff --git a/src/runtime/file_util.h b/src/runtime/file_util.h index dfbaa16bded6..1c350357ec9a 100644 --- a/src/runtime/file_util.h +++ b/src/runtime/file_util.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -26,6 +26,7 @@ #include #include + #include "meta_data.h" namespace tvm { @@ -35,8 +36,7 @@ namespace runtime { * \param file_name The name of the file. * \param format The format of the file. */ -std::string GetFileFormat(const std::string& file_name, - const std::string& format); +std::string GetFileFormat(const std::string& file_name, const std::string& format); /*! * \return the directory in which TVM stores cached files. @@ -62,34 +62,30 @@ std::string GetFileBasename(const std::string& file_name); * \param file_name The name of the file. * \param data The data to be loaded. */ -void LoadBinaryFromFile(const std::string& file_name, - std::string* data); +void LoadBinaryFromFile(const std::string& file_name, std::string* data); /*! * \brief Load binary file into a in-memory buffer. * \param file_name The name of the file. * \param data The binary data to be saved. */ -void SaveBinaryToFile(const std::string& file_name, - const std::string& data); +void SaveBinaryToFile(const std::string& file_name, const std::string& data); /*! * \brief Save meta data to file. * \param file_name The name of the file. * \param fmap The function info map. */ -void SaveMetaDataToFile( - const std::string& file_name, - const std::unordered_map& fmap); +void SaveMetaDataToFile(const std::string& file_name, + const std::unordered_map& fmap); /*! * \brief Load meta data to file. * \param file_name The name of the file. * \param fmap The function info map. */ -void LoadMetaDataFromFile( - const std::string& file_name, - std::unordered_map* fmap); +void LoadMetaDataFromFile(const std::string& file_name, + std::unordered_map* fmap); /*! * \brief Remove (unlink) a file. diff --git a/src/runtime/graph/debug/graph_runtime_debug.cc b/src/runtime/graph/debug/graph_runtime_debug.cc index 1c85de859273..5439be9109f9 100644 --- a/src/runtime/graph/debug/graph_runtime_debug.cc +++ b/src/runtime/graph/debug/graph_runtime_debug.cc @@ -20,12 +20,14 @@ /*! * \file graph_runtime_debug.cc */ +#include +#include #include #include -#include #include #include + #include "../graph_runtime.h" namespace tvm { @@ -59,15 +61,14 @@ class GraphRuntimeDebug : public GraphRuntime { std::ostringstream os; std::vector time_per_op(op_execs_.size(), 0); for (int i = 0; i < repeat; ++i) { - std::chrono::time_point< - std::chrono::high_resolution_clock, std::chrono::nanoseconds> tbegin, tend; + std::chrono::time_point tbegin, + tend; double duration_ms = 0.0; do { std::fill(time_per_op.begin(), time_per_op.end(), 0); if (duration_ms > 0.0) { - number = static_cast( - std::max((min_repeat_ms / (duration_ms / number) + 1), - number * 1.618)); // 1.618 is chosen by random + number = static_cast(std::max((min_repeat_ms / (duration_ms / number) + 1), + number * 1.618)); // 1.618 is chosen by random } tbegin = std::chrono::high_resolution_clock::now(); for (int k = 0; k < number; k++) { @@ -78,15 +79,17 @@ class GraphRuntimeDebug : public GraphRuntime { op_execs_[index](); TVMSynchronize(ctx.device_type, ctx.device_id, nullptr); auto op_tend = std::chrono::high_resolution_clock::now(); - double op_duration = std::chrono::duration_cast< - std::chrono::duration >(op_tend - op_tbegin).count(); + double op_duration = + std::chrono::duration_cast >(op_tend - op_tbegin) + .count(); time_per_op[index] += op_duration * 1e6; // us } } } tend = std::chrono::high_resolution_clock::now(); - duration_ms = std::chrono::duration_cast > - (tend - tbegin).count() * 1000; + duration_ms = + std::chrono::duration_cast >(tend - tbegin).count() * + 1000; } while (duration_ms < min_repeat_ms); LOG(INFO) << "Iteration: " << i; @@ -94,8 +97,8 @@ class GraphRuntimeDebug : public GraphRuntime { for (size_t index = 0; index < time_per_op.size(); index++) { if (op_execs_[index]) { time_per_op[index] /= number; - LOG(INFO) << "Op #" << op++ << " " << GetNodeName(index) << ": " - << time_per_op[index] << " us/iter"; + LOG(INFO) << "Op #" << op++ << " " << GetNodeName(index) << ": " << time_per_op[index] + << " us/iter"; } } } @@ -110,17 +113,14 @@ class GraphRuntimeDebug : public GraphRuntime { * \param index The index of op which needs to be returned. * \param eid The Entry id of the op. */ - NDArray GetOutputByLayer(int index, int eid) { - return data_entry_[entry_id(index, eid)]; - } + NDArray GetOutputByLayer(int index, int eid) { return data_entry_[entry_id(index, eid)]; } /*! * \brief GetFunction Get the function based on input. * \param name The function which needs to be invoked. * \param sptr_to_self Packed function pointer. */ - PackedFunc GetFunction(const std::string& name, - const ObjectPtr& sptr_to_self); + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self); /*! * \brief Get the node index given the name of node. @@ -135,53 +135,51 @@ class GraphRuntimeDebug : public GraphRuntime { } LOG(FATAL) << "cannot find " << name << " among nodex"; return -1; -} + } -/*! - * \brief Copy index-th node to data_out. - * - * This method will do a partial run of the the graph - * from begining upto the index-th node and return output of index-th node. - * This is costly operation and suggest to use only for debug porpose. - * - * \param index: The index of the node. - * \param data_out the node data. - */ -void DebugGetNodeOutput(int index, DLTensor* data_out) { - CHECK_LT(static_cast(index), op_execs_.size()); - uint32_t eid = index; + /*! + * \brief Copy index-th node to data_out. + * + * This method will do a partial run of the the graph + * from begining upto the index-th node and return output of index-th node. + * This is costly operation and suggest to use only for debug porpose. + * + * \param index: The index of the node. + * \param data_out the node data. + */ + void DebugGetNodeOutput(int index, DLTensor* data_out) { + CHECK_LT(static_cast(index), op_execs_.size()); + uint32_t eid = index; - for (size_t i = 0; i < op_execs_.size(); ++i) { - if (op_execs_[i]) op_execs_[i](); - if (static_cast(i) == index) break; - } + for (size_t i = 0; i < op_execs_.size(); ++i) { + if (op_execs_[i]) op_execs_[i](); + if (static_cast(i) == index) break; + } - data_entry_[eid].CopyTo(data_out); -} + data_entry_[eid].CopyTo(data_out); + } }; - /*! * \brief GetFunction Get the function based on input. * \param name The function which needs to be invoked. * \param sptr_to_self Packed function pointer. */ -PackedFunc GraphRuntimeDebug::GetFunction( - const std::string& name, - const ObjectPtr& sptr_to_self) { +PackedFunc GraphRuntimeDebug::GetFunction(const std::string& name, + const ObjectPtr& sptr_to_self) { // return member functions during query. if (name == "get_output_by_layer") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->GetOutputByLayer(args[0], args[1]); - }); + *rv = this->GetOutputByLayer(args[0], args[1]); + }); } else if (name == "debug_get_output") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - if (args[0].type_code() == kTVMStr) { - this->DebugGetNodeOutput(this->GetNodeIndex(args[0]), args[1]); - } else { - this->DebugGetNodeOutput(args[0], args[1]); - } - }); + if (String::CanConvertFrom(args[0])) { + this->DebugGetNodeOutput(this->GetNodeIndex(args[0]), args[1]); + } else { + this->DebugGetNodeOutput(args[0], args[1]); + } + }); } else if (name == "run_individual") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { int number = args[0]; @@ -203,21 +201,18 @@ PackedFunc GraphRuntimeDebug::GetFunction( * \param m Compiled module which will be loaded. * \param ctxs All devices contexts. */ -Module GraphRuntimeDebugCreate(const std::string& sym_json, - const tvm::runtime::Module& m, +Module GraphRuntimeDebugCreate(const std::string& sym_json, const tvm::runtime::Module& m, const std::vector& ctxs) { auto exec = make_object(); exec->Init(sym_json, m, ctxs); return Module(exec); } -TVM_REGISTER_GLOBAL("tvm.graph_runtime_debug.create") -.set_body([](TVMArgs args, TVMRetValue* rv) { - CHECK_GE(args.num_args, 4) - << "The expected number of arguments for graph_runtime.create is " - "at least 4, but it has " - << args.num_args; - *rv = GraphRuntimeDebugCreate(args[0], args[1], GetAllContext(args)); - }); +TVM_REGISTER_GLOBAL("tvm.graph_runtime_debug.create").set_body([](TVMArgs args, TVMRetValue* rv) { + CHECK_GE(args.num_args, 4) << "The expected number of arguments for graph_runtime.create is " + "at least 4, but it has " + << args.num_args; + *rv = GraphRuntimeDebugCreate(args[0], args[1], GetAllContext(args)); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/graph/graph_runtime.cc b/src/runtime/graph/graph_runtime.cc index 0427e400ab8c..daa2c68ecac4 100644 --- a/src/runtime/graph/graph_runtime.cc +++ b/src/runtime/graph/graph_runtime.cc @@ -20,6 +20,9 @@ /*! * \file graph_runtime.cc */ +#include "graph_runtime.h" + +#include #include #include #include @@ -35,8 +38,6 @@ #include #include -#include "graph_runtime.h" - namespace tvm { namespace runtime { namespace details { @@ -64,8 +65,7 @@ void GraphRuntime::Run() { * \param ctxs The context of the host and devices where graph nodes will be * executed on. */ -void GraphRuntime::Init(const std::string& graph_json, - tvm::runtime::Module module, +void GraphRuntime::Init(const std::string& graph_json, tvm::runtime::Module module, const std::vector& ctxs) { std::istringstream is(graph_json); dmlc::JSONReader reader(&is); @@ -172,9 +172,7 @@ void GraphRuntime::SetInputZeroCopy(int index, DLTensor* data_ref) { * * \return The number of outputs from graph. */ -int GraphRuntime::NumOutputs() const { - return outputs_.size(); -} +int GraphRuntime::NumOutputs() const { return outputs_.size(); } /*! * \brief Get the type of the index-th output. * \param index The output index. @@ -239,20 +237,14 @@ void GraphRuntime::LoadParams(const std::string& param_blob) { void GraphRuntime::LoadParams(dmlc::Stream* strm) { uint64_t header, reserved; - CHECK(strm->Read(&header)) - << "Invalid parameters file format"; - CHECK(header == kTVMNDArrayListMagic) - << "Invalid parameters file format"; - CHECK(strm->Read(&reserved)) - << "Invalid parameters file format"; - - CHECK(strm->Read(&weight_names_)) - << "Invalid parameters file format"; + CHECK(strm->Read(&header)) << "Invalid parameters file format"; + CHECK(header == kTVMNDArrayListMagic) << "Invalid parameters file format"; + CHECK(strm->Read(&reserved)) << "Invalid parameters file format"; + CHECK(strm->Read(&weight_names_)) << "Invalid parameters file format"; uint64_t sz; strm->Read(&sz); size_t size = static_cast(sz); - CHECK(size == weight_names_.size()) - << "Invalid parameters file format"; + CHECK(size == weight_names_.size()) << "Invalid parameters file format"; for (size_t i = 0; i < size; ++i) { int in_idx = GetInputIndex(weight_names_[i]); CHECK_GE(in_idx, 0) << "Found param for non-existent input: " << weight_names_[i]; @@ -267,13 +259,10 @@ void GraphRuntime::LoadParams(dmlc::Stream* strm) { } void GraphRuntime::ShareParams(const GraphRuntime& other, dmlc::Stream* strm) { - uint64_t header, reserved; - CHECK(strm->Read(&header)) - << "Invalid parameters file format"; - CHECK(header == kTVMNDArrayListMagic) - << "Invalid parameters file format"; - CHECK(strm->Read(&reserved)) - << "Invalid parameters file format"; + uint64_t header, reserved; + CHECK(strm->Read(&header)) << "Invalid parameters file format"; + CHECK(header == kTVMNDArrayListMagic) << "Invalid parameters file format"; + CHECK(strm->Read(&reserved)) << "Invalid parameters file format"; std::vector names; CHECK(strm->Read(&names)) << "Invalid parameters file format"; uint64_t sz; @@ -318,15 +307,14 @@ void GraphRuntime::SetupStorage() { CHECK_GE(storage_id, 0) << "Do not support runtime shape op"; DLDataType t = vtype[i]; size_t bits = t.bits * t.lanes; - CHECK(bits % 8U == 0U || bits ==1U); + CHECK(bits % 8U == 0U || bits == 1U); size_t bytes = ((bits + 7U) / 8U) * size; uint32_t sid = static_cast(storage_id); if (sid >= pool_entry.size()) { pool_entry.resize(sid + 1, {0, -1}); } else { - CHECK(pool_entry[sid].device_type == -1 || - pool_entry[sid].device_type == device_type) + CHECK(pool_entry[sid].device_type == -1 || pool_entry[sid].device_type == device_type) << "The same pool entry cannot be assigned to multiple devices"; } pool_entry[sid].size = std::max(pool_entry[sid].size, bytes); @@ -338,14 +326,12 @@ void GraphRuntime::SetupStorage() { std::vector shape; // This for loop is very fast since there are usually only a couple of // devices available on the same hardware. - const auto& cit = - std::find_if(ctxs_.begin(), ctxs_.end(), [&pit](const TVMContext& c) { - return pit.device_type == static_cast(c.device_type); - }); + const auto& cit = std::find_if(ctxs_.begin(), ctxs_.end(), [&pit](const TVMContext& c) { + return pit.device_type == static_cast(c.device_type); + }); TVMContext ctx = cit == ctxs_.end() ? ctxs_[0] : *cit; shape.push_back(static_cast(pit.size + 3) / 4); - storage_pool_.push_back( - NDArray::Empty(shape, DLDataType{kDLFloat, 32, 1}, ctx)); + storage_pool_.push_back(NDArray::Empty(shape, DLDataType{kDLFloat, 32, 1}, ctx)); } // Assign the pooled entries. A unified memory pool is used to simplifiy @@ -356,8 +342,7 @@ void GraphRuntime::SetupStorage() { for (size_t i = 0; i < data_entry_.size(); ++i) { int storage_id = attrs_.storage_id[i]; CHECK_LT(static_cast(storage_id), storage_pool_.size()); - data_entry_[i] = - storage_pool_[storage_id].CreateView(attrs_.shape[i], vtype[i]); + data_entry_[i] = storage_pool_[storage_id].CreateView(attrs_.shape[i], vtype[i]); const DLTensor* tmp = data_entry_[i].operator->(); data_alignment_[i] = details::GetDataAlignment(*tmp); } @@ -385,39 +370,23 @@ void GraphRuntime::SetupOpExecs() { uint32_t eid = this->entry_id(nid, index); args.push_back(*(data_entry_[eid].operator->())); } + CHECK(inode.op_type == "tvm_op") << "Can only take tvm_op as op"; - if (inode.op_type == "tvm_op") { - std::shared_ptr op_args = nullptr; - std::tie(op_execs_[nid], op_args) = - CreateTVMOp(inode.param, args, inode.inputs.size()); + std::shared_ptr op_args = nullptr; + std::tie(op_execs_[nid], op_args) = CreateTVMOp(inode.param, args, inode.inputs.size()); - for (size_t i = 0; i < inode.inputs.size(); i++) { - uint32_t eid = this->entry_id(inode.inputs[i]); - // check if op input is model input - if (input_node_eids.count(eid) > 0) { - input_dltensors_[eid].push_back( - static_cast(op_args->arg_values[i].v_handle)); - } + for (size_t i = 0; i < inode.inputs.size(); i++) { + uint32_t eid = this->entry_id(inode.inputs[i]); + // check if op input is model input + if (input_node_eids.count(eid) > 0) { + input_dltensors_[eid].push_back(static_cast(op_args->arg_values[i].v_handle)); } - } else if (inode.op_type == "_tensorrt_subgraph_op") { -#ifdef TVM_GRAPH_RUNTIME_TENSORRT - CHECK_EQ(inode.subgraphs.size(), 1U) << "Only supports one subgraph per node"; - CHECK_EQ(inode.subgraphs[0].arg_nodes.size(), inode.inputs.size()); - op_execs_[nid] = tensorrt_exec_manager_.CreateExec( - inode.name, inode.subgraphs[0], args); -#else - LOG(FATAL) << "TensorRT NOT enabled for operator " << inode.op_type; -#endif // TVM_GRAPH_RUNTIME_TENSORRT - } else { - LOG(FATAL) << "Unknown op type " << inode.op_type << " in graph runtime"; } } } std::pair, std::shared_ptr > GraphRuntime::CreateTVMOp( - const TVMOpParam& param, - const std::vector& args, - size_t num_inputs) { + const TVMOpParam& param, const std::vector& args, size_t num_inputs) { std::shared_ptr arg_ptr = std::make_shared(); // setup address. arg_ptr->args = args; @@ -431,15 +400,15 @@ std::pair, std::shared_ptr > GraphRu arg_ptr->arg_values.push_back(v); arg_ptr->arg_tcodes.push_back(kTVMDLTensorHandle); if (param.flatten_data) { - arg_ptr->shape_data[i] = std::accumulate( - t->shape, t->shape + t->ndim, 1, std::multiplies()); + arg_ptr->shape_data[i] = + std::accumulate(t->shape, t->shape + t->ndim, 1, std::multiplies()); t->ndim = 1; t->shape = &(arg_ptr->shape_data[i]); } } if (param.func_name == "__nop") { - return {[](){}, arg_ptr}; + return {[]() {}, arg_ptr}; } else if (param.func_name == "__copy") { // Perform cross device data copy. // Directly copy data from the input to the output. @@ -450,8 +419,6 @@ std::pair, std::shared_ptr > GraphRu }; return {fexec, arg_ptr}; } - CHECK(!module_.IsEmpty()) - << "Module cannot be empty in order to get functions from the lib"; // Get compiled function from the module that contains both host and device // code. @@ -460,31 +427,29 @@ std::pair, std::shared_ptr > GraphRu auto fexec = [arg_ptr, pf]() { TVMRetValue rv; - TVMArgs targs(arg_ptr->arg_values.data(), - arg_ptr->arg_tcodes.data(), + TVMArgs targs(arg_ptr->arg_values.data(), arg_ptr->arg_tcodes.data(), static_cast(arg_ptr->arg_values.size())); pf.CallPacked(targs, &rv); }; return {fexec, arg_ptr}; } -PackedFunc GraphRuntime::GetFunction( - const std::string& name, - const ObjectPtr& sptr_to_self) { +PackedFunc GraphRuntime::GetFunction(const std::string& name, + const ObjectPtr& sptr_to_self) { // Return member functions during query. if (name == "set_input") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - if (args[0].type_code() == kTVMStr) { - int in_idx = this->GetInputIndex(args[0]); - if (in_idx >= 0) this->SetInput(in_idx, args[1]); - } else { - this->SetInput(args[0], args[1]); - } - }); + if (String::CanConvertFrom(args[0])) { + int in_idx = this->GetInputIndex(args[0].operator String()); + if (in_idx >= 0) this->SetInput(in_idx, args[1]); + } else { + this->SetInput(args[0], args[1]); + } + }); } else if (name == "set_input_zero_copy") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - if (args[0].type_code() == kTVMStr) { - int in_idx = this->GetInputIndex(args[0]); + if (String::CanConvertFrom(args[0])) { + int in_idx = this->GetInputIndex(args[0].operator String()); if (in_idx >= 0) this->SetInputZeroCopy(in_idx, args[1]); } else { this->SetInputZeroCopy(args[0], args[1]); @@ -500,42 +465,38 @@ PackedFunc GraphRuntime::GetFunction( }); } else if (name == "get_input") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - int in_idx = 0; - if (args[0].type_code() == kTVMStr) { - in_idx = this->GetInputIndex(args[0]); - } else { - in_idx = args[0]; - } - CHECK_GE(in_idx, 0); - *rv = this->GetInput(in_idx); - }); + int in_idx = 0; + if (String::CanConvertFrom(args[0])) { + in_idx = this->GetInputIndex(args[0].operator String()); + } else { + in_idx = args[0]; + } + CHECK_GE(in_idx, 0); + *rv = this->GetInput(in_idx); + }); } else if (name == "get_num_outputs") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->NumOutputs(); - }); + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->NumOutputs(); }); } else if (name == "run") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - this->Run(); - }); + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { this->Run(); }); } else if (name == "load_params") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - this->LoadParams(args[0].operator std::string()); - }); + this->LoadParams(args[0].operator std::string()); + }); } else if (name == "share_params") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - const auto& module = args[0].operator Module(); - CHECK_EQ(module.operator->()->type_key(), "GraphRuntime"); - const auto& param_blob = args[1].operator std::string(); - dmlc::MemoryStringStream strm(const_cast(¶m_blob)); - this->ShareParams(dynamic_cast(*module.operator->()), &strm); - }); + const auto& module = args[0].operator Module(); + CHECK_EQ(module.operator->()->type_key(), "GraphRuntime"); + const auto& param_blob = args[1].operator std::string(); + dmlc::MemoryStringStream strm(const_cast(¶m_blob)); + this->ShareParams(dynamic_cast(*module.operator->()), &strm); + }); } else { return PackedFunc(); } } -Module GraphRuntimeCreate(const std::string& sym_json, - const tvm::runtime::Module& m, +Module GraphRuntimeCreate(const std::string& sym_json, const tvm::runtime::Module& m, const std::vector& ctxs) { auto exec = make_object(); exec->Init(sym_json, m, ctxs); @@ -561,14 +522,12 @@ std::vector GetAllContext(const TVMArgs& args) { // execution support yet. For heterogenenous execution, at least 5 arguments will // be passed in. The third one is the number of devices. // Eventually, we will only probably pass TVMContext for all the languages. -TVM_REGISTER_GLOBAL("tvm.graph_runtime.create") - .set_body([](TVMArgs args, TVMRetValue* rv) { - CHECK_GE(args.num_args, 4) - << "The expected number of arguments for graph_runtime.create is " - "at least 4, but it has " - << args.num_args; - const auto& contexts = GetAllContext(args); - *rv = GraphRuntimeCreate(args[0], args[1], contexts); - }); +TVM_REGISTER_GLOBAL("tvm.graph_runtime.create").set_body([](TVMArgs args, TVMRetValue* rv) { + CHECK_GE(args.num_args, 4) << "The expected number of arguments for graph_runtime.create is " + "at least 4, but it has " + << args.num_args; + const auto& contexts = GetAllContext(args); + *rv = GraphRuntimeCreate(args[0], args[1], contexts); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/graph/graph_runtime.h b/src/runtime/graph/graph_runtime.h index 345a953b38c3..327dbe079cdb 100644 --- a/src/runtime/graph/graph_runtime.h +++ b/src/runtime/graph/graph_runtime.h @@ -26,16 +26,16 @@ #define TVM_RUNTIME_GRAPH_GRAPH_RUNTIME_H_ #include -#include #include +#include #include #include #include +#include #include #include #include -#include #include "../../contrib/subgraph/subgraph.h" #ifdef TVM_GRAPH_RUNTIME_TENSORRT @@ -46,11 +46,10 @@ namespace tvm { namespace runtime { /*! \brief macro to do C API call */ -#define TVM_CCALL(func) \ - { \ - int ret = (func); \ - CHECK_EQ(ret, 0) \ - << TVMGetLastError(); \ +#define TVM_CCALL(func) \ + { \ + int ret = (func); \ + CHECK_EQ(ret, 0) << TVMGetLastError(); \ } /*! \brief Magic number for NDArray list file */ @@ -85,15 +84,12 @@ class TVM_DLL GraphRuntime : public ModuleNode { * \param sptr_to_self The pointer to the module node. * \return The corresponding member function. */ - virtual PackedFunc GetFunction(const std::string& name, - const ObjectPtr& sptr_to_self); + virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self); /*! * \return The type key of the executor. */ - const char* type_key() const final { - return "GraphRuntime"; - } + const char* type_key() const final { return "GraphRuntime"; } void Run(); /*! @@ -105,8 +101,7 @@ class TVM_DLL GraphRuntime : public ModuleNode { * executed on. */ - void Init(const std::string& graph_json, - tvm::runtime::Module module, + void Init(const std::string& graph_json, tvm::runtime::Module module, const std::vector& ctxs); /*! @@ -210,14 +205,9 @@ class TVM_DLL GraphRuntime : public ModuleNode { * \brief Get total number of nodes. * \return Total number of nodes. */ - uint32_t GetNumOfNodes() const { - return static_cast(nodes_.size()); - } - - std::string GetNodeName(uint32_t nid) const { - return nodes_[nid].name; - } + uint32_t GetNumOfNodes() const { return static_cast(nodes_.size()); } + std::string GetNodeName(uint32_t nid) const { return nodes_[nid].name; } protected: // Memory pool entry. @@ -232,7 +222,7 @@ class TVM_DLL GraphRuntime : public ModuleNode { uint32_t index; uint32_t version; // JSON Loader - void Load(dmlc::JSONReader *reader) { + void Load(dmlc::JSONReader* reader) { reader->BeginArray(); CHECK(reader->NextArrayItem()) << "invalid json format"; reader->Read(&node_id); @@ -261,7 +251,7 @@ class TVM_DLL GraphRuntime : public ModuleNode { // subgraphs std::vector subgraphs; // JSON Loader - void LoadAttrs(dmlc::JSONReader *reader, TVMOpParam* param) { + void LoadAttrs(dmlc::JSONReader* reader, TVMOpParam* param) { int bitmask = 0; std::string key, value; reader->BeginObject(); @@ -281,7 +271,7 @@ class TVM_DLL GraphRuntime : public ModuleNode { bitmask |= 8; } } - CHECK_EQ(bitmask, 1|2|4|8) << "invalid format"; + CHECK_EQ(bitmask, 1 | 2 | 4 | 8) << "invalid format"; } // Subgraph loader @@ -297,7 +287,7 @@ class TVM_DLL GraphRuntime : public ModuleNode { } // JSON Loader - void Load(dmlc::JSONReader *reader) { + void Load(dmlc::JSONReader* reader) { reader->BeginObject(); int bitmask = 0; std::string key; @@ -321,7 +311,7 @@ class TVM_DLL GraphRuntime : public ModuleNode { LOG(FATAL) << "do not support key " << key; } } - CHECK_EQ(bitmask, 1|2|4) << "invalid format"; + CHECK_EQ(bitmask, 1 | 2 | 4) << "invalid format"; } }; struct GraphAttr { @@ -329,9 +319,9 @@ class TVM_DLL GraphRuntime : public ModuleNode { std::vector storage_id; std::vector device_index; std::vector dltype; - std::vector > shape; + std::vector> shape; // The graph attribute fields. - void Load(dmlc::JSONReader *reader) { + void Load(dmlc::JSONReader* reader) { reader->BeginObject(); int bitmask = 0; std::string key, type; @@ -389,37 +379,37 @@ class TVM_DLL GraphRuntime : public ModuleNode { CHECK(!reader->NextArrayItem()); } } - CHECK_EQ(bitmask, 1|2|4) << "invalid format"; + CHECK_EQ(bitmask, 1 | 2 | 4) << "invalid format"; } }; // The graph attribute fields. - void Load(dmlc::JSONReader *reader) { - reader->BeginObject(); - int bitmask = 0; - std::string key; - while (reader->NextObjectItem(&key)) { - if (key == "nodes") { - reader->Read(&nodes_); - bitmask |= 1; - } else if (key == "arg_nodes") { - reader->Read(&input_nodes_); - bitmask |= 2; - } else if (key == "node_row_ptr") { - reader->Read(&node_row_ptr_); - bitmask |= 4; - } else if (key == "heads") { - reader->Read(&outputs_); - bitmask |= 8; - } else if (key == "attrs") { - reader->Read(&attrs_); - bitmask |= 16; - } else if (key == "metadata") { - break; - } else { - LOG(FATAL) << "key " << key << " is not supported"; - } + void Load(dmlc::JSONReader* reader) { + reader->BeginObject(); + int bitmask = 0; + std::string key; + while (reader->NextObjectItem(&key)) { + if (key == "nodes") { + reader->Read(&nodes_); + bitmask |= 1; + } else if (key == "arg_nodes") { + reader->Read(&input_nodes_); + bitmask |= 2; + } else if (key == "node_row_ptr") { + reader->Read(&node_row_ptr_); + bitmask |= 4; + } else if (key == "heads") { + reader->Read(&outputs_); + bitmask |= 8; + } else if (key == "attrs") { + reader->Read(&attrs_); + bitmask |= 16; + } else if (key == "metadata") { + break; + } else { + LOG(FATAL) << "key " << key << " is not supported"; } - CHECK_EQ(bitmask, 1|2|4|8|16) << "invalid format"; + } + CHECK_EQ(bitmask, 1 | 2 | 4 | 8 | 16) << "invalid format"; } /*! \brief Setup the temporal storage */ void SetupStorage(); @@ -432,21 +422,14 @@ class TVM_DLL GraphRuntime : public ModuleNode { * \param num_inputs Number of inputs. * \return The created executor. */ - std::pair, std::shared_ptr > CreateTVMOp( - const TVMOpParam& attrs, const std::vector& args, - size_t num_inputs); + std::pair, std::shared_ptr> CreateTVMOp( + const TVMOpParam& attrs, const std::vector& args, size_t num_inputs); // Get node entry index. - uint32_t entry_id(uint32_t nid, uint32_t index) const { - return node_row_ptr_[nid] + index; - } + uint32_t entry_id(uint32_t nid, uint32_t index) const { return node_row_ptr_[nid] + index; } // Get node entry index. - uint32_t entry_id(const NodeEntry& e) const { - return entry_id(e.node_id, e.index); - } + uint32_t entry_id(const NodeEntry& e) const { return entry_id(e.node_id, e.index); } // Number of node entries. - uint32_t num_node_entries() const { - return node_row_ptr_.back(); - } + uint32_t num_node_entries() const { return node_row_ptr_.back(); } /*! \brief The weight names. */ std::vector weight_names_; /*! \brief The graph nodes. */ @@ -474,13 +457,7 @@ class TVM_DLL GraphRuntime : public ModuleNode { /*! \brief Data alignment of each node. */ std::vector data_alignment_; /*! \brief Operator on each node. */ - std::vector > op_execs_; -#ifdef TVM_GRAPH_RUNTIME_TENSORRT - contrib::TensorRTExecManager tensorrt_exec_manager_; -#endif // TVM_GRAPH_RUNTIME_TENSORRT - - /*! \brief Arg info of TVM ops */ - std::vector > op_args_; + std::vector> op_execs_; }; std::vector GetAllContext(const TVMArgs& args); diff --git a/src/runtime/hexagon/hexagon_device_api.cc b/src/runtime/hexagon/hexagon_device_api.cc index d88e6d7284a3..fd6f32374005 100644 --- a/src/runtime/hexagon/hexagon_device_api.cc +++ b/src/runtime/hexagon/hexagon_device_api.cc @@ -33,21 +33,17 @@ class HexagonDeviceAPI : public DeviceAPI { public: void SetDevice(TVMContext ctx) final; void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final; - void* AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment, - DLDataType type_hint) final; + void* AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment, DLDataType type_hint) final; void FreeDataSpace(TVMContext ctx, void* ptr) final; - void CopyDataFromTo(const void* from, size_t from_offset, void* to, - size_t to_offset, size_t num_bytes, TVMContext ctx_from, - TVMContext ctx_to, DLDataType type_hint, - TVMStreamHandle stream) final; + void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, + size_t num_bytes, TVMContext ctx_from, TVMContext ctx_to, + DLDataType type_hint, TVMStreamHandle stream) final; void StreamSync(TVMContext ctx, TVMStreamHandle stream) final; - void* AllocWorkspace(TVMContext ctx, size_t nbytes, - DLDataType type_hint = {}) final; + void* AllocWorkspace(TVMContext ctx, size_t nbytes, DLDataType type_hint = {}) final; void FreeWorkspace(TVMContext ctx, void* ptr) final; static const std::shared_ptr& Global() { - static std::shared_ptr inst = - std::make_shared(); + static std::shared_ptr inst = std::make_shared(); return inst; } }; @@ -56,13 +52,11 @@ class HexagonDeviceAPI : public DeviceAPI { inline void HexagonDeviceAPI::SetDevice(TVMContext ctx) {} -inline void HexagonDeviceAPI::GetAttr(TVMContext ctx, DeviceAttrKind kind, - TVMRetValue* rv) { +inline void HexagonDeviceAPI::GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) { if (kind == kExist) *rv = 1; } -inline void* HexagonDeviceAPI::AllocDataSpace(TVMContext ctx, size_t nbytes, - size_t alignment, +inline void* HexagonDeviceAPI::AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment, DLDataType type_hint) { CHECK(hexagon::Device::ValidateDeviceId(ctx.device_id)); return hexagon::Device::Global()->Alloc(nbytes, alignment); @@ -73,10 +67,10 @@ inline void HexagonDeviceAPI::FreeDataSpace(TVMContext ctx, void* ptr) { hexagon::Device::Global()->Free(ptr); } -inline void HexagonDeviceAPI::CopyDataFromTo( - const void* from, size_t from_offset, void* to, size_t to_offset, - size_t num_bytes, TVMContext ctx_from, TVMContext ctx_to, - DLDataType type_hint, TVMStreamHandle stream) { +inline void HexagonDeviceAPI::CopyDataFromTo(const void* from, size_t from_offset, void* to, + size_t to_offset, size_t num_bytes, + TVMContext ctx_from, TVMContext ctx_to, + DLDataType type_hint, TVMStreamHandle stream) { const char* src = static_cast(from) + from_offset; char* dst = static_cast(to) + to_offset; @@ -110,11 +104,9 @@ inline void HexagonDeviceAPI::CopyDataFromTo( } } -inline void HexagonDeviceAPI::StreamSync(TVMContext ctx, - TVMStreamHandle stream) {} +inline void HexagonDeviceAPI::StreamSync(TVMContext ctx, TVMStreamHandle stream) {} -inline void* HexagonDeviceAPI::AllocWorkspace(TVMContext ctx, size_t nbytes, - DLDataType type_hint) { +inline void* HexagonDeviceAPI::AllocWorkspace(TVMContext ctx, size_t nbytes, DLDataType type_hint) { CHECK(hexagon::Device::ValidateDeviceId(ctx.device_id)); if (type_hint.code == 100) { size_t align = std::min(nbytes, 2048lu); @@ -128,11 +120,10 @@ inline void HexagonDeviceAPI::FreeWorkspace(TVMContext ctx, void* ptr) { DeviceAPI::FreeWorkspace(ctx, ptr); } -TVM_REGISTER_GLOBAL("device_api.hexagon") - .set_body([](TVMArgs args, TVMRetValue* rv) { - DeviceAPI* ptr = HexagonDeviceAPI::Global().get(); - *rv = ptr; - }); +TVM_REGISTER_GLOBAL("device_api.hexagon").set_body([](TVMArgs args, TVMRetValue* rv) { + DeviceAPI* ptr = HexagonDeviceAPI::Global().get(); + *rv = ptr; +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/hexagon/hexagon_module.cc b/src/runtime/hexagon/hexagon_module.cc index e14843688b73..f76ac1670e24 100644 --- a/src/runtime/hexagon/hexagon_module.cc +++ b/src/runtime/hexagon/hexagon_module.cc @@ -176,8 +176,7 @@ void ArgLayout::Push(uint32_t* v, unsigned t_size, unsigned t_align) { if (!InReg) { // Allocate on stack. - CHECK_EQ((t_align & (t_align - 1)), 0) - << "Alignment should be a power of 2"; + CHECK_EQ((t_align & (t_align - 1)), 0) << "Alignment should be a power of 2"; CHECK_GE(t_align, 4) << "Alignment should be at least 4"; // Round t_size up to a multiple of 4. unsigned s_size = Stack.size(); @@ -193,9 +192,8 @@ void ArgLayout::Push(uint32_t* v, unsigned t_size, unsigned t_align) { class HexagonModuleNode final : public runtime::ModuleNode { public: HexagonModuleNode(std::string data, std::string fmt, - std::unordered_map fmap, - std::string asm_str, std::string obj_str, - std::string ir_str, std::string bc_str, + std::unordered_map fmap, std::string asm_str, + std::string obj_str, std::string ir_str, std::string bc_str, const std::set& packed_c_abi) : hexagon_device_(hexagon::Device::Global()), data_(data), @@ -214,13 +212,11 @@ class HexagonModuleNode final : public runtime::ModuleNode { } } - PackedFunc GetFunction(const std::string& name, - const ObjectPtr& sptr_to_self) final; + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final; const char* type_key() const final { return "hexagon"; } - void SaveToFile(const std::string& file_name, - const std::string& format) final { + void SaveToFile(const std::string& file_name, const std::string& format) final { std::string fmt = runtime::GetFileFormat(file_name, format); if (fmt == "so" || fmt == "dll" || fmt == "hexagon") { std::string meta_file = GetMetaFilePath(file_name); @@ -240,8 +236,7 @@ class HexagonModuleNode final : public runtime::ModuleNode { CHECK(!bc_.empty()) << "LLVM IR bitcode not available"; SaveBinaryToFile(file_name, bc_); } else { - LOG(FATAL) << "HexagonModuleNode::SaveToFile: unhandled format `" << fmt - << "'"; + LOG(FATAL) << "HexagonModuleNode::SaveToFile: unhandled format `" << fmt << "'"; } } void SaveToBinary(dmlc::Stream* stream) final { @@ -251,10 +246,8 @@ class HexagonModuleNode final : public runtime::ModuleNode { } private: - void CallRemotePackedCABI(void* func_ptr, const TVMArgs& args, - TVMRetValue* rv) const; - void CallRemoteDirect(void* func_ptr, const TVMArgs& args, - TVMRetValue* rv) const; + void CallRemotePackedCABI(void* func_ptr, const TVMArgs& args, TVMRetValue* rv) const; + void CallRemoteDirect(void* func_ptr, const TVMArgs& args, TVMRetValue* rv) const; void RemapArgs(const TVMArgs& args, std::vector& values, // NOLINT(*) std::vector& type_codes, // NOLINT(*) @@ -274,8 +267,7 @@ class HexagonModuleNode final : public runtime::ModuleNode { std::set packed_c_abi_funcs_; }; -void HexagonModuleNode::CallRemotePackedCABI(void* func_ptr, - const TVMArgs& args, +void HexagonModuleNode::CallRemotePackedCABI(void* func_ptr, const TVMArgs& args, TVMRetValue* rv) const { // Remap all arguments, creating remote DLTensors. std::vector values; @@ -297,8 +289,8 @@ void HexagonModuleNode::CallRemotePackedCABI(void* func_ptr, int num_args = args.size(); int values_size = num_args * sizeof(TVMValue); int codes_size = num_args * sizeof(int); - void* remote = hexagon_device_->Alloc( - values_size + sizeof(TVMValue) + codes_size + sizeof(int), 8); + void* remote = + hexagon_device_->Alloc(values_size + sizeof(TVMValue) + codes_size + sizeof(int), 8); // Copy all argument TVMValues to the remote space. void* remote_values = remote; @@ -316,12 +308,12 @@ void HexagonModuleNode::CallRemotePackedCABI(void* func_ptr, temp_values[2].v_int64 = num_args; temp_values[3].v_handle = remote_ret_value; temp_values[4].v_handle = remote_ret_code; - int temp_codes[5] = {kTVMOpaqueHandle, kTVMOpaqueHandle, kDLInt, - kTVMOpaqueHandle, kTVMOpaqueHandle}; + int temp_codes[5] = {kTVMOpaqueHandle, kTVMOpaqueHandle, kDLInt, kTVMOpaqueHandle, + kTVMOpaqueHandle}; TVMArgs temp_args(temp_values, temp_codes, 5); hexagon::ArgLayout as = BuildArgLayout(temp_args); - hexagon_device_->Call(func_ptr, as.Scalar.data(), as.Scalar.size(), - as.Stack.data(), as.Stack.size()); + hexagon_device_->Call(func_ptr, as.Scalar.data(), as.Scalar.size(), as.Stack.data(), + as.Stack.size()); // TODO(kparzysz-quic): copy return value back std::for_each(remote_tensors.begin(), remote_tensors.end(), @@ -332,12 +324,12 @@ void HexagonModuleNode::CallRemotePackedCABI(void* func_ptr, void HexagonModuleNode::CallRemoteDirect(void* func_ptr, const TVMArgs& args, TVMRetValue* rv) const { hexagon::ArgLayout as = BuildArgLayout(args); - hexagon_device_->Call(func_ptr, as.Scalar.data(), as.Scalar.size(), - as.Stack.data(), as.Stack.size()); + hexagon_device_->Call(func_ptr, as.Scalar.data(), as.Scalar.size(), as.Stack.data(), + as.Stack.size()); } -PackedFunc HexagonModuleNode::GetFunction( - const std::string& name, const ObjectPtr& sptr_to_self) { +PackedFunc HexagonModuleNode::GetFunction(const std::string& name, + const ObjectPtr& sptr_to_self) { auto f = fmap_.find(name); if (f == fmap_.end()) return PackedFunc(nullptr); @@ -363,8 +355,7 @@ PackedFunc HexagonModuleNode::GetFunction( } } -void HexagonModuleNode::RemapArgs(const TVMArgs& args, - std::vector& values, +void HexagonModuleNode::RemapArgs(const TVMArgs& args, std::vector& values, std::vector& type_codes, std::vector& remote_tensors) const { for (unsigned i = 0, e = args.size(); i != e; ++i) { @@ -437,18 +428,17 @@ void* HexagonModuleNode::CreateRemoteTensor(const DLTensor* t) const { uint32_t remote_as_int = reinterpret_cast(remote); void* remote_ss = reinterpret_cast(remote_as_int + size_ht); - HexagonDLTensor local = { - .data = static_cast(reinterpret_cast(t->data)), - .ctx_device_type = uint8_t(t->ctx.device_type), - .pad0 = {0, 0, 0}, - .ctx_device_id = t->ctx.device_id, - .ndim = t->ndim, - .dtype_code = t->dtype.code, - .dtype_bits = t->dtype.bits, - .dtype_lanes = t->dtype.lanes, - .shape = remote_as_int + size_ht, - .strides = t->strides ? remote_as_int + size_ht + size_s : 0u, - .byte_offset = t->byte_offset}; + HexagonDLTensor local = {.data = static_cast(reinterpret_cast(t->data)), + .ctx_device_type = uint8_t(t->ctx.device_type), + .pad0 = {0, 0, 0}, + .ctx_device_id = t->ctx.device_id, + .ndim = t->ndim, + .dtype_code = t->dtype.code, + .dtype_bits = t->dtype.bits, + .dtype_lanes = t->dtype.lanes, + .shape = remote_as_int + size_ht, + .strides = t->strides ? remote_as_int + size_ht + size_s : 0u, + .byte_offset = t->byte_offset}; std::vector local_ss(size_ss / 8); for (int i = 0; i != ndim; ++i) local_ss[i] = t->shape[i]; @@ -505,18 +495,16 @@ hexagon::ArgLayout HexagonModuleNode::BuildArgLayout(const TVMArgs& As) const { } Module HexagonModuleCreate(std::string data, std::string fmt, - std::unordered_map fmap, - std::string asm_str, std::string obj_str, - std::string ir_str, std::string bc_str, + std::unordered_map fmap, std::string asm_str, + std::string obj_str, std::string ir_str, std::string bc_str, const std::set& packed_c_abi) { - auto n = make_object(data, fmt, fmap, asm_str, obj_str, - ir_str, bc_str, packed_c_abi); + auto n = make_object(data, fmt, fmap, asm_str, obj_str, ir_str, bc_str, + packed_c_abi); return Module(n); } // Load module from file. -Module HexagonModuleLoadFile(const std::string& file_name, - const std::string& format) { +Module HexagonModuleLoadFile(const std::string& file_name, const std::string& format) { std::string data = file_name; std::unordered_map fmap; std::string fmt = GetFileFormat(file_name, format); @@ -552,10 +540,9 @@ std::shared_ptr Device::Global() { } // namespace hexagon -TVM_REGISTER_GLOBAL("runtime.module.loadfile_hexagon") - .set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = HexagonModuleLoadFile(args[0], args[1]); - }); +TVM_REGISTER_GLOBAL("runtime.module.loadfile_hexagon").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = HexagonModuleLoadFile(args[0], args[1]); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/hexagon/hexagon_module.h b/src/runtime/hexagon/hexagon_module.h index c9e23a77776e..b922b169bd61 100644 --- a/src/runtime/hexagon/hexagon_module.h +++ b/src/runtime/hexagon/hexagon_module.h @@ -47,9 +47,8 @@ namespace runtime { * convention. */ Module HexagonModuleCreate(std::string data, std::string fmt, - std::unordered_map fmap, - std::string asm_str, std::string obj_str, - std::string ir_str, std::string bc_str, + std::unordered_map fmap, std::string asm_str, + std::string obj_str, std::string ir_str, std::string bc_str, const std::set& packed_c_abi); namespace hexagon { @@ -91,24 +90,21 @@ class Device { * \param src Pointer (local to device) of the source buffer. * \param len Number of bytes to copy. */ - virtual void CopyDeviceToDevice(void* dst, const void* src, - unsigned len) = 0; + virtual void CopyDeviceToDevice(void* dst, const void* src, unsigned len) = 0; /*! * \brief Copy a block of data from device to host. * \param host_dst Pointer (local to host) to the destination buffer. * \param src Pointer (local to device) to the source buffer. * \param len Number of bytes to copy. */ - virtual void CopyDeviceToHost(void* host_dst, const void* src, - unsigned len) = 0; + virtual void CopyDeviceToHost(void* host_dst, const void* src, unsigned len) = 0; /*! * \brief Copy a block of data from host to device. * \param dst Pointer (local to device) to the destination buffer. * \param host_src Pointer (local to host) to the source buffer. * \param len Number of bytes to copy. */ - virtual void CopyHostToDevice(void* dst, const void* host_src, - unsigned len) = 0; + virtual void CopyHostToDevice(void* dst, const void* host_src, unsigned len) = 0; /*! * \brief Load a module (typically a shared library) into device. * \param data Name of the shared library. @@ -141,8 +137,8 @@ class Device { * for padding. * \param st_num Number of values in the "stack" array. */ - virtual void Call(void* func, uint32_t* scalar, unsigned sc_num, - uint32_t* stack, unsigned st_num) = 0; + virtual void Call(void* func, uint32_t* scalar, unsigned sc_num, uint32_t* stack, + unsigned st_num) = 0; virtual ~Device() = 0; diff --git a/src/runtime/hexagon/hexagon_posix.cc b/src/runtime/hexagon/hexagon_posix.cc index 627963f384f5..e98fefd1da22 100644 --- a/src/runtime/hexagon/hexagon_posix.cc +++ b/src/runtime/hexagon/hexagon_posix.cc @@ -23,12 +23,10 @@ #include extern "C" { -int posix_memalign(void** memptr, size_t alignment, size_t size) - __attribute__((nothrow)); +int posix_memalign(void** memptr, size_t alignment, size_t size) __attribute__((nothrow)); } -__attribute__((nothrow)) int posix_memalign(void** memptr, size_t alignment, - size_t size) { +__attribute__((nothrow)) int posix_memalign(void** memptr, size_t alignment, size_t size) { if (void* p = memalign(alignment, size)) { *memptr = p; return 0; diff --git a/src/runtime/hexagon/sim/driver/CMakeLists.txt b/src/runtime/hexagon/sim/driver/CMakeLists.txt new file mode 100644 index 000000000000..8632b491f259 --- /dev/null +++ b/src/runtime/hexagon/sim/driver/CMakeLists.txt @@ -0,0 +1,62 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +project(SIM_DEV C CXX) +cmake_minimum_required(VERSION 3.0.2) + +set(CMAKE_SYSTEM_NAME "Linux") + +if(EXISTS ${CMAKE_CURRENT_BINARY_DIR}/config.cmake) + include(${CMAKE_CURRENT_BINARY_DIR}/config.cmake) +endif() + +set(EXTRA_CXX_FLAGS + "-O2" + "-Wno-format" + "-mhvx -mhvx-length=128b" + "-mv60" + "-stdlib=libc++" +) + +set(EXTRA_LINK_FLAGS + "-stdlib=libc++" + "-G0" + "-Wl,--force-dynamic" + "-Wl,--export-dynamic" + "-Wl,--whole-archive" # This should link entire libc, libc++ and libc+abi. + "-Wl,--defsym=HEAP_SIZE=0x40000000" +) + +string(REGEX REPLACE ";" " " EXTRA_CXX_FLAGS_STR "${EXTRA_CXX_FLAGS}") +string(REGEX REPLACE ";" " " EXTRA_LINK_FLAGS_STR "${EXTRA_LINK_FLAGS}") + +set(CMAKE_CXX_STANDARD 11) +set(CMAKE_CXX_FLAGS "${EXTRA_CXX_FLAGS_STR} ${CMAKE_CXX_FLAGS}") +set(CMAKE_EXE_LINKER_FLAGS "${EXTRA_LINK_FLAGS_STR} ${CMAKE_EXE_LINKER_FLAGS}") + +# Set project properties. + +file(GLOB SOURCE_FILES "*.cc") +add_executable(sim_dev ${SOURCE_FILES}) +target_include_directories(sim_dev + PUBLIC "." + PUBLIC ".." + PUBLIC "../../../../../include" + PUBLIC "../../../../../3rdparty/dlpack/include" +) + +target_link_libraries(sim_dev "-ldl") diff --git a/tests/webgl/README.md b/src/runtime/hexagon/sim/driver/README.md similarity index 50% rename from tests/webgl/README.md rename to src/runtime/hexagon/sim/driver/README.md index 5303cc059740..3aee1a14b796 100644 --- a/tests/webgl/README.md +++ b/src/runtime/hexagon/sim/driver/README.md @@ -15,10 +15,24 @@ -## Test cases for the WebGL backend +# Hexagon simulator driver -Any test case with name `test_local_...` tests the C++ OpenGL backend on the -local OS, which can be executed automatically. +The driver (`sim_dev` executable) is the process running on the Hexagon simulator that handles the Hexagon-side communication with the TVM runtime running on x86. The location of `sim_dev` should be added to `PATH` before running any python code that uses Hexagon. The `sim_dev` executable is not intended to be run by users, it is automatically loaded by the simulator control code (in `hexagon_device_sim.cc`). -Any test case with name `test_remote_...` tests the WebGL backend within the -browser, which must be run manually. See instruction within the test. +### Prerequisites + +1. Hexagon C/C++ toolchain (such as the one in Hexagon SDK version 3.5.0 or later). + +Hexagon SDK is available at //developer.qualcomm.com/software/hexagon-dsp-sdk. + +### Configuring + +Set +``` +CMAKE_C_COMPILER=hexagon-clang +CMAKE_CXX_COMPILER=hexagon-clang++ +``` + +### Building + +There are no special options required for `make` (or the tool selected with `cmake`). The location of the resulting binary `sim_dev` should be added to `PATH`. diff --git a/src/runtime/hexagon/sim/driver/fake_pthread.cc b/src/runtime/hexagon/sim/driver/fake_pthread.cc new file mode 100644 index 000000000000..3613186908a2 --- /dev/null +++ b/src/runtime/hexagon/sim/driver/fake_pthread.cc @@ -0,0 +1,286 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include +#include + +#include "pthread.h" +#include "sched.h" + +/*! + * Implementation of a subset of pthread API for single-threaded execution. + * + * They main idea is that the thread function ("start_routine" in the call + * to pthread_create) is executed immediately. When pthread_create returns, + * the thread function has already finished. + * + * Since the thread routine can itself call pthread_create, it is possible + * to have multiple threads existing at the same time, although only the + * last one is running. + * + * There are two main things that need to be taken care of: + * - thread-specific data, i.e. pthread_setspecific, pthread_getspecific, + * and the handling of thread keys, + * - handling of thread return values. + * + * Threads are identified by thread ids (of type pthread_t). The main process + * thread has the id of 0, the remaining threads have ids starting at 1 and + * incrementing by 1. For each thread there is some data (thread_info_t) + * associated with it, and stored in "thread_data" map. When a thread + * terminates, the corresponding entry from "thread_data" cannot be removed + * until the return value is claimed (pthread_join), unless it is explicitly + * discarded (pthread_detach). When a new thread is created, it gets the + * first available id for which there is no entry in "thread_data". This + * could be an id that was never allocated, or an id that was used, but + * has since been removed from the map. + * A thread can terminate through thread_exit. This means that when the + * thread function calls thread_exit, the execution should return to the + * pthread_create call that ran it. This is implemented via setjmp/longjmp + * (neither longjmp nor pthread_exit unwind the stack). + * + * Any mutexes or condition variables cannot block, or else it would cause + * a deadlock. Since there is only one thread running at a time, locking + * a mutex or waiting for a condition always succeeds (returns immediately). + */ + +struct key_entry_t { + key_entry_t(void* v, void (*d)(void*)) : value(v), dtor(d) {} + void* value = nullptr; + void (*dtor)(void*) = nullptr; +}; + +struct thread_info_t { + thread_info_t() = default; + std::map keys; + std::jmp_buf env; + void* ret_value = nullptr; + bool finished = false; + bool detached = false; +}; + +static pthread_t main_thread_id = 0; + +static std::map thread_data = { + // Reserve the 0th entry. + {main_thread_id, {}}}; + +static std::vector running_threads = {main_thread_id}; + +template +K first_available_key(const std::map& m) { + auto i = m.begin(), e = m.end(); + K key = 1; + for (; i != e && key == i->first; ++i, ++key) { + } + return key; +} + +int pthread_cond_destroy(pthread_cond_t* cond) { return 0; } + +int pthread_cond_init(pthread_cond_t* __restrict cond, const pthread_condattr_t* __restrict attr) { + return 0; +} + +int pthread_cond_signal(pthread_cond_t* cond) { return 0; } + +int pthread_cond_broadcast(pthread_cond_t* cond) { return 0; } + +int pthread_cond_timedwait(pthread_cond_t* __restrict cond, pthread_mutex_t* __restrict mutex, + const struct timespec* __restrict abstime) { + return 0; +} + +int pthread_cond_wait(pthread_cond_t* __restrict cond, pthread_mutex_t* __restrict mutex) { + return 0; +} + +int pthread_mutexattr_init(pthread_mutexattr_t* attr) { return 0; } + +int pthread_mutexattr_destroy(pthread_mutexattr_t* attr) { return 0; } + +int pthread_mutexattr_settype(pthread_mutexattr_t* attr, int type) { return 0; } + +int pthread_mutexattr_gettype(const pthread_mutexattr_t* __restrict attr, int* __restrict type) { + *type = PTHREAD_MUTEX_NORMAL; + return 0; +} + +int pthread_mutex_init(pthread_mutex_t* __restrict mutex, + const pthread_mutexattr_t* __restrict attr) { + return 0; +} + +int pthread_mutex_destroy(pthread_mutex_t* mutex) { return 0; } + +int pthread_mutex_lock(pthread_mutex_t* mutex) { return 0; } + +int pthread_mutex_trylock(pthread_mutex_t* mutex) { return 0; } + +int pthread_mutex_unlock(pthread_mutex_t* mutex) { return 0; } + +int pthread_once(pthread_once_t* once_control, void (*init_routine)(void)) { + static_assert(PTHREAD_ONCE_INIT != PTHREAD_ONCE_DONE, + "PTHREAD_ONCE_INIT must be different from PTHREAD_ONCE_DONE"); + if (*once_control == PTHREAD_ONCE_INIT) { + init_routine(); + *once_control = PTHREAD_ONCE_DONE; + } + return 0; +} + +int pthread_equal(pthread_t t1, pthread_t t2) { return t1 == t2; } + +int pthread_create(pthread_t* thread, const pthread_attr_t* attr, void* (*start_routine)(void*), + void* arg) { + std::jmp_buf& env = thread_data[pthread_self()].env; + volatile pthread_t tid; + if (setjmp(env) == 0) { + tid = first_available_key(thread_data); + *thread = tid; + running_threads.push_back(pthread_t(tid)); + thread_info_t& thr = thread_data[pthread_t(tid)]; + thr.ret_value = start_routine(arg); + } + thread_info_t& thr = thread_data[pthread_t(tid)]; + thr.finished = true; + running_threads.pop_back(); + + // Destroy all keys. + bool repeat = true; + size_t iter = 0; + while (repeat && iter++ < PTHREAD_DESTRUCTOR_ITERATIONS) { + repeat = false; + // Assume that destructors can create new keys (i.e. modify the map). + for (size_t k = 0; k != PTHREAD_KEYS_MAX; ++k) { + auto f = thr.keys.find(k); + if (f == thr.keys.end()) { + continue; + } + key_entry_t& key = f->second; + if (key.dtor == nullptr || key.value == nullptr) { + continue; + } + key.dtor(key.value); + repeat = true; + } + } + + if (thr.detached) { + thread_data.erase(pthread_t(tid)); + } + + return 0; +} + +int pthread_join(pthread_t thread, void** retval) { + auto f = thread_data.find(thread); + if (f == thread_data.end()) { + return ESRCH; + } + thread_info_t& thr = f->second; + if (!thr.finished) { + return EDEADLK; + } + if (retval != nullptr) { + *retval = thr.ret_value; + } + thread_data.erase(f); + return 0; +} + +int pthread_detach(pthread_t thread) { + auto f = thread_data.find(thread); + if (f == thread_data.end()) { + return ESRCH; + } + // Can discard the return value. + f->second.detached = true; + return 0; +} + +void pthread_exit(void* retval) { + pthread_t sid = pthread_self(); + if (sid != main_thread_id) { + thread_info_t& self = thread_data[sid]; + self.ret_value = retval; + self.finished = true; + longjmp(self.env, 1); + } + exit(0); // Only executes for the main thread, plus silences + // the "should not return" warning. +} + +int pthread_key_create(pthread_key_t* key, void (*destructor)(void*)) { + if (key == nullptr) { + return EINVAL; + } + auto& keys = thread_data[pthread_self()].keys; + pthread_key_t k = first_available_key(keys); + if (k >= PTHREAD_KEYS_MAX) { + return EAGAIN; + } + *key = k; + keys.emplace(k, key_entry_t{nullptr, destructor}); + return 0; +} + +int pthread_key_delete(pthread_key_t key) { + auto& keys = thread_data[pthread_self()].keys; + auto f = keys.find(key); + if (f == keys.end()) { + return EINVAL; + } + // pthread_key_delete does not call key destructors. + keys.erase(f); + return 0; +} + +int pthread_setspecific(pthread_key_t key, const void* value) { + auto& keys = thread_data[pthread_self()].keys; + auto f = keys.find(key); + if (f == keys.end()) { + return EINVAL; + } + f->second.value = const_cast(value); + return 0; +} + +void* pthread_getspecific(pthread_key_t key) { + auto& keys = thread_data[pthread_self()].keys; + auto f = keys.find(key); + if (f != keys.end()) { + return f->second.value; + } + return nullptr; +} + +pthread_t pthread_self(void) { return running_threads.back(); } + +int sched_yield(void) { return 0; } + +#ifdef __cplusplus_ +extern "C" int nanosleep(const struct timespec* req, struct timespec* rem); +#endif + +int nanosleep(const struct timespec* req, struct timespec* rem) { return 0; } diff --git a/src/runtime/hexagon/sim/driver/pthread.h b/src/runtime/hexagon/sim/driver/pthread.h new file mode 100644 index 000000000000..7ec74b4f99f5 --- /dev/null +++ b/src/runtime/hexagon/sim/driver/pthread.h @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_RUNTIME_HEXAGON_SIM_DRIVER_PTHREAD_H_ +#define TVM_RUNTIME_HEXAGON_SIM_DRIVER_PTHREAD_H_ + +#define _PROVIDE_POSIX_TIME_DECLS 1 +#include +#undef _PROVIDE_POSIX_TIME_DECLS + +typedef int pthread_t; +typedef int pthread_attr_t; +typedef int pthread_cond_t; +typedef int pthread_condattr_t; +typedef int pthread_key_t; +typedef int pthread_mutex_t; +typedef int pthread_mutexattr_t; +typedef int pthread_once_t; + +enum { + PTHREAD_COND_INITIALIZER, + PTHREAD_MUTEX_DEFAULT, + PTHREAD_MUTEX_ERRORCHECK, + PTHREAD_MUTEX_INITIALIZER, + PTHREAD_MUTEX_NORMAL, + PTHREAD_MUTEX_RECURSIVE, + PTHREAD_ONCE_INIT = 0, // Must be same as in QuRT + PTHREAD_ONCE_DONE, // Non-standard +}; + +const size_t PTHREAD_KEYS_MAX = 128; +const size_t PTHREAD_DESTRUCTOR_ITERATIONS = 4; + +#ifdef __cplusplus +extern "C" { +#endif +int pthread_cond_destroy(pthread_cond_t* cond); +int pthread_cond_init(pthread_cond_t* __restrict cond, const pthread_condattr_t* __restrict attr); +int pthread_cond_signal(pthread_cond_t* cond); +int pthread_cond_broadcast(pthread_cond_t* cond); +int pthread_cond_timedwait(pthread_cond_t* __restrict cond, pthread_mutex_t* __restrict mutex, + const struct timespec* __restrict abstime); +int pthread_cond_wait(pthread_cond_t* __restrict cond, pthread_mutex_t* __restrict mutex); + +int pthread_mutexattr_init(pthread_mutexattr_t* attr); +int pthread_mutexattr_destroy(pthread_mutexattr_t* attr); +int pthread_mutexattr_gettype(const pthread_mutexattr_t* __restrict attr, int* __restrict type); +int pthread_mutexattr_settype(pthread_mutexattr_t* attr, int type); + +int pthread_mutex_init(pthread_mutex_t* __restrict mutex, + const pthread_mutexattr_t* __restrict attr); +int pthread_mutex_destroy(pthread_mutex_t* mutex); +int pthread_mutex_lock(pthread_mutex_t* mutex); +int pthread_mutex_trylock(pthread_mutex_t* mutex); +int pthread_mutex_unlock(pthread_mutex_t* mutex); + +int pthread_once(pthread_once_t* once_control, void (*init_routine)(void)); +int pthread_equal(pthread_t t1, pthread_t t2); + +int pthread_create(pthread_t* thread, const pthread_attr_t* attr, void* (*start_routine)(void*), + void* arg); +int pthread_join(pthread_t thread, void** retval); +int pthread_detach(pthread_t thread); +void pthread_exit(void* retval) __attribute__((__noreturn__)); + +int pthread_key_create(pthread_key_t* key, void (*destructor)(void*)); +int pthread_key_delete(pthread_key_t key); +int pthread_setspecific(pthread_key_t key, const void* value); +void* pthread_getspecific(pthread_key_t key); + +pthread_t pthread_self(void); +#ifdef __cplusplus +} +#endif + +#endif // TVM_RUNTIME_HEXAGON_SIM_DRIVER_PTHREAD_H_ diff --git a/web/example_rpc_node.js b/src/runtime/hexagon/sim/driver/sched.h similarity index 60% rename from web/example_rpc_node.js rename to src/runtime/hexagon/sim/driver/sched.h index 45f917a3234b..cc63630f2072 100644 --- a/web/example_rpc_node.js +++ b/src/runtime/hexagon/sim/driver/sched.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -17,17 +17,15 @@ * under the License. */ -// Javascript RPC server example -// Start and connect to websocket proxy. +#ifndef TVM_RUNTIME_HEXAGON_SIM_DRIVER_SCHED_H_ +#define TVM_RUNTIME_HEXAGON_SIM_DRIVER_SCHED_H_ -// Load Emscripten Module, need to change path to root/lib -const path = require("path"); -process.chdir(path.join(__dirname, "../lib")); -var Module = require("../lib/libtvm_web_runtime.js"); -// Bootstrap TVMruntime with emscripten module. -const tvm_runtime = require("../web/tvm_runtime.js"); -const tvm = tvm_runtime.create(Module); +#ifdef __cplusplus +extern "C" { +#endif +int sched_yield(void); +#ifdef __cplusplus +} +#endif -var websock_proxy = "ws://localhost:9190/ws"; -var num_sess = 100; -tvm.startRPCServer(websock_proxy, "js", num_sess) +#endif // TVM_RUNTIME_HEXAGON_SIM_DRIVER_SCHED_H_ diff --git a/src/runtime/hexagon/sim/driver/sim_device.cc b/src/runtime/hexagon/sim/driver/sim_device.cc new file mode 100644 index 000000000000..c8cf7838948e --- /dev/null +++ b/src/runtime/hexagon/sim/driver/sim_device.cc @@ -0,0 +1,560 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* + Required options: + -ldl -G0 For dlinit/dlopen/dlclose. + -Wl,--force-dynamic Make this a dynamic executable (with dynamic + symbol table). + -Wl,-E Export all defined symbols as dynamic. + -Wl,--whole-archive Link the entire contents of libc. + -mhvx -mhvx-length=128b Enable HVX. + -Wno-format Silence format warning (unsigned vs uint32_t). +*/ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "hexagon_sim_proto.h" +#include "pthread.h" +#include "tvm/runtime/c_runtime_api.h" + +static std::string timeNow() { + char str[11]; // [hh:mm:ss] + time_t time_value = time(NULL); + tm* pnow = localtime(&time_value); // NOLINT(runtime/threadsafe_fn) + + snprintf(str, sizeof(str), "[%02d:%02d:%02d]", pnow->tm_hour, pnow->tm_min, pnow->tm_sec); + return std::string(str); +} + +#define LOG(FMT, ...) \ + fprintf(stderr, "%s %s:%d: " FMT "\n", timeNow().c_str(), __FILE__, __LINE__, ##__VA_ARGS__) + +using HVX_Vector = int __attribute__((__vector_size__(128))) __attribute__((aligned(128))); + +static unsigned getVectorLength() { + HVX_Vector v = __builtin_HEXAGON_V6_lvsplatw_128B(0x01010101); + unsigned char* p = reinterpret_cast(&v); + if (p[127] == 1) return 128; + assert(p[63] == 1); + return 64; +} + +extern "C" { +// Print vector functions. They can be used to help debug tensorized +// code, via +// ib.emit(tvm.call_extern('int32', 'V6_pv8', 'vector:', v)) +// ib.emit(tvm.call_extern('int32', 'V6_pv16', 'info:', v)) +// ib.emit(tvm.call_extern('int32', 'V6_pv32', 'value:', v)) + +// The first argument is a string printed before the vector contents. +int V6_pv8(const char* s, HVX_Vector v); +int V6_pv16(const char* s, HVX_Vector v); +int V6_pv32(const char* s, HVX_Vector v); +} + +int V6_pv8(const char* s, HVX_Vector v) { + unsigned vlen = getVectorLength(); + uint8_t* ptr = reinterpret_cast(&v); + fprintf(stderr, "%s:", s); + for (unsigned i = 0; i != vlen; ++i) { + fprintf(stderr, " %02x", ptr[i]); + } + fprintf(stderr, "\n"); + return 0; +} + +int V6_pv16(const char* s, HVX_Vector v) { + unsigned vlen = getVectorLength(); + uint16_t* ptr = reinterpret_cast(&v); + fprintf(stderr, "%s:", s); + for (unsigned i = 0; i != vlen / sizeof(uint16_t); ++i) { + fprintf(stderr, " %04x", ptr[i]); + } + fprintf(stderr, "\n"); + return 0; +} + +int V6_pv32(const char* s, HVX_Vector v) { + unsigned vlen = getVectorLength(); + uint32_t* ptr = reinterpret_cast(&v); + fprintf(stderr, "%s:", s); + for (unsigned i = 0; i != vlen / sizeof(uint32_t); ++i) { + fprintf(stderr, " %08x", ptr[i]); + } + fprintf(stderr, "\n"); + return 0; +} + +extern "C" { +// Function referenced from libc++.a, but not defined in libc.a. +int clock_gettime(clockid_t clock_id, struct timespec* tp); +// pthread_create is wrapped so that we can set a bigger stack size +// for QuRT. Here this isn't needed, but we still need to implement +// the wrapper. +int __wrap_pthread_create(pthread_t* thread, const pthread_attr_t* attr, + void* (*start_routine)(void*), void* arg); +} + +int clock_gettime(clockid_t clock_id, struct timespec* tp) { + // Stub implementation. + return 0; +} + +int __wrap_pthread_create(pthread_t* thread, const pthread_attr_t* attr, + void* (*start_routine)(void*), void* arg) { + LOG("%s", __func__); + return pthread_create(thread, attr, start_routine, arg); +} + +// FIXME(kparzysz-quic): query the cfg register to compute the VTCM base. +// This works now. +const unsigned int TCM_BASE = 0xD8000000; +const unsigned int VTCM_BASE = TCM_BASE + 0x400000; + +class Allocator { + private: + struct Block { + Block(void* p, size_t s) : ptr_(p), size_(s), vtcm_(false) {} + Block(void* p, size_t s, bool v) : ptr_(p), size_(s), vtcm_(v) {} + bool operator<(const Block& b) const { return uintptr_t(ptr_) < uintptr_t(b.ptr_); } + void* ptr_; + size_t size_; + bool vtcm_; + }; + + using vector_type = std::vector; + using iterator = vector_type::iterator; + vector_type allocations_; + + uintptr_t cur_vtcm = VTCM_BASE; + + public: + void* alloc(unsigned size, size_t align); + void* vtcm_alloc(unsigned size, size_t align); + void free(void* p); +}; + +void* Allocator::alloc(unsigned size, size_t align) { + void* ptr = aligned_alloc(align, size); + if (ptr == nullptr) { + perror("device: error allocating memory:"); + return ptr; + } + + Block b(ptr, size); + iterator i = std::lower_bound(allocations_.begin(), allocations_.end(), b); + iterator w = allocations_.insert(i, b); + if (w != allocations_.begin()) { + iterator pw = w - 1; + assert(uintptr_t(pw->ptr_) + pw->size_ < uintptr_t(w->ptr_)); + } + if (w + 1 != allocations_.end()) { + iterator nw = w + 1; + assert(uintptr_t(w->ptr_) + w->size_ <= uintptr_t(nw->ptr_)); + } + + LOG("device: allocated %d bytes aligned at %d: %p", size, align, ptr); + return ptr; +} + +// For now, just allocation sequentially. This needs to be improved to use a +// free list. +void* Allocator::vtcm_alloc(unsigned size, size_t align) { + uintptr_t a = cur_vtcm; + a = (a + (align - 1)) & -align; + cur_vtcm = a + size; + void* ptr = reinterpret_cast(a); + if (ptr == nullptr) { + perror("device: error allocating vtcm memory:"); + return ptr; + } + + Block b(ptr, size, true); + iterator i = std::lower_bound(allocations_.begin(), allocations_.end(), b); + iterator w = allocations_.insert(i, b); + if (w != allocations_.begin()) { + iterator pw = w - 1; + assert(uintptr_t(pw->ptr_) + pw->size_ <= uintptr_t(w->ptr_)); + } + if (w + 1 != allocations_.end()) { + iterator nw = w + 1; + assert(uintptr_t(w->ptr_) + w->size_ <= uintptr_t(nw->ptr_)); + } + + LOG("device: allocated vtcm %d bytes aligned at %d: %p", size, align, ptr); + return ptr; +} + +void Allocator::free(void* ptr) { + LOG("device: freeing %p", ptr); + iterator i = std::lower_bound(allocations_.begin(), allocations_.end(), Block(ptr, 0)); + assert(i != allocations_.end()); + assert(i->ptr_ == ptr); + if (!i->vtcm_) ::free(i->ptr_); + allocations_.erase(i); +} + +static void printMsgCall(const MsgCall& mc) { + auto to_dec_string = [](int v) { + char tmp[11]; + snprintf(tmp, sizeof(tmp), "%d", v); + return std::string(tmp); + }; + auto to_hex_string = [](uint32_t v) { + char tmp[9]; + snprintf(tmp, sizeof(tmp), "%lx", v); + return std::string(tmp); + }; + std::string str = "device: launching " + to_hex_string(mc.func_va) + + " sc:" + to_dec_string(mc.scalar_num) + " {"; + for (unsigned i = 0; i != mc.scalar_num; ++i) { + str += ' ' + to_hex_string(mc.data[i]); + if (i + 1 != mc.scalar_num) str += ','; + } + str += " }, st:" + to_dec_string(mc.stack_num) + " {"; + for (unsigned i = 0; i != mc.stack_num; ++i) { + str += ' ' + to_hex_string(mc.data[i + mc.scalar_num]); + if (i + 1 != mc.stack_num) str += ','; + } + str += " }"; + LOG("%s", str.c_str()); +} + +static std::vector task_queue; + +struct Environment { + Allocator alloc; + void* dl_handle = nullptr; +}; + +extern "C" { +volatile Message message_buffer; +int dispatch(Environment* env) __attribute__((noinline)); +} + +static volatile unsigned char payload_buffer[4096]; + +static void setMsg(uint32_t code, uint32_t len, uint32_t va) { + message_buffer.code = code; + message_buffer.len = len; + message_buffer.va = va; +} + +inline void* pointer(uint32_t v) { return reinterpret_cast(static_cast(v)); } + +inline uint32_t va(const volatile void* p) { + return static_cast(reinterpret_cast(p)); +} + +__attribute__((naked)) uint32_t launcher(volatile MsgCall* mc, uint64_t* pcc) { + __asm__( + "// This function is intentionally written to be readable, \n" + "// rather than fast. \n" + "// r0 = value of 'volatile MsgCall *mc' \n" + "// r1 = address where to store the program cycle count \n" + "{ memd(r29+#-16) = r21:20 \n" + " allocframe(#24) } \n" + "{ memd(r29+#0) = r17:16 \n" + " memd(r29+#8) = r19:18 } \n" + "{ r17:16 = combine(r1,r0) \n" + " r18 = r29 \n" + " r1 = memw(r0+#4) // scalar_num \n" + " r2 = memw(r0+#8) } // stack_num \n" + "// If there are no stack values, skip the stack setup. \n" + "{ p0 = cmp.eq(r2,#0) \n" + " if (p0.new) jump:t .Llauncher1 } \n" + + "// Allocate space on the stack. Let r2 = needed space \n" + "// rounded up to a multiple of 8. \n" + "{ loop0(.Llauncher0,r2) \n" + " r2 = asl(r2,#2) } \n" + "{ r2 = add(r2,#4) } \n" + "{ r2 = clrbit(r2,#2) } \n" + "{ r29 = sub(r29,r2) } \n" + + "// Copy stack contents onto the stack. Stack contents start \n" + "// at r3 = r0 + offsetof(data) + scalar_num*4 \n" + "{ r3 = addasl(r0,r1,#2) \n" + " r4 = r29 } \n" + "{ r3 = add(r3,#12) } // offsetof(data) \n" + ".Llauncher0: \n" + "{ r5 = memw(r3++#4) \n" + " memw(r4++#4) = r5.new } :endloop0 \n" + + "// Load registers. Some of the loaded data may actually be \n" + "// values from the stack part of 'data', but it's not an issue.\n" + ".Llauncher1: \n" + "{ r0 = memw(r16+#12) // mc + offsetof(data) \n" + " r1 = memw(r16+#16) } \n" + "{ r2 = memw(r16+#20) \n" + " r3 = memw(r16+#24) } \n" + "{ r4 = memw(r16+#28) \n" + " r5 = memw(r16+#32) } \n" + + "// Call. \n" + "{ r6 = memw(r16+#0) \n" + " r21:20 = upcycle } \n" + "{ callr r6 } \n" + + "// Restore stack pointer (free up r18), calculate cycle count. \n" + "{ r29 = r18 \n" + " r19:18 = upcycle } \n" + "{ r19:18 = sub(r19:18, r21:20) } \n" + + "// Store pcount, restore non-volatile registers, and return. \n" + "{ memd(r17+#0) = r19:18 \n" + " r21:20 = memd(r29+#16) } \n" + "{ r19:18 = memd(r29+#8) \n" + " r17:16 = memd(r29+#0) } \n" + "{ dealloc_return } // implicit-use r1:0 \n"); +} + +int dispatch(Environment* env) { + uint32_t code = message_buffer.code; + // Special handling of MsgReq. + if (code == kMsgReq) { + assert(message_buffer.len <= sizeof(payload_buffer)); + setMsg(kMsgAck, sizeof(payload_buffer), va(payload_buffer)); + return 0; + } + + switch (code) { + case kAlloc: { + LOG("device: {kAlloc, %lu, %lx}", message_buffer.len, message_buffer.va); + assert(message_buffer.len == sizeof(MsgAlloc)); + auto* ma = reinterpret_cast(message_buffer.va); + void* p = env->alloc.alloc(ma->size, ma->align); + reinterpret_cast(payload_buffer)->va = va(p); + setMsg(kNone, sizeof(MsgPointer), va(payload_buffer)); + break; + } + case kFree: { + LOG("device: {kFree, %lu, %lx}", message_buffer.len, message_buffer.va); + assert(message_buffer.len == sizeof(MsgPointer)); + auto* mp = reinterpret_cast(message_buffer.va); + env->alloc.free(pointer(mp->va)); + setMsg(kNone, 0u, 0u); + break; + } + case kAllocVtcm: { + LOG("device: {kAllocVtcm, %lu, %lx}", message_buffer.len, message_buffer.va); + assert(message_buffer.len == sizeof(MsgAlloc)); + auto* ma = reinterpret_cast(message_buffer.va); + void* p = env->alloc.vtcm_alloc(ma->size, ma->align); + reinterpret_cast(payload_buffer)->va = va(p); + setMsg(kNone, sizeof(MsgPointer), va(payload_buffer)); + break; + } + case kCopy: { + LOG("device: {kCopy, %lu, %lx}", message_buffer.len, message_buffer.va); + assert(message_buffer.len == sizeof(MsgCopy)); + auto* mc = reinterpret_cast(message_buffer.va); + memcpy(pointer(mc->dst), pointer(mc->src), mc->len); + setMsg(kNone, 0u, 0u); + break; + } + case kLoad: { + if (env->dl_handle != nullptr) dlclose(env->dl_handle); + const char* name = static_cast(pointer(message_buffer.va)); + // LOG(stderr, "device: dlopen(%s)", name); + env->dl_handle = dlopen(name, RTLD_LAZY); + if (env->dl_handle == nullptr) LOG("dlopen: %s\n", dlerror()); + assert(env->dl_handle != nullptr); + reinterpret_cast(payload_buffer)->va = va(env->dl_handle); + setMsg(kNone, sizeof(MsgPointer), va(payload_buffer)); + break; + } + case kUnload: { + assert(env->dl_handle != nullptr); + assert(message_buffer.len == sizeof(MsgPointer)); + auto* mp = reinterpret_cast(message_buffer.va); + assert(pointer(mp->va) == env->dl_handle); + dlclose(env->dl_handle); + env->dl_handle = nullptr; + setMsg(kNone, 0u, 0u); + break; + } + case kResolve: { + LOG("device: {kResolve, %lu, %lx}", message_buffer.len, message_buffer.va); + assert(env->dl_handle != nullptr); + dlerror(); + const char* name = static_cast(pointer(message_buffer.va)); + void* s = dlsym(env->dl_handle, name); + reinterpret_cast(payload_buffer)->va = va(s); + setMsg(kNone, sizeof(MsgPointer), va(payload_buffer)); + break; + } + case kCall: { + LOG("device: {kCall, %lu, %lx}", message_buffer.len, message_buffer.va); + // Add the task to the queue. + auto* mc = reinterpret_cast(message_buffer.va); + uint32_t size = 4 * (3 + mc->scalar_num + mc->stack_num); + MsgCall* t = static_cast(malloc(size)); + memcpy(t, mc, size); + task_queue.push_back(t); + // Return 0. + *reinterpret_cast(payload_buffer) = 0; + setMsg(kNone, sizeof(uint32_t), va(payload_buffer)); + break; + } + case kFlush: { + LOG("device: {kFlush}"); + LOG("device: %d tasks in the queue", task_queue.size()); + // Execute all tasks from the queue and release memory buffers + // for as long as the return values are 0. Upon receiving a non-zero + // return value, continue freeing memory but no longer execute + // any tasks. The task queue will be cleared in any case. + uint32_t rv = 0; + uint64_t pcc; // Pcycle counter, will be 0 under simulator (upcycle). + for (MsgCall* t : task_queue) { + if (rv == 0) { + printMsgCall(*t); + rv = launcher(t, &pcc); + LOG("device: execution took %lld pcycles", pcc); + } + free(t); + } + task_queue.clear(); + *reinterpret_cast(payload_buffer) = rv; + setMsg(kNone, sizeof(uint32_t), va(payload_buffer)); + break; + } + default: + LOG("device: unknown code: %lu", message_buffer.code); + abort(); + break; + } + return 0; +} + +extern "C" { +int acquire_vector_unit(int); +void release_vector_unit(); +} + +static void makePathList(const std::string& arg, std::vector* list) { + size_t p = 0, e = arg.size(); + std::vector tmp; + + while (p < e) { + tmp.clear(); + bool check_next = true; + size_t i = p; + for (; i != e; ++i) { + char c = arg[i]; + if (check_next) { + if (c == '\\') { + check_next = false; + continue; + } else if (c == ':') { + break; + } + } + check_next = true; + tmp.push_back(c); + } + if (!tmp.empty()) list->emplace_back(tmp.begin(), tmp.end()); + p = i + 1; + } +} + +static std::string findInPaths(const std::string& filename, const std::string& paths) { + std::vector path_list; + makePathList(paths, &path_list); + + for (const auto& p : path_list) { + std::string pf = p + '/' + filename; + if (access(pf.c_str(), X_OK) == 0) return std::move(pf); + } + // If the search failed, try bare filename. If it cannot be loaded, + // dlerror will print a meaningful message. + return filename; +} + +// Presence of this function indicates that sim_dev is running. +extern "C" int running_in_sim_dev_17bc90206f6cf5a7(); +int running_in_sim_dev_17bc90206f6cf5a7() { return 0; } + +int main(int argc, char* argv[]) { + int opt; + std::string ld_path; + while ((opt = getopt(argc, argv, "L:")) != -1) { + switch (opt) { + case 'L': + ld_path += ':' + std::string(optarg); + break; + case '?': + LOG("Usage %s: [-L path1[:path2...]]", argv[0]); + return 1; + } + } + + std::string rt_path = findInPaths("libtvm_runtime.so", ld_path); + LOG("TVM runtime path: %s", rt_path.c_str()); + + Environment env; + acquire_vector_unit(0); + + const char* builtin[] = { + "libgcc.so", "libc.so", "libc++.so", + "libc++abi.so", "libc++.so.1", "libc++abi.so.1" // Alternative names. + }; + dlinit(sizeof(builtin) / sizeof(builtin[0]), const_cast(builtin)); + void* rt_handle = dlopen(rt_path.c_str(), RTLD_GLOBAL); + if (rt_handle == nullptr) { + LOG("error loading TVM runtime: %s", dlerror()); + return 1; + } + + // When running TVM runtime on Hexagon there is no longer a device + // for Hexagon, but standalone ops can still refer to it. All of + // required DeviceAPI's functionality is adequately implemented + // via the CPU device, so remap device_api.hexagon to device_api.cpu. + auto* get_global = + reinterpret_cast(dlsym(rt_handle, "TVMFuncGetGlobal")); + assert(get_global != nullptr); + auto* register_global = + reinterpret_cast(dlsym(rt_handle, "TVMFuncRegisterGlobal")); + assert(register_global != nullptr); + + TVMFunctionHandle cpu_api; + if (get_global("device_api.cpu", &cpu_api) != 0 || + register_global("device_api.hexagon", cpu_api, true) != 0) { + LOG("error setting device_api.hexagon"); + return 1; + } + + while (!dispatch(&env)) { + } + + dlclose(rt_handle); + release_vector_unit(); + return 0; +} diff --git a/src/runtime/hexagon/sim/hexagon_device_sim.cc b/src/runtime/hexagon/sim/hexagon_device_sim.cc index b58377baa947..477da09c1c65 100644 --- a/src/runtime/hexagon/sim/hexagon_device_sim.cc +++ b/src/runtime/hexagon/sim/hexagon_device_sim.cc @@ -41,8 +41,7 @@ namespace tvm { namespace runtime { namespace hexagon { -static_assert(sizeof(HEX_VA_t) == sizeof(uint32_t), - "Hexagon VA must be uint32"); +static_assert(sizeof(HEX_VA_t) == sizeof(uint32_t), "Hexagon VA must be uint32"); template struct unalign { @@ -89,8 +88,7 @@ std::unique_ptr make_unique(size_t size) { // user from memory reallocation and copying. struct non_const_str { non_const_str() {} - explicit non_const_str(const std::string& str) - : non_const_str(std::vector{str}) {} + explicit non_const_str(const std::string& str) : non_const_str(std::vector{str}) {} explicit non_const_str(const std::vector& vec) { for (const std::string& s : vec) { auto c = detail::make_unique(s.size() + 1); @@ -220,8 +218,7 @@ class HexagonSimulator final : public tvm::runtime::hexagon::Device { void* Load(const std::string& data, const std::string& fmt) final; void Unload(void* mod) final; void* Resolve(const std::string& sym) final; - void Call(void* func, uint32_t* scalar, unsigned sc_num, uint32_t* stack, - unsigned st_num) final; + void Call(void* func, uint32_t* scalar, unsigned sc_num, uint32_t* stack, unsigned st_num) final; static std::string to_string(HEXAPI_Status status); @@ -312,10 +309,8 @@ class HexagonSimulator final : public tvm::runtime::hexagon::Device { bool should_parse_next(const string_list& rest); llvm::Optional to_interval(const detail::MaybeString& str); - llvm::Optional to_timingmode( - const detail::MaybeString& str); - llvm::Optional to_verbosemode( - const detail::MaybeString& str); + llvm::Optional to_timingmode(const detail::MaybeString& str); + llvm::Optional to_verbosemode(const detail::MaybeString& str); llvm::Optional to_nullptr(const detail::MaybeString& str); MaybeUIntRange ahb_, axi2_; @@ -399,12 +394,11 @@ decltype(HexagonSimulator::opt_map_) HexagonSimulator::opt_map_ = { {"--verbose", &HexagonSimulator::HandleVerbose}, }; -#define CHECKED_CALL(func, ...) \ - do { \ - HEXAPI_Status s = sim_->func(__VA_ARGS__); \ - CHECK_EQ(s, HEX_STAT_SUCCESS) \ - << "HexagonSimulator: " #func " failed with code " \ - << HexagonSimulator::to_string(s); \ +#define CHECKED_CALL(func, ...) \ + do { \ + HEXAPI_Status s = sim_->func(__VA_ARGS__); \ + CHECK_EQ(s, HEX_STAT_SUCCESS) << "HexagonSimulator: " #func " failed with code " \ + << HexagonSimulator::to_string(s); \ } while (false) inline HEX_VA_t HexagonSimulator::p2va(const void* p) { @@ -444,8 +438,7 @@ void HexagonSimulator::CopyNFromV(void* host_dst, HEX_VA_t src) { pd->value = v; } -void HexagonSimulator::CopyToV(HEX_VA_t dst, const void* host_src, - unsigned len) { +void HexagonSimulator::CopyToV(HEX_VA_t dst, const void* host_src, unsigned len) { const uint8_t* src = static_cast(host_src); while (len >= 8) { @@ -556,18 +549,15 @@ HexagonSimulator::HexagonSimulator(bool enable_queuing) { using iterator = std::istream_iterator; auto sim_args = string_list(iterator(sim_args_iss), iterator()); - std::string target_str = - !sim_args.empty() ? *detail::pop_front(sim_args) : std::string("v66"); + std::string target_str = !sim_args.empty() ? *detail::pop_front(sim_args) : std::string("v66"); arch_ = target_str; - sim_ = - detail::make_unique(detail::non_const_str(target_str)); + sim_ = detail::make_unique(detail::non_const_str(target_str)); LOG(INFO) << "HexagonSimulator: Core version: " << arch_; // Locate the sim_dev binary in PATH, or in the current working directory. llvm::StringRef sim_dev = "sim_dev"; - detail::MaybeString path_sim_dev = - llvm::sys::Process::FindInEnvPath("PATH", sim_dev); + detail::MaybeString path_sim_dev = llvm::sys::Process::FindInEnvPath("PATH", sim_dev); if (!path_sim_dev) { if (!llvm::sys::fs::exists(sim_dev)) { LOG(FATAL) << "Cannot find sim_dev in PATH."; @@ -615,8 +605,7 @@ HexagonSimulator::HexagonSimulator(bool enable_queuing) { } void* HexagonSimulator::Alloc(unsigned size, unsigned align) { - LOG(INFO) << "HexagonSimulator::Alloc(size=" << size << ", align=" << align - << ')'; + LOG(INFO) << "HexagonSimulator::Alloc(size=" << size << ", align=" << align << ')'; Message m = {kAlloc, sizeof(MsgAlloc), 0u}; MsgAlloc ma = {size, align}; SendMsg(m, &ma, true); @@ -631,8 +620,7 @@ void* HexagonSimulator::Alloc(unsigned size, unsigned align) { } void HexagonSimulator::Free(void* ptr) { - LOG(INFO) << "HexagonSimulator::Free(ptr=" << std::hex << ptr << std::dec - << ')'; + LOG(INFO) << "HexagonSimulator::Free(ptr=" << std::hex << ptr << std::dec << ')'; if (task_queuing_) { Message mf = {kFlush, 0, 0}; SendMsg(mf, 0, true); @@ -643,8 +631,7 @@ void HexagonSimulator::Free(void* ptr) { } void* HexagonSimulator::AllocVtcm(unsigned size, unsigned align) { - LOG(INFO) << "HexagonSimulator::AllocVtcm(size=" << size - << ", align=" << align << ')'; + LOG(INFO) << "HexagonSimulator::AllocVtcm(size=" << size << ", align=" << align << ')'; Message m = {kAllocVtcm, sizeof(MsgAlloc), 0u}; MsgAlloc ma = {size, align}; SendMsg(m, &ma, true); @@ -653,28 +640,25 @@ void* HexagonSimulator::AllocVtcm(unsigned size, unsigned align) { MsgPointer mp; CopyFromV(&mp, m.va, m.len); - LOG(INFO) << "HexagonSimulator::AllocVtcm -> " << std::hex << mp.va - << std::dec; + LOG(INFO) << "HexagonSimulator::AllocVtcm -> " << std::hex << mp.va << std::dec; CHECK_NE(mp.va, 0); return va2p(mp.va); } void HexagonSimulator::FreeVtcm(void* ptr) {} -void HexagonSimulator::CopyDeviceToDevice(void* dst, const void* src, - unsigned len) { - LOG(INFO) << "HexagonSimulator::CopyDeviceToDevice(dst=" << std::hex << dst - << ", src=" << src << ", len=" << std::dec << len << ')'; +void HexagonSimulator::CopyDeviceToDevice(void* dst, const void* src, unsigned len) { + LOG(INFO) << "HexagonSimulator::CopyDeviceToDevice(dst=" << std::hex << dst << ", src=" << src + << ", len=" << std::dec << len << ')'; CHECK(dst != nullptr && src != nullptr); Message m = {kCopy, sizeof(MsgCopy), 0u}; MsgCopy mc = {p2va(dst), p2va(src), len}; SendMsg(m, &mc, true); } -void HexagonSimulator::CopyDeviceToHost(void* host_dst, const void* src, - unsigned len) { - LOG(INFO) << "HexagonSimulator::CopyDeviceToHost(host_dst=" << host_dst - << ", src=" << src << ", len=" << len << ')'; +void HexagonSimulator::CopyDeviceToHost(void* host_dst, const void* src, unsigned len) { + LOG(INFO) << "HexagonSimulator::CopyDeviceToHost(host_dst=" << host_dst << ", src=" << src + << ", len=" << len << ')'; if (task_queuing_) { Message mf = {kFlush, 0, 0}; SendMsg(mf, 0, true); @@ -682,10 +666,9 @@ void HexagonSimulator::CopyDeviceToHost(void* host_dst, const void* src, CopyFromV(host_dst, p2va(src), len); } -void HexagonSimulator::CopyHostToDevice(void* dst, const void* host_src, - unsigned len) { - LOG(INFO) << "HexagonSimulator::CopyHostToDevice(dst=" << dst - << ", host_src=" << host_src << ", len=" << len << ')'; +void HexagonSimulator::CopyHostToDevice(void* dst, const void* host_src, unsigned len) { + LOG(INFO) << "HexagonSimulator::CopyHostToDevice(dst=" << dst << ", host_src=" << host_src + << ", len=" << len << ')'; CopyToV(p2va(dst), host_src, len); } @@ -717,19 +700,17 @@ void* HexagonSimulator::Resolve(const std::string& sym) { MsgPointer mp; CopyFromV(&mp, m.va, sizeof(mp)); - LOG(INFO) << "HexagonSimulator::Resolve -> " << std::hex << mp.va - << std::dec; + LOG(INFO) << "HexagonSimulator::Resolve -> " << std::hex << mp.va << std::dec; return va2p(mp.va); } -void HexagonSimulator::Call(void* func, uint32_t* scalar, unsigned sc_num, - uint32_t* stack, unsigned st_num) { - LOG(INFO) << "HexagonSimulator::Call(func=" << std::hex << func - << ", scalar=" << scalar << ", sc_num=" << std::dec +void HexagonSimulator::Call(void* func, uint32_t* scalar, unsigned sc_num, uint32_t* stack, + unsigned st_num) { + LOG(INFO) << "HexagonSimulator::Call(func=" << std::hex << func << ", scalar=" << scalar + << ", sc_num=" << std::dec << sc_num // NOLINTNEXTLINE(build/include_what_you_use) - << ", stack=" << std::hex << stack << ", st_num=" << std::dec - << st_num; + << ", stack=" << std::hex << stack << ", st_num=" << std::dec << st_num; std::vector data; @@ -753,8 +734,7 @@ void HexagonSimulator::Call(void* func, uint32_t* scalar, unsigned sc_num, log_data << std::dec << " }" << std::flush; LOG(INFO) << log_data.str(); - Message m = {kCall, static_cast(data.size() * sizeof(uint32_t)), - 0u}; + Message m = {kCall, static_cast(data.size() * sizeof(uint32_t)), 0u}; SendMsg(m, data.data(), true); if (!task_queuing_) { @@ -768,8 +748,7 @@ void HexagonSimulator::Call(void* func, uint32_t* scalar, unsigned sc_num, std::ostringstream log_rv; log_rv << "HexagonSimulator::Call -> {" << std::hex; for (unsigned i = 0, e = std::min(rv.size(), 4u); i != e; ++i) { - log_rv << ' ' << std::setw(2) << std::setfill('0') - << static_cast(rv[i]); + log_rv << ' ' << std::setw(2) << std::setfill('0') << static_cast(rv[i]); } if (rv.size() > 4) log_rv << "..."; log_rv << std::dec << " }"; @@ -1059,8 +1038,7 @@ bool HexagonSimulator::HandlePacketAnalyze(string_list& rest) { } bool HexagonSimulator::HandlePCFilter(string_list& rest) { - auto range = - detail::to_range(detail::pop_front(rest)); + auto range = detail::to_range(detail::pop_front(rest)); if (range) { CHECKED_CALL(ConfigurePCRangeFilter, range->first, range->second); } @@ -1222,11 +1200,9 @@ bool HexagonSimulator::HandleTCMLowAddr(string_list& rest) { } bool HexagonSimulator::HandleTimeFilterNS(string_list& rest) { - auto range = - detail::to_range(detail::pop_front(rest)); + auto range = detail::to_range(detail::pop_front(rest)); if (range) { - CHECKED_CALL(ConfigureTimeRangeFilter, range->first, HEX_NANOSEC, - range->second, HEX_NANOSEC); + CHECKED_CALL(ConfigureTimeRangeFilter, range->first, HEX_NANOSEC, range->second, HEX_NANOSEC); } return static_cast(range); } @@ -1284,8 +1260,7 @@ bool HexagonSimulator::should_parse_next(const string_list& rest) { return false; } -llvm::Optional HexagonSimulator::to_interval( - const detail::MaybeString& str) { +llvm::Optional HexagonSimulator::to_interval(const detail::MaybeString& str) { auto none = llvm::Optional(); if (!str) return none; @@ -1309,8 +1284,7 @@ llvm::Optional HexagonSimulator::to_interval( .Default(none); } -llvm::Optional HexagonSimulator::to_timingmode( - const detail::MaybeString& str) { +llvm::Optional HexagonSimulator::to_timingmode(const detail::MaybeString& str) { auto none = llvm::Optional(); if (!str) return none; @@ -1357,8 +1331,7 @@ llvm::Optional HexagonSimulator::to_verbosemode( .Default(none); } -llvm::Optional HexagonSimulator::to_nullptr( - const detail::MaybeString& str) { +llvm::Optional HexagonSimulator::to_nullptr(const detail::MaybeString& str) { auto none = llvm::Optional(); if (!str) return none; diff --git a/src/runtime/hexagon/target/fastrpc/CMakeLists.txt b/src/runtime/hexagon/target/fastrpc/CMakeLists.txt new file mode 100644 index 000000000000..072b9ca62fb2 --- /dev/null +++ b/src/runtime/hexagon/target/fastrpc/CMakeLists.txt @@ -0,0 +1,179 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +cmake_minimum_required(VERSION 3.2) +project(HexagonIDL C CXX) + +if(NOT "${FASTRPC_LIBS}" STREQUAL "SKEL" AND + NOT "${FASTRPC_LIBS}" STREQUAL "STUB") + message(SEND_ERROR "Please set FASTRPC_LIBS to either SKEL or STUB") +endif() + + +set(FASTRPC_SRC "${CMAKE_CURRENT_SOURCE_DIR}") + +include_directories(include) +include_directories(${HEXAGON_SDK_ROOT}/incs) +include_directories(${HEXAGON_SDK_ROOT}/incs/stddef) +include_directories( + ${HEXAGON_SDK_ROOT}/libs/common/remote/ship/android_Release_aarch64) + +set(QAIC_EXE "${HEXAGON_SDK_ROOT}/tools/qaic/Ubuntu16/qaic") +set(QAIC_FLAGS + "-I${HEXAGON_SDK_ROOT}/incs/stddef" + "-I${HEXAGON_SDK_ROOT}/libs/common/remote/ship/android_Release_aarch64" + "-I${HEXAGON_SDK_ROOT}/libs/common/rpcmem/inc" +) + +set(CMAKE_SKIP_RPATH TRUE) + +# Qaic for the non-domain header. +# +# Don't add paths to these filenames, or otherwise cmake may spontaneously +# add -o option to the qaic invocation (with an undesirable path). +set(TVM_REMOTE_ND_IDL "tvm_remote_nd.idl") +set(TVM_REMOTE_ND_H "tvm_remote_nd.h") +set(TVM_REMOTE_ND_SKEL_C "tvm_remote_nd_skel.c") +set(TVM_REMOTE_ND_STUB_C "tvm_remote_nd_stub.c") + +add_custom_command( + OUTPUT ${TVM_REMOTE_ND_SKEL_C} ${TVM_REMOTE_ND_STUB_C} + "${FASTRPC_SRC}/include/${TVM_REMOTE_ND_H}" + COMMAND ${QAIC_EXE} ${QAIC_FLAGS} + "${FASTRPC_SRC}/include/${TVM_REMOTE_ND_IDL}" + COMMAND ${CMAKE_COMMAND} -E rename "${TVM_REMOTE_ND_H}" + "${FASTRPC_SRC}/include/${TVM_REMOTE_ND_H}" + MAIN_DEPENDENCY "${FASTRPC_SRC}/include/${TVM_REMOTE_ND_IDL}" +) + +# Qaic for the domain header. +# +# Don't add paths to these filenames, or otherwise cmake may spontaneously +# add -o option to the qaic invocation (with an undesirable path). +set(TVM_REMOTE_D_IDL "tvm_remote.idl") +set(TVM_REMOTE_D_H "tvm_remote.h") +set(TVM_REMOTE_D_SKEL_C "tvm_remote_skel.c") +set(TVM_REMOTE_D_STUB_C "tvm_remote_stub.c") + +add_custom_command( + OUTPUT ${TVM_REMOTE_D_SKEL_C} ${TVM_REMOTE_D_STUB_C} + "${FASTRPC_SRC}/include/${TVM_REMOTE_D_H}" + COMMAND ${QAIC_EXE} ${QAIC_FLAGS} + "${FASTRPC_SRC}/include/${TVM_REMOTE_D_IDL}" + COMMAND ${CMAKE_COMMAND} -E rename "${TVM_REMOTE_D_H}" + "${FASTRPC_SRC}/include/${TVM_REMOTE_D_H}" + MAIN_DEPENDENCY "${FASTRPC_SRC}/include/${TVM_REMOTE_D_IDL}" +) + + +if("${FASTRPC_LIBS}" STREQUAL "SKEL") + # Skel libraries. + # + set(HEXARCH_DIR_v60 "ADSPv60MP") + set(HEXARCH_DIR_v62 "ADSPv62MP") + set(HEXARCH_DIR_v65 "computev65") + set(HEXARCH_DIR_v66 "computev66") + set(HEXARCH_DIR_STR "HEXARCH_DIR_${HEXAGON_ARCH}") + set(HEXARCH_DIR ${${HEXARCH_DIR_STR}}) + + if(NOT HEXARCH_DIR) + message(SEND_ERROR + "Please set HEXAGON_ARCH to one of v60, v62, v65, v66") + endif() + + include_directories( + ${HEXAGON_SDK_ROOT}/libs/common/qurt/${HEXARCH_DIR}/include/qurt) + include_directories( + ${HEXAGON_SDK_ROOT}/libs/common/qurt/${HEXARCH_DIR}/include/posix) + + # Extra compile flags (both C and C++). + set(EXTRA_COMP_FLAGS + "-O3" + "-m${HEXAGON_ARCH}" + ) + string(REGEX REPLACE ";" " " EXTRA_COMP_FLAGS_STR "${EXTRA_COMP_FLAGS}") + message(STATUS "EXTRA_COMP_FLAGS_STR: ${EXTRA_COMP_FLAGS_STR}") + set(CMAKE_C_FLAGS "${EXTRA_COMP_FLAGS_STR} ${CMAKE_C_FLAGS}") + set(CMAKE_CXX_FLAGS "${EXTRA_COMP_FLAGS_STR} ${CMAKE_CXX_FLAGS}") + + set(EXTRA_LINK_FLAGS + "-Wl,--no-threads" + "-Wl,--wrap=malloc" + "-Wl,--wrap=calloc" + "-Wl,--wrap=free" + "-Wl,--wrap=realloc" + "-Wl,--wrap=memalign" + "-Wl,--wrap=posix_memalign" + "-Wl,--wrap=__stack_chk_fail" + ) + string(REGEX REPLACE ";" " " EXTRA_LINK_FLAGS_STR "${EXTRA_LINK_FLAGS}") + + # Extra linker flags for linking shared libraries. + set(CMAKE_SHARED_LINKER_FLAGS + "${EXTRA_LINK_FLAGS_STR} ${CMAKE_SHARED_LINKER_FLAGS}") + + set(SKEL_ND_SRCS + "src/tvm_hvx.cc" + "src/tvm_remote_nd_imp.cc" + ) + add_library(tvm_remote_nd_skel SHARED + "${FASTRPC_SRC}/include/${TVM_REMOTE_ND_H}" + ${TVM_REMOTE_ND_SKEL_C} + ${SKEL_ND_SRCS} + ) + + set(SKEL_D_SRCS + # Also includes src/tvm_remote_nd_imp.cc + ${SKEL_ND_SRCS} + "src/tvm_remote_imp.cc" + ) + add_library(tvm_remote_skel SHARED + "${FASTRPC_SRC}/include/${TVM_REMOTE_D_H}" + ${TVM_REMOTE_D_SKEL_C} + ${SKEL_D_SRCS} + ) + + # Separate shared library with __wrap_pthread_create. + # It is necessary to have it as a separate library because it defines + # a function that libtvm_runtime.so will call. Because of that, this + # function needs to be in the global dynamic symbol table, but the + # skel libraries are loaded as private by FastRPC. + set(WRAP_PTHREAD_SRCS "src/tvm_wrap_pthread.cc") + add_library(tvm_wrap_pthread SHARED ${WRAP_PTHREAD_SRCS}) + +else() + # Stub libraries. + # + include_directories(${HEXAGON_SDK_ROOT}/incs/a1std) + include_directories(${HEXAGON_SDK_ROOT}/incs/qlist) + include_directories(${HEXAGON_SDK_ROOT}/libs/common/rpcmem/inc) + link_directories( + ${HEXAGON_SDK_ROOT}/libs/common/remote/ship/android_Release_aarch64) + + add_library(tvm_remote_nd_stub SHARED + "${FASTRPC_SRC}/include/${TVM_REMOTE_ND_H}" + "${HEXAGON_SDK_ROOT}/libs/common/rpcmem/src/rpcmem_android.c" + "${TVM_REMOTE_ND_STUB_C}" + ) + add_library(tvm_remote_stub SHARED + "${FASTRPC_SRC}/include/${TVM_REMOTE_D_H}" + "${HEXAGON_SDK_ROOT}/libs/common/rpcmem/src/rpcmem_android.c" + "${TVM_REMOTE_D_STUB_C}" + ) + target_link_libraries(tvm_remote_nd_stub adsprpc) + target_link_libraries(tvm_remote_stub adsprpc) +endif() diff --git a/src/runtime/hexagon/target/fastrpc/README.md b/src/runtime/hexagon/target/fastrpc/README.md new file mode 100644 index 000000000000..2d85679bdc65 --- /dev/null +++ b/src/runtime/hexagon/target/fastrpc/README.md @@ -0,0 +1,56 @@ + + + + + + + + + + + + + + + + + +# Hexagon IDL libraries + +This directory hosts IDL files and their implementations to offload TVM kernels to Hexagon via FastRPC. The implementations can be used to generate stub and skel libraries. + +### Prerequisites + +1. Android NDK version r19c or later. +2. Hexagon SDK version 3.5.0 or later. + +Android NDK can be downloaded from https://developer.android.com/ndk. +Hexagon SDK is available at //developer.qualcomm.com/software/hexagon-dsp-sdk. + +### Configuring + +Skel and stub libraries need to be configured and built separately. Please use different subdirectories for each. Otherwise the cmake cache from one configuration can interfere with the next. + +For skel libraries, set +``` +FASTRPC_LIBS=SKEL +HEXAGON_SDK_ROOT=/path/to/sdk +CMAKE_C_COMPILER=hexagon-clang +CMAKE_CXX_COMPILER=hexagon-clang++ +HEXAGON_ARCH= one of v60, v62, v65, v66 +``` + +Please note that support for older versions of the Hexagon processor may be removed from the future versions of the Hexagon toolchain. + + +For stub libraries, set +``` +FASTRPC_LIBS=STUB +HEXAGON_SDK_ROOT=/path/to/sdk +CMAKE_C_COMPILER=aarch64-linux-android28-clang # or later +CMAKE_CXX_COMPILER=aarch64-linux-android28-clang++ # or later +``` + +### Building + +In each instance, simple `make` command will create header files `fastrpc/include/tvm_remote.h` and `fastrpc/include/tvm_remote_nd.h`. These headers are needed to compile the TVM runtime for Android (and the stub/skel libraries themselves). diff --git a/src/runtime/hexagon/target/fastrpc/include/tvm_remote.idl b/src/runtime/hexagon/target/fastrpc/include/tvm_remote.idl new file mode 100644 index 000000000000..bb7d8a29550d --- /dev/null +++ b/src/runtime/hexagon/target/fastrpc/include/tvm_remote.idl @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* + * IDL to offload TVM kernels to Hexagon from APPS for multi-domains. + */ +#include "remote.idl" +#include "AEEStdDef.idl" + +interface tvm_remote : remote_handle64 { + typedef sequence buffer; + typedef unsigned long handle_t; + + long load_library(in sequence soname, + rout handle_t mod_ptr); + long get_symbol(in handle_t mod, + in sequence name, + rout handle_t sym_ptr); + long kernel(in handle_t mod, + in handle_t symbol, + in sequence scalar, + in sequence stack, + in sequence scalar_in_octet, + rout sequence scalar_out_octet, + in sequence stack_in_octet, + rout sequence stack_out_octet, + rout unsigned long long pcycles, + rout unsigned long long time_usec); + long release_library(in handle_t mod); + long alloc_vtcm(in unsigned long size, + in unsigned long align, + rout unsigned long dsp_va); + long free_vtcm(in unsigned long dsp_va); + long call_mmap64(); +}; diff --git a/src/runtime/hexagon/target/fastrpc/include/tvm_remote_nd.idl b/src/runtime/hexagon/target/fastrpc/include/tvm_remote_nd.idl new file mode 100644 index 000000000000..845ddeffa26f --- /dev/null +++ b/src/runtime/hexagon/target/fastrpc/include/tvm_remote_nd.idl @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* + * IDL to offload TVM kernels to Hexagon from APPS for non-domains. + */ +#include "remote.idl" +#include "AEEStdDef.idl" + +interface tvm_remote_nd { + typedef sequence buffer; + typedef unsigned long handle_t; + + long open(); + long close(); + long load_library(in sequence soname, + rout handle_t mod_ptr); + long get_symbol(in handle_t mod, + in sequence name, + rout handle_t sym_ptr); + long kernel(in handle_t mod, + in handle_t symbol, + in sequence scalar, + in sequence stack, + in sequence scalar_in_octet, + rout sequence scalar_out_octet, + in sequence stack_in_octet, + rout sequence stack_out_octet, + rout unsigned long long pcycles, + rout unsigned long long time_usec); + long release_library(in handle_t mod); + long call_mmap64(); +}; diff --git a/src/runtime/hexagon/target/fastrpc/src/tvm_hvx.cc b/src/runtime/hexagon/target/fastrpc/src/tvm_hvx.cc new file mode 100644 index 000000000000..54c06e10243b --- /dev/null +++ b/src/runtime/hexagon/target/fastrpc/src/tvm_hvx.cc @@ -0,0 +1,208 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "tvm_hvx.h" + +#include "AEEStdErr.h" +#include "HAP_farf.h" +#include "HAP_power.h" + +extern "C" { +#include "qurt_error.h" +#include "qurt_hvx.h" +} + +namespace hvx { + +#if __HEXAGON_ARCH__ >= 65 +#define DEFAULT_HVX_MODE MODE_128B +#else +#define DEFAULT_HVX_MODE MODE_DONT_CARE +#endif + +static constexpr mode_t default_hvx_mode = DEFAULT_HVX_MODE; + +int reserve(unsigned num_units) { + if (qurt_hvx_get_units() <= 0) { + return -1; // HVX not supported in this target. + } + + if (num_units == 0) num_units = QURT_HVX_RESERVE_ALL_AVAILABLE; + int ret_val = qurt_hvx_reserve(num_units); + switch (ret_val) { + case QURT_HVX_RESERVE_ALREADY_MADE: + case QURT_HVX_RESERVE_NOT_SUPPORTED: + case QURT_HVX_RESERVE_NOT_SUCCESSFUL: + return 0; + + default: + if (ret_val < 0) { + return -1; + } + break; + } + return ret_val; +} + +int unreserve() { + int ret_val = qurt_hvx_cancel_reserve(); + if (ret_val != QURT_EOK) { + return -1; + } + return 0; +} + +int power_on() { + HAP_power_request_t request; + request.type = HAP_power_set_HVX; + request.hvx.power_up = 1; + int rc = HAP_power_set(nullptr, &request); + if (rc != AEE_SUCCESS) { + FARF(ERROR, "%s: unable to power on HVX, rc=%08x", rc); + return -1; + } + return 0; +} + +int power_off() { + HAP_power_request_t request; + request.type = HAP_power_set_HVX; + request.hvx.power_up = 0; + int rc = HAP_power_set(nullptr, &request); + if (rc != AEE_SUCCESS) { + FARF(ERROR, "%s: unable to power off HVX, rc=%08x", rc); + return -1; + } + return 0; +} + +int lock(mode_t mode) { + qurt_hvx_mode_t qurt_mode; + int vlen; + + if (MODE_DONT_CARE == mode) mode = default_hvx_mode; + + switch (mode) { + case MODE_DONT_CARE: { + int ret_val = qurt_hvx_get_mode(); + if (ret_val < 0) { + FARF(HIGH, "%s: unknown HVX mode %d", __func__, qurt_mode); + return -1; + } + qurt_mode = static_cast(ret_val); + switch (qurt_mode) { + case QURT_HVX_MODE_64B: + vlen = 64; + break; + case QURT_HVX_MODE_128B: + vlen = 128; + break; + } + break; + } + + case MODE_64B: + qurt_mode = QURT_HVX_MODE_64B; + vlen = 64; + break; + + case MODE_128B: + qurt_mode = QURT_HVX_MODE_128B; + vlen = 128; + break; + + default: + FARF(HIGH, "%s: unknown HVX mode %d", __func__, qurt_mode); + return -3; + } + + // Starting with v65, the RTOS supports HVX context switching. + // Treat all hvx locks as blocking now, so they can succeed, and + // be scheduled according to RTOS scheduler via thread priority. + // Nonblocking call: qurt_hvx_try_lock(qurt_mode). + int ret_val = qurt_hvx_lock(qurt_mode); + + if (ret_val != QURT_EOK) { + return -1; + } + return vlen; +} + +int unlock() { + int ret_val = qurt_hvx_unlock(); + if (ret_val != QURT_EOK) { + return -1; + } + return 0; +} + +int prepare_mt_job(config_t* hvx_config) { + int num_units = qurt_hvx_get_units(); + if (num_units <= 0) { + return -1; + } + + // Check whether HVX is reserved for this protection domain. If not, + // see if we can temporarily reserve them for this invocation only. + hvx_config->temp_reserve = false; + if (hvx_config->num_reserved == 0) { + hvx_config->num_reserved = reserve(0); // Reserve all units. + if (hvx_config->num_reserved <= 0) { + return -1; + } + hvx_config->temp_reserve = true; + } + + // If client doesn't specify required mode, fallback to default. + if (hvx_config->mode == MODE_DONT_CARE) hvx_config->mode = default_hvx_mode; + + // Choose 64 byte or 128 byte mode, based on whether there are odd or even + // number of units + if (hvx_config->mode == MODE_64B || + (hvx_config->mode == MODE_DONT_CARE && (hvx_config->num_reserved & 1))) { + hvx_config->vlen = 64; + hvx_config->mode = MODE_64B; + hvx_config->num_threads = hvx_config->num_reserved; + } else { + hvx_config->vlen = 128; + hvx_config->mode = MODE_128B; + hvx_config->num_threads = (num_units >> 8) & 0xFF; + // Handle case where only 1 64-byte unit was available. + if (hvx_config->num_threads == 0) { + if (hvx_config->temp_reserve) unreserve(); + return -1; + } + } + + // If using HVX, make sure it turns on properly. + if (hvx_config->num_reserved > 0 && power_on() != 0) { + return -1; + } + return 0; +} + +int cleanup_mt_job(const config_t* hvx_config) { + // If HVX was used, indicate it can be turned off. + if (hvx_config->num_reserved > 0) power_off(); + // If HVX was temporarily reserved, unreserve it. + if (hvx_config->temp_reserve) unreserve(); + return 0; +} + +} // namespace hvx diff --git a/src/runtime/hexagon/target/fastrpc/src/tvm_hvx.h b/src/runtime/hexagon/target/fastrpc/src/tvm_hvx.h new file mode 100644 index 000000000000..2fe947574bbb --- /dev/null +++ b/src/runtime/hexagon/target/fastrpc/src/tvm_hvx.h @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_RUNTIME_HEXAGON_TARGET_FASTRPC_SRC_TVM_HVX_H_ +#define TVM_RUNTIME_HEXAGON_TARGET_FASTRPC_SRC_TVM_HVX_H_ + +// Utility providing functions for accessing the Hexagon Vector Extensions +// (HVX) hardware. + +#include + +namespace hvx { + +enum mode_t : uint32_t { + MODE_DONT_CARE = 0, /*!< Don't-care, just use whatever current mode is. */ + MODE_64B, /*!< 64 byte HVX vector width. */ + MODE_128B /*!< 128 byte HVX vector width. */ +}; + +/*! + * \brief HVX configuration data. + */ +struct config_t { + int num_reserved; /*!< Number of reserved HVX units. */ + bool temp_reserve; /*!< Indicates that HVX pool reservation is */ + /*!< temporary and needs to be released after use. */ + mode_t mode; /*!< Configured HVX mode. */ + int vlen; /*!< Configured HVX vector width (64 or 128 bytes). */ + int num_threads; /*!< Number of threads that can lock HVX units. */ +}; + +/*! + * \brief + * This function reserves HVX units for the protection domain to which + * the caller belongs. Reservation is optional before locking HVX units. + * Typically it would be called by applications that want to guarantee + * up front that the requested number of HVX units will be available + * for the duration of the application. + * + * \param num_units + * Number of HVX units to reserve. 0 indicates to reserve all the units + * present in the given target. > 0 indicates the number of single HVX + * units to reserve. Mode (64 byte vs. 128 byte) is not specified. + * + * \return + * The number of HVX units (in terms of 64 byte single units) successfully + * reserved. The return value of -1 indicates no HVX hardware is available + * on the target. + */ +int reserve(unsigned num_units); + +/*! + * \brief + * This function releases all HVX unit from reservation. A call to this + * function nullifies all previous calls to reserve HVX units from within + * this worker pool's protection domain. + * + * \return + * 0 on success, -1 if there was an error. + */ +int unreserve(); + +/*! + * \brief + * This function turns on the HVX hardware. It must be called sometime + * before (possibly multiple) software threads lock HVX units. + * + * \return + * 0 on success, -1 if there was an error. + */ +int power_on(); + +/*! + * \brief + * This function turns off the HVX hardware. It must be called sometime + * after all threads have unlocked their HVX units. + * + * \return + * 0 on success, -1 if there was an error. + */ +int power_off(); + +/*! + * \brief + * This function locks the HVX units for the calling threads. + * + * \param mode + * The HVX mode. + * + * \return + * 0 on success, -1 if there was an error. + */ +int lock(mode_t mode); + +/*! + * \brief + * This function unlocks the HVX units for the calling threads. + * + * \return + * 0 on success, -1 if there was an error. + */ +int unlock(); + +/*! + * \brief + * This function performs preparations for multithreaded job. + * It does so by filling out data members in the configuration + * structure passed as a parameter, and by setting up the hardware: + * - it performs a temporary reservation of HVX units, if no units + * have yet been reserved, + * - it powers on the HVX hardware. + * + * \param hvx_config + * Structure describing the HVX configuration. Two data members + * must be set prior to calling \ref prepare_mt_job: + * \ref num_reserved, indicating the number of previously reserved + * HVX units (can be 0), and \ref mode indicating the HVX mode. + * + * \return + * 0 on success, -1 if there was an error. + */ +int prepare_mt_job(config_t* hvx_config); + +/*! + * \brief + * This function cleans up after \ref prepare_mt_job, in particular + * it releases temporarily reserved HVX units and turns the HVX + * hardware off. + * + * \return + * 0 on success, -1 if there was an error. + */ +int cleanup_mt_job(const config_t* hvx_config); + +} // namespace hvx + +#endif // TVM_RUNTIME_HEXAGON_TARGET_FASTRPC_SRC_TVM_HVX_H_ diff --git a/src/runtime/hexagon/target/fastrpc/src/tvm_remote_imp.cc b/src/runtime/hexagon/target/fastrpc/src/tvm_remote_imp.cc new file mode 100644 index 000000000000..c9e3332d59a7 --- /dev/null +++ b/src/runtime/hexagon/target/fastrpc/src/tvm_remote_imp.cc @@ -0,0 +1,244 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include + +#define FARF_ERROR 1 +#include "AEEStdErr.h" +#include "HAP_farf.h" +#include "HAP_perf.h" +#include "apps_mem.h" +#include "qurt.h" +#include "tvm_remote.h" +#include "tvm_remote_nd.h" + +#if __HEXAGON_ARCH__ >= 65 +#include "HAP_vtcm_mgr.h" +#else +// Stub functions for targets that don't support VTCM. +static void* HAP_request_VTCM(int a, int b) { return 0; } +static int HAP_release_VTCM(void* a) { return 0; } +static int HAP_query_avail_VTCM(unsigned* avail_block_size, unsigned* max_page_size, + unsigned* num_pages) { + FARF(ALWAYS, "%s: running on architecture V62 or less", __func__); + return AEE_ENOMEMORY; +} +#endif // __HEXAGON_ARCH__ + +#define MIN_GATHER_SCATTER_SZ (32 * 1024) +#define MAX_GATHER_SCATTER_SZ (64 * 1024) +#define MIN_VTCM_SZ (64 * 1024) + +/*! + * \brief Open a domain channel. + * + * \param uri URI of the channel description. + * \param handle_ptr Where to store the channel handle. + * + * \return 0 on success, negative value on error. + */ +int tvm_remote_open(const char* uri, remote_handle64* handle_ptr) { + FARF(ALWAYS, "%s, uri=%s", __func__, uri); + int rc = tvm_remote_nd_open(); + if (rc != AEE_SUCCESS) { + FARF(ERROR, "%s: tvm_remote_nd_open failed rc=%08x", __func__, rc); + return rc; + } + + *handle_ptr = static_cast(reinterpret_cast(malloc(1))); + if (!*handle_ptr) { + FARF(ERROR, "%s: cannot allocate memory", __func__); + return AEE_ENOMEMORY; + } + return AEE_SUCCESS; +} + +/*! + * \brief Close domain channel. + * + * \param handle Domain channel handle to close. + * + * \return 0 on success, negative value on error. + */ +int tvm_remote_close(remote_handle64 handle) { + FARF(ALWAYS, "%s", __func__); + if (handle) free(reinterpret_cast(static_cast(handle))); + int rc = tvm_remote_nd_close(); + if (rc != AEE_SUCCESS) { + FARF(ERROR, "%s: tvm_remote_nd_close failed rc=%08x", __func__, rc); + } + return rc; +} + +/*! + * \brief Dummy function. + * + * \param handle Domain channel handle. + * + * \return This function always returns 0. + * + * This function is present as a workaround. See comment at the call site + * in hexagon_device_target.cc. + */ +int tvm_remote_call_mmap64(remote_handle64 handle) { return AEE_SUCCESS; } + +/*! + * \brief Load a shared library. + * + * \param handle Domain channel handle. + * \param soname Name of the shared library. + * \param soname_len Length of the name. + * \param lib_ptr Where to store the handle of the loaded libarary. + * + * \return 0 on success, negative value on error. + */ +int tvm_remote_load_library(remote_handle64 handle, const char* soname, int soname_len, + tvm_remote_handle_t* lib_ptr) { + return tvm_remote_nd_load_library(soname, soname_len, lib_ptr); +} + +/*! + * \brief Resolve symbol name to an address. + * + * \param handle Domain channel handle. + * \param lib Handle of the shared library with the symbol. + * \param name Symbol name. + * \param name_len Length of the name. + * \param sym_ptr Where to store the resolved address. + * + * \return 0 on success, negative value on error. + */ +int tvm_remote_get_symbol(remote_handle64 handle, tvm_remote_handle_t lib, const char* name, + int name_len, tvm_remote_handle_t* sym_ptr) { + return tvm_remote_nd_get_symbol(lib, name, name_len, sym_ptr); +} + +/*! + * \brief Call the specified function. + * + * \param handle Domain channel handle. + * \param lib Handle of the library containing + * the function to call. + * \param symbol Address of the function to call. + * \param scalar Address of values to pass in registers. + * \param scalar_len Number of values to pass in registers. + * \param stack Address of values to pass on stack. + * \param stack_len Number of values to pass on stack. + * + * \param scalar_in_octet Address of the incoming scalar buffer. + * \param scalar_in_octet_len Length of the incoming scalar buffer. + * \param scalar_out_octet Address of the outgoing scalar buffer. + * \param scalar_out_octet_len Length of the outgoing scalar buffer. + * \param stack_in_octet Address of the incoming stack buffer. + * \param stack_in_octet_len Length of the incoming stack buffer. + * \param stack_out_octet Address of the outgoing stack buffer. + * \param stack_out_octet_len Length of the outgoing stack buffer. + * + * \param pcycles Pointer to where to store cycle count. + * \param time_usec Pointer to where to store time in usec. + * + * \return 0 on success, negative value on error. + * + * The 8 "octet" arguments in this function are used for cache operations + * only. They are not used for procesing. + */ +int tvm_remote_kernel(remote_handle64 handle, tvm_remote_handle_t lib, tvm_remote_handle_t symbol, + const int* scalar, int scalar_len, const int* stack, int stack_len, + const tvm_remote_buffer* scalar_in_octet, int scalar_in_octet_len, + tvm_remote_buffer* scalar_out_octet, int scalar_out_octet_len, + const tvm_remote_buffer* stack_in_octet, int stack_in_octet_len, + tvm_remote_buffer* stack_out_octet, int stack_out_octet_len, uint64* pcycles, + uint64* time_usec) { + return tvm_remote_nd_kernel( + lib, symbol, scalar, scalar_len, stack, stack_len, + reinterpret_cast(scalar_in_octet), scalar_in_octet_len, + reinterpret_cast(scalar_out_octet), scalar_out_octet_len, + reinterpret_cast(stack_in_octet), stack_in_octet_len, + reinterpret_cast(stack_out_octet), stack_out_octet_len, pcycles, + time_usec); +} + +/*! + * \brief Release previously loaded shared object. + * + * \param handle Domain channel handle. + * \param lib Handle of shared library to release. + * + * \return 0 on success, negative value on error. + */ +int tvm_remote_release_library(remote_handle64 handle, tvm_remote_handle_t lib) { + // FARF(ALWAYS, "tvm_remote_release_library begin "); + return tvm_remote_nd_release_library(lib); +} + +/*! + * \brief Allocate VTCM memory. + * + * \param handle Domain channel handle. + * \param size Number of bytes to allocate. + * \param align Requested alignment. + * \param dsp_va Address of variable to store the allocated VTCM + * address to. + * + * \return 0 on success, negative value on error. + */ +int tvm_remote_alloc_vtcm(remote_handle64 handle, unsigned size, unsigned align, unsigned* dsp_va) { + FARF(ALWAYS, "%s: size=%u, align=%u", __func__, size, align); + unsigned avail_block_size, max_page_size, num_pages; + int rc = HAP_query_avail_VTCM(&avail_block_size, &max_page_size, &num_pages); + if (rc != AEE_SUCCESS) { + FARF(ERROR, "%s: HAP_query_avail_VTCM failed, rc=%08x", __func__, rc); + return rc; + } + FARF(ALWAYS, "%s: avail_block_size=%u, max_page_size=%u, num_pages=%u", __func__, + avail_block_size, max_page_size, num_pages); + + if (max_page_size < MIN_VTCM_SZ) { + FARF(ERROR, "%s: available VTCM size less than %d KB, aborting", __func__, MIN_VTCM_SZ / 1024); + return AEE_ENOMEMORY; + } + + void* vtcm_base = HAP_request_VTCM(size, /*single_page_flag=*/1); + if (!vtcm_base) { + FARF(ERROR, "%s: error allocating VTCM", __func__); + return AEE_ENOMEMORY; + } + *dsp_va = static_cast(reinterpret_cast(vtcm_base)); + FARF(ALWAYS, "%s: allocated VTCM addr=0x%p", __func__, vtcm_base); + return AEE_SUCCESS; +} + +/*! + * \brief Free VTCM memory. + * + * \param handle Domain channel handle. + * \param dsp_va VTCM address to free. + * + * \return 0 on success, negative value on error. + */ +int tvm_remote_free_vtcm(remote_handle64 handle, unsigned dsp_va) { + FARF(ALWAYS, "%s: dsp_va=0x%08x", __func__, dsp_va); + void* vtcm_base = reinterpret_cast(dsp_va); + int rc = HAP_release_VTCM(vtcm_base); + if (rc != AEE_SUCCESS) { + FARF(ERROR, "%s: error freeing VTCM, rc=%08x", __func__, rc); + } + return rc; +} diff --git a/src/runtime/hexagon/target/fastrpc/src/tvm_remote_nd_imp.cc b/src/runtime/hexagon/target/fastrpc/src/tvm_remote_nd_imp.cc new file mode 100644 index 000000000000..c0f6f22172c0 --- /dev/null +++ b/src/runtime/hexagon/target/fastrpc/src/tvm_remote_nd_imp.cc @@ -0,0 +1,325 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include + +#include + +#define FARF_ERROR 1 +#include "AEEStdDef.h" +#include "AEEStdErr.h" +#include "HAP_farf.h" +#include "HAP_mem.h" +#include "HAP_perf.h" +#include "qurt.h" +#include "tvm_hvx.h" +#include "tvm_remote_nd.h" + +struct msg_call { + uint32_t func_va; + uint32_t scalar_num; + uint32_t stack_num; + uint32_t data[]; +} __attribute__((packed)); + +__attribute__((naked)) uint32_t launcher(volatile msg_call* mc, uint64_t* pcc) { + __asm__( + "// This function is intentionally written to be readable, \n" + "// rather than fast. \n" + "// r0 = value of 'volatile msg_call *mc' \n" + "// r1 = address where to store the program cycle count \n" + + "// In this packet the store happens before the allocframe so \n" + "// the offset added to r29 must reflect that the r29 has not \n" + "// yet been updated (stack grows towards decreasing addresses):\n" + "// r29 before allocframe --. \n" + "// [ r17:16 ] [ r19:18 ] [ r21:20 ] [ FP/LR ] \n" + "// `-- r29 after allocframe increasing addresses --> \n" + "{ memd(r29+#-16) = r21:20 \n" + " allocframe(#24) } \n" + "{ memd(r29+#0) = r17:16 \n" + " memd(r29+#8) = r19:18 } \n" + "{ r17:16 = combine(r1,r0) \n" + " r18 = r29 \n" + " r1 = memw(r0+#4) // scalar_num \n" + " r2 = memw(r0+#8) } // stack_num \n" + "// If there are no stack values, skip the stack setup. \n" + "{ p0 = cmp.eq(r2,#0) \n" + " if (p0.new) jump:t .Llauncher1 } \n" + + "// Allocate space on the stack. Let r2 = needed space \n" + "// rounded up to a multiple of 8. \n" + "{ loop0(.Llauncher0,r2) \n" + " r2 = asl(r2,#2) } \n" + "{ r2 = add(r2,#4) } \n" + "{ r2 = clrbit(r2,#2) } \n" + "{ r29 = sub(r29,r2) } \n" + + "// Copy stack contents onto the stack. Stack contents start \n" + "// at r3 = r0 + offsetof(data) + scalar_num*4 \n" + "{ r3 = addasl(r0,r1,#2) \n" + " r4 = r29 } \n" + "{ r3 = add(r3,#12) } // offsetof(data) \n" + ".Llauncher0: \n" + "{ r5 = memw(r3++#4) \n" + " memw(r4++#4) = r5.new } :endloop0 \n" + + "// Load registers. Some of the loaded data may actually be \n" + "// values from the stack part of 'data', but it's not an issue.\n" + ".Llauncher1: \n" + "{ r0 = memw(r16+#12) // mc + offsetof(data) \n" + " r1 = memw(r16+#16) } \n" + "{ r2 = memw(r16+#20) \n" + " r3 = memw(r16+#24) } \n" + "{ r4 = memw(r16+#28) \n" + " r5 = memw(r16+#32) } \n" + + "// Call. \n" + "{ r6 = memw(r16+#0) \n" + " r21:20 = upcycle } \n" + "{ callr r6 } \n" + + "// Restore stack pointer (free up r18), calculate cycle count. \n" + "{ r29 = r18 \n" + " r19:18 = upcycle } \n" + "{ r19:18 = sub(r19:18, r21:20) } \n" + + "// Store pcount, restore non-volatile registers, and return. \n" + "{ memd(r17+#0) = r19:18 \n" + " r21:20 = memd(r29+#16) } \n" + "{ r19:18 = memd(r29+#8) \n" + " r17:16 = memd(r29+#0) } \n" + "{ dealloc_return } // implicit-use r1:0 \n"); +} + +extern "C" { +#pragma weak __wrap_pthread_create +int __wrap_pthread_create(pthread_t* restrict thread, const pthread_attr_t* restrict attr, + void* (*start)(void*), void* restrict arg) { + FARF(ERROR, "Wrong %s called", __func__); + abort(); +} +} + +static void* lib_rt = nullptr; +static void* lib_thread = nullptr; + +/*! + * \brief Perform initialization. + * + * \return 0 on success, negative value on error. + */ +int tvm_remote_nd_open() { + lib_thread = dlopen("libtvm_wrap_pthread.so", RTLD_NOW | RTLD_GLOBAL); + if (lib_thread == nullptr) { + FARF(ERROR, "%s: dlopen failed for libtvm_wrap_pthread.so: %s", __func__, dlerror()); + return AEE_EUNABLETOLOAD; + } + + lib_rt = dlopen("libtvm_runtime.so", RTLD_NOW | RTLD_GLOBAL); + if (lib_rt == nullptr) { + FARF(ERROR, "%s: dlopen failed for libtvm_runtime.so: %s", __func__, dlerror()); + return AEE_EUNABLETOLOAD; + } + return AEE_SUCCESS; +} + +/*! + * \brief Perform cleanup. + * + * \return 0 on success, negative value on error. + */ +int tvm_remote_nd_close() { + if (lib_thread != nullptr) { + dlclose(lib_thread); + lib_thread = nullptr; + } + if (lib_rt != nullptr) { + dlclose(lib_rt); + lib_rt = nullptr; + } + return AEE_SUCCESS; +} + +/*! + * \brief Dummy function. + * + * \param handle Domain channel handle. + * + * \return This function always returns 0. + * + * This function is present as a workaround. See comment at the call site + * in hexagon_device_target.cc. + */ +int tvm_remote_nd_call_mmap64() { return AEE_SUCCESS; } + +/*! + * \brief Load a shared library. + * + * \param soname Name of the shared library. + * \param soname_len Length of the name. + * \param lib_ptr Where to store the handle of the loaded libarary. + * + * \return 0 on success, negative value on error. + */ +int tvm_remote_nd_load_library(const char* soname, int soname_len, + tvm_remote_nd_handle_t* lib_ptr) { + // We need to use RTLD_NOW, the libraries we build for Hexagon + // offloading do not support lazy binding. + FARF(ALWAYS, "%s: %s", __func__, soname); + if (void* lib = dlopen(soname, RTLD_GLOBAL | RTLD_NOW)) { + *lib_ptr = reinterpret_cast(lib); + return AEE_SUCCESS; + } + FARF(ERROR, "%s: dlopen failed: %s", __func__, dlerror()); + return AEE_EUNKNOWN; +} + +/*! + * \brief Resolve symbol name to an address. + * + * \param lib Handle of the shared library with the symbol. + * \param name Symbol name. + * \param name_len Length of the name. + * \param sym_ptr Where to store the resolved address. + * + * \return 0 on success, negative value on error. + */ +int tvm_remote_nd_get_symbol(tvm_remote_nd_handle_t lib, const char* name, int name_len, + tvm_remote_nd_handle_t* sym_ptr) { + FARF(ALWAYS, "%s: name=%s", __func__, name); + if (void* p = dlsym(reinterpret_cast(lib), name)) { + *sym_ptr = reinterpret_cast(p); + return AEE_SUCCESS; + } + + FARF(ERROR, "%s: dlsym failed: %s", __func__, dlerror()); + return AEE_EUNKNOWN; +} + +static void print_msg_call(const msg_call& mc) { + FARF(ALWAYS, "device: launching %x scalar_num:%d stack_num:%d", mc.func_va, mc.scalar_num, + mc.stack_num); + for (unsigned i = 0; i != mc.scalar_num; ++i) { + FARF(ALWAYS, "scalar_data[%d] %x", i, mc.data[i]); + } + for (unsigned i = 0; i != mc.stack_num; ++i) { + FARF(ALWAYS, "stack_data[%d] %x", i, mc.data[mc.scalar_num + i]); + } +} + +/*! + * \brief Call the specified function. + * + * \param lib Handle of the library containing + * the function to call. + * \param symbol Address of the function to call. + * \param scalar Address of values to pass in registers. + * \param scalar_len Number of values to pass in registers. + * \param stack Address of values to pass on stack. + * \param stack_len Number of values to pass on stack. + * + * \param scalar_in_octet Address of the incoming scalar buffer. + * \param scalar_in_octet_len Length of the incoming scalar buffer. + * \param scalar_out_octet Address of the outgoing scalar buffer. + * \param scalar_out_octet_len Length of the outgoing scalar buffer. + * \param stack_in_octet Address of the incoming stack buffer. + * \param stack_in_octet_len Length of the incoming stack buffer. + * \param stack_out_octet Address of the outgoing stack buffer. + * \param stack_out_octet_len Length of the outgoing stack buffer. + * + * \param pcycles Pointer to where to store cycle count. + * \param time_usec Pointer to where to store time in usec. + * + * \return 0 on success, negative value on error. + * + * The 8 "octet" arguments in this function are used for cache operations + * only. They are not used for procesing. + */ +int tvm_remote_nd_kernel(tvm_remote_nd_handle_t lib, tvm_remote_nd_handle_t symbol, + const int* scalar, int scalar_len, const int* stack, int stack_len, + const tvm_remote_nd_buffer* scalar_in_octet, int scalar_in_octet_len, + tvm_remote_nd_buffer* scalar_out_octet, int scalar_out_octet_len, + const tvm_remote_nd_buffer* stack_in_octet, int stack_in_octet_len, + tvm_remote_nd_buffer* stack_out_octet, int stack_out_octet_len, + uint64* pcycles, uint64* time_usec) { + hvx::config_t hvx_info = {0}; + hvx::prepare_mt_job(&hvx_info); + + int lock_result; + // Check if HVX units are available + if (hvx_info.num_reserved > 0) { + lock_result = hvx::lock(hvx::MODE_128B); + if (lock_result < 0) { + FARF(ERROR, "%s: HVX locking failed lock_result=%d num_reserved=%d", __func__, lock_result, + hvx_info.num_reserved); + } else { + FARF(ALWAYS, "%s: HVX lock successful lock_result=%d", __func__, lock_result); + } + } else { + FARF(ERROR, "%s: there are no HVX units available", __func__); + } + + struct msg_call* mc = (struct msg_call*)malloc(sizeof(uint32_t) * (3 + scalar_len + stack_len)); + if (mc == nullptr) { + FARF(ERROR, "%s: failed to allocate memory for mc", __func__); + return AEE_ENOMEMORY; + } + + int32_t* mc_ptr = reinterpret_cast(mc); + // Scalar buffers come first. + int k = 3; + for (int i = 0; i < scalar_len; i++, k++) { + *(mc_ptr + k) = static_cast(scalar[i]); + } + + for (int i = 0; i < stack_len; i++, k++) { + *(mc_ptr + k) = static_cast(stack[i]); + } + + mc->scalar_num = scalar_len; + mc->stack_num = stack_len; + mc->func_va = symbol; + print_msg_call(*mc); + uint64_t start_time = HAP_perf_get_time_us(); + int result = launcher(mc, pcycles); + *time_usec = HAP_perf_get_time_us() - start_time; + FARF(ALWAYS, "kernel execution: %llu pcycles %llu usec", *pcycles, *time_usec); + if (lock_result > 0) hvx::unlock(); + hvx::cleanup_mt_job(&hvx_info); + if (mc) free(mc); + return result; +} + +/*! + * \brief Release previously loaded shared object. + * + * \param lib Handle of shared library to release. + * + * \return 0 on success, negative value on error. + */ +int tvm_remote_nd_release_library(tvm_remote_nd_handle_t lib) { + // FARF(ALWAYS, "tvm_remote_nd_release_library begin "); + dlclose(reinterpret_cast(lib)); + FARF(ALWAYS, "tvm_remote_nd_release_library done "); + return 0; +} diff --git a/src/runtime/hexagon/target/fastrpc/src/tvm_wrap_pthread.cc b/src/runtime/hexagon/target/fastrpc/src/tvm_wrap_pthread.cc new file mode 100644 index 000000000000..d26073af8ae1 --- /dev/null +++ b/src/runtime/hexagon/target/fastrpc/src/tvm_wrap_pthread.cc @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Implement a wrapper around pthread_create that sets the thread stack + * size to a chosen value. + * + * TVM runtime uses std::thread, but the C++ standard does not provide + * any means of controlling thread attributes (like stack size). Because + * of that, any thread created by the std::thread constructor will use + * default attributes. The default stack size for a thread in QuRT is 16kB. + * This has proven to be insufficient in the past, so we need to increase + * it. + * When libtvm_runtime.so is linked, a linker flag --wrap=pthread_create + * is used, which causes the linker to rename all uses of pthread_create + * with references to __wrap_pthread_create. This file implements the + * __wrap function to set the larger stack size and call the actual + * pthread_create. The call to pthread_create here must not be renamed, + * so this function cannot be included in the TVM runtime binary. + * Instead, it's implemented in a separate shared library. + */ + +#include + +#include "HAP_farf.h" + +static constexpr size_t kThreadStackSize = 128 * 1024; // 128kB + +// Make sure the function has C linkage. +extern "C" { +int __wrap_pthread_create(pthread_t* restrict thread, const pthread_attr_t* restrict attr, + void* (*start)(void*), void* restrict arg); +} + +int __wrap_pthread_create(pthread_t* restrict thread, const pthread_attr_t* restrict attr, + void* (*start)(void*), void* restrict arg) { + pthread_attr_t def_attr; + if (attr == nullptr) { + if (int rc = pthread_attr_init(&def_attr)) { + FARF(ERROR, "pthread_attr_init failed: rc=%08x", rc); + return rc; + } + if (int rc = pthread_attr_setstacksize(&def_attr, kThreadStackSize)) { + FARF(ERROR, "pthread_attr_setstacksize failed: rc=%08x", rc); + return rc; + } + attr = &def_attr; + } + size_t stack_size = 0; + if (int rc = pthread_attr_getstacksize(attr, &stack_size)) { + FARF(ERROR, "pthread_attr_setstacksize failed: rc=%08x", rc); + return rc; + } + FARF(ALWAYS, "launching thread with stack_size=%zu", stack_size); + int t = pthread_create(thread, attr, start, arg); + if (int rc = pthread_attr_destroy(&def_attr)) { + FARF(ERROR, "pthread_attr_destroy failed (after pthread_create): rc=%08x", rc); + } + return t; +} diff --git a/src/runtime/hexagon/target/fastrpc/tvm_hexagon_remote.h b/src/runtime/hexagon/target/fastrpc/tvm_hexagon_remote.h deleted file mode 100644 index bc8766c63db2..000000000000 --- a/src/runtime/hexagon/target/fastrpc/tvm_hexagon_remote.h +++ /dev/null @@ -1,129 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#ifndef TVM_RUNTIME_HEXAGON_TARGET_FASTRPC_TVM_HEXAGON_REMOTE_H_ -#define TVM_RUNTIME_HEXAGON_TARGET_FASTRPC_TVM_HEXAGON_REMOTE_H_ -/// @file tvm_hexagon_remote.idl -/// IDL to offload TVM kernels to Hexagon from APPS for multi-domains -#include "AEEStdDef.h" -#include "remote.h" -#ifndef __QAIC_HEADER -#define __QAIC_HEADER(ff) ff -#endif // __QAIC_HEADER - -#ifndef __QAIC_HEADER_EXPORT -#define __QAIC_HEADER_EXPORT -#endif // __QAIC_HEADER_EXPORT - -#ifndef __QAIC_HEADER_ATTRIBUTE -#define __QAIC_HEADER_ATTRIBUTE -#endif // __QAIC_HEADER_ATTRIBUTE - -#ifndef __QAIC_IMPL -#define __QAIC_IMPL(ff) ff -#endif // __QAIC_IMPL - -#ifndef __QAIC_IMPL_EXPORT -#define __QAIC_IMPL_EXPORT -#endif // __QAIC_IMPL_EXPORT - -#ifndef __QAIC_IMPL_ATTRIBUTE -#define __QAIC_IMPL_ATTRIBUTE -#endif // __QAIC_IMPL_ATTRIBUTE -#ifdef __cplusplus -extern "C" { -#endif -/** - * Opens the handle in the specified domain. If this is the first - * handle, this creates the session. Typically this means opening - * the device, aka open("/dev/adsprpc-smd"), then calling ioctl - * device APIs to create a PD on the DSP to execute our code in, - * then asking that PD to dlopen the .so and dlsym the skel function. - * - * @param uri, _URI"&_dom=aDSP" - * _URI is a QAIC generated uri, or - * "file:///?_skel_handle_invoke&_modver=1.0" - * If the _dom parameter is not present, _dom=DEFAULT is assumed - * but not forwarded. - * Reserved uri keys: - * [0]: first unamed argument is the skel invoke function - * _dom: execution domain name, _dom=mDSP/aDSP/DEFAULT - * _modver: module version, _modver=1.0 - * _*: any other key name starting with an _ is reserved - * Unknown uri keys/values are forwarded as is. - * @param h, resulting handle - * @retval, 0 on success - */ -__QAIC_HEADER_EXPORT int __QAIC_HEADER(tvm_hexagon_remote_open)( - const char* uri, remote_handle64* h) __QAIC_HEADER_ATTRIBUTE; -/** - * Closes a handle. If this is the last handle to close, the session - * is closed as well, releasing all the allocated resources. - - * @param h, the handle to close - * @retval, 0 on success, should always succeed - */ -__QAIC_HEADER_EXPORT int __QAIC_HEADER(tvm_hexagon_remote_close)( - remote_handle64 h) __QAIC_HEADER_ATTRIBUTE; -typedef struct _tvm_hexagon_remote_buffer__seq_octet - _tvm_hexagon_remote_buffer__seq_octet; -typedef _tvm_hexagon_remote_buffer__seq_octet tvm_hexagon_remote_buffer; -struct _tvm_hexagon_remote_buffer__seq_octet { - unsigned char* data; - int dataLen; -}; -typedef unsigned int tvm_hexagon_remote_handle_t; -typedef uint64 tvm_hexagon_remote_scalar_t; -__QAIC_HEADER_EXPORT int __QAIC_HEADER(tvm_hexagon_remote_load_library)( - remote_handle64 _h, const char* soname, int sonameLen, const char* code, - int codeLen, - tvm_hexagon_remote_handle_t* module_ptr) __QAIC_HEADER_ATTRIBUTE; -__QAIC_HEADER_EXPORT int __QAIC_HEADER(tvm_hexagon_remote_get_symbol)( - remote_handle64 _h, tvm_hexagon_remote_handle_t module_ptr, - const char* name, int nameLen, - tvm_hexagon_remote_handle_t* sym_ptr) __QAIC_HEADER_ATTRIBUTE; -__QAIC_HEADER_EXPORT int __QAIC_HEADER(tvm_hexagon_remote_kernel)( - remote_handle64 _h, tvm_hexagon_remote_handle_t module_ptr, - tvm_hexagon_remote_handle_t symbol, int* scalar, int scalarLen, int* stack, - int stackLen, const tvm_hexagon_remote_buffer* scalar_in_octet, - int scalar_in_octetLen, tvm_hexagon_remote_buffer* scalar_out_octet, - int scalar_out_octetLen, const tvm_hexagon_remote_buffer* stack_in_octet, - int stack_in_octetLen, tvm_hexagon_remote_buffer* stack_out_octet, - int stack_out_octetLen, uint64* pcycles, - uint64* time_usec) __QAIC_HEADER_ATTRIBUTE; -__QAIC_HEADER_EXPORT int __QAIC_HEADER(tvm_hexagon_remote_release_library)( - remote_handle64 _h, - tvm_hexagon_remote_handle_t module_ptr) __QAIC_HEADER_ATTRIBUTE; -__QAIC_HEADER_EXPORT int __QAIC_HEADER(tvm_hexagon_remote_alloc_vtcm)( - remote_handle64 _h, unsigned int size, unsigned int align, - unsigned int* dsp_va) __QAIC_HEADER_ATTRIBUTE; -__QAIC_HEADER_EXPORT int __QAIC_HEADER(tvm_hexagon_remote_free_vtcm)( - remote_handle64 _h, unsigned int dsp_va) __QAIC_HEADER_ATTRIBUTE; -__QAIC_HEADER_EXPORT int __QAIC_HEADER(tvm_hexagon_remote_call_mmap64)( - remote_handle64 _h) __QAIC_HEADER_ATTRIBUTE; -#ifndef tvm_hexagon_remote_URI -#define tvm_hexagon_remote_URI \ - "file:///" \ - "libtvm_hexagon_remote_skel.so?tvm_hexagon_remote_skel_handle_invoke&_" \ - "modver=1.0" -#endif /*tvm_hexagon_remote_URI*/ -#ifdef __cplusplus -} -#endif -#endif // TVM_RUNTIME_HEXAGON_TARGET_FASTRPC_TVM_HEXAGON_REMOTE_H_ diff --git a/src/runtime/hexagon/target/fastrpc/tvm_hexagon_remote_nd.h b/src/runtime/hexagon/target/fastrpc/tvm_hexagon_remote_nd.h deleted file mode 100644 index bb35bd30f679..000000000000 --- a/src/runtime/hexagon/target/fastrpc/tvm_hexagon_remote_nd.h +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#ifndef TVM_RUNTIME_HEXAGON_TARGET_FASTRPC_TVM_HEXAGON_REMOTE_ND_H_ -#define TVM_RUNTIME_HEXAGON_TARGET_FASTRPC_TVM_HEXAGON_REMOTE_ND_H_ -/// @file tvm_hexagon_remote_nd.idl -/// IDL to offload TVM kernels to Hexagon from APPS for non-domains -#include "AEEStdDef.h" -#include "remote.h" -#ifndef __QAIC_HEADER -#define __QAIC_HEADER(ff) ff -#endif // __QAIC_HEADER - -#ifndef __QAIC_HEADER_EXPORT -#define __QAIC_HEADER_EXPORT -#endif // __QAIC_HEADER_EXPORT - -#ifndef __QAIC_HEADER_ATTRIBUTE -#define __QAIC_HEADER_ATTRIBUTE -#endif // __QAIC_HEADER_ATTRIBUTE - -#ifndef __QAIC_IMPL -#define __QAIC_IMPL(ff) ff -#endif // __QAIC_IMPL - -#ifndef __QAIC_IMPL_EXPORT -#define __QAIC_IMPL_EXPORT -#endif // __QAIC_IMPL_EXPORT - -#ifndef __QAIC_IMPL_ATTRIBUTE -#define __QAIC_IMPL_ATTRIBUTE -#endif // __QAIC_IMPL_ATTRIBUTE -#ifdef __cplusplus -extern "C" { -#endif -typedef struct _tvm_hexagon_remote_nd_buffer__seq_octet - _tvm_hexagon_remote_nd_buffer__seq_octet; -typedef _tvm_hexagon_remote_nd_buffer__seq_octet tvm_hexagon_remote_nd_buffer; -struct _tvm_hexagon_remote_nd_buffer__seq_octet { - unsigned char* data; - int dataLen; -}; -typedef unsigned int tvm_hexagon_remote_nd_handle_t; -typedef uint64 tvm_hexagon_remote_nd_scalar_t; -__QAIC_HEADER_EXPORT int __QAIC_HEADER(tvm_hexagon_remote_nd_open)(void) - __QAIC_HEADER_ATTRIBUTE; -__QAIC_HEADER_EXPORT int __QAIC_HEADER(tvm_hexagon_remote_nd_close)(void) - __QAIC_HEADER_ATTRIBUTE; -__QAIC_HEADER_EXPORT int __QAIC_HEADER(tvm_hexagon_remote_nd_load_library)( - const char* soname, int sonameLen, const char* code, int codeLen, - tvm_hexagon_remote_nd_handle_t* module_ptr) __QAIC_HEADER_ATTRIBUTE; -__QAIC_HEADER_EXPORT int __QAIC_HEADER(tvm_hexagon_remote_nd_get_symbol)( - tvm_hexagon_remote_nd_handle_t module_ptr, const char* name, int nameLen, - tvm_hexagon_remote_nd_handle_t* sym_ptr) __QAIC_HEADER_ATTRIBUTE; -__QAIC_HEADER_EXPORT int __QAIC_HEADER(tvm_hexagon_remote_nd_kernel)( - tvm_hexagon_remote_nd_handle_t module_ptr, - tvm_hexagon_remote_nd_handle_t symbol, int* scalar, int scalarLen, - int* stack, int stackLen, - const tvm_hexagon_remote_nd_buffer* scalar_in_octet, - int scalar_in_octetLen, tvm_hexagon_remote_nd_buffer* scalar_out_octet, - int scalar_out_octetLen, - const tvm_hexagon_remote_nd_buffer* stack_in_octet, int stack_in_octetLen, - tvm_hexagon_remote_nd_buffer* stack_out_octet, int stack_out_octetLen, - uint64* pcycles, uint64* time_usec) __QAIC_HEADER_ATTRIBUTE; -__QAIC_HEADER_EXPORT int __QAIC_HEADER(tvm_hexagon_remote_nd_release_library)( - tvm_hexagon_remote_nd_handle_t module_ptr) __QAIC_HEADER_ATTRIBUTE; -__QAIC_HEADER_EXPORT int __QAIC_HEADER(tvm_hexagon_remote_nd_call_mmap64)(void) - __QAIC_HEADER_ATTRIBUTE; -#ifdef __cplusplus -} -#endif -#endif // TVM_RUNTIME_HEXAGON_TARGET_FASTRPC_TVM_HEXAGON_REMOTE_ND_H_ diff --git a/src/runtime/hexagon/target/hexagon_device_target.cc b/src/runtime/hexagon/target/hexagon_device_target.cc index 00ca49ea9797..ee326ca0b159 100644 --- a/src/runtime/hexagon/target/hexagon_device_target.cc +++ b/src/runtime/hexagon/target/hexagon_device_target.cc @@ -29,7 +29,7 @@ #include "../hexagon_module.h" #include "AEEStdErr.h" -#include "fastrpc/tvm_hexagon_remote.h" +#include "fastrpc/include/tvm_remote.h" #include "hexagon_dsprpcapi.h" #include "hexagon_stubapi.h" #include "hexagon_target_log.h" @@ -45,10 +45,8 @@ // The downside is that the format string must be given as a string literal, // but it seems to be a minor issue. #define VA_EXPANDER(...) , ##__VA_ARGS__ -#define TVM_LOGD_HT(fmt, ...) \ - TVM_LOGD("HexagonTarget::%s: " fmt, __func__ VA_EXPANDER(__VA_ARGS__)) -#define TVM_LOGE_HT(fmt, ...) \ - TVM_LOGE("HexagonTarget::%s: " fmt, __func__ VA_EXPANDER(__VA_ARGS__)) +#define TVM_LOGD_HT(fmt, ...) TVM_LOGD("HexagonTarget::%s: " fmt, __func__ VA_EXPANDER(__VA_ARGS__)) +#define TVM_LOGE_HT(fmt, ...) TVM_LOGE("HexagonTarget::%s: " fmt, __func__ VA_EXPANDER(__VA_ARGS__)) namespace tvm { namespace runtime { @@ -74,8 +72,7 @@ class HexagonTarget : public tvm::runtime::hexagon::Device { unsigned stack_num) final; private: - std::pair AddAddrMapping(const void* dsp_addr, - void* apps_addr, size_t size); + std::pair AddAddrMapping(const void* dsp_addr, void* apps_addr, size_t size); std::pair GetAppsAddr(const void* dsp_addr, bool exact) const; void RemoveAddrMapping(const void* dsp_addr); int OpenDomainChannel(bool set_unsigned_pd); @@ -88,7 +85,7 @@ class HexagonTarget : public tvm::runtime::hexagon::Device { // in apps's pointers, i.e. sizeof_dsp(void*) <= sizeof_apps(void*). std::map> dsp_to_apps_; remote_handle64 domain_channel_handle_ = AEE_EUNKNOWN; - tvm_hexagon_remote_handle_t module_pointer_ = AEE_EUNKNOWN; + tvm_remote_handle_t module_pointer_ = AEE_EUNKNOWN; uint64_t count_channel_open_ = 0; // Global lock, used for all critical sections. This can be refined // in the future. @@ -102,24 +99,19 @@ class HexagonTarget : public tvm::runtime::hexagon::Device { void* const HexagonTarget::vtcm_mark_ = reinterpret_cast(~0); -std::shared_ptr CreateHexagonTarget() { - return std::make_shared(); -} +std::shared_ptr CreateHexagonTarget() { return std::make_shared(); } -std::pair HexagonTarget::AddAddrMapping(const void* dsp_addr, - void* apps_addr, +std::pair HexagonTarget::AddAddrMapping(const void* dsp_addr, void* apps_addr, size_t size) { crit_section_.lock(); auto p = dsp_to_apps_.insert({dsp_addr, {apps_addr, size}}); crit_section_.unlock(); if (!p.second) { - TVM_LOGE_HT( - "failed to insert address mapping: dsp:%p -> apps:%p, size:%zu", - dsp_addr, apps_addr, size); + TVM_LOGE_HT("failed to insert address mapping: dsp:%p -> apps:%p, size:%zu", dsp_addr, + apps_addr, size); return std::make_pair(nullptr, 0); } - TVM_LOGD_HT("added address mapping: dsp:%p -> apps:%p, size:%zu", dsp_addr, - apps_addr, size); + TVM_LOGD_HT("added address mapping: dsp:%p -> apps:%p, size:%zu", dsp_addr, apps_addr, size); return p.first->second; } @@ -135,8 +127,7 @@ void HexagonTarget::RemoveAddrMapping(const void* dsp_addr) { crit_section_.unlock(); } -std::pair HexagonTarget::GetAppsAddr(const void* dsp_addr, - bool exact) const { +std::pair HexagonTarget::GetAppsAddr(const void* dsp_addr, bool exact) const { struct AutoUnlock { explicit AutoUnlock(std::mutex& m) : m(m) {} ~AutoUnlock() { m.unlock(); } @@ -192,16 +183,14 @@ int HexagonTarget::OpenDomainChannel(bool use_unsigned_pd) { data.domain = CDSP_DOMAIN_ID; int rc = rsc_ptr(DSPRPC_CONTROL_UNSIGNED_MODULE, &data, sizeof(data)); if (rc != AEE_SUCCESS) { - TVM_LOGE_HT("remote_session_control failed rc=%08x for unsigned PD", - rc); + TVM_LOGE_HT("remote_session_control failed rc=%08x for unsigned PD", rc); } } } else { TVM_LOGD_HT("remote_session_control not available"); } - int rc = stub_api->tvm_hexagon_remote_open( - tvm_hexagon_remote_URI "&_dom=cdsp", &domain_channel_handle_); + int rc = stub_api->tvm_remote_open(tvm_remote_URI "&_dom=cdsp", &domain_channel_handle_); if (rc != AEE_SUCCESS) { TVM_LOGE_HT("failed to open channel rc=0x%x", rc); } else { @@ -216,7 +205,7 @@ int HexagonTarget::CloseDomainChannel() { const StubAPI* stub_api = StubAPI::Global(); - int rc = stub_api->tvm_hexagon_remote_close(domain_channel_handle_); + int rc = stub_api->tvm_remote_close(domain_channel_handle_); if (rc == AEE_SUCCESS) { domain_channel_handle_ = AEE_EUNKNOWN; stub_api->rpcmem_deinit_ptr()(); @@ -231,8 +220,7 @@ void HexagonTarget::ReleaseLibrary() { crit_section_.lock(); if (module_pointer_ != AEE_EUNKNOWN) { const StubAPI* stub_api = StubAPI::Global(); - int rc = stub_api->tvm_hexagon_remote_release_library( - domain_channel_handle_, module_pointer_); + int rc = stub_api->tvm_remote_release_library(domain_channel_handle_, module_pointer_); if (rc != AEE_SUCCESS) { TVM_LOGE_HT("failed to unload device library rc=0x%x", rc); } else { @@ -267,23 +255,20 @@ void* HexagonTarget::Alloc(unsigned size, unsigned align) { // thread then remote_mmap64 fails. FastRPC expects one call to be made to // DSP before calling remote_map64. Hence this call is needed for now untill // FastRPC comes up with a fix. - int rc_call_mmap_64 = - stub_api->tvm_hexagon_remote_call_mmap64(domain_channel_handle_); + int rc_call_mmap_64 = stub_api->tvm_remote_call_mmap64(domain_channel_handle_); if (rc_call_mmap_64 != AEE_SUCCESS) { TVM_LOGE_HT("mmap64 failed for domain channel %lu", domain_channel_handle_); return nullptr; } - void* mem = - stub_api->rpcmem_alloc_ptr()(RPCMEM_HEAP, RPCMEM_DEFAULT_FLAGS, size); + void* mem = stub_api->rpcmem_alloc_ptr()(RPCMEM_HEAP, RPCMEM_DEFAULT_FLAGS, size); if (mem == nullptr) { TVM_LOGE_HT("mem alloc failed for size=0x%x alignment=0x%x", size, align); return nullptr; } int mem_fd = stub_api->rpcmem_to_fd_ptr()(mem); uintptr_t dsp_va = 0; - int rc = dsp_api->remote_mmap64_ptr()( - mem_fd, 0, reinterpret_cast(mem), size, &dsp_va); + int rc = dsp_api->remote_mmap64_ptr()(mem_fd, 0, reinterpret_cast(mem), size, &dsp_va); if (rc != AEE_SUCCESS) { TVM_LOGE_HT( "buffer mapping failed for remote_map64 fd=0x%x rc=0x%x " @@ -312,8 +297,7 @@ void HexagonTarget::Free(void* ptr) { auto aa = GetAppsAddr(ptr, true); if (aa.first == nullptr) return; - int rc = dsp_api->remote_munmap64_ptr()(reinterpret_cast(ptr), - aa.second); + int rc = dsp_api->remote_munmap64_ptr()(reinterpret_cast(ptr), aa.second); if (rc != AEE_SUCCESS) { TVM_LOGE_HT("buffer unmapping failed rc=0x%x", rc); } @@ -325,8 +309,7 @@ void* HexagonTarget::AllocVtcm(unsigned size, unsigned align) { const StubAPI* stub_api = StubAPI::Global(); unsigned int dsp_va = 0; - int rc = stub_api->tvm_hexagon_remote_alloc_vtcm(domain_channel_handle_, - size, align, &dsp_va); + int rc = stub_api->tvm_remote_alloc_vtcm(domain_channel_handle_, size, align, &dsp_va); if (rc != AEE_SUCCESS) { TVM_LOGE_HT("VTCM allocation failed size=%u, align=%u", size, align); return nullptr; @@ -342,16 +325,14 @@ void HexagonTarget::FreeVtcm(void* ptr) { TVM_LOGD_HT("%s:Calling vtcm free. ptr=%p", __func__, ptr); uintptr_t dsp_va = reinterpret_cast(ptr); - int rc = - stub_api->tvm_hexagon_remote_free_vtcm(domain_channel_handle_, dsp_va); + int rc = stub_api->tvm_remote_free_vtcm(domain_channel_handle_, dsp_va); if (rc != AEE_SUCCESS) { TVM_LOGE_HT("VTCM deallocation failed"); } TVM_LOGD_HT("Done VTCM free from HexagonTarget::FreeVtcm"); } -void HexagonTarget::CopyDeviceToDevice(void* dst, const void* src, - unsigned len) { +void HexagonTarget::CopyDeviceToDevice(void* dst, const void* src, unsigned len) { auto aa_src = GetAppsAddr(src, false); auto aa_dst = GetAppsAddr(dst, false); if (aa_src.first == vtcm_mark_ || aa_dst.first == vtcm_mark_) { @@ -365,21 +346,22 @@ void HexagonTarget::CopyDeviceToDevice(void* dst, const void* src, if (aa_src.second < len) { TVM_LOGD_HT( "specified length:%u larger than source buffer size:%zu, copy " - "truncated", len, aa_src.second); + "truncated", + len, aa_src.second); } if (aa_dst.second < len) { TVM_LOGD_HT( "specified length:%u larger than dest buffer size:%zu, copy " - "truncated", len, aa_dst.second); + "truncated", + len, aa_dst.second); } len = std::min({size_t(len), aa_src.second, aa_dst.second}); - TVM_LOGD_HT("copy, dsp:%p(apps:%p) -> dsp:%p(apps:%p), len:%u", - src, aa_src.first, dst, aa_dst.first, len); + TVM_LOGD_HT("copy, dsp:%p(apps:%p) -> dsp:%p(apps:%p), len:%u", src, aa_src.first, dst, + aa_dst.first, len); std::memcpy(aa_dst.first, aa_src.first, len); } -void HexagonTarget::CopyDeviceToHost(void* host_dst, const void* src, - unsigned len) { +void HexagonTarget::CopyDeviceToHost(void* host_dst, const void* src, unsigned len) { auto aa = GetAppsAddr(src, false); if (aa.first == vtcm_mark_) { TVM_LOGE_HT("VTCM address. Copy operation not supported"); @@ -390,18 +372,14 @@ void HexagonTarget::CopyDeviceToHost(void* host_dst, const void* src, return; } if (aa.second < len) { - TVM_LOGD_HT( - "specified length:%u larger than buffer size:%zu, copy truncated", len, - aa.second); + TVM_LOGD_HT("specified length:%u larger than buffer size:%zu, copy truncated", len, aa.second); len = aa.second; } - TVM_LOGD_HT("copy, dsp:%p(apps:%p) -> apps:%p, len:%u", src, aa.first, - host_dst, len); + TVM_LOGD_HT("copy, dsp:%p(apps:%p) -> apps:%p, len:%u", src, aa.first, host_dst, len); std::memcpy(host_dst, aa.first, len); } -void HexagonTarget::CopyHostToDevice(void* dst, const void* host_src, - unsigned len) { +void HexagonTarget::CopyHostToDevice(void* dst, const void* host_src, unsigned len) { auto aa = GetAppsAddr(dst, false); if (aa.first == vtcm_mark_) { TVM_LOGE_HT("VTCM address. Copy operation not supported"); @@ -412,13 +390,10 @@ void HexagonTarget::CopyHostToDevice(void* dst, const void* host_src, return; } if (aa.second < len) { - TVM_LOGD_HT( - "specified length:%u larger than buffer size:%zu, copy truncated", len, - aa.second); + TVM_LOGD_HT("specified length:%u larger than buffer size:%zu, copy truncated", len, aa.second); len = aa.second; } - TVM_LOGD_HT("copy, dsp:%p(apps:%p) <- apps:%p, len:%u", dst, aa.first, - host_src, len); + TVM_LOGD_HT("copy, dsp:%p(apps:%p) <- apps:%p, len:%u", dst, aa.first, host_src, len); std::memcpy(aa.first, host_src, len); } @@ -427,8 +402,7 @@ void* HexagonTarget::Load(const std::string& data, const std::string& fmt) { int rc_oc = OpenDomainChannel(/*use_unsigned_pd*/ unsigned_pd); crit_section_.unlock(); if (rc_oc != AEE_SUCCESS) { - TVM_LOGE_HT("loading of %s failed: unable to open domain channel", - data.c_str()); + TVM_LOGE_HT("loading of %s failed: unable to open domain channel", data.c_str()); return nullptr; } @@ -438,9 +412,8 @@ void* HexagonTarget::Load(const std::string& data, const std::string& fmt) { crit_section_.lock(); TVM_LOGD_HT("loading library %s ", data.c_str()); const StubAPI* stub_api = StubAPI::Global(); - int rc = stub_api->tvm_hexagon_remote_load_library( - domain_channel_handle_, data.c_str(), data.size() + 1, data.c_str(), - data.size() + 1, &module_pointer_); + int rc = stub_api->tvm_remote_load_library(domain_channel_handle_, data.c_str(), data.size() + 1, + &module_pointer_); if (rc != AEE_SUCCESS) { TVM_LOGE_HT("failed to load device library rc=0x%x", rc); } @@ -470,11 +443,10 @@ void HexagonTarget::Unload(void* mod) { void* HexagonTarget::Resolve(const std::string& sym) { const StubAPI* stub_api = StubAPI::Global(); - tvm_hexagon_remote_handle_t pf; + tvm_remote_handle_t pf; TVM_LOGD_HT("resolving symbol %s", sym.c_str()); - int rc = stub_api->tvm_hexagon_remote_get_symbol( - domain_channel_handle_, module_pointer_, sym.c_str(), sym.size() + 1, - &pf); + int rc = stub_api->tvm_remote_get_symbol(domain_channel_handle_, module_pointer_, sym.c_str(), + sym.size() + 1, &pf); if (rc != AEE_SUCCESS) { TVM_LOGE_HT("failed to get symbol from CDSP rc=0x%x", rc); return nullptr; @@ -484,27 +456,21 @@ void* HexagonTarget::Resolve(const std::string& sym) { return addr; } -void HexagonTarget::Call(void* func, uint32_t* scalar, unsigned scalar_num, - uint32_t* stack, unsigned stack_num) { +void HexagonTarget::Call(void* func, uint32_t* scalar, unsigned scalar_num, uint32_t* stack, + unsigned stack_num) { uint64 pcycles = 0, execution_time_usec = 0; - auto scalar_octet = std::unique_ptr( - new tvm_hexagon_remote_buffer[scalar_num]); - auto stack_octet = std::unique_ptr( - new tvm_hexagon_remote_buffer[stack_num]); + auto scalar_octet = std::unique_ptr(new tvm_remote_buffer[scalar_num]); + auto stack_octet = std::unique_ptr(new tvm_remote_buffer[stack_num]); TVM_LOGD_HT("scalars=%p, stack=%p", scalar, stack); if (scalar_octet == nullptr || stack_octet == nullptr) { TVM_LOGE_HT("mem alloc failed for scalar/stack octets"); return; } - std::memset(scalar_octet.get(), 0, - scalar_num * sizeof(tvm_hexagon_remote_buffer)); - std::memset(stack_octet.get(), 0, - stack_num * sizeof(tvm_hexagon_remote_buffer)); + std::memset(scalar_octet.get(), 0, scalar_num * sizeof(tvm_remote_buffer)); + std::memset(stack_octet.get(), 0, stack_num * sizeof(tvm_remote_buffer)); - auto ProcessInputs = [this](uint32_t* inputs, - tvm_hexagon_remote_buffer* buffers, - unsigned num) { + auto ProcessInputs = [this](uint32_t* inputs, tvm_remote_buffer* buffers, unsigned num) { for (unsigned i = 0; i != num; ++i) { void* ptr = reinterpret_cast(static_cast(inputs[i])); auto aa = GetAppsAddr(ptr, false); @@ -533,20 +499,18 @@ void HexagonTarget::Call(void* func, uint32_t* scalar, unsigned scalar_num, TVM_LOGD_HT("%s", ToString(" stack", stack, stack_num).c_str()); const StubAPI* stub_api = StubAPI::Global(); - int rc = stub_api->tvm_hexagon_remote_kernel( + int rc = stub_api->tvm_remote_kernel( domain_channel_handle_, module_pointer_, - static_cast( - reinterpret_cast(func)), - reinterpret_cast(scalar), scalar_num, - reinterpret_cast(stack), stack_num, scalar_octet.get(), scalar_num, - scalar_octet.get(), scalar_num, stack_octet.get(), stack_num, + static_cast(reinterpret_cast(func)), + reinterpret_cast(scalar), scalar_num, reinterpret_cast(stack), stack_num, + scalar_octet.get(), scalar_num, scalar_octet.get(), scalar_num, stack_octet.get(), stack_num, stack_octet.get(), stack_num, &pcycles, &execution_time_usec); if (rc != AEE_SUCCESS) { TVM_LOGE_HT("failed to run kernel on CDSP rc=0x%x", rc); } else { - TVM_LOGD_HT("kernel execution: %llu pcycles, %llu usec, scalar_num=%d", - pcycles, execution_time_usec, scalar_num); + TVM_LOGD_HT("kernel execution: %llu pcycles, %llu usec, scalar_num=%d", pcycles, + execution_time_usec, scalar_num); } } diff --git a/src/runtime/hexagon/target/hexagon_stubapi.cc b/src/runtime/hexagon/target/hexagon_stubapi.cc index 3600640e89b7..2ed33471b98f 100644 --- a/src/runtime/hexagon/target/hexagon_stubapi.cc +++ b/src/runtime/hexagon/target/hexagon_stubapi.cc @@ -41,31 +41,30 @@ StubAPI::StubAPI() { TVM_LOGD("ADSP subsystem present"); } - constexpr auto domain_lib_name = "libtvm_hexagon_remote_stub.so"; - constexpr auto nondomain_lib_name = "libtvm_hexagon_remote_nd_stub.so"; + constexpr auto domain_lib_name = "libtvm_remote_stub.so"; + constexpr auto nondomain_lib_name = "libtvm_remote_nd_stub.so"; - const char* lib_name = - enable_domains_ ? domain_lib_name : nondomain_lib_name; + const char* lib_name = enable_domains_ ? domain_lib_name : nondomain_lib_name; CHECK(lib_handle_ = dlopen(lib_name, RTLD_LAZY | RTLD_LOCAL)); #define RESOLVE(fn) p##fn##_ = GetSymbol(#fn) if (enable_domains_) { - RESOLVE(tvm_hexagon_remote_load_library); - RESOLVE(tvm_hexagon_remote_release_library); - RESOLVE(tvm_hexagon_remote_get_symbol); - RESOLVE(tvm_hexagon_remote_kernel); - RESOLVE(tvm_hexagon_remote_open); - RESOLVE(tvm_hexagon_remote_close); - RESOLVE(tvm_hexagon_remote_alloc_vtcm); - RESOLVE(tvm_hexagon_remote_free_vtcm); - RESOLVE(tvm_hexagon_remote_call_mmap64); + RESOLVE(tvm_remote_load_library); + RESOLVE(tvm_remote_release_library); + RESOLVE(tvm_remote_get_symbol); + RESOLVE(tvm_remote_kernel); + RESOLVE(tvm_remote_open); + RESOLVE(tvm_remote_close); + RESOLVE(tvm_remote_alloc_vtcm); + RESOLVE(tvm_remote_free_vtcm); + RESOLVE(tvm_remote_call_mmap64); } else { - RESOLVE(tvm_hexagon_remote_nd_load_library); - RESOLVE(tvm_hexagon_remote_nd_release_library); - RESOLVE(tvm_hexagon_remote_nd_get_symbol); - RESOLVE(tvm_hexagon_remote_nd_kernel); - RESOLVE(tvm_hexagon_remote_nd_open); - RESOLVE(tvm_hexagon_remote_nd_call_mmap64); + RESOLVE(tvm_remote_nd_load_library); + RESOLVE(tvm_remote_nd_release_library); + RESOLVE(tvm_remote_nd_get_symbol); + RESOLVE(tvm_remote_nd_kernel); + RESOLVE(tvm_remote_nd_open); + RESOLVE(tvm_remote_nd_call_mmap64); } RESOLVE(rpcmem_init); diff --git a/src/runtime/hexagon/target/hexagon_stubapi.h b/src/runtime/hexagon/target/hexagon_stubapi.h index ef3dcfdbcc79..5213b6d0d7af 100644 --- a/src/runtime/hexagon/target/hexagon_stubapi.h +++ b/src/runtime/hexagon/target/hexagon_stubapi.h @@ -28,8 +28,8 @@ #include -#include "fastrpc/tvm_hexagon_remote.h" -#include "fastrpc/tvm_hexagon_remote_nd.h" +#include "fastrpc/include/tvm_remote.h" +#include "fastrpc/include/tvm_remote_nd.h" namespace tvm { namespace runtime { @@ -39,15 +39,15 @@ namespace hexagon { * Unify the handling of domain and non-domain functions. * * In most cases, for a function "foo", the domain version will be called - * "tvm_hexagon_remote_foo", and the non-domain version will have "nd_foo". + * "tvm_remote_foo", and the non-domain version will have "nd_foo". * The interfaces will be the same, except: * - the domain version will take "remote_handle64" as the first parameter, * while the non-domain version will not: - * int tvm_hexagon_remote_foo (remote_handle64 h, param1, param2, ...); - * int tvm_hexagon_remote_nd_foo (param1, param2, ...); + * int tvm_remote_foo (remote_handle64 h, param1, param2, ...); + * int tvm_remote_nd_foo (param1, param2, ...); * - any parameter of type "buffer" in the IDL, will be converted into a - * type "tvm_hexagon_remote_buffer" for domain functions, and into - * "tvm_hexagon_remote_nd_buffer" for non-domain functions. These two + * type "tvm_remote_buffer" for domain functions, and into + * "tvm_remote_nd_buffer" for non-domain functions. These two * types are identical, but since they are declared in two different IDLs, * they get different names. * @@ -55,32 +55,32 @@ namespace hexagon { * since the pointee types are different, this is enough to create a * difference in the function signatures even if the "remote_handle64" * parameter is ignored. For this reason, in all function types, the - * types "tvm_hexagon_remote_buffer *" and "tvm_hexagon_remote_nd_buffer *", + * types "tvm_remote_buffer *" and "tvm_remote_nd_buffer *", * both const and non-const, are replaced with "void *", with the * corresponding const-qualification. This is done by the templates * "replace_pointee_type" and "map_tuple_element" below. * * The following functions are subject to the uniform handling: * - * tvm_hexagon_remote_load_library (remote_handle64 h, p1, p2, ...) - * tvm_hexagon_remote_release_library - * tvm_hexagon_remote_get_symbol - * tvm_hexagon_remote_kernel - * tvm_hexagon_remote_close - * tvm_hexagon_remote_alloc_vtcm - * tvm_hexagon_remote_free_vtcm + * tvm_remote_load_library (remote_handle64 h, p1, p2, ...) + * tvm_remote_release_library + * tvm_remote_get_symbol + * tvm_remote_kernel + * tvm_remote_close + * tvm_remote_alloc_vtcm + * tvm_remote_free_vtcm * - * tvm_hexagon_remote_nd_load_library (p1, p2, ...) - * tvm_hexagon_remote_nd_release_library - * tvm_hexagon_remote_nd_get_symbol - * tvm_hexagon_remote_nd_kernel - * tvm_hexagon_remote_nd_close + * tvm_remote_nd_load_library (p1, p2, ...) + * tvm_remote_nd_release_library + * tvm_remote_nd_get_symbol + * tvm_remote_nd_kernel + * tvm_remote_nd_close * * The "open" functions differ in their parameters in different ways, and * need to be handled individually. * - * tvm_hexagon_remote_open - * tvm_hexagon_remote_nd_open + * tvm_remote_open + * tvm_remote_nd_open */ namespace { @@ -157,35 +157,34 @@ class StubAPI { private: // Create types for each remote function. For functions that take - // a pointer to tvm_hexagon_remote_buffer or tvm_hexagon_remote_nd_buffer, + // a pointer to tvm_remote_buffer or tvm_remote_nd_buffer, // replace that pointer with pointer to void to make pointers to these // two types identical in the function types created below. - // For example, int foo(tvm_hexagon_remote_buffer*) and - // int bar(tvm_hexagon_remote_nd_buffer*) should both have the same type. -#define MAPTYPE(fn, ty) \ - using fn##_t = typename map_func_type::type; - MAPTYPE(tvm_hexagon_remote_load_library, tvm_hexagon_remote_buffer) - MAPTYPE(tvm_hexagon_remote_release_library, tvm_hexagon_remote_buffer) - MAPTYPE(tvm_hexagon_remote_get_symbol, tvm_hexagon_remote_buffer) - MAPTYPE(tvm_hexagon_remote_kernel, tvm_hexagon_remote_buffer) - MAPTYPE(tvm_hexagon_remote_close, tvm_hexagon_remote_buffer) - MAPTYPE(tvm_hexagon_remote_alloc_vtcm, tvm_hexagon_remote_buffer) - MAPTYPE(tvm_hexagon_remote_free_vtcm, tvm_hexagon_remote_buffer) - MAPTYPE(tvm_hexagon_remote_call_mmap64, tvm_hexagon_remote_buffer) - - MAPTYPE(tvm_hexagon_remote_nd_load_library, tvm_hexagon_remote_nd_buffer) - MAPTYPE(tvm_hexagon_remote_nd_release_library, tvm_hexagon_remote_nd_buffer) - MAPTYPE(tvm_hexagon_remote_nd_get_symbol, tvm_hexagon_remote_nd_buffer) - MAPTYPE(tvm_hexagon_remote_nd_kernel, tvm_hexagon_remote_nd_buffer) - MAPTYPE(tvm_hexagon_remote_nd_close, tvm_hexagon_remote_buffer) - MAPTYPE(tvm_hexagon_remote_nd_call_mmap64, tvm_hexagon_remote_buffer) + // For example, int foo(tvm_remote_buffer*) and + // int bar(tvm_remote_nd_buffer*) should both have the same type. +#define MAPTYPE(fn, ty) using fn##_t = typename map_func_type::type; + MAPTYPE(tvm_remote_load_library, tvm_remote_buffer) + MAPTYPE(tvm_remote_release_library, tvm_remote_buffer) + MAPTYPE(tvm_remote_get_symbol, tvm_remote_buffer) + MAPTYPE(tvm_remote_kernel, tvm_remote_buffer) + MAPTYPE(tvm_remote_close, tvm_remote_buffer) + MAPTYPE(tvm_remote_alloc_vtcm, tvm_remote_buffer) + MAPTYPE(tvm_remote_free_vtcm, tvm_remote_buffer) + MAPTYPE(tvm_remote_call_mmap64, tvm_remote_buffer) + + MAPTYPE(tvm_remote_nd_load_library, tvm_remote_nd_buffer) + MAPTYPE(tvm_remote_nd_release_library, tvm_remote_nd_buffer) + MAPTYPE(tvm_remote_nd_get_symbol, tvm_remote_nd_buffer) + MAPTYPE(tvm_remote_nd_kernel, tvm_remote_nd_buffer) + MAPTYPE(tvm_remote_nd_close, tvm_remote_buffer) + MAPTYPE(tvm_remote_nd_call_mmap64, tvm_remote_buffer) #undef MAPTYPE // For remote functions whose prototypes differ significantly between // the domain and non-domain versions, create the types directly. #define DECLTYPE(fn) using fn##_t = decltype(::fn); - DECLTYPE(tvm_hexagon_remote_open) - DECLTYPE(tvm_hexagon_remote_nd_open) + DECLTYPE(tvm_remote_open) + DECLTYPE(tvm_remote_nd_open) DECLTYPE(rpcmem_init) DECLTYPE(rpcmem_deinit) @@ -196,8 +195,7 @@ class StubAPI { public: template - int invoke(Fd func_d, Fnd func_nd, remote_handle64 handle, - Ts... args) const { + int invoke(Fd func_d, Fnd func_nd, remote_handle64 handle, Ts... args) const { if (enable_domains_) { return func_d(handle, args...); } @@ -214,16 +212,15 @@ class StubAPI { #define CONCAT_STR_FOR_REAL(a, b) a##b #define CONCAT_STR(a, b) CONCAT_STR_FOR_REAL(a, b) -#define FUNC(name) CONCAT_STR(tvm_hexagon_remote_, name) -#define FUNC_D(name) CONCAT_STR(tvm_hexagon_remote_, name) -#define FUNC_ND(name) CONCAT_STR(tvm_hexagon_remote_nd_, name) +#define FUNC(name) CONCAT_STR(tvm_remote_, name) +#define FUNC_D(name) CONCAT_STR(tvm_remote_, name) +#define FUNC_ND(name) CONCAT_STR(tvm_remote_nd_, name) #define PTRNAME(fn) CONCAT_STR(p, CONCAT_STR(fn, _)) -#define DECLFUNC(name) \ - template \ - int FUNC(name)(remote_handle64 handle, Ts... args) const { \ - return invoke(PTRNAME(FUNC_D(name)), PTRNAME(FUNC_ND(name)), handle, \ - args...); \ +#define DECLFUNC(name) \ + template \ + int FUNC(name)(remote_handle64 handle, Ts... args) const { \ + return invoke(PTRNAME(FUNC_D(name)), PTRNAME(FUNC_ND(name)), handle, args...); \ } #define DECLFUNC_D(name) \ @@ -254,11 +251,11 @@ class StubAPI { #undef DECLSFUNC #undef DECLFUNC_D - int tvm_hexagon_remote_open(const char* uri, remote_handle64* handle) const { + int tvm_remote_open(const char* uri, remote_handle64* handle) const { if (enable_domains_) { - return PTRNAME(tvm_hexagon_remote_open)(uri, handle); + return PTRNAME(tvm_remote_open)(uri, handle); } - return PTRNAME(tvm_hexagon_remote_nd_open)(); + return PTRNAME(tvm_remote_nd_open)(); } static const StubAPI* Global(); @@ -268,23 +265,23 @@ class StubAPI { void* lib_handle_ = nullptr; #define DECLPTR(fn) fn##_t* PTRNAME(fn) = nullptr - DECLPTR(tvm_hexagon_remote_load_library); - DECLPTR(tvm_hexagon_remote_release_library); - DECLPTR(tvm_hexagon_remote_get_symbol); - DECLPTR(tvm_hexagon_remote_kernel); - DECLPTR(tvm_hexagon_remote_open); - DECLPTR(tvm_hexagon_remote_close); - DECLPTR(tvm_hexagon_remote_alloc_vtcm); - DECLPTR(tvm_hexagon_remote_free_vtcm); - DECLPTR(tvm_hexagon_remote_call_mmap64); - - DECLPTR(tvm_hexagon_remote_nd_load_library); - DECLPTR(tvm_hexagon_remote_nd_release_library); - DECLPTR(tvm_hexagon_remote_nd_get_symbol); - DECLPTR(tvm_hexagon_remote_nd_kernel); - DECLPTR(tvm_hexagon_remote_nd_open); - DECLPTR(tvm_hexagon_remote_nd_close); - DECLPTR(tvm_hexagon_remote_nd_call_mmap64); + DECLPTR(tvm_remote_load_library); + DECLPTR(tvm_remote_release_library); + DECLPTR(tvm_remote_get_symbol); + DECLPTR(tvm_remote_kernel); + DECLPTR(tvm_remote_open); + DECLPTR(tvm_remote_close); + DECLPTR(tvm_remote_alloc_vtcm); + DECLPTR(tvm_remote_free_vtcm); + DECLPTR(tvm_remote_call_mmap64); + + DECLPTR(tvm_remote_nd_load_library); + DECLPTR(tvm_remote_nd_release_library); + DECLPTR(tvm_remote_nd_get_symbol); + DECLPTR(tvm_remote_nd_kernel); + DECLPTR(tvm_remote_nd_open); + DECLPTR(tvm_remote_nd_close); + DECLPTR(tvm_remote_nd_call_mmap64); #undef DECLPTR // "System" functions. diff --git a/src/runtime/hexagon/target/hexagon_target_log.h b/src/runtime/hexagon/target/hexagon_target_log.h index ae09503cd35b..c7684fc56197 100644 --- a/src/runtime/hexagon/target/hexagon_target_log.h +++ b/src/runtime/hexagon/target/hexagon_target_log.h @@ -23,18 +23,12 @@ #include -#define TVM_LOGV(...) \ - __android_log_print(ANDROID_LOG_VERBOSE, "TVM", ##__VA_ARGS__) -#define TVM_LOGD(...) \ - __android_log_print(ANDROID_LOG_DEBUG, "TVM", ##__VA_ARGS__) -#define TVM_LOGI(...) \ - __android_log_print(ANDROID_LOG_INFO, "TVM", ##__VA_ARGS__) -#define TVM_LOGW(...) \ - __android_log_print(ANDROID_LOG_WARN, "TVM", ##__VA_ARGS__) -#define TVM_LOGE(...) \ - __android_log_print(ANDROID_LOG_ERROR, "TVM", ##__VA_ARGS__) -#define TVM_LOGF(...) \ - __android_log_print(ANDROID_LOG_FATAL, "TVM", ##__VA_ARGS__) +#define TVM_LOGV(...) __android_log_print(ANDROID_LOG_VERBOSE, "TVM", ##__VA_ARGS__) +#define TVM_LOGD(...) __android_log_print(ANDROID_LOG_DEBUG, "TVM", ##__VA_ARGS__) +#define TVM_LOGI(...) __android_log_print(ANDROID_LOG_INFO, "TVM", ##__VA_ARGS__) +#define TVM_LOGW(...) __android_log_print(ANDROID_LOG_WARN, "TVM", ##__VA_ARGS__) +#define TVM_LOGE(...) __android_log_print(ANDROID_LOG_ERROR, "TVM", ##__VA_ARGS__) +#define TVM_LOGF(...) __android_log_print(ANDROID_LOG_FATAL, "TVM", ##__VA_ARGS__) #endif // __ANDROID__ #endif // TVM_RUNTIME_HEXAGON_TARGET_HEXAGON_TARGET_LOG_H_ diff --git a/src/runtime/library_module.cc b/src/runtime/library_module.cc index 306a7e990516..7c3323c56229 100644 --- a/src/runtime/library_module.cc +++ b/src/runtime/library_module.cc @@ -21,13 +21,15 @@ * \file module_util.cc * \brief Utilities for module. */ +#include "library_module.h" + #include #include #include + #include -#include #include -#include "library_module.h" +#include namespace tvm { namespace runtime { @@ -35,22 +37,16 @@ namespace runtime { // Library module that exposes symbols from a library. class LibraryModuleNode final : public ModuleNode { public: - explicit LibraryModuleNode(ObjectPtr lib) - : lib_(lib) { - } + explicit LibraryModuleNode(ObjectPtr lib) : lib_(lib) {} - const char* type_key() const final { - return "library"; - } + const char* type_key() const final { return "library"; } - PackedFunc GetFunction( - const std::string& name, - const ObjectPtr& sptr_to_self) final { + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { TVMBackendPackedCFunc faddr; if (name == runtime::symbol::tvm_module_main) { - const char* entry_name = reinterpret_cast( - lib_->GetSymbol(runtime::symbol::tvm_module_main)); - CHECK(entry_name!= nullptr) + const char* entry_name = + reinterpret_cast(lib_->GetSymbol(runtime::symbol::tvm_module_main)); + CHECK(entry_name != nullptr) << "Symbol " << runtime::symbol::tvm_module_main << " is not presented"; faddr = reinterpret_cast(lib_->GetSymbol(entry_name)); } else { @@ -70,35 +66,27 @@ class LibraryModuleNode final : public ModuleNode { class ModuleInternal { public: // Get mutable reference of imports. - static std::vector* GetImportsAddr(ModuleNode* node) { - return &(node->imports_); - } + static std::vector* GetImportsAddr(ModuleNode* node) { return &(node->imports_); } }; -PackedFunc WrapPackedFunc(TVMBackendPackedCFunc faddr, - const ObjectPtr& sptr_to_self) { +PackedFunc WrapPackedFunc(TVMBackendPackedCFunc faddr, const ObjectPtr& sptr_to_self) { return PackedFunc([faddr, sptr_to_self](TVMArgs args, TVMRetValue* rv) { - TVMValue ret_value; - int ret_type_code = kTVMNullptr; - int ret = (*faddr)( - const_cast(args.values), - const_cast(args.type_codes), - args.num_args, - &ret_value, - &ret_type_code); - CHECK_EQ(ret, 0) << TVMGetLastError(); - if (ret_type_code != kTVMNullptr) { - *rv = TVMRetValue::MoveFromCHost(ret_value, ret_type_code); - } - }); + TVMValue ret_value; + int ret_type_code = kTVMNullptr; + int ret = (*faddr)(const_cast(args.values), const_cast(args.type_codes), + args.num_args, &ret_value, &ret_type_code); + CHECK_EQ(ret, 0) << TVMGetLastError(); + if (ret_type_code != kTVMNullptr) { + *rv = TVMRetValue::MoveFromCHost(ret_value, ret_type_code); + } + }); } void InitContextFunctions(std::function fgetsymbol) { - #define TVM_INIT_CONTEXT_FUNC(FuncName) \ - if (auto *fp = reinterpret_cast \ - (fgetsymbol("__" #FuncName))) { \ - *fp = FuncName; \ - } +#define TVM_INIT_CONTEXT_FUNC(FuncName) \ + if (auto* fp = reinterpret_cast(fgetsymbol("__" #FuncName))) { \ + *fp = FuncName; \ + } // Initialize the functions TVM_INIT_CONTEXT_FUNC(TVMFuncCall); TVM_INIT_CONTEXT_FUNC(TVMAPISetLastError); @@ -108,7 +96,7 @@ void InitContextFunctions(std::function fgetsymbol) { TVM_INIT_CONTEXT_FUNC(TVMBackendParallelLaunch); TVM_INIT_CONTEXT_FUNC(TVMBackendParallelBarrier); - #undef TVM_INIT_CONTEXT_FUNC +#undef TVM_INIT_CONTEXT_FUNC } /*! @@ -123,10 +111,10 @@ runtime::Module ProcessModuleBlob(const char* mblob, ObjectPtr lib) { uint64_t nbytes = 0; for (size_t i = 0; i < sizeof(nbytes); ++i) { uint64_t c = mblob[i]; - nbytes |= (c & 0xffUL) << (i * 8); + nbytes |= (c & 0xffUL) << (i * 8); } - dmlc::MemoryFixedSizeStream fs( - const_cast(mblob + sizeof(nbytes)), static_cast(nbytes)); + dmlc::MemoryFixedSizeStream fs(const_cast(mblob + sizeof(nbytes)), + static_cast(nbytes)); dmlc::Stream* stream = &fs; uint64_t size; CHECK(stream->Read(&size)); @@ -147,9 +135,7 @@ runtime::Module ProcessModuleBlob(const char* mblob, ObjectPtr lib) { } else { std::string fkey = "runtime.module.loadbinary_" + tkey; const PackedFunc* f = Registry::Get(fkey); - CHECK(f != nullptr) - << "Loader of " << tkey << "(" - << fkey << ") is not presented."; + CHECK(f != nullptr) << "Loader of " << tkey << "(" << fkey << ") is not presented."; Module m = (*f)(static_cast(stream)); modules.emplace_back(m); } @@ -180,14 +166,11 @@ runtime::Module ProcessModuleBlob(const char* mblob, ObjectPtr lib) { } Module CreateModuleFromLibrary(ObjectPtr lib) { - InitContextFunctions([lib](const char* fname) { - return lib->GetSymbol(fname); - }); + InitContextFunctions([lib](const char* fname) { return lib->GetSymbol(fname); }); auto n = make_object(lib); // Load the imported modules const char* dev_mblob = - reinterpret_cast( - lib->GetSymbol(runtime::symbol::tvm_dev_mblob)); + reinterpret_cast(lib->GetSymbol(runtime::symbol::tvm_dev_mblob)); Module root_mod; if (dev_mblob != nullptr) { root_mod = ProcessModuleBlob(dev_mblob, lib); @@ -197,8 +180,7 @@ Module CreateModuleFromLibrary(ObjectPtr lib) { } // allow lookup of symbol from root (so all symbols are visible). - if (auto *ctx_addr = - reinterpret_cast(lib->GetSymbol(runtime::symbol::tvm_module_ctx))) { + if (auto* ctx_addr = reinterpret_cast(lib->GetSymbol(runtime::symbol::tvm_module_ctx))) { *ctx_addr = root_mod.operator->(); } diff --git a/src/runtime/library_module.h b/src/runtime/library_module.h index 61e62661f149..91918c1ccaa3 100644 --- a/src/runtime/library_module.h +++ b/src/runtime/library_module.h @@ -24,9 +24,10 @@ #ifndef TVM_RUNTIME_LIBRARY_MODULE_H_ #define TVM_RUNTIME_LIBRARY_MODULE_H_ -#include -#include #include +#include +#include + #include namespace tvm { @@ -47,7 +48,7 @@ class Library : public Object { * \param name The name of the symbol. * \return The symbol. */ - virtual void *GetSymbol(const char* name) = 0; + virtual void* GetSymbol(const char* name) = 0; // NOTE: we do not explicitly create an type index and type_key here for libary. // This is because we do not need dynamic type downcasting. }; @@ -77,4 +78,4 @@ void InitContextFunctions(std::function fgetsymbol); Module CreateModuleFromLibrary(ObjectPtr lib); } // namespace runtime } // namespace tvm -#endif // TVM_RUNTIME_LIBRARY_MODULE_H_ +#endif // TVM_RUNTIME_LIBRARY_MODULE_H_ diff --git a/src/runtime/meta_data.h b/src/runtime/meta_data.h index 22f2e9aa0909..451c0e88fcb0 100644 --- a/src/runtime/meta_data.h +++ b/src/runtime/meta_data.h @@ -24,11 +24,13 @@ #ifndef TVM_RUNTIME_META_DATA_H_ #define TVM_RUNTIME_META_DATA_H_ -#include #include +#include #include + #include #include + #include "runtime_base.h" namespace tvm { @@ -40,10 +42,10 @@ struct FunctionInfo { std::vector arg_types; std::vector thread_axis_tags; - void Save(dmlc::JSONWriter *writer) const; - void Load(dmlc::JSONReader *reader); - void Save(dmlc::Stream *writer) const; - bool Load(dmlc::Stream *reader); + void Save(dmlc::JSONWriter* writer) const; + void Load(dmlc::JSONReader* reader); + void Save(dmlc::Stream* writer) const; + bool Load(dmlc::Stream* reader); }; } // namespace runtime } // namespace tvm diff --git a/src/runtime/metal/metal_common.h b/src/runtime/metal/metal_common.h index 8a7c9fe53018..ca369d46e5ba 100644 --- a/src/runtime/metal/metal_common.h +++ b/src/runtime/metal/metal_common.h @@ -24,21 +24,22 @@ #ifndef TVM_RUNTIME_METAL_METAL_COMMON_H_ #define TVM_RUNTIME_METAL_METAL_COMMON_H_ +#import #import -#import #import -#import +#import #import #import - +#include #include -#include #include -#include +#include + +#include #include #include #include -#include + #include "../workspace_pool.h" namespace tvm { @@ -64,14 +65,14 @@ class MetalWorkspace final : public DeviceAPI { // Get command queue for given context. id GetCommandQueue(TVMContext ctx) { CHECK_EQ(ctx.device_type, kDLMetal); - CHECK(ctx.device_id >= 0 && static_cast(ctx.device_id) < queues.size()) + CHECK(ctx.device_id >= 0 && static_cast(ctx.device_id) < queues.size()) << "Invalid Metal device_id=" << ctx.device_id; return queues[ctx.device_id]; } // Get device for given context id GetDevice(TVMContext ctx) { CHECK_EQ(ctx.device_type, kDLMetal); - CHECK(ctx.device_id >= 0 && static_cast(ctx.device_id) < devices.size()) + CHECK(ctx.device_id >= 0 && static_cast(ctx.device_id) < devices.size()) << "Invalid Metal device_id=" << ctx.device_id; return devices[ctx.device_id]; } @@ -81,19 +82,10 @@ class MetalWorkspace final : public DeviceAPI { // override device API void SetDevice(TVMContext ctx) final; void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final; - void* AllocDataSpace(TVMContext ctx, - size_t nbytes, - size_t alignment, - DLDataType type_hint) final; + void* AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment, DLDataType type_hint) final; void FreeDataSpace(TVMContext ctx, void* ptr) final; - void CopyDataFromTo(const void* from, - size_t from_size, - void* to, - size_t to_size, - size_t size, - TVMContext ctx_from, - TVMContext ctx_to, - DLDataType type_hint, + void CopyDataFromTo(const void* from, size_t from_size, void* to, size_t to_size, size_t size, + TVMContext ctx_from, TVMContext ctx_to, DLDataType type_hint, TVMStreamHandle stream) final; void StreamSync(TVMContext ctx, TVMStreamHandle stream) final; void* AllocWorkspace(TVMContext ctx, size_t size, DLDataType type_hint) final; @@ -112,8 +104,7 @@ class MetalThreadEntry { /*! \brief workspace pool */ WorkspacePool pool; // constructor - MetalThreadEntry() - : pool(static_cast(kDLMetal), MetalWorkspace::Global()) { + MetalThreadEntry() : pool(static_cast(kDLMetal), MetalWorkspace::Global()) { context.device_id = 0; context.device_type = static_cast(kDLMetal); } diff --git a/src/runtime/metal/metal_device_api.mm b/src/runtime/metal/metal_device_api.mm index a49f8a5cfc96..3bad2c3e9deb 100644 --- a/src/runtime/metal/metal_device_api.mm +++ b/src/runtime/metal/metal_device_api.mm @@ -20,8 +20,8 @@ /*! * \file metal_device_api.mm */ -#include #include +#include #include "metal_common.h" namespace tvm { @@ -29,25 +29,21 @@ namespace metal { const std::shared_ptr& MetalWorkspace::Global() { - static std::shared_ptr inst = - std::make_shared(); + static std::shared_ptr inst = std::make_shared(); return inst; } -void MetalWorkspace::GetAttr( - TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) { +void MetalWorkspace::GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) { this->Init(); size_t index = static_cast(ctx.device_id); if (kind == kExist) { - *rv = int(index< devices.size()); + *rv = int(index < devices.size()); return; } - CHECK_LT(index, devices.size()) - << "Invalid device id " << index; + CHECK_LT(index, devices.size()) << "Invalid device id " << index; switch (kind) { case kMaxThreadsPerBlock: { - *rv = static_cast( - [devices[ctx.device_id] maxThreadsPerThreadgroup].width); + *rv = static_cast([devices[ctx.device_id] maxThreadsPerThreadgroup].width); break; } case kWarpSize: { @@ -55,14 +51,22 @@ *rv = 1; break; } - case kMaxSharedMemoryPerBlock: return; - case kComputeVersion: return; - case kDeviceName: return; - case kMaxClockRate: return; - case kMultiProcessorCount: return; - case kMaxThreadDimensions: return; - case kExist: break; - case kGcnArch: return; + case kMaxSharedMemoryPerBlock: + return; + case kComputeVersion: + return; + case kDeviceName: + return; + case kMaxClockRate: + return; + case kMultiProcessorCount: + return; + case kMaxThreadDimensions: + return; + case kExist: + break; + case kGcnArch: + return; } } @@ -87,22 +91,13 @@ kernel void CopyKernel( // But we keep this code. int GetWarpSize(id dev) { NSError* error_msg = nil; - id lib = - [dev - newLibraryWithSource: - [NSString stringWithUTF8String:kDummyKernel] - options:nil - error:&error_msg]; + id lib = [dev newLibraryWithSource:[NSString stringWithUTF8String:kDummyKernel] + options:nil + error:&error_msg]; CHECK(lib != nil) << [[error_msg localizedDescription] UTF8String]; - id f = - [lib - newFunctionWithName: - [NSString stringWithUTF8String:"CopyKernel"]]; - CHECK(f!= nil); - id state = - [dev - newComputePipelineStateWithFunction:f - error:&error_msg]; + id f = [lib newFunctionWithName:[NSString stringWithUTF8String:"CopyKernel"]]; + CHECK(f != nil); + id state = [dev newComputePipelineStateWithFunction:f error:&error_msg]; CHECK(state != nil) << [[error_msg localizedDescription] UTF8String]; return static_cast(state.threadExecutionWidth); } @@ -123,20 +118,19 @@ int GetWarpSize(id dev) { initialized_ = true; if (devices.size() != 0) return; #if TARGET_OS_IPHONE - // on iPhone - id d = MTLCreateSystemDefaultDevice(); + // on iPhone + id d = MTLCreateSystemDefaultDevice(); + devices.push_back([d retain]); + queues.push_back([[d newCommandQueue] retain]); +#else + NSArray >* devs = MTLCopyAllDevices(); + for (size_t i = 0; i < devs.count; ++i) { + id d = [devs objectAtIndex:i]; devices.push_back([d retain]); queues.push_back([[d newCommandQueue] retain]); -#else - NSArray>* devs = MTLCopyAllDevices(); - for (size_t i = 0; i < devs.count; ++i) { - id d = [devs objectAtIndex:i]; - devices.push_back([d retain]); - queues.push_back([[d newCommandQueue] retain]); - LOG(INFO) << "Intializing Metal device " << i - << ", name=" << [d.name UTF8String]; - warp_size.push_back(GetWarpSize(d)); - } + LOG(INFO) << "Intializing Metal device " << i << ", name=" << [d.name UTF8String]; + warp_size.push_back(GetWarpSize(d)); + } #endif } @@ -144,8 +138,8 @@ int GetWarpSize(id dev) { MetalThreadEntry::ThreadLocal()->context.device_id = ctx.device_id; } -void* MetalWorkspace::AllocDataSpace( - TVMContext ctx, size_t nbytes, size_t alignment, DLDataType type_hint) { +void* MetalWorkspace::AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment, + DLDataType type_hint) { this->Init(); id dev = GetDevice(ctx); // GPU memory only @@ -157,9 +151,7 @@ int GetWarpSize(id dev) { storage_mode = MTLResourceStorageModeManaged; #endif */ - id buf = [ - dev newBufferWithLength:nbytes - options:storage_mode]; + id buf = [dev newBufferWithLength:nbytes options:storage_mode]; CHECK(buf != nil); return (__bridge void*)([buf retain]); } @@ -169,14 +161,9 @@ int GetWarpSize(id dev) { CFRelease(ptr); } -void MetalWorkspace::CopyDataFromTo(const void* from, - size_t from_offset, - void* to, - size_t to_offset, - size_t size, - TVMContext ctx_from, - TVMContext ctx_to, - DLDataType type_hint, +void MetalWorkspace::CopyDataFromTo(const void* from, size_t from_offset, void* to, + size_t to_offset, size_t size, TVMContext ctx_from, + TVMContext ctx_to, DLDataType type_hint, TVMStreamHandle stream) { this->Init(); CHECK(stream == nullptr); @@ -188,65 +175,54 @@ int GetWarpSize(id dev) { int to_dev_type = static_cast(ctx_to.device_type); if (from_dev_type == kDLMetal && to_dev_type == kDLMetal) { - CHECK_EQ(ctx_from.device_id, ctx_to.device_id) - << "Metal disallow cross device copy."; + CHECK_EQ(ctx_from.device_id, ctx_to.device_id) << "Metal disallow cross device copy."; id encoder = [cb blitCommandEncoder]; [encoder copyFromBuffer:(__bridge id)(from) - sourceOffset:from_offset - toBuffer:(__bridge id)(to) - destinationOffset:to_offset - size:size]; + sourceOffset:from_offset + toBuffer:(__bridge id)(to)destinationOffset:to_offset + size:size]; [encoder endEncoding]; [cb commit]; } else if (from_dev_type == kDLMetal && to_dev_type == kDLCPU) { // copy to a local buffer before get into global buffer. id from_buf = (__bridge id)(from); if (from_buf.storageMode != MTLStorageModeShared) { - id temp = MetalThreadEntry::ThreadLocal() - ->GetTempBuffer(ctx_from, size); + id temp = MetalThreadEntry::ThreadLocal()->GetTempBuffer(ctx_from, size); id encoder = [cb blitCommandEncoder]; [encoder copyFromBuffer:from_buf - sourceOffset:from_offset - toBuffer:temp - destinationOffset:0 - size:size]; + sourceOffset:from_offset + toBuffer:temp + destinationOffset:0 + size:size]; [encoder endEncoding]; [cb commit]; [cb waitUntilCompleted]; - memcpy(static_cast(to) + to_offset, - static_cast([temp contents]), - size); + memcpy(static_cast(to) + to_offset, static_cast([temp contents]), size); } else { memcpy(static_cast(to) + to_offset, - static_cast([from_buf contents]) + from_offset, - size); + static_cast([from_buf contents]) + from_offset, size); } } else if (from_dev_type == kDLCPU && to_dev_type == kDLMetal) { id to_buf = (__bridge id)(to); if (to_buf.storageMode != MTLStorageModeShared) { - id temp = MetalThreadEntry::ThreadLocal() - ->GetTempBuffer(ctx_to, size); - memcpy([temp contents], - static_cast(from) + from_offset, - size); + id temp = MetalThreadEntry::ThreadLocal()->GetTempBuffer(ctx_to, size); + memcpy([temp contents], static_cast(from) + from_offset, size); id encoder = [cb blitCommandEncoder]; [encoder copyFromBuffer:temp - sourceOffset:0 - toBuffer:to_buf - destinationOffset:to_offset - size:size]; + sourceOffset:0 + toBuffer:to_buf + destinationOffset:to_offset + size:size]; [encoder endEncoding]; [cb commit]; [cb waitUntilCompleted]; } else { memcpy(static_cast([to_buf contents]) + to_offset, - static_cast(from) + from_offset, - size); + static_cast(from) + from_offset, size); } } else { LOG(FATAL) << "Expect copy from/to Metal or between Metal" - << ", from=" << from_dev_type - << ", to=" << to_dev_type; + << ", from=" << from_dev_type << ", to=" << to_dev_type; } } @@ -259,9 +235,7 @@ int GetWarpSize(id dev) { [cb waitUntilCompleted]; } -void* MetalWorkspace::AllocWorkspace(TVMContext ctx, - size_t size, - DLDataType type_hint) { +void* MetalWorkspace::AllocWorkspace(TVMContext ctx, size_t size, DLDataType type_hint) { return MetalThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size); } @@ -279,30 +253,25 @@ int GetWarpSize(id dev) { if (temp_buffer_.size() <= static_cast(ctx.device_id)) { temp_buffer_.resize(ctx.device_id + 1, nil); } - if (temp_buffer_[ctx.device_id] == nil || - temp_buffer_[ctx.device_id].length < size) { + if (temp_buffer_[ctx.device_id] == nil || temp_buffer_[ctx.device_id].length < size) { id dev = MetalWorkspace::Global()->GetDevice(ctx); if (temp_buffer_[ctx.device_id] != nil) { [temp_buffer_[ctx.device_id] release]; } - temp_buffer_[ctx.device_id] = [ - [dev newBufferWithLength:size - options:MTLStorageModeShared] retain]; + temp_buffer_[ctx.device_id] = [[dev newBufferWithLength:size + options:MTLStorageModeShared] retain]; } return temp_buffer_[ctx.device_id]; } typedef dmlc::ThreadLocalStore MetalThreadStore; -MetalThreadEntry* MetalThreadEntry::ThreadLocal() { - return MetalThreadStore::Get(); -} +MetalThreadEntry* MetalThreadEntry::ThreadLocal() { return MetalThreadStore::Get(); } -TVM_REGISTER_GLOBAL("device_api.metal") -.set_body([](TVMArgs args, TVMRetValue* rv) { - DeviceAPI* ptr = MetalWorkspace::Global().get(); - *rv = static_cast(ptr); - }); +TVM_REGISTER_GLOBAL("device_api.metal").set_body([](TVMArgs args, TVMRetValue* rv) { + DeviceAPI* ptr = MetalWorkspace::Global().get(); + *rv = static_cast(ptr); +}); } // namespace metal } // namespace runtime diff --git a/src/runtime/metal/metal_module.h b/src/runtime/metal/metal_module.h index 0d2d429fcf61..77cdf64df8bc 100644 --- a/src/runtime/metal/metal_module.h +++ b/src/runtime/metal/metal_module.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -25,10 +25,12 @@ #define TVM_RUNTIME_METAL_METAL_MODULE_H_ #include + #include -#include #include #include +#include + #include "../meta_data.h" namespace tvm { @@ -44,11 +46,8 @@ static constexpr const int kMetalMaxNumDevice = 32; * \param fmap The map function information map of each function. * \param source Optional, source file */ -Module MetalModuleCreate( - std::string data, - std::string fmt, - std::unordered_map fmap, - std::string source); +Module MetalModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, std::string source); } // namespace runtime } // namespace tvm #endif // TVM_RUNTIME_METAL_METAL_MODULE_H_ diff --git a/src/runtime/metal/metal_module.mm b/src/runtime/metal/metal_module.mm index 41269b9f1a5d..9bdebf3d06c1 100644 --- a/src/runtime/metal/metal_module.mm +++ b/src/runtime/metal/metal_module.mm @@ -20,18 +20,18 @@ /*! * \file metal_module.cc */ +#include "metal_module.h" #include -#include #include +#include #include -#include #include -#include "metal_module.h" -#include "metal_common.h" +#include +#include "../file_util.h" +#include "../meta_data.h" #include "../pack_args.h" #include "../thread_storage_scope.h" -#include "../meta_data.h" -#include "../file_util.h" +#include "metal_common.h" namespace tvm { namespace runtime { @@ -39,27 +39,18 @@ // Module to support thread-safe multi-GPU execution. // The runtime will contain a per-device module table // The modules will be lazily loaded -class MetalModuleNode final :public runtime::ModuleNode { +class MetalModuleNode final : public runtime::ModuleNode { public: - explicit MetalModuleNode(std::string data, - std::string fmt, - std::unordered_map fmap, - std::string source) - : data_(data), fmt_(fmt), fmap_(fmap), source_(source) { - } - const char* type_key() const final { - return "metal"; - } + explicit MetalModuleNode(std::string data, std::string fmt, + std::unordered_map fmap, std::string source) + : data_(data), fmt_(fmt), fmap_(fmap), source_(source) {} + const char* type_key() const final { return "metal"; } - PackedFunc GetFunction( - const std::string& name, - const ObjectPtr& sptr_to_self) final; + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final; - void SaveToFile(const std::string& file_name, - const std::string& format) final { + void SaveToFile(const std::string& file_name, const std::string& format) final { std::string fmt = GetFileFormat(file_name, format); - CHECK_EQ(fmt, fmt_) - << "Can only save to format=" << fmt_; + CHECK_EQ(fmt, fmt_) << "Can only save to format=" << fmt_; std::string meta_file = GetMetaFilePath(file_name); SaveMetaDataToFile(meta_file, fmap_); SaveBinaryToFile(file_name, data_); @@ -81,8 +72,7 @@ void SaveToBinary(dmlc::Stream* stream) final { } } // get a from primary context in device_id - id GetPipelineState( - size_t device_id, const std::string& func_name) { + id GetPipelineState(size_t device_id, const std::string& func_name) { metal::MetalWorkspace* w = metal::MetalWorkspace::Global().get(); CHECK_LT(device_id, w->devices.size()); // start lock scope. @@ -97,53 +87,43 @@ void SaveToBinary(dmlc::Stream* stream) final { NSError* err_msg = nil; if (e.lib == nil) { if (fmt_ == "metal") { - MTLCompileOptions *opts = [MTLCompileOptions alloc]; + MTLCompileOptions* opts = [MTLCompileOptions alloc]; // Use the Metal 1.2 for now. opts.languageVersion = MTLLanguageVersion1_2; opts.fastMathEnabled = YES; // opts = nil; - e.lib = [ - w->devices[device_id] - newLibraryWithSource:[NSString stringWithUTF8String:data_.c_str()] - options:opts - error:&err_msg]; + e.lib = [w->devices[device_id] + newLibraryWithSource:[NSString stringWithUTF8String:data_.c_str()] + options:opts + error:&err_msg]; [opts dealloc]; if (e.lib == nil) { - LOG(FATAL) << "Fail to compile metal lib:" - << [[err_msg localizedDescription] UTF8String]; + LOG(FATAL) << "Fail to compile metal lib:" << [[err_msg localizedDescription] UTF8String]; } if (err_msg != nil) { - LOG(INFO) << "Warning: " - << [[err_msg localizedDescription] UTF8String]; + LOG(INFO) << "Warning: " << [[err_msg localizedDescription] UTF8String]; } } else { // Build from library. auto q = dispatch_queue_create("q", DISPATCH_QUEUE_SERIAL); - auto data = dispatch_data_create( - data_.c_str(), data_.length(), q, ^{}); - e.lib = [ - w->devices[device_id] - newLibraryWithData:data - error:&err_msg]; + auto data = dispatch_data_create(data_.c_str(), data_.length(), q, + ^{ + }); + e.lib = [w->devices[device_id] newLibraryWithData:data error:&err_msg]; if (err_msg != nil || e.lib == nil) { - LOG(FATAL) << "Fail to compile metal lib:" - << [[err_msg localizedDescription] UTF8String]; + LOG(FATAL) << "Fail to compile metal lib:" << [[err_msg localizedDescription] UTF8String]; } } [e.lib retain]; } - id f = [ - e.lib - newFunctionWithName: - [NSString stringWithUTF8String:func_name.c_str()]]; + id f = + [e.lib newFunctionWithName:[NSString stringWithUTF8String:func_name.c_str()]]; CHECK(f != nil) << "cannot find function " << func_name; id state = - [w->devices[device_id] - newComputePipelineStateWithFunction:f - error:&err_msg]; - CHECK(state != nil) - << "cannot get state:" << " for function " << func_name - << [[err_msg localizedDescription] UTF8String]; + [w->devices[device_id] newComputePipelineStateWithFunction:f error:&err_msg]; + CHECK(state != nil) << "cannot get state:" + << " for function " << func_name + << [[err_msg localizedDescription] UTF8String]; // The state.threadExecutionWidth can change dynamically according // to the resource constraint in kernel, so it is not strictly hold // Turn of warp aware optimziation for now. @@ -162,7 +142,7 @@ void SaveToBinary(dmlc::Stream* stream) final { ~DeviceEntry() { if (lib != nil) [lib release]; - for (auto &&kv : smap) { + for (auto&& kv : smap) { [kv.second release]; } } @@ -185,11 +165,8 @@ void SaveToBinary(dmlc::Stream* stream) final { class MetalWrappedFunc { public: // initialize the METAL function. - void Init(MetalModuleNode* m, - ObjectPtr sptr, - const std::string& func_name, - size_t num_buffer_args, - size_t num_pack_args, + void Init(MetalModuleNode* m, ObjectPtr sptr, const std::string& func_name, + size_t num_buffer_args, size_t num_pack_args, const std::vector& thread_axis_tags) { w_ = metal::MetalWorkspace::Global().get(); m_ = m; @@ -204,9 +181,7 @@ void Init(MetalModuleNode* m, scache_[dev_id] = m->GetPipelineState(dev_id, func_name); } // invoke the function with void arguments - void operator()(TVMArgs args, - TVMRetValue* rv, - const ArgUnion* pack_args) const { + void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion* pack_args) const { metal::MetalThreadEntry* t = metal::MetalThreadEntry::ThreadLocal(); int device_id = t->context.device_id; if (scache_[device_id] == nil) { @@ -223,16 +198,13 @@ void operator()(TVMArgs args, } if (num_pack_args_ != 0) { [encoder setBytes:pack_args - length:num_pack_args_ * sizeof(ArgUnion) - atIndex:num_buffer_args_]; + length:num_pack_args_ * sizeof(ArgUnion) + atIndex:num_buffer_args_]; } // launch - MTLSize dimGrid = MTLSizeMake( - wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2)); - MTLSize dimBlock = MTLSizeMake( - wl.block_dim(0), wl.block_dim(1), wl.block_dim(2)); - [encoder dispatchThreadgroups: dimGrid - threadsPerThreadgroup: dimBlock]; + MTLSize dimGrid = MTLSizeMake(wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2)); + MTLSize dimBlock = MTLSizeMake(wl.block_dim(0), wl.block_dim(1), wl.block_dim(2)); + [encoder dispatchThreadgroups:dimGrid threadsPerThreadgroup:dimBlock]; [encoder endEncoding]; [cb commit]; } @@ -257,36 +229,29 @@ void operator()(TVMArgs args, ThreadAxisConfig thread_axis_cfg_; }; -PackedFunc MetalModuleNode::GetFunction( - const std::string& name, - const ObjectPtr& sptr_to_self) { +PackedFunc MetalModuleNode::GetFunction(const std::string& name, + const ObjectPtr& sptr_to_self) { CHECK_EQ(sptr_to_self.get(), this); - CHECK_NE(name, symbol::tvm_module_main) - << "Device function do not have main"; + CHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main"; auto it = fmap_.find(name); if (it == fmap_.end()) return PackedFunc(); const FunctionInfo& info = it->second; MetalWrappedFunc f; size_t num_buffer_args = NumBufferArgs(info.arg_types); - f.Init(this, sptr_to_self, name, - num_buffer_args, info.arg_types.size() - num_buffer_args, + f.Init(this, sptr_to_self, name, num_buffer_args, info.arg_types.size() - num_buffer_args, info.thread_axis_tags); return PackFuncNonBufferArg(f, info.arg_types); } -Module MetalModuleCreate( - std::string data, - std::string fmt, - std::unordered_map fmap, - std::string source) { +Module MetalModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, std::string source) { metal::MetalWorkspace::Global()->Init(); auto n = make_object(data, fmt, fmap, source); return Module(n); } // Load module from module. -Module MetalModuleLoadFile(const std::string& file_name, - const std::string& format) { +Module MetalModuleLoadFile(const std::string& file_name, const std::string& format) { std::string data; std::unordered_map fmap; std::string fmt = GetFileFormat(file_name, format); @@ -307,10 +272,8 @@ Module MetalModuleLoadBinary(void* strm) { return MetalModuleCreate(data, fmt, fmap, ""); } -TVM_REGISTER_GLOBAL("runtime.module.loadfile_metal") -.set_body_typed(MetalModuleLoadFile); +TVM_REGISTER_GLOBAL("runtime.module.loadfile_metal").set_body_typed(MetalModuleLoadFile); -TVM_REGISTER_GLOBAL("runtime.module.loadbinary_metal") -.set_body_typed(MetalModuleLoadBinary); +TVM_REGISTER_GLOBAL("runtime.module.loadbinary_metal").set_body_typed(MetalModuleLoadBinary); } // namespace runtime } // namespace tvm diff --git a/src/runtime/micro/device/arm/stm32f746xx/utvm_init.s b/src/runtime/micro/device/arm/stm32f746xx/utvm_init.s index 300deb8079a0..f5720f4d7b28 100644 --- a/src/runtime/micro/device/arm/stm32f746xx/utvm_init.s +++ b/src/runtime/micro/device/arm/stm32f746xx/utvm_init.s @@ -17,11 +17,6 @@ * under the License. */ -/*! - * \file utvm_init.s - * \brief uTVM init definition for STM32F746XX-series boards - */ - .syntax unified .cpu cortex-m7 .fpu softvfp diff --git a/src/runtime/micro/device/arm/stm32f746xx/utvm_timer.c b/src/runtime/micro/device/arm/stm32f746xx/utvm_timer.c index 1b8376150fce..ae2b1994df12 100644 --- a/src/runtime/micro/device/arm/stm32f746xx/utvm_timer.c +++ b/src/runtime/micro/device/arm/stm32f746xx/utvm_timer.c @@ -29,100 +29,49 @@ extern "C" { #include #include "utvm_runtime.h" +// NOTE: This expects ST CMSIS to be in your include path. +// Download STM32CubeF7 here: +// https://www.st.com/content/st_com/en/products/embedded-software/mcu-mpu-embedded-software/stm32-embedded-software/stm32cube-mcu-mpu-packages/stm32cubef7.html +// and add Drivers/CMSIS to your C include path. +#include "Device/ST/STM32F7xx/Include/stm32f746xx.h" -// There are two implementations of cycle counters on the STM32F7X: SysTick and -// CYCCNT. SysTick is preferred, as it gives better error handling, but the -// counter is only 24 bits wide. If a larger timer is needed, use the CYCCNT -// implementation, which has a 32-bit counter. -#define USE_SYSTICK - -#ifdef USE_SYSTICK - -#define SYST_CSR (*((volatile uint32_t *) 0xE000E010)) -#define SYST_RVR (*((volatile uint32_t *) 0xE000E014)) -#define SYST_CVR (*((volatile uint32_t *) 0xE000E018)) -#define SYST_CALIB (*((volatile uint32_t *) 0xE000E01C)) - -#define SYST_CSR_ENABLE 0 -#define SYST_CSR_TICKINT 1 -#define SYST_CSR_CLKSOURCE 2 -#define SYST_COUNTFLAG 16 - -#define SYST_CALIB_NOREF 31 -#define SYST_CALIB_SKEW 30 - -uint32_t start_time = 0; -uint32_t stop_time = 0; +#define utvm_SystemCoreClock 216000000UL int32_t UTVMTimerStart() { - SYST_CSR = (1 << SYST_CSR_ENABLE) | (1 << SYST_CSR_CLKSOURCE); - // wait until timer starts - while (SYST_CVR == 0) {} - start_time = SYST_CVR; - return 0; -} - -void UTVMTimerStop() { - SYST_CSR = 0; - stop_time = SYST_CVR; -} - -void UTVMTimerReset() { - SYST_CSR = 0; - // maximum reload value (24-bit) - SYST_RVR = (~((uint32_t) 0)) >> 8; - SYST_CVR = 0; + UTVMTimerReset(); + TIM2->CR1 = TIM_CR1_CEN; // Start counter + return UTVM_ERR_OK; } -uint32_t UTVMTimerRead() { - if (SYST_CSR & SYST_COUNTFLAG) { - TVMAPISetLastError("timer overflowed"); - return -1; - } else { - return start_time - stop_time; +uint32_t UTVMTimerStop(int32_t* err) { + TIM2->CR1 &= TIM_CR1_CEN; + if (TIM2->SR & TIM_SR_UIF_Msk) { + *err = UTVM_ERR_TIMER_OVERFLOW; + return 0; } + *err = UTVM_ERR_OK; + uint32_t tim_cnt = TIM2->CNT; + uint32_t millis = tim_cnt / (utvm_SystemCoreClock / 1000); + uint32_t micros = + (tim_cnt - (millis * (utvm_SystemCoreClock / 1000))) / (utvm_SystemCoreClock / 1000000); + return millis * 1000 + micros; } -#else // !USE_SYSTICK - -#define DWT_CTRL (*((volatile uint32_t *) 0xE0001000)) -#define DWT_CYCCNT (*((volatile uint32_t *) 0xE0001004)) - -#define DWT_CTRL_NOCYCCNT 25 -#define DWT_CTRL_CYCCNTENA 0 - -uint32_t start_time = 0; -uint32_t stop_time = 0; - void UTVMTimerReset() { - DWT_CYCCNT = 0; -} - -int32_t UTVMTimerStart() { - if (DWT_CTRL & DWT_CTRL_NOCYCCNT) { - TVMAPISetLastError("cycle counter not implemented on device"); - return -1; + RCC->APB1RSTR |= RCC_APB1RSTR_TIM2RST; // Hold TIM2 in reset + RCC->DCKCFGR1 = (RCC->DCKCFGR1 & ~RCC_DCKCFGR1_TIMPRE_Msk); // disable 2x clock boost to TIM2 + RCC->CFGR = (RCC->CFGR & ~RCC_CFGR_PPRE1_Msk); // No AHB clock division to APB1 (1:1). + RCC->APB1ENR |= RCC_APB1ENR_TIM2EN; // Enable TIM2 clock. + RCC->APB1RSTR &= ~RCC_APB1RSTR_TIM2RST; // Exit TIM2 reset. + + DBGMCU->APB1FZ |= DBGMCU_APB1_FZ_DBG_TIM2_STOP; // stop TIM2 clock during debug halt. + TIM2->ARR = 0xffffffff; + if (TIM2->SR & TIM_SR_UIF_Msk) { + for (;;) { + } } - start_time = DWT_CYCCNT; - DWT_CTRL |= (1 << DWT_CTRL_CYCCNTENA); } -void UTVMTimerStop() { - stop_time = DWT_CYCCNT; - DWT_CTRL &= ~(1 << DWT_CTRL_CYCCNTENA); -} - -int32_t UTVMTimerRead() { - if (stop_time > stop_time) { - return stop_time - start_time; - } else { - uint32_t largest = ~0; - return (largest - start_time) + stop_time; - } -} - -#endif // USE_SYSTICK - #ifdef __cplusplus } // TVM_EXTERN_C #endif diff --git a/src/runtime/micro/device/host/utvm_timer.c b/src/runtime/micro/device/host/utvm_timer.c index 56a36ebae86d..887b15c8b25a 100644 --- a/src/runtime/micro/device/host/utvm_timer.c +++ b/src/runtime/micro/device/host/utvm_timer.c @@ -22,26 +22,15 @@ * \brief uTVM timer API stubs for the host emulated device */ -#ifdef __cplusplus -extern "C" { -#endif +#include #include "utvm_runtime.h" // TODO(weberlo): use this? https://stackoverflow.com/questions/5141960/get-the-current-time-in-c -int32_t UTVMTimerStart() { - return 0; -} - -void UTVMTimerStop() { } +int32_t UTVMTimerStart() { return UTVM_ERR_OK; } -void UTVMTimerReset() { } - -uint32_t UTVMTimerRead() { - return 1; +uint32_t UTVMTimerStop(int32_t* err) { + *err = UTVM_ERR_OK; + return 0; } - -#ifdef __cplusplus -} // TVM_EXTERN_C -#endif diff --git a/src/runtime/micro/device/riscv_spike/utvm_init.s b/src/runtime/micro/device/riscv_spike/utvm_init.s new file mode 100644 index 000000000000..68662cce97e7 --- /dev/null +++ b/src/runtime/micro/device/riscv_spike/utvm_init.s @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +UTVMInit: + /* set stack pointer */ + la sp, _utvm_stack_pointer_init + call UTVMMain diff --git a/src/tir/pass/detect_device.cc b/src/runtime/micro/device/riscv_spike/utvm_timer.c similarity index 68% rename from src/tir/pass/detect_device.cc rename to src/runtime/micro/device/riscv_spike/utvm_timer.c index ee3a2e23b487..78c811979d43 100644 --- a/src/tir/pass/detect_device.cc +++ b/src/runtime/micro/device/riscv_spike/utvm_timer.c @@ -18,21 +18,23 @@ */ /*! - * \file detect_device.cc + * \file utvm_timer.c + * \brief uTVM timer API stubs for Spike */ -#include -#include "ir_util.h" +#ifdef __cplusplus +extern "C" { +#endif -namespace tvm { -namespace tir { -Stmt DecorateDeviceScope(Stmt stmt) { - Stmt body = AttrStmtNode::make(make_zero(DataType::Int(32)), - tir::attr::device_scope, - 0, - stmt); - return body; +#include "utvm_runtime.h" + +int32_t UTVMTimerStart() { return UTVM_ERR_OK; } + +uint32_t UTVMTimerStop(int32_t* err) { + *err = UTVM_ERR_OK; + return 0; } -} // namespace tir -} // namespace tvm +#ifdef __cplusplus +} // TVM_EXTERN_C +#endif diff --git a/src/runtime/micro/host_driven/utvm_device_dylib_redirect.c b/src/runtime/micro/host_driven/utvm_device_dylib_redirect.c index a8c600ed347b..64b5908e6c1c 100644 --- a/src/runtime/micro/host_driven/utvm_device_dylib_redirect.c +++ b/src/runtime/micro/host_driven/utvm_device_dylib_redirect.c @@ -29,16 +29,17 @@ #ifdef __cplusplus extern "C" { #endif -#include #include +#include -void *(*TVMBackendAllocWorkspace_)(int, int, uint64_t, int, int) = - (void *(*)(int, int, uint64_t, int, int)) NULL; -int (*TVMBackendFreeWorkspace_)(int, int, void*) = (int (*)(int, int, void*)) NULL; -void (*TVMAPISetLastError_)(const char*) = (void (*)(const char*)) NULL; +// TODO(weberlo, areusch): compiler errors say volatile qualifier is discarded. +// should we just get rid of em? +void* (*volatile TVMBackendAllocWorkspace_)(int, int, uint64_t, int, int) = NULL; +int (*volatile TVMBackendFreeWorkspace_)(int, int, void*) = NULL; +void (*volatile TVMAPISetLastError_)(const char*) = NULL; -void* TVMBackendAllocWorkspace(int device_type, int device_id, uint64_t size, - int dtype_code_hint, int dtype_bits_hint) { +void* TVMBackendAllocWorkspace(int device_type, int device_id, uint64_t size, int dtype_code_hint, + int dtype_bits_hint) { return (*TVMBackendAllocWorkspace_)(device_type, device_id, size, dtype_code_hint, dtype_bits_hint); } @@ -47,8 +48,41 @@ int TVMBackendFreeWorkspace(int device_type, int device_id, void* ptr) { return (*TVMBackendFreeWorkspace_)(device_type, device_id, ptr); } -void TVMAPISetLastError(const char* msg) { - (*TVMAPISetLastError_)(msg); +void TVMAPISetLastError(const char* msg) { (*TVMAPISetLastError_)(msg); } + +void* memset(void* s, int c, size_t n) { + char* p = (char*)s; // NOLINT(readability/casting): linter is configured for c++ + while (n > 0) { + *p = (char)c; // NOLINT(readability/casting): linter is configured for c++ + p++; + n--; + } + return s; +} + +void* memmove(void* to, const void* from, size_t n) { + // TODO(weberlo, areusch): will need to factor memmove calls into workspace size calculation + // NOLINTNEXTLINE(readability/casting): linter is configured for c++ + char* temp = (char*)TVMBackendAllocWorkspace(1, 1, (uint64_t)n, 2, 8); + if (temp == NULL) { + return NULL; + } + + const char* from_pp = (char*)from; // NOLINT(readability/casting): linter is configured for c++ + for (size_t i = 0; i < n; i++) { + temp[i] = from_pp[i]; + } + char* to_pp = (char*)to; // NOLINT(readability/casting): linter is configured for c++ + for (size_t i = 0; i < n; i++) { + to_pp[i] = temp[i]; + } + + // NOLINTNEXTLINE(readability/casting): linter is configured for c++ + if (TVMBackendFreeWorkspace(1, (uint64_t)1, (void*)temp) != 0) { + return NULL; + } + + return to; } #ifdef __cplusplus diff --git a/src/runtime/micro/host_driven/utvm_runtime.c b/src/runtime/micro/host_driven/utvm_runtime.c index a4de495a185c..398a08a014e0 100644 --- a/src/runtime/micro/host_driven/utvm_runtime.c +++ b/src/runtime/micro/host_driven/utvm_runtime.c @@ -34,89 +34,151 @@ extern "C" { #include "utvm_runtime.h" -// Task pointers must be patched before calling a function. -UTVMTask utvm_task = { - .func = NULL, - .arg_values = NULL, - .arg_type_codes = NULL, - .num_args = 0, -}; - -size_t utvm_word_size = 0; // NOLINT(*) +// TODO(weberlo, areusch): move defines into header +// TODO(weberlo, areusch): unify TASK_QUEUE_SIZE and MicroSession::kTaskQueueCapacity. +#define TASK_QUEUE_SIZE 20 +volatile UTVMTask utvm_tasks[TASK_QUEUE_SIZE] = {}; +volatile uint32_t utvm_num_tasks = 0; +volatile uint32_t utvm_task_times[TASK_QUEUE_SIZE] = {}; // These pointers are patched at load time to point to the workspace section. -char* utvm_workspace_start = NULL; // NOLINT(*) -char* utvm_workspace_end = NULL; // NOLINT(*) -char* utvm_workspace_curr = NULL; // NOLINT(*) +volatile char* utvm_workspace_start = NULL; // NOLINT(*) +volatile char* utvm_workspace_end = NULL; // NOLINT(*) +volatile char* utvm_workspace_curr = NULL; // NOLINT(*) +#define MAX_WS_ALLOCS 10 +volatile char* utvm_alloc_ends[MAX_WS_ALLOCS] = {}; // NOLINT(*) +volatile uint32_t utvm_alloc_idx = 0; // Keep track of how many active allocations there are on the workspace. -size_t utvm_num_active_allocs = 0; +volatile uint32_t utvm_num_active_allocs = 0; + +volatile uint32_t utvm_word_size = 0; -const char* utvm_last_error = NULL; // NOLINT(*) -int32_t utvm_return_code = 0; // NOLINT(*) +volatile int32_t utvm_last_error = 0; // NOLINT(*) -uint32_t utvm_task_time = 0; +volatile uint32_t utvm_done = 0; // Gets called by UTVMInit, after device-specific initialization is finished. void UTVMMain() { + utvm_done = 0; + // loss of precision should be fine here, since we only care about the lower bits + if (((uint32_t)utvm_workspace_start) % utvm_word_size) { + utvm_last_error = UTVM_ERR_WS_UNALIGNED_START; + UTVMDone(); + return; + } utvm_workspace_curr = utvm_workspace_start; utvm_num_active_allocs = 0; - utvm_last_error = NULL; // NOLINT(*) - utvm_return_code = 0; - utvm_task_time = 0; - UTVMTimerReset(); - int32_t err = UTVMTimerStart(); - if (err < 0) { - utvm_return_code = err; - UTVMDone(); + utvm_alloc_idx = 0; + utvm_last_error = UTVM_ERR_NOT_FINISHED; + for (uint32_t i = 0; i < utvm_num_tasks; i++) { + int32_t err = UTVM_ERR_OK; + utvm_task_times[i] = 0; + err = UTVMTimerStart(); + if (err < 0) { + utvm_last_error = err; + UTVMDone(); + return; + } + err = utvm_tasks[i].func((void*)utvm_tasks[i].arg_values, // NOLINT(*) + (void*)utvm_tasks[i].arg_type_codes, // NOLINT(*) + utvm_tasks[i].num_args); + if (err < 0) { + UTVMDone(); + return; + } + utvm_task_times[i] = UTVMTimerStop(&err); + if (err < 0) { + utvm_last_error = err; + UTVMDone(); + return; + } + } + if (utvm_last_error == UTVM_ERR_NOT_FINISHED) { + utvm_last_error = UTVM_ERR_OK; } - utvm_return_code = utvm_task.func( - (void*) utvm_task.arg_values, // NOLINT(*) - (void*) utvm_task.arg_type_codes, // NOLINT(*) - utvm_task.num_args); - UTVMTimerStop(); - utvm_task_time = UTVMTimerRead(); UTVMDone(); } // We use a dummy function to signal execution is finished for device // backends which require breakpoints. -void UTVMDone() { } +void __attribute__((noinline)) UTVMDone() { + utvm_done = 1; +#ifndef UTVM_TARGET_HOST + for (;;) { + } +#endif +} -void* TVMBackendAllocWorkspace(int device_type, int device_id, uint64_t size, - int dtype_code_hint, int dtype_bits_hint) { - // Align up to 8 bytes. - utvm_workspace_curr += - (utvm_word_size - ((uintptr_t) utvm_workspace_curr % utvm_word_size)) % utvm_word_size; // NOLINT(*) - if (utvm_workspace_curr + size > utvm_workspace_end) { +#define ALIGNED_UP(x, word_size) \ + ((((word_size) - (((uintptr_t)(x)) % (word_size))) % (word_size)) + (x)) + +void* TVMBackendAllocWorkspace(int device_type, int device_id, uint64_t size, int dtype_code_hint, + int dtype_bits_hint) { + if (size == 0) { + utvm_last_error = UTVM_ERR_WS_ZERO_SIZE_ALLOC; + return NULL; + } + size_t alloc_requested_bytes = size; + size_t alloc_size_words = (alloc_requested_bytes + utvm_word_size - 1) / utvm_word_size; + size_t alloc_size_bytes = alloc_size_words * utvm_word_size; + + // Align up to the target word size. + if (utvm_workspace_curr + alloc_size_bytes > utvm_workspace_end) { // Out of space in workspace. + utvm_last_error = UTVM_ERR_WS_OUT_OF_SPACE; return NULL; } - void* ret_ptr = (void*) utvm_workspace_curr; // NOLINT(*) - utvm_workspace_curr += size; + if (utvm_alloc_idx == MAX_WS_ALLOCS - 1) { + // Exceeded number of allocs we can keep track of. + utvm_last_error = UTVM_ERR_WS_TOO_MANY_ALLOCS; + return NULL; + } + void* ret_ptr = (void*)utvm_workspace_curr; // NOLINT(*) + utvm_workspace_curr = utvm_workspace_curr + alloc_size_bytes; + // store the *end* of the alloc, so we can restore the WS pointer when freeing + utvm_alloc_ends[utvm_alloc_idx] = utvm_workspace_curr; + utvm_alloc_idx++; utvm_num_active_allocs++; return ret_ptr; } int TVMBackendFreeWorkspace(int device_type, int device_id, void* ptr) { - utvm_num_active_allocs--; - if (utvm_num_active_allocs < 0) { + // TODO(weberlo, areusch): add dev type check + if (utvm_num_active_allocs == 0) { TVMAPISetLastError("free called with no active workspace allocations"); // Reset allocations and workspace (for future task executions). utvm_num_active_allocs = 0; utvm_workspace_curr = utvm_workspace_start; + utvm_last_error = UTVM_ERR_WS_DOUBLE_FREE; return -1; - } else if (utvm_num_active_allocs == 0) { - // No more allocations. Reset workspace. - utvm_workspace_curr = utvm_workspace_start; - return 0; } else { + utvm_num_active_allocs--; + if (ptr == utvm_workspace_start) { + // it's the first allocation + utvm_alloc_ends[0] = NULL; + } else { + for (uint32_t i = utvm_alloc_idx - 1; i >= 0; i--) { + if (utvm_alloc_ends[i] == ptr) { + utvm_alloc_ends[i + 1] = NULL; + break; + } + } + } + while (utvm_alloc_idx > 0 && utvm_alloc_ends[utvm_alloc_idx - 1] == NULL) { + utvm_alloc_idx--; + } + if (utvm_alloc_idx == 0) { + utvm_workspace_curr = utvm_workspace_start; + } else { + // TODO(weberlo, areusch): could you possibly have utvm_alloc_idx pointing to a NULL entry in + // this branch? + utvm_workspace_curr = utvm_alloc_ends[utvm_alloc_idx - 1]; + } return 0; } } -void TVMAPISetLastError(const char* msg) { - utvm_last_error = msg; -} +void TVMAPISetLastError(const char* msg) {} #ifdef __cplusplus } // TVM_EXTERN_C diff --git a/src/runtime/micro/host_driven/utvm_runtime.h b/src/runtime/micro/host_driven/utvm_runtime.h index c364ecf40792..8758c3ad89a1 100644 --- a/src/runtime/micro/host_driven/utvm_runtime.h +++ b/src/runtime/micro/host_driven/utvm_runtime.h @@ -29,8 +29,10 @@ extern "C" { #endif #include -#include #include +#include + +#include "utvm_runtime_enum.h" /*! * \brief Task structure for uTVM @@ -46,20 +48,46 @@ typedef struct { int32_t num_args; } UTVMTask; +/*! + * \brief microTVM processor startup. + * Expected to reset the stack pointer, configure any hardware required to support the CRT + * (i.e. FPU), and then jump to UTVMMain. + */ extern void UTVMInit(); -extern void UTVMTimerReset(); - +/*! + * \brief Start the on-device timer. + * \return UTVMReturnCode indicating the outcome of the operation. + */ extern int32_t UTVMTimerStart(); -extern void UTVMTimerStop(); - -extern uint32_t UTVMTimerRead(); +/*! + * \brief Stop the on-device timer. + * TODO(areusch): Use an SI specification of timer units here. + * \param err Receives a UTVMReturnCode indicating the outcome of the operation. + * \return elapsed time since UTVMTimerStart returned, in device timer ticks. + */ +extern uint32_t UTVMTimerStop(int32_t* err); +/*! + * \brief Main entry point for UTVM runtime. + * Waits for "go" signal, then executes tasks and reports result. Should never return. + */ void UTVMMain(); +/*! + * \brief Function entered when UTVMMain is complete. + * Should never return. The host sets a breakpoint here to detect end of computation. + */ void UTVMDone(); +// GCC -O3 begins to inject memset and memmove calls, so we provide impls in +// the runtime for this case and for general usage. + +void* memset(void* s, int c, size_t n); + +void* memmove(void* to, const void* from, size_t n); + #ifdef __cplusplus } // TVM_EXTERN_C #endif diff --git a/src/runtime/micro/host_driven/utvm_runtime_enum.h b/src/runtime/micro/host_driven/utvm_runtime_enum.h new file mode 100644 index 000000000000..17f803612cb9 --- /dev/null +++ b/src/runtime/micro/host_driven/utvm_runtime_enum.h @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file utvm_runtime_enum.h + * \brief Defines constants used both on the host and on device. + */ +#ifndef TVM_RUNTIME_MICRO_HOST_DRIVEN_UTVM_RUNTIME_ENUM_H_ +#define TVM_RUNTIME_MICRO_HOST_DRIVEN_UTVM_RUNTIME_ENUM_H_ + +#ifdef __cplusplus +extern "C" { +#endif + +/*! + * \brief TODO + */ +enum UTVMReturnCode { + UTVM_ERR_OK = 0, + UTVM_ERR_NOT_FINISHED = -1, + UTVM_ERR_TIMER_NOT_IMPLEMENTED = -2, + UTVM_ERR_TIMER_OVERFLOW = -3, + UTVM_ERR_WS_DOUBLE_FREE = -4, + UTVM_ERR_WS_OUT_OF_SPACE = -5, + UTVM_ERR_WS_TOO_MANY_ALLOCS = -6, + UTVM_ERR_WS_ZERO_SIZE_ALLOC = -7, + UTVM_ERR_WS_UNALIGNED_START = -8, + UTVM_ERR_WS_UNALIGNED_ALLOC_SIZE = -9, +}; + +#ifdef __cplusplus +} // TVM_EXTERN_C +#endif + +#endif // TVM_RUNTIME_MICRO_HOST_DRIVEN_UTVM_RUNTIME_ENUM_H_ diff --git a/src/runtime/micro/host_low_level_device.cc b/src/runtime/micro/host_low_level_device.cc index a24994a2a0e5..7c3e7a2abad8 100644 --- a/src/runtime/micro/host_low_level_device.cc +++ b/src/runtime/micro/host_low_level_device.cc @@ -23,10 +23,12 @@ */ #include + #include #include -#include "micro_common.h" + #include "low_level_device.h" +#include "micro_common.h" namespace tvm { namespace runtime { @@ -43,38 +45,35 @@ class HostLowLevelDevice final : public LowLevelDevice { * \brief constructor to initialize on-host memory region to act as device * \param num_bytes size of the emulated on-device memory region */ - explicit HostLowLevelDevice(size_t num_bytes, void** base_addr) : size_(num_bytes) { + explicit HostLowLevelDevice(size_t num_bytes, TargetPtr* base_addr) : size_(num_bytes) { size_t size_in_pages = (num_bytes + kPageSize - 1) / kPageSize; // TODO(weberlo): Set permissions per section (e.g., read-write perms for // the heap, execute perms for text, etc.). int mmap_prot = PROT_READ | PROT_WRITE | PROT_EXEC; int mmap_flags = MAP_ANONYMOUS | MAP_PRIVATE; base_addr_ = mmap(nullptr, size_in_pages * kPageSize, mmap_prot, mmap_flags, -1, 0); - *base_addr = base_addr_; + *base_addr = + TargetPtr(TargetWordSize(sizeof(size_t) * 8), reinterpret_cast(base_addr_)); } /*! * \brief destructor to deallocate on-host device region */ - virtual ~HostLowLevelDevice() { - munmap(base_addr_, size_); - } + virtual ~HostLowLevelDevice() { munmap(base_addr_, size_); } - void Read(DevPtr addr, void* buf, size_t num_bytes) { + void Read(TargetPtr addr, void* buf, size_t num_bytes) { std::memcpy(buf, addr.cast_to(), num_bytes); } - void Write(DevPtr addr, const void* buf, size_t num_bytes) { + void Write(TargetPtr addr, const void* buf, size_t num_bytes) { std::memcpy(addr.cast_to(), buf, num_bytes); } - void Execute(DevPtr func_addr, DevPtr breakpoint_addr) { - reinterpret_cast(func_addr.value().val64)(); + void Execute(TargetPtr func_addr, TargetPtr breakpoint_addr) { + reinterpret_cast(func_addr.value().uint64())(); } - const char* device_type() const final { - return "host"; - } + const char* device_type() const final { return "host"; } private: /*! \brief base address of the micro device memory region */ @@ -83,9 +82,9 @@ class HostLowLevelDevice final : public LowLevelDevice { size_t size_; }; -const std::shared_ptr HostLowLevelDeviceCreate(size_t num_bytes, void** base_addr) { - std::shared_ptr lld = - std::make_shared(num_bytes, base_addr); +const std::shared_ptr HostLowLevelDeviceCreate(size_t num_bytes, + TargetPtr* base_addr) { + std::shared_ptr lld = std::make_shared(num_bytes, base_addr); return lld; } diff --git a/src/runtime/micro/low_level_device.h b/src/runtime/micro/low_level_device.h index 3158e2fe20de..6cc0e1dc5af0 100644 --- a/src/runtime/micro/low_level_device.h +++ b/src/runtime/micro/low_level_device.h @@ -45,9 +45,7 @@ class LowLevelDevice { * \param buffer on-host buffer to be read into * \param num_bytes number of bytes to read */ - virtual void Read(DevPtr addr, - void* buffer, - size_t num_bytes) = 0; + virtual void Read(TargetPtr addr, void* buffer, size_t num_bytes) = 0; /*! * \brief writes num_bytes from buffer to device memory at addr @@ -55,16 +53,14 @@ class LowLevelDevice { * \param buffer host buffer to write from * \param num_bytes number of bytes to write */ - virtual void Write(DevPtr addr, - const void* buffer, - size_t num_bytes) = 0; + virtual void Write(TargetPtr addr, const void* buffer, size_t num_bytes) = 0; /*! * \brief starts execution of device at func_addr * \param func_addr offset of the init stub function * \param breakpoint_addr address at which to stop function execution */ - virtual void Execute(DevPtr func_addr, DevPtr breakpoint_addr) = 0; + virtual void Execute(TargetPtr func_addr, TargetPtr breakpoint_addr) = 0; /*! * \brief getter function for low-level device type @@ -78,7 +74,8 @@ class LowLevelDevice { * \param num_bytes size of the memory region * \param base_addr pointer to write the host device's resulting base address into */ -const std::shared_ptr HostLowLevelDeviceCreate(size_t num_bytes, void** base_addr); +const std::shared_ptr HostLowLevelDeviceCreate(size_t num_bytes, + TargetPtr* base_addr); /*! * \brief connect to OpenOCD and create an OpenOCD low-level device diff --git a/src/runtime/micro/micro_common.cc b/src/runtime/micro/micro_common.cc index 632b6048b182..eba77f3dadbc 100644 --- a/src/runtime/micro/micro_common.cc +++ b/src/runtime/micro/micro_common.cc @@ -22,65 +22,65 @@ * \brief common utilties for uTVM */ +#include "micro_common.h" + #include #include + +#include #include -#include #include -#include -#include "micro_session.h" -#include "micro_common.h" +#include + #include "low_level_device.h" +#include "micro_session.h" namespace tvm { namespace runtime { const char* SectionToString(SectionKind section) { switch (section) { - case SectionKind::kText: return "text"; - case SectionKind::kRodata: return "rodata"; - case SectionKind::kData: return "data"; - case SectionKind::kBss: return "bss"; - case SectionKind::kArgs: return "args"; - case SectionKind::kHeap: return "heap"; - case SectionKind::kWorkspace: return "workspace"; - case SectionKind::kStack: return "stack"; - default: return ""; + case SectionKind::kText: + return "text"; + case SectionKind::kRodata: + return "rodata"; + case SectionKind::kData: + return "data"; + case SectionKind::kBss: + return "bss"; + case SectionKind::kArgs: + return "args"; + case SectionKind::kHeap: + return "heap"; + case SectionKind::kWorkspace: + return "workspace"; + case SectionKind::kStack: + return "stack"; + default: + return ""; } } -std::string RelocateBinarySections( - const std::string& binary_path, - size_t word_size, - DevPtr text_start, - DevPtr rodata_start, - DevPtr data_start, - DevPtr bss_start, - DevPtr stack_end, - const std::string& toolchain_prefix) { +std::string RelocateBinarySections(const std::string& binary_path, TargetWordSize word_size, + TargetPtr text_start, TargetPtr rodata_start, + TargetPtr data_start, TargetPtr bss_start, TargetPtr stack_end, + const std::string& toolchain_prefix) { const auto* f = Registry::Get("tvm_callback_relocate_binary"); - CHECK(f != nullptr) - << "Require tvm_callback_relocate_binary to exist in registry"; - std::string relocated_bin = (*f)(binary_path, - word_size, - text_start.cast_to(), - rodata_start.cast_to(), - data_start.cast_to(), - bss_start.cast_to(), - stack_end.cast_to(), - toolchain_prefix); + CHECK(f != nullptr) << "Require tvm_callback_relocate_binary to exist in registry"; + std::string relocated_bin = + (*f)(binary_path, word_size.bytes(), text_start.cast_to(), + rodata_start.cast_to(), data_start.cast_to(), + bss_start.cast_to(), stack_end.cast_to(), toolchain_prefix); return relocated_bin; } -std::string ReadSection(const std::string& binary, - SectionKind section, +std::string ReadSection(const std::string& binary, SectionKind section, const std::string& toolchain_prefix) { CHECK(section == SectionKind::kText || section == SectionKind::kRodata || section == SectionKind::kData || section == SectionKind::kBss) << "ReadSection requires section to be one of text, rodata, data, or bss."; const auto* f = Registry::Get("tvm_callback_read_binary_section"); - CHECK(f != nullptr) - << "Require tvm_callback_read_binary_section to exist in registry"; + CHECK(f != nullptr) << "Require tvm_callback_read_binary_section to exist in registry"; TVMByteArray arr; arr.data = &binary[0]; arr.size = binary.length(); @@ -88,18 +88,43 @@ std::string ReadSection(const std::string& binary, return section_contents; } -size_t GetSectionSize(const std::string& binary_path, - SectionKind section, - const std::string& toolchain_prefix, - size_t align) { +size_t GetSectionSize(const std::string& binary_path, SectionKind section, + const std::string& toolchain_prefix, TargetWordSize word_size) { CHECK(section == SectionKind::kText || section == SectionKind::kRodata || section == SectionKind::kData || section == SectionKind::kBss) << "GetSectionSize requires section to be one of text, rodata, data, or bss."; const auto* f = Registry::Get("tvm_callback_get_section_size"); - CHECK(f != nullptr) - << "Require tvm_callback_get_section_size to exist in registry"; + CHECK(f != nullptr) << "Require tvm_callback_get_section_size to exist in registry"; int size = (*f)(binary_path, SectionToString(section), toolchain_prefix); - return UpperAlignValue(size, align); + return UpperAlignValue(size, word_size.bytes()); +} + +std::ostream& operator<<(std::ostream& os, const TargetVal& v) { + std::ios_base::fmtflags f(os.flags()); + os << std::dec << "0x"; + switch (v.width_bits()) { + case 8: + os << uint8_t(v.uint32()); + break; + case 16: + os << uint16_t(v.uint32()); + break; + case 32: + os << v.uint32(); + break; + case 64: + os << v.uint64(); + break; + default: + os << (v.uint64() & ((1 << v.width_bits()) - 1)); + } + os.flags(f); + return os; +} + +std::ostream& operator<<(std::ostream& os, const TargetPtr& v) { + os << "*" << v.value_; + return os; } } // namespace runtime diff --git a/src/runtime/micro/micro_common.h b/src/runtime/micro/micro_common.h index 4a0189b3e89e..2c4684b357a8 100644 --- a/src/runtime/micro/micro_common.h +++ b/src/runtime/micro/micro_common.h @@ -24,12 +24,12 @@ #define TVM_RUNTIME_MICRO_MICRO_COMMON_H_ #include - #include #include #include #include +#include namespace tvm { namespace runtime { @@ -52,28 +52,111 @@ enum class SectionKind : size_t { kNumKinds, }; -/*! \brief union for storing values on varying target word sizes */ -union TargetVal { - /*! \brief 32-bit pointer */ - uint32_t val32; - /*! \brief 64-bit pointer */ - uint64_t val64; +/*! \brief data type for word sizes */ +class TargetWordSize { + public: + explicit TargetWordSize(size_t word_size_bits) : word_size_bits_{word_size_bits} { + CHECK(word_size_bits == 32 || word_size_bits == 64) + << "only 32-bit and 64-bit are supported now"; + } + + size_t bytes() const { return word_size_bits_ / 8; } + + size_t bits() const { return word_size_bits_; } + + private: + size_t word_size_bits_; }; -/*! \brief absolute device address */ -class DevPtr { +/*! \brief class for storing values on varying target word sizes */ +class TargetVal { + private: + size_t width_bits_; + uint64_t value_; + public: - /*! \brief construct a device address with value `value` */ - explicit DevPtr(std::uintptr_t value) : value_(TargetVal { .val64 = value }) {} + /*! \brief construct a TargetVal matching the size of the given integral argument */ + template ::value, T>::type> + explicit constexpr TargetVal(T value) : TargetVal(sizeof(T) * 8, value) {} + + /*! \brief construct an uninitialized value */ + TargetVal() : width_bits_{0}, value_{0} {} + + /*! \brief construct a TargetVal with explicit size and value */ + TargetVal(size_t width_bits, uint64_t value) : width_bits_{width_bits} { + CHECK(width_bits >= 8 && width_bits <= 64 && (width_bits & (width_bits - 1)) == 0) + << "width_bits must be a power of 2 in [8, 64], got " << width_bits; + value_ = value & Bitmask(); + } + + bool IsInitialized() const { return width_bits_ != 0; } + + size_t width_bits() const { + CHECK(IsInitialized()) << "TargetVal is not initialized"; + return width_bits_; + } + + uint64_t Bitmask() const { + CHECK(IsInitialized()) << "TargetVal is not initialized"; + + if (width_bits_ == 64) { + return ~0UL; + } else { + return (1UL << width_bits_) - 1; + } + } + + uint32_t uint32() const { + CHECK(IsInitialized()) << "TargetVal is not initialized"; + CHECK(width_bits_ <= 32) << "TargetVal: requested 32-bit value, actual width is " + << width_bits_; + return uint32_t(value_ & Bitmask()); + } - /*! \brief default constructor */ - DevPtr() : value_(TargetVal { .val64 = 0 }) {} + uint64_t uint64() const { + CHECK(IsInitialized()) << "TargetVal is not initialized"; + return value_; + } + + TargetVal& operator=(const TargetVal& other) { + CHECK(other.IsInitialized()) << "Cannot assign an uninitialized TargetVal"; + + if (!IsInitialized()) { + width_bits_ = other.width_bits_; + } + + CHECK(width_bits_ >= other.width_bits_) + << "Cannot assign TargetVal with width " << other.width_bits_ + << "bits to TargetVal with width " << width_bits_ << "bits"; + + value_ = other.value_ & Bitmask(); + return *this; + } + + private: + friend std::ostream& operator<<(std::ostream& os, const TargetVal& v); +}; + +// TODO(weberlo, areusch): just get rid of `TargetPtr`. +/*! \brief absolute device address */ +class TargetPtr { + public: + /*! \brief construct a device address with variable-length value `value` */ + TargetPtr(TargetWordSize word_size, std::uint64_t value) + : value_(TargetVal(word_size.bits(), value)) {} /*! \brief construct a null address */ - explicit DevPtr(std::nullptr_t value) : value_(TargetVal { .val64 = 0 }) {} + TargetPtr(TargetWordSize word_size, std::nullptr_t value) + : value_{TargetVal(word_size.bits(), 0)} {} + + /*! \brief construct an uninitialized pointer whose word_size can be changed once */ + TargetPtr() = default; + + /*! \brief construct a device address using the given TargetVal */ + explicit TargetPtr(const TargetVal& value) : value_{value} {} /*! \brief destructor */ - ~DevPtr() {} + ~TargetPtr() {} /*! * \brief get value of pointer @@ -86,39 +169,43 @@ class DevPtr { * \return casted result */ template - T cast_to() const { return reinterpret_cast(value_.val64); } + T cast_to() const { + return reinterpret_cast(value_.uint64()); + } /*! \brief check if location is null */ - bool operator==(std::nullptr_t) const { return value_.val64 == 0; } + bool operator==(std::nullptr_t) const { return value_.uint64() == 0; } /*! \brief check if location is not null */ - bool operator!=(std::nullptr_t) const { return value_.val64 != 0; } + bool operator!=(std::nullptr_t) const { return value_.uint64() != 0; } /*! \brief add an integer to this absolute address to get a larger absolute address */ - DevPtr operator+(size_t n) const { - return DevPtr(value_.val64 + n); + TargetPtr operator+(size_t n) const { + return TargetPtr(TargetWordSize(value_.width_bits()), value_.uint64() + n); } /*! \brief mutably add an integer to this absolute address */ - DevPtr& operator+=(size_t n) { - value_.val64 += n; + TargetPtr& operator+=(size_t n) { + value_ = TargetVal(value_.width_bits(), value_.uint64() + n); return *this; } /*! \brief subtract an integer from this absolute address to get a smaller absolute address */ - DevPtr operator-(size_t n) const { - return DevPtr(value_.val64 - n); + TargetPtr operator-(size_t n) const { + return TargetPtr(TargetWordSize(value_.width_bits()), value_.uint64() - n); } /*! \brief mutably subtract an integer from this absolute address */ - DevPtr& operator-=(size_t n) { - value_.val64 -= n; + TargetPtr& operator-=(size_t n) { + value_ = TargetVal(value_.width_bits(), value_.uint64() - n); return *this; } private: /*! \brief raw value storing the pointer */ TargetVal value_; + + friend std::ostream& operator<<(std::ostream& os, const TargetPtr& v); }; /*! @@ -136,8 +223,8 @@ class SymbolMap { * \param binary contents of binary object file * \param toolchain_prefix prefix of compiler toolchain to use */ - SymbolMap(const std::string& binary, - const std::string& toolchain_prefix) { + SymbolMap(const std::string& binary, const std::string& toolchain_prefix, + TargetWordSize word_size) { const auto* f = Registry::Get("tvm_callback_get_symbol_map"); CHECK(f != nullptr) << "require tvm_callback_get_symbol_map to exist in registry"; TVMByteArray arr; @@ -152,7 +239,7 @@ class SymbolMap { stream >> name; stream >> std::hex >> addr; while (stream) { - map_[name] = DevPtr(addr); + map_.emplace(std::make_pair(name, TargetPtr(word_size, addr))); stream >> name; stream >> std::hex >> addr; } @@ -163,25 +250,29 @@ class SymbolMap { * \param name name of the symbol * \return on-device offset of the symbol */ - DevPtr operator[](const std::string& name) const { + TargetPtr operator[](const std::string& name) const { auto result = map_.find(name); CHECK(result != map_.end()) << "\"" << name << "\" not in symbol map"; return result->second; } - bool HasSymbol(const std::string& name) const { - return map_.find(name) != map_.end(); + bool HasSymbol(const std::string& name) const { return map_.find(name) != map_.end(); } + + void Dump(std::ostream& stream) const { + for (auto e : map_) { + stream << "Entry:" << e.first << std::endl; + } } private: /*! \brief backing map */ - std::unordered_map map_; + std::unordered_map map_; }; /*! \brief struct containing start and size of a device memory region */ struct DevMemRegion { /*! \brief section start offset */ - DevPtr start; + TargetPtr start; /*! \brief size of section */ size_t size; }; @@ -237,15 +328,10 @@ const char* SectionToString(SectionKind section); * \param toolchain_prefix prefix of compiler toolchain to use * \return relocated binary file contents */ -std::string RelocateBinarySections( - const std::string& binary_path, - size_t word_size, - DevPtr text_start, - DevPtr rodata_start, - DevPtr data_start, - DevPtr bss_start, - DevPtr stack_end, - const std::string& toolchain_prefix); +std::string RelocateBinarySections(const std::string& binary_path, TargetWordSize word_size, + TargetPtr text_start, TargetPtr rodata_start, + TargetPtr data_start, TargetPtr bss_start, TargetPtr stack_end, + const std::string& toolchain_prefix); /*! * \brief reads section from binary @@ -254,8 +340,7 @@ std::string RelocateBinarySections( * \param toolchain_prefix prefix of compiler toolchain to use * \return contents of the section */ -std::string ReadSection(const std::string& binary, - SectionKind section, +std::string ReadSection(const std::string& binary, SectionKind section, const std::string& toolchain_prefix); /*! @@ -263,13 +348,11 @@ std::string ReadSection(const std::string& binary, * \param binary input binary contents * \param section section type * \param toolchain_prefix prefix of compiler toolchain to use - * \param align alignment of the returned size (default: 8) + * \param word_size word size of the target, for alignment * \return size of the section if it exists, 0 otherwise */ -size_t GetSectionSize(const std::string& binary_name, - SectionKind section, - const std::string& toolchain_prefix, - size_t align); +size_t GetSectionSize(const std::string& binary_name, SectionKind section, + const std::string& toolchain_prefix, TargetWordSize word_size); } // namespace runtime } // namespace tvm diff --git a/src/runtime/micro/micro_device_api.cc b/src/runtime/micro/micro_device_api.cc index 3d0a6889c4f7..68480786ac87 100644 --- a/src/runtime/micro/micro_device_api.cc +++ b/src/runtime/micro/micro_device_api.cc @@ -21,9 +21,10 @@ * \file micro_device_api.cc */ -#include -#include #include +#include +#include + #include "../workspace_pool.h" #include "micro_session.h" @@ -35,7 +36,7 @@ namespace runtime { class MicroDeviceAPI final : public DeviceAPI { public: /*! \brief constructor */ - MicroDeviceAPI() { } + MicroDeviceAPI() {} void SetDevice(TVMContext ctx) final {} @@ -45,100 +46,93 @@ class MicroDeviceAPI final : public DeviceAPI { } } - void* AllocDataSpace(TVMContext ctx, - size_t nbytes, - size_t alignment, + void* AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment, DLDataType type_hint) final { ObjectPtr& session = MicroSession::Current(); - void* data = session->AllocateInSection(SectionKind::kHeap, nbytes).cast_to(); + TargetPtr data = session->AllocateInSection(SectionKind::kHeap, nbytes); CHECK(data != nullptr) << "unable to allocate " << nbytes << " bytes on device heap"; - MicroDevSpace* dev_space = new MicroDevSpace(); - dev_space->data = data; - dev_space->session = session; - return static_cast(dev_space); + return reinterpret_cast(new MicroDevSpace{data, session}); } void FreeDataSpace(TVMContext ctx, void* ptr) final { MicroDevSpace* dev_space = static_cast(ptr); - dev_space->session->FreeInSection( - SectionKind::kHeap, DevPtr(reinterpret_cast(dev_space->data))); + dev_space->session->FreeInSection(SectionKind::kHeap, dev_space->data); delete dev_space; } - void CopyDataFromTo(const void* from, - size_t from_offset, - void* to, - size_t to_offset, - size_t size, - TVMContext ctx_from, - TVMContext ctx_to, - DLDataType type_hint, + void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size, + TVMContext ctx_from, TVMContext ctx_to, DLDataType type_hint, TVMStreamHandle stream) final { std::tuple type_from_to(ctx_from.device_type, ctx_to.device_type); if (type_from_to == std::make_tuple(kDLMicroDev, kDLMicroDev)) { // Copying from the device to the device. - MicroDevSpace* from_space = static_cast(const_cast(from)); MicroDevSpace* to_space = static_cast(const_cast(to)); CHECK(from_space->session == to_space->session) - << "attempt to copy data between different micro sessions (" - << from_space->session.get() + << "attempt to copy data between different micro sessions (" << from_space->session.get() << " != " << to_space->session.get() << ")"; CHECK(ctx_from.device_id == ctx_to.device_id) - << "can only copy between the same micro device"; + << "can only copy between the same micro device"; ObjectPtr& session = from_space->session; + // flush all pending tasks to ensure data is consistent + session->FlushTaskQueue(); const std::shared_ptr& lld = session->low_level_device(); - DevPtr from_dev_addr = GetDevLoc(from_space, from_offset); - DevPtr to_dev_addr = GetDevLoc(to_space, to_offset); + TargetPtr from_dev_addr = GetDevLoc(from_space, from_offset); + TargetPtr to_dev_addr = GetDevLoc(to_space, to_offset); std::vector buffer(size); lld->Read(from_dev_addr, static_cast(buffer.data()), size); lld->Write(to_dev_addr, static_cast(buffer.data()), size); + } else if (type_from_to == std::make_tuple(kDLMicroDev, kDLCPU)) { // Reading from the device. - MicroDevSpace* from_space = static_cast(const_cast(from)); ObjectPtr& session = from_space->session; + // flush all pending tasks to ensure data is consistent + session->FlushTaskQueue(); const std::shared_ptr& lld = session->low_level_device(); - DevPtr from_dev_addr = GetDevLoc(from_space, from_offset); + TargetPtr from_dev_addr = GetDevLoc(from_space, from_offset); void* to_host_ptr = GetHostLoc(to, to_offset); lld->Read(from_dev_addr, to_host_ptr, size); + } else if (type_from_to == std::make_tuple(kDLCPU, kDLMicroDev)) { // Writing to the device. - MicroDevSpace* to_space = static_cast(const_cast(to)); ObjectPtr& session = to_space->session; + // flush all pending tasks to ensure data is consistent + session->FlushTaskQueue(); const std::shared_ptr& lld = session->low_level_device(); void* from_host_ptr = GetHostLoc(from, from_offset); - DevPtr to_dev_addr = GetDevLoc(to_space, to_offset); + TargetPtr to_dev_addr = GetDevLoc(to_space, to_offset); lld->Write(to_dev_addr, from_host_ptr, size); + } else { LOG(FATAL) << "Expect copy from/to micro device or between micro device\n"; } } void StreamSync(TVMContext ctx, TVMStreamHandle stream) final { + MicroSession::Current()->FlushTaskQueue(); } void* AllocWorkspace(TVMContext ctx, size_t size, DLDataType type_hint) final { + CHECK(false) << "the on-device workspace allocator isn't aware of this function"; ObjectPtr& session = MicroSession::Current(); - void* data = session->AllocateInSection(SectionKind::kWorkspace, size).cast_to(); - CHECK(data != nullptr) << "unable to allocate " << size << " bytes on device workspace"; - MicroDevSpace* dev_space = new MicroDevSpace(); - dev_space->data = data; - dev_space->session = session; - return static_cast(dev_space); + TargetPtr data = session->AllocateInSection(SectionKind::kWorkspace, size); + CHECK(data.value().uint64() != 0) + << "unable to allocate " << size << " bytes on device workspace"; + return static_cast(new MicroDevSpace{data, session}); } void FreeWorkspace(TVMContext ctx, void* data) final { + CHECK(false) << "the on-device workspace allocator isn't aware of this function"; MicroDevSpace* dev_space = static_cast(data); ObjectPtr& session = dev_space->session; - session->FreeInSection(SectionKind::kWorkspace, - DevPtr(reinterpret_cast(dev_space->data))); + session->FreeInSection(SectionKind::kWorkspace, dev_space->data); delete dev_space; } @@ -152,9 +146,7 @@ class MicroDeviceAPI final : public DeviceAPI { } private: - DevPtr GetDevLoc(MicroDevSpace* dev_space, size_t offset) { - return DevPtr(reinterpret_cast(dev_space->data) + offset); - } + TargetPtr GetDevLoc(MicroDevSpace* dev_space, size_t offset) { return dev_space->data + offset; } void* GetHostLoc(const void* ptr, size_t offset) { return reinterpret_cast(reinterpret_cast(ptr) + offset); @@ -162,10 +154,9 @@ class MicroDeviceAPI final : public DeviceAPI { }; // register device that can be obtained from Python frontend -TVM_REGISTER_GLOBAL("device_api.micro_dev") -.set_body([](TVMArgs args, TVMRetValue* rv) { - DeviceAPI* ptr = MicroDeviceAPI::Global().get(); - *rv = static_cast(ptr); - }); +TVM_REGISTER_GLOBAL("device_api.micro_dev").set_body([](TVMArgs args, TVMRetValue* rv) { + DeviceAPI* ptr = MicroDeviceAPI::Global().get(); + *rv = static_cast(ptr); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/micro/micro_module.cc b/src/runtime/micro/micro_module.cc index 50cee34be4a6..b4770ec6f934 100644 --- a/src/runtime/micro/micro_module.cc +++ b/src/runtime/micro/micro_module.cc @@ -21,15 +21,17 @@ * \file micro_module.cc */ -#include #include #include -#include +#include + #include -#include "micro_session.h" +#include + +#include "../pack_args.h" #include "low_level_device.h" #include "micro_common.h" -#include "../pack_args.h" +#include "micro_session.h" namespace tvm { namespace runtime { @@ -42,18 +44,17 @@ class MicroModuleNode final : public ModuleNode { ~MicroModuleNode() {} - const char* type_key() const final { - return "micro"; - } + const char* type_key() const final { return "micro"; } - PackedFunc GetFunction(const std::string& name, - const ObjectPtr& sptr_to_self) final; + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final; /*! * \brief initializes module by establishing device connection and loads binary * \param binary_path path of the binary to be loaded */ void InitMicroModule(const std::string& binary_path) { + // std::cout << "[MicroModuleNode::InitMicroModule]" << std::endl; + // std::cout << " start" << std::endl; session_ = MicroSession::Current(); symbol_map_ = session_->LoadBinary(binary_path, true).symbol_map; } @@ -66,27 +67,25 @@ class MicroModuleNode final : public ModuleNode { class MicroWrappedFunc { public: - MicroWrappedFunc(ObjectPtr session, - DevPtr func_ptr) { + MicroWrappedFunc(ObjectPtr session, TargetPtr func_ptr) { session_ = session; func_ptr_ = func_ptr; } void operator()(TVMArgs args, TVMRetValue* rv) const { - *rv = session_->PushToExecQueue(func_ptr_, args); + session_->PushToTaskQueue(func_ptr_, args); } private: /*! \brief reference to the session for this function (to keep the session alive) */ ObjectPtr session_; /*! \brief offset of the function to be called */ - DevPtr func_ptr_; + TargetPtr func_ptr_; }; -PackedFunc MicroModuleNode::GetFunction( - const std::string& name, - const ObjectPtr& sptr_to_self) { - DevPtr func_ptr; +PackedFunc MicroModuleNode::GetFunction(const std::string& name, + const ObjectPtr& sptr_to_self) { + TargetPtr func_ptr; if (name == tvm::runtime::symbol::tvm_module_main) { if (symbol_map_.HasSymbol(tvm::runtime::symbol::tvm_module_main)) { func_ptr = symbol_map_[tvm::runtime::symbol::tvm_module_main]; @@ -102,10 +101,10 @@ PackedFunc MicroModuleNode::GetFunction( // register loadfile function to load module from Python frontend TVM_REGISTER_GLOBAL("runtime.module.loadfile_micro_dev") -.set_body([](TVMArgs args, TVMRetValue* rv) { - auto n = make_object(); - n->InitMicroModule(args[0]); - *rv = runtime::Module(n); - }); + .set_body([](TVMArgs args, TVMRetValue* rv) { + auto n = make_object(); + n->InitMicroModule(args[0]); + *rv = runtime::Module(n); + }); } // namespace runtime } // namespace tvm diff --git a/src/runtime/micro/micro_section_allocator.h b/src/runtime/micro/micro_section_allocator.h index 5c75f92737ab..5cafb41bbc4b 100644 --- a/src/runtime/micro/micro_section_allocator.h +++ b/src/runtime/micro/micro_section_allocator.h @@ -23,7 +23,9 @@ #ifndef TVM_RUNTIME_MICRO_MICRO_SECTION_ALLOCATOR_H_ #define TVM_RUNTIME_MICRO_MICRO_SECTION_ALLOCATOR_H_ +#include #include + #include "micro_common.h" namespace tvm { @@ -38,16 +40,18 @@ class MicroSectionAllocator { * \brief constructor that specifies section boundaries * \param region location and size of the section on the device */ - explicit MicroSectionAllocator(DevMemRegion region, size_t word_size) - : start_addr_(region.start), - size_(0), - capacity_(region.size), - word_size_(word_size) { - CHECK_EQ(start_addr_.value().val64 % word_size, 0) - << "micro section start not aligned to " << word_size << " bytes"; - CHECK_EQ(capacity_ % word_size, 0) - << "micro section end not aligned to " << word_size << " bytes"; - } + explicit MicroSectionAllocator(std::string section_name, DevMemRegion region, + TargetWordSize word_size) + : section_name_(section_name), + start_addr_(region.start), + size_(0), + capacity_(region.size), + word_size_(word_size) { + CHECK_EQ(start_addr_.value().uint64() % word_size.bytes(), 0) + << "micro section start not aligned to " << word_size.bytes() << " bytes"; + CHECK_EQ(capacity_ % word_size.bytes(), 0) + << "micro section end not aligned to " << word_size.bytes() << " bytes"; + } /*! * \brief destructor @@ -56,17 +60,18 @@ class MicroSectionAllocator { /*! * \brief memory allocator - * \param size size of allocated memory in bytes + * \param alloc_size size of allocated memory in bytes * \return pointer to allocated memory region in section, nullptr if out of space */ - DevPtr Allocate(size_t size) { - size_ = UpperAlignValue(size_, word_size_); + TargetPtr Allocate(size_t size) { + size_ = UpperAlignValue(size_, word_size_.bytes()); CHECK(size_ + size < capacity_) - << "cannot alloc " << size << " bytes in section with start_addr " << - start_addr_.cast_to(); - DevPtr alloc_addr = start_addr_ + size_; + << "cannot alloc " << size << " bytes in section \"" << section_name_ + << "\" (start_addr=" << start_addr_.cast_to() << ", used=" << size_ + << ", capacity=" << capacity_ << ")"; + TargetPtr alloc_addr = start_addr_ + size_; size_ += size; - alloc_map_[alloc_addr.value().val64] = size; + alloc_map_[alloc_addr.value().uint64()] = size; return alloc_addr; } @@ -75,10 +80,10 @@ class MicroSectionAllocator { * \param offs offset to allocated memory * \note simple allocator scheme, more complex versions will be implemented later */ - void Free(DevPtr addr) { - CHECK(alloc_map_.find(addr.value().val64) != alloc_map_.end()) - << "freed pointer was never allocated"; - alloc_map_.erase(addr.value().val64); + void Free(TargetPtr addr) { + CHECK(alloc_map_.find(addr.value().uint64()) != alloc_map_.end()) + << "freed pointer was never allocated"; + alloc_map_.erase(addr.value().uint64()); if (alloc_map_.empty()) { size_ = 0; } @@ -87,17 +92,17 @@ class MicroSectionAllocator { /*! * \brief start offset of the memory region managed by this allocator */ - DevPtr start_addr() const { return start_addr_; } + TargetPtr start_addr() const { return start_addr_; } /*! * \brief current end addr of the space being used in this memory region */ - DevPtr curr_end_addr() const { return start_addr_ + size_; } + TargetPtr curr_end_addr() const { return start_addr_ + size_; } /*! * \brief end addr of the memory region managed by this allocator */ - DevPtr max_addr() const { return start_addr_ + capacity_; } + TargetPtr max_addr() const { return start_addr_ + capacity_; } /*! * \brief size of the section @@ -110,14 +115,16 @@ class MicroSectionAllocator { size_t capacity() const { return capacity_; } private: + /*! \brief name of the section (for debugging) */ + std::string section_name_; /*! \brief start address of the section */ - DevPtr start_addr_; + TargetPtr start_addr_; /*! \brief current size of the section */ size_t size_; /*! \brief total storage capacity of the section */ size_t capacity_; /*! \brief number of bytes in a word on the target device */ - size_t word_size_; + TargetWordSize word_size_; /*! \brief allocation map for allocation sizes */ std::unordered_map alloc_map_; }; diff --git a/src/runtime/micro/micro_session.cc b/src/runtime/micro/micro_session.cc index 4bdc8ed69797..f458872bfeb0 100644 --- a/src/runtime/micro/micro_session.cc +++ b/src/runtime/micro/micro_session.cc @@ -21,13 +21,19 @@ * \file micro_session.cc */ +#include "micro_session.h" + #include +#include #include + +#include +#include #include #include #include #include -#include "micro_session.h" + #include "low_level_device.h" #include "target_data_layout_encoder.h" @@ -41,164 +47,184 @@ struct TVMMicroSessionThreadLocalEntry { typedef dmlc::ThreadLocalStore TVMMicroSessionThreadLocalStore; ObjectPtr& MicroSession::Current() { - TVMMicroSessionThreadLocalEntry *entry = TVMMicroSessionThreadLocalStore::Get(); + TVMMicroSessionThreadLocalEntry* entry = TVMMicroSessionThreadLocalStore::Get(); CHECK_GT(entry->session_stack.size(), 0) << "No current session"; return entry->session_stack.top(); } void MicroSession::EnterWithScope(ObjectPtr session) { - TVMMicroSessionThreadLocalEntry *entry = TVMMicroSessionThreadLocalStore::Get(); + TVMMicroSessionThreadLocalEntry* entry = TVMMicroSessionThreadLocalStore::Get(); entry->session_stack.push(session); } void MicroSession::ExitWithScope() { - TVMMicroSessionThreadLocalEntry *entry = TVMMicroSessionThreadLocalStore::Get(); + TVMMicroSessionThreadLocalEntry* entry = TVMMicroSessionThreadLocalStore::Get(); CHECK(!entry->session_stack.empty()); entry->session_stack.pop(); } -MicroSession::MicroSession( - const std::string& comms_method, - const std::string& binary_path, - const std::string& toolchain_prefix, - uint64_t text_start, - size_t text_size, - uint64_t rodata_start, - size_t rodata_size, - uint64_t data_start, - size_t data_size, - uint64_t bss_start, - size_t bss_size, - uint64_t args_start, - size_t args_size, - uint64_t heap_start, - size_t heap_size, - uint64_t workspace_start, - size_t workspace_size, - uint64_t stack_start, - size_t stack_size, - size_t word_size, - bool thumb_mode, - const std::string& server_addr, - int port) - : toolchain_prefix_(toolchain_prefix) - , word_size_(word_size) - , thumb_mode_(thumb_mode) { - CHECK(word_size_ == 4 || word_size_ == 8) << "unsupported word size " << word_size_; +MicroSession::MicroSession(const std::string& comms_method, const std::string& binary_path, + const std::string& toolchain_prefix, uint64_t text_start, + size_t text_size, uint64_t rodata_start, size_t rodata_size, + uint64_t data_start, size_t data_size, uint64_t bss_start, + size_t bss_size, uint64_t args_start, size_t args_size, + uint64_t heap_start, size_t heap_size, uint64_t workspace_start, + size_t workspace_size, uint64_t stack_start, size_t stack_size, + TargetWordSize word_size, bool thumb_mode, bool use_device_timer, + const std::string& server_addr, int port, PackedFunc debug_func) + : toolchain_prefix_(toolchain_prefix), + word_size_(word_size), + thumb_mode_(thumb_mode), + use_device_timer_(use_device_timer), + batch_args_encoder_(args_size, word_size), + debug_func_{debug_func} { if (comms_method == "host") { // TODO(weberlo): move checks to python - CHECK( - text_start == 0 && - rodata_start == 0 && - data_start == 0 && - bss_start == 0 && - args_start == 0 && - heap_start == 0 && - workspace_start == 0 && - stack_start == 0) << "unable to specify section addresses for host device"; - size_t memory_size = - text_size + rodata_size + data_size + bss_size + - args_size + heap_size + workspace_size + stack_size; - void* base_addr; + CHECK(text_start == 0 && rodata_start == 0 && data_start == 0 && bss_start == 0 && + args_start == 0 && heap_start == 0 && workspace_start == 0 && stack_start == 0) + << "unable to specify section addresses for host device"; + size_t memory_size = text_size + rodata_size + data_size + bss_size + args_size + heap_size + + workspace_size + stack_size; + TargetPtr base_addr; low_level_device_ = HostLowLevelDeviceCreate(memory_size, &base_addr); - CHECK_EQ(reinterpret_cast(base_addr) % word_size_, 0) - << "base address not aligned to " << word_size_ << " bytes"; - DevPtr curr_addr = DevPtr(reinterpret_cast(base_addr)); - - section_allocators_[0] = std::make_shared(DevMemRegion { - .start = curr_addr, - .size = text_size, - }, word_size_); + CHECK_EQ(base_addr.value().uint64() % word_size.bytes(), 0) + << "base address not aligned to " << word_size.bytes() << " bytes"; + TargetPtr curr_addr = base_addr; + + section_allocators_[0] = std::make_shared("text", + DevMemRegion{ + .start = curr_addr, + .size = text_size, + }, + word_size_); curr_addr += text_size; - section_allocators_[1] = std::make_shared(DevMemRegion { - .start = curr_addr, - .size = rodata_size, - }, word_size_); + section_allocators_[1] = std::make_shared("rodata", + DevMemRegion{ + .start = curr_addr, + .size = rodata_size, + }, + word_size_); curr_addr += rodata_size; - section_allocators_[2] = std::make_shared(DevMemRegion { - .start = curr_addr, - .size = data_size, - }, word_size_); + section_allocators_[2] = std::make_shared("data", + DevMemRegion{ + .start = curr_addr, + .size = data_size, + }, + word_size_); curr_addr += data_size; - section_allocators_[3] = std::make_shared(DevMemRegion { - .start = curr_addr, - .size = bss_size, - }, word_size_); + section_allocators_[3] = std::make_shared("bss", + DevMemRegion{ + .start = curr_addr, + .size = bss_size, + }, + word_size_); curr_addr += bss_size; - section_allocators_[4] = std::make_shared(DevMemRegion { - .start = curr_addr, - .size = args_size, - }, word_size_); + section_allocators_[4] = std::make_shared("args", + DevMemRegion{ + .start = curr_addr, + .size = args_size, + }, + word_size_); curr_addr += args_size; - section_allocators_[5] = std::make_shared(DevMemRegion { - .start = curr_addr, - .size = heap_size, - }, word_size_); + section_allocators_[5] = std::make_shared("heap", + DevMemRegion{ + .start = curr_addr, + .size = heap_size, + }, + word_size_); curr_addr += heap_size; - section_allocators_[6] = std::make_shared(DevMemRegion { - .start = curr_addr, - .size = workspace_size, - }, word_size_); + section_allocators_[6] = std::make_shared("workspace", + DevMemRegion{ + .start = curr_addr, + .size = workspace_size, + }, + word_size_); curr_addr += workspace_size; - section_allocators_[7] = std::make_shared(DevMemRegion { - .start = curr_addr, - .size = stack_size, - }, word_size_); + section_allocators_[7] = std::make_shared("stack", + DevMemRegion{ + .start = curr_addr, + .size = stack_size, + }, + word_size_); curr_addr += stack_size; } else if (comms_method == "openocd") { low_level_device_ = OpenOCDLowLevelDeviceCreate(server_addr, port); - section_allocators_[0] = std::make_shared(DevMemRegion { - .start = DevPtr(text_start), - .size = text_size, - }, word_size_); - section_allocators_[1] = std::make_shared(DevMemRegion { - .start = DevPtr(rodata_start), - .size = rodata_size, - }, word_size_); - section_allocators_[2] = std::make_shared(DevMemRegion { - .start = DevPtr(data_start), - .size = data_size, - }, word_size_); - section_allocators_[3] = std::make_shared(DevMemRegion { - .start = DevPtr(bss_start), - .size = bss_size, - }, word_size_); - section_allocators_[4] = std::make_shared(DevMemRegion { - .start = DevPtr(args_start), - .size = args_size, - }, word_size_); - section_allocators_[5] = std::make_shared(DevMemRegion { - .start = DevPtr(heap_start), - .size = heap_size, - }, word_size_); - section_allocators_[6] = std::make_shared(DevMemRegion { - .start = DevPtr(workspace_start), - .size = workspace_size, - }, word_size_); - section_allocators_[7] = std::make_shared(DevMemRegion { - .start = DevPtr(stack_start), - .size = stack_size, - }, word_size_); + section_allocators_[0] = + std::make_shared("text", + DevMemRegion{ + .start = TargetPtr(word_size_, text_start), + .size = text_size, + }, + word_size_); + section_allocators_[1] = + std::make_shared("rodata", + DevMemRegion{ + .start = TargetPtr(word_size_, rodata_start), + .size = rodata_size, + }, + word_size_); + section_allocators_[2] = + std::make_shared("data", + DevMemRegion{ + .start = TargetPtr(word_size_, data_start), + .size = data_size, + }, + word_size_); + section_allocators_[3] = + std::make_shared("bss", + DevMemRegion{ + .start = TargetPtr(word_size_, bss_start), + .size = bss_size, + }, + word_size_); + section_allocators_[4] = + std::make_shared("args", + DevMemRegion{ + .start = TargetPtr(word_size_, args_start), + .size = args_size, + }, + word_size_); + section_allocators_[5] = + std::make_shared("heap", + DevMemRegion{ + .start = TargetPtr(word_size_, heap_start), + .size = heap_size, + }, + word_size_); + section_allocators_[6] = + std::make_shared("workspace", + DevMemRegion{ + .start = TargetPtr(word_size_, workspace_start), + .size = workspace_size, + }, + word_size_); + section_allocators_[7] = + std::make_shared("stack", + DevMemRegion{ + .start = TargetPtr(word_size_, stack_start), + .size = stack_size, + }, + word_size_); } else { LOG(FATAL) << "unsupported micro low-level device"; } + TargetPtr args_start_addr = GetAllocator(SectionKind::kArgs)->start_addr(); + batch_args_encoder_.set_start_addr(args_start_addr); + runtime_symbol_map_ = LoadBinary(binary_path, false).symbol_map; // Patch pointers to define the bounds of the workspace section and the word // size (for allocation alignment). std::shared_ptr ws_allocator = GetAllocator(SectionKind::kWorkspace); - TargetVal ws_start = ws_allocator->start_addr().value(); - TargetVal ws_end = ws_allocator->max_addr().value(); - TargetVal target_word_size { .val64 = word_size_ }; - if (word_size_ == 4) { - DevSymbolWrite(runtime_symbol_map_, "utvm_workspace_start", ws_start.val32); - DevSymbolWrite(runtime_symbol_map_, "utvm_workspace_end", ws_end.val32); - DevSymbolWrite(runtime_symbol_map_, "utvm_word_size", target_word_size.val32); - } else if (word_size_ == 8) { - DevSymbolWrite(runtime_symbol_map_, "utvm_workspace_start", ws_start.val64); - DevSymbolWrite(runtime_symbol_map_, "utvm_workspace_end", ws_end.val64); - DevSymbolWrite(runtime_symbol_map_, "utvm_word_size", target_word_size.val64); + DevSymbolWrite(runtime_symbol_map_, "utvm_workspace_start", ws_allocator->start_addr()); + DevSymbolWrite(runtime_symbol_map_, "utvm_workspace_end", ws_allocator->max_addr()); + if (word_size.bytes() == 4) { + DevSymbolWrite(runtime_symbol_map_, "utvm_word_size", uint32_t(word_size.bytes())); + } else if (word_size.bytes() == 8) { + DevSymbolWrite(runtime_symbol_map_, "utvm_word_size", uint64_t(word_size.bytes())); + } else { + CHECK(false) << "Unsupported word size unexpectedly here"; } } @@ -209,59 +235,121 @@ MicroSession::~MicroSession() { low_level_device_ = nullptr; } -double MicroSession::PushToExecQueue(DevPtr func_ptr, const TVMArgs& args) { +void MicroSession::PushToTaskQueue(TargetPtr func_ptr, const TVMArgs& args) { if (thumb_mode_) { + // TODO(areusch): should be |= func_ptr += 1; } + TargetVal func_dev_addr = func_ptr.value(); + + std::tuple arg_field_addrs = EncoderAppend(&batch_args_encoder_, args); + TargetVal arg_values_dev_addr{std::get<0>(arg_field_addrs).value()}; + TargetVal arg_type_codes_dev_addr{std::get<1>(arg_field_addrs).value()}; + + task_queue_.push_back(DevTask{.func = func_dev_addr, + .arg_values = arg_values_dev_addr, + .arg_type_codes = arg_type_codes_dev_addr, + .num_args = args.num_args}); + + if (task_queue_.size() == MicroSession::kTaskQueueCapacity) { + FlushTaskQueue(); + } +} + +void MicroSession::FlushTaskQueue() { + if (task_queue_.size() == 0) { + // nothing to run + return; + } + if (word_size_.bytes() == 4) { + FlushTaskQueuePriv(); + } else if (word_size_.bytes() == 8) { + FlushTaskQueuePriv(); + } +} - // Create an allocator stream for the memory region after the most recent - // allocation in the args section. - DevPtr args_addr = GetAllocator(SectionKind::kArgs)->curr_end_addr(); - TargetDataLayoutEncoder encoder(args_addr, word_size_); - - std::tuple arg_field_addrs = EncoderAppend(&encoder, args); - - // Flush `stream` to device memory. - DevPtr stream_dev_addr = - GetAllocator(SectionKind::kArgs)->Allocate(encoder.buf_size()); - low_level_device()->Write(stream_dev_addr, - reinterpret_cast(encoder.data()), - encoder.buf_size()); - - TargetVal arg_values_dev_addr = std::get<0>(arg_field_addrs).value(); - TargetVal arg_type_codes_dev_addr = std::get<1>(arg_field_addrs).value(); - if (word_size_ == 4) { - UTVMTask32 task = { - .func = func_ptr.value().val32, - .arg_values = arg_values_dev_addr.val32, - .arg_type_codes = arg_type_codes_dev_addr.val32, - .num_args = args.num_args, - }; - // Write the task. - DevSymbolWrite(runtime_symbol_map_, "utvm_task", task); - } else if (word_size_ == 8) { - UTVMTask64 task = { - .func = func_ptr.value().val64, - .arg_values = arg_values_dev_addr.val64, - .arg_type_codes = arg_type_codes_dev_addr.val64, - .num_args = args.num_args, - }; - // Write the task. - DevSymbolWrite(runtime_symbol_map_, "utvm_task", task); +template +void MicroSession::FlushTaskQueuePriv() { + std::vector prepped_tasks; + for (const auto& task : task_queue_) { + prepped_tasks.push_back(T(task)); } - DevPtr utvm_init_addr = runtime_symbol_map_["UTVMInit"]; - DevPtr utvm_done_addr = runtime_symbol_map_["UTVMDone"]; + // Flush `args` to device memory. + low_level_device()->Write(batch_args_encoder_.start_addr(), + reinterpret_cast(batch_args_encoder_.data()), + batch_args_encoder_.buf_size()); + + // Flush `tasks` to device memory. + TargetPtr dev_tasks_addr = runtime_symbol_map_["utvm_tasks"]; + low_level_device()->Write(dev_tasks_addr, reinterpret_cast(prepped_tasks.data()), + prepped_tasks.size() * sizeof(T)); + DevSymbolWrite(runtime_symbol_map_, "utvm_num_tasks", prepped_tasks.size()); + + TargetPtr utvm_init_addr = runtime_symbol_map_["UTVMInit"]; + TargetPtr utvm_done_addr = runtime_symbol_map_["UTVMDone"]; if (thumb_mode_) { + // TODO(areusch): should be |= utvm_init_addr += 1; } - low_level_device()->Execute(utvm_init_addr, utvm_done_addr); + bool did_debug = false; + if (debug_func_ != nullptr) { + TVMRetValue rv = debug_func_(); + if (rv.type_code() == kTVMNullptr) { + did_debug = true; + } else { + did_debug = static_cast(rv); + } + + if (did_debug && !use_device_timer_) { + LOG(INFO) << "NOTE: when debugging and use_device_timer == false, reported execution time " + << "will be inaccurate!"; + } + } + + if (!did_debug) { + std::chrono::time_point tbegin, + tend; + tbegin = std::chrono::high_resolution_clock::now(); + low_level_device()->Execute(utvm_init_addr, utvm_done_addr); + tend = std::chrono::high_resolution_clock::now(); + if (!use_device_timer_) { + last_batch_time_ += + std::chrono::duration_cast>(tend - tbegin).count() * 1000; + } + } + // Check if there was an error during execution. If so, log it. CheckDeviceError(); - uint32_t task_time = DevSymbolRead(runtime_symbol_map_, "utvm_task_time"); - GetAllocator(SectionKind::kArgs)->Free(stream_dev_addr); - return static_cast(task_time); + + if (use_device_timer_) { + uint64_t sum = 0; + std::vector times; + times.resize(task_queue_.size()); + low_level_device()->Read(runtime_symbol_map_["utvm_task_times"], times.data(), + task_queue_.size() * sizeof(uint32_t)); + int i = 0; + for (uint32_t time : times) { + LOG(INFO) << "Time " << i++ << ": " << time; + sum += time; + } + last_batch_time_ += static_cast(sum) / 1e3; + } else { + // TODO(weberlo): Reading internal data structure is hacky. + uint64_t sum = 0; + std::vector times; + times.resize(task_queue_.size()); + low_level_device()->Read(runtime_symbol_map_["utvm_task_times"], times.data(), + task_queue_.size() * sizeof(uint32_t)); + for (uint32_t time : times) { + sum += time; + } + last_batch_cycles_ += static_cast(sum); + } + + batch_args_encoder_.Clear(); + task_queue_.clear(); } BinaryInfo MicroSession::LoadBinary(const std::string& binary_path, bool patch_dylib_pointers) { @@ -270,32 +358,22 @@ BinaryInfo MicroSession::LoadBinary(const std::string& binary_path, bool patch_d DevMemRegion data_section; DevMemRegion bss_section; - text_section.size = GetSectionSize( - binary_path, SectionKind::kText, toolchain_prefix_, word_size_); - rodata_section.size = GetSectionSize( - binary_path, SectionKind::kRodata, toolchain_prefix_, word_size_); - data_section.size = GetSectionSize( - binary_path, SectionKind::kData, toolchain_prefix_, word_size_); - bss_section.size = GetSectionSize( - binary_path, SectionKind::kBss, toolchain_prefix_, word_size_); + text_section.size = + GetSectionSize(binary_path, SectionKind::kText, toolchain_prefix_, word_size_); + rodata_section.size = + GetSectionSize(binary_path, SectionKind::kRodata, toolchain_prefix_, word_size_); + data_section.size = + GetSectionSize(binary_path, SectionKind::kData, toolchain_prefix_, word_size_); + bss_section.size = GetSectionSize(binary_path, SectionKind::kBss, toolchain_prefix_, word_size_); text_section.start = AllocateInSection(SectionKind::kText, text_section.size); rodata_section.start = AllocateInSection(SectionKind::kRodata, rodata_section.size); data_section.start = AllocateInSection(SectionKind::kData, data_section.size); bss_section.start = AllocateInSection(SectionKind::kBss, bss_section.size); - CHECK(text_section.start != nullptr && rodata_section.start != nullptr && - data_section.start != nullptr && bss_section.start != nullptr) - << "not enough space to load module on device"; std::string relocated_bin = RelocateBinarySections( - binary_path, - word_size_, - text_section.start, - rodata_section.start, - data_section.start, - bss_section.start, - GetAllocator(SectionKind::kStack)->max_addr(), - toolchain_prefix_); + binary_path, word_size_, text_section.start, rodata_section.start, data_section.start, + bss_section.start, GetAllocator(SectionKind::kStack)->max_addr(), toolchain_prefix_); std::string text_contents = ReadSection(relocated_bin, SectionKind::kText, toolchain_prefix_); std::string rodata_contents = ReadSection(relocated_bin, SectionKind::kRodata, toolchain_prefix_); std::string data_contents = ReadSection(relocated_bin, SectionKind::kData, toolchain_prefix_); @@ -305,7 +383,7 @@ BinaryInfo MicroSession::LoadBinary(const std::string& binary_path, bool patch_d low_level_device_->Write(rodata_section.start, &rodata_contents[0], rodata_section.size); low_level_device_->Write(data_section.start, &data_contents[0], data_section.size); low_level_device_->Write(bss_section.start, &bss_contents[0], bss_section.size); - SymbolMap symbol_map {relocated_bin, toolchain_prefix_}; + SymbolMap symbol_map{relocated_bin, toolchain_prefix_, word_size_}; if (patch_dylib_pointers) { // Patch device lib pointers. @@ -314,7 +392,7 @@ BinaryInfo MicroSession::LoadBinary(const std::string& binary_path, bool patch_d PatchImplHole(symbol_map, "TVMAPISetLastError"); } - return BinaryInfo { + return BinaryInfo{ .text_section = text_section, .rodata_section = rodata_section, .data_section = data_section, @@ -323,13 +401,13 @@ BinaryInfo MicroSession::LoadBinary(const std::string& binary_path, bool patch_d }; } -std::tuple MicroSession::EncoderAppend( - TargetDataLayoutEncoder* encoder, const TVMArgs& args) { +std::tuple MicroSession::EncoderAppend(TargetDataLayoutEncoder* encoder, + const TVMArgs& args) { const int* type_codes = args.type_codes; int num_args = args.num_args; - auto tvm_vals_slot = encoder->Alloc(num_args); - auto type_codes_slot = encoder->Alloc(num_args); + auto tvm_vals_alloc = encoder->Alloc(num_args); + auto type_codes_alloc = encoder->Alloc(num_args); for (int i = 0; i < num_args; i++) { switch (type_codes[i]) { @@ -341,12 +419,13 @@ std::tuple MicroSession::EncoderAppend( // order to prevent premature session destruction. void* old_data = base_arr_handle->data; // Mutate the array to unwrap the `data` field. - base_arr_handle->data = reinterpret_cast(old_data)->data; + MicroDevSpace* dev_arr_ptr = reinterpret_cast(old_data); + base_arr_handle->data = reinterpret_cast(dev_arr_ptr->data.value().uint64()); // Now, encode the unwrapped version. void* arr_ptr = nullptr; - if (word_size_ == 4) { + if (word_size_.bytes() == 4) { arr_ptr = EncoderAppend(encoder, *base_arr_handle).cast_to(); - } else if (word_size_ == 8) { + } else if (word_size_.bytes() == 8) { arr_ptr = EncoderAppend(encoder, *base_arr_handle).cast_to(); } // And restore the original wrapped version. @@ -354,7 +433,7 @@ std::tuple MicroSession::EncoderAppend( TVMValue val; val.v_handle = arr_ptr; - tvm_vals_slot.WriteValue(val); + tvm_vals_alloc->WriteValue(val); break; } // TODO(weberlo): Implement `double` and `int64` case. @@ -366,77 +445,103 @@ std::tuple MicroSession::EncoderAppend( break; } } - type_codes_slot.WriteArray(type_codes, num_args); - return std::make_tuple(tvm_vals_slot.start_addr(), type_codes_slot.start_addr()); + type_codes_alloc->WriteArray(type_codes, num_args); + encoder->CheckUnfilledAllocs(); + return std::make_tuple(tvm_vals_alloc->start_addr(), type_codes_alloc->start_addr()); } template -DevPtr MicroSession::EncoderAppend(TargetDataLayoutEncoder* encoder, const DLTensor& arr) { - auto tvm_arr_slot = encoder->Alloc(); - auto shape_slot = encoder->Alloc(arr.ndim); - +TargetPtr MicroSession::EncoderAppend(TargetDataLayoutEncoder* encoder, const DLTensor& arr) { // `shape` and `strides` are stored on the host, so we need to write them to // the device first. The `data` field is already allocated on the device and // is a device pointer, so we don't need to write it. - shape_slot.WriteArray(arr.shape, arr.ndim); - DevPtr shape_dev_addr = shape_slot.start_addr(); - DevPtr strides_dev_addr = DevPtr(nullptr); + auto shape_alloc = encoder->Alloc(arr.ndim); + shape_alloc->WriteArray(arr.shape, arr.ndim); + TargetPtr shape_dev_addr = shape_alloc->start_addr(); + TargetPtr strides_dev_addr = TargetPtr(word_size_, nullptr); if (arr.strides != nullptr) { - auto stride_slot = encoder->Alloc(arr.ndim); - stride_slot.WriteArray(arr.strides, arr.ndim); - strides_dev_addr = stride_slot.start_addr(); + auto stride_alloc = encoder->Alloc(arr.ndim); + stride_alloc->WriteArray(arr.strides, arr.ndim); + strides_dev_addr = stride_alloc->start_addr(); } - T dev_arr( - TargetVal { .val64 = reinterpret_cast(arr.data) }, - arr.ctx, - arr.ndim, - arr.dtype, - shape_dev_addr.value(), - strides_dev_addr.value(), - TargetVal { .val64 = arr.byte_offset }); + T dev_arr(TargetVal{word_size_.bits(), reinterpret_cast(arr.data)}, arr.ctx, arr.ndim, + arr.dtype, shape_dev_addr.value(), strides_dev_addr.value(), + TargetVal{word_size_.bits(), arr.byte_offset}); CHECK(dev_arr.ctx.device_type == static_cast(kDLMicroDev)) - << "attempt to write DLTensor with non-micro device type"; + << "attempt to write DLTensor with non-micro device type"; // Update the device type to CPU, because from the microcontroller's // perspective, it is. dev_arr.ctx.device_type = DLDeviceType::kDLCPU; - tvm_arr_slot.WriteValue(dev_arr); - return tvm_arr_slot.start_addr(); + + auto tvm_arr_alloc = encoder->Alloc(); + tvm_arr_alloc->WriteValue(dev_arr); + return tvm_arr_alloc->start_addr(); } +// TODO(weberlo): switch over entirely to error codes that expand to error +// messages on the host side. void MicroSession::CheckDeviceError() { - int32_t return_code = DevSymbolRead(runtime_symbol_map_, "utvm_return_code"); - - if (return_code) { - std::uintptr_t last_error = - DevSymbolRead(runtime_symbol_map_, "utvm_last_error"); - std::string last_error_str; - if (last_error) { - DevPtr last_err_addr = DevPtr(last_error); - last_error_str = ReadString(last_err_addr); + int32_t last_error = DevSymbolRead(runtime_symbol_map_, "utvm_last_error"); + + if (last_error) { + if (!use_device_timer_ && + (last_error == UTVM_ERR_TIMER_OVERFLOW || last_error == UTVM_ERR_TIMER_NOT_IMPLEMENTED)) { + // these errors don't matter if we're not using the on-device timer + return; + } + std::string err_msg; + switch (last_error) { + case UTVM_ERR_NOT_FINISHED: + err_msg = "execution timed out"; + break; + case UTVM_ERR_TIMER_NOT_IMPLEMENTED: + err_msg = "timer is not implemented for the target device"; + break; + case UTVM_ERR_TIMER_OVERFLOW: + // TODO(weberlo): this should be remedied by using interrupts to accumulate the + // timer into a larger datatype (ARM timers are only 24 bits) + err_msg = "timer overflowed during execution"; + break; + case UTVM_ERR_WS_DOUBLE_FREE: + err_msg = "free called with no active workspace allocations"; + break; + case UTVM_ERR_WS_OUT_OF_SPACE: + err_msg = "ran out of space in workspace section"; + break; + case UTVM_ERR_WS_TOO_MANY_ALLOCS: + err_msg = "exceeded number of allocs the runtime can keep track of"; + break; + case UTVM_ERR_WS_ZERO_SIZE_ALLOC: + err_msg = "attempt to allocate scratchpad of size zero"; + break; + case UTVM_ERR_WS_UNALIGNED_START: + err_msg = "start of workspace section is not word-aligned"; + break; + case UTVM_ERR_WS_UNALIGNED_ALLOC_SIZE: + err_msg = "scratchpad allocation size is not a multiple of the word size"; + break; + default: + err_msg = "unknown error code"; + break; } LOG(FATAL) << "error during micro function execution:\n" - << " return code: " << std::dec << return_code << "\n" - << " dev str addr: 0x" << std::hex << last_error << "\n" - << " dev str data: " << last_error_str << std::endl; + << " error ID: " << std::dec << last_error << std::endl + << " error message: " << err_msg; } } void MicroSession::PatchImplHole(const SymbolMap& symbol_map, const std::string& func_name) { - DevPtr runtime_impl_addr = runtime_symbol_map_[func_name]; + TargetPtr runtime_impl_addr = runtime_symbol_map_[func_name]; if (thumb_mode_) { runtime_impl_addr += 1; } std::ostringstream func_name_underscore; func_name_underscore << func_name << "_"; - if (word_size_ == 4) { - DevSymbolWrite(symbol_map, func_name_underscore.str(), runtime_impl_addr.value().val32); - } else if (word_size_ == 8) { - DevSymbolWrite(symbol_map, func_name_underscore.str(), runtime_impl_addr.value().val64); - } + DevSymbolWrite(symbol_map, func_name_underscore.str(), runtime_impl_addr); } -std::string MicroSession::ReadString(DevPtr str_addr) { +std::string MicroSession::ReadString(TargetPtr str_addr) { std::ostringstream result; const size_t buf_size = 256; std::vector buf(buf_size, 0); @@ -454,98 +559,128 @@ std::string MicroSession::ReadString(DevPtr str_addr) { return result.str(); } -DevPtr MicroSession::AllocateInSection(SectionKind type, size_t size) { +TargetPtr MicroSession::AllocateInSection(SectionKind type, size_t size) { return GetAllocator(type)->Allocate(size); } -void MicroSession::FreeInSection(SectionKind type, DevPtr addr) { +void MicroSession::FreeInSection(SectionKind type, TargetPtr addr) { return GetAllocator(type)->Free(addr); } template T MicroSession::DevSymbolRead(const SymbolMap& symbol_map, const std::string& symbol) { - DevPtr sym_addr = symbol_map[symbol]; + TargetPtr sym_addr = symbol_map[symbol]; T result; low_level_device()->Read(sym_addr, &result, sizeof(T)); return result; } +void MicroSession::DevSymbolWrite(const SymbolMap& symbol_map, const std::string& symbol, + const TargetPtr& ptr) { + if (word_size_.bytes() == 4) { + DevSymbolWrite(symbol_map, symbol, ptr.value().uint32()); + } else if (word_size_.bytes() == 8) { + DevSymbolWrite(symbol_map, symbol, ptr.value().uint64()); + } else { + CHECK(false) << "Unsupported word size unexpectedly here"; + } +} + template -void MicroSession::DevSymbolWrite(const SymbolMap& symbol_map, - const std::string& symbol, +void MicroSession::DevSymbolWrite(const SymbolMap& symbol_map, const std::string& symbol, const T& value) { - DevPtr sym_addr = symbol_map[symbol]; + TargetPtr sym_addr = symbol_map[symbol]; low_level_device()->Write(sym_addr, &value, sizeof(T)); } -PackedFunc MicroSession::GetFunction( - const std::string& name, - const ObjectPtr& sptr_to_self) { +PackedFunc MicroSession::GetFunction(const std::string& name, + const ObjectPtr& sptr_to_self) { if (name == "enter") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { MicroSession::EnterWithScope(GetObjectPtr(this)); }); } else if (name == "exit") { - return PackedFunc([sptr_to_self](TVMArgs args, TVMRetValue* rv) { - MicroSession::ExitWithScope(); - }); + return PackedFunc( + [sptr_to_self](TVMArgs args, TVMRetValue* rv) { MicroSession::ExitWithScope(); }); + // TODO(weberlo): add a `clear_batch_timer` func + } else if (name == "get_last_batch_time") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetLastBatchTime(); }); + // TODO(weberlo): remove this func + } else if (name == "get_last_batch_cycles") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetLastBatchCycles(); }); } else { return PackedFunc(); } } +TVM_REGISTER_GLOBAL("micro._GetMicroTimeEvaluator").set_body([](TVMArgs args, TVMRetValue* rv) { + PackedFunc pf = args[0]; + TVMContext ctx = args[1]; + uint64_t number = args[2]; + uint64_t repeat = args[3]; + + auto ftimer = [pf, ctx, number, repeat](TVMArgs args, TVMRetValue* rv) mutable { + TVMRetValue temp; + std::ostringstream os; + + for (unsigned int i = 0; i < repeat; ++i) { + // start timing + CHECK(number < MicroSession::kTaskQueueCapacity) + << "`number` must be less than uTVM task queue capacity"; + for (unsigned int j = 0; j < number; ++j) { + pf.CallPacked(args, &temp); + } + ObjectPtr session = MicroSession::Current(); + DeviceAPI::Get(ctx)->StreamSync(ctx, nullptr); + double time_per_batch = session->GetLastBatchTime() / number; + os.write(reinterpret_cast(&time_per_batch), sizeof(time_per_batch)); + } + std::string blob = os.str(); + TVMByteArray arr; + arr.size = blob.length(); + arr.data = blob.data(); + // return the time. + *rv = arr; + }; + *rv = PackedFunc(ftimer); +}); + // create micro session and low-level device from Python frontend -TVM_REGISTER_GLOBAL("micro._CreateSession") -.set_body([](TVMArgs args, TVMRetValue* rv) { - const std::string& comms_method = args[0]; - const std::string& binary_path = args[1]; - const std::string& toolchain_prefix = args[2]; - uint64_t text_start = args[3]; - size_t text_size = args[4]; - uint64_t rodata_start = args[5]; - size_t rodata_size = args[6]; - uint64_t data_start = args[7]; - size_t data_size = args[8]; - uint64_t bss_start = args[9]; - size_t bss_size = args[10]; - uint64_t args_start = args[11]; - size_t args_size = args[12]; - uint64_t heap_start = args[13]; - size_t heap_size = args[14]; - uint64_t workspace_start = args[15]; - size_t workspace_size = args[16]; - uint64_t stack_start = args[17]; - size_t stack_size = args[18]; - size_t word_size = args[19]; - bool thumb_mode = args[20]; - const std::string& server_addr = args[21]; - int port = args[22]; - ObjectPtr session = make_object( - comms_method, - binary_path, - toolchain_prefix, - text_start, - text_size, - rodata_start, - rodata_size, - data_start, - data_size, - bss_start, - bss_size, - args_start, - args_size, - heap_start, - heap_size, - workspace_start, - workspace_size, - stack_start, - stack_size, - word_size, - thumb_mode, - server_addr, - port); - *rv = Module(session); - }); +TVM_REGISTER_GLOBAL("micro._CreateSession").set_body([](TVMArgs args, TVMRetValue* rv) { + const std::string& comms_method = args[0]; + const std::string& binary_path = args[1]; + const std::string& toolchain_prefix = args[2]; + uint64_t text_start = args[3]; + size_t text_size = uint64_t(args[4]); + uint64_t rodata_start = args[5]; + size_t rodata_size = uint64_t(args[6]); + uint64_t data_start = args[7]; + size_t data_size = uint64_t(args[8]); + uint64_t bss_start = args[9]; + size_t bss_size = uint64_t(args[10]); + uint64_t args_start = args[11]; + size_t args_size = uint64_t(args[12]); + uint64_t heap_start = args[13]; + size_t heap_size = uint64_t(args[14]); + uint64_t workspace_start = args[15]; + size_t workspace_size = uint64_t(args[16]); + uint64_t stack_start = args[17]; + size_t stack_size = uint64_t(args[18]); + TargetWordSize word_size{uint64_t(args[19])}; + bool thumb_mode = args[20]; + bool use_device_timer = args[21]; + const std::string& server_addr = args[22]; + int port = args[23]; + PackedFunc debug_func = args[24]; + ObjectPtr session = make_object( + comms_method, binary_path, toolchain_prefix, text_start, text_size, rodata_start, rodata_size, + data_start, data_size, bss_start, bss_size, args_start, args_size, heap_start, heap_size, + workspace_start, workspace_size, stack_start, stack_size, word_size, thumb_mode, + use_device_timer, server_addr, port, debug_func); + *rv = Module(session); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/micro/micro_session.h b/src/runtime/micro/micro_session.h index 9e844e8b2140..f911cf7dde43 100644 --- a/src/runtime/micro/micro_session.h +++ b/src/runtime/micro/micro_session.h @@ -34,24 +34,25 @@ #ifndef TVM_RUNTIME_MICRO_MICRO_SESSION_H_ #define TVM_RUNTIME_MICRO_MICRO_SESSION_H_ -#include "micro_common.h" -#include "micro_section_allocator.h" - -#include #include +#include #include #include +#include #include #include -#include #include "low_level_device.h" +#include "micro_common.h" +#include "micro_section_allocator.h" #include "target_data_layout_encoder.h" namespace tvm { namespace runtime { +struct DevTask; + /*! * \brief session for facilitating micro device interaction */ @@ -63,15 +64,15 @@ class MicroSession : public ModuleNode { * \param sptr_to_self The pointer to the module node. * \return The corresponding member function. */ - virtual PackedFunc GetFunction(const std::string& name, - const ObjectPtr& sptr_to_self); + virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self); + + // todo having this decoupled from the value in utvm_runtime.c gives me stress dreams + static const size_t kTaskQueueCapacity = 20; /*! * \return The type key of the executor. */ - const char* type_key() const final { - return "MicroSession"; - } + const char* type_key() const final { return "MicroSession"; } /*! * \brief creates session by setting up a low-level device and initting allocators for it @@ -94,35 +95,19 @@ class MicroSession : public ModuleNode { * \param workspace_size workspace section size * \param stack_start stack section start address * \param stack_size stack section size - * \param word_size number of bytes in a word on the target device + * \param word_size_bytes number of bytes in a word on the target device * \param thumb_mode whether the target device requires a thumb-mode bit on function addresses * \param server_addr address of the OpenOCD server to connect to (if `comms_method == "openocd"`) * \param port port of the OpenOCD server to connect to (if `comms_method == "openocd"`) */ - MicroSession( - const std::string& comms_method, - const std::string& binary_path, - const std::string& toolchain_prefix, - uint64_t text_start, - size_t text_size, - uint64_t rodata_start, - size_t rodata_size, - uint64_t data_start, - size_t data_size, - uint64_t bss_start, - size_t bss_size, - uint64_t args_start, - size_t args_size, - uint64_t heap_start, - size_t heap_size, - uint64_t workspace_start, - size_t workspace_size, - uint64_t stack_start, - size_t stack_size, - size_t word_size, - bool thumb_mode, - const std::string& server_addr, - int port); + MicroSession(const std::string& comms_method, const std::string& binary_path, + const std::string& toolchain_prefix, uint64_t text_start, size_t text_size, + uint64_t rodata_start, size_t rodata_size, uint64_t data_start, size_t data_size, + uint64_t bss_start, size_t bss_size, uint64_t args_start, size_t args_size, + uint64_t heap_start, size_t heap_size, uint64_t workspace_start, + size_t workspace_size, uint64_t stack_start, size_t stack_size, + TargetWordSize word_size, bool thumb_mode, bool use_device_timer, + const std::string& server_addr, int port, PackedFunc debug_func); /*! * \brief destructor @@ -137,7 +122,19 @@ class MicroSession : public ModuleNode { * \param args args to the packed function * \return elapsed time during function execution on the device */ - double PushToExecQueue(DevPtr func, const TVMArgs& args); + void PushToTaskQueue(TargetPtr func, const TVMArgs& args); + + /*! + * \brief serialize runtime metadata to the device for enqueued tasks and execute + * \return elapsed time during function execution on the device + */ + void FlushTaskQueue(); + + /*! + * \brief TODO + */ + template + void FlushTaskQueuePriv(); /*! * \brief loads binary onto device @@ -153,36 +150,44 @@ class MicroSession : public ModuleNode { * \param size size of allocated memory in bytes * \return pointer to allocated memory region in section, nullptr if out of space */ - DevPtr AllocateInSection(SectionKind type, size_t size); + TargetPtr AllocateInSection(SectionKind type, size_t size); /*! * \brief free prior allocation from section * \param type type of section to allocate in * \param addr device address of allocated memory */ - void FreeInSection(SectionKind type, DevPtr addr); + void FreeInSection(SectionKind type, TargetPtr addr); /*! * \brief read string from device to host * \param str_addr device address of first character of string * \return host copy of device string that was read */ - std::string ReadString(DevPtr str_addr); + std::string ReadString(TargetPtr str_addr); /*! - * \brief read value of symbol from device memory - * \param symbol_map symbol map to read location of symbol from - * \param symbol name of symbol being read from - * \return value at symbol in memory - */ + * \brief read value of symbol from device memory + * \param symbol_map symbol map to read location of symbol from + * \param symbol name of symbol being read from + * \return value at symbol in memory + */ template T DevSymbolRead(const SymbolMap& symbol_map, const std::string& symbol); /*! - * \brief write value into device memory corresponding to symbol - * \param symbol_map symbol map to read location of symbol from - * \param symbol name of symbol being written to - * \param value value being written into symbol + * \brief write pointer value into device memory corresponding to symbol + * \param symbol_map symbol map to read location of symbol from + * \param symbol name of symbol being written to + * \param ptr pointer value to write into symbol + */ + void DevSymbolWrite(const SymbolMap& symbol_map, const std::string& symbol, const TargetPtr& ptr); + + /*! + * \brief write value into device memory corresponding to symbol + * \param symbol_map symbol map to read location of symbol from + * \param symbol name of symbol being written to + * \param value value being written into symbol */ template void DevSymbolWrite(const SymbolMap& symbol_map, const std::string& symbol, const T& value); @@ -196,6 +201,18 @@ class MicroSession : public ModuleNode { return low_level_device_; } + const double GetLastBatchTime() { + double result = last_batch_time_; + last_batch_time_ = 0.0; + return result; + } + + const double GetLastBatchCycles() { + double result = last_batch_cycles_; + last_batch_cycles_ = 0.0; + return result; + } + private: /*! \brief low-level device pointer */ std::shared_ptr low_level_device_; @@ -205,7 +222,7 @@ class MicroSession : public ModuleNode { std::shared_ptr section_allocators_[static_cast(SectionKind::kNumKinds)]; /*! \brief number of bytes in a word on the target device */ - size_t word_size_; + TargetWordSize word_size_; /*! \brief whether the target device requires a thumb-mode bit on function addresses * * ARM and other manufacturers use the lowest bit of a function address to determine @@ -213,8 +230,22 @@ class MicroSession : public ModuleNode { * results in more compact binaries. */ bool thumb_mode_; + /*! \brief TODO */ + bool use_device_timer_; /*! \brief symbol map for the device runtime */ SymbolMap runtime_symbol_map_; + /*! \brief TODO */ + std::vector task_queue_; + // TODO(weberlo): we don't even need an allocator mechanism for the args + // section. there's only ever one allocation. + /*! \brief TODO hack */ + TargetDataLayoutEncoder batch_args_encoder_; + /*! \brief TODO hack */ + double last_batch_time_; + /*! \brief TODO hack */ + double last_batch_cycles_; + /*! \brief the debug function invoked to launch gdb */ + PackedFunc debug_func_; /*! * \brief patches a function pointer in this module to an implementation @@ -228,7 +259,8 @@ class MicroSession : public ModuleNode { * \param args args to be appended * \return device address of the allocated args */ - std::tuple EncoderAppend(TargetDataLayoutEncoder* encoder, const TVMArgs& args); + std::tuple EncoderAppend(TargetDataLayoutEncoder* encoder, + const TVMArgs& args); /*! * \brief appends a `DLTensor` to the host-side buffer of `encoder` @@ -237,7 +269,7 @@ class MicroSession : public ModuleNode { * \return device address of the allocated `DLTensor` */ template - DevPtr EncoderAppend(TargetDataLayoutEncoder* encoder, const DLTensor& arr); + TargetPtr EncoderAppend(TargetDataLayoutEncoder* encoder, const DLTensor& arr); /*! * \brief checks and logs if there was an error during the device's most recent execution @@ -254,15 +286,15 @@ class MicroSession : public ModuleNode { } /*! - * \brief Push a new session context onto the thread-local stack. - * The session on top of the stack is used as the current global session. - */ + * \brief Push a new session context onto the thread-local stack. + * The session on top of the stack is used as the current global session. + */ static void EnterWithScope(ObjectPtr session); /*! - * \brief Pop a session off the thread-local context stack, - * restoring the previous session as the current context. - */ + * \brief Pop a session off the thread-local context stack, + * restoring the previous session as the current context. + */ static void ExitWithScope(); }; @@ -274,7 +306,7 @@ class MicroSession : public ModuleNode { */ struct MicroDevSpace { /*! \brief data being wrapped */ - void* data; + TargetPtr data; /*! \brief shared ptr to session where this data is valid */ ObjectPtr session; }; @@ -283,33 +315,26 @@ struct MicroDevSpace { /*! \brief TVM array for serialization to 32-bit devices */ struct TVMArray32 { - TVMArray32( - TargetVal data, - DLContext ctx, - int32_t ndim, - DLDataType dtype, - TargetVal shape, - TargetVal strides, - TargetVal byte_offset) - : data(data.val32), - ctx(ctx), - ndim(ndim), - pad0(0), - dtype(dtype), - shape(shape.val32), - strides(strides.val32), - pad1(0), - byte_offset(byte_offset.val32), - pad2(0) { } - - /*! \brief opaque pointer to the allocated data */ + TVMArray32(TargetVal data, DLContext ctx, int32_t ndim, DLDataType dtype, TargetVal shape, + TargetVal strides, TargetVal byte_offset) + : data{data.uint32()}, + ctx{ctx}, + ndim{ndim}, + dtype{dtype}, + shape{shape.uint32()}, + strides{strides.uint32()}, + byte_offset{byte_offset.uint32()} {} + + /*! + * \brief The opaque data pointer points to the allocated data. + * This will be CUDA device pointer or cl_mem handle in OpenCL. + * This pointer is always aligns to 256 bytes as in CUDA. + */ uint32_t data; /*! \brief The device context of the tensor */ DLContext ctx; /*! \brief Number of dimensions */ int32_t ndim; - /*! \brief Padding to enforce struct alignment */ - uint32_t pad0; /*! \brief The data type of the pointer */ DLDataType dtype; /*! \brief The shape of the tensor */ @@ -319,41 +344,31 @@ struct TVMArray32 { * can be NULL, indicating tensor is compact. */ uint32_t strides; - /*! \brief Padding to enforce struct alignment */ - uint32_t pad1; /*! \brief The offset in bytes to the beginning pointer to data */ uint32_t byte_offset; - /*! \brief Padding to enforce struct alignment */ - uint32_t pad2; }; /*! \brief TVM array for serialization to 64-bit devices */ struct TVMArray64 { - TVMArray64( - TargetVal data, - DLContext ctx, - int32_t ndim, - DLDataType dtype, - TargetVal shape, - TargetVal strides, - TargetVal byte_offset) - : data(data.val64), - ctx(ctx), - ndim(ndim), - pad0(0), - dtype(dtype), - shape(shape.val64), - strides(strides.val64), - byte_offset(byte_offset.val64) { } - - /*! \brief opaque pointer to the allocated data */ + TVMArray64(TargetVal data, DLContext ctx, int32_t ndim, DLDataType dtype, TargetVal shape, + TargetVal strides, TargetVal byte_offset) + : data(data.uint64()), + ctx(ctx), + ndim(ndim), + dtype(dtype), + shape(shape.uint64()), + strides(strides.uint64()), + byte_offset(byte_offset.uint64()) {} + /*! + * \brief The opaque data pointer points to the allocated data. + * This will be CUDA device pointer or cl_mem handle in OpenCL. + * This pointer is always aligns to 256 bytes as in CUDA. + */ uint64_t data; /*! \brief The device context of the tensor */ DLContext ctx; /*! \brief Number of dimensions */ int32_t ndim; - /*! \brief Padding to enforce struct alignment */ - uint32_t pad0; /*! \brief The data type of the pointer */ DLDataType dtype; /*! \brief The shape of the tensor */ @@ -367,8 +382,26 @@ struct TVMArray64 { uint64_t byte_offset; }; +/*! \brief MicroTVM task to store in task queue before specializing to word size */ +struct DevTask { + /*! \brief Pointer to function to call for this task */ + TargetVal func; + /*! \brief Array of argument values */ + TargetVal arg_values; + /*! \brief Array of type codes for each argument value */ + TargetVal arg_type_codes; + /*! \brief Number of arguments */ + int32_t num_args; +}; + /*! \brief MicroTVM task for serialization to 32-bit devices */ typedef struct StructUTVMTask32 { + StructUTVMTask32(DevTask task) + : func(task.func.uint32()), + arg_values(task.arg_values.uint32()), + arg_type_codes(task.arg_type_codes.uint32()), + num_args(task.num_args) {} + /*! \brief Pointer to function to call for this task */ uint32_t func; /*! \brief Array of argument values */ @@ -377,10 +410,16 @@ typedef struct StructUTVMTask32 { uint32_t arg_type_codes; /*! \brief Number of arguments */ int32_t num_args; -} UTVMTask32; +} StructUTVMTask32; /*! \brief MicroTVM task for serialization to 64-bit devices */ typedef struct StructUTVMTask64 { + StructUTVMTask64(DevTask task) + : func(task.func.uint64()), + arg_values(task.arg_values.uint64()), + arg_type_codes(task.arg_type_codes.uint64()), + num_args(task.num_args) {} + /*! \brief Pointer to function to call for this task */ uint64_t func; /*! \brief Array of argument values */ @@ -389,7 +428,7 @@ typedef struct StructUTVMTask64 { uint64_t arg_type_codes; /*! \brief Number of arguments */ int32_t num_args; -} UTVMTask64; +} StructUTVMTask64; } // namespace runtime } // namespace tvm diff --git a/src/runtime/micro/openocd_low_level_device.cc b/src/runtime/micro/openocd_low_level_device.cc index e5c83e590c36..610ca8590dd1 100644 --- a/src/runtime/micro/openocd_low_level_device.cc +++ b/src/runtime/micro/openocd_low_level_device.cc @@ -20,11 +20,11 @@ /*! * \file openocd_low_level_device.cc */ -#include #include +#include -#include "micro_common.h" #include "low_level_device.h" +#include "micro_common.h" #include "tcl_socket.h" namespace tvm { @@ -40,17 +40,19 @@ class OpenOCDLowLevelDevice final : public LowLevelDevice { * \param server_addr address of the OpenOCD server to connect to * \param port port of the OpenOCD server to connect to */ - explicit OpenOCDLowLevelDevice(const std::string& server_addr, - int port) : socket_() { + explicit OpenOCDLowLevelDevice(const std::string& server_addr, int port) : socket_() { server_addr_ = server_addr; port_ = port; socket_.Connect(tvm::support::SockAddr(server_addr_.c_str(), port_)); - socket_.cmd_builder() << "halt 0"; + socket_.cmd_builder() << "reset run"; + socket_.SendCommand(); + + socket_.cmd_builder() << "halt 500"; socket_.SendCommand(); } - void Read(DevPtr addr, void* buf, size_t num_bytes) { + void Read(TargetPtr addr, void* buf, size_t num_bytes) override { if (num_bytes == 0) { return; } @@ -77,18 +79,17 @@ class OpenOCDLowLevelDevice final : public LowLevelDevice { socket_.cmd_builder() << "array unset output"; socket_.SendCommand(); - socket_.cmd_builder() - << "mem2array output" - << " " << std::dec << kWordSize - << " " << addr.cast_to() - // Round up any request sizes under a byte, since OpenOCD doesn't support - // sub-byte-sized transfers. - << " " << std::dec << (num_bytes < 8 ? 8 : num_bytes); + socket_.cmd_builder() << "mem2array output" + << " " << std::dec << kWordSize << " " + << addr.cast_to() + // Round up any request sizes under a byte, since OpenOCD doesn't + // support sub-byte-sized transfers. + << " " << std::dec << (num_bytes < 8 ? 8 : num_bytes); socket_.SendCommand(); } { - socket_.cmd_builder() << "ocd_echo $output"; + socket_.cmd_builder() << "return $output"; socket_.SendCommand(); const std::string& reply = socket_.last_reply(); @@ -101,9 +102,8 @@ class OpenOCDLowLevelDevice final : public LowLevelDevice { // The response from this command pairs indices with the contents of the // memory at that index. values >> index; - CHECK(index < num_bytes) - << "index " << index << - " out of bounds (length " << num_bytes << ")"; + CHECK(index < num_bytes) << "index " << index << " out of bounds (length " << num_bytes + << ")"; // Read the value into `curr_val`, instead of reading directly into // `buf_iter`, because otherwise it's interpreted as the ASCII value and // not the integral value. @@ -119,7 +119,7 @@ class OpenOCDLowLevelDevice final : public LowLevelDevice { } } - void Write(DevPtr addr, const void* buf, size_t num_bytes) { + void Write(TargetPtr addr, const void* buf, size_t num_bytes) override { if (num_bytes == 0) { return; } @@ -162,16 +162,14 @@ class OpenOCDLowLevelDevice final : public LowLevelDevice { socket_.SendCommand(); } { - socket_.cmd_builder() - << "array2mem input" - << " " << std::dec << kWordSize - << " " << addr.cast_to() - << " " << std::dec << num_bytes; + socket_.cmd_builder() << "array2mem input" + << " " << std::dec << kWordSize << " " << addr.cast_to() << " " + << std::dec << num_bytes; socket_.SendCommand(); } } - void Execute(DevPtr func_addr, DevPtr breakpoint_addr) { + void Execute(TargetPtr func_addr, TargetPtr breakpoint_addr) override { socket_.cmd_builder() << "halt 0"; socket_.SendCommand(); @@ -193,9 +191,7 @@ class OpenOCDLowLevelDevice final : public LowLevelDevice { socket_.SendCommand(); } - const char* device_type() const final { - return "openocd"; - } + const char* device_type() const final { return "openocd"; } private: /*! \brief socket used to communicate with the device through Tcl */ @@ -207,18 +203,17 @@ class OpenOCDLowLevelDevice final : public LowLevelDevice { /*! \brief number of bytes in a word on the target device (64-bit) */ static const constexpr ssize_t kWordSize = 8; - // NOTE: OpenOCD will call any request larger than this constant an "absurd - // request". + // NOTE: The OS pipe buffer must be able to handle a line long enough to + // print this transfer request. /*! \brief maximum number of bytes allowed in a single memory transfer */ - static const constexpr ssize_t kMemTransferLimit = 64000; + static const constexpr ssize_t kMemTransferLimit = 8000; /*! \brief number of milliseconds to wait for function execution to halt */ - static const constexpr int kWaitTime = 10000; + static const constexpr int kWaitTime = 30000; }; const std::shared_ptr OpenOCDLowLevelDeviceCreate(const std::string& server_addr, int port) { - std::shared_ptr lld = - std::make_shared(server_addr, port); + std::shared_ptr lld = std::make_shared(server_addr, port); return lld; } diff --git a/src/runtime/micro/standalone/minimal_vector.h b/src/runtime/micro/standalone/minimal_vector.h index 4d04e526329f..74bea06ebcfd 100644 --- a/src/runtime/micro/standalone/minimal_vector.h +++ b/src/runtime/micro/standalone/minimal_vector.h @@ -27,7 +27,6 @@ namespace tvm { namespace micro { - // A minimal wrapper, derived from https://github.com/Robbepop/dynarray/, that // supports a minimal subset of the std::vector API with a minimized code size. template diff --git a/src/runtime/micro/standalone/utvm_graph_runtime.cc b/src/runtime/micro/standalone/utvm_graph_runtime.cc index 546ed7d4988b..e19ee347a45e 100644 --- a/src/runtime/micro/standalone/utvm_graph_runtime.cc +++ b/src/runtime/micro/standalone/utvm_graph_runtime.cc @@ -20,8 +20,10 @@ #include "utvm_graph_runtime.h" #include + #include #include + #include "picojson.h" namespace tvm { @@ -325,7 +327,7 @@ std::function CreateTVMOp(const DSOModule& module, const TVMOpParam& par } TVMValue; /*typedef*/ enum { kTVMDLTensorHandle = 7U, - } /*TVMTypeCode*/; + } /*TVMArgTypeCode*/; struct OpArgs { DynArray args; DynArray arg_values; diff --git a/src/runtime/micro/standalone/utvm_runtime.cc b/src/runtime/micro/standalone/utvm_runtime.cc index 418443818bf1..73d616b6d482 100644 --- a/src/runtime/micro/standalone/utvm_runtime.cc +++ b/src/runtime/micro/standalone/utvm_runtime.cc @@ -16,15 +16,15 @@ * specific language governing permissions and limitations * under the License. */ +#include "tvm/runtime/micro/standalone/utvm_runtime.h" + #include -#include "tvm/runtime/micro/standalone/utvm_runtime.h" #include "utvm_graph_runtime.h" void* UTVMRuntimeCreate(const char* json, size_t json_len, void* module) { - return new tvm::micro::MicroGraphRuntime( - std::string(json, json + json_len), - reinterpret_cast(module)); + return new tvm::micro::MicroGraphRuntime(std::string(json, json + json_len), + reinterpret_cast(module)); } void UTVMRuntimeDestroy(void* handle) { diff --git a/src/runtime/micro/standalone/utvm_runtime_api.cc b/src/runtime/micro/standalone/utvm_runtime_api.cc index 896ff578da9e..a6ac420feec2 100644 --- a/src/runtime/micro/standalone/utvm_runtime_api.cc +++ b/src/runtime/micro/standalone/utvm_runtime_api.cc @@ -20,6 +20,7 @@ #include "utvm_runtime_api.h" #include + #include #include diff --git a/src/runtime/micro/standalone/utvm_runtime_api.h b/src/runtime/micro/standalone/utvm_runtime_api.h index 1b87052840d4..b38aa0a47a8c 100644 --- a/src/runtime/micro/standalone/utvm_runtime_api.h +++ b/src/runtime/micro/standalone/utvm_runtime_api.h @@ -21,6 +21,7 @@ #include #include + #include // The subset of the TVM runtime API that is implemented by the minimal runtime API. diff --git a/src/runtime/micro/target_data_layout_encoder.cc b/src/runtime/micro/target_data_layout_encoder.cc new file mode 100644 index 000000000000..4a87a8f35721 --- /dev/null +++ b/src/runtime/micro/target_data_layout_encoder.cc @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "target_data_layout_encoder.h" + +namespace tvm { +namespace runtime { + +TargetDataLayoutEncoder::Alloc::Alloc(TargetDataLayoutEncoder* parent, size_t start_offset, + size_t size, TargetPtr start_addr) + : parent_(parent), + start_offset_(start_offset), + curr_offset_(0), + size_(size), + start_addr_(start_addr) { + parent_->live_unchecked_allocs_.insert(this); +} + +TargetDataLayoutEncoder::Alloc::~Alloc() { + auto it = parent_->live_unchecked_allocs_.find(this); + if (it != parent_->live_unchecked_allocs_.end()) { + // alloc was not already checked + parent_->live_unchecked_allocs_.erase(it); + if (curr_offset_ != size_) { + parent_->unchecked_alloc_start_offsets_.push_back(start_addr_.value().uint64()); + } + } +} + +void TargetDataLayoutEncoder::Alloc::CheckUnfilled() { + CHECK(curr_offset_ == size_) << "unwritten space in alloc 0x" << std::hex + << start_addr_.value().uint64() << "; curr_offset=0x" << curr_offset_ + << ", size=0x" << size_; +} + +TargetPtr TargetDataLayoutEncoder::Alloc::start_addr() { return start_addr_; } + +size_t TargetDataLayoutEncoder::Alloc::size() { return size_; } + +void TargetDataLayoutEncoder::CheckUnfilledAllocs() { + CHECK(live_unchecked_allocs_.size() > 0) << "No allocs to check"; + if (unchecked_alloc_start_offsets_.size() > 0) { + LOG(ERROR) << "Unchecked allocs were found:"; + for (size_t alloc_start_addr : unchecked_alloc_start_offsets_) { + LOG(ERROR) << " * 0x" << std::hex << alloc_start_addr; + } + CHECK(false) << "Unchecked allocs found during CheckUnfilledAllocs"; + } + + for (class Alloc* s : live_unchecked_allocs_) { + s->CheckUnfilled(); + } + live_unchecked_allocs_.clear(); +} + +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/micro/target_data_layout_encoder.h b/src/runtime/micro/target_data_layout_encoder.h index e0275165e774..81587755e3b3 100644 --- a/src/runtime/micro/target_data_layout_encoder.h +++ b/src/runtime/micro/target_data_layout_encoder.h @@ -24,13 +24,17 @@ #ifndef TVM_RUNTIME_MICRO_TARGET_DATA_LAYOUT_ENCODER_H_ #define TVM_RUNTIME_MICRO_TARGET_DATA_LAYOUT_ENCODER_H_ +#include +#include #include -#include "host_driven/utvm_runtime.h" + +#include "host_driven/utvm_runtime_enum.h" +#include "micro_common.h" namespace tvm { namespace runtime { -// TODO(weberlo): Handle endianness. +// TODO(weberlo, areusch): Handle endianness. /*! * \brief data encoder for uTVM that builds a host-side buffer @@ -40,152 +44,157 @@ class TargetDataLayoutEncoder { /*! * \brief helper class for writing into `TargetDataLayoutEncoder` */ - template - class Slot { + class Alloc { public: /*! * \brief constructor * \param parent pointer to parent encoder - * \param start_offset start byte offset of the slot in the backing buffer - * \param size size (in bytes) of the memory region allocated for this slot - * \param start_addr start address of the slot in the device's memory + * \param start_offset start byte offset of the alloc in the backing buffer + * \param size size (in bytes) of the memory region allocated for this alloc + * \param start_addr start address of the alloc in the device's memory */ - Slot(TargetDataLayoutEncoder* parent, size_t start_offset, size_t size, DevPtr start_addr); + Alloc(TargetDataLayoutEncoder* parent, size_t start_offset, size_t size, TargetPtr start_addr); - ~Slot(); + ~Alloc(); /*! * \brief writes `sizeof(T) * num_elems` bytes of data from `arr` * \param arr array to be read from * \param num_elems number of elements in array */ + template void WriteArray(const T* arr, size_t num_elems); /*! * \brief writes `val` * \param val value to be written */ + template void WriteValue(const T& val); /*! - * \brief returns start address of the slot in device memory + * \brief returns start address of the alloc in device memory * \return device start address */ - DevPtr start_addr(); + TargetPtr start_addr(); /*! - * \brief returns number of bytes allocated for this slot - * \return size of this slot + * \brief returns number of bytes allocated for this alloc + * \return size of this alloc */ size_t size(); + size_t curr_offset() const { return curr_offset_; } + + void CheckUnfilled(); + private: /*! \brief pointer to parent encoder */ TargetDataLayoutEncoder* parent_; - /*! \brief start offset of the slot in the parent's backing parent_buffer */ + /*! \brief start offset of the alloc in the parent's backing parent_buffer */ size_t start_offset_; - /*! \brief current offset relative to the start offset of this slot */ + /*! \brief current offset relative to the start offset of this alloc */ size_t curr_offset_; - /*! \brief size (in bytes) of the memory region allocated for this slot */ + /*! \brief size (in bytes) of the memory region allocated for this alloc */ size_t size_; - /*! \brief start address of the slot in the device's memory */ - DevPtr start_addr_; + /*! \brief start address of the alloc in the device's memory */ + TargetPtr start_addr_; }; /*! * \brief constructor * \param start_addr start address of the encoder in device memory */ - explicit TargetDataLayoutEncoder(DevPtr start_addr, size_t word_size) - : buf_(std::vector()), curr_offset_(0), word_size_(word_size) { - start_addr_ = DevPtr(UpperAlignValue(start_addr.value().val64, word_size_)); - } + explicit TargetDataLayoutEncoder(size_t capacity, TargetWordSize word_size) + : buf_(std::vector()), + curr_offset_(0), + start_addr_(word_size, nullptr), + capacity_(capacity), + word_size_(word_size) {} /*! - * \brief allocates a slot for `sizeof(T) * num_elems` bytes of data + * \brief allocates a alloc for `sizeof(T) * num_elems` bytes of data * \param num_elems number of elements of type `T` being allocated (defaults to 1) - * \return slot of size `sizeof(T) * num_elems` bytes + * \return alloc of size `sizeof(T) * num_elems` bytes */ template - Slot Alloc(size_t num_elems = 1) { - curr_offset_ = UpperAlignValue(curr_offset_, word_size_); + std::unique_ptr Alloc(size_t num_elems = 1) { + curr_offset_ = UpperAlignValue(curr_offset_, word_size_.bytes()); size_t size = sizeof(T) * num_elems; if (curr_offset_ + size > buf_.size()) { buf_.resize(curr_offset_ + size); } - size_t slot_start_offset = curr_offset_; + CHECK(buf_.size() < capacity_) << "out of space in data encoder"; + size_t alloc_start_offset = curr_offset_; curr_offset_ += size; - return Slot(this, slot_start_offset, size, start_addr_ + slot_start_offset); + class Alloc* alloc = + new class Alloc(this, alloc_start_offset, size, start_addr() + alloc_start_offset); + return std::unique_ptr(alloc); + } + + void Clear() { + buf_.clear(); + curr_offset_ = 0; } /*! * \brief returns the array backing the encoder's buffer * \return array backing the encoder's buffer */ - uint8_t* data() { - return buf_.data(); - } + uint8_t* data() { return buf_.data(); } /*! * \brief returns current size of the encoder's buffer * \return buffer size */ - size_t buf_size() { - return buf_.size(); + size_t buf_size() const { return buf_.size(); } + + TargetPtr start_addr() const { + CHECK_NE(start_addr_.value().uint64(), 0) << "start addr uninitialized"; + return start_addr_; } + void set_start_addr(TargetPtr start_addr) { + CHECK_EQ(buf_.size(), 0) << "cannot change encoder start addr unless empty"; + start_addr_ = + TargetPtr(word_size_, UpperAlignValue(start_addr.value().uint64(), word_size_.bytes())); + } + + void CheckUnfilledAllocs(); + private: /*! \brief in-memory backing buffer */ std::vector buf_; /*! \brief current offset */ size_t curr_offset_; /*! \brief start address of the encoder in device memory */ - DevPtr start_addr_; + TargetPtr start_addr_; + /*! \brief number of bytes available in device memory */ + size_t capacity_; /*! \brief number of bytes in a word on the target device */ - size_t word_size_; + TargetWordSize word_size_; + /*! \brief Alloc instances allocated now but not yet checked by CheckUnfilledAllocs */ + std::set live_unchecked_allocs_; + /*! \brief start offsets Alloc instances that were dealloated before CheckUnfilledAllocs ran */ + std::vector unchecked_alloc_start_offsets_; + friend Alloc::~Alloc(); }; template -TargetDataLayoutEncoder::Slot::Slot(TargetDataLayoutEncoder* parent, - size_t start_offset, - size_t size, - DevPtr start_addr) - : parent_(parent), - start_offset_(start_offset), - curr_offset_(0), - size_(size), - start_addr_(start_addr) {} - -template -TargetDataLayoutEncoder::Slot::~Slot() { - CHECK(curr_offset_ == size_) << "unwritten space in slot"; -} - -template -void TargetDataLayoutEncoder::Slot::WriteArray(const T* arr, size_t num_elems) { +void TargetDataLayoutEncoder::Alloc::WriteArray(const T* arr, size_t num_elems) { if (num_elems == 0) return; size_t size = sizeof(T) * num_elems; - CHECK(curr_offset_ + size <= size_) << "not enough space in slot"; + CHECK(curr_offset_ + size <= size_) << "not enough space in alloc"; uint8_t* curr_ptr = &(parent_->data())[start_offset_ + curr_offset_]; std::memcpy(curr_ptr, arr, size); curr_offset_ += size; } template -void TargetDataLayoutEncoder::Slot::WriteValue(const T& val) { +void TargetDataLayoutEncoder::Alloc::WriteValue(const T& val) { WriteArray(&val, 1); } -template -DevPtr TargetDataLayoutEncoder::Slot::start_addr() { - return start_addr_; -} - -template -size_t TargetDataLayoutEncoder::Slot::size() { - return size_; -} - } // namespace runtime } // namespace tvm #endif // TVM_RUNTIME_MICRO_TARGET_DATA_LAYOUT_ENCODER_H_ diff --git a/src/runtime/micro/tcl_socket.cc b/src/runtime/micro/tcl_socket.cc index 64dfbf218388..8f482b874260 100644 --- a/src/runtime/micro/tcl_socket.cc +++ b/src/runtime/micro/tcl_socket.cc @@ -20,10 +20,10 @@ /*! * \file tcl_socket.cc */ -#include - #include "tcl_socket.h" +#include + namespace tvm { namespace runtime { @@ -33,9 +33,7 @@ TclSocket::TclSocket() { reply_buf_.reserve(kReplyBufSize); } -TclSocket::~TclSocket() { - tcp_socket_.Close(); -} +TclSocket::~TclSocket() { tcp_socket_.Close(); } void TclSocket::Connect(tvm::support::SockAddr addr) { CHECK(tcp_socket_.Connect(addr)) << "failed to connect"; @@ -45,8 +43,8 @@ void TclSocket::SendCommand() { const char terminate_token = kCommandTerminateToken; cmd_builder_ << terminate_token; std::string full_cmd = cmd_builder_.str(); - CHECK(tcp_socket_.Send(full_cmd.data(), full_cmd.length()) != -1) - << "failed to send command"; + + CHECK(tcp_socket_.Send(full_cmd.data(), full_cmd.length()) != -1) << "failed to send command"; cmd_builder_.str(std::string()); reply_builder_.str(std::string()); @@ -66,8 +64,7 @@ void TclSocket::SendCommand() { CHECK(bytes_read != -1) << "failed to read command reply"; } while (last_read != terminate_token); last_reply_ = reply_builder_.str(); - CHECK_EQ(last_reply_[last_reply_.length()-1], terminate_token) - << "missing command terminator"; + CHECK_EQ(last_reply_[last_reply_.length() - 1], terminate_token) << "missing command terminator"; } } // namespace runtime diff --git a/src/runtime/micro/tcl_socket.h b/src/runtime/micro/tcl_socket.h index 0b23e7f1b07f..4aef2aef36e2 100644 --- a/src/runtime/micro/tcl_socket.h +++ b/src/runtime/micro/tcl_socket.h @@ -66,12 +66,12 @@ class TclSocket { /* * \return string stream for current command being built - */ + */ std::ostringstream& cmd_builder() { return cmd_builder_; } /* * \return reply from most recently sent command - */ + */ const std::string& last_reply() { return last_reply_; } private: diff --git a/src/runtime/module.cc b/src/runtime/module.cc index f03579531ea4..46ef6fab082b 100644 --- a/src/runtime/module.cc +++ b/src/runtime/module.cc @@ -22,10 +22,12 @@ * \brief TVM module system */ #include -#include #include -#include +#include + #include +#include + #include "file_util.h" namespace tvm { @@ -36,7 +38,7 @@ void ModuleNode::Import(Module other) { if (!std::strcmp(this->type_key(), "rpc")) { static const PackedFunc* fimport_ = nullptr; if (fimport_ == nullptr) { - fimport_ = runtime::Registry::Get("rpc._ImportRemoteModule"); + fimport_ = runtime::Registry::Get("rpc.ImportRemoteModule"); CHECK(fimport_ != nullptr); } (*fimport_)(GetRef(this), other); @@ -55,8 +57,7 @@ void ModuleNode::Import(Module other) { stack.push_back(next); } } - CHECK(!visited.count(this)) - << "Cyclic dependency detected during import"; + CHECK(!visited.count(this)) << "Cyclic dependency detected during import"; this->imports_.emplace_back(std::move(other)); } @@ -73,29 +74,20 @@ PackedFunc ModuleNode::GetFunction(const std::string& name, bool query_imports) return pf; } -Module Module::LoadFromFile(const std::string& file_name, - const std::string& format) { +Module Module::LoadFromFile(const std::string& file_name, const std::string& format) { std::string fmt = GetFileFormat(file_name, format); - CHECK(fmt.length() != 0) - << "Cannot deduce format of file " << file_name; + CHECK(fmt.length() != 0) << "Cannot deduce format of file " << file_name; if (fmt == "dll" || fmt == "dylib" || fmt == "dso") { fmt = "so"; } std::string load_f_name = "runtime.module.loadfile_" + fmt; const PackedFunc* f = Registry::Get(load_f_name); - CHECK(f != nullptr) - << "Loader of " << format << "(" - << load_f_name << ") is not presented."; + CHECK(f != nullptr) << "Loader of " << format << "(" << load_f_name << ") is not presented."; Module m = (*f)(file_name, format); return m; } -bool Module::IsEmpty() const { - return this->operator->() == nullptr; -} - -void ModuleNode::SaveToFile(const std::string& file_name, - const std::string& format) { +void ModuleNode::SaveToFile(const std::string& file_name, const std::string& format) { LOG(FATAL) << "Module[" << type_key() << "] does not support SaveToFile"; } @@ -118,9 +110,8 @@ const PackedFunc* ModuleNode::GetFuncFromEnv(const std::string& name) { } if (pf == nullptr) { const PackedFunc* f = Registry::Get(name); - CHECK(f != nullptr) - << "Cannot find function " << name - << " in the imported modules or global registry"; + CHECK(f != nullptr) << "Cannot find function " << name + << " in the imported modules or global registry"; return f; } else { import_cache_.insert(std::make_pair(name, std::make_shared(pf))); @@ -136,10 +127,10 @@ bool RuntimeEnabled(const std::string& target) { f_name = "device_api.gpu"; } else if (target == "cl" || target == "opencl" || target == "sdaccel") { f_name = "device_api.opencl"; - } else if (target == "gl" || target == "opengl") { - f_name = "device_api.opengl"; } else if (target == "mtl" || target == "metal") { f_name = "device_api.metal"; + } else if (target == "tflite") { + f_name = "target.runtime.tflite"; } else if (target == "vulkan") { f_name = "device_api.vulkan"; } else if (target == "stackvm") { @@ -162,47 +153,31 @@ bool RuntimeEnabled(const std::string& target) { return runtime::Registry::Get(f_name) != nullptr; } -TVM_REGISTER_GLOBAL("runtime.RuntimeEnabled") -.set_body_typed(RuntimeEnabled); +TVM_REGISTER_GLOBAL("runtime.RuntimeEnabled").set_body_typed(RuntimeEnabled); -TVM_REGISTER_GLOBAL("runtime.ModuleGetSource") -.set_body_typed([](Module mod, std::string fmt) { +TVM_REGISTER_GLOBAL("runtime.ModuleGetSource").set_body_typed([](Module mod, std::string fmt) { return mod->GetSource(fmt); }); -TVM_REGISTER_GLOBAL("runtime.ModuleImportsSize") -.set_body_typed([](Module mod) { +TVM_REGISTER_GLOBAL("runtime.ModuleImportsSize").set_body_typed([](Module mod) { return static_cast(mod->imports().size()); }); -TVM_REGISTER_GLOBAL("runtime.ModuleGetImport") -.set_body_typed([](Module mod, int index) { +TVM_REGISTER_GLOBAL("runtime.ModuleGetImport").set_body_typed([](Module mod, int index) { return mod->imports().at(index); }); -TVM_REGISTER_GLOBAL("runtime.ModuleGetTypeKey") -.set_body_typed([](Module mod) { +TVM_REGISTER_GLOBAL("runtime.ModuleGetTypeKey").set_body_typed([](Module mod) { return std::string(mod->type_key()); }); -TVM_REGISTER_GLOBAL("runtime.ModuleLoadFromFile") -.set_body_typed(Module::LoadFromFile); +TVM_REGISTER_GLOBAL("runtime.ModuleLoadFromFile").set_body_typed(Module::LoadFromFile); TVM_REGISTER_GLOBAL("runtime.ModuleSaveToFile") -.set_body_typed([](Module mod, std::string name, std::string fmt) { - mod->SaveToFile(name, fmt); -}); - -TVM_REGISTER_GLOBAL("runtime.IsEmpty") -.set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = args[0].operator Module().IsEmpty(); -}); - -TVM_REGISTER_GLOBAL("runtime.CreateEmptyModule") -.set_body([](TVMArgs args, TVMRetValue *ret) { - Module m; - *ret = m; -}); + .set_body_typed([](Module mod, std::string name, std::string fmt) { + mod->SaveToFile(name, fmt); + }); +TVM_REGISTER_OBJECT_TYPE(ModuleNode); } // namespace runtime } // namespace tvm diff --git a/src/runtime/ndarray.cc b/src/runtime/ndarray.cc index ac12472a903e..800a9167dadc 100644 --- a/src/runtime/ndarray.cc +++ b/src/runtime/ndarray.cc @@ -22,9 +22,10 @@ * \brief NDArray container infratructure. */ #include -#include #include #include +#include + #include "runtime_base.h" extern "C" { @@ -45,9 +46,12 @@ inline void VerifyDataType(DLDataType dtype) { // allow uint1 as a special flag for bool. if (dtype.bits == 1 && dtype.code == kDLUInt) return; // allow int1/uint4/int4 - else if (dtype.bits == 1 && dtype.code == kDLInt) return; - else if (dtype.bits == 4 && dtype.code == kDLUInt) return; - else if (dtype.bits == 4 && dtype.code == kDLInt) return; + else if (dtype.bits == 1 && dtype.code == kDLInt) + return; + else if (dtype.bits == 4 && dtype.code == kDLUInt) + return; + else if (dtype.bits == 4 && dtype.code == kDLInt) + return; else CHECK_EQ(dtype.bits % 8, 0); } @@ -65,12 +69,10 @@ void ArrayCopyFromBytes(DLTensor* handle, const void* data, size_t nbytes) { cpu_ctx.device_type = kDLCPU; cpu_ctx.device_id = 0; size_t arr_size = GetDataSize(*handle); - CHECK_EQ(arr_size, nbytes) - << "ArrayCopyFromBytes: size mismatch"; - DeviceAPI::Get(handle->ctx)->CopyDataFromTo( - data, 0, - handle->data, static_cast(handle->byte_offset), - nbytes, cpu_ctx, handle->ctx, handle->dtype, nullptr); + CHECK_EQ(arr_size, nbytes) << "ArrayCopyFromBytes: size mismatch"; + DeviceAPI::Get(handle->ctx) + ->CopyDataFromTo(data, 0, handle->data, static_cast(handle->byte_offset), nbytes, + cpu_ctx, handle->ctx, handle->dtype, nullptr); } void ArrayCopyToBytes(const DLTensor* handle, void* data, size_t nbytes) { @@ -78,12 +80,10 @@ void ArrayCopyToBytes(const DLTensor* handle, void* data, size_t nbytes) { cpu_ctx.device_type = kDLCPU; cpu_ctx.device_id = 0; size_t arr_size = GetDataSize(*handle); - CHECK_EQ(arr_size, nbytes) - << "ArrayCopyToBytes: size mismatch"; - DeviceAPI::Get(handle->ctx)->CopyDataFromTo( - handle->data, static_cast(handle->byte_offset), - data, 0, - nbytes, handle->ctx, cpu_ctx, handle->dtype, nullptr); + CHECK_EQ(arr_size, nbytes) << "ArrayCopyToBytes: size mismatch"; + DeviceAPI::Get(handle->ctx) + ->CopyDataFromTo(handle->data, static_cast(handle->byte_offset), data, 0, nbytes, + handle->ctx, cpu_ctx, handle->dtype, nullptr); } struct NDArray::Internal { @@ -93,8 +93,8 @@ struct NDArray::Internal { if (ptr->manager_ctx != nullptr) { static_cast(ptr->manager_ctx)->DecRef(); } else if (ptr->dl_tensor.data != nullptr) { - tvm::runtime::DeviceAPI::Get(ptr->dl_tensor.ctx)->FreeDataSpace( - ptr->dl_tensor.ctx, ptr->dl_tensor.data); + tvm::runtime::DeviceAPI::Get(ptr->dl_tensor.ctx) + ->FreeDataSpace(ptr->dl_tensor.ctx, ptr->dl_tensor.data); } delete ptr; } @@ -113,9 +113,7 @@ struct NDArray::Internal { } // Local create function which allocates tensor metadata // but does not allocate space for the data. - static NDArray Create(std::vector shape, - DLDataType dtype, - DLContext ctx) { + static NDArray Create(std::vector shape, DLDataType dtype, DLContext ctx) { VerifyDataType(dtype); // critical zone: construct header @@ -140,13 +138,11 @@ struct NDArray::Internal { ObjectRef::FFIClearAfterMove(&arr); return handle; } - static void FFIDecRef(TVMArrayHandle tensor) { - NDArray::FFIDecRef(tensor); - } + static void FFIDecRef(TVMArrayHandle tensor) { NDArray::FFIDecRef(tensor); } // Container to DLManagedTensor static DLManagedTensor* ToDLPack(TVMArrayHandle handle) { - auto* from = static_cast( - reinterpret_cast(handle)); + auto* from = + static_cast(reinterpret_cast(handle)); return ToDLPack(from); } @@ -168,11 +164,9 @@ struct NDArray::Internal { NDArray NDArray::CreateView(std::vector shape, DLDataType dtype) { CHECK(data_ != nullptr); - CHECK(get_mutable()->dl_tensor.strides == nullptr) - << "Can only create view for compact tensor"; + CHECK(get_mutable()->dl_tensor.strides == nullptr) << "Can only create view for compact tensor"; NDArray ret = Internal::Create(shape, dtype, get_mutable()->dl_tensor.ctx); - ret.get_mutable()->dl_tensor.byte_offset = - this->get_mutable()->dl_tensor.byte_offset; + ret.get_mutable()->dl_tensor.byte_offset = this->get_mutable()->dl_tensor.byte_offset; size_t curr_size = GetDataSize(this->get_mutable()->dl_tensor); size_t view_size = GetDataSize(ret.get_mutable()->dl_tensor); CHECK_LE(view_size, curr_size) @@ -184,20 +178,15 @@ NDArray NDArray::CreateView(std::vector shape, DLDataType dtype) { return ret; } -DLManagedTensor* NDArray::ToDLPack() const { - return Internal::ToDLPack(get_mutable()); -} +DLManagedTensor* NDArray::ToDLPack() const { return Internal::ToDLPack(get_mutable()); } -NDArray NDArray::Empty(std::vector shape, - DLDataType dtype, - DLContext ctx) { +NDArray NDArray::Empty(std::vector shape, DLDataType dtype, DLContext ctx) { NDArray ret = Internal::Create(shape, dtype, ctx); // setup memory content size_t size = GetDataSize(ret.get_mutable()->dl_tensor); size_t alignment = GetDataAlignment(ret.get_mutable()->dl_tensor); ret.get_mutable()->dl_tensor.data = - DeviceAPI::Get(ret->ctx)->AllocDataSpace( - ret->ctx, size, alignment, ret->dtype); + DeviceAPI::Get(ret->ctx)->AllocDataSpace(ret->ctx, size, alignment, ret->dtype); return ret; } @@ -227,33 +216,28 @@ void NDArray::CopyFromBytes(const void* data, size_t nbytes) { ArrayCopyFromBytes(&get_mutable()->dl_tensor, data, nbytes); } -void NDArray::CopyFromTo(const DLTensor* from, - DLTensor* to, - TVMStreamHandle stream) { +void NDArray::CopyFromTo(const DLTensor* from, DLTensor* to, TVMStreamHandle stream) { size_t from_size = GetDataSize(*from); size_t to_size = GetDataSize(*to); - CHECK_EQ(from_size, to_size) - << "TVMArrayCopyFromTo: The size must exactly match"; + CHECK_EQ(from_size, to_size) << "TVMArrayCopyFromTo: The size must exactly match"; - CHECK(from->ctx.device_type == to->ctx.device_type - || from->ctx.device_type == kDLCPU - || to->ctx.device_type == kDLCPU - || from->ctx.device_type == kDLCPUPinned - || to->ctx.device_type == kDLCPUPinned) - << "Can not copy across different ctx types directly"; + CHECK(from->ctx.device_type == to->ctx.device_type || from->ctx.device_type == kDLCPU || + to->ctx.device_type == kDLCPU || from->ctx.device_type == kDLCPUPinned || + to->ctx.device_type == kDLCPUPinned) + << "Can not copy across different ctx types directly"; // Use the context that is *not* a cpu context to get the correct device // api manager. TVMContext ctx = from->ctx.device_type != kDLCPU ? from->ctx : to->ctx; - DeviceAPI::Get(ctx)->CopyDataFromTo( - from->data, static_cast(from->byte_offset), - to->data, static_cast(to->byte_offset), - from_size, from->ctx, to->ctx, from->dtype, stream); + DeviceAPI::Get(ctx)->CopyDataFromTo(from->data, static_cast(from->byte_offset), to->data, + static_cast(to->byte_offset), from_size, from->ctx, + to->ctx, from->dtype, stream); } -std::vector NDArray::Shape() const { - return get_mutable()->shape_; +std::vector NDArray::Shape() const { return get_mutable()->shape_; } +runtime::DataType NDArray::DataType() const { + return runtime::DataType(get_mutable()->dl_tensor.dtype); } TVM_REGISTER_OBJECT_TYPE(NDArray::Container); @@ -273,14 +257,8 @@ int TVMArrayGetTypeIndex(TVMArrayHandle handle, unsigned* out_tindex) { API_END(); } -int TVMArrayAlloc(const tvm_index_t* shape, - int ndim, - int dtype_code, - int dtype_bits, - int dtype_lanes, - int device_type, - int device_id, - TVMArrayHandle* out) { +int TVMArrayAlloc(const tvm_index_t* shape, int ndim, int dtype_code, int dtype_bits, + int dtype_lanes, int device_type, int device_id, TVMArrayHandle* out) { API_BEGIN(); DLDataType dtype; dtype.code = static_cast(dtype_code); @@ -300,43 +278,33 @@ int TVMArrayFree(TVMArrayHandle handle) { API_END(); } -int TVMArrayCopyFromTo(TVMArrayHandle from, - TVMArrayHandle to, - TVMStreamHandle stream) { +int TVMArrayCopyFromTo(TVMArrayHandle from, TVMArrayHandle to, TVMStreamHandle stream) { API_BEGIN(); NDArray::CopyFromTo(from, to, stream); API_END(); } -int TVMArrayFromDLPack(DLManagedTensor* from, - TVMArrayHandle* out) { +int TVMArrayFromDLPack(DLManagedTensor* from, TVMArrayHandle* out) { API_BEGIN(); *out = NDArray::Internal::MoveToFFIHandle(NDArray::FromDLPack(from)); API_END(); } -int TVMArrayToDLPack(TVMArrayHandle from, - DLManagedTensor** out) { +int TVMArrayToDLPack(TVMArrayHandle from, DLManagedTensor** out) { API_BEGIN(); *out = NDArray::Internal::ToDLPack(from); API_END(); } -void TVMDLManagedTensorCallDeleter(DLManagedTensor* dltensor) { - (*(dltensor->deleter))(dltensor); -} +void TVMDLManagedTensorCallDeleter(DLManagedTensor* dltensor) { (*(dltensor->deleter))(dltensor); } -int TVMArrayCopyFromBytes(TVMArrayHandle handle, - void* data, - size_t nbytes) { +int TVMArrayCopyFromBytes(TVMArrayHandle handle, void* data, size_t nbytes) { API_BEGIN(); ArrayCopyFromBytes(handle, data, nbytes); API_END(); } -int TVMArrayCopyToBytes(TVMArrayHandle handle, - void* data, - size_t nbytes) { +int TVMArrayCopyToBytes(TVMArrayHandle handle, void* data, size_t nbytes) { API_BEGIN(); ArrayCopyToBytes(handle, data, nbytes); API_END(); diff --git a/src/runtime/object.cc b/src/runtime/object.cc index 0d85b9dab42c..dc5f1ceabbae 100644 --- a/src/runtime/object.cc +++ b/src/runtime/object.cc @@ -21,13 +21,16 @@ * \brief Object type management system. */ #include -#include #include +#include + +#include #include #include -#include -#include #include +#include +#include + #include "object_internal.h" #include "runtime_base.h" @@ -75,10 +78,8 @@ class TypeContext { return child_tindex == parent_tindex; } - uint32_t GetOrAllocRuntimeTypeIndex(const std::string& skey, - uint32_t static_tindex, - uint32_t parent_tindex, - uint32_t num_child_slots, + uint32_t GetOrAllocRuntimeTypeIndex(const std::string& skey, uint32_t static_tindex, + uint32_t parent_tindex, uint32_t num_child_slots, bool child_slots_can_overflow) { std::lock_guard lock(mutex_); auto it = type_key2index_.find(skey); @@ -86,7 +87,8 @@ class TypeContext { return it->second; } // try to allocate from parent's type table. - CHECK_LT(parent_tindex, type_table_.size()); + CHECK_LT(parent_tindex, type_table_.size()) + << " skey= " << skey << "static_index=" << static_tindex; TypeInfo& pinfo = type_table_[parent_tindex]; CHECK_EQ(pinfo.index, parent_tindex); @@ -104,11 +106,9 @@ class TypeContext { allocated_tindex = static_tindex; CHECK_LT(static_tindex, type_table_.size()); CHECK_EQ(type_table_[allocated_tindex].allocated_slots, 0U) - << "Conflicting static index " << static_tindex - << " between " << type_table_[allocated_tindex].name - << " and " - << skey; - } else if (pinfo.allocated_slots + num_slots < pinfo.num_slots) { + << "Conflicting static index " << static_tindex << " between " + << type_table_[allocated_tindex].name << " and " << skey; + } else if (pinfo.allocated_slots + num_slots <= pinfo.num_slots) { // allocate the slot from parent's reserved pool allocated_tindex = parent_tindex + pinfo.allocated_slots; // update parent's state @@ -119,8 +119,8 @@ class TypeContext { // allocate new entries. allocated_tindex = type_counter_; type_counter_ += num_slots; - CHECK_LE(type_table_.size(), allocated_tindex); - type_table_.resize(allocated_tindex + 1, TypeInfo()); + CHECK_LE(type_table_.size(), type_counter_); + type_table_.resize(type_counter_, TypeInfo()); } CHECK_GT(allocated_tindex, parent_tindex); // initialize the slot. @@ -128,8 +128,7 @@ class TypeContext { type_table_[allocated_tindex].parent_index = parent_tindex; type_table_[allocated_tindex].num_slots = num_slots; type_table_[allocated_tindex].allocated_slots = 1; - type_table_[allocated_tindex].child_slots_can_overflow = - child_slots_can_overflow; + type_table_[allocated_tindex].child_slots_can_overflow = child_slots_can_overflow; type_table_[allocated_tindex].name = skey; type_table_[allocated_tindex].name_hash = std::hash()(skey); // update the key2index mapping. @@ -139,16 +138,14 @@ class TypeContext { std::string TypeIndex2Key(uint32_t tindex) { std::lock_guard lock(mutex_); - CHECK(tindex < type_table_.size() && - type_table_[tindex].allocated_slots != 0) + CHECK(tindex < type_table_.size() && type_table_[tindex].allocated_slots != 0) << "Unknown type index " << tindex; return type_table_[tindex].name; } size_t TypeIndex2KeyHash(uint32_t tindex) { std::lock_guard lock(mutex_); - CHECK(tindex < type_table_.size() && - type_table_[tindex].allocated_slots != 0) + CHECK(tindex < type_table_.size() && type_table_[tindex].allocated_slots != 0) << "Unknown type index " << tindex; return type_table_[tindex].name_hash; } @@ -161,6 +158,25 @@ class TypeContext { return it->second; } + void Dump(int min_children_count) { + std::vector num_children(type_table_.size(), 0); + // reverse accumulation so we can get total counts in a bottom-up manner. + for (auto it = type_table_.rbegin(); it != type_table_.rend(); ++it) { + if (it->index != 0) { + num_children[it->parent_index] += num_children[it->index] + 1; + } + } + + for (const auto& info : type_table_) { + if (info.index != 0 && num_children[info.index] >= min_children_count) { + std::cerr << '[' << info.index << "] " << info.name + << "\tparent=" << type_table_[info.parent_index].name + << "\tnum_child_slots=" << info.num_slots - 1 + << "\tnum_children=" << num_children[info.index] << std::endl; + } + } + } + static TypeContext* Global() { static TypeContext inst; return &inst; @@ -169,6 +185,7 @@ class TypeContext { private: TypeContext() { type_table_.resize(TypeIndex::kStaticIndexEnd, TypeInfo()); + type_table_[0].name = "runtime.Object"; } // mutex to avoid registration from multiple threads. std::mutex mutex_; @@ -177,18 +194,15 @@ class TypeContext { std::unordered_map type_key2index_; }; -uint32_t Object::GetOrAllocRuntimeTypeIndex(const std::string& key, - uint32_t static_tindex, - uint32_t parent_tindex, - uint32_t num_child_slots, +uint32_t Object::GetOrAllocRuntimeTypeIndex(const std::string& key, uint32_t static_tindex, + uint32_t parent_tindex, uint32_t num_child_slots, bool child_slots_can_overflow) { return TypeContext::Global()->GetOrAllocRuntimeTypeIndex( key, static_tindex, parent_tindex, num_child_slots, child_slots_can_overflow); } bool Object::DerivedFrom(uint32_t parent_tindex) const { - return TypeContext::Global()->DerivedFrom( - this->type_index_, parent_tindex); + return TypeContext::Global()->DerivedFrom(this->type_index_, parent_tindex); } std::string Object::TypeIndex2Key(uint32_t tindex) { @@ -203,10 +217,12 @@ uint32_t Object::TypeKey2Index(const std::string& key) { return TypeContext::Global()->TypeKey2Index(key); } +TVM_REGISTER_GLOBAL("runtime.ObjectPtrHash").set_body_typed([](ObjectRef obj) { + return static_cast(ObjectPtrHash()(obj)); +}); -TVM_REGISTER_GLOBAL("runtime.ObjectHash") -.set_body_typed([](ObjectRef obj) { - return static_cast(ObjectHash()(obj)); +TVM_REGISTER_GLOBAL("runtime.DumpTypeTable").set_body_typed([](int min_child_count) { + TypeContext::Global()->Dump(min_child_count); }); } // namespace runtime } // namespace tvm @@ -218,15 +234,27 @@ int TVMObjectGetTypeIndex(TVMObjectHandle obj, unsigned* out_tindex) { API_END(); } +int TVMObjectRetain(TVMObjectHandle obj) { + API_BEGIN(); + tvm::runtime::ObjectInternal::ObjectRetain(obj); + API_END(); +} + int TVMObjectFree(TVMObjectHandle obj) { API_BEGIN(); tvm::runtime::ObjectInternal::ObjectFree(obj); API_END(); } +int TVMObjectDerivedFrom(uint32_t child_type_index, uint32_t parent_type_index, int* is_derived) { + API_BEGIN(); + *is_derived = + tvm::runtime::TypeContext::Global()->DerivedFrom(child_type_index, parent_type_index); + API_END(); +} + int TVMObjectTypeKey2Index(const char* type_key, unsigned* out_tindex) { API_BEGIN(); - out_tindex[0] = tvm::runtime::ObjectInternal::ObjectTypeKey2Index( - type_key); + out_tindex[0] = tvm::runtime::ObjectInternal::ObjectTypeKey2Index(type_key); API_END(); } diff --git a/src/runtime/object_internal.h b/src/runtime/object_internal.h index 79551309d67c..f255b28ad04c 100644 --- a/src/runtime/object_internal.h +++ b/src/runtime/object_internal.h @@ -24,8 +24,9 @@ #ifndef TVM_RUNTIME_OBJECT_INTERNAL_H_ #define TVM_RUNTIME_OBJECT_INTERNAL_H_ -#include #include +#include + #include namespace tvm { @@ -37,6 +38,15 @@ namespace runtime { */ class ObjectInternal { public: + /*! + * \brief Retain an object handle. + */ + static void ObjectRetain(TVMObjectHandle obj) { + if (obj != nullptr) { + static_cast(obj)->IncRef(); + } + } + /*! * \brief Free an object handle. */ @@ -45,6 +55,15 @@ class ObjectInternal { static_cast(obj)->DecRef(); } } + /*! + * \brief Check of obj derives from the type indicated by type index. + * \param obj The original object. + * \param type_index The type index of interest. + * \return The derivation checking result. + */ + static bool DerivedFrom(const Object* obj, uint32_t type_index) { + return obj->DerivedFrom(type_index); + } /*! * \brief Expose TypeKey2Index * \param type_key The original type key. @@ -68,4 +87,4 @@ class ObjectInternal { } // namespace runtime } // namespace tvm -#endif // TVM_RUNTIME_OBJECT_INTERNAL_H_ +#endif // TVM_RUNTIME_OBJECT_INTERNAL_H_ diff --git a/src/runtime/opencl/aocl/aocl_common.h b/src/runtime/opencl/aocl/aocl_common.h index d9251f8aaf53..1b98d4b2d221 100644 --- a/src/runtime/opencl/aocl/aocl_common.h +++ b/src/runtime/opencl/aocl/aocl_common.h @@ -25,6 +25,7 @@ #define TVM_RUNTIME_OPENCL_AOCL_AOCL_COMMON_H_ #include + #include "../opencl_common.h" namespace tvm { @@ -44,7 +45,6 @@ class AOCLWorkspace final : public OpenCLWorkspace { static const std::shared_ptr& Global(); }; - /*! \brief Thread local workspace for AOCL */ class AOCLThreadEntry : public OpenCLThreadEntry { public: diff --git a/src/runtime/opencl/aocl/aocl_device_api.cc b/src/runtime/opencl/aocl/aocl_device_api.cc index 84c29eea33ec..07057ff29716 100644 --- a/src/runtime/opencl/aocl/aocl_device_api.cc +++ b/src/runtime/opencl/aocl/aocl_device_api.cc @@ -20,17 +20,16 @@ /*! * \file aocl_device_api.cc */ -#include #include +#include + #include "aocl_common.h" namespace tvm { namespace runtime { namespace cl { -OpenCLThreadEntry* AOCLWorkspace::GetThreadEntry() { - return AOCLThreadEntry::ThreadLocal(); -} +OpenCLThreadEntry* AOCLWorkspace::GetThreadEntry() { return AOCLThreadEntry::ThreadLocal(); } const std::shared_ptr& AOCLWorkspace::Global() { static std::shared_ptr inst = std::make_shared(); @@ -47,15 +46,12 @@ bool AOCLWorkspace::IsOpenCLDevice(TVMContext ctx) { typedef dmlc::ThreadLocalStore AOCLThreadStore; -AOCLThreadEntry* AOCLThreadEntry::ThreadLocal() { - return AOCLThreadStore::Get(); -} +AOCLThreadEntry* AOCLThreadEntry::ThreadLocal() { return AOCLThreadStore::Get(); } -TVM_REGISTER_GLOBAL("device_api.aocl") -.set_body([](TVMArgs args, TVMRetValue* rv) { - DeviceAPI* ptr = AOCLWorkspace::Global().get(); - *rv = static_cast(ptr); - }); +TVM_REGISTER_GLOBAL("device_api.aocl").set_body([](TVMArgs args, TVMRetValue* rv) { + DeviceAPI* ptr = AOCLWorkspace::Global().get(); + *rv = static_cast(ptr); +}); } // namespace cl } // namespace runtime diff --git a/src/runtime/opencl/aocl/aocl_module.cc b/src/runtime/opencl/aocl/aocl_module.cc index abda5b179a6a..747188cf7b2d 100644 --- a/src/runtime/opencl/aocl/aocl_module.cc +++ b/src/runtime/opencl/aocl/aocl_module.cc @@ -20,23 +20,24 @@ /*! * \file aocl_module.cc */ +#include "aocl_module.h" + #include #include -#include + #include #include +#include + #include "aocl_common.h" -#include "aocl_module.h" namespace tvm { namespace runtime { class AOCLModuleNode : public OpenCLModuleNode { public: - explicit AOCLModuleNode(std::string data, - std::string fmt, - std::unordered_map fmap, - std::string source) + explicit AOCLModuleNode(std::string data, std::string fmt, + std::unordered_map fmap, std::string source) : OpenCLModuleNode(data, fmt, fmap, source) {} const std::shared_ptr& GetGlobalWorkspace() final; }; @@ -45,18 +46,14 @@ const std::shared_ptr& AOCLModuleNode::GetGlobalWorkspace() return cl::AOCLWorkspace::Global(); } -Module AOCLModuleCreate( - std::string data, - std::string fmt, - std::unordered_map fmap, - std::string source) { +Module AOCLModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, std::string source) { auto n = make_object(data, fmt, fmap, source); n->Init(); return Module(n); } -Module AOCLModuleLoadFile(const std::string& file_name, - const std::string& format) { +Module AOCLModuleLoadFile(const std::string& file_name, const std::string& format) { std::string data; std::unordered_map fmap; std::string fmt = GetFileFormat(file_name, format); @@ -66,8 +63,7 @@ Module AOCLModuleLoadFile(const std::string& file_name, return AOCLModuleCreate(data, fmt, fmap, std::string()); } -TVM_REGISTER_GLOBAL("runtime.module.loadfile_aocx") -.set_body_typed(AOCLModuleLoadFile); +TVM_REGISTER_GLOBAL("runtime.module.loadfile_aocx").set_body_typed(AOCLModuleLoadFile); } // namespace runtime } // namespace tvm diff --git a/src/runtime/opencl/aocl/aocl_module.h b/src/runtime/opencl/aocl/aocl_module.h index 70955cc65528..199a94decdd8 100644 --- a/src/runtime/opencl/aocl/aocl_module.h +++ b/src/runtime/opencl/aocl/aocl_module.h @@ -25,10 +25,12 @@ #define TVM_RUNTIME_OPENCL_AOCL_AOCL_MODULE_H_ #include + #include -#include #include #include +#include + #include "../../meta_data.h" namespace tvm { @@ -40,11 +42,8 @@ namespace runtime { * \param fmt The format of the data, can be "aocx" * \param fmap The map function information map of each function. */ -Module AOCLModuleCreate( - std::string data, - std::string fmt, - std::unordered_map fmap, - std::string source); +Module AOCLModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, std::string source); } // namespace runtime } // namespace tvm #endif // TVM_RUNTIME_OPENCL_AOCL_AOCL_MODULE_H_ diff --git a/src/runtime/opencl/opencl_common.h b/src/runtime/opencl/opencl_common.h index 8f9d5d6352ba..a892bff75342 100644 --- a/src/runtime/opencl/opencl_common.h +++ b/src/runtime/opencl/opencl_common.h @@ -24,10 +24,10 @@ #ifndef TVM_RUNTIME_OPENCL_OPENCL_COMMON_H_ #define TVM_RUNTIME_OPENCL_OPENCL_COMMON_H_ +#include #include -#include #include -#include +#include /* There are many OpenCL platforms that do not yet support OpenCL 2.0, * hence we use 1.2 APIs, some of which are now deprecated. In order @@ -45,73 +45,120 @@ #include #endif +#include #include #include -#include -#include #include -#include "../workspace_pool.h" +#include + +#include "../file_util.h" +#include "../meta_data.h" #include "../pack_args.h" #include "../thread_storage_scope.h" -#include "../meta_data.h" -#include "../file_util.h" +#include "../workspace_pool.h" namespace tvm { namespace runtime { namespace cl { -static_assert(sizeof(cl_mem) ==sizeof(void*), - "Required to store cl_mem inside void*"); +static_assert(sizeof(cl_mem) == sizeof(void*), "Required to store cl_mem inside void*"); inline const char* CLGetErrorString(cl_int error) { switch (error) { - case CL_SUCCESS: return "CL_SUCCESS"; - case CL_DEVICE_NOT_FOUND: return "CL_DEVICE_NOT_FOUND"; - case CL_DEVICE_NOT_AVAILABLE: return "CL_DEVICE_NOT_AVAILABLE"; - case CL_COMPILER_NOT_AVAILABLE: return "CL_COMPILER_NOT_AVAILABLE"; - case CL_MEM_OBJECT_ALLOCATION_FAILURE: return "CL_MEM_OBJECT_ALLOCATION_FAILURE"; - case CL_OUT_OF_RESOURCES: return "CL_OUT_OF_RESOURCES"; - case CL_OUT_OF_HOST_MEMORY: return "CL_OUT_OF_HOST_MEMORY"; - case CL_PROFILING_INFO_NOT_AVAILABLE: return "CL_PROFILING_INFO_NOT_AVAILABLE"; - case CL_MEM_COPY_OVERLAP: return "CL_MEM_COPY_OVERLAP"; - case CL_IMAGE_FORMAT_MISMATCH: return "CL_IMAGE_FORMAT_MISMATCH"; - case CL_IMAGE_FORMAT_NOT_SUPPORTED: return "CL_IMAGE_FORMAT_NOT_SUPPORTED"; - case CL_BUILD_PROGRAM_FAILURE: return "CL_BUILD_PROGRAM_FAILURE"; - case CL_MAP_FAILURE: return "CL_MAP_FAILURE"; - case CL_INVALID_VALUE: return "CL_INVALID_VALUE"; - case CL_INVALID_DEVICE_TYPE: return "CL_INVALID_DEVICE_TYPE"; - case CL_INVALID_PLATFORM: return "CL_INVALID_PLATFORM"; - case CL_INVALID_DEVICE: return "CL_INVALID_DEVICE"; - case CL_INVALID_CONTEXT: return "CL_INVALID_CONTEXT"; - case CL_INVALID_QUEUE_PROPERTIES: return "CL_INVALID_QUEUE_PROPERTIES"; - case CL_INVALID_COMMAND_QUEUE: return "CL_INVALID_COMMAND_QUEUE"; - case CL_INVALID_HOST_PTR: return "CL_INVALID_HOST_PTR"; - case CL_INVALID_MEM_OBJECT: return "CL_INVALID_MEM_OBJECT"; - case CL_INVALID_IMAGE_FORMAT_DESCRIPTOR: return "CL_INVALID_IMAGE_FORMAT_DESCRIPTOR"; - case CL_INVALID_IMAGE_SIZE: return "CL_INVALID_IMAGE_SIZE"; - case CL_INVALID_SAMPLER: return "CL_INVALID_SAMPLER"; - case CL_INVALID_BINARY: return "CL_INVALID_BINARY"; - case CL_INVALID_BUILD_OPTIONS: return "CL_INVALID_BUILD_OPTIONS"; - case CL_INVALID_PROGRAM: return "CL_INVALID_PROGRAM"; - case CL_INVALID_PROGRAM_EXECUTABLE: return "CL_INVALID_PROGRAM_EXECUTABLE"; - case CL_INVALID_KERNEL_NAME: return "CL_INVALID_KERNEL_NAME"; - case CL_INVALID_KERNEL_DEFINITION: return "CL_INVALID_KERNEL_DEFINITION"; - case CL_INVALID_KERNEL: return "CL_INVALID_KERNEL"; - case CL_INVALID_ARG_INDEX: return "CL_INVALID_ARG_INDEX"; - case CL_INVALID_ARG_VALUE: return "CL_INVALID_ARG_VALUE"; - case CL_INVALID_ARG_SIZE: return "CL_INVALID_ARG_SIZE"; - case CL_INVALID_KERNEL_ARGS: return "CL_INVALID_KERNEL_ARGS"; - case CL_INVALID_WORK_DIMENSION: return "CL_INVALID_WORK_DIMENSION"; - case CL_INVALID_WORK_GROUP_SIZE: return "CL_INVALID_WORK_GROUP_SIZE"; - case CL_INVALID_WORK_ITEM_SIZE: return "CL_INVALID_WORK_ITEM_SIZE"; - case CL_INVALID_GLOBAL_OFFSET: return "CL_INVALID_GLOBAL_OFFSET"; - case CL_INVALID_EVENT_WAIT_LIST: return "CL_INVALID_EVENT_WAIT_LIST"; - case CL_INVALID_EVENT: return "CL_INVALID_EVENT"; - case CL_INVALID_OPERATION: return "CL_INVALID_OPERATION"; - case CL_INVALID_GL_OBJECT: return "CL_INVALID_GL_OBJECT"; - case CL_INVALID_BUFFER_SIZE: return "CL_INVALID_BUFFER_SIZE"; - case CL_INVALID_MIP_LEVEL: return "CL_INVALID_MIP_LEVEL"; - default: return "Unknown OpenCL error code"; + case CL_SUCCESS: + return "CL_SUCCESS"; + case CL_DEVICE_NOT_FOUND: + return "CL_DEVICE_NOT_FOUND"; + case CL_DEVICE_NOT_AVAILABLE: + return "CL_DEVICE_NOT_AVAILABLE"; + case CL_COMPILER_NOT_AVAILABLE: + return "CL_COMPILER_NOT_AVAILABLE"; + case CL_MEM_OBJECT_ALLOCATION_FAILURE: + return "CL_MEM_OBJECT_ALLOCATION_FAILURE"; + case CL_OUT_OF_RESOURCES: + return "CL_OUT_OF_RESOURCES"; + case CL_OUT_OF_HOST_MEMORY: + return "CL_OUT_OF_HOST_MEMORY"; + case CL_PROFILING_INFO_NOT_AVAILABLE: + return "CL_PROFILING_INFO_NOT_AVAILABLE"; + case CL_MEM_COPY_OVERLAP: + return "CL_MEM_COPY_OVERLAP"; + case CL_IMAGE_FORMAT_MISMATCH: + return "CL_IMAGE_FORMAT_MISMATCH"; + case CL_IMAGE_FORMAT_NOT_SUPPORTED: + return "CL_IMAGE_FORMAT_NOT_SUPPORTED"; + case CL_BUILD_PROGRAM_FAILURE: + return "CL_BUILD_PROGRAM_FAILURE"; + case CL_MAP_FAILURE: + return "CL_MAP_FAILURE"; + case CL_INVALID_VALUE: + return "CL_INVALID_VALUE"; + case CL_INVALID_DEVICE_TYPE: + return "CL_INVALID_DEVICE_TYPE"; + case CL_INVALID_PLATFORM: + return "CL_INVALID_PLATFORM"; + case CL_INVALID_DEVICE: + return "CL_INVALID_DEVICE"; + case CL_INVALID_CONTEXT: + return "CL_INVALID_CONTEXT"; + case CL_INVALID_QUEUE_PROPERTIES: + return "CL_INVALID_QUEUE_PROPERTIES"; + case CL_INVALID_COMMAND_QUEUE: + return "CL_INVALID_COMMAND_QUEUE"; + case CL_INVALID_HOST_PTR: + return "CL_INVALID_HOST_PTR"; + case CL_INVALID_MEM_OBJECT: + return "CL_INVALID_MEM_OBJECT"; + case CL_INVALID_IMAGE_FORMAT_DESCRIPTOR: + return "CL_INVALID_IMAGE_FORMAT_DESCRIPTOR"; + case CL_INVALID_IMAGE_SIZE: + return "CL_INVALID_IMAGE_SIZE"; + case CL_INVALID_SAMPLER: + return "CL_INVALID_SAMPLER"; + case CL_INVALID_BINARY: + return "CL_INVALID_BINARY"; + case CL_INVALID_BUILD_OPTIONS: + return "CL_INVALID_BUILD_OPTIONS"; + case CL_INVALID_PROGRAM: + return "CL_INVALID_PROGRAM"; + case CL_INVALID_PROGRAM_EXECUTABLE: + return "CL_INVALID_PROGRAM_EXECUTABLE"; + case CL_INVALID_KERNEL_NAME: + return "CL_INVALID_KERNEL_NAME"; + case CL_INVALID_KERNEL_DEFINITION: + return "CL_INVALID_KERNEL_DEFINITION"; + case CL_INVALID_KERNEL: + return "CL_INVALID_KERNEL"; + case CL_INVALID_ARG_INDEX: + return "CL_INVALID_ARG_INDEX"; + case CL_INVALID_ARG_VALUE: + return "CL_INVALID_ARG_VALUE"; + case CL_INVALID_ARG_SIZE: + return "CL_INVALID_ARG_SIZE"; + case CL_INVALID_KERNEL_ARGS: + return "CL_INVALID_KERNEL_ARGS"; + case CL_INVALID_WORK_DIMENSION: + return "CL_INVALID_WORK_DIMENSION"; + case CL_INVALID_WORK_GROUP_SIZE: + return "CL_INVALID_WORK_GROUP_SIZE"; + case CL_INVALID_WORK_ITEM_SIZE: + return "CL_INVALID_WORK_ITEM_SIZE"; + case CL_INVALID_GLOBAL_OFFSET: + return "CL_INVALID_GLOBAL_OFFSET"; + case CL_INVALID_EVENT_WAIT_LIST: + return "CL_INVALID_EVENT_WAIT_LIST"; + case CL_INVALID_EVENT: + return "CL_INVALID_EVENT"; + case CL_INVALID_OPERATION: + return "CL_INVALID_OPERATION"; + case CL_INVALID_GL_OBJECT: + return "CL_INVALID_GL_OBJECT"; + case CL_INVALID_BUFFER_SIZE: + return "CL_INVALID_BUFFER_SIZE"; + case CL_INVALID_MIP_LEVEL: + return "CL_INVALID_MIP_LEVEL"; + default: + return "Unknown OpenCL error code"; } } @@ -119,16 +166,13 @@ inline const char* CLGetErrorString(cl_int error) { * \brief Protected OpenCL call * \param func Expression to call. */ -#define OPENCL_CHECK_ERROR(e) \ - { \ - CHECK(e == CL_SUCCESS) \ - << "OpenCL Error, code=" << e << ": " << cl::CLGetErrorString(e); \ - } +#define OPENCL_CHECK_ERROR(e) \ + { CHECK(e == CL_SUCCESS) << "OpenCL Error, code=" << e << ": " << cl::CLGetErrorString(e); } -#define OPENCL_CALL(func) \ - { \ - cl_int e = (func); \ - OPENCL_CHECK_ERROR(e); \ +#define OPENCL_CALL(func) \ + { \ + cl_int e = (func); \ + OPENCL_CHECK_ERROR(e); \ } class OpenCLThreadEntry; @@ -172,37 +216,24 @@ class OpenCLWorkspace : public DeviceAPI { // Initialzie the device. void Init(const std::string& type_key, const std::string& device_type, const std::string& platform_name = ""); - virtual void Init() { - Init("opencl", "gpu"); - } + virtual void Init() { Init("opencl", "gpu"); } // Check whether the context is OpenCL or not. - virtual bool IsOpenCLDevice(TVMContext ctx) { - return ctx.device_type == kDLOpenCL; - } + virtual bool IsOpenCLDevice(TVMContext ctx) { return ctx.device_type == kDLOpenCL; } // get the queue of the context cl_command_queue GetQueue(TVMContext ctx) { CHECK(IsOpenCLDevice(ctx)); this->Init(); - CHECK(ctx.device_id >= 0 && static_cast(ctx.device_id) < queues.size()) + CHECK(ctx.device_id >= 0 && static_cast(ctx.device_id) < queues.size()) << "Invalid OpenCL device_id=" << ctx.device_id; return queues[ctx.device_id]; } // override device API void SetDevice(TVMContext ctx) final; void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final; - void* AllocDataSpace(TVMContext ctx, - size_t size, - size_t alignment, - DLDataType type_hint) final; + void* AllocDataSpace(TVMContext ctx, size_t size, size_t alignment, DLDataType type_hint) final; void FreeDataSpace(TVMContext ctx, void* ptr) final; - void CopyDataFromTo(const void* from, - size_t from_offset, - void* to, - size_t to_offset, - size_t size, - TVMContext ctx_from, - TVMContext ctx_to, - DLDataType type_hint, + void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size, + TVMContext ctx_from, TVMContext ctx_to, DLDataType type_hint, TVMStreamHandle stream) final; void StreamSync(TVMContext ctx, TVMStreamHandle stream) final; void* AllocWorkspace(TVMContext ctx, size_t size, DLDataType type_hint) final; @@ -217,7 +248,6 @@ class OpenCLWorkspace : public DeviceAPI { static const std::shared_ptr& Global(); }; - /*! \brief Thread local workspace */ class OpenCLThreadEntry { public: @@ -240,8 +270,7 @@ class OpenCLThreadEntry { context.device_id = 0; context.device_type = device_type; } - OpenCLThreadEntry() - : OpenCLThreadEntry(kDLOpenCL, OpenCLWorkspace::Global()) {} + OpenCLThreadEntry() : OpenCLThreadEntry(kDLOpenCL, OpenCLWorkspace::Global()) {} // get the global workspace static OpenCLThreadEntry* ThreadLocal(); @@ -260,10 +289,8 @@ class OpenCLModuleNode : public ModuleNode { size_t kernel_id; size_t version; }; - explicit OpenCLModuleNode(std::string data, - std::string fmt, - std::unordered_map fmap, - std::string source) + explicit OpenCLModuleNode(std::string data, std::string fmt, + std::unordered_map fmap, std::string source) : data_(data), fmt_(fmt), fmap_(fmap), source_(source) {} // destructor ~OpenCLModuleNode(); @@ -275,20 +302,15 @@ class OpenCLModuleNode : public ModuleNode { const char* type_key() const final { return workspace_->type_key.c_str(); } - PackedFunc GetFunction( - const std::string& name, - const ObjectPtr& sptr_to_self) final; - void SaveToFile(const std::string& file_name, - const std::string& format) final; + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final; + void SaveToFile(const std::string& file_name, const std::string& format) final; void SaveToBinary(dmlc::Stream* stream) final; std::string GetSource(const std::string& format) final; // Initialize the programs void Init(); // install a new kernel to thread local entry - cl_kernel InstallKernel(cl::OpenCLWorkspace* w, - cl::OpenCLThreadEntry* t, - const std::string& func_name, - const KTRefEntry& e); + cl_kernel InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThreadEntry* t, + const std::string& func_name, const KTRefEntry& e); private: // The workspace, need to keep reference to use it in destructor. diff --git a/src/runtime/opencl/opencl_device_api.cc b/src/runtime/opencl/opencl_device_api.cc index 99d2b0cb24e6..6d9835e6231c 100644 --- a/src/runtime/opencl/opencl_device_api.cc +++ b/src/runtime/opencl/opencl_device_api.cc @@ -20,17 +20,16 @@ /*! * \file opencl_device_api.cc */ -#include #include +#include + #include "opencl_common.h" namespace tvm { namespace runtime { namespace cl { -OpenCLThreadEntry* OpenCLWorkspace::GetThreadEntry() { - return OpenCLThreadEntry::ThreadLocal(); -} +OpenCLThreadEntry* OpenCLWorkspace::GetThreadEntry() { return OpenCLThreadEntry::ThreadLocal(); } const std::shared_ptr& OpenCLWorkspace::Global() { static std::shared_ptr inst = std::make_shared(); @@ -41,23 +40,21 @@ void OpenCLWorkspace::SetDevice(TVMContext ctx) { GetThreadEntry()->context.device_id = ctx.device_id; } -void OpenCLWorkspace::GetAttr( - TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) { +void OpenCLWorkspace::GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) { this->Init(); size_t index = static_cast(ctx.device_id); if (kind == kExist) { - *rv = static_cast(index< devices.size()); + *rv = static_cast(index < devices.size()); return; } - CHECK_LT(index, devices.size()) - << "Invalid device id " << index; + CHECK_LT(index, devices.size()) << "Invalid device id " << index; switch (kind) { - case kExist: break; + case kExist: + break; case kMaxThreadsPerBlock: { size_t value; - OPENCL_CALL(clGetDeviceInfo( - devices[index], CL_DEVICE_MAX_WORK_GROUP_SIZE, - sizeof(size_t), &value, nullptr)); + OPENCL_CALL(clGetDeviceInfo(devices[index], CL_DEVICE_MAX_WORK_GROUP_SIZE, sizeof(size_t), + &value, nullptr)); *rv = static_cast(value); break; } @@ -72,58 +69,55 @@ void OpenCLWorkspace::GetAttr( } case kMaxSharedMemoryPerBlock: { cl_ulong value; - OPENCL_CALL(clGetDeviceInfo( - devices[index], CL_DEVICE_LOCAL_MEM_SIZE, - sizeof(cl_ulong), &value, nullptr)); + OPENCL_CALL(clGetDeviceInfo(devices[index], CL_DEVICE_LOCAL_MEM_SIZE, sizeof(cl_ulong), + &value, nullptr)); *rv = static_cast(value); break; } - case kComputeVersion: return; + case kComputeVersion: + return; case kDeviceName: { char value[128] = {0}; - OPENCL_CALL(clGetDeviceInfo( - devices[index], CL_DEVICE_NAME, - sizeof(value) - 1, value, nullptr)); + OPENCL_CALL( + clGetDeviceInfo(devices[index], CL_DEVICE_NAME, sizeof(value) - 1, value, nullptr)); *rv = std::string(value); break; } case kMaxClockRate: { cl_uint value; - OPENCL_CALL(clGetDeviceInfo( - devices[index], CL_DEVICE_MAX_CLOCK_FREQUENCY, - sizeof(cl_uint), &value, nullptr)); + OPENCL_CALL(clGetDeviceInfo(devices[index], CL_DEVICE_MAX_CLOCK_FREQUENCY, sizeof(cl_uint), + &value, nullptr)); *rv = static_cast(value); break; } case kMultiProcessorCount: { cl_uint value; - OPENCL_CALL(clGetDeviceInfo( - devices[index], CL_DEVICE_MAX_COMPUTE_UNITS, - sizeof(cl_uint), &value, nullptr)); + OPENCL_CALL(clGetDeviceInfo(devices[index], CL_DEVICE_MAX_COMPUTE_UNITS, sizeof(cl_uint), + &value, nullptr)); *rv = static_cast(value); break; } case kMaxThreadDimensions: { size_t dims[3]; - OPENCL_CALL(clGetDeviceInfo( - devices[index], CL_DEVICE_MAX_WORK_ITEM_SIZES, sizeof(dims), dims, nullptr)); + OPENCL_CALL(clGetDeviceInfo(devices[index], CL_DEVICE_MAX_WORK_ITEM_SIZES, sizeof(dims), dims, + nullptr)); std::stringstream ss; // use json string to return multiple int values; - ss << "[" << dims[0] <<", " << dims[1] << ", " << dims[2] << "]"; + ss << "[" << dims[0] << ", " << dims[1] << ", " << dims[2] << "]"; *rv = ss.str(); break; } - case kGcnArch: return; + case kGcnArch: + return; } } -void* OpenCLWorkspace::AllocDataSpace( - TVMContext ctx, size_t size, size_t alignment, DLDataType type_hint) { +void* OpenCLWorkspace::AllocDataSpace(TVMContext ctx, size_t size, size_t alignment, + DLDataType type_hint) { this->Init(); CHECK(context != nullptr) << "No OpenCL device"; cl_int err_code; - cl_mem mptr = clCreateBuffer( - this->context, CL_MEM_READ_WRITE, size, nullptr, &err_code); + cl_mem mptr = clCreateBuffer(this->context, CL_MEM_READ_WRITE, size, nullptr, &err_code); OPENCL_CHECK_ERROR(err_code); return mptr; } @@ -137,38 +131,27 @@ void OpenCLWorkspace::FreeDataSpace(TVMContext ctx, void* ptr) { OPENCL_CALL(clReleaseMemObject(mptr)); } -void OpenCLWorkspace::CopyDataFromTo(const void* from, - size_t from_offset, - void* to, - size_t to_offset, - size_t size, - TVMContext ctx_from, - TVMContext ctx_to, - DLDataType type_hint, +void OpenCLWorkspace::CopyDataFromTo(const void* from, size_t from_offset, void* to, + size_t to_offset, size_t size, TVMContext ctx_from, + TVMContext ctx_to, DLDataType type_hint, TVMStreamHandle stream) { this->Init(); CHECK(stream == nullptr); if (IsOpenCLDevice(ctx_from) && IsOpenCLDevice(ctx_to)) { - OPENCL_CALL(clEnqueueCopyBuffer( - this->GetQueue(ctx_to), - static_cast((void*)from), // NOLINT(*) - static_cast(to), - from_offset, to_offset, size, 0, nullptr, nullptr)); + OPENCL_CALL(clEnqueueCopyBuffer(this->GetQueue(ctx_to), + static_cast((void*)from), // NOLINT(*) + static_cast(to), from_offset, to_offset, size, 0, + nullptr, nullptr)); } else if (IsOpenCLDevice(ctx_from) && ctx_to.device_type == kDLCPU) { - OPENCL_CALL(clEnqueueReadBuffer( - this->GetQueue(ctx_from), - static_cast((void*)from), // NOLINT(*) - CL_FALSE, from_offset, size, - static_cast(to) + to_offset, - 0, nullptr, nullptr)); + OPENCL_CALL(clEnqueueReadBuffer(this->GetQueue(ctx_from), + static_cast((void*)from), // NOLINT(*) + CL_FALSE, from_offset, size, static_cast(to) + to_offset, + 0, nullptr, nullptr)); OPENCL_CALL(clFinish(this->GetQueue(ctx_from))); } else if (ctx_from.device_type == kDLCPU && IsOpenCLDevice(ctx_to)) { - OPENCL_CALL(clEnqueueWriteBuffer( - this->GetQueue(ctx_to), - static_cast(to), - CL_FALSE, to_offset, size, - static_cast(from) + from_offset, - 0, nullptr, nullptr)); + OPENCL_CALL(clEnqueueWriteBuffer(this->GetQueue(ctx_to), static_cast(to), CL_FALSE, + to_offset, size, static_cast(from) + from_offset, + 0, nullptr, nullptr)); OPENCL_CALL(clFinish(this->GetQueue(ctx_to))); } else { LOG(FATAL) << "Expect copy from/to OpenCL or between OpenCL"; @@ -180,9 +163,7 @@ void OpenCLWorkspace::StreamSync(TVMContext ctx, TVMStreamHandle stream) { OPENCL_CALL(clFinish(this->GetQueue(ctx))); } -void* OpenCLWorkspace::AllocWorkspace(TVMContext ctx, - size_t size, - DLDataType type_hint) { +void* OpenCLWorkspace::AllocWorkspace(TVMContext ctx, size_t size, DLDataType type_hint) { return GetThreadEntry()->pool.AllocWorkspace(ctx, size); } @@ -192,12 +173,9 @@ void OpenCLWorkspace::FreeWorkspace(TVMContext ctx, void* data) { typedef dmlc::ThreadLocalStore OpenCLThreadStore; -OpenCLThreadEntry* OpenCLThreadEntry::ThreadLocal() { - return OpenCLThreadStore::Get(); -} +OpenCLThreadEntry* OpenCLThreadEntry::ThreadLocal() { return OpenCLThreadStore::Get(); } -std::string GetPlatformInfo( - cl_platform_id pid, cl_platform_info param_name) { +std::string GetPlatformInfo(cl_platform_id pid, cl_platform_info param_name) { size_t ret_size; OPENCL_CALL(clGetPlatformInfo(pid, param_name, 0, nullptr, &ret_size)); std::string ret; @@ -206,8 +184,7 @@ std::string GetPlatformInfo( return ret; } -std::string GetDeviceInfo( - cl_device_id pid, cl_device_info param_name) { +std::string GetDeviceInfo(cl_device_id pid, cl_device_info param_name) { size_t ret_size; OPENCL_CALL(clGetDeviceInfo(pid, param_name, 0, nullptr, &ret_size)); std::string ret; @@ -226,8 +203,7 @@ std::vector GetPlatformIDs() { return ret; } -std::vector GetDeviceIDs( - cl_platform_id pid, std::string device_type) { +std::vector GetDeviceIDs(cl_platform_id pid, std::string device_type) { cl_device_type dtype = CL_DEVICE_TYPE_ALL; if (device_type == "cpu") dtype = CL_DEVICE_TYPE_CPU; if (device_type == "gpu") dtype = CL_DEVICE_TYPE_GPU; @@ -241,10 +217,7 @@ std::vector GetDeviceIDs( return ret; } -bool MatchPlatformInfo( - cl_platform_id pid, - cl_platform_info param_name, - std::string value) { +bool MatchPlatformInfo(cl_platform_id pid, cl_platform_info param_name, std::string value) { if (value.length() == 0) return true; std::string param_value = GetPlatformInfo(pid, param_name); return param_value.find(value) != std::string::npos; @@ -286,25 +259,22 @@ void OpenCLWorkspace::Init(const std::string& type_key, const std::string& devic return; } cl_int err_code; - this->context = clCreateContext( - nullptr, this->devices.size(), &(this->devices[0]), - nullptr, nullptr, &err_code); + this->context = clCreateContext(nullptr, this->devices.size(), &(this->devices[0]), nullptr, + nullptr, &err_code); OPENCL_CHECK_ERROR(err_code); CHECK_EQ(this->queues.size(), 0U); for (size_t i = 0; i < this->devices.size(); ++i) { cl_device_id did = this->devices[i]; - this->queues.push_back( - clCreateCommandQueue(this->context, did, 0, &err_code)); + this->queues.push_back(clCreateCommandQueue(this->context, did, 0, &err_code)); OPENCL_CHECK_ERROR(err_code); } initialized_ = true; } -TVM_REGISTER_GLOBAL("device_api.opencl") -.set_body([](TVMArgs args, TVMRetValue* rv) { - DeviceAPI* ptr = OpenCLWorkspace::Global().get(); - *rv = static_cast(ptr); - }); +TVM_REGISTER_GLOBAL("device_api.opencl").set_body([](TVMArgs args, TVMRetValue* rv) { + DeviceAPI* ptr = OpenCLWorkspace::Global().get(); + *rv = static_cast(ptr); +}); } // namespace cl } // namespace runtime diff --git a/src/runtime/opencl/opencl_module.cc b/src/runtime/opencl/opencl_module.cc index fefde72b9508..95d0481c31d5 100644 --- a/src/runtime/opencl/opencl_module.cc +++ b/src/runtime/opencl/opencl_module.cc @@ -20,13 +20,16 @@ /*! * \file opencl_module.cc */ +#include "opencl_module.h" + #include #include -#include + #include #include +#include + #include "opencl_common.h" -#include "opencl_module.h" namespace tvm { namespace runtime { @@ -34,12 +37,9 @@ namespace runtime { class OpenCLWrappedFunc { public: // initialize the OpenCL function. - void Init(OpenCLModuleNode* m, - ObjectPtr sptr, - OpenCLModuleNode::KTRefEntry entry, - std::string func_name, - std::vector arg_size, - const std::vector& thread_axis_tags) { + void Init(OpenCLModuleNode* m, ObjectPtr sptr, OpenCLModuleNode::KTRefEntry entry, + std::string func_name, std::vector arg_size, + const std::vector& thread_axis_tags) { w_ = m->GetGlobalWorkspace().get(); m_ = m; sptr_ = sptr; @@ -49,9 +49,7 @@ class OpenCLWrappedFunc { thread_axis_cfg_.Init(arg_size.size(), thread_axis_tags); } // invoke the function with void arguments - void operator()(TVMArgs args, - TVMRetValue* rv, - void** void_args) const { + void operator()(TVMArgs args, TVMRetValue* rv, void** void_args) const { CHECK(w_->context != nullptr) << "No OpenCL device"; cl::OpenCLThreadEntry* t = w_->GetThreadEntry(); // get the kernel from thread local kernel table. @@ -74,11 +72,8 @@ class OpenCLWrappedFunc { wl.work_size[i] *= wl.work_size[i + 3]; } // launch kernel - OPENCL_CALL(clEnqueueNDRangeKernel( - queue, kernel, work_dim, nullptr, - wl.work_size, - wl.work_size + 3, - 0, nullptr, nullptr)); + OPENCL_CALL(clEnqueueNDRangeKernel(queue, kernel, work_dim, nullptr, wl.work_size, + wl.work_size + 3, 0, nullptr, nullptr)); } private: @@ -119,12 +114,10 @@ const std::shared_ptr& OpenCLModuleNode::GetGlobalWorkspace return cl::OpenCLWorkspace::Global(); } -PackedFunc OpenCLModuleNode::GetFunction( - const std::string& name, - const ObjectPtr& sptr_to_self) { +PackedFunc OpenCLModuleNode::GetFunction(const std::string& name, + const ObjectPtr& sptr_to_self) { CHECK_EQ(sptr_to_self.get(), this); - CHECK_NE(name, symbol::tvm_module_main) - << "Device function do not have main"; + CHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main"; auto it = fmap_.find(name); if (it == fmap_.end()) return PackedFunc(); const FunctionInfo& info = it->second; @@ -143,16 +136,13 @@ PackedFunc OpenCLModuleNode::GetFunction( } } // initialize the wrapped func. - f.Init(this, sptr_to_self, kid_map_.at(name), - name, arg_size, info.thread_axis_tags); + f.Init(this, sptr_to_self, kid_map_.at(name), name, arg_size, info.thread_axis_tags); return PackFuncVoidAddr(f, info.arg_types); } -void OpenCLModuleNode::SaveToFile(const std::string& file_name, - const std::string& format) { +void OpenCLModuleNode::SaveToFile(const std::string& file_name, const std::string& format) { std::string fmt = GetFileFormat(file_name, format); - CHECK_EQ(fmt, fmt_) - << "Can only save to format=" << fmt_; + CHECK_EQ(fmt, fmt_) << "Can only save to format=" << fmt_; std::string meta_file = GetMetaFilePath(file_name); SaveMetaDataToFile(meta_file, fmap_); SaveBinaryToFile(file_name, data_); @@ -193,10 +183,8 @@ void OpenCLModuleNode::Init() { } } -cl_kernel OpenCLModuleNode::InstallKernel(cl::OpenCLWorkspace* w, - cl::OpenCLThreadEntry* t, - const std::string& func_name, - const KTRefEntry& e) { +cl_kernel OpenCLModuleNode::InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThreadEntry* t, + const std::string& func_name, const KTRefEntry& e) { std::lock_guard lock(build_lock_); int device_id = t->context.device_id; if (!device_built_flag_[device_id]) { @@ -210,7 +198,7 @@ cl_kernel OpenCLModuleNode::InstallKernel(cl::OpenCLWorkspace* w, OPENCL_CHECK_ERROR(err); } } else if (fmt_ == "xclbin" || fmt_ == "awsxclbin" || fmt_ == "aocx") { - const unsigned char* s = (const unsigned char *)data_.c_str(); + const unsigned char* s = (const unsigned char*)data_.c_str(); size_t len = data_.length(); cl_int err; cl_device_id dev = w->devices[device_id]; @@ -226,11 +214,9 @@ cl_kernel OpenCLModuleNode::InstallKernel(cl::OpenCLWorkspace* w, if (err != CL_SUCCESS) { size_t len; std::string log; - clGetProgramBuildInfo( - program_, dev, CL_PROGRAM_BUILD_LOG, 0, nullptr, &len); + clGetProgramBuildInfo(program_, dev, CL_PROGRAM_BUILD_LOG, 0, nullptr, &len); log.resize(len); - clGetProgramBuildInfo( - program_, dev, CL_PROGRAM_BUILD_LOG, len, &log[0], nullptr); + clGetProgramBuildInfo(program_, dev, CL_PROGRAM_BUILD_LOG, len, &log[0], nullptr); LOG(FATAL) << "OpenCL build error for device=" << dev << log; } device_built_flag_[device_id] = true; @@ -245,19 +231,15 @@ cl_kernel OpenCLModuleNode::InstallKernel(cl::OpenCLWorkspace* w, return kernel; } -Module OpenCLModuleCreate( - std::string data, - std::string fmt, - std::unordered_map fmap, - std::string source) { +Module OpenCLModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, std::string source) { auto n = make_object(data, fmt, fmap, source); n->Init(); return Module(n); } // Load module from module. -Module OpenCLModuleLoadFile(const std::string& file_name, - const std::string& format) { +Module OpenCLModuleLoadFile(const std::string& file_name, const std::string& format) { std::string data; std::unordered_map fmap; std::string fmt = GetFileFormat(file_name, format); @@ -278,13 +260,10 @@ Module OpenCLModuleLoadBinary(void* strm) { return OpenCLModuleCreate(data, fmt, fmap, std::string()); } -TVM_REGISTER_GLOBAL("runtime.module.loadfile_cl") -.set_body_typed(OpenCLModuleLoadFile); +TVM_REGISTER_GLOBAL("runtime.module.loadfile_cl").set_body_typed(OpenCLModuleLoadFile); -TVM_REGISTER_GLOBAL("runtime.module.loadfile_clbin") -.set_body_typed(OpenCLModuleLoadFile); +TVM_REGISTER_GLOBAL("runtime.module.loadfile_clbin").set_body_typed(OpenCLModuleLoadFile); -TVM_REGISTER_GLOBAL("runtime.module.loadbinary_opencl") -.set_body_typed(OpenCLModuleLoadBinary); +TVM_REGISTER_GLOBAL("runtime.module.loadbinary_opencl").set_body_typed(OpenCLModuleLoadBinary); } // namespace runtime } // namespace tvm diff --git a/src/runtime/opencl/opencl_module.h b/src/runtime/opencl/opencl_module.h index 3b7ebb9c1659..77f4b8010779 100644 --- a/src/runtime/opencl/opencl_module.h +++ b/src/runtime/opencl/opencl_module.h @@ -25,10 +25,12 @@ #define TVM_RUNTIME_OPENCL_OPENCL_MODULE_H_ #include + #include -#include #include #include +#include + #include "../meta_data.h" namespace tvm { @@ -40,11 +42,8 @@ namespace runtime { * \param fmt The format of the data, can be "clbin", "cl" * \param fmap The map function information map of each function. */ -Module OpenCLModuleCreate( - std::string data, - std::string fmt, - std::unordered_map fmap, - std::string source); +Module OpenCLModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, std::string source); } // namespace runtime } // namespace tvm #endif // TVM_RUNTIME_OPENCL_OPENCL_MODULE_H_ diff --git a/src/runtime/opencl/sdaccel/sdaccel_common.h b/src/runtime/opencl/sdaccel/sdaccel_common.h index 2100b50678b3..803cbe67b9a7 100644 --- a/src/runtime/opencl/sdaccel/sdaccel_common.h +++ b/src/runtime/opencl/sdaccel/sdaccel_common.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -25,6 +25,7 @@ #define TVM_RUNTIME_OPENCL_SDACCEL_SDACCEL_COMMON_H_ #include + #include "../opencl_common.h" namespace tvm { @@ -44,7 +45,6 @@ class SDAccelWorkspace final : public OpenCLWorkspace { static const std::shared_ptr& Global(); }; - /*! \brief Thread local workspace for SDAccel*/ class SDAccelThreadEntry : public OpenCLThreadEntry { public: diff --git a/src/runtime/opencl/sdaccel/sdaccel_device_api.cc b/src/runtime/opencl/sdaccel/sdaccel_device_api.cc index 59e8a25c834e..6bac0c916aad 100644 --- a/src/runtime/opencl/sdaccel/sdaccel_device_api.cc +++ b/src/runtime/opencl/sdaccel/sdaccel_device_api.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -20,26 +20,23 @@ /*! * \file sdaccel_device_api.cc */ -#include #include +#include + #include "sdaccel_common.h" namespace tvm { namespace runtime { namespace cl { -OpenCLThreadEntry* SDAccelWorkspace::GetThreadEntry() { - return SDAccelThreadEntry::ThreadLocal(); -} +OpenCLThreadEntry* SDAccelWorkspace::GetThreadEntry() { return SDAccelThreadEntry::ThreadLocal(); } const std::shared_ptr& SDAccelWorkspace::Global() { static std::shared_ptr inst = std::make_shared(); return inst; } -void SDAccelWorkspace::Init() { - OpenCLWorkspace::Init("sdaccel", "accelerator", "Xilinx"); -} +void SDAccelWorkspace::Init() { OpenCLWorkspace::Init("sdaccel", "accelerator", "Xilinx"); } bool SDAccelWorkspace::IsOpenCLDevice(TVMContext ctx) { return ctx.device_type == static_cast(kDLSDAccel); @@ -47,15 +44,12 @@ bool SDAccelWorkspace::IsOpenCLDevice(TVMContext ctx) { typedef dmlc::ThreadLocalStore SDAccelThreadStore; -SDAccelThreadEntry* SDAccelThreadEntry::ThreadLocal() { - return SDAccelThreadStore::Get(); -} +SDAccelThreadEntry* SDAccelThreadEntry::ThreadLocal() { return SDAccelThreadStore::Get(); } -TVM_REGISTER_GLOBAL("device_api.sdaccel") -.set_body([](TVMArgs args, TVMRetValue* rv) { - DeviceAPI* ptr = SDAccelWorkspace::Global().get(); - *rv = static_cast(ptr); - }); +TVM_REGISTER_GLOBAL("device_api.sdaccel").set_body([](TVMArgs args, TVMRetValue* rv) { + DeviceAPI* ptr = SDAccelWorkspace::Global().get(); + *rv = static_cast(ptr); +}); } // namespace cl } // namespace runtime diff --git a/src/runtime/opencl/sdaccel/sdaccel_module.cc b/src/runtime/opencl/sdaccel/sdaccel_module.cc index 4569ec3946df..b4edca32a998 100644 --- a/src/runtime/opencl/sdaccel/sdaccel_module.cc +++ b/src/runtime/opencl/sdaccel/sdaccel_module.cc @@ -20,23 +20,24 @@ /*! * \file sdaccel_module.cc */ +#include "sdaccel_module.h" + #include #include -#include + #include #include +#include + #include "sdaccel_common.h" -#include "sdaccel_module.h" namespace tvm { namespace runtime { class SDAccelModuleNode : public OpenCLModuleNode { public: - explicit SDAccelModuleNode(std::string data, - std::string fmt, - std::unordered_map fmap, - std::string source) + explicit SDAccelModuleNode(std::string data, std::string fmt, + std::unordered_map fmap, std::string source) : OpenCLModuleNode(data, fmt, fmap, source) {} const std::shared_ptr& GetGlobalWorkspace() final; }; @@ -45,18 +46,14 @@ const std::shared_ptr& SDAccelModuleNode::GetGlobalWorkspac return cl::SDAccelWorkspace::Global(); } -Module SDAccelModuleCreate( - std::string data, - std::string fmt, - std::unordered_map fmap, - std::string source) { +Module SDAccelModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, std::string source) { auto n = make_object(data, fmt, fmap, source); n->Init(); return Module(n); } -Module SDAccelModuleLoadFile(const std::string& file_name, - const std::string& format) { +Module SDAccelModuleLoadFile(const std::string& file_name, const std::string& format) { std::string data; std::unordered_map fmap; std::string fmt = GetFileFormat(file_name, format); @@ -77,10 +74,8 @@ Module SDAccelModuleLoadBinary(void* strm) { return SDAccelModuleCreate(data, fmt, fmap, std::string()); } -TVM_REGISTER_GLOBAL("runtime.module.loadfile_xclbin") -.set_body_typed(SDAccelModuleLoadFile); +TVM_REGISTER_GLOBAL("runtime.module.loadfile_xclbin").set_body_typed(SDAccelModuleLoadFile); -TVM_REGISTER_GLOBAL("runtime.module.loadfile_awsxclbin") -.set_body_typed(SDAccelModuleLoadFile); +TVM_REGISTER_GLOBAL("runtime.module.loadfile_awsxclbin").set_body_typed(SDAccelModuleLoadFile); } // namespace runtime } // namespace tvm diff --git a/src/runtime/opencl/sdaccel/sdaccel_module.h b/src/runtime/opencl/sdaccel/sdaccel_module.h index e126291f3f03..322decc4460c 100644 --- a/src/runtime/opencl/sdaccel/sdaccel_module.h +++ b/src/runtime/opencl/sdaccel/sdaccel_module.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -25,10 +25,12 @@ #define TVM_RUNTIME_OPENCL_SDACCEL_SDACCEL_MODULE_H_ #include + #include -#include #include #include +#include + #include "../../meta_data.h" namespace tvm { @@ -40,11 +42,8 @@ namespace runtime { * \param fmt The format of the data, can be "xclbin", "awsxclbin" * \param fmap The map function information map of each function. */ -Module SDAccelModuleCreate( - std::string data, - std::string fmt, - std::unordered_map fmap, - std::string source); +Module SDAccelModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, std::string source); } // namespace runtime } // namespace tvm #endif // TVM_RUNTIME_OPENCL_SDACCEL_SDACCEL_MODULE_H_ diff --git a/src/runtime/opengl/opengl_common.h b/src/runtime/opengl/opengl_common.h deleted file mode 100644 index 009ea6c9111d..000000000000 --- a/src/runtime/opengl/opengl_common.h +++ /dev/null @@ -1,514 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file opengl_common.h - * \brief OpenGL common header - */ -#ifndef TVM_RUNTIME_OPENGL_OPENGL_COMMON_H_ -#define TVM_RUNTIME_OPENGL_OPENGL_COMMON_H_ - -#include -#include -#include -#include -#if defined(__APPLE__) -#define GLFW_INCLUDE_GLCOREARB -#endif -#include -#include -#include -#include -#include -#include - -namespace tvm { -namespace runtime { -namespace gl { - -// This file contains the following classes. -class GLFunctionPointers; -class OpenGLWorkspace; -class Texture; -class Program; - -inline GLFWglproc GetProcAddress(const char* procname) { - GLFWglproc proc = glfwGetProcAddress(procname); - CHECK(proc != nullptr) << "Cannot get function \"" << procname << "\""; - return proc; -} - -#define SetGLFunctionPointer(NAME) \ - NAME(decltype(NAME)(GetProcAddress("gl" #NAME))) - -/*! - * \brief The function pointers of all OpenGL APIs that are used. - * Must be constructed after creating an OpenGL context. - */ -class GLFunctionPointers { - public: - GLFunctionPointers() - : SetGLFunctionPointer(ActiveTexture), - SetGLFunctionPointer(AttachShader), - SetGLFunctionPointer(BindBuffer), - SetGLFunctionPointer(BindFramebuffer), - SetGLFunctionPointer(BindTexture), - SetGLFunctionPointer(BindVertexArray), - SetGLFunctionPointer(BufferData), - SetGLFunctionPointer(CheckFramebufferStatus), - SetGLFunctionPointer(Clear), - SetGLFunctionPointer(CompileShader), - SetGLFunctionPointer(CreateProgram), - SetGLFunctionPointer(CreateShader), - SetGLFunctionPointer(DeleteFramebuffers), - SetGLFunctionPointer(DeleteProgram), - SetGLFunctionPointer(DeleteShader), - SetGLFunctionPointer(DeleteTextures), - SetGLFunctionPointer(DetachShader), - SetGLFunctionPointer(DrawArrays), - SetGLFunctionPointer(DrawBuffers), - SetGLFunctionPointer(EnableVertexAttribArray), - SetGLFunctionPointer(Finish), - SetGLFunctionPointer(FramebufferTexture2D), - SetGLFunctionPointer(GenBuffers), - SetGLFunctionPointer(GenFramebuffers), - SetGLFunctionPointer(GenTextures), - SetGLFunctionPointer(GenVertexArrays), - SetGLFunctionPointer(GetAttribLocation), - SetGLFunctionPointer(GetError), - SetGLFunctionPointer(GetIntegerv), - SetGLFunctionPointer(GetProgramInfoLog), - SetGLFunctionPointer(GetProgramiv), - SetGLFunctionPointer(GetShaderInfoLog), - SetGLFunctionPointer(GetShaderiv), - SetGLFunctionPointer(GetString), - SetGLFunctionPointer(GetUniformLocation), - SetGLFunctionPointer(LinkProgram), - SetGLFunctionPointer(ReadPixels), - SetGLFunctionPointer(ShaderSource), - SetGLFunctionPointer(TexImage2D), - SetGLFunctionPointer(TexParameteri), - SetGLFunctionPointer(TexSubImage2D), - SetGLFunctionPointer(Uniform1f), - SetGLFunctionPointer(Uniform1i), - SetGLFunctionPointer(UseProgram), - SetGLFunctionPointer(VertexAttribPointer), - SetGLFunctionPointer(Viewport) {} - - void (*ActiveTexture)(GLenum texture); - void (*AttachShader)(GLuint program, GLuint shader); - void (*BindBuffer)(GLenum target, GLuint buffer); - void (*BindFramebuffer)(GLenum target, GLuint framebuffer); - void (*BindTexture)(GLenum target, GLuint texture); - void (*BindVertexArray)(GLuint array); - void (*BufferData)(GLenum target, GLsizeiptr size, const GLvoid* data, - GLenum usage); - GLenum (*CheckFramebufferStatus)(GLenum target); - void (*Clear)(GLbitfield mask); - void (*CompileShader)(GLuint shader); - GLuint (*CreateProgram)(); - GLuint (*CreateShader)(GLenum shader_type); - void (*DeleteFramebuffers)(GLsizei n, const GLuint* framebuffers); - void (*DeleteProgram)(GLuint program); - void (*DeleteShader)(GLuint shader); - void (*DeleteTextures)(GLsizei n, const GLuint* textures); - void (*DetachShader)(GLuint program, GLuint shader); - void (*DrawArrays)(GLenum mode, GLint first, GLsizei count); - void (*DrawBuffers)(GLsizei n, const GLenum* bufs); - void (*EnableVertexAttribArray)(GLuint index); - void (*Finish)(); - void (*FramebufferTexture2D)(GLenum target, GLenum attachment, - GLenum textarget, GLuint texture, GLint level); - void (*GenBuffers)(GLsizei n, GLuint* buffers); - void (*GenFramebuffers)(GLsizei n, GLuint* ids); - void (*GenTextures)(GLsizei n, GLuint* textures); - void (*GenVertexArrays)(GLsizei n, GLuint* arrays); - GLint (*GetAttribLocation)(GLuint program, const GLchar* name); - GLenum (*GetError)(); - void (*GetIntegerv)(GLenum pname, GLint* data); - void (*GetProgramInfoLog)(GLuint program, GLsizei maxLength, GLsizei* length, - GLchar* info_log); - void (*GetProgramiv)(GLuint program, GLenum pname, GLint* params); - void (*GetShaderInfoLog)(GLuint shader, GLsizei max_length, GLsizei* length, - GLchar* info_log); - void (*GetShaderiv)(GLuint shader, GLenum pname, GLint* params); - const GLubyte *(*GetString)(GLenum name); - GLint (*GetUniformLocation)(GLuint program, const GLchar* name); - void (*LinkProgram)(GLuint program); - void (*ReadPixels)(GLint x, GLint y, GLsizei width, GLsizei height, - GLenum format, GLenum type, GLvoid* data); - void (*ShaderSource)(GLuint shader, GLsizei count, const GLchar** string, - const GLint* length); - void (*TexImage2D)(GLenum target, GLint level, GLint internal_format, - GLsizei width, GLsizei height, GLint border, GLenum format, - GLenum type, const GLvoid* data); - void (*TexParameteri)(GLenum target, GLenum pname, GLint param); - void (*TexSubImage2D)(GLenum target, GLint level, GLint xoffset, - GLint yoffset, GLsizei width, GLsizei height, - GLenum format, GLenum type, const GLvoid* data); - void (*Uniform1f)(GLint location, GLfloat v0); - void (*Uniform1i)(GLint location, GLint v0); - void (*UseProgram)(GLuint program); - void (*VertexAttribPointer)(GLuint index, GLint size, GLenum type, - GLboolean normalized, GLsizei stride, - const GLvoid* pointer); - void (*Viewport)(GLint x, GLint y, GLsizei width, GLsizei height); -}; - -/*! - * \brief Process global OpenGL workspace. - */ -class OpenGLWorkspace final : public DeviceAPI { - public: - ~OpenGLWorkspace() final; - - // override device API - void SetDevice(TVMContext ctx) final; - void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final; - void* AllocDataSpace(TVMContext ctx, - size_t nbytes, - size_t alignment, - DLDataType type_hint) final; - void FreeDataSpace(TVMContext ctx, void* ptr) final; - void CopyDataFromTo(const void* from, - size_t from_offset, - void* to, - size_t to_offset, - size_t size, - TVMContext ctx_from, - TVMContext ctx_to, - DLDataType type_hint, - TVMStreamHandle stream) final; - void StreamSync(TVMContext ctx, TVMStreamHandle stream) final; - - /*! - * \brief Get the global OpenGL workspace. - * \return The global OpenGL workspace. - */ - static const std::shared_ptr& Global(); - - /*! - * \brief Create an OpenGL program that uses the given fragment shader. - * \param fragment_shader The fragment shader **source**. - * \return The OpenGL program. - */ - Program CreateProgram(const char* fragment_shader_src); - - /*! - * \brief Create an OpenGL texture that stores an array. - * \param type Element type. - * \param nbytes Number of bytes in the array. - * \return The OpenGL texture. - */ - Texture CreateTexture(DLDataType type, size_t nbytes); - - /*! - * \brief Upload user data into a sub-region of an OpenGL texture. - * \param texture The texture to be written to. - * \param begin The index of the first element to be written to. - * \param nelems The number of elements to be written to. - * \param data The user data. - */ - void PutTextureData(Texture* texture, - GLint begin, - GLsizei nelems, - const GLvoid* data); - /*! - * \brief Download a sub-region of an OpenGL texture. - * \param texture The texture to download from. - * \param begin The index of first element to download from. - * \param nelems The number of elements to download from. - * \param data The user buffer. - */ - void GetTextureData(const Texture* texture, - GLint begin, - GLsizei nelems, - GLvoid* data); - - /*! - * \brief Set currently used OpenGL program. - */ - void SetCurrentProgram(const Program& program); - - /*! - * \brief Set uniform values for an OpenGL program. - * Must call SetCurrentProgram before calling this. - * \param program The OpenGL program. - * \param name The uniform argument name. - * \param type The type of the uniform. - * \param value The value to pass in. - */ - void SetUniform(const Program& program, - const std::string& name, - DLDataType type, - void* value); - - /*! - * \brief Set input texture for an OpenGL program. - * Must call SetCurrentProgram before calling this. - * \param program The OpenGL program. - * \param name The texture uniform argument name. - * \param unit The texture unit to use. Each input texture must occupy a - * different unit. - * \param texture The OpenGL texture to pass in. - */ - void SetInputTexture(const Program& program, - const std::string& name, - GLuint unit, - Texture* texture); - - /*! - * \brief Render to a texture. - * \param output The output texture. - */ - void Render(Texture* output); - - private: - friend class Texture; - friend class Program; - - // Global singleton. Hide constructor. - OpenGLWorkspace(); - - GLFWwindow* window_; - std::unique_ptr gl; - GLuint vertex_shader_; - static const int kWindowWidth = 640; - static const int kWindowHeight = 480; - struct Vertex { - float x, y; - }; - static constexpr size_t kNumVertices = 6; - static const Vertex vertices[kNumVertices]; - static const char* vertex_shader_text_; - - /*! - * \brief Bind a texture to a "texture unit". - * After calling this function, the "texture unit" becomes "active", and the - * texture is bound to GL_TEXTURE_2D in that "texture unit". - * \param unit The texture unit to activate. - * \param texture The texture to bind. - */ - void BindTextureUnit(GLuint unit, GLuint texture); - - /*! - * \brief Callback in Texture's destructor. - */ - void OnDeleteTexture(GLuint texture); - - /*! - * \brief Callback in Program's destructor. - */ - void OnDeleteProgram(GLuint program); - - /*! - * \brief Check if there is any outstanding OpenGL error. If there is, crash. - */ - void CheckOpenGLError(); - - /*! - * \brief Get the maximum number of texture units. - */ - GLuint NumTextureUnits(); - - /*! - * \brief Create and compile a shader from a source string. - * \param shader_kind The kind of shader. - * Could be GL_VERTEX_SHADER or GL_FRAGMENT_SHADER. - * \param shader_src The source string of the shader. - * \return The compiled shader ID. - */ - GLuint CreateShader(GLenum shader_kind, const char* shader_src); - - /*! - * \brief Create an OpenGL program that uses the given fragment shader. - * \param fragment_shader The **compiled** fragment shader. - * \return The OpenGL program. - */ - Program CreateProgram(GLuint fragment_shader); -}; - -/*! - * \brief An OpenGL program, composed of a vertex shader and a fragment shader. - * In TVM, every program has the same vertex shader. - * So a program just corresponds to a fragment shader. - * A program can only be created by the workspace. - * This class is just a wrapper over an OpenGL program ID. - */ -class Program { - public: - // Move constructor. - Program(Program&& other) noexcept - : workspace_(other.workspace_), program_(other.program_) { - other.program_ = kInvalidProgram; - } - - // Move assignment. - Program& operator=(Program&& other) noexcept { - workspace_ = other.workspace_; - program_ = other.program_; - other.program_ = kInvalidProgram; - return *this; - } - - // Disallow copy. - Program(const Program& other) = delete; - Program& operator=(const Program& other) = delete; - - // Destructor. - ~Program() { - if (program_ != kInvalidProgram) { - workspace_->OnDeleteProgram(program_); - program_ = kInvalidProgram; - } - } - - private: - friend class OpenGLWorkspace; - - // Only OpenGLWorkspace can create a Program. - // We enforce this to make sure OpenGL is initialized. - explicit Program(OpenGLWorkspace* workspace, GLuint program) - : workspace_(workspace), program_(program) {} - - // The internal OpenGL program ID. - GLuint program() const { return program_; } - - static constexpr GLuint kInvalidProgram = static_cast(-1); - - OpenGLWorkspace* workspace_; - GLuint program_; -}; - -/*! - * \brief The storage format of a texture. - * The members match the API of glTexImage2D. - */ -struct TextureFormat { - TextureFormat(GLint internal_format, GLenum format, GLenum type) - : internal_format(internal_format), format(format), type(type) {} - - GLsizei elemsz() const { - switch (type) { - case GL_BYTE: case GL_UNSIGNED_BYTE: - return 1; - case GL_SHORT: case GL_UNSIGNED_SHORT: - return 2; - case GL_INT: case GL_UNSIGNED_INT: - return 4; - case GL_FLOAT: - return 4; - default: - LOG(FATAL) << "Unsupported type"; - return -1; - } - } - - bool operator==(const TextureFormat& other) const { - return std::make_tuple(internal_format, format, type) == - std::make_tuple(other.internal_format, other.format, other.type); - } - - GLint internal_format; // OpenGL says this is GLint, not GLenum. - GLenum format; - GLenum type; -}; - -/*! - * \brief An OpenGL texture represents a chunk of GPU memory. - * This is the way we represent tensors. - * We always use 2D textures. - */ -class Texture { - public: - // Move constructor. - Texture(Texture&& other) noexcept - : workspace_(other.workspace_), texture_(other.texture_), - format_(other.format_), width_(other.width_), height_(other.height_) { - other.texture_ = kInvalidTexture; - } - - // Move assignment. - Texture& operator=(Texture&& other) noexcept { - workspace_ = other.workspace_; - texture_ = other.texture_; - format_ = other.format_; - width_ = other.width_; - height_ = other.height_; - other.texture_ = kInvalidTexture; - return *this; - } - - // Disallow copy. - Texture(const Texture& other) = delete; - Texture& operator=(const Texture& other) = delete; - - // Destructor. - ~Texture() { - if (texture_ != kInvalidTexture) { - workspace_->OnDeleteTexture(texture_); - texture_ = kInvalidTexture; - } - } - - /*! - * \brief The width of the texture in number of pixels. - */ - GLsizei width() const { return width_; } - - /*! - * \brief The height of the texture in number of pixels. - */ - GLsizei height() const { return height_; } - - /*! - * \brief The number of bytes of each element in the array. - */ - GLsizei elemsz() const { return format_.elemsz(); } - - private: - friend class OpenGLWorkspace; - - // Only OpenGLWorkspace can create a Texture. - // We enforce this to make sure OpenGL is initialized. - // Always only use the first dimension of a 2D texture. - // The reason is that texelFetch only supports 2D textures. - explicit Texture(OpenGLWorkspace* workspace, GLuint texture, - TextureFormat format, - GLsizei width, GLsizei height) - : workspace_(workspace), texture_(texture), format_(format), - width_(width), height_(height) {} - - // The internal texture ID. - GLuint texture() const { return texture_; } - - static constexpr GLuint kInvalidTexture = static_cast(-1); - - OpenGLWorkspace* workspace_; - GLuint texture_; - TextureFormat format_; - GLsizei width_; - GLsizei height_; -}; - -} // namespace gl -} // namespace runtime -} // namespace tvm - -#endif // TVM_RUNTIME_OPENGL_OPENGL_COMMON_H_ diff --git a/src/runtime/opengl/opengl_device_api.cc b/src/runtime/opengl/opengl_device_api.cc deleted file mode 100644 index 0be921cb4ae5..000000000000 --- a/src/runtime/opengl/opengl_device_api.cc +++ /dev/null @@ -1,633 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file opengl_device_api.cc - */ -#include -#include -#include "opengl_common.h" -#include "opengl_module.h" - -namespace tvm { -namespace runtime { -namespace gl { - -/*! - * \brief Turn OpenGL error enum to string. - */ -static const char* GLGetErrorString(GLenum error) { - switch (error) { - case GL_NO_ERROR: - return "GL_NO_ERROR"; - case GL_INVALID_ENUM: - return "GL_INVALID_ENUM"; - case GL_INVALID_VALUE: - return "GL_INVALID_VALUE"; - case GL_INVALID_OPERATION: - return "GL_INVALID_OPERATION"; -#if !defined(__APPLE__) - case GL_STACK_OVERFLOW: - return "GL_STACK_OVERFLOW"; - case GL_STACK_UNDERFLOW: - return "GL_STACK_UNDERFLOW"; -#endif - case GL_OUT_OF_MEMORY: - return "GL_OUT_OF_MEMORY"; - default: - return "Unknown OpenGL error code"; - } -} - -/*! - * \brief Get the latest error. - */ -void OpenGLWorkspace::CheckOpenGLError() { - GLenum err = gl->GetError(); - CHECK_EQ(err, GL_NO_ERROR) << "OpenGL error, code=" << err << ": " - << gl::GLGetErrorString(err); -} - -/*! - * \brief Protected OpenGL call. - * \param func Expression to call. - */ -#define OPENGL_CALL(func) \ - { \ - (func); \ - CheckOpenGLError(); \ - } - -/*! - * \brief The error handling callback passed to GLFW. - */ -void GlfwErrorCallback(int err, const char* str) { - LOG(FATAL) << "Error: [" << err << "] " << str; -} - -const std::shared_ptr& OpenGLWorkspace::Global() { - static std::shared_ptr inst(new OpenGLWorkspace); - return inst; -} - -void OpenGLWorkspace::SetDevice(TVMContext ctx) { - CHECK_EQ(ctx.device_type, static_cast(kOpenGL)) - << "Device type must be OpenGL."; - CHECK_EQ(ctx.device_id, 0) << "Only support 1 OpenGL \"device\"."; -} - -void OpenGLWorkspace::GetAttr( - TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) { - switch (kind) { - case kExist: { - *rv = static_cast(ctx.device_id == 0); - break; - } - case kMaxThreadsPerBlock: { - GLint max_texture_size; - OPENGL_CALL(gl->GetIntegerv(GL_MAX_TEXTURE_SIZE, &max_texture_size)); - break; - } - case kWarpSize: { - *rv = 1; - break; - } - case kMaxSharedMemoryPerBlock: return; - case kComputeVersion: { - break; - } - case kDeviceName: return; - case kMaxClockRate: return; - case kMultiProcessorCount: return; - case kMaxThreadDimensions: return; - case kGcnArch: return; - } -} - -void* OpenGLWorkspace::AllocDataSpace( - TVMContext ctx, size_t nbytes, size_t alignment, DLDataType type_hint) { - return reinterpret_cast(new Texture(CreateTexture(type_hint, nbytes))); -} - -void OpenGLWorkspace::FreeDataSpace(TVMContext ctx, void* ptr) { - delete reinterpret_cast(ptr); -} - -void OpenGLWorkspace::CopyDataFromTo(const void* from, - size_t from_offset, - void* to, - size_t to_offset, - size_t size, - TVMContext ctx_from, - TVMContext ctx_to, - DLDataType type_hint, - TVMStreamHandle stream) { - CHECK(stream == nullptr); - - // TODO(zhixunt): This is a nasty hack to avoid comparison between - // incompatible enums. We should add kOpenGL to dlpack. - constexpr int gl_devtype = kOpenGL; - std::tuple type_from_to(ctx_from.device_type, ctx_to.device_type); - - if (type_from_to == std::make_tuple(gl_devtype, gl_devtype)) { - auto from_texture = static_cast(from); - auto to_texture = static_cast(to); - auto temp_buffer = std::unique_ptr(new char[size]); - CHECK(from_texture->format_ == to_texture->format_); - auto elemsz = from_texture->elemsz(); - auto from_begin = static_cast(from_offset / elemsz); - auto to_begin = static_cast(to_offset / elemsz); - auto nelems = static_cast(size / elemsz); - GetTextureData(from_texture, from_begin, nelems, temp_buffer.get()); - PutTextureData(to_texture, to_begin, nelems, temp_buffer.get()); - - } else if (type_from_to == std::make_tuple(gl_devtype, kDLCPU)) { - auto texture = static_cast(from); - void *data = static_cast(to) + to_offset; - auto elemsz = texture->elemsz(); - auto begin = static_cast(from_offset / elemsz); - auto nelems = static_cast(size / elemsz); - GetTextureData(texture, begin, nelems, data); - - } else if (type_from_to == std::make_tuple(kDLCPU, gl_devtype)) { - auto texture = reinterpret_cast(to); - const void* data = static_cast(from) + from_offset; - auto elemsz = texture->elemsz(); - auto begin = static_cast(to_offset / elemsz); - auto nelems = static_cast(size / elemsz); - PutTextureData(texture, begin, nelems, data); - - } else { - LOG(FATAL) << "Expect copy from/to OpenGL or between OpenGL"; - } -} - -void OpenGLWorkspace::StreamSync(TVMContext ctx, TVMStreamHandle stream) {} - -OpenGLWorkspace::OpenGLWorkspace() { - // Set an error handler. - // This can be called before glfwInit(). - glfwSetErrorCallback(&GlfwErrorCallback); - - // Initialize GLFW. - if (glfwInit() != GL_TRUE) { - LOG(FATAL) << "glfwInit() failed!"; - } - - // Create a window. - glfwWindowHint(GLFW_CONTEXT_VERSION_MAJOR, 3); - glfwWindowHint(GLFW_CONTEXT_VERSION_MINOR, 3); - glfwWindowHint(GLFW_OPENGL_FORWARD_COMPAT, GL_TRUE); - glfwWindowHint(GLFW_OPENGL_PROFILE, GLFW_OPENGL_CORE_PROFILE); - glfwWindowHint(GLFW_VISIBLE, GL_FALSE); - window_ = glfwCreateWindow(kWindowWidth, kWindowHeight, "", nullptr, nullptr); - if (window_ == nullptr) { - LOG(FATAL) << "glfwCreateWindow() failed!"; - } - - // Before using any OpenGL API, we must specify a context. - glfwMakeContextCurrent(window_); - - // Load all OpenGL API function pointers. - gl = std::unique_ptr(new GLFunctionPointers); - - CheckOpenGLError(); - - // We always render the same vertices and triangles. - GLuint vertex_buffer; - OPENGL_CALL(gl->GenBuffers(1, &vertex_buffer)); - OPENGL_CALL(gl->BindBuffer(GL_ARRAY_BUFFER, vertex_buffer)); - OPENGL_CALL(gl->BufferData(GL_ARRAY_BUFFER, sizeof(vertices), vertices, - GL_STATIC_DRAW)); - - GLuint vertex_array; - OPENGL_CALL(gl->GenVertexArrays(1, &vertex_array)); - OPENGL_CALL(gl->BindVertexArray(vertex_array)); - OPENGL_CALL(gl->BindBuffer(GL_ARRAY_BUFFER, vertex_buffer)); - - // We always use the same vertex shader. - vertex_shader_ = CreateShader(GL_VERTEX_SHADER, vertex_shader_text_); - - LOG(INFO) << "OpenGL initialized, version = " << gl->GetString(GL_VERSION); -} - -OpenGLWorkspace::~OpenGLWorkspace() { - // Paired with glfwCreateWindow(). - glfwDestroyWindow(window_); - - // Paired with glfwInit(). - glfwTerminate(); -} - -void OpenGLWorkspace::BindTextureUnit(GLuint unit, GLuint texture) { - OPENGL_CALL(gl->ActiveTexture(GL_TEXTURE0 + unit)); - OPENGL_CALL(gl->BindTexture(GL_TEXTURE_2D, texture)); -} - -void OpenGLWorkspace::OnDeleteTexture(GLuint texture) { - OPENGL_CALL(gl->DeleteTextures(1, &texture)); -} - -void OpenGLWorkspace::OnDeleteProgram(GLuint program) { - OPENGL_CALL(gl->DeleteProgram(program)); -} - -GLuint OpenGLWorkspace::NumTextureUnits() { - GLint num_units; - OPENGL_CALL(gl->GetIntegerv(GL_MAX_COMBINED_TEXTURE_IMAGE_UNITS, &num_units)); - return static_cast(num_units); -} - -const OpenGLWorkspace::Vertex OpenGLWorkspace::vertices[OpenGLWorkspace::kNumVertices] = { - {-1.f, -1.f}, - {1.0f, -1.f}, - {1.0f, 1.0f}, - {-1.f, -1.f}, - {-1.f, 1.0f}, - {1.0f, 1.0f}, -}; - -// Don't need to change this. -// The vertex shader only needs to take in the triangle points. -// No need for point transformations. -const char* OpenGLWorkspace::vertex_shader_text_ = "#version 300 es\n" - "in vec2 point; // input to vertex shader\n" - "void main() {\n" - " gl_Position = vec4(point, 0.0, 1.0);\n" - "}\n"; - -Program OpenGLWorkspace::CreateProgram( - const char* fragment_shader_src) { - // Create and compile the shaders. - GLuint fragment_shader = CreateShader(GL_FRAGMENT_SHADER, - fragment_shader_src); - - // Link the shaders and create the program. - Program program = CreateProgram(fragment_shader); - - OPENGL_CALL(gl->DeleteShader(fragment_shader)); - - return program; -} - -GLuint OpenGLWorkspace::CreateShader(GLenum shader_kind, - const char* shader_src) { - // Create the shader. - GLuint shader = gl->CreateShader(shader_kind); - gl->ShaderSource(shader, 1, &shader_src, nullptr); - gl->CompileShader(shader); - - // Check compile errors. - GLint err; - gl->GetShaderiv(shader, GL_COMPILE_STATUS, &err); - - GLint info_log_len; - gl->GetShaderiv(shader, GL_INFO_LOG_LENGTH, &info_log_len); - - if (err != GL_TRUE) { - std::unique_ptr err_msg(new char[info_log_len + 1]); - gl->GetShaderInfoLog(shader, info_log_len, nullptr, err_msg.get()); - LOG(FATAL) << err_msg.get() << "\n" << shader_src; - assert(false); - } - - CheckOpenGLError(); - - return shader; -} - -static TextureFormat GetTextureFormat(DLDataType type) { - CHECK_EQ(type.lanes, 1) << "Not supporting multi-lane types."; - - switch (type.code) { - case kDLInt: { - switch (type.bits) { - case 8: - return {GL_R8I, GL_RED_INTEGER, GL_BYTE}; - case 16: - return {GL_R16I, GL_RED_INTEGER, GL_SHORT}; - case 32: - return {GL_R32I, GL_RED_INTEGER, GL_INT}; - default: - LOG(FATAL) << "Unsupported type bits " << type.bits; - } - } - case kDLUInt: { - switch (type.bits) { - case 8: - return {GL_R8UI, GL_RED_INTEGER, GL_UNSIGNED_BYTE}; - case 16: - return {GL_R16UI, GL_RED_INTEGER, GL_UNSIGNED_SHORT}; - case 32: - return {GL_R32UI, GL_RED_INTEGER, GL_UNSIGNED_INT}; - default: - LOG(FATAL) << "Unsupported type bits " << type.bits; - } - } - case kDLFloat: { - switch (type.bits) { - case 32: - return {GL_R32F, GL_RED, GL_FLOAT}; - default: - LOG(FATAL) << "Unsupported type bits " << type.bits; - } - } - default: { - LOG(FATAL) << "Unsupported type code" << type.code; - } - } - return {GL_R32F, GL_RED, GL_FLOAT}; -} - -Texture OpenGLWorkspace::CreateTexture(DLDataType type, size_t nbytes) { - // Create a texture. - GLuint texture; - OPENGL_CALL(gl->GenTextures(1, &texture)); - - BindTextureUnit(NumTextureUnits() - 1, texture); - - // Use glTexImage2D with nullptr data to specify GPU data storage. - auto texture_format = GetTextureFormat(type); - auto nelems = static_cast(nbytes / (type.bits / 8)); - auto height = (nelems + kTextureRowSize - 1) / kTextureRowSize; - auto width = (height == 1) ? nelems : kTextureRowSize; - OPENGL_CALL(gl->TexImage2D(GL_TEXTURE_2D, /*level=*/0, - texture_format.internal_format, - width, height, /*border=*/0, - texture_format.format, texture_format.type, - /*data=*/nullptr)); - - OPENGL_CALL( - gl->TexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_CLAMP_TO_EDGE)); - OPENGL_CALL( - gl->TexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_CLAMP_TO_EDGE)); - OPENGL_CALL( - gl->TexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST)); - OPENGL_CALL( - gl->TexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST)); - - return Texture(this, texture, texture_format, width, height); -} - -Program OpenGLWorkspace::CreateProgram(GLuint fragment_shader) { - // Create the program and link the shaders. - GLuint program = gl->CreateProgram(); - gl->AttachShader(program, vertex_shader_); - gl->AttachShader(program, fragment_shader); - gl->LinkProgram(program); - - // Check link errors. - GLint err; - gl->GetProgramiv(program, GL_LINK_STATUS, &err); - - GLint info_log_len; - gl->GetProgramiv(program, GL_INFO_LOG_LENGTH, &info_log_len); - - if (err != GL_TRUE) { - std::unique_ptr err_msg(new char[info_log_len + 1]); - gl->GetProgramInfoLog(program, info_log_len, nullptr, err_msg.get()); - LOG(FATAL) << err_msg.get(); - assert(false); - } - - CheckOpenGLError(); - - OPENGL_CALL(gl->DetachShader(program, vertex_shader_)); - OPENGL_CALL(gl->DetachShader(program, fragment_shader)); - - auto point_attrib = GLuint(gl->GetAttribLocation(program, "point")); - OPENGL_CALL(gl->EnableVertexAttribArray(point_attrib)); - - OPENGL_CALL(gl->VertexAttribPointer(point_attrib, 2, GL_FLOAT, GL_FALSE, - sizeof(Vertex), nullptr)); - - return Program(this, program); -} - -/*! - * \brief Visit a 1D range of an OpenGL texture-backed TVM array. - * When getting/setting a sub image of a texture, we can only specify a 2D - * block (xbeg, ybeg, width, height). - * Since we are storing all TVM arrays using (kTextureRowSize x nrows) 2D - * textures (row-major), a range in an array does not necessarily map to a 2D - * block. - * This function split a 1D range into 3 2D blocks. - * \param beg The index of the first element in the 1D range. - * \param end The index of the last + 1 element in the 1D range. - * \param on_2d_block Callback for each 2D block. Must have interface - * void(GLint xbeg, GLint ybeg, GLsizei width, GLsizei height). - */ -template -static void Visit1DRange(GLint beg, GLint end, F&& on_2d_block) { - CHECK_LE(beg, end) << "Invalid range."; - - // xbeg kTextureRowSize - // ybeg ....************ - // **************** - // **************** - // ylast *********....... - // xlast - GLint xbeg = beg % kTextureRowSize; - GLint ybeg = beg / kTextureRowSize; - GLint xlast = (end - 1) % kTextureRowSize; - GLint ylast = (end - 1) / kTextureRowSize; - - if (ybeg == ylast) { // Only one row. - on_2d_block(xbeg, ybeg, end - beg, 1); - return; - } - - // First row. - on_2d_block(xbeg, ybeg, kTextureRowSize - xbeg, 1); - - // Middle block. - if (ylast - ybeg > 1) { - on_2d_block(0, ybeg + 1, kTextureRowSize, ylast - ybeg - 1); - } - - // Last row. - on_2d_block(0, ylast, xlast + 1, 1); -} - -void OpenGLWorkspace::PutTextureData(Texture *texture, - GLint begin, - GLsizei nelems, - const GLvoid* data) { - // Bind to temporary unit. - BindTextureUnit(NumTextureUnits() - 1, texture->texture()); - - Visit1DRange(begin, begin + nelems, [&](GLint xbeg, GLint ybeg, - GLsizei width, GLsizei height) { - auto offset = (ybeg * kTextureRowSize + xbeg - begin) * texture->elemsz(); - const GLvoid* ptr = static_cast(data) + offset; - - // Similar to cudaMemcpy. - OPENGL_CALL(gl->TexSubImage2D(GL_TEXTURE_2D, /*level=*/0, - xbeg, ybeg, width, height, - texture->format_.format, - texture->format_.type, ptr)); - }); -} - -void OpenGLWorkspace::GetTextureData(const Texture *texture, - GLint begin, - GLsizei nelems, - GLvoid* data) { - BindTextureUnit(NumTextureUnits() - 1, texture->texture()); - - // Create frame buffer. - GLuint frame_buffer; - OPENGL_CALL(gl->GenFramebuffers(1, &frame_buffer)); - OPENGL_CALL(gl->BindFramebuffer(GL_FRAMEBUFFER, frame_buffer)); - - // Bind texture to framebuffer's attachment 0. - OPENGL_CALL(gl->FramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, - GL_TEXTURE_2D, texture->texture(), 0)); - - // Always check that our framebuffer is okay. - if (gl->CheckFramebufferStatus(GL_FRAMEBUFFER) != GL_FRAMEBUFFER_COMPLETE) { - LOG(FATAL) << "Framebuffer not complete."; - } - -#ifdef __EMSCRIPTEN__ - // WebGL2's glReadPixels API doesn't allow GL_RED user buffer format. - // Instead, We must use GL_RGBA. This means the data we retrieve has useless - // GBA channels. Here we are applying a dirty hack. - // TODO(zhixunt): We really want to utilize all RGBA channels in textures. - // - // WebGL2's glReadPixels API also doesn't allow GL_RED_INTEGER or - // GL_RGB_INTEGER user buffer format, which means we cannot retrieve integer - // texture data? (need to confirm) - - CHECK_EQ(texture->format_.internal_format, GL_R32F) - << "Retrieving integer texture not supported yet."; - auto elemsz = texture->format_.elemsz(); - auto nchannels = 4; - auto padded_data_size = nchannels * nelems * elemsz; - auto padded_data = std::unique_ptr(new char[padded_data_size]); - Visit1DRange(begin, begin + nelems, [&](GLint xbeg, GLint ybeg, - GLsizei width, GLsizei height) { - auto data_offset = (ybeg * kTextureRowSize + xbeg - begin) * elemsz; - auto padded_data_offset = data_offset * nchannels; - OPENGL_CALL(gl->ReadPixels(xbeg, ybeg, width, height, - GL_RGBA, GL_FLOAT, - padded_data.get() + padded_data_offset)); - }); - for (GLsizei i = 0; i != nelems; ++i) { - auto dst = reinterpret_cast(data) + i * elemsz; - auto src = padded_data.get() + nchannels * i * elemsz; - std::memcpy(dst, src, elemsz); - } -#else - Visit1DRange(begin, begin + nelems, [&](GLint xbeg, GLint ybeg, - GLsizei width, GLsizei height) { - auto offset = (ybeg * kTextureRowSize + xbeg - begin) * texture->elemsz(); - GLvoid* ptr = static_cast(data) + offset; - - OPENGL_CALL(gl->ReadPixels(xbeg, ybeg, width, height, - texture->format_.format, texture->format_.type, - ptr)); - }); -#endif - - OPENGL_CALL(gl->DeleteFramebuffers(1, &frame_buffer)); -} - -void OpenGLWorkspace::SetCurrentProgram(const Program& program) { - OPENGL_CALL(gl->UseProgram(program.program())); -} - -void OpenGLWorkspace::SetUniform(const Program& program, - const std::string& name, - DLDataType type, - void* value) { - GLint location = gl->GetUniformLocation(program.program(), name.c_str()); - switch (type.code) { - case kDLInt: { - CHECK_EQ(type.bits, 32) << "Only support 32-bit int for uniform."; - GLint uniform_value = *reinterpret_cast(value); - OPENGL_CALL(gl->Uniform1i(location, uniform_value)); - break; - } - case kDLUInt: { - LOG(FATAL) << "Strangely, emcc WebGL does not support glUniform1ui."; - break; - } - case kDLFloat: { - CHECK_EQ(type.bits, 32) << "Only support 32-bit float for uniform."; - GLfloat uniform_value = *reinterpret_cast(value); - OPENGL_CALL(gl->Uniform1f(location, uniform_value)); - break; - } - default: { - LOG(FATAL) << "Unsupported type code for uniform."; - break; - } - } -} - -void OpenGLWorkspace::SetInputTexture(const Program& program, - const std::string& name, - GLuint unit, - Texture* texture) { - // We always use the last texture unit as temporary. - // Therefore, we can have "NumTextureUnits() - 1" input textures. - CHECK_LT(unit, NumTextureUnits() - 1) << "Too many textures."; - - BindTextureUnit(unit, texture->texture()); - GLint location = gl->GetUniformLocation(program.program_, name.c_str()); - OPENGL_CALL(gl->Uniform1i(location, unit)); -} - -void OpenGLWorkspace::Render(Texture* output) { - // Create frame buffer. - GLuint frame_buffer; - OPENGL_CALL(gl->GenFramebuffers(1, &frame_buffer)); - OPENGL_CALL(gl->BindFramebuffer(GL_FRAMEBUFFER, frame_buffer)); - - // Set "renderedTexture" as our colour attachement 0. - OPENGL_CALL(gl->FramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, - GL_TEXTURE_2D, output->texture(), 0)); - - // Specify that we will render to color attachment 0. - GLenum DrawBuffers[1] = {GL_COLOR_ATTACHMENT0}; - OPENGL_CALL(gl->DrawBuffers(1, DrawBuffers)); - - // Always check that our framebuffer is okay. - if (gl->CheckFramebufferStatus(GL_FRAMEBUFFER) != GL_FRAMEBUFFER_COMPLETE) { - LOG(FATAL) << "Framebuffer not complete."; - } - - // Perform rendering. - OPENGL_CALL(gl->Viewport(0, 0, output->width(), output->height())); - OPENGL_CALL(gl->Clear(GL_COLOR_BUFFER_BIT)); - OPENGL_CALL(gl->DrawArrays(GL_TRIANGLES, 0, 6)); - - OPENGL_CALL(gl->DeleteFramebuffers(1, &frame_buffer)); -} - -TVM_REGISTER_GLOBAL("device_api.opengl") -.set_body([](TVMArgs args, TVMRetValue* rv) { - DeviceAPI* ptr = OpenGLWorkspace::Global().get(); - *rv = static_cast(ptr); -}); - -} // namespace gl -} // namespace runtime -} // namespace tvm diff --git a/src/runtime/opengl/opengl_module.cc b/src/runtime/opengl/opengl_module.cc deleted file mode 100644 index 6435aca1bfdd..000000000000 --- a/src/runtime/opengl/opengl_module.cc +++ /dev/null @@ -1,297 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file opengl_module.cc - */ -#include -#include -#include -#include "opengl_common.h" -#include "opengl_module.h" -#include "../pack_args.h" -#include "../thread_storage_scope.h" -#include "../file_util.h" - -namespace tvm { -namespace runtime { - -class OpenGLModuleNode final : public ModuleNode { - public: - OpenGLModuleNode(std::unordered_map shaders, - std::string fmt, - std::unordered_map fmap); - - ~OpenGLModuleNode() override = default; - - const char* type_key() const final { return "opengl"; } - - PackedFunc GetFunction(const std::string& name, - const ObjectPtr& sptr_to_self) final; - - std::string GetSource(const std::string& format) final; - - void SaveToFile(const std::string& file_name, - const std::string& format) final; - - void SaveToBinary(dmlc::Stream* stream) final; - - const gl::Program& GetProgram(const std::string& func_name) const; - - const OpenGLShader& GetShader(const std::string& func_name) const; - - const FunctionInfo& GetFunctionInfo(const std::string& func_name) const; - - gl::OpenGLWorkspace& workspace() const { return *workspace_; } - - private: - std::shared_ptr workspace_; - std::unordered_map shaders_; - std::string fmt_; - std::unordered_map fmap_; - std::unordered_map programs_; - - DISALLOW_COPY_AND_ASSIGN(OpenGLModuleNode); -}; - -class OpenGLWrappedFunc { - public: - OpenGLWrappedFunc(OpenGLModuleNode* m, - ObjectPtr sptr, - std::string func_name, - std::vector arg_size, - const std::vector& thread_axis_tags); - - void operator()(TVMArgs args, TVMRetValue* rv, void** void_args) const; - - private: - // The module - OpenGLModuleNode* m_; - // resource handle - ObjectPtr sptr_; - // The name of the function. - std::string func_name_; - // convert code for void argument - std::vector arg_size_; - // thread axis config - ThreadAxisConfig thread_axis_cfg_; -}; - -OpenGLModuleNode::OpenGLModuleNode( - std::unordered_map shaders, - std::string fmt, - std::unordered_map fmap) - : workspace_(gl::OpenGLWorkspace::Global()), shaders_(std::move(shaders)), - fmt_(std::move(fmt)), fmap_(std::move(fmap)), programs_() { - CHECK_EQ(fmt_, "gl") << "Unknown OpenGL format " << fmt_; - for (auto &pair : shaders_) { - auto &func_name = pair.first; - auto &shader = pair.second; - programs_.emplace(func_name, - workspace_->CreateProgram(shader.source.c_str())); - } -} - -PackedFunc OpenGLModuleNode::GetFunction( - const std::string& name, - const ObjectPtr& sptr_to_self) { - CHECK_EQ(sptr_to_self.get(), this); - CHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main"; - - auto func_info_it = fmap_.find(name); - if (func_info_it == fmap_.end()) { return PackedFunc(); } - auto &func_info = func_info_it->second; - - std::vector arg_size(func_info.arg_types.size()); - for (size_t i = 0; i < func_info.arg_types.size(); ++i) { - DLDataType t = func_info.arg_types[i]; - CHECK_EQ(t.lanes, 1U); - uint32_t bits = t.bits; - CHECK_EQ(bits % 8, 0U); - arg_size[i] = bits / 8; - } - - // Initialize the wrapped func. - OpenGLWrappedFunc f(this, sptr_to_self, name, arg_size, - func_info.thread_axis_tags); - return PackFuncVoidAddr(f, func_info.arg_types); -} - -std::string OpenGLModuleNode::GetSource(const std::string& format) { - if (format != fmt_ && fmt_ != "gl") { return ""; } - - std::ostringstream os; - for (auto &pair : shaders_) { - auto &name = pair.first; - auto &shader = pair.second; - os << "[" << name << "]" << "\n"; - os << shader.source <<"\n"; - } - return os.str(); -} - -void OpenGLModuleNode::SaveToFile(const std::string& file_name, - const std::string& format) { - std::string fmt = GetFileFormat(file_name, format); - CHECK_EQ(fmt, fmt_) << "Can only save to format=" << fmt_; - std::string meta_file = GetMetaFilePath(file_name); - SaveMetaDataToFile(meta_file, fmap_); - SaveBinaryToFile(file_name, ToJSON(shaders_)); -} - -void OpenGLModuleNode::SaveToBinary(dmlc::Stream* stream) { - stream->Write(fmt_); - stream->Write(fmap_); - stream->Write(ToJSON(shaders_)); -} - -const gl::Program& OpenGLModuleNode::GetProgram( - const std::string& func_name) const { - auto it = programs_.find(func_name); - if (it == programs_.end()) { - LOG(FATAL) << "Cannot find program"; - } - return it->second; -} - -const OpenGLShader& OpenGLModuleNode::GetShader( - const std::string& func_name) const { - auto it = shaders_.find(func_name); - if (it == shaders_.end()) { - LOG(FATAL) << "Cannot find shader"; - } - return it->second; -} - -const FunctionInfo& OpenGLModuleNode::GetFunctionInfo( - const std::string& func_name) const { - auto it = fmap_.find(func_name); - if (it == fmap_.end()) { - LOG(FATAL) << "Cannot find shader"; - } - return it->second; -} - -OpenGLWrappedFunc::OpenGLWrappedFunc( - OpenGLModuleNode* m, - ObjectPtr sptr, - std::string func_name, - std::vector arg_size, - const std::vector& thread_axis_tags) - : m_(m), sptr_(std::move(sptr)), func_name_(std::move(func_name)), - arg_size_(std::move(arg_size)) { - thread_axis_cfg_.Init(arg_size_.size(), thread_axis_tags); -} - -void OpenGLWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, - void** void_args) const { - auto &shader = m_->GetShader(func_name_); - auto &program = m_->GetProgram(func_name_); - auto &func_info = m_->GetFunctionInfo(func_name_); - size_t nargs = shader.arg_kinds.size(); - - // Must call this function before setting uniforms & input textures. - m_->workspace().SetCurrentProgram(program); - - // Set all arguments. - GLuint texture_unit = 0; - gl::Texture* output = nullptr; - for (size_t i = 0; i != nargs; ++i) { - auto &name = shader.arg_names.at(i); - auto kind = shader.arg_kinds.at(i); - auto type = func_info.arg_types.at(i); - switch (kind) { - case OpenGLArgKind::kUniform: { - m_->workspace().SetUniform(program, name, type, void_args[i]); - break; - } - case OpenGLArgKind::kInputTexture: { - CHECK_EQ(type.code, kTVMOpaqueHandle) << "Type is not handle?"; - auto texture = *static_cast(void_args[i]); - m_->workspace().SetInputTexture(program, name, texture_unit, texture); - ++texture_unit; - break; - } - case OpenGLArgKind::kOutputTexture: { - CHECK_EQ(type.code, kTVMOpaqueHandle) << "Type is not handle?"; - CHECK(output == nullptr) << "Can only have one output texture."; - output = *static_cast(void_args[i]); - break; - } - } - } - - // Set "thread_extent" uniform. - ThreadWorkLoad wl = thread_axis_cfg_.Extract(args); - std::unique_ptr thread_extent(new GLint(wl.block_dim(0))); - m_->workspace().SetUniform(program, shader.thread_extent_var, - DLDataType{kDLInt, 32, 1}, - static_cast(thread_extent.get())); - - m_->workspace().Render(output); -} - -Module OpenGLModuleCreate(std::unordered_map shaders, - std::string fmt, - std::unordered_map fmap) { - auto n = make_object(std::move(shaders), - std::move(fmt), - std::move(fmap)); - return Module(n); -} - -Module OpenGLModuleLoadFile(const std::string& file_name, - const std::string& format) { - std::string data; - std::unordered_map fmap; - std::string fmt = GetFileFormat(file_name, format); - std::string meta_file = GetMetaFilePath(file_name); - LoadBinaryFromFile(file_name, &data); - LoadMetaDataFromFile(meta_file, &fmap); - return OpenGLModuleCreate(FromJSON(data), fmt, fmap); -} - -Module OpenGLModuleLoadBinary(void* strm) { - auto stream = static_cast(strm); - std::string data; - std::unordered_map fmap; - std::string fmt; - stream->Read(&fmt); - stream->Read(&fmap); - stream->Read(&data); - return OpenGLModuleCreate(FromJSON(data), fmt, fmap); -} - -TVM_REGISTER_GLOBAL("runtime.module.loadfile_gl") - .set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = OpenGLModuleLoadFile(args[0], args[1]); - }); - -TVM_REGISTER_GLOBAL("runtime.module.loadfile_glbin") - .set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = OpenGLModuleLoadFile(args[0], args[1]); - }); - -TVM_REGISTER_GLOBAL("runtime.module.loadbinary_opengl") - .set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = OpenGLModuleLoadBinary(args[0]); - }); - -} // namespace runtime -} // namespace tvm diff --git a/src/runtime/opengl/opengl_module.h b/src/runtime/opengl/opengl_module.h deleted file mode 100644 index 4d2d1c859253..000000000000 --- a/src/runtime/opengl/opengl_module.h +++ /dev/null @@ -1,182 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file opengl_module.h - * \brief Execution handling of OpenGL kernels - */ -#ifndef TVM_RUNTIME_OPENGL_OPENGL_MODULE_H_ -#define TVM_RUNTIME_OPENGL_OPENGL_MODULE_H_ - -#include -#include -#include -#include -#include -#include -#include -#include "../meta_data.h" - -namespace tvm { -namespace runtime { - -/*! - * \brief The fixed row size of all OpenGL textures in TVM. - * - * OpenGL has texture size limit on each dimension. Suppose we have a limit of - * 1024, then we can have a 2D texture of size (2^10 x 2^10) but not (2^20 x 1). - * This means we don't want to just use (n x 1) 2D textures for all arrays, - * because that would limit our array size to be 1024. Here we use (1024 x m) - * 2D textures. Then we can have arrays of size up to 2^20. - */ -static constexpr int kTextureRowBits = 10; -static constexpr int kTextureRowSize = 1 << kTextureRowBits; -static constexpr int kTextureRowMask = kTextureRowSize - 1; - -/*! - * \brief Determines how we supply arguments. - */ -enum class OpenGLArgKind { - kInputTexture = 0, // Bind to "gsampler2D" in GLSL. - kOutputTexture = 1, // Bind to "out" in GLSL. - kUniform = 2, // Bind to "uniform" in GLSL. -}; - -std::string OpenGLArgKind2String(OpenGLArgKind kind); -OpenGLArgKind String2OpenGLArgKind(const std::string& str); - -/*! - * \brief The output of OpenGL codegen. - * Contains necessary information to build a fragment shader and bind arguments. - */ -struct OpenGLShader { - OpenGLShader() = default; - OpenGLShader(std::string source, - std::vector arg_names, - std::vector arg_kinds, - std::string thread_extent_var) - : source(std::move(source)), arg_names(std::move(arg_names)), - arg_kinds(std::move(arg_kinds)), - thread_extent_var(std::move(thread_extent_var)) { - CHECK_EQ(this->arg_names.size(), this->arg_kinds.size()) << "Invalid input"; - } - - std::string source; - std::vector arg_names; // Matches FunctionInfo. - std::vector arg_kinds; // Matches FunctionInfo. - std::string thread_extent_var; // Stores the output length. - - void Save(dmlc::JSONWriter* writer) const; - void Load(dmlc::JSONReader* reader); -}; - -std::string ToJSON(const std::unordered_map& shaders); -std::unordered_map FromJSON(const std::string& str); - -/*! - * \brief Create an OpenGL module from data. - * - * \param data The module data. - * \param fmt The format of the data, - * \param fmap The map function information map of each function. - */ -Module OpenGLModuleCreate(std::unordered_map shaders, - std::string fmt, - std::unordered_map fmap); - -inline std::string OpenGLArgKind2String(OpenGLArgKind kind) { - switch (kind) { - case OpenGLArgKind::kOutputTexture: - return "output_texture"; - case OpenGLArgKind::kInputTexture: - return "input_texture"; - case OpenGLArgKind::kUniform: - return "uniform"; - default: - LOG(FATAL) << "invalid arg kind"; - return ""; - } -} - -inline OpenGLArgKind String2OpenGLArgKind(const std::string& str) { - if (str == "output_texture") { - return OpenGLArgKind::kOutputTexture; - } else if (str == "input_texture") { - return OpenGLArgKind::kInputTexture; - } else if (str == "uniform") { - return OpenGLArgKind::kUniform; - } else { - LOG(FATAL) << "Invalid OpenGL arg kind."; - return OpenGLArgKind::kUniform; - } -} - -inline void OpenGLShader::Save(dmlc::JSONWriter* writer) const { - std::vector arg_kind_strs; - for (auto kind : arg_kinds) { - arg_kind_strs.push_back(OpenGLArgKind2String(kind)); - } - - writer->BeginObject(); - writer->WriteObjectKeyValue("arg_names", arg_names); - writer->WriteObjectKeyValue("arg_kinds", arg_kind_strs); - writer->WriteObjectKeyValue("source", source); - writer->WriteObjectKeyValue("thread_extent_var", thread_extent_var); - writer->EndObject(); -} - -inline void OpenGLShader::Load(dmlc::JSONReader* reader) { - std::vector arg_kind_strs; - dmlc::JSONObjectReadHelper helper; - helper.DeclareField("arg_names", &arg_names); - helper.DeclareField("arg_kinds", &arg_kind_strs); - helper.DeclareField("source", &source); - helper.DeclareField("thread_extent_var", &thread_extent_var); - helper.ReadAllFields(reader); - - arg_kinds.clear(); - for (auto& str : arg_kind_strs) { - arg_kinds.push_back(String2OpenGLArgKind(str)); - } -} - -inline std::string ToJSON( - const std::unordered_map& shaders) { - std::ostringstream os; - dmlc::JSONWriter writer(&os); - writer.BeginObject(); - writer.WriteObjectKeyValue("shaders", shaders); - writer.EndObject(); - return os.str(); -} - -inline std::unordered_map FromJSON( - const std::string& str) { - std::unordered_map shaders; - std::istringstream is(str); - dmlc::JSONReader reader(&is); - dmlc::JSONObjectReadHelper helper; - helper.DeclareField("shaders", &shaders); - helper.ReadAllFields(&reader); - return shaders; -} - -} // namespace runtime -} // namespace tvm -#endif // TVM_RUNTIME_OPENGL_OPENGL_MODULE_H_ diff --git a/src/runtime/pack_args.h b/src/runtime/pack_args.h index 9d24ca9072b4..ae9771641b23 100644 --- a/src/runtime/pack_args.h +++ b/src/runtime/pack_args.h @@ -32,8 +32,9 @@ #define TVM_RUNTIME_PACK_ARGS_H_ #include -#include + #include +#include namespace tvm { namespace runtime { @@ -55,7 +56,7 @@ union ArgUnion { * * \return The wrapped packed function. */ -template +template inline PackedFunc PackFuncVoidAddr(F f, const std::vector& arg_types); /*! * \brief Create a packed function that from function only packs buffer arguments. @@ -66,7 +67,7 @@ inline PackedFunc PackFuncVoidAddr(F f, const std::vector& arg_types * * \return The wrapped packed function. */ -template +template inline PackedFunc PackFuncNonBufferArg(F f, const std::vector& arg_types); /*! * \brief Create a packed function that from function that takes a packed arguments. @@ -77,7 +78,7 @@ inline PackedFunc PackFuncNonBufferArg(F f, const std::vector& arg_t * * \return The wrapped packed function. */ -template +template inline PackedFunc PackFuncPackedArg(F f, const std::vector& arg_types); /*! * \brief Extract number of buffer argument from the argument types. @@ -88,23 +89,21 @@ inline size_t NumBufferArgs(const std::vector& arg_types); // implementations details namespace detail { -template +template class TempArray { public: explicit TempArray(int size) {} - T* data() { - return data_; - } + T* data() { return data_; } + private: T data_[kSize]; }; -template +template class TempArray { public: explicit TempArray(int size) : data_(size) {} - T* data() { - return data_.data(); - } + T* data() { return data_.data(); } + private: std::vector data_; }; @@ -120,8 +119,7 @@ enum ArgConvertCode { }; inline ArgConvertCode GetArgConvertCode(DLDataType t) { - CHECK_EQ(t.lanes, 1U) - << "Cannot pass vector type argument to devic function for now"; + CHECK_EQ(t.lanes, 1U) << "Cannot pass vector type argument to devic function for now"; if (t.code == kDLInt) { if (t.bits == 64U) return INT64_TO_INT64; if (t.bits == 32U) return INT64_TO_INT32; @@ -137,7 +135,7 @@ inline ArgConvertCode GetArgConvertCode(DLDataType t) { return HANDLE_TO_HANDLE; } -template +template inline PackedFunc PackFuncVoidAddr_(F f, const std::vector& codes) { int num_args = static_cast(codes.size()); auto ret = [f, codes, num_args](TVMArgs args, TVMRetValue* ret) { @@ -158,7 +156,7 @@ inline PackedFunc PackFuncVoidAddr_(F f, const std::vector& code addr[i] = &(holder[i]); break; } - case INT64_TO_UINT32 : { + case INT64_TO_UINT32: { holder[i].v_uint32 = static_cast(args.values[i].v_int64); addr[i] = &(holder[i]); break; @@ -175,9 +173,8 @@ inline PackedFunc PackFuncVoidAddr_(F f, const std::vector& code return PackedFunc(ret); } -template -inline PackedFunc PackFuncNonBufferArg_( - F f, int base, const std::vector& codes) { +template +inline PackedFunc PackFuncNonBufferArg_(F f, int base, const std::vector& codes) { int num_args = static_cast(codes.size()); auto ret = [f, codes, base, num_args](TVMArgs args, TVMRetValue* ret) { TempArray holder_(num_args); @@ -186,13 +183,14 @@ inline PackedFunc PackFuncNonBufferArg_( switch (codes[i]) { case INT64_TO_INT64: case FLOAT64_TO_FLOAT64: { - LOG(FATAL) << "Do not support 64bit argument to device function"; break; + LOG(FATAL) << "Do not support 64bit argument to device function"; + break; } case INT64_TO_INT32: { holder[i].v_int32 = static_cast(args.values[base + i].v_int64); break; } - case INT64_TO_UINT32 : { + case INT64_TO_UINT32: { holder[i].v_uint32 = static_cast(args.values[base + i].v_int64); break; } @@ -201,7 +199,8 @@ inline PackedFunc PackFuncNonBufferArg_( break; } case HANDLE_TO_HANDLE: { - LOG(FATAL) << "not reached"; break; + LOG(FATAL) << "not reached"; + break; } } } @@ -210,9 +209,8 @@ inline PackedFunc PackFuncNonBufferArg_( return PackedFunc(ret); } -template -inline PackedFunc PackFuncPackedArg_( - F f, const std::vector& codes) { +template +inline PackedFunc PackFuncPackedArg_(F f, const std::vector& codes) { int num_args = static_cast(codes.size()); auto ret = [f, codes, num_args](TVMArgs args, TVMRetValue* ret) { TempArray pack_(num_args); @@ -238,20 +236,19 @@ inline PackedFunc PackFuncPackedArg_( ++ptr; break; } - case INT64_TO_UINT32 : { - *reinterpret_cast(ptr) = - static_cast(args.values[i].v_int64); + case INT64_TO_UINT32: { + *reinterpret_cast(ptr) = static_cast(args.values[i].v_int64); ++ptr; break; } case FLOAT64_TO_FLOAT32: { - *reinterpret_cast(ptr) = - static_cast(args.values[i].v_float64); + *reinterpret_cast(ptr) = static_cast(args.values[i].v_float64); ++ptr; break; } default: { - LOG(FATAL) << "not reached"; break; + LOG(FATAL) << "not reached"; + break; } } } @@ -261,7 +258,7 @@ inline PackedFunc PackFuncPackedArg_( } } // namespace detail -template +template inline PackedFunc PackFuncVoidAddr(F f, const std::vector& arg_types) { std::vector codes(arg_types.size()); for (size_t i = 0; i < arg_types.size(); ++i) { @@ -282,17 +279,17 @@ inline size_t NumBufferArgs(const std::vector& arg_types) { size_t base = arg_types.size(); for (size_t i = 0; i < arg_types.size(); ++i) { if (arg_types[i].code != kTVMOpaqueHandle) { - base = i; break; + base = i; + break; } } for (size_t i = base; i < arg_types.size(); ++i) { - CHECK(arg_types[i].code != kTVMOpaqueHandle) - << "Device function need to be organized"; + CHECK(arg_types[i].code != kTVMOpaqueHandle) << "Device function need to be organized"; } return base; } -template +template inline PackedFunc PackFuncNonBufferArg(F f, const std::vector& arg_types) { size_t num_buffer = NumBufferArgs(arg_types); std::vector codes; @@ -309,7 +306,7 @@ inline PackedFunc PackFuncNonBufferArg(F f, const std::vector& arg_t } } -template +template inline PackedFunc PackFuncPackedArg(F f, const std::vector& arg_types) { std::vector codes; for (size_t i = 0; i < arg_types.size(); ++i) { diff --git a/src/runtime/registry.cc b/src/runtime/registry.cc index 4717d89e33c1..641532a83927 100644 --- a/src/runtime/registry.cc +++ b/src/runtime/registry.cc @@ -24,10 +24,12 @@ #include #include #include -#include -#include -#include + #include +#include +#include +#include + #include "runtime_base.h" namespace tvm { @@ -37,14 +39,13 @@ struct Registry::Manager { // map storing the functions. // We delibrately used raw pointer // This is because PackedFunc can contain callbacks into the host languge(python) - // and the resource can become invalid because of indeterminstic order of destruction. + // and the resource can become invalid because of indeterminstic order of destruction and forking. // The resources will only be recycled during program exit. std::unordered_map fmap; // mutex std::mutex mutex; - Manager() { - } + Manager() {} static Manager* Global() { // We deliberately leak the Manager instance, to avoid leak sanitizers @@ -60,20 +61,17 @@ Registry& Registry::set_body(PackedFunc f) { // NOLINT(*) return *this; } -Registry& Registry::Register(const std::string& name, bool override) { // NOLINT(*) +Registry& Registry::Register(const std::string& name, bool can_override) { // NOLINT(*) Manager* m = Manager::Global(); std::lock_guard lock(m->mutex); - auto it = m->fmap.find(name); - if (it == m->fmap.end()) { - Registry* r = new Registry(); - r->name_ = name; - m->fmap[name] = r; - return *r; - } else { - CHECK(override) - << "Global PackedFunc " << name << " is already registered"; - return *it->second; + if (m->fmap.count(name)) { + CHECK(can_override) << "Global PackedFunc " << name << " is already registered"; } + + Registry* r = new Registry(); + r->name_ = name; + m->fmap[name] = r; + return *r; } bool Registry::Remove(const std::string& name) { @@ -98,7 +96,7 @@ std::vector Registry::ListNames() { std::lock_guard lock(m->mutex); std::vector keys; keys.reserve(m->fmap.size()); - for (const auto &kv : m->fmap) { + for (const auto& kv : m->fmap) { keys.push_back(kv.first); } return keys; @@ -112,14 +110,13 @@ struct TVMFuncThreadLocalEntry { /*! \brief result holder for returning strings */ std::vector ret_vec_str; /*! \brief result holder for returning string pointers */ - std::vector ret_vec_charp; + std::vector ret_vec_charp; }; /*! \brief Thread local store that can be used to hold return values. */ typedef dmlc::ThreadLocalStore TVMFuncThreadLocalStore; -int TVMFuncRegisterGlobal( - const char* name, TVMFunctionHandle f, int override) { +int TVMFuncRegisterGlobal(const char* name, TVMFunctionHandle f, int override) { API_BEGIN(); tvm::runtime::Registry::Register(name, override != 0) .set_body(*static_cast(f)); @@ -128,8 +125,7 @@ int TVMFuncRegisterGlobal( int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out) { API_BEGIN(); - const tvm::runtime::PackedFunc* fp = - tvm::runtime::Registry::Get(name); + const tvm::runtime::PackedFunc* fp = tvm::runtime::Registry::Get(name); if (fp != nullptr) { *out = new tvm::runtime::PackedFunc(*fp); // NOLINT(*) } else { @@ -138,10 +134,9 @@ int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out) { API_END(); } -int TVMFuncListGlobalNames(int *out_size, - const char*** out_array) { +int TVMFuncListGlobalNames(int* out_size, const char*** out_array) { API_BEGIN(); - TVMFuncThreadLocalEntry *ret = TVMFuncThreadLocalStore::Get(); + TVMFuncThreadLocalEntry* ret = TVMFuncThreadLocalStore::Get(); ret->ret_vec_str = tvm::runtime::Registry::ListNames(); ret->ret_vec_charp.clear(); for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) { diff --git a/src/runtime/rocm/rocm_common.h b/src/runtime/rocm/rocm_common.h index 5d0d5c972c4b..2e637f5496bb 100644 --- a/src/runtime/rocm/rocm_common.h +++ b/src/runtime/rocm/rocm_common.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -24,28 +24,28 @@ #ifndef TVM_RUNTIME_ROCM_ROCM_COMMON_H_ #define TVM_RUNTIME_ROCM_ROCM_COMMON_H_ -#include #include +#include + #include + #include "../workspace_pool.h" namespace tvm { namespace runtime { -#define ROCM_DRIVER_CALL(x) \ - { \ - hipError_t result = x; \ - if (result != hipSuccess && result != hipErrorDeinitialized) { \ - LOG(FATAL) \ - << "ROCM HIP Error: " #x " failed with error: " << hipGetErrorString(result); \ - } \ +#define ROCM_DRIVER_CALL(x) \ + { \ + hipError_t result = x; \ + if (result != hipSuccess && result != hipErrorDeinitialized) { \ + LOG(FATAL) << "ROCM HIP Error: " #x " failed with error: " << hipGetErrorString(result); \ + } \ } -#define ROCM_CALL(func) \ - { \ - hipError_t e = (func); \ - CHECK(e == hipSuccess) \ - << "ROCM HIP: " << hipGetErrorString(e); \ +#define ROCM_CALL(func) \ + { \ + hipError_t e = (func); \ + CHECK(e == hipSuccess) << "ROCM HIP: " << hipGetErrorString(e); \ } /*! \brief Thread local workspace */ diff --git a/src/runtime/rocm/rocm_device_api.cc b/src/runtime/rocm/rocm_device_api.cc index 25e1ac70c241..475c4fbffadc 100644 --- a/src/runtime/rocm/rocm_device_api.cc +++ b/src/runtime/rocm/rocm_device_api.cc @@ -35,9 +35,7 @@ namespace runtime { class ROCMDeviceAPI final : public DeviceAPI { public: - void SetDevice(TVMContext ctx) final { - ROCM_CALL(hipSetDevice(ctx.device_id)); - } + void SetDevice(TVMContext ctx) final { ROCM_CALL(hipSetDevice(ctx.device_id)); } void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final { int value = 0; switch (kind) { @@ -53,27 +51,26 @@ class ROCMDeviceAPI final : public DeviceAPI { break; } case kMaxThreadsPerBlock: { - ROCM_CALL(hipDeviceGetAttribute( - &value, hipDeviceAttributeMaxThreadsPerBlock, ctx.device_id)); + ROCM_CALL( + hipDeviceGetAttribute(&value, hipDeviceAttributeMaxThreadsPerBlock, ctx.device_id)); break; } case kWarpSize: { - ROCM_CALL(hipDeviceGetAttribute(&value, hipDeviceAttributeWarpSize, - ctx.device_id)); + ROCM_CALL(hipDeviceGetAttribute(&value, hipDeviceAttributeWarpSize, ctx.device_id)); break; } case kMaxSharedMemoryPerBlock: { - ROCM_CALL(hipDeviceGetAttribute( - &value, hipDeviceAttributeMaxSharedMemoryPerBlock, ctx.device_id)); + ROCM_CALL(hipDeviceGetAttribute(&value, hipDeviceAttributeMaxSharedMemoryPerBlock, + ctx.device_id)); break; } case kComputeVersion: { std::ostringstream os; - ROCM_CALL(hipDeviceGetAttribute( - &value, hipDeviceAttributeComputeCapabilityMajor, ctx.device_id)); + ROCM_CALL( + hipDeviceGetAttribute(&value, hipDeviceAttributeComputeCapabilityMajor, ctx.device_id)); os << value << "."; - ROCM_CALL(hipDeviceGetAttribute( - &value, hipDeviceAttributeComputeCapabilityMinor, ctx.device_id)); + ROCM_CALL( + hipDeviceGetAttribute(&value, hipDeviceAttributeComputeCapabilityMinor, ctx.device_id)); os << value; *rv = os.str(); return; @@ -86,23 +83,19 @@ class ROCMDeviceAPI final : public DeviceAPI { return; } case kMaxClockRate: { - ROCM_CALL(hipDeviceGetAttribute(&value, hipDeviceAttributeClockRate, - ctx.device_id)); + ROCM_CALL(hipDeviceGetAttribute(&value, hipDeviceAttributeClockRate, ctx.device_id)); break; } case kMultiProcessorCount: { - ROCM_CALL(hipDeviceGetAttribute( - &value, hipDeviceAttributeMultiprocessorCount, ctx.device_id)); + ROCM_CALL( + hipDeviceGetAttribute(&value, hipDeviceAttributeMultiprocessorCount, ctx.device_id)); break; } case kMaxThreadDimensions: { int dims[3]; - ROCM_CALL(hipDeviceGetAttribute( - &dims[0], hipDeviceAttributeMaxBlockDimX, ctx.device_id)); - ROCM_CALL(hipDeviceGetAttribute( - &dims[1], hipDeviceAttributeMaxBlockDimY, ctx.device_id)); - ROCM_CALL(hipDeviceGetAttribute( - &dims[2], hipDeviceAttributeMaxBlockDimZ, ctx.device_id)); + ROCM_CALL(hipDeviceGetAttribute(&dims[0], hipDeviceAttributeMaxBlockDimX, ctx.device_id)); + ROCM_CALL(hipDeviceGetAttribute(&dims[1], hipDeviceAttributeMaxBlockDimY, ctx.device_id)); + ROCM_CALL(hipDeviceGetAttribute(&dims[2], hipDeviceAttributeMaxBlockDimZ, ctx.device_id)); std::stringstream ss; ss << "[" << dims[0] << ", " << dims[1] << ", " << dims[2] << "]"; @@ -132,9 +125,8 @@ class ROCMDeviceAPI final : public DeviceAPI { ROCM_CALL(hipFree(ptr)); } - void CopyDataFromTo(const void* from, size_t from_offset, void* to, - size_t to_offset, size_t size, TVMContext ctx_from, - TVMContext ctx_to, DLDataType type_hint, + void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size, + TVMContext ctx_from, TVMContext ctx_to, DLDataType type_hint, TVMStreamHandle stream) final { hipStream_t hip_stream = static_cast(stream); from = static_cast(from) + from_offset; @@ -144,15 +136,12 @@ class ROCMDeviceAPI final : public DeviceAPI { if (ctx_from.device_id == ctx_to.device_id) { GPUCopy(from, to, size, hipMemcpyDeviceToDevice, hip_stream); } else { - hipMemcpyPeerAsync(to, ctx_to.device_id, from, ctx_from.device_id, size, - hip_stream); + hipMemcpyPeerAsync(to, ctx_to.device_id, from, ctx_from.device_id, size, hip_stream); } - } else if (ctx_from.device_type == kDLROCM && - ctx_to.device_type == kDLCPU) { + } else if (ctx_from.device_type == kDLROCM && ctx_to.device_type == kDLCPU) { ROCM_CALL(hipSetDevice(ctx_from.device_id)); GPUCopy(from, to, size, hipMemcpyDeviceToHost, hip_stream); - } else if (ctx_from.device_type == kDLCPU && - ctx_to.device_type == kDLROCM) { + } else if (ctx_from.device_type == kDLCPU && ctx_to.device_type == kDLROCM) { ROCM_CALL(hipSetDevice(ctx_to.device_id)); GPUCopy(from, to, size, hipMemcpyHostToDevice, hip_stream); } else { @@ -178,14 +167,13 @@ class ROCMDeviceAPI final : public DeviceAPI { } static const std::shared_ptr& Global() { - static std::shared_ptr inst = - std::make_shared(); + static std::shared_ptr inst = std::make_shared(); return inst; } private: - static void GPUCopy(const void* from, void* to, size_t size, - hipMemcpyKind kind, hipStream_t stream) { + static void GPUCopy(const void* from, void* to, size_t size, hipMemcpyKind kind, + hipStream_t stream) { if (stream != 0) { ROCM_CALL(hipMemcpyAsync(to, from, size, kind, stream)); } else { @@ -198,14 +186,11 @@ typedef dmlc::ThreadLocalStore ROCMThreadStore; ROCMThreadEntry::ROCMThreadEntry() : pool(kDLROCM, ROCMDeviceAPI::Global()) {} -ROCMThreadEntry* ROCMThreadEntry::ThreadLocal() { - return ROCMThreadStore::Get(); -} +ROCMThreadEntry* ROCMThreadEntry::ThreadLocal() { return ROCMThreadStore::Get(); } -TVM_REGISTER_GLOBAL("device_api.rocm") - .set_body([](TVMArgs args, TVMRetValue* rv) { - DeviceAPI* ptr = ROCMDeviceAPI::Global().get(); - *rv = static_cast(ptr); - }); +TVM_REGISTER_GLOBAL("device_api.rocm").set_body([](TVMArgs args, TVMRetValue* rv) { + DeviceAPI* ptr = ROCMDeviceAPI::Global().get(); + *rv = static_cast(ptr); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/rocm/rocm_module.cc b/src/runtime/rocm/rocm_module.cc index 1f4b830ce434..79958d20aa1f 100644 --- a/src/runtime/rocm/rocm_module.cc +++ b/src/runtime/rocm/rocm_module.cc @@ -20,19 +20,22 @@ /*! * \file rocm_module.cc */ -#include +#include "rocm_module.h" + #include -#include +#include + #include -#include #include +#include #include -#include "rocm_module.h" -#include "rocm_common.h" +#include + +#include "../file_util.h" +#include "../meta_data.h" #include "../pack_args.h" #include "../thread_storage_scope.h" -#include "../meta_data.h" -#include "../file_util.h" +#include "rocm_common.h" namespace tvm { namespace runtime { @@ -43,12 +46,10 @@ namespace runtime { // The modules will be lazily loaded class ROCMModuleNode : public runtime::ModuleNode { public: - explicit ROCMModuleNode(std::string data, - std::string fmt, + explicit ROCMModuleNode(std::string data, std::string fmt, std::unordered_map fmap, - std::string hip_source, - std::string assembly) - : data_(data), fmt_(fmt), fmap_(fmap), hip_source_(hip_source), assembly_(assembly) { + std::string hip_source, std::string assembly) + : data_(data), fmt_(fmt), fmap_(fmap), hip_source_(hip_source), assembly_(assembly) { std::fill(module_.begin(), module_.end(), nullptr); } // destructor @@ -61,17 +62,11 @@ class ROCMModuleNode : public runtime::ModuleNode { } } - const char* type_key() const final { - return "hip"; - } - - PackedFunc GetFunction( - const std::string& name, - const ObjectPtr& sptr_to_self) final; + const char* type_key() const final { return "hip"; } + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final; - void SaveToFile(const std::string& file_name, - const std::string& format) final { + void SaveToFile(const std::string& file_name, const std::string& format) final { std::string fmt = GetFileFormat(file_name, format); std::string meta_file = GetMetaFilePath(file_name); // note: llvm and asm formats are not laodable, so we don't save them @@ -87,9 +82,15 @@ class ROCMModuleNode : public runtime::ModuleNode { } std::string GetSource(const std::string& format) final { - if (format == fmt_) { return data_; } - if (format == "llvm" || format == "") { return hip_source_; } - if (format == "asm") { return assembly_; } + if (format == fmt_) { + return data_; + } + if (format == "llvm" || format == "") { + return hip_source_; + } + if (format == "asm") { + return assembly_; + } return ""; } @@ -104,16 +105,13 @@ class ROCMModuleNode : public runtime::ModuleNode { hipFunction_t func; hipError_t result = hipModuleGetFunction(&func, module_[device_id], func_name.c_str()); if (result != hipSuccess) { - LOG(FATAL) - << "ROCMError: hipModuleGetFunction " << func_name - << " failed with error: " << hipGetErrorString(result); + LOG(FATAL) << "ROCMError: hipModuleGetFunction " << func_name + << " failed with error: " << hipGetErrorString(result); } return func; } // get a global var from primary context in device_id - hipDeviceptr_t GetGlobal(int device_id, - const std::string& global_name, - size_t expect_nbytes) { + hipDeviceptr_t GetGlobal(int device_id, const std::string& global_name, size_t expect_nbytes) { std::lock_guard lock(mutex_); // must recheck under the lock scope if (module_[device_id] == nullptr) { @@ -122,8 +120,7 @@ class ROCMModuleNode : public runtime::ModuleNode { hipDeviceptr_t global = nullptr; size_t nbytes = 0; - ROCM_DRIVER_CALL(hipModuleGetGlobal(&global, &nbytes, - module_[device_id], global_name.c_str())); + ROCM_DRIVER_CALL(hipModuleGetGlobal(&global, &nbytes, module_[device_id], global_name.c_str())); CHECK_EQ(nbytes, expect_nbytes); return global; } @@ -149,11 +146,8 @@ class ROCMModuleNode : public runtime::ModuleNode { class ROCMWrappedFunc { public: // initialize the ROCM function. - void Init(ROCMModuleNode* m, - ObjectPtr sptr, - const std::string& func_name, - size_t num_void_args, - const std::vector& thread_axis_tags) { + void Init(ROCMModuleNode* m, ObjectPtr sptr, const std::string& func_name, + size_t num_void_args, const std::vector& thread_axis_tags) { m_ = m; sptr_ = sptr; func_name_ = func_name; @@ -161,10 +155,7 @@ class ROCMWrappedFunc { thread_axis_cfg_.Init(num_void_args, thread_axis_tags); } // invoke the function with void arguments - void operator()(TVMArgs args, - TVMRetValue* rv, - void* packed_args, - size_t packed_nbytes) const { + void operator()(TVMArgs args, TVMRetValue* rv, void* packed_args, size_t packed_nbytes) const { int device_id; ROCM_CALL(hipGetDevice(&device_id)); if (fcache_[device_id] == nullptr) { @@ -174,22 +165,12 @@ class ROCMWrappedFunc { hipStream_t strm = static_cast(ROCMThreadEntry::ThreadLocal()->stream); ThreadWorkLoad wl = thread_axis_cfg_.Extract(args); - void* config[] = { - HIP_LAUNCH_PARAM_BUFFER_POINTER, packed_args, - HIP_LAUNCH_PARAM_BUFFER_SIZE, &packed_nbytes, - HIP_LAUNCH_PARAM_END - }; + void* config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, packed_args, HIP_LAUNCH_PARAM_BUFFER_SIZE, + &packed_nbytes, HIP_LAUNCH_PARAM_END}; // HIP supports only extra_args. ROCM_DRIVER_CALL(hipModuleLaunchKernel( - fcache_[device_id], - wl.grid_dim(0), - wl.grid_dim(1), - wl.grid_dim(2), - wl.block_dim(0), - wl.block_dim(1), - wl.block_dim(2), - 0, strm, nullptr, - reinterpret_cast(&config))); + fcache_[device_id], wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2), wl.block_dim(0), + wl.block_dim(1), wl.block_dim(2), 0, strm, nullptr, reinterpret_cast(&config))); } private: @@ -206,13 +187,10 @@ class ROCMWrappedFunc { ThreadAxisConfig thread_axis_cfg_; }; - -PackedFunc ROCMModuleNode::GetFunction( - const std::string& name, - const ObjectPtr& sptr_to_self) { +PackedFunc ROCMModuleNode::GetFunction(const std::string& name, + const ObjectPtr& sptr_to_self) { CHECK_EQ(sptr_to_self.get(), this); - CHECK_NE(name, symbol::tvm_module_main) - << "Device function do not have main"; + CHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main"; auto it = fmap_.find(name); if (it == fmap_.end()) return PackedFunc(); const FunctionInfo& info = it->second; @@ -221,18 +199,14 @@ PackedFunc ROCMModuleNode::GetFunction( return PackFuncPackedArg(f, info.arg_types); } -Module ROCMModuleCreate( - std::string data, - std::string fmt, - std::unordered_map fmap, - std::string hip_source, - std::string assembly) { +Module ROCMModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, std::string hip_source, + std::string assembly) { auto n = make_object(data, fmt, fmap, hip_source, assembly); return Module(n); } -Module ROCMModuleLoadFile(const std::string& file_name, - const std::string& format) { +Module ROCMModuleLoadFile(const std::string& file_name, const std::string& format) { std::string data; std::unordered_map fmap; std::string fmt = GetFileFormat(file_name, format); @@ -253,19 +227,12 @@ Module ROCMModuleLoadBinary(void* strm) { return ROCMModuleCreate(data, fmt, fmap, std::string(), std::string()); } +TVM_REGISTER_GLOBAL("runtime.module.loadbinary_hsaco").set_body_typed(ROCMModuleLoadBinary); -TVM_REGISTER_GLOBAL("runtime.module.loadbinary_hsaco") -.set_body_typed(ROCMModuleLoadBinary); - - -TVM_REGISTER_GLOBAL("runtime.module.loadbinary_hip") -.set_body_typed(ROCMModuleLoadBinary); - +TVM_REGISTER_GLOBAL("runtime.module.loadbinary_hip").set_body_typed(ROCMModuleLoadBinary); -TVM_REGISTER_GLOBAL("runtime.module.loadfile_hsaco") -.set_body_typed(ROCMModuleLoadFile); +TVM_REGISTER_GLOBAL("runtime.module.loadfile_hsaco").set_body_typed(ROCMModuleLoadFile); -TVM_REGISTER_GLOBAL("runtime.module.loadfile_hip") -.set_body_typed(ROCMModuleLoadFile); +TVM_REGISTER_GLOBAL("runtime.module.loadfile_hip").set_body_typed(ROCMModuleLoadFile); } // namespace runtime } // namespace tvm diff --git a/src/runtime/rocm/rocm_module.h b/src/runtime/rocm/rocm_module.h index 7f2a0ce319bf..c17e123c1a12 100644 --- a/src/runtime/rocm/rocm_module.h +++ b/src/runtime/rocm/rocm_module.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -25,10 +25,12 @@ #define TVM_RUNTIME_ROCM_ROCM_MODULE_H_ #include + #include -#include #include #include +#include + #include "../meta_data.h" namespace tvm { @@ -45,12 +47,9 @@ static constexpr const int kMaxNumGPUs = 32; * \param fmap The map function information map of each function. * \param rocm_source Optional, rocm source file */ -Module ROCMModuleCreate( - std::string data, - std::string fmt, - std::unordered_map fmap, - std::string rocm_source, - std::string assembly); +Module ROCMModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, std::string rocm_source, + std::string assembly); } // namespace runtime } // namespace tvm #endif // TVM_RUNTIME_ROCM_ROCM_MODULE_H_ diff --git a/src/runtime/rpc/minrpc/minrpc_server.h b/src/runtime/rpc/minrpc/minrpc_server.h new file mode 100644 index 000000000000..91a900afd900 --- /dev/null +++ b/src/runtime/rpc/minrpc/minrpc_server.h @@ -0,0 +1,581 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file minrpc_server.h + * \brief Minimum RPC server implementation, + * redirects all the calls to C runtime API. + * + * \note This file do not depend on c++ std or c std, + * and only depends on TVM's C runtime API. + */ +#ifndef TVM_RUNTIME_RPC_MINRPC_MINRPC_SERVER_H_ +#define TVM_RUNTIME_RPC_MINRPC_MINRPC_SERVER_H_ + +#include +#include + +#include "../../../support/arena.h" +#include "../rpc_protocol.h" + +/*! \brief Whether or not to enable glog style DLOG */ +#ifndef TVM_MINRPC_ENABLE_LOGGING +#define TVM_MINRPC_ENABLE_LOGGING 0 +#endif + +#ifndef MINRPC_CHECK +#define MINRPC_CHECK(cond) \ + if (!(cond)) this->ThrowError(RPCServerStatus::kCheckError); +#endif + +#if TVM_MINRPC_ENABLE_LOGGING +#include +#endif + +namespace tvm { +namespace runtime { + +/*! + * \brief A minimum RPC server that only depends on the tvm C runtime.. + * + * All the dependencies are provided by the io arguments. + * + * \tparam TIOHandler IO provider to provide io handling. + * An IOHandler needs to provide the following functions: + * - PosixWrite, PosixRead, Close: posix style, read, write, close API. + * - Exit: exit with status code. + */ +template +class MinRPCServer { + public: + /*! + * \brief Constructor. + * \param io The IO handler. + */ + explicit MinRPCServer(TIOHandler io) : io_(io), arena_(PageAllocator(io)) {} + + /*! \brief Run the server loop until shutdown signal is received. */ + void ServerLoop() { + RPCCode code; + uint64_t packet_len; + + while (true) { + arena_.RecycleAll(); + allow_clean_shutdown_ = true; + + this->Read(&packet_len); + if (packet_len == 0) continue; + this->Read(&code); + + allow_clean_shutdown_ = false; + + if (code >= RPCCode::kSyscallCodeStart) { + this->HandleSyscallFunc(code); + } else { + switch (code) { + case RPCCode::kCallFunc: { + HandleNormalCallFunc(); + break; + } + case RPCCode::kInitServer: { + HandleInitServer(); + break; + } + case RPCCode::kCopyFromRemote: { + HandleCopyFromRemote(); + break; + } + case RPCCode::kCopyToRemote: { + HandleCopyToRemote(); + break; + } + case RPCCode::kShutdown: { + this->Shutdown(); + return; + } + default: { + this->ThrowError(RPCServerStatus::kUnknownRPCCode); + break; + } + } + } + } + } + + void Shutdown() { + arena_.FreeAll(); + io_.Close(); + } + + void HandleNormalCallFunc() { + uint64_t call_handle; + TVMValue* values; + int* tcodes; + int num_args; + TVMValue ret_value[3]; + int ret_tcode[3]; + + this->Read(&call_handle); + RecvPackedSeq(&values, &tcodes, &num_args); + + int call_ecode = TVMFuncCall(reinterpret_cast(call_handle), values, tcodes, num_args, + &(ret_value[1]), &(ret_tcode[1])); + + if (call_ecode == 0) { + // Return value encoding as in LocalSession + int rv_tcode = ret_tcode[1]; + ret_tcode[0] = kDLInt; + ret_value[0].v_int64 = rv_tcode; + if (rv_tcode == kTVMNDArrayHandle) { + ret_tcode[1] = kTVMDLTensorHandle; + ret_value[2].v_handle = ret_value[1].v_handle; + ret_tcode[2] = kTVMOpaqueHandle; + this->ReturnPackedSeq(ret_value, ret_tcode, 3); + } else if (rv_tcode == kTVMPackedFuncHandle || rv_tcode == kTVMModuleHandle) { + ret_tcode[1] = kTVMOpaqueHandle; + this->ReturnPackedSeq(ret_value, ret_tcode, 2); + } else { + this->ReturnPackedSeq(ret_value, ret_tcode, 2); + } + } else { + this->ReturnLastTVMError(); + } + } + + void HandleCopyFromRemote() { + uint64_t handle, offset, num_bytes; + TVMContext ctx; + DLDataType type_hint; + + this->Read(&handle); + this->Read(&offset); + this->Read(&num_bytes); + this->Read(&ctx); + this->Read(&type_hint); + + uint8_t* data_ptr; + int call_ecode = 0; + if (ctx.device_type == kDLCPU) { + data_ptr = reinterpret_cast(handle) + offset; + } else { + data_ptr = this->ArenaAlloc(num_bytes); + call_ecode = + TVMDeviceCopyDataFromTo(reinterpret_cast(handle), offset, data_ptr, 0, num_bytes, + ctx, DLContext{kDLCPU, 0}, type_hint, nullptr); + // need sync to make sure that the copy is completed. + if (call_ecode == 0) { + call_ecode = TVMSynchronize(ctx.device_type, ctx.device_id, nullptr); + } + } + + if (call_ecode == 0) { + RPCCode code = RPCCode::kCopyAck; + uint64_t packet_nbytes = sizeof(code) + num_bytes; + + this->Write(packet_nbytes); + this->Write(code); + this->WriteArray(data_ptr, num_bytes); + } else { + this->ReturnLastTVMError(); + } + } + + void HandleCopyToRemote() { + uint64_t handle, offset, num_bytes; + TVMContext ctx; + DLDataType type_hint; + + this->Read(&handle); + this->Read(&offset); + this->Read(&num_bytes); + this->Read(&ctx); + this->Read(&type_hint); + int call_ecode = 0; + + if (ctx.device_type == kDLCPU) { + uint8_t* dptr = reinterpret_cast(handle) + offset; + this->ReadArray(dptr, num_bytes); + } else { + uint8_t* temp_data = this->ArenaAlloc(num_bytes); + this->ReadArray(temp_data, num_bytes); + + call_ecode = + TVMDeviceCopyDataFromTo(temp_data, 0, reinterpret_cast(handle), offset, num_bytes, + DLContext{kDLCPU, 0}, ctx, type_hint, nullptr); + // need sync to make sure that the copy is completed. + if (call_ecode == 0) { + call_ecode = TVMSynchronize(ctx.device_type, ctx.device_id, nullptr); + } + } + + if (call_ecode == 0) { + this->ReturnVoid(); + } else { + this->ReturnLastTVMError(); + } + } + + void HandleSyscallFunc(RPCCode code) { + TVMValue* values; + int* tcodes; + int num_args; + RecvPackedSeq(&values, &tcodes, &num_args); + switch (code) { + case RPCCode::kFreeHandle: { + this->SyscallFreeHandle(values, tcodes, num_args); + break; + } + case RPCCode::kGetGlobalFunc: { + this->SyscallGetGlobalFunc(values, tcodes, num_args); + break; + } + case RPCCode::kDevSetDevice: { + this->ReturnException("SetDevice not supported"); + break; + } + case RPCCode::kDevGetAttr: { + this->ReturnException("GetAttr not supported"); + break; + } + case RPCCode::kDevAllocData: { + this->SyscallDevAllocData(values, tcodes, num_args); + break; + } + case RPCCode::kDevFreeData: { + this->SyscallDevFreeData(values, tcodes, num_args); + break; + } + case RPCCode::kDevStreamSync: { + this->SyscallDevStreamSync(values, tcodes, num_args); + break; + } + case RPCCode::kCopyAmongRemote: { + this->SyscallCopyAmongRemote(values, tcodes, num_args); + break; + } + default: { + this->ReturnException("Syscall not recognized"); + break; + } + } + } + + void HandleInitServer() { + uint64_t len; + this->Read(&len); + char* proto_ver = this->ArenaAlloc(len + 1); + this->ReadArray(proto_ver, len); + + TVMValue* values; + int* tcodes; + int num_args; + RecvPackedSeq(&values, &tcodes, &num_args); + MINRPC_CHECK(num_args == 0); + this->ReturnVoid(); + } + + void SyscallFreeHandle(TVMValue* values, int* tcodes, int num_args) { + MINRPC_CHECK(num_args == 2); + MINRPC_CHECK(tcodes[0] == kTVMOpaqueHandle); + MINRPC_CHECK(tcodes[1] == kDLInt); + + void* handle = values[0].v_handle; + int64_t type_code = values[1].v_int64; + int call_ecode; + + if (type_code == kTVMNDArrayHandle) { + call_ecode = TVMArrayFree(static_cast(handle)); + } else if (type_code == kTVMPackedFuncHandle) { + call_ecode = TVMFuncFree(handle); + } else { + MINRPC_CHECK(type_code == kTVMModuleHandle); + call_ecode = TVMModFree(handle); + } + + if (call_ecode == 0) { + this->ReturnVoid(); + } else { + this->ReturnLastTVMError(); + } + } + + void SyscallGetGlobalFunc(TVMValue* values, int* tcodes, int num_args) { + MINRPC_CHECK(num_args == 1); + MINRPC_CHECK(tcodes[0] == kTVMStr); + + void* handle; + int call_ecode = TVMFuncGetGlobal(values[0].v_str, &handle); + + if (call_ecode == 0) { + this->ReturnHandle(handle); + } else { + this->ReturnLastTVMError(); + } + } + + void SyscallCopyAmongRemote(TVMValue* values, int* tcodes, int num_args) { + MINRPC_CHECK(num_args == 9); + // from, from_offset + MINRPC_CHECK(tcodes[0] == kTVMOpaqueHandle); + MINRPC_CHECK(tcodes[1] == kDLInt); + // to, to_offset + MINRPC_CHECK(tcodes[2] == kTVMOpaqueHandle); + MINRPC_CHECK(tcodes[3] == kDLInt); + // size + MINRPC_CHECK(tcodes[4] == kDLInt); + // ctx_from, ctx_to + MINRPC_CHECK(tcodes[5] == kTVMContext); + MINRPC_CHECK(tcodes[6] == kTVMContext); + // type_hint, stream + MINRPC_CHECK(tcodes[7] == kTVMDataType); + MINRPC_CHECK(tcodes[8] == kTVMOpaqueHandle); + + void* from = values[0].v_handle; + int64_t from_offset = values[1].v_int64; + void* to = values[2].v_handle; + int64_t to_offset = values[3].v_int64; + int64_t size = values[4].v_int64; + TVMContext ctx_from = values[5].v_ctx; + TVMContext ctx_to = values[6].v_ctx; + DLDataType type_hint = values[7].v_type; + TVMStreamHandle stream = values[8].v_handle; + + int call_ecode = TVMDeviceCopyDataFromTo(from, from_offset, to, to_offset, size, ctx_from, + ctx_to, type_hint, stream); + + if (call_ecode == 0) { + this->ReturnVoid(); + } else { + this->ReturnLastTVMError(); + } + } + + void SyscallDevAllocData(TVMValue* values, int* tcodes, int num_args) { + MINRPC_CHECK(num_args == 4); + MINRPC_CHECK(tcodes[0] == kTVMContext); + MINRPC_CHECK(tcodes[1] == kDLInt); + MINRPC_CHECK(tcodes[2] == kDLInt); + MINRPC_CHECK(tcodes[3] == kTVMDataType); + + TVMContext ctx = values[0].v_ctx; + int64_t nbytes = values[1].v_int64; + int64_t alignment = values[2].v_int64; + DLDataType type_hint = values[3].v_type; + + void* handle; + int call_ecode = TVMDeviceAllocDataSpace(ctx, nbytes, alignment, type_hint, &handle); + + if (call_ecode == 0) { + this->ReturnHandle(handle); + } else { + this->ReturnLastTVMError(); + } + } + + void SyscallDevFreeData(TVMValue* values, int* tcodes, int num_args) { + MINRPC_CHECK(num_args == 2); + MINRPC_CHECK(tcodes[0] == kTVMContext); + MINRPC_CHECK(tcodes[1] == kTVMOpaqueHandle); + + TVMContext ctx = values[0].v_ctx; + void* handle = values[1].v_handle; + + int call_ecode = TVMDeviceFreeDataSpace(ctx, handle); + + if (call_ecode == 0) { + this->ReturnVoid(); + } else { + this->ReturnLastTVMError(); + } + } + + void SyscallDevStreamSync(TVMValue* values, int* tcodes, int num_args) { + MINRPC_CHECK(num_args == 2); + MINRPC_CHECK(tcodes[0] == kTVMContext); + MINRPC_CHECK(tcodes[1] == kTVMOpaqueHandle); + + TVMContext ctx = values[0].v_ctx; + void* handle = values[1].v_handle; + + int call_ecode = TVMSynchronize(ctx.device_type, ctx.device_id, handle); + + if (call_ecode == 0) { + this->ReturnVoid(); + } else { + this->ReturnLastTVMError(); + } + } + + void ThrowError(RPCServerStatus code, RPCCode info = RPCCode::kNone) { + io_.Exit(static_cast(code)); + } + + template + T* ArenaAlloc(int count) { + static_assert(std::is_pod::value, "need to be trival"); + return arena_.template allocate_(count); + } + + template + void Read(T* data) { + static_assert(std::is_pod::value, "need to be trival"); + this->ReadRawBytes(data, sizeof(T)); + } + + template + void ReadArray(T* data, size_t count) { + static_assert(std::is_pod::value, "need to be trival"); + return this->ReadRawBytes(data, sizeof(T) * count); + } + + template + void Write(const T& data) { + static_assert(std::is_pod::value, "need to be trival"); + return this->WriteRawBytes(&data, sizeof(T)); + } + + template + void WriteArray(T* data, size_t count) { + static_assert(std::is_pod::value, "need to be trival"); + return this->WriteRawBytes(data, sizeof(T) * count); + } + + private: + // Internal allocator that redirects alloc to TVM's C API. + class PageAllocator { + public: + using ArenaPageHeader = tvm::support::ArenaPageHeader; + + explicit PageAllocator(TIOHandler io) : io_(io) {} + + ArenaPageHeader* allocate(size_t min_size) { + size_t npages = ((min_size + kPageSize - 1) / kPageSize); + void* data; + + if (TVMDeviceAllocDataSpace(DLContext{kDLCPU, 0}, npages * kPageSize, kPageAlign, + DLDataType{kDLInt, 1, 1}, &data) != 0) { + io_.Exit(static_cast(RPCServerStatus::kAllocError)); + } + + ArenaPageHeader* header = static_cast(data); + header->size = npages * kPageSize; + header->offset = sizeof(ArenaPageHeader); + return header; + } + + void deallocate(ArenaPageHeader* page) { + if (TVMDeviceFreeDataSpace(DLContext{kDLCPU, 0}, page) != 0) { + io_.Exit(static_cast(RPCServerStatus::kAllocError)); + } + } + + static const constexpr int kPageSize = 2 << 10; + static const constexpr int kPageAlign = 8; + + private: + TIOHandler io_; + }; + + void RecvPackedSeq(TVMValue** out_values, int** out_tcodes, int* out_num_args) { + RPCReference::RecvPackedSeq(out_values, out_tcodes, out_num_args, this); + } + + void ReturnVoid() { + int32_t num_args = 1; + int32_t tcode = kTVMNullptr; + RPCCode code = RPCCode::kReturn; + + uint64_t packet_nbytes = sizeof(code) + sizeof(num_args) + sizeof(tcode); + + this->Write(packet_nbytes); + this->Write(code); + this->Write(num_args); + this->Write(tcode); + } + + void ReturnHandle(void* handle) { + int32_t num_args = 1; + int32_t tcode = kTVMOpaqueHandle; + RPCCode code = RPCCode::kReturn; + uint64_t encode_handle = reinterpret_cast(handle); + + uint64_t packet_nbytes = + sizeof(code) + sizeof(num_args) + sizeof(tcode) + sizeof(encode_handle); + + this->Write(packet_nbytes); + this->Write(code); + this->Write(num_args); + this->Write(tcode); + this->Write(encode_handle); + } + + void ReturnException(const char* msg) { RPCReference::ReturnException(msg, this); } + + void ReturnPackedSeq(const TVMValue* arg_values, const int* type_codes, int num_args) { + RPCReference::ReturnPackedSeq(arg_values, type_codes, num_args, this); + } + + void ReturnLastTVMError() { this->ReturnException(TVMGetLastError()); } + + void ReadRawBytes(void* data, size_t size) { + uint8_t* buf = reinterpret_cast(data); + size_t ndone = 0; + while (ndone < size) { + ssize_t ret = io_.PosixRead(buf, size - ndone); + if (ret == 0) { + if (allow_clean_shutdown_) { + this->Shutdown(); + io_.Exit(0); + } else { + this->ThrowError(RPCServerStatus::kReadError); + } + } + if (ret == -1) { + this->ThrowError(RPCServerStatus::kReadError); + } + ndone += ret; + buf += ret; + } + } + + void WriteRawBytes(const void* data, size_t size) { + const uint8_t* buf = reinterpret_cast(data); + size_t ndone = 0; + while (ndone < size) { + ssize_t ret = io_.PosixWrite(buf, size - ndone); + if (ret == 0 || ret == -1) { + this->ThrowError(RPCServerStatus::kWriteError); + } + buf += ret; + ndone += ret; + } + } + + /*! \brief IO handler. */ + TIOHandler io_; + /*! \brief internal arena. */ + support::GenericArena arena_; + /*! \brief Whether we are in a state that allows clean shutdown. */ + bool allow_clean_shutdown_{true}; + static_assert(DMLC_LITTLE_ENDIAN, "MinRPC only works on little endian."); +}; + +} // namespace runtime +} // namespace tvm +#endif // TVM_RUNTIME_RPC_MINRPC_MINRPC_SERVER_H_ diff --git a/src/runtime/rpc/minrpc/posix_popen_server.cc b/src/runtime/rpc/minrpc/posix_popen_server.cc new file mode 100644 index 000000000000..9784780fea18 --- /dev/null +++ b/src/runtime/rpc/minrpc/posix_popen_server.cc @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +// Disable constructor to bring minimum dep on c++ABI. +#define TVM_ARENA_HAS_DESTRUCTOR 0 + +#include + +#include + +#include "minrpc_server.h" + +namespace tvm { +namespace runtime { + +/*! + * \brief IOHandler based on posix API. + */ +class PosixIOHandler { + public: + explicit PosixIOHandler(int read_fd = 0, int write_fd = 1) + : read_fd_(read_fd), write_fd_(write_fd) {} + + ssize_t PosixRead(void* data, size_t size) { return read(read_fd_, data, size); } + + ssize_t PosixWrite(const void* data, size_t size) { return write(write_fd_, data, size); } + + void Exit(int code) { exit(code); } + + void Close() { + if (read_fd_ != 0) close(read_fd_); + if (write_fd_ != 0) close(write_fd_); + } + + private: + int read_fd_{0}; + int write_fd_{1}; +}; + +/*! \brief Type for the posix version of min rpc server. */ +using PosixMinRPCServer = MinRPCServer; + +} // namespace runtime +} // namespace tvm + +int main(int argc, char* argv[]) { + if (argc != 3) return -1; + // pass the descriptor via arguments. + tvm::runtime::PosixIOHandler handler(atoi(argv[1]), atoi(argv[2])); + tvm::runtime::PosixMinRPCServer server(handler); + server.ServerLoop(); + return 0; +} diff --git a/src/target/opt/build_opengl_off.cc b/src/runtime/rpc/rpc_channel.cc similarity index 57% rename from src/target/opt/build_opengl_off.cc rename to src/runtime/rpc/rpc_channel.cc index 781bf51c2cc0..eaa64e3372c6 100644 --- a/src/target/opt/build_opengl_off.cc +++ b/src/runtime/rpc/rpc_channel.cc @@ -18,20 +18,35 @@ */ /*! - * Optional module when build opencl is switched to off + * \file rpc_channel.cc */ -#include "../source/codegen_source_base.h" -#include "../../runtime/opengl/opengl_module.h" +#include "rpc_channel.h" + +#include namespace tvm { namespace runtime { -Module OpenGLModuleCreate(std::unordered_map shaders, - std::string fmt, - std::unordered_map fmap) { - LOG(WARNING) << "OpenGL runtime not enabled, return a source module..."; - auto data = ToJSON(shaders); - return codegen::DeviceSourceModuleCreate(data, "gl", fmap, "opengl"); +size_t CallbackChannel::Send(const void* data, size_t size) { + TVMByteArray bytes; + bytes.data = static_cast(data); + bytes.size = size; + int64_t n = fsend_(bytes); + if (n == -1) { + LOG(FATAL) << "CallbackChannel::Send"; + } + return static_cast(n); +} + +size_t CallbackChannel::Recv(void* data, size_t size) { + TVMRetValue ret = frecv_(size); + + if (ret.type_code() != kTVMBytes) { + LOG(FATAL) << "CallbackChannel::Recv"; + } + std::string* bytes = ret.ptr(); + memcpy(static_cast(data), bytes->c_str(), bytes->length()); + return bytes->length(); } } // namespace runtime diff --git a/src/runtime/rpc/rpc_channel.h b/src/runtime/rpc/rpc_channel.h new file mode 100644 index 000000000000..114bc0a2e7bd --- /dev/null +++ b/src/runtime/rpc/rpc_channel.h @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file rpc_channel.h + * \brief Communication endpoints to connect local and remote RPC sessions. + */ +#ifndef TVM_RUNTIME_RPC_RPC_CHANNEL_H_ +#define TVM_RUNTIME_RPC_RPC_CHANNEL_H_ + +#include + +#include + +namespace tvm { +namespace runtime { + +/*! + * \brief Abstract channel interface used to create RPCEndpoint. + */ +class RPCChannel { + public: + /*! \brief virtual destructor */ + virtual ~RPCChannel() {} + /*! + * \brief Send data over to the channel. + * \param data The data pointer. + * \param size The size fo the data. + * \return The actual bytes sent. + */ + virtual size_t Send(const void* data, size_t size) = 0; + /*! + * \brief Recv data from channel. + * + * \param data The data pointer. + * \param size The size fo the data. + * \return The actual bytes received. + */ + virtual size_t Recv(void* data, size_t size) = 0; +}; + +/*! + * \brief RPC channel which callback + * frontend (Python/Java/etc.)'s send & recv function + */ +class CallbackChannel final : public RPCChannel { + public: + /*! + * \brief Constructor. + * + * \param fsend The send function, takes in a TVMByteArray and returns the + * number of bytes sent in that array. Returns -1 if error happens. + * \param frecv The recv function, takes an expected maximum size, and return + * a byte array with the actual amount of data received. + */ + explicit CallbackChannel(PackedFunc fsend, PackedFunc frecv) + : fsend_(std::move(fsend)), frecv_(std::move(frecv)) {} + + ~CallbackChannel() {} + /*! + * \brief Send data over to the channel. + * \param data The data pointer. + * \param size The size fo the data. + * \return The actual bytes sent. + */ + size_t Send(const void* data, size_t size) final; + /*! + * \brief Recv data from channel. + * + * \param data The data pointer. + * \param size The size fo the data. + * \return The actual bytes received. + */ + size_t Recv(void* data, size_t size) final; + + private: + PackedFunc fsend_; + PackedFunc frecv_; +}; + +} // namespace runtime +} // namespace tvm +#endif // TVM_RUNTIME_RPC_RPC_CHANNEL_H_ diff --git a/src/runtime/rpc/rpc_device_api.cc b/src/runtime/rpc/rpc_device_api.cc index 9fd45acd14bf..196a97ecbd66 100644 --- a/src/runtime/rpc/rpc_device_api.cc +++ b/src/runtime/rpc/rpc_device_api.cc @@ -21,8 +21,11 @@ * \file rpc_device_api.cc */ #include -#include #include +#include + +#include + #include "rpc_session.h" namespace tvm { @@ -31,20 +34,22 @@ namespace runtime { class RPCDeviceAPI final : public DeviceAPI { public: void SetDevice(TVMContext ctx) final { - GetSess(ctx)->CallRemote( - RPCCode::kDevSetDevice, ctx); + auto remote_ctx = RemoveSessMask(ctx); + GetSess(ctx)->GetDeviceAPI(remote_ctx)->SetDevice(remote_ctx); } + void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final { - *rv = GetSess(ctx)->CallRemote( - RPCCode::kDevGetAttr, ctx, static_cast(kind)); + auto remote_ctx = RemoveSessMask(ctx); + GetSess(ctx)->GetDeviceAPI(remote_ctx)->GetAttr(remote_ctx, kind, rv); } - void* AllocDataSpace(TVMContext ctx, - size_t nbytes, - size_t alignment, + + void* AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment, DLDataType type_hint) final { auto sess = GetSess(ctx); - void *data = sess->CallRemote( - RPCCode::kDevAllocData, ctx, nbytes, alignment, type_hint); + auto remote_ctx = RemoveSessMask(ctx); + void* data = + sess->GetDeviceAPI(remote_ctx)->AllocDataSpace(remote_ctx, nbytes, alignment, type_hint); + RemoteSpace* space = new RemoteSpace(); space->data = data; space->sess = std::move(sess); @@ -52,68 +57,68 @@ class RPCDeviceAPI final : public DeviceAPI { } void FreeDataSpace(TVMContext ctx, void* ptr) final { RemoteSpace* space = static_cast(ptr); + auto remote_ctx = RemoveSessMask(ctx); try { - GetSess(ctx)->CallRemote( - RPCCode::kDevFreeData, ctx, space->data); + GetSess(ctx)->GetDeviceAPI(remote_ctx)->FreeDataSpace(remote_ctx, space->data); } catch (const dmlc::Error& e) { // fault tolerance to remote close. } delete space; } - void CopyDataFromTo(const void* from, - size_t from_offset, - void* to, - size_t to_offset, - size_t size, - TVMContext ctx_from, - TVMContext ctx_to, - DLDataType type_hint, + void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size, + TVMContext ctx_from, TVMContext ctx_to, DLDataType type_hint, TVMStreamHandle stream) final { int from_dev_type = ctx_from.device_type; int to_dev_type = ctx_to.device_type; - if (from_dev_type > kRPCSessMask && - to_dev_type > kRPCSessMask) { + if (from_dev_type > kRPCSessMask && to_dev_type > kRPCSessMask) { CHECK(ctx_from.device_type == ctx_to.device_type) << "Cannot copy across two different remote session"; - GetSess(ctx_from)->CallRemote( - RPCCode::kCopyAmongRemote, - static_cast(from)->data, from_offset, - static_cast(to)->data, to_offset, - size, ctx_from, ctx_to, type_hint, stream); - } else if (from_dev_type > kRPCSessMask && - to_dev_type == kDLCPU) { - GetSess(ctx_from)->CopyFromRemote( - static_cast(from)->data, from_offset, - to, to_offset, size, ctx_from, type_hint); - } else if (from_dev_type == kDLCPU && - to_dev_type > kRPCSessMask) { - GetSess(ctx_to)->CopyToRemote( - (void*)from, from_offset, // NOLINT(*) - static_cast(to)->data, to_offset, - size, ctx_to, type_hint); + auto remote_ctx_from = RemoveSessMask(ctx_from); + auto remote_ctx_to = RemoveSessMask(ctx_to); + auto remote_ctx = remote_ctx_from; + if (remote_ctx.device_type == kDLCPU) remote_ctx = remote_ctx_to; + GetSess(ctx_from) + ->GetDeviceAPI(remote_ctx) + ->CopyDataFromTo(static_cast(from)->data, from_offset, + static_cast(to)->data, to_offset, size, + remote_ctx_from, remote_ctx_to, type_hint, stream); + } else if (from_dev_type > kRPCSessMask && to_dev_type == kDLCPU) { + auto remote_ctx_from = RemoveSessMask(ctx_from); + GetSess(ctx_from)->CopyFromRemote(static_cast(from)->data, from_offset, + to, to_offset, size, remote_ctx_from, type_hint); + } else if (from_dev_type == kDLCPU && to_dev_type > kRPCSessMask) { + auto remote_ctx_to = RemoveSessMask(ctx_to); + GetSess(ctx_to)->CopyToRemote(const_cast(from), from_offset, + static_cast(to)->data, to_offset, size, + remote_ctx_to, type_hint); } else { LOG(FATAL) << "expect copy from/to remote or between remote"; } } + void StreamSync(TVMContext ctx, TVMStreamHandle stream) final { - GetSess(ctx)->CallRemote( - RPCCode::kDevStreamSync, ctx, stream); + auto remote_ctx = RemoveSessMask(ctx); + GetSess(ctx)->GetDeviceAPI(remote_ctx)->StreamSync(remote_ctx, stream); } private: std::shared_ptr GetSess(TVMContext ctx) { int dev_type = ctx.device_type; CHECK_GE(dev_type, kRPCSessMask); - int tbl_index = dev_type / kRPCSessMask - 1; + int tbl_index = dev_type / kRPCSessMask - 1; return RPCSession::Get(tbl_index); } + + static TVMContext RemoveSessMask(TVMContext ctx) { + ctx.device_type = static_cast(ctx.device_type % kRPCSessMask); + return ctx; + } }; -TVM_REGISTER_GLOBAL("device_api.rpc") -.set_body([](TVMArgs args, TVMRetValue* rv) { - static RPCDeviceAPI inst; - DeviceAPI* ptr = &inst; - *rv = static_cast(ptr); - }); +TVM_REGISTER_GLOBAL("device_api.rpc").set_body([](TVMArgs args, TVMRetValue* rv) { + static RPCDeviceAPI inst; + DeviceAPI* ptr = &inst; + *rv = static_cast(ptr); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/rpc/rpc_endpoint.cc b/src/runtime/rpc/rpc_endpoint.cc new file mode 100644 index 000000000000..bf85dc56dac9 --- /dev/null +++ b/src/runtime/rpc/rpc_endpoint.cc @@ -0,0 +1,1034 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file rpc_session.cc + * \brief RPC session for remote function call. + */ +#include "rpc_endpoint.h" + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "../../support/arena.h" +#include "../../support/ring_buffer.h" +#include "../object_internal.h" +#include "rpc_local_session.h" + +namespace tvm { +namespace runtime { + +/*! + * Event-driven state-machine based handlers for RPCEndpoint. + * + * Key functions: + * + * - SendPackedSeq: send the arguments over to the peer + * - HandleNextEvent: handle the next request from the peer(RPCCode followed by per code protocol). + */ +class RPCEndpoint::EventHandler : public dmlc::Stream { + public: + EventHandler(support::RingBuffer* reader, support::RingBuffer* writer, std::string name, + std::string* remote_key, std::function flush_writer) + : reader_(reader), + writer_(writer), + name_(name), + remote_key_(remote_key), + flush_writer_(flush_writer) { + this->Clear(); + + if (*remote_key == "%toinit") { + state_ = kInitHeader; + remote_key_->resize(0); + pending_request_bytes_ = sizeof(int32_t); + } + } + + /*! + * \brief Bytes needed to fulfill current request + */ + size_t BytesNeeded() const { + if (reader_->bytes_available() < pending_request_bytes_) { + return pending_request_bytes_ - reader_->bytes_available(); + } else { + return 0; + } + } + + /*! + * \brief Request number of bytes from the reader. + * \param nbytes The number of bytes + */ + void RequestBytes(size_t nbytes) { + pending_request_bytes_ += nbytes; + reader_->Reserve(pending_request_bytes_); + } + + /*! \return Whether we are ready to handle next request. */ + bool Ready() const { return reader_->bytes_available() >= pending_request_bytes_; } + + /*! \return Whether we can perform a clean shutdown */ + bool CanCleanShutdown() const { return state_ == kRecvPacketNumBytes; } + + /*! \brief Finish the copy ack stage. */ + void FinishCopyAck() { this->SwitchToState(kRecvPacketNumBytes); } + + /*! + * \brief Enter the io loop until the next event. + * \param client_mode Whether we are in the client. + * \param async_server_mode Whether we are in the async server mode. + * \param setreturn The function to set the return value encoding. + * \return The function to set return values when there is a return event. + */ + RPCCode HandleNextEvent(bool client_mode, bool async_server_mode, + RPCSession::FEncodeReturn setreturn) { + std::swap(client_mode_, client_mode); + std::swap(async_server_mode_, async_server_mode); + + RPCCode status = RPCCode::kNone; + + while (status == RPCCode::kNone && state_ != kWaitForAsyncCallback && this->Ready()) { + switch (state_) { + case kInitHeader: + HandleInitHeader(); + break; + case kRecvPacketNumBytes: { + uint64_t packet_nbytes; + CHECK(this->Read(&packet_nbytes)); + if (packet_nbytes != 0) { + this->SwitchToState(kProcessPacket); + this->RequestBytes(packet_nbytes); + } else { + this->SwitchToState(kRecvPacketNumBytes); + } + break; + } + case kProcessPacket: { + this->HandleProcessPacket(setreturn); + break; + } + case kWaitForAsyncCallback: { + break; + } + case kReturnReceived: { + this->SwitchToState(kRecvPacketNumBytes); + status = RPCCode::kReturn; + break; + } + case kCopyAckReceived: { + status = RPCCode::kCopyAck; + break; + } + case kShutdownReceived: { + status = RPCCode::kShutdown; + } + } + } + + std::swap(async_server_mode_, async_server_mode); + std::swap(client_mode_, client_mode); + return status; + } + + /*! \brief Clear all the states in the Handler.*/ + void Clear() { + state_ = kRecvPacketNumBytes; + pending_request_bytes_ = sizeof(uint64_t); + } + + /*! + * \brief Validate that the arguments can be sent through RPC. + * \param arg_values The argument values. + * \param type_codes The type codes. + */ + void ValidateArguments(const TVMValue* arg_values, const int* type_codes, int num_args) { + TVMArgs args(arg_values, type_codes, num_args); + for (int i = 0; i < num_args; ++i) { + int tcode = type_codes[i]; + if (tcode == kTVMObjectHandle || tcode == kTVMObjectRValueRefArg) { + LOG(FATAL) << "ValueError: Cannot pass argument " << i << ", type " + << args[i].AsObjectRef()->GetTypeKey() << " is not supported by RPC"; + } else if (tcode == kTVMContext) { + DLContext ctx = args[i]; + CHECK_LT(static_cast(ctx.device_type), kRPCSessMask) + << "InternalError: cannot pass RPC context in the channel"; + } + } + } + + void ThrowError(RPCServerStatus code, RPCCode info = RPCCode::kNone) { + LOG(FATAL) << "RPCServerError:" << RPCServerStatusToString(code); + } + + uint64_t PackedSeqGetNumBytes(const TVMValue* arg_values, const int* type_codes, int num_args, + bool client_mode) { + return RPCReference::PackedSeqGetNumBytes(arg_values, type_codes, num_args, client_mode, this); + } + + void SendPackedSeq(const TVMValue* arg_values, const int* type_codes, int num_args, + bool client_mode) { + RPCReference::SendPackedSeq(arg_values, type_codes, num_args, client_mode, this); + } + + // Endian aware IO handling + using Stream::Read; + using Stream::ReadArray; + using Stream::Write; + using Stream::WriteArray; + + bool Read(RPCCode* code) { + int32_t cdata; + if (!this->Read(&cdata)) return false; + *code = static_cast(cdata); + return true; + } + void Write(RPCCode code) { + int32_t cdata = static_cast(code); + this->Write(cdata); + } + + template + T* ArenaAlloc(int count) { + static_assert(std::is_pod::value, "need to be trival"); + return arena_.template allocate_(count); + } + + protected: + enum State { + kInitHeader, + kRecvPacketNumBytes, + kProcessPacket, + kWaitForAsyncCallback, + kReturnReceived, + kCopyAckReceived, + kShutdownReceived + }; + // Current state; + State state_; + // Initialize remote header + bool init_header_step_{0}; + // Whether current handler is client or server mode. + bool client_mode_{false}; + // Whether current handler is in the async server mode. + bool async_server_mode_{false}; + // Internal arena + support::Arena arena_; + + // State switcher + void SwitchToState(State state) { + // invariant + if (state != kCopyAckReceived) { + CHECK_EQ(pending_request_bytes_, 0U) << "state=" << state; + } + // need to actively flush the writer + // so the data get pushed out. + if (state_ == kWaitForAsyncCallback) { + flush_writer_(); + } + state_ = state; + CHECK(state != kInitHeader) << "cannot switch to init header"; + if (state == kRecvPacketNumBytes) { + this->RequestBytes(sizeof(uint64_t)); + // recycle arena for the next session. + arena_.RecycleAll(); + } + } + + // handler for initial header read + void HandleInitHeader() { + if (init_header_step_ == 0) { + int32_t len; + this->Read(&len); + remote_key_->resize(len); + init_header_step_ = 1; + this->RequestBytes(len); + return; + } else { + CHECK_EQ(init_header_step_, 1); + this->ReadArray(dmlc::BeginPtr(*remote_key_), remote_key_->length()); + this->SwitchToState(kRecvPacketNumBytes); + } + } + + // Handler for read code. + void HandleProcessPacket(RPCSession::FEncodeReturn setreturn) { + RPCCode code = RPCCode::kNone; + this->Read(&code); + + if (code >= RPCCode::kSyscallCodeStart) { + this->HandleSyscall(code); + } else { + switch (code) { + case RPCCode::kInitServer: { + this->HandleInitServer(); + break; + } + case RPCCode::kCallFunc: { + this->HandleNormalCallFunc(); + break; + } + case RPCCode::kCopyFromRemote: { + this->HandleCopyFromRemote(); + break; + } + case RPCCode::kCopyToRemote: { + this->HandleCopyToRemote(); + break; + } + case RPCCode::kException: + case RPCCode::kReturn: { + this->HandleReturn(code, setreturn); + break; + } + case RPCCode::kCopyAck: { + this->SwitchToState(kCopyAckReceived); + break; + } + case RPCCode::kShutdown: { + this->SwitchToState(kShutdownReceived); + break; + } + default: + LOG(FATAL) << "Unknown event " << static_cast(code); + } + } + } + + /*! + * \brief Recive incoming packed seq from the stream. + * \return The received argments. + * \note The TVMArgs is available until we switchstate. + */ + TVMArgs RecvPackedSeq() { + TVMValue* values; + int* tcodes; + int num_args; + RPCReference::RecvPackedSeq(&values, &tcodes, &num_args, this); + return TVMArgs(values, tcodes, num_args); + } + + /*! + * \brief Return exception to the remote. + * \param err_msg The error message. + */ + void ReturnException(const char* err_msg) { RPCReference::ReturnException(err_msg, this); } + + /*! + * \brief Return nullptr to the remote. + * \param err_msg The error message. + */ + void ReturnVoid() { RPCReference::ReturnVoid(this); } + + /*! + * \brief Return a packed sequence to the remote. + * \param args The arguments. + */ + void ReturnPackedSeq(TVMArgs args) { + RPCReference::ReturnPackedSeq(args.values, args.type_codes, args.size(), this); + } + + /*! + * \brief Handle the case when return/exception value is received. + * \param code The RPC code. + * \param setreturn The function to encode return. + */ + void HandleReturn(RPCCode code, RPCSession::FEncodeReturn setreturn) { + TVMArgs args = RecvPackedSeq(); + + if (code == RPCCode::kException) { + // switch to the state before sending exception. + this->SwitchToState(kRecvPacketNumBytes); + std::string msg = args[0]; + LOG(FATAL) << "RPCError: Error caught from RPC call:\n" << msg; + } + + CHECK(setreturn != nullptr) << "fsetreturn not available"; + setreturn(args); + + this->SwitchToState(kReturnReceived); + } + + void HandleSyscall(RPCCode code); + + void HandleCopyFromRemote() { + uint64_t handle, offset, num_bytes; + TVMContext ctx; + DLDataType type_hint; + this->Read(&handle); + this->Read(&offset); + this->Read(&num_bytes); + this->Read(&ctx); + this->Read(&type_hint); + size_t elem_bytes = (type_hint.bits * type_hint.lanes + 7) / 8; + + auto* sess = GetServingSession(); + + // Return Copy Ack with the given data + auto fcopyack = [this](char* data_ptr, size_t num_bytes) { + RPCCode code = RPCCode::kCopyAck; + uint64_t packet_nbytes = sizeof(code) + num_bytes; + + this->Write(packet_nbytes); + this->Write(code); + this->WriteArray(data_ptr, num_bytes); + this->SwitchToState(kRecvPacketNumBytes); + }; + + // When session is local, we can directly treat handle + // as the cpu pointer without allocating a temp space. + if (ctx.device_type == kDLCPU && sess->IsLocalSession() && DMLC_IO_NO_ENDIAN_SWAP) { + char* data_ptr = reinterpret_cast(handle) + offset; + fcopyack(data_ptr, num_bytes); + } else { + char* data_ptr = this->ArenaAlloc(num_bytes); + + auto on_copy_complete = [this, elem_bytes, num_bytes, data_ptr, fcopyack](RPCCode status, + TVMArgs args) { + if (status == RPCCode::kException) { + this->ReturnException(args.values[0].v_str); + this->SwitchToState(kRecvPacketNumBytes); + } else { + // endian aware handling + if (!DMLC_IO_NO_ENDIAN_SWAP) { + dmlc::ByteSwap(data_ptr, elem_bytes, num_bytes / elem_bytes); + } + fcopyack(data_ptr, num_bytes); + } + }; + + this->SwitchToState(kWaitForAsyncCallback); + sess->AsyncCopyFromRemote(reinterpret_cast(handle), offset, data_ptr, 0, num_bytes, + ctx, type_hint, on_copy_complete); + } + } + + void HandleCopyToRemote() { + uint64_t handle, offset, num_bytes; + TVMContext ctx; + DLDataType type_hint; + + this->Read(&handle); + this->Read(&offset); + this->Read(&num_bytes); + this->Read(&ctx); + this->Read(&type_hint); + + size_t elem_bytes = (type_hint.bits * type_hint.lanes + 7) / 8; + auto* sess = GetServingSession(); + + // When session is local, we can directly treat handle + // as the cpu pointer without allocating a temp space. + if (ctx.device_type == kDLCPU && sess->IsLocalSession()) { + char* dptr = reinterpret_cast(handle) + offset; + this->ReadArray(dptr, num_bytes); + + if (!DMLC_IO_NO_ENDIAN_SWAP) { + dmlc::ByteSwap(dptr, elem_bytes, num_bytes / elem_bytes); + } + this->ReturnVoid(); + this->SwitchToState(kRecvPacketNumBytes); + } else { + char* temp_data = this->ArenaAlloc(num_bytes); + this->ReadArray(temp_data, num_bytes); + + if (!DMLC_IO_NO_ENDIAN_SWAP) { + dmlc::ByteSwap(temp_data, elem_bytes, num_bytes / elem_bytes); + } + + auto on_copy_complete = [this](RPCCode status, TVMArgs args) { + if (status == RPCCode::kException) { + this->ReturnException(args.values[0].v_str); + this->SwitchToState(kRecvPacketNumBytes); + } else { + this->ReturnVoid(); + this->SwitchToState(kRecvPacketNumBytes); + } + }; + + this->SwitchToState(kWaitForAsyncCallback); + sess->AsyncCopyToRemote(temp_data, 0, reinterpret_cast(handle), offset, num_bytes, ctx, + type_hint, on_copy_complete); + } + } + + // Handle for packed call. + void HandleNormalCallFunc() { + uint64_t call_handle; + + this->Read(&call_handle); + TVMArgs args = RecvPackedSeq(); + + this->SwitchToState(kWaitForAsyncCallback); + GetServingSession()->AsyncCallFunc(reinterpret_cast(call_handle), args.values, + args.type_codes, args.size(), + [this](RPCCode status, TVMArgs args) { + if (status == RPCCode::kException) { + this->ReturnException(args.values[0].v_str); + } else { + this->ReturnPackedSeq(args); + } + this->SwitchToState(kRecvPacketNumBytes); + }); + } + + void HandleInitServer() { + std::string client_protocol_ver; + + uint64_t len; + this->Read(&len); + client_protocol_ver.resize(len); + this->Read(dmlc::BeginPtr(client_protocol_ver), len); + + TVMArgs args = RecvPackedSeq(); + + try { + CHECK(serving_session_ == nullptr) << "Server has already been initialized"; + + std::string server_protocol_ver = kRPCProtocolVer; + CHECK_EQ(client_protocol_ver, server_protocol_ver) + << "Server[" << name_ << "]: Client protocol version mismatch with the server " + << " server protocol=" << server_protocol_ver + << ", client protocol=" << client_protocol_ver; + + std::string constructor_name; + TVMArgs constructor_args = TVMArgs(nullptr, nullptr, 0); + + if (args.size() == 0) { + constructor_name = "rpc.LocalSession"; + serving_session_ = std::make_shared(); + } else { + constructor_name = args[0].operator std::string(); + constructor_args = TVMArgs(args.values + 1, args.type_codes + 1, args.size() - 1); + } + + auto* fconstructor = Registry::Get(constructor_name); + CHECK(fconstructor != nullptr) << " Cannot find session constructor " << constructor_name; + TVMRetValue con_ret; + + try { + fconstructor->CallPacked(constructor_args, &con_ret); + } catch (const dmlc::Error& e) { + LOG(FATAL) << "Server[" << name_ << "]:" + << " Error caught from session constructor " << constructor_name << ":\n" + << e.what(); + } + + CHECK_EQ(con_ret.type_code(), kTVMModuleHandle) + << "Server[" << name_ << "]:" + << " Constructor " << constructor_name << " need to return an RPCModule"; + Module mod = con_ret; + std::string tkey = mod->type_key(); + CHECK_EQ(tkey, "rpc") << "Constructor " << constructor_name << " to return an RPCModule"; + serving_session_ = RPCModuleGetSession(mod); + this->ReturnVoid(); + } catch (const std::runtime_error& e) { + this->ReturnException(e.what()); + } + + this->SwitchToState(kRecvPacketNumBytes); + } + + void HandleSyscallStreamSync() { + TVMArgs args = RecvPackedSeq(); + try { + TVMContext ctx = args[0]; + TVMStreamHandle handle = args[1]; + + this->SwitchToState(kWaitForAsyncCallback); + GetServingSession()->AsyncStreamWait(ctx, handle, [this](RPCCode status, TVMArgs args) { + if (status == RPCCode::kException) { + this->ReturnException(args.values[0].v_str); + } else { + this->ReturnVoid(); + } + this->SwitchToState(kRecvPacketNumBytes); + }); + } catch (const std::runtime_error& e) { + this->ReturnException(e.what()); + this->SwitchToState(kRecvPacketNumBytes); + } + } + + // Handler for special syscalls that have a specific RPCCode. + template + void SysCallHandler(F f) { + TVMArgs args = RecvPackedSeq(); + try { + TVMRetValue rv; + f(GetServingSession(), args, &rv); + TVMValue ret_value; + int ret_tcode; + TVMArgsSetter setter(&ret_value, &ret_tcode); + setter(0, rv); + + this->ReturnPackedSeq(TVMArgs(&ret_value, &ret_tcode, 1)); + } catch (const std::runtime_error& e) { + this->ReturnException(e.what()); + } + this->SwitchToState(kRecvPacketNumBytes); + } + + private: + RPCSession* GetServingSession() const { + CHECK(serving_session_ != nullptr) + << "Need to call InitRemoteSession first before any further actions"; + CHECK(!serving_session_->IsAsync() || async_server_mode_) + << "Cannot host an async session in a non-Event driven server"; + + return serving_session_.get(); + } + // Utility functions + // Internal read function, update pending_request_bytes_ + size_t Read(void* data, size_t size) final { + CHECK_LE(size, pending_request_bytes_); + reader_->Read(data, size); + pending_request_bytes_ -= size; + return size; + } + // wriite the data to the channel. + void Write(const void* data, size_t size) final { writer_->Write(data, size); } + // Number of pending bytes requests + size_t pending_request_bytes_{0}; + // The ring buffer to read data from. + support::RingBuffer* reader_; + // The ringr buffer to write reply to. + support::RingBuffer* writer_; + // The session used to serve the RPC requests. + std::shared_ptr serving_session_; + // Name of endpoint. + std::string name_; + // remote key + std::string* remote_key_; + // function to flush the writer. + std::function flush_writer_; +}; + +RPCCode RPCEndpoint::HandleUntilReturnEvent(bool client_mode, RPCSession::FEncodeReturn setreturn) { + RPCCode code = RPCCode::kCallFunc; + while (code != RPCCode::kReturn && code != RPCCode::kShutdown && code != RPCCode::kCopyAck) { + while (writer_.bytes_available() != 0) { + writer_.ReadWithCallback( + [this](const void* data, size_t size) { return channel_->Send(data, size); }, + writer_.bytes_available()); + } + size_t bytes_needed = handler_->BytesNeeded(); + if (bytes_needed != 0) { + size_t n = reader_.WriteWithCallback( + [this](void* data, size_t size) { return channel_->Recv(data, size); }, bytes_needed); + if (n == 0) { + if (handler_->CanCleanShutdown()) { + return RPCCode::kShutdown; + } else { + LOG(FATAL) << "Channel closes before we get neded bytes"; + } + } + } + code = handler_->HandleNextEvent(client_mode, false, setreturn); + } + return code; +} + +void RPCEndpoint::Init() { + // callback to flush the writer. + auto flush_writer = [this]() { + while (writer_.bytes_available() != 0) { + size_t n = writer_.ReadWithCallback( + [this](const void* data, size_t size) { return channel_->Send(data, size); }, + writer_.bytes_available()); + if (n == 0) break; + } + }; + + // Event handler + handler_ = std::make_shared(&reader_, &writer_, name_, &remote_key_, flush_writer); + + // Quick function to for syscall remote. + syscall_remote_ = PackedFunc([this](TVMArgs all_args, TVMRetValue* rv) { + std::lock_guard lock(mutex_); + RPCCode code = static_cast(all_args[0].operator int()); + TVMArgs args(all_args.values + 1, all_args.type_codes + 1, all_args.num_args - 1); + + uint64_t packet_nbytes = sizeof(code) + handler_->PackedSeqGetNumBytes( + args.values, args.type_codes, args.num_args, true); + + // All packet begins with packet nbytes + handler_->Write(packet_nbytes); + handler_->Write(code); + handler_->SendPackedSeq(args.values, args.type_codes, args.num_args, true); + + code = HandleUntilReturnEvent(true, [rv](TVMArgs args) { + CHECK_EQ(args.size(), 1); + *rv = args[0]; + }); + CHECK(code == RPCCode::kReturn) << "code=" << static_cast(code); + }); +} + +std::shared_ptr RPCEndpoint::Create(std::unique_ptr channel, + std::string name, std::string remote_key) { + std::shared_ptr endpt = std::make_shared(); + endpt->channel_ = std::move(channel); + endpt->name_ = std::move(name); + endpt->remote_key_ = std::move(remote_key); + endpt->Init(); + return endpt; +} + +RPCEndpoint::~RPCEndpoint() { this->Shutdown(); } + +void RPCEndpoint::Shutdown() { + if (channel_ != nullptr) { + RPCCode code = RPCCode::kShutdown; + uint64_t packet_nbytes = sizeof(code); + + handler_->Write(packet_nbytes); + handler_->Write(code); + + // flush all writing buffer to output channel. + try { + while (writer_.bytes_available() != 0) { + size_t n = writer_.ReadWithCallback( + [this](const void* data, size_t size) { return channel_->Send(data, size); }, + writer_.bytes_available()); + if (n == 0) break; + } + } catch (const dmlc::Error& e) { + } + channel_.reset(nullptr); + } +} + +void RPCEndpoint::ServerLoop() { + if (const auto* f = Registry::Get("tvm.rpc.server.start")) { + (*f)(); + } + TVMRetValue rv; + CHECK(HandleUntilReturnEvent(false, [](TVMArgs) {}) == RPCCode::kShutdown); + if (const auto* f = Registry::Get("tvm.rpc.server.shutdown")) { + (*f)(); + } + channel_.reset(nullptr); +} + +int RPCEndpoint::ServerAsyncIOEventHandler(const std::string& in_bytes, int event_flag) { + RPCCode code = RPCCode::kNone; + if (in_bytes.length() != 0) { + reader_.Write(in_bytes.c_str(), in_bytes.length()); + code = handler_->HandleNextEvent(false, true, [](TVMArgs) {}); + } + if ((event_flag & 2) != 0 && writer_.bytes_available() != 0) { + writer_.ReadWithCallback( + [this](const void* data, size_t size) { return channel_->Send(data, size); }, + writer_.bytes_available()); + } + CHECK(code != RPCCode::kReturn && code != RPCCode::kCopyAck); + if (code == RPCCode::kShutdown) return 0; + if (writer_.bytes_available() != 0) return 2; + return 1; +} + +void RPCEndpoint::InitRemoteSession(TVMArgs args) { + std::lock_guard lock(mutex_); + RPCCode code = RPCCode::kInitServer; + std::string protocol_ver = kRPCProtocolVer; + uint64_t length = protocol_ver.length(); + + uint64_t packet_nbytes = + sizeof(code) + sizeof(length) + length + + handler_->PackedSeqGetNumBytes(args.values, args.type_codes, args.num_args, true); + + // All packet begins with packet nbytes + handler_->Write(packet_nbytes); + handler_->Write(code); + handler_->Write(length); + handler_->WriteArray(protocol_ver.data(), length); + handler_->SendPackedSeq(args.values, args.type_codes, args.num_args, true); + + code = HandleUntilReturnEvent(true, [](TVMArgs args) {}); + CHECK(code == RPCCode::kReturn) << "code=" << static_cast(code); +} + +// Get remote function with name +void RPCEndpoint::CallFunc(RPCSession::PackedFuncHandle h, const TVMValue* arg_values, + const int* arg_type_codes, int num_args, + RPCSession::FEncodeReturn encode_return) { + std::lock_guard lock(mutex_); + + handler_->ValidateArguments(arg_values, arg_type_codes, num_args); + RPCCode code = RPCCode::kCallFunc; + uint64_t handle = reinterpret_cast(h); + + uint64_t packet_nbytes = + sizeof(code) + sizeof(handle) + + handler_->PackedSeqGetNumBytes(arg_values, arg_type_codes, num_args, true); + + handler_->Write(packet_nbytes); + handler_->Write(code); + handler_->Write(handle); + handler_->SendPackedSeq(arg_values, arg_type_codes, num_args, true); + + code = HandleUntilReturnEvent(true, encode_return); + CHECK(code == RPCCode::kReturn) << "code=" << static_cast(code); +} + +void RPCEndpoint::CopyToRemote(void* from, size_t from_offset, void* to, size_t to_offset, + size_t data_size, TVMContext ctx_to, DLDataType type_hint) { + std::lock_guard lock(mutex_); + RPCCode code = RPCCode::kCopyToRemote; + uint64_t handle = reinterpret_cast(to); + uint64_t offset = static_cast(to_offset); + uint64_t size = static_cast(data_size); + + uint64_t packet_nbytes = sizeof(code) + sizeof(handle) + sizeof(offset) + sizeof(size) + + sizeof(ctx_to) + sizeof(type_hint) + data_size; + + handler_->Write(packet_nbytes); + handler_->Write(code); + handler_->Write(handle); + handler_->Write(offset); + handler_->Write(size); + handler_->Write(ctx_to); + handler_->Write(type_hint); + handler_->WriteArray(reinterpret_cast(from) + from_offset, data_size); + + CHECK(HandleUntilReturnEvent(true, [](TVMArgs) {}) == RPCCode::kReturn); +} + +void RPCEndpoint::CopyFromRemote(void* from, size_t from_offset, void* to, size_t to_offset, + size_t data_size, TVMContext ctx_from, DLDataType type_hint) { + std::lock_guard lock(mutex_); + RPCCode code = RPCCode::kCopyFromRemote; + uint64_t handle = reinterpret_cast(from); + uint64_t offset = static_cast(from_offset); + uint64_t size = static_cast(data_size); + + uint64_t packet_nbytes = sizeof(code) + sizeof(handle) + sizeof(offset) + sizeof(size) + + sizeof(ctx_from) + sizeof(type_hint); + + handler_->Write(packet_nbytes); + handler_->Write(code); + handler_->Write(handle); + handler_->Write(offset); + handler_->Write(size); + handler_->Write(ctx_from); + handler_->Write(type_hint); + + TVMRetValue rv; + CHECK(HandleUntilReturnEvent(true, [](TVMArgs) {}) == RPCCode::kCopyAck); + handler_->ReadArray(reinterpret_cast(to) + to_offset, data_size); + handler_->FinishCopyAck(); +} + +// SysCallEventHandler functions +void RPCGetGlobalFunc(RPCSession* handler, TVMArgs args, TVMRetValue* rv) { + std::string name = args[0]; + *rv = handler->GetFunction(name); +} + +void RPCFreeHandle(RPCSession* handler, TVMArgs args, TVMRetValue* rv) { + void* handle = args[0]; + int type_code = args[1]; + handler->FreeHandle(handle, type_code); +} + +void RPCDevSetDevice(RPCSession* handler, TVMArgs args, TVMRetValue* rv) { + TVMContext ctx = args[0]; + handler->GetDeviceAPI(ctx)->SetDevice(ctx); +} + +void RPCDevGetAttr(RPCSession* handler, TVMArgs args, TVMRetValue* rv) { + TVMContext ctx = args[0]; + DeviceAttrKind kind = static_cast(args[1].operator int()); + if (kind == kExist) { + DeviceAPI* api = handler->GetDeviceAPI(ctx, true); + if (api != nullptr) { + api->GetAttr(ctx, kind, rv); + } else { + *rv = 0; + } + } else { + handler->GetDeviceAPI(ctx)->GetAttr(ctx, static_cast(kind), rv); + } +} + +void RPCDevAllocData(RPCSession* handler, TVMArgs args, TVMRetValue* rv) { + TVMContext ctx = args[0]; + uint64_t nbytes = args[1]; + uint64_t alignment = args[2]; + DLDataType type_hint = args[3]; + void* data = handler->GetDeviceAPI(ctx)->AllocDataSpace(ctx, nbytes, alignment, type_hint); + *rv = data; +} + +void RPCDevFreeData(RPCSession* handler, TVMArgs args, TVMRetValue* rv) { + TVMContext ctx = args[0]; + void* ptr = args[1]; + handler->GetDeviceAPI(ctx)->FreeDataSpace(ctx, ptr); +} + +void RPCCopyAmongRemote(RPCSession* handler, TVMArgs args, TVMRetValue* rv) { + void* from = args[0]; + uint64_t from_offset = args[1]; + void* to = args[2]; + uint64_t to_offset = args[3]; + uint64_t size = args[4]; + TVMContext ctx_from = args[5]; + TVMContext ctx_to = args[6]; + DLDataType type_hint = args[7]; + TVMStreamHandle stream = args[8]; + TVMContext ctx = ctx_from; + + if (ctx.device_type == kDLCPU) { + ctx = ctx_to; + } else { + CHECK(ctx_to.device_type == kDLCPU || ctx_to.device_type == ctx_from.device_type) + << "Can not copy across different ctx types directly"; + } + handler->GetDeviceAPI(ctx)->CopyDataFromTo(from, from_offset, to, to_offset, size, ctx_from, + ctx_to, type_hint, stream); +} + +void RPCEndpoint::EventHandler::HandleSyscall(RPCCode code) { + // Event handler sit at clean state at this point. + switch (code) { + // system functions + case RPCCode::kFreeHandle: + SysCallHandler(RPCFreeHandle); + break; + case RPCCode::kGetGlobalFunc: + SysCallHandler(RPCGetGlobalFunc); + break; + case RPCCode::kDevSetDevice: + SysCallHandler(RPCDevSetDevice); + break; + case RPCCode::kDevGetAttr: + SysCallHandler(RPCDevGetAttr); + break; + case RPCCode::kDevAllocData: + SysCallHandler(RPCDevAllocData); + break; + case RPCCode::kDevFreeData: + SysCallHandler(RPCDevFreeData); + break; + case RPCCode::kDevStreamSync: + this->HandleSyscallStreamSync(); + break; + case RPCCode::kCopyAmongRemote: + SysCallHandler(RPCCopyAmongRemote); + break; + default: + LOG(FATAL) << "Unknown event " << static_cast(code); + } + + if (state_ != kWaitForAsyncCallback) { + CHECK_EQ(state_, kRecvPacketNumBytes); + } +} + +/*! + * \brief RPC client session that proxies all calls to an endpoint. + */ +class RPCClientSession : public RPCSession, public DeviceAPI { + public: + /*! + * \brief param endpoint The client endpoint of the session. + */ + explicit RPCClientSession(std::shared_ptr endpoint) : endpoint_(endpoint) {} + + // function overrides + PackedFuncHandle GetFunction(const std::string& name) final { + return endpoint_->SysCallRemote(RPCCode::kGetGlobalFunc, name); + } + + void CallFunc(PackedFuncHandle func, const TVMValue* arg_values, const int* arg_type_codes, + int num_args, const FEncodeReturn& fencode_return) final { + endpoint_->CallFunc(func, arg_values, arg_type_codes, num_args, fencode_return); + } + + void CopyToRemote(void* from, size_t from_offset, void* to, size_t to_offset, size_t nbytes, + TVMContext ctx_to, DLDataType type_hint) final { + endpoint_->CopyToRemote(from, from_offset, to, to_offset, nbytes, ctx_to, type_hint); + } + + void CopyFromRemote(void* from, size_t from_offset, void* to, size_t to_offset, size_t nbytes, + TVMContext ctx_from, DLDataType type_hint) final { + endpoint_->CopyFromRemote(from, from_offset, to, to_offset, nbytes, ctx_from, type_hint); + } + + void FreeHandle(void* handle, int type_code) final { + endpoint_->SysCallRemote(RPCCode::kFreeHandle, handle, type_code); + } + + void SetDevice(TVMContext ctx) final { endpoint_->SysCallRemote(RPCCode::kDevSetDevice, ctx); } + + void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final { + if (ctx.device_type == kDLCPU && kind == kExist) { + // cpu always exists. + *rv = 1; + } else { + *rv = endpoint_->SysCallRemote(RPCCode::kDevGetAttr, ctx, static_cast(kind)); + } + } + + void* AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment, + DLDataType type_hint) final { + return endpoint_->SysCallRemote(RPCCode::kDevAllocData, ctx, nbytes, alignment, type_hint); + } + + void FreeDataSpace(TVMContext ctx, void* ptr) final { + endpoint_->SysCallRemote(RPCCode::kDevFreeData, ctx, ptr); + } + + void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size, + TVMContext ctx_from, TVMContext ctx_to, DLDataType type_hint, + TVMStreamHandle stream) final { + endpoint_->SysCallRemote(RPCCode::kCopyAmongRemote, const_cast(from), from_offset, to, + to_offset, size, ctx_from, ctx_to, type_hint, stream); + } + + void StreamSync(TVMContext ctx, TVMStreamHandle stream) final { + endpoint_->SysCallRemote(RPCCode::kDevStreamSync, ctx, stream); + } + + DeviceAPI* GetDeviceAPI(TVMContext ctx, bool allow_missing) final { return this; } + + bool IsLocalSession() const final { return false; } + + private: + std::shared_ptr endpoint_; +}; + +std::shared_ptr CreateClientSession(std::shared_ptr endpoint) { + return std::make_shared(endpoint); +} + +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/rpc/rpc_endpoint.h b/src/runtime/rpc/rpc_endpoint.h new file mode 100644 index 000000000000..2b88cee15c01 --- /dev/null +++ b/src/runtime/rpc/rpc_endpoint.h @@ -0,0 +1,211 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file rpc_endpoint.h + * \brief Communication endpoints to connect local and remote RPC sessions. + */ +#ifndef TVM_RUNTIME_RPC_RPC_ENDPOINT_H_ +#define TVM_RUNTIME_RPC_RPC_ENDPOINT_H_ + +#include + +#include +#include +#include +#include + +#include "../../support/ring_buffer.h" +#include "rpc_channel.h" +#include "rpc_protocol.h" +#include "rpc_session.h" + +namespace tvm { +namespace runtime { + +// Magic header for RPC data plane +const int kRPCMagic = 0xff271; +// magic header for RPC tracker(control plane) +const int kRPCTrackerMagic = 0x2f271; +// sucess response +const int kRPCSuccess = kRPCMagic + 0; +// cannot found matched key in server +const int kRPCMismatch = kRPCMagic + 2; + +/*! \brief Enumeration code for the RPC tracker */ +enum class TrackerCode : int { + kFail = -1, + kSuccess = 0, + kPing = 1, + kStop = 2, + kPut = 3, + kRequest = 4, + kUpdateInfo = 5, + kSummary = 6, + kGetPendingMatchKeys = 7 +}; + +/*! + * \brief Communication endpoints to connect local and remote RPC sessions. + * An endpoint can either be a client or a server. + */ +class RPCEndpoint { + public: + /*! \brief virtual destructor */ + ~RPCEndpoint(); + /*! + * \brief The server loop that server runs to handle RPC calls. + */ + void ServerLoop(); + /*! + * \brief Message handling function for an async IO event driven server. + * + * Called when the server receives a message or an IO event update. + * Event driven handler will never call recv on the channel + * and always relies on the ServerIOEventHandler to receive the data. + * + * \param in_bytes The incoming bytes. + * \param event_flag 1: read_available, 2: write_avaiable. + * \return State flag. + * 1: continue running, no need to write, + * 2: need to write + * 0: shutdown + */ + int ServerAsyncIOEventHandler(const std::string& in_bytes, int event_flag); + + /*! + * \brief Initalize the session on the remote that will be used to back all the RPC requests. + * + * If no session constructor arguments is passed, LocalSession will be used in the remote. + * Otherwise the remote serving session will be constructed using the arguments + * specified in the session_constructor_args. + * + * The construction rule can be summarized as follows: + * + * \code + * + * auto args = session_constructor_args; + * int n = args.size(); + * if (n != 0) { + * std::string constructor = args[0]; + * server.serving_session_ = GetGlobalFunc(constructor)( + * args[1], args[2] ... args[n - 1]) + * } else { + * server.serving_session_ = LocalSession(); + * } + * \endcode + * + * \param session_constructor_args Optional sequence of the remote sesssion constructor. + */ + void InitRemoteSession(TVMArgs session_constructor_args); + + /*! + * \brief Call into remote function + * \param handle The function handle + * \param arg_values The argument values. + * \param arg_type_codes the type codes of the argument. + * \param num_args Number of arguments. + * \param fencode_return The function to receive return value encodings. + */ + void CallFunc(RPCSession::PackedFuncHandle handle, const TVMValue* arg_values, + const int* arg_type_codes, int num_args, RPCSession::FEncodeReturn encode_return); + /*! + * \brief Copy bytes into remote array content. + * \param from The source host data. + * \param from_offset The byte offeset in the from. + * \param to The target array. + * \param to_offset The byte offset in the to. + * \param nbytes The size of the memory in bytes. + * \param ctx_to The target context. + * \param type_hint Hint of content data type. + */ + void CopyToRemote(void* from, size_t from_offset, void* to, size_t to_offset, size_t nbytes, + TVMContext ctx_to, DLDataType type_hint); + /*! + * \brief Copy bytes from remote array content. + * \param from The source host data. + * \param from_offset The byte offeset in the from. + * \param to The target array. + * \param to_offset The byte offset in the to. + * \param nbytes The size of the memory in bytes. + * \param ctx_from The source context. + * \param type_hint Hint of content data type. + */ + void CopyFromRemote(void* from, size_t from_offset, void* to, size_t to_offset, size_t nbytes, + TVMContext ctx_from, DLDataType type_hint); + + /*! + * \brief Call a remote defined system function with arguments. + * \param fcode The function code. + * \param args The arguments + * \return The returned remote value. + */ + template + inline TVMRetValue SysCallRemote(RPCCode fcode, Args&&... args); + /*! + * \brief Create a RPC session with given channel. + * \param channel The communication channel. + * \param name The local name of the session, used for debug + * \param remote_key The remote key of the session + * if remote_key equals "%toinit", we need to re-intialize + * it by event handler. + */ + static std::shared_ptr Create(std::unique_ptr channel, std::string name, + std::string remote_key); + + private: + class EventHandler; + // Handle events until receives a return + // Also flushes channels so that the function advances. + RPCCode HandleUntilReturnEvent(bool client_mode, RPCSession::FEncodeReturn setreturn); + // Initalization + void Init(); + // Shutdown + void Shutdown(); + // Internal channel. + std::unique_ptr channel_; + // Internal mutex + std::mutex mutex_; + // Internal ring buffer. + support::RingBuffer reader_, writer_; + // Event handler. + std::shared_ptr handler_; + // syscall remote with specified function code. + PackedFunc syscall_remote_; + // The name of the session. + std::string name_; + // The remote key + std::string remote_key_; +}; + +/*! + * \brief Create an RPC client session from an RPC client endpoint. + * \param endpoint The endpoint. + * \return The created session. + */ +std::shared_ptr CreateClientSession(std::shared_ptr endpoint); + +// implementation of inline functions +template +inline TVMRetValue RPCEndpoint::SysCallRemote(RPCCode code, Args&&... args) { + return syscall_remote_(static_cast(code), std::forward(args)...); +} +} // namespace runtime +} // namespace tvm +#endif // TVM_RUNTIME_RPC_RPC_ENDPOINT_H_ diff --git a/src/runtime/rpc/rpc_event_impl.cc b/src/runtime/rpc/rpc_event_impl.cc index 29adb0fed108..f5b933fcf79f 100644 --- a/src/runtime/rpc/rpc_event_impl.cc +++ b/src/runtime/rpc/rpc_event_impl.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -19,32 +19,32 @@ /*! * \file rpc_event_impl.cc - * \brief Event based RPC server implementation. + * \brief Event driven RPC server implementation. */ #include + #include -#include "rpc_session.h" + +#include "rpc_endpoint.h" +#include "rpc_local_session.h" namespace tvm { namespace runtime { -PackedFunc CreateEventDrivenServer(PackedFunc fsend, - std::string name, - std::string remote_key) { +PackedFunc CreateEventDrivenServer(PackedFunc fsend, std::string name, std::string remote_key) { static PackedFunc frecv([](TVMArgs args, TVMRetValue* rv) { LOG(FATAL) << "Do not allow explicit receive"; return 0; }); + std::unique_ptr ch(new CallbackChannel(fsend, frecv)); - std::shared_ptr sess = - RPCSession::Create(std::move(ch), name, remote_key); + std::shared_ptr sess = RPCEndpoint::Create(std::move(ch), name, remote_key); return PackedFunc([sess](TVMArgs args, TVMRetValue* rv) { - int ret = sess->ServerEventHandler(args[0], args[1]); - *rv = ret; - }); + int ret = sess->ServerAsyncIOEventHandler(args[0], args[1]); + *rv = ret; + }); } -TVM_REGISTER_GLOBAL("rpc._CreateEventDrivenServer") -.set_body_typed(CreateEventDrivenServer); +TVM_REGISTER_GLOBAL("rpc.CreateEventDrivenServer").set_body_typed(CreateEventDrivenServer); } // namespace runtime } // namespace tvm diff --git a/src/runtime/rpc/rpc_local_session.cc b/src/runtime/rpc/rpc_local_session.cc new file mode 100644 index 000000000000..b35c62d255fc --- /dev/null +++ b/src/runtime/rpc/rpc_local_session.cc @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file local_session.cc + * \brief Local session that directs requests to local API. + */ +#include "rpc_local_session.h" + +#include +#include + +#include + +namespace tvm { +namespace runtime { + +RPCSession::PackedFuncHandle LocalSession::GetFunction(const std::string& name) { + if (auto* fp = tvm::runtime::Registry::Get(name)) { + // return raw handle because the remote need to explicitly manage it. + return new PackedFunc(*fp); + } else { + return nullptr; + } +} + +void LocalSession::EncodeReturn(TVMRetValue rv, const FEncodeReturn& encode_return) { + int rv_tcode = rv.type_code(); + + // return value encoding. + TVMValue ret_value_pack[3]; + int ret_tcode_pack[3]; + TVMArgsSetter set_arg(ret_value_pack, ret_tcode_pack); + // first location always encode type code. + set_arg(0, rv_tcode); + + if (rv_tcode == kTVMNDArrayHandle) { + // We follow a special protocol to return NDArray to client side + // The first pack value is the NDArray handle as DLTensor + // The second pack value is a customized deleter that deletes the NDArray. + rv.MoveToCHost(&ret_value_pack[1], &ret_tcode_pack[1]); + ret_tcode_pack[1] = kTVMDLTensorHandle; + ret_value_pack[2].v_handle = ret_value_pack[1].v_handle; + ret_tcode_pack[2] = kTVMOpaqueHandle; + encode_return(TVMArgs(ret_value_pack, ret_tcode_pack, 3)); + } else if (rv_tcode == kTVMPackedFuncHandle || rv_tcode == kTVMModuleHandle) { + // MoveToCHost means rv no longer manages the object. + // return handle instead. + rv.MoveToCHost(&ret_value_pack[1], &ret_tcode_pack[1]); + ret_tcode_pack[1] = kTVMOpaqueHandle; + encode_return(TVMArgs(ret_value_pack, ret_tcode_pack, 2)); + } else if (rv_tcode == kTVMBytes) { + TVMByteArray byte_arr; + auto* sptr = rv.ptr(); + byte_arr.data = sptr->data(); + byte_arr.size = sptr->length(); + set_arg(1, byte_arr); + encode_return(TVMArgs(ret_value_pack, ret_tcode_pack, 2)); + } else { + set_arg(1, rv); + encode_return(TVMArgs(ret_value_pack, ret_tcode_pack, 2)); + } +} + +void LocalSession::CallFunc(RPCSession::PackedFuncHandle func, const TVMValue* arg_values, + const int* arg_type_codes, int num_args, + const FEncodeReturn& encode_return) { + auto* pf = static_cast(func); + TVMRetValue rv; + pf->CallPacked(TVMArgs(arg_values, arg_type_codes, num_args), &rv); + this->EncodeReturn(std::move(rv), encode_return); +} + +void LocalSession::CopyToRemote(void* from, size_t from_offset, void* to, size_t to_offset, + size_t nbytes, TVMContext ctx_to, DLDataType type_hint) { + TVMContext cpu_ctx; + cpu_ctx.device_type = kDLCPU; + cpu_ctx.device_id = 0; + this->GetDeviceAPI(ctx_to)->CopyDataFromTo(from, from_offset, to, to_offset, nbytes, cpu_ctx, + ctx_to, type_hint, nullptr); + // Copy can happen asynchrously + // synchronize to make sure that copy is completed + this->GetDeviceAPI(ctx_to)->StreamSync(ctx_to, nullptr); +} + +void LocalSession::CopyFromRemote(void* from, size_t from_offset, void* to, size_t to_offset, + size_t nbytes, TVMContext ctx_from, DLDataType type_hint) { + TVMContext cpu_ctx; + cpu_ctx.device_type = kDLCPU; + cpu_ctx.device_id = 0; + + this->GetDeviceAPI(ctx_from)->CopyDataFromTo(from, from_offset, to, to_offset, nbytes, ctx_from, + cpu_ctx, type_hint, nullptr); + // Copy can happen asynchrously + // synchronize to make sure that copy is completed + this->GetDeviceAPI(ctx_from)->StreamSync(ctx_from, nullptr); +} + +void LocalSession::FreeHandle(void* handle, int type_code) { + TVMValue value; + value.v_handle = handle; + // will trigger deleter once the rv goes out of the scope. + TVMRetValue rv = TVMRetValue::MoveFromCHost(value, type_code); +} + +DeviceAPI* LocalSession::GetDeviceAPI(TVMContext ctx, bool allow_missing) { + return DeviceAPI::Get(ctx, allow_missing); +} + +TVM_REGISTER_GLOBAL("rpc.LocalSession").set_body_typed([]() { + return CreateRPCSessionModule(std::make_shared()); +}); + +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/rpc/rpc_local_session.h b/src/runtime/rpc/rpc_local_session.h new file mode 100644 index 000000000000..7a67ce86bf80 --- /dev/null +++ b/src/runtime/rpc/rpc_local_session.h @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file rpc_local_session.h + * \brief Local session that directs all request to the local runtime API. + */ +#ifndef TVM_RUNTIME_RPC_RPC_LOCAL_SESSION_H_ +#define TVM_RUNTIME_RPC_RPC_LOCAL_SESSION_H_ + +#include +#include + +#include +#include +#include + +#include "rpc_session.h" + +namespace tvm { +namespace runtime { + +/*! + * \brief A local session that directly use the handle repr of the + * local tvm runtime objects on the same process. + */ +class LocalSession : public RPCSession { + public: + // function overrides + PackedFuncHandle GetFunction(const std::string& name) override; + + void CallFunc(PackedFuncHandle func, const TVMValue* arg_values, const int* arg_type_codes, + int num_args, const FEncodeReturn& fencode_return) override; + + void CopyToRemote(void* from, size_t from_offset, void* to, size_t to_offset, size_t nbytes, + TVMContext ctx_to, DLDataType type_hint) override; + + void CopyFromRemote(void* from, size_t from_offset, void* to, size_t to_offset, size_t nbytes, + TVMContext ctx_from, DLDataType type_hint) override; + + void FreeHandle(void* handle, int type_code) override; + + DeviceAPI* GetDeviceAPI(TVMContext ctx, bool allow_missing = false) override; + + bool IsLocalSession() const override { return true; } + + protected: + /*! + * \brief internal encode return fucntion. + * \param rv The return value. + * \param encode_return The encoding function. + */ + void EncodeReturn(TVMRetValue rv, const FEncodeReturn& encode_return); +}; + +} // namespace runtime +} // namespace tvm +#endif // TVM_RUNTIME_RPC_RPC_LOCAL_SESSION_H_ diff --git a/src/runtime/rpc/rpc_module.cc b/src/runtime/rpc/rpc_module.cc index 0e48e6fb2708..89f3e7c6c7f8 100644 --- a/src/runtime/rpc/rpc_module.cc +++ b/src/runtime/rpc/rpc_module.cc @@ -18,69 +18,121 @@ */ /*! - * \file rpc_device_api.cc - * \brief RPC module. + * \file rpc_module.cc + * \brief RPC runtime module. */ +#include #include -#include + #include +#include + +#include "rpc_endpoint.h" #include "rpc_session.h" namespace tvm { namespace runtime { -// Wrapped remote function to packed func. -class RPCWrappedFunc { +/*! + * \brief A wrapped remote function as a PackedFunc. + */ +class RPCWrappedFunc : public Object { public: - RPCWrappedFunc(void* handle, - std::shared_ptr sess) - : handle_(handle), sess_(sess) { - fwrap_ = PackedFunc([sess](TVMArgs args, TVMRetValue* rv) { - WrapRemote(sess, args, rv); - }); - } + RPCWrappedFunc(void* handle, std::shared_ptr sess) : handle_(handle), sess_(sess) {} + + void operator()(TVMArgs args, TVMRetValue* rv) const { + std::vector values(args.values, args.values + args.size()); + std::vector type_codes(args.type_codes, args.type_codes + args.size()); + std::vector> temp_dltensors; - void operator()(TVMArgs args, TVMRetValue *rv) const { - sess_->CallFunc(handle_, args, rv, UnwrapRemote, &fwrap_); + // scan and check whether we need rewrite these arguments + // to their remote variant. + for (int i = 0; i < args.size(); ++i) { + if (args[i].IsObjectRef()) { + String str = args[i]; + type_codes[i] = kTVMStr; + values[i].v_str = str.c_str(); + continue; + } + int tcode = type_codes[i]; + switch (tcode) { + case kTVMDLTensorHandle: + case kTVMNDArrayHandle: { + // Pass NDArray as DLTensor, NDArray and DLTensor + // are compatible to each other, just need to change the index. + type_codes[i] = kTVMDLTensorHandle; + // translate to a remote view of DLTensor + auto dptr = std::make_unique(*static_cast(values[i].v_handle)); + dptr->ctx = RemoveSessMask(dptr->ctx); + dptr->data = static_cast(dptr->data)->data; + values[i].v_handle = dptr.get(); + temp_dltensors.emplace_back(std::move(dptr)); + break; + } + case kTVMContext: { + values[i].v_ctx = RemoveSessMask(values[i].v_ctx); + break; + } + case kTVMPackedFuncHandle: + case kTVMModuleHandle: { + values[i].v_handle = UnwrapRemoteValueToHandle(TVMArgValue(values[i], tcode)); + break; + } + } + } + auto set_return = [this, rv](TVMArgs args) { this->WrapRemoteReturnToValue(args, rv); }; + sess_->CallFunc(handle_, values.data(), type_codes.data(), args.size(), set_return); } + ~RPCWrappedFunc() { try { - sess_->CallRemote(RPCCode::kFreeFunc, handle_); + sess_->FreeHandle(handle_, kTVMPackedFuncHandle); } catch (const dmlc::Error& e) { // fault tolerance to remote close } } - static void WrapRemote(std::shared_ptr sess, - TVMArgs args, - TVMRetValue* rv); + private: + // remote function handle + void* handle_{nullptr}; + // pointer to the session. + std::shared_ptr sess_; - static void* UnwrapRemote(int rpc_sess_table_index, - const TVMArgValue& arg); + // unwrap a remote value to the underlying handle. + void* UnwrapRemoteValueToHandle(const TVMArgValue& arg) const; + // wrap a remote return via Set + void WrapRemoteReturnToValue(TVMArgs args, TVMRetValue* rv) const; + + // remove a remote session mask + TVMContext RemoveSessMask(TVMContext ctx) const { + int dev_type = ctx.device_type; + CHECK_EQ(dev_type / kRPCSessMask, sess_->table_index() + 1) + << "Can not pass in local context or context with a different remote session"; + ctx.device_type = static_cast(ctx.device_type % kRPCSessMask); + return ctx; + } // deleter of RPC remote array static void RemoteNDArrayDeleter(Object* obj) { auto* ptr = static_cast(obj); RemoteSpace* space = static_cast(ptr->dl_tensor.data); - space->sess->CallRemote(RPCCode::kNDArrayFree, ptr->manager_ctx); + space->sess->FreeHandle(ptr->manager_ctx, kTVMNDArrayHandle); delete space; delete ptr; } + // wrap return value as remote NDArray. - static NDArray WrapRemoteNDArray(std::shared_ptr sess, - DLTensor* tensor, - void* nd_handle) { + NDArray WrapRemoteNDArray(DLTensor* tensor, void* nd_handle) const { NDArray::Container* data = new NDArray::Container(); data->manager_ctx = nd_handle; data->SetDeleter(RemoteNDArrayDeleter); RemoteSpace* space = new RemoteSpace(); - space->sess = sess; + space->sess = sess_; space->data = tensor->data; data->dl_tensor.data = space; NDArray ret(GetObjectPtr(data)); // RAII now in effect - data->shape_ = std::vector( - tensor->shape, tensor->shape + tensor->ndim); + data->shape_ = std::vector(tensor->shape, tensor->shape + tensor->ndim); data->dl_tensor.shape = dmlc::BeginPtr(data->shape_); data->dl_tensor.ndim = static_cast(data->shape_.size()); // setup dtype @@ -88,31 +140,25 @@ class RPCWrappedFunc { // setup ctx, encode as remote session data->dl_tensor.ctx.device_id = tensor->ctx.device_id; data->dl_tensor.ctx.device_type = static_cast( - static_cast(tensor->ctx.device_type) + - kRPCSessMask * (sess->table_index() + 1)); + static_cast(tensor->ctx.device_type) + kRPCSessMask * (sess_->table_index() + 1)); // check strides. CHECK(tensor->strides == nullptr); // setup byteoffset data->dl_tensor.byte_offset = tensor->byte_offset; return ret; } - - private: - PackedFunc fwrap_; - void* handle_{nullptr}; - std::shared_ptr sess_; }; // RPC that represents a remote module session. class RPCModuleNode final : public ModuleNode { public: RPCModuleNode(void* module_handle, std::shared_ptr sess) - : module_handle_(module_handle), sess_(sess) { - } + : module_handle_(module_handle), sess_(sess) {} + ~RPCModuleNode() { if (module_handle_ != nullptr) { try { - sess_->CallRemote(RPCCode::kModuleFree, module_handle_); + sess_->FreeHandle(module_handle_, kTVMModuleHandle); } catch (const dmlc::Error& e) { // fault tolerance to remote close } @@ -120,177 +166,247 @@ class RPCModuleNode final : public ModuleNode { } } - const char* type_key() const final { - return "rpc"; - } + const char* type_key() const final { return "rpc"; } - PackedFunc GetFunction( - const std::string& name, - const ObjectPtr& sptr_to_self) final { - RPCFuncHandle handle = GetFuncHandle(name); - return WrapRemote(handle); + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { + if (module_handle_ == nullptr) { + return WrapRemoteFunc(sess_->GetFunction(name)); + } else { + InitRemoteFunc(&remote_mod_get_function_, "tvm.rpc.server.ModuleGetFunction"); + return remote_mod_get_function_(GetRef(this), name, false); + } } std::string GetSource(const std::string& format) final { - if (module_handle_ != nullptr) { - std::string ret = sess_->CallRemote( - RPCCode::kModuleGetSource, module_handle_, format); - } + LOG(FATAL) << "GetSource for rpc Module is not supported"; return ""; } - std::shared_ptr& sess() { - return sess_; + PackedFunc GetTimeEvaluator(const std::string& name, TVMContext ctx, int number, int repeat, + int min_repeat_ms) { + InitRemoteFunc(&remote_get_time_evaluator_, "runtime.RPCTimeEvaluator"); + // Remove session mask because we pass ctx by parts. + int dev_type = ctx.device_type; + CHECK_EQ(dev_type / kRPCSessMask, sess_->table_index() + 1) + << "ValueError: Need to pass the matched remote context to RPCModule.GetTimeEvaluator"; + ctx.device_type = static_cast(ctx.device_type % kRPCSessMask); + + if (module_handle_ != nullptr) { + return remote_get_time_evaluator_(GetRef(this), name, + static_cast(ctx.device_type), ctx.device_id, number, + repeat, min_repeat_ms); + } else { + return remote_get_time_evaluator_(Optional(nullptr), name, + static_cast(ctx.device_type), ctx.device_id, number, + repeat, min_repeat_ms); + } } - PackedFunc GetTimeEvaluator(const std::string& name, - TVMContext ctx, - int number, - int repeat, - int min_repeat_ms) { - RPCFuncHandle handle = GetFuncHandle(name); - if (handle == nullptr) return PackedFunc(); - handle = sess_->GetTimeEvaluator(handle, ctx, number, repeat, min_repeat_ms); - return WrapRemote(handle); + Module LoadModule(std::string name) { + InitRemoteFunc(&remote_load_module_, "tvm.rpc.server.load_module"); + return remote_load_module_(name); } - void* module_handle() const { - return module_handle_; + void ImportModule(Module other) { + InitRemoteFunc(&remote_import_module_, "tvm.rpc.server.ImportModule"); + remote_import_module_(GetRef(this), other); } + const std::shared_ptr& sess() { return sess_; } + + void* module_handle() const { return module_handle_; } + private: - PackedFunc WrapRemote(RPCFuncHandle handle) { + template + void InitRemoteFunc(FType* func, const std::string& name) { + if (*func != nullptr) return; + RPCSession::PackedFuncHandle handle = sess_->GetFunction(name); + CHECK(handle != nullptr) << "Cannot found remote function " << name; + *func = WrapRemoteFunc(handle); + } + + PackedFunc WrapRemoteFunc(RPCSession::PackedFuncHandle handle) { if (handle == nullptr) return PackedFunc(); auto wf = std::make_shared(handle, sess_); - return PackedFunc([wf](TVMArgs args, TVMRetValue* rv) { - return wf->operator()(args, rv); - }); + return PackedFunc([wf](TVMArgs args, TVMRetValue* rv) { return wf->operator()(args, rv); }); } - RPCFuncHandle GetFuncHandle(const std::string& name) { - RPCFuncHandle handle = nullptr; - if (module_handle_ == nullptr) { - handle = sess_->CallRemote(RPCCode::kGetGlobalFunc, name); - } else { - handle = sess_->CallRemote( - RPCCode::kModuleGetFunc, module_handle_, name); - } - return handle; - } // The module handle void* module_handle_{nullptr}; // The local channel std::shared_ptr sess_; - // Wrap function to wrap remote module/function. - PackedFunc fwrap_; + // remote function to get time evaluator + TypedPackedFunc, std::string, int, int, int, int, int)> + remote_get_time_evaluator_; + // remote function getter for modules. + TypedPackedFunc remote_mod_get_function_; + // remote function getter for load module + TypedPackedFunc remote_load_module_; + // remote function getter for load module + TypedPackedFunc remote_import_module_; }; -void* RPCWrappedFunc::UnwrapRemote(int rpc_sess_table_index, - const TVMArgValue& arg) { +void* RPCWrappedFunc::UnwrapRemoteValueToHandle(const TVMArgValue& arg) const { if (arg.type_code() == kTVMModuleHandle) { Module mod = arg; std::string tkey = mod->type_key(); - CHECK_EQ(tkey, "rpc") - << "ValueError: Cannot pass a non-RPC module to remote"; + CHECK_EQ(tkey, "rpc") << "ValueError: Cannot pass a non-RPC module to remote"; auto* rmod = static_cast(mod.operator->()); - CHECK_EQ(rmod->sess()->table_index(), rpc_sess_table_index) + CHECK(rmod->sess() == sess_) << "ValueError: Cannot pass in module into a different remote session"; return rmod->module_handle(); } else { - LOG(FATAL) << "ValueError: Cannot pass type " - << runtime::TypeCode2Str(arg.type_code()) + LOG(FATAL) << "ValueError: Cannot pass type " << runtime::ArgTypeCode2Str(arg.type_code()) << " as an argument to the remote"; return nullptr; } } -void RPCWrappedFunc::WrapRemote(std::shared_ptr sess, - TVMArgs args, - TVMRetValue *rv) { - void* handle = args.values[0].v_handle; - int tcode = args.type_codes[0]; +void RPCWrappedFunc::WrapRemoteReturnToValue(TVMArgs args, TVMRetValue* rv) const { + int tcode = args[0]; - if (handle == nullptr) return; + if (tcode == kTVMNullptr) return; if (tcode == kTVMPackedFuncHandle) { - auto wf = std::make_shared(handle, sess); - *rv = PackedFunc([wf](TVMArgs args, TVMRetValue* rv) { - return wf->operator()(args, rv); - }); + CHECK_EQ(args.size(), 2); + void* handle = args[1]; + auto wf = std::make_shared(handle, sess_); + *rv = PackedFunc([wf](TVMArgs args, TVMRetValue* rv) { return wf->operator()(args, rv); }); } else if (tcode == kTVMModuleHandle) { - auto n = make_object(handle, sess); + CHECK_EQ(args.size(), 2); + void* handle = args[1]; + auto n = make_object(handle, sess_); *rv = Module(n); } else if (tcode == kTVMDLTensorHandle || tcode == kTVMNDArrayHandle) { - CHECK_EQ(args.size(), 2); - DLTensor* tensor = args[0]; - void* nd_handle = args[1]; - *rv = WrapRemoteNDArray(sess, tensor, nd_handle); + CHECK_EQ(args.size(), 3); + DLTensor* tensor = args[1]; + void* nd_handle = args[2]; + *rv = WrapRemoteNDArray(tensor, nd_handle); } else { - LOG(FATAL) << "Cannot wrap tcode=" << tcode; + CHECK_EQ(args.size(), 2); + *rv = args[1]; } } -Module CreateRPCModule(std::shared_ptr sess) { +Module CreateRPCSessionModule(std::shared_ptr sess) { auto n = make_object(nullptr, sess); + RPCSession::InsertToSessionTable(sess); return Module(n); } -TVM_REGISTER_GLOBAL("runtime.RPCTimeEvaluator") -.set_body([](TVMArgs args, TVMRetValue* rv) { - Module m = args[0]; - std::string tkey = m->type_key(); - TVMContext ctx; - ctx.device_type = static_cast(args[2].operator int()); - ctx.device_id = args[3]; - if (tkey == "rpc") { - *rv = static_cast(m.operator->()) - ->GetTimeEvaluator(args[1], ctx, args[4], args[5], args[6]); - } else { - *rv = WrapTimeEvaluator( - m.GetFunction(args[1], false), ctx, args[4], args[5], args[6]); +std::shared_ptr RPCModuleGetSession(Module mod) { + std::string tkey = mod->type_key(); + CHECK_EQ(tkey, "rpc") << "ValueError: Cannot pass a non-RPC module to remote"; + auto* rmod = static_cast(mod.operator->()); + return rmod->sess(); +} + +PackedFunc WrapTimeEvaluator(PackedFunc pf, TVMContext ctx, int number, int repeat, + int min_repeat_ms) { + CHECK(pf != nullptr); + + if (static_cast(ctx.device_type) == static_cast(kDLMicroDev)) { + auto get_micro_time_evaluator = runtime::Registry::Get("micro._GetMicroTimeEvaluator"); + CHECK(get_micro_time_evaluator != nullptr) << "micro backend not enabled"; + return (*get_micro_time_evaluator)(pf, ctx, number, repeat); + } + + auto ftimer = [pf, ctx, number, repeat, min_repeat_ms](TVMArgs args, TVMRetValue* rv) mutable { + TVMRetValue temp; + std::ostringstream os; + // skip first time call, to activate lazy compilation components. + pf.CallPacked(args, &temp); + + DeviceAPI::Get(ctx)->StreamSync(ctx, nullptr); + + for (int i = 0; i < repeat; ++i) { + std::chrono::time_point tbegin, + tend; + double duration_ms = 0.0; + + do { + if (duration_ms > 0.0) { + number = static_cast(std::max((min_repeat_ms / (duration_ms / number) + 1), + number * 1.618)); // 1.618 is chosen by random + } + + tbegin = std::chrono::high_resolution_clock::now(); + // start timing + for (int i = 0; i < number; ++i) { + pf.CallPacked(args, &temp); + } + DeviceAPI::Get(ctx)->StreamSync(ctx, nullptr); + tend = std::chrono::high_resolution_clock::now(); + + duration_ms = + std::chrono::duration_cast>(tend - tbegin).count() * 1000; + } while (duration_ms < min_repeat_ms); + + double speed = + std::chrono::duration_cast>(tend - tbegin).count() / number; + os.write(reinterpret_cast(&speed), sizeof(speed)); } - }); - -TVM_REGISTER_GLOBAL("rpc._LoadRemoteModule") -.set_body([](TVMArgs args, TVMRetValue* rv) { - Module m = args[0]; - std::string tkey = m->type_key(); - CHECK_EQ(tkey, "rpc"); - auto& sess = static_cast(m.operator->())->sess(); - void* mhandle = sess->CallRemote(RPCCode::kModuleLoad, args[1]); - auto n = make_object(mhandle, sess); - *rv = Module(n); - }); - -TVM_REGISTER_GLOBAL("rpc._ImportRemoteModule") -.set_body([](TVMArgs args, TVMRetValue* rv) { - Module parent = args[0]; - Module child = args[1]; - CHECK(!std::strcmp(parent->type_key(), "rpc") && - !std::strcmp(child->type_key(), "rpc")); - auto* pmod = static_cast(parent.operator->()); - auto* cmod = static_cast(child.operator->()); - CHECK(pmod->sess().get() == cmod->sess().get()) - << "Import of remote module need to belong to same session."; - pmod->sess()->CallRemote(RPCCode::kModuleImport, - pmod->module_handle(), - cmod->module_handle()); - }); - -TVM_REGISTER_GLOBAL("rpc._ModuleHandle") -.set_body([](TVMArgs args, TVMRetValue* rv) { - Module m = args[0]; - std::string tkey = m->type_key(); - CHECK_EQ(tkey, "rpc"); - *rv = static_cast(m.operator->())->module_handle(); - }); - -TVM_REGISTER_GLOBAL("rpc._SessTableIndex") -.set_body([](TVMArgs args, TVMRetValue* rv) { - Module m = args[0]; - std::string tkey = m->type_key(); - CHECK_EQ(tkey, "rpc"); - *rv = static_cast(m.operator->())->sess()->table_index(); - }); + + std::string blob = os.str(); + TVMByteArray arr; + arr.size = blob.length(); + arr.data = blob.data(); + // return the time. + *rv = arr; + }; + return PackedFunc(ftimer); +} + +TVM_REGISTER_GLOBAL("runtime.RPCTimeEvaluator") + .set_body_typed([](Optional opt_mod, std::string name, int device_type, int device_id, + int number, int repeat, int min_repeat_ms) { + TVMContext ctx; + ctx.device_type = static_cast(device_type); + ctx.device_id = device_id; + if (opt_mod.defined()) { + Module m = opt_mod.value(); + std::string tkey = m->type_key(); + if (tkey == "rpc") { + return static_cast(m.operator->()) + ->GetTimeEvaluator(name, ctx, number, repeat, min_repeat_ms); + } else { + return WrapTimeEvaluator(m.GetFunction(name, false), ctx, number, repeat, min_repeat_ms); + } + } else { + auto* pf = runtime::Registry::Get(name); + CHECK(pf != nullptr) << "Cannot find " << name << " in the global function"; + return WrapTimeEvaluator(*pf, ctx, number, repeat, min_repeat_ms); + } + }); + +// server function registration. +TVM_REGISTER_GLOBAL("tvm.rpc.server.ImportModule").set_body_typed([](Module parent, Module child) { + parent->Import(child); +}); + +TVM_REGISTER_GLOBAL("tvm.rpc.server.ModuleGetFunction") + .set_body_typed([](Module parent, std::string name, bool query_imports) { + return parent->GetFunction(name, query_imports); + }); + +// functions to access an RPC module. +TVM_REGISTER_GLOBAL("rpc.LoadRemoteModule").set_body_typed([](Module sess, std::string name) { + std::string tkey = sess->type_key(); + CHECK_EQ(tkey, "rpc"); + return static_cast(sess.operator->())->LoadModule(name); +}); + +TVM_REGISTER_GLOBAL("rpc.ImportRemoteModule").set_body_typed([](Module parent, Module child) { + std::string tkey = parent->type_key(); + CHECK_EQ(tkey, "rpc"); + static_cast(parent.operator->())->ImportModule(child); +}); + +TVM_REGISTER_GLOBAL("rpc.SessTableIndex").set_body([](TVMArgs args, TVMRetValue* rv) { + Module m = args[0]; + std::string tkey = m->type_key(); + CHECK_EQ(tkey, "rpc"); + *rv = static_cast(m.operator->())->sess()->table_index(); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/rpc/rpc_pipe_impl.cc b/src/runtime/rpc/rpc_pipe_impl.cc new file mode 100644 index 000000000000..2f4243574909 --- /dev/null +++ b/src/runtime/rpc/rpc_pipe_impl.cc @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file rpc_pipe_impl.cc + * \brief Pipe-based RPC channel. + */ +// Linux only for now, as linux is the most common usecase. +#if defined(__linux__) || defined(__ANDROID__) + +#include +#include +#include +#include +#include + +#include +#include + +#include "../../support/pipe.h" +#include "rpc_endpoint.h" +#include "rpc_local_session.h" + +namespace tvm { +namespace runtime { + +class PipeChannel final : public RPCChannel { + public: + explicit PipeChannel(int readfd, int writefd, pid_t child_pid) + : readfd_(readfd), writefd_(writefd), child_pid_(child_pid) {} + + ~PipeChannel() { Close(); } + + size_t Send(const void* data, size_t size) final { + ssize_t n = write(writefd_, data, size); + if (n == -1) { + LOG(FATAL) << "Pipe write error"; + } + return static_cast(n); + } + + size_t Recv(void* data, size_t size) final { + ssize_t n = read(readfd_, data, size); + if (n == -1) { + LOG(FATAL) << "Pipe read error"; + } + return static_cast(n); + } + + void Close() { + close(readfd_); + close(writefd_); + kill(child_pid_, SIGKILL); + } + + private: + int readfd_; + int writefd_; + pid_t child_pid_; +}; + +Module CreatePipeClient(std::vector cmd) { + int parent2child[2]; + int child2parent[2]; + CHECK_EQ(pipe(parent2child), 0); + CHECK_EQ(pipe(child2parent), 0); + + int parent_read = child2parent[0]; + int parent_write = parent2child[1]; + int child_read = parent2child[0]; + int child_write = child2parent[1]; + + pid_t pid = fork(); + if (pid == 0) { + // child process + close(parent_read); + close(parent_write); + std::string sread_pipe = std::to_string(child_read); + std::string swrite_pipe = std::to_string(child_write); + std::vector argv; + for (auto& str : cmd) { + argv.push_back(dmlc::BeginPtr(str)); + } + argv.push_back(dmlc::BeginPtr(sread_pipe)); + argv.push_back(dmlc::BeginPtr(swrite_pipe)); + argv.push_back(nullptr); + execvp(argv[0], &argv[0]); + } + // parent process + close(child_read); + close(child_write); + + auto endpt = RPCEndpoint::Create( + std::unique_ptr(new PipeChannel(parent_read, parent_write, pid)), "pipe", + "pipe"); + endpt->InitRemoteSession(TVMArgs(nullptr, nullptr, 0)); + return CreateRPCSessionModule(CreateClientSession(endpt)); +} + +TVM_REGISTER_GLOBAL("rpc.CreatePipeClient").set_body([](TVMArgs args, TVMRetValue* rv) { + std::vector cmd; + for (int i = 0; i < args.size(); ++i) { + cmd.push_back(args[i].operator std::string()); + } + *rv = CreatePipeClient(cmd); +}); + +} // namespace runtime +} // namespace tvm +#endif diff --git a/src/runtime/rpc/rpc_protocol.h b/src/runtime/rpc/rpc_protocol.h new file mode 100644 index 000000000000..3a0555d0cc6d --- /dev/null +++ b/src/runtime/rpc/rpc_protocol.h @@ -0,0 +1,475 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file rpc_procotol.h + * \brief Common header defining the communication code used in the RPC protocol. + */ +#ifndef TVM_RUNTIME_RPC_RPC_PROTOCOL_H_ +#define TVM_RUNTIME_RPC_RPC_PROTOCOL_H_ + +namespace tvm { +namespace runtime { + +/*! \brief The current RPC procotol version. */ +constexpr const char* kRPCProtocolVer = "0.7.0"; + +/*! \brief The RPC code */ +enum class RPCCode : int { + kNone, + kShutdown, + kInitServer, + kCallFunc, + kReturn, + kException, + kCopyFromRemote, + kCopyToRemote, + kCopyAck, + // The following are syscall code that can send over CallRemote + kSyscallCodeStart, + kGetGlobalFunc = kSyscallCodeStart, + kFreeHandle, + kDevSetDevice, + kDevGetAttr, + kDevAllocData, + kDevFreeData, + kDevStreamSync, + kCopyAmongRemote, +}; + +/*! + * \brief List of potential error status during rpc communication. + */ +enum class RPCServerStatus : int { + kSuccess = 0, + kInvalidTypeCodeObject, + kInvalidTypeCodeNDArray, + kInvalidDLTensorFieldStride, + kInvalidDLTensorFieldByteOffset, + kUnknownTypeCode, + kUnknownRPCCode, + kRPCCodeNotSupported, + kUnknownRPCSyscall, + kCheckError, + kReadError, + kWriteError, + kAllocError +}; + +/*! + * \brief Convert RPC server status to string. + * \param status The status. + * \return The corresponding string. + */ +inline const char* RPCServerStatusToString(RPCServerStatus status) { + switch (status) { + case RPCServerStatus::kSuccess: + return "kSuccess"; + case RPCServerStatus::kInvalidTypeCodeObject: + return "kInvalidTypeCodeObject"; + case RPCServerStatus::kInvalidTypeCodeNDArray: + return "kInvalidTypeCodeNDArray"; + case RPCServerStatus::kInvalidDLTensorFieldStride: + return "kInvalidDLTensorFieldStride"; + case RPCServerStatus::kInvalidDLTensorFieldByteOffset: { + return "kInvalidDLTensorFieldByteOffset"; + } + case RPCServerStatus::kUnknownTypeCode: + return "kUnknownTypeCode"; + case RPCServerStatus::kUnknownRPCCode: + return "kUnknownRPCCode"; + case RPCServerStatus::kRPCCodeNotSupported: + return "RPCCodeNotSupported"; + case RPCServerStatus::kUnknownRPCSyscall: + return "kUnknownRPCSyscall"; + case RPCServerStatus::kCheckError: + return "kCheckError"; + case RPCServerStatus::kReadError: + return "kReadError"; + case RPCServerStatus::kWriteError: + return "kWriteError"; + case RPCServerStatus::kAllocError: + return "kAllocError"; + default: + return ""; + } +} + +/*! + * \brief Reference implementation of the communication protocol. + * + * \note The implementation is intentionally written via template + * so it can be used in a dependency free setting. + * + * \sa src/runtime/rpc/device/min_rpc_server.h + */ +struct RPCReference { + /*! + * \brief Auxiliary class to get the packed sequence. + * \tparam TChannel The channel to throw errror. + */ + template + struct PackedSeqNumBytesGetter { + public: + explicit PackedSeqNumBytesGetter(TChannel* channel) : channel_(channel) {} + + template + void Write(const T& value) { + num_bytes_ += sizeof(T); + } + + template + void WriteArray(const T* value, size_t num) { + num_bytes_ += sizeof(T) * num; + } + + void ThrowError(RPCServerStatus status) { channel_->ThrowError(status); } + + uint64_t num_bytes() const { return num_bytes_; } + + private: + TChannel* channel_; + uint64_t num_bytes_{0}; + }; + + /*! + * \return the length of the str. + * \param str the string. + * \return The length. + */ + static uint64_t StrLength(const char* str) { + uint64_t len = 0; + while (str[len] != '\0') ++len; + return len; + } + + /*! + * \brief Get the total nbytes to be sent in the packed sequence. + * + * \param arg_values The values to be sent over. + * \param type_codes The type codes to be sent over. + * \param num_args Number of argument. + * \param client_mode Whether it is a client to server call. + * \param channel The communication channel handler. + * \tparam TChannel The type of the communication channel. + * \return The total number of bytes. + */ + template + static uint64_t PackedSeqGetNumBytes(const TVMValue* arg_values, const int* type_codes, + int num_args, bool client_mode, TChannel* channel) { + PackedSeqNumBytesGetter getter(channel); + SendPackedSeq(arg_values, type_codes, num_args, client_mode, &getter); + return getter.num_bytes(); + } + + /*! + * \brief Send packed argument sequnce to the other peer. + * + * This function serves as the foundational communication primitive between peers. + * + * TVMValue sequence encoding protocol(according to the type): + * + * - int/float/uint/bytes/str: Serialize all content. + * - DLTensor: send meta-data, send data handle as opaque handle(via uint64_t) + * - OpaqueHandle: send as uint64_t + * - ModuleHandle, PackedFuncHandle: send as uint64_t, + * The support to Module/PackedFuncHandle are reserved for arguments + * in the CallFunc from a client to server only. + * Note that we cannot simply take these argument out(as the handle) + * refers to a value on the remote(instead of local). + * + * \param arg_values The values to be sent over. + * \param type_codes The type codes to be sent over. + * \param num_args Number of argument. + * \param client_mode Whether it is a client to server call. + * \param channel The communication channel handler. + * \tparam TChannel The type of the communication channel. + */ + template + static void SendPackedSeq(const TVMValue* arg_values, const int* type_codes, int num_args, + bool client_mode, TChannel* channel) { + channel->Write(num_args); + channel->WriteArray(type_codes, num_args); + + // Argument packing. + for (int i = 0; i < num_args; ++i) { + int tcode = type_codes[i]; + TVMValue value = arg_values[i]; + switch (tcode) { + case kDLInt: + case kDLUInt: + case kDLFloat: { + channel->template Write(value.v_int64); + break; + } + case kTVMDataType: { + channel->Write(value.v_type); + // padding + int32_t padding = 0; + channel->template Write(padding); + break; + } + case kTVMContext: { + channel->Write(value.v_ctx); + break; + } + + case kTVMPackedFuncHandle: + case kTVMModuleHandle: { + if (!client_mode) { + channel->ThrowError(RPCServerStatus::kInvalidTypeCodeObject); + } + // always send handle in 64 bit. + uint64_t handle = reinterpret_cast(value.v_handle); + channel->Write(handle); + break; + } + case kTVMOpaqueHandle: { + // always send handle in 64 bit. + uint64_t handle = reinterpret_cast(value.v_handle); + channel->Write(handle); + break; + } + case kTVMNDArrayHandle: { + channel->ThrowError(RPCServerStatus::kInvalidTypeCodeNDArray); + break; + } + case kTVMDLTensorHandle: { + DLTensor* arr = static_cast(value.v_handle); + TVMContext ctx; + uint64_t data; + // When we return NDArray, we directly return + // the space and the context + // The client will be further wrapping + ctx = arr->ctx; + data = reinterpret_cast(arr->data); + channel->Write(data); + channel->Write(ctx); + channel->Write(arr->ndim); + channel->Write(arr->dtype); + channel->WriteArray(arr->shape, arr->ndim); + if (arr->strides != nullptr) { + channel->ThrowError(RPCServerStatus::kInvalidDLTensorFieldStride); + } + if (arr->byte_offset != 0) { + channel->ThrowError(RPCServerStatus::kInvalidDLTensorFieldByteOffset); + } + break; + } + case kTVMNullptr: + break; + case kTVMStr: { + const char* s = value.v_str; + uint64_t len = StrLength(s); + channel->Write(len); + channel->WriteArray(s, len); + break; + } + case kTVMBytes: { + TVMByteArray* bytes = static_cast(arg_values[i].v_handle); + uint64_t len = bytes->size; + channel->Write(len); + channel->WriteArray(bytes->data, len); + break; + } + default: { + channel->ThrowError(RPCServerStatus::kUnknownTypeCode); + break; + } + } + } + } + + /*! + * \brief Receive packed seq from the channel. + * + * \param out_arg_values The values to be received. + * \param out_tcodes The type codes to be received. + * \param out_num_args Number of argument. + * \param channel The communication channel handler. + * \tparam TChannel The type of the communication channel. + * \note The temporary space are populated via an arena inside channel. + */ + template + static void RecvPackedSeq(TVMValue** out_values, int** out_tcodes, int* out_num_args, + TChannel* channel) { + // receive number of args + int num_args; + channel->Read(&num_args); + *out_num_args = num_args; + + if (num_args == 0) { + *out_values = nullptr; + *out_tcodes = nullptr; + return; + } + + TVMValue* values = channel->template ArenaAlloc(num_args); + int* tcodes = channel->template ArenaAlloc(num_args); + *out_values = values; + *out_tcodes = tcodes; + + // receive type code. + channel->ReadArray(tcodes, num_args); + + // receive arguments + for (int i = 0; i < num_args; ++i) { + auto& value = values[i]; + switch (tcodes[i]) { + case kDLInt: + case kDLUInt: + case kDLFloat: { + channel->template Read(&(value.v_int64)); + break; + } + case kTVMDataType: { + channel->Read(&(value.v_type)); + int32_t padding = 0; + channel->template Read(&padding); + break; + } + case kTVMContext: { + channel->Read(&(value.v_ctx)); + break; + } + case kTVMPackedFuncHandle: + case kTVMModuleHandle: + case kTVMOpaqueHandle: { + // always send handle in 64 bit. + uint64_t handle; + channel->Read(&handle); + value.v_handle = reinterpret_cast(handle); + break; + } + case kTVMNullptr: { + value.v_handle = nullptr; + break; + } + case kTVMStr: { + uint64_t len; + channel->Read(&len); + char* str = channel->template ArenaAlloc(len + 1); + str[len] = '\0'; + channel->ReadArray(str, len); + value.v_str = str; + break; + } + case kTVMBytes: { + uint64_t len; + channel->Read(&len); + TVMByteArray* arr = channel->template ArenaAlloc(1); + char* data = channel->template ArenaAlloc(len); + arr->size = len; + arr->data = data; + channel->ReadArray(data, len); + value.v_handle = arr; + break; + } + case kTVMDLTensorHandle: { + uint64_t handle; + channel->Read(&handle); + DLTensor* arr = channel->template ArenaAlloc(1); + DLTensor& tensor = *arr; + tensor.data = reinterpret_cast(handle); + channel->Read(&(tensor.ctx)); + channel->Read(&(tensor.ndim)); + channel->Read(&(tensor.dtype)); + tensor.shape = channel->template ArenaAlloc(tensor.ndim); + channel->ReadArray(tensor.shape, tensor.ndim); + tensor.strides = nullptr; + tensor.byte_offset = 0; + value.v_handle = arr; + break; + } + default: { + channel->ThrowError(RPCServerStatus::kUnknownTypeCode); + break; + } + } + } + } + + /*! + * \brief Return an exception packet. + * + * \param msg The error message. + * \param channel The communication channel handler. + * \tparam TChannel The type of the communication channel. + */ + template + static void ReturnException(const char* msg, TChannel* channel) { + RPCCode code = RPCCode::kException; + int32_t num_args = 1; + int32_t tcode = kTVMStr; + uint64_t len = StrLength(msg); + + uint64_t packet_nbytes = sizeof(code) + sizeof(num_args) + sizeof(tcode) + sizeof(len) + len; + + channel->Write(packet_nbytes); + channel->Write(code); + channel->Write(num_args); + channel->Write(tcode); + channel->Write(len); + channel->WriteArray(msg, len); + } + + /*! + * \brief Return a normal packed sequence packet. + * + * \param msg The error message. + * \param channel The communication channel handler. + * \tparam TChannel The type of the communication channel. + */ + template + static void ReturnPackedSeq(const TVMValue* arg_values, const int* type_codes, int num_args, + TChannel* channel) { + RPCCode code = RPCCode::kReturn; + + uint64_t packet_nbytes = + sizeof(code) + PackedSeqGetNumBytes(arg_values, type_codes, num_args, false, channel); + + channel->Write(packet_nbytes); + channel->Write(code); + SendPackedSeq(arg_values, type_codes, num_args, false, channel); + } + + /*! + * \brief Return a null(void) packet. + * + * \param channel The communication channel handler. + * \tparam TChannel The type of the communication channel. + */ + template + static void ReturnVoid(TChannel* channel) { + int32_t num_args = 1; + int32_t tcode = kTVMNullptr; + RPCCode code = RPCCode::kReturn; + + uint64_t packet_nbytes = sizeof(code) + sizeof(num_args) + sizeof(tcode); + + channel->Write(packet_nbytes); + channel->Write(code); + channel->Write(num_args); + channel->Write(tcode); + } +}; + +} // namespace runtime +} // namespace tvm +#endif // TVM_RUNTIME_RPC_RPC_PROTOCOL_H_ diff --git a/src/runtime/rpc/rpc_server_env.cc b/src/runtime/rpc/rpc_server_env.cc index f6a7fb60b5f4..b999a48a376a 100644 --- a/src/runtime/rpc/rpc_server_env.cc +++ b/src/runtime/rpc/rpc_server_env.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -22,42 +22,40 @@ * \brief Server environment of the RPC. */ #include + #include "../file_util.h" namespace tvm { namespace runtime { std::string RPCGetPath(const std::string& name) { - static const PackedFunc* f = - runtime::Registry::Get("tvm.rpc.server.workpath"); + // do live lookup everytime as workpath can change. + const PackedFunc* f = runtime::Registry::Get("tvm.rpc.server.workpath"); CHECK(f != nullptr) << "require tvm.rpc.server.workpath"; return (*f)(name); } -TVM_REGISTER_GLOBAL("tvm.rpc.server.upload"). -set_body([](TVMArgs args, TVMRetValue *rv) { - std::string file_name = RPCGetPath(args[0]); - std::string data = args[1]; - SaveBinaryToFile(file_name, data); - }); - -TVM_REGISTER_GLOBAL("tvm.rpc.server.download") -.set_body([](TVMArgs args, TVMRetValue *rv) { - std::string file_name = RPCGetPath(args[0]); - std::string data; - LoadBinaryFromFile(file_name, &data); - TVMByteArray arr; - arr.data = data.c_str(); - arr.size = data.length(); - LOG(INFO) << "Download " << file_name << "... nbytes=" << arr.size; - *rv = arr; - }); - -TVM_REGISTER_GLOBAL("tvm.rpc.server.remove") -.set_body([](TVMArgs args, TVMRetValue *rv) { - std::string file_name = RPCGetPath(args[0]); - RemoveFile(file_name); - }); +TVM_REGISTER_GLOBAL("tvm.rpc.server.upload").set_body([](TVMArgs args, TVMRetValue* rv) { + std::string file_name = RPCGetPath(args[0]); + std::string data = args[1]; + SaveBinaryToFile(file_name, data); +}); + +TVM_REGISTER_GLOBAL("tvm.rpc.server.download").set_body([](TVMArgs args, TVMRetValue* rv) { + std::string file_name = RPCGetPath(args[0]); + std::string data; + LoadBinaryFromFile(file_name, &data); + TVMByteArray arr; + arr.data = data.c_str(); + arr.size = data.length(); + LOG(INFO) << "Download " << file_name << "... nbytes=" << arr.size; + *rv = arr; +}); + +TVM_REGISTER_GLOBAL("tvm.rpc.server.remove").set_body([](TVMArgs args, TVMRetValue* rv) { + std::string file_name = RPCGetPath(args[0]); + RemoveFile(file_name); +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/rpc/rpc_session.cc b/src/runtime/rpc/rpc_session.cc index 43ca630f9496..9e05e5d1628d 100644 --- a/src/runtime/rpc/rpc_session.cc +++ b/src/runtime/rpc/rpc_session.cc @@ -21,816 +21,84 @@ * \file rpc_session.cc * \brief RPC session for remote function call. */ -#include -#include +#include "rpc_session.h" + #include -#include -#include -#include +#include + #include -#include -#include -#include -#include -#include -#include -#include "rpc_session.h" -#include "../object_internal.h" -#include "../../support/ring_buffer.h" -#include "../../support/socket.h" +#include namespace tvm { namespace runtime { -// Temp buffer for data array -struct RPCByteArrayBuffer { - TVMByteArray arr; - std::string data; -}; -// Temp buffer for data array -struct RPCDataArrayBuffer { - DLTensor tensor; - std::vector shape; -}; -/*! - * \brief Temporal argument buffer. - */ -struct RPCArgBuffer { - // The argument values - std::vector value; - // The type codes. - std::vector tcode; - // Temporal resources. - std::vector > temp_bytes; - // Temporal array - std::vector > temp_array; - // convert buffer as TVMArgs - TVMArgs AsTVMArgs() const { - return TVMArgs(value.data(), tcode.data(), static_cast(value.size())); - } -}; - -// Event handler for RPC events. -class RPCSession::EventHandler : public dmlc::Stream { - public: - EventHandler(support::RingBuffer* reader, - support::RingBuffer* writer, - int rpc_sess_table_index, - std::string name, - std::string* remote_key) - : reader_(reader), - writer_(writer), - rpc_sess_table_index_(rpc_sess_table_index), - name_(name), - remote_key_(remote_key) { - this->Clear(); - if (*remote_key == "%toinit") { - state_ = kInitHeader; - remote_key_->resize(0); - pending_request_bytes_ = sizeof(int32_t); - } - } - // Bytes needed to fulfill current request - size_t BytesNeeded() { - if (reader_->bytes_available() < pending_request_bytes_) { - return pending_request_bytes_ - reader_->bytes_available(); - } else { - return 0; - } - } - // Request number of bytes from reader. - void RequestBytes(size_t nbytes) { - pending_request_bytes_ += nbytes; - reader_->Reserve(pending_request_bytes_); - } - // Whether we are ready to handle next request. - bool Ready() { - return reader_->bytes_available() >= pending_request_bytes_; - } - bool CanCleanShutdown() const { - return state_ == kRecvCode; - } - void FinishCopyAck() { - this->SwitchToState(kRecvCode); - } - RPCCode HandleNextEvent(TVMRetValue* rv, - bool client_mode, - const PackedFunc* fwrap) { - std::swap(client_mode_, client_mode); - while (this->Ready()) { - switch (state_) { - case kInitHeader: HandleInitHeader(); break; - case kRecvCode: HandleRecvCode(); break; - case kRecvCallHandle: { - CHECK(this->Read(&call_handle_)); - this->SwitchToState(kRecvPackedSeqNumArgs); - break; - } - case kRecvPackedSeqNumArgs: { - CHECK(this->Read(&num_packed_args_)); - arg_buf_.reset(new RPCArgBuffer()); - arg_buf_->value.resize(num_packed_args_); - arg_buf_->tcode.resize(num_packed_args_); - this->SwitchToState(kRecvPackedSeqTypeCode); - break; - } - case kRecvPackedSeqTypeCode: { - if (num_packed_args_ != 0) { - this->ReadArray(arg_buf_->tcode.data(), num_packed_args_); - } - arg_index_ = 0; - arg_recv_stage_ = 0; - this->SwitchToState(kRecvPackedSeqArg); - break; - } - case kRecvPackedSeqArg: { - this->HandleRecvPackedSeqArg(); - break; - } - case kDoCopyFromRemote: { - this->HandleCopyFromRemote(); - break; - } - case kDoCopyToRemote: { - this->HandleCopyToRemote(); - break; - } - case kReturnReceived: { - CHECK_GE(arg_buf_->value.size(), 1U); +bool RPCSession::IsAsync() const { return false; } - TVMArgValue argv = arg_buf_->AsTVMArgs()[0]; - if (argv.type_code() == kTVMPackedFuncHandle || - argv.type_code() == kTVMModuleHandle || - argv.type_code() == kTVMDLTensorHandle) { - CHECK(fwrap != nullptr) << "function/module wrapper not available"; - fwrap->CallPacked(arg_buf_->AsTVMArgs(), rv); - } else { - CHECK_EQ(arg_buf_->value.size(), 1U); - *rv = argv; - } - arg_buf_.reset(); - this->SwitchToState(kRecvCode); - std::swap(client_mode_, client_mode); - return RPCCode::kReturn; - } - case kCopyAckReceived: { - std::swap(client_mode_, client_mode); - return RPCCode::kCopyAck; - } - case kShutdownReceived: { - std::swap(client_mode_, client_mode); - return RPCCode::kShutdown; - } - } - } - std::swap(client_mode_, client_mode); - return RPCCode::kNone; - } - // Reset and clear all states. - void Clear() { - state_ = kRecvCode; - pending_request_bytes_ = sizeof(RPCCode); - arg_recv_stage_ = 0; - arg_buf_.reset(); - } - // strip session on mask - TVMContext StripSessMask(TVMContext ctx) { - int dev_type = ctx.device_type; - CHECK_EQ(dev_type / kRPCSessMask, rpc_sess_table_index_ + 1) - << "Can not pass in local context or context with a different remote session"; - ctx.device_type = static_cast(dev_type % kRPCSessMask); - return ctx; - } - // Send Packed sequence to writer. - // - // client_mode: whether we are in client mode. - // - // funwrap: auxiliary function to unwrap remote Object - // when it is provided, we need to unwrap objects. - // - // return_ndarray is a special flag to handle returning of ndarray - // In this case, we return the shape, context and data of the array, - // as well as a customized PackedFunc that handles deletion of - // the array in the remote. - void SendPackedSeq(const TVMValue* arg_values, - const int* type_codes, - int num_args, - bool client_mode, - FUnwrapRemoteObject funwrap = nullptr, - bool return_ndarray = false) { - std::swap(client_mode_, client_mode); - - this->Write(num_args); - for (int i = 0; i < num_args; ++i) { - int tcode = type_codes[i]; - if (tcode == kTVMNDArrayHandle) tcode = kTVMDLTensorHandle; - this->Write(tcode); - } - - // Argument packing. - for (int i = 0; i < num_args; ++i) { - int tcode = type_codes[i]; - TVMValue value = arg_values[i]; - switch (tcode) { - case kDLInt: - case kDLUInt: - case kDLFloat: { - this->Write(value.v_int64); - break; - } - case kTVMDataType: { - this->Write(value.v_type); - // padding - int32_t padding = 0; - this->Write(padding); - break; - } - case kTVMContext: { - value.v_ctx = StripSessMask(value.v_ctx); - this->Write(value.v_ctx); - break; - } - case kTVMPackedFuncHandle: - case kTVMModuleHandle: { - // always send handle in 64 bit. - uint64_t handle; - // allow pass module as argument to remote. - if (funwrap != nullptr) { - void* remote_handle = (*funwrap)( - rpc_sess_table_index_, - runtime::TVMArgValue(value, tcode)); - handle = reinterpret_cast(remote_handle); - } else { - CHECK(!client_mode_) - << "Cannot directly pass remote object as argument"; - handle = reinterpret_cast(value.v_handle); - } - this->Write(handle); - break; - } - case kTVMOpaqueHandle: { - // always send handle in 64 bit. - uint64_t handle = reinterpret_cast(value.v_handle); - this->Write(handle); - break; - } - case kTVMNDArrayHandle: - case kTVMDLTensorHandle: { - DLTensor* arr = static_cast(value.v_handle); - TVMContext ctx; - uint64_t data; - if (!return_ndarray) { - // in the client mode - // ctx contains the remote table index - // the space is wrapped by an RemoteSpace - // that holds reference to the session. - ctx = StripSessMask(arr->ctx); - data = reinterpret_cast( - static_cast(arr->data)->data); - } else { - // When we return NDArray, we directly return - // the space and the context - // The client will be further wrapping - ctx = arr->ctx; - data = reinterpret_cast(arr->data); - } - this->Write(data); - this->Write(ctx); - this->Write(arr->ndim); - this->Write(arr->dtype); - this->WriteArray(arr->shape, arr->ndim); - CHECK(arr->strides == nullptr) - << "Do not support strided remote array"; - CHECK_EQ(arr->byte_offset, 0) - << "Do not support send byte offset"; - break; - } - case kTVMNullptr: break; - case kTVMStr: { - const char* s = value.v_str; - uint64_t len = strlen(s); - this->Write(len); - this->WriteArray(s, len); - break; - } - case kTVMBytes: { - TVMByteArray* bytes = static_cast(arg_values[i].v_handle); - uint64_t len = bytes->size; - this->Write(len); - this->WriteArray(bytes->data, len); - break; - } - default: { - LOG(FATAL) << "RPC cannot handle type " << TypeCode2Str(tcode); - break; - } - } - } - std::swap(client_mode_, client_mode); - } - - // Endian aware IO handling - using Stream::Read; - using Stream::Write; - using Stream::ReadArray; - using Stream::WriteArray; - - inline bool Read(RPCCode* code) { - int cdata; - if (!this->Read(&cdata)) return false; - *code = static_cast(cdata); - return true; - } - inline void Write(RPCCode code) { - int cdata = static_cast(code); - this->Write(cdata); - } +void RPCSession::SendException(FAsyncCallback callback, const char* msg) { + TVMValue value; + value.v_str = msg; + int32_t tcode = kTVMStr; + callback(RPCCode::kException, TVMArgs(&value, &tcode, 1)); +} - protected: - enum State { - kInitHeader, - kRecvCode, - kRecvCallHandle, - kRecvPackedSeqNumArgs, - kRecvPackedSeqTypeCode, - kRecvPackedSeqArg, - kDoCopyFromRemote, - kDoCopyToRemote, - kReturnReceived, - kCopyAckReceived, - kShutdownReceived - }; - // Current state; - State state_; - // The RPCCode to be read. - RPCCode code_; - // Handle for the remote function call. - uint64_t call_handle_; - // Initialize remote header - bool init_header_step_{0}; - // Number of packed arguments. - int num_packed_args_; - // Current argument index. - int arg_index_; - // The stage of each argument receiver. - int arg_recv_stage_; - // Whether current handler is client or server mode. - bool client_mode_{false}; - // Argument buffer - std::unique_ptr arg_buf_; - // Temp byte buffer. - std::unique_ptr temp_bytes_; - // Temp array buffer. - std::unique_ptr temp_array_; - // Internal temporal data space. - std::string temp_data_; - // Temp variables for copy request state. - TVMContext copy_ctx_; - DLDataType copy_dtype_; - uint64_t copy_handle_, copy_offset_, copy_size_; - // State switcher - void SwitchToState(State state) { - // invariant - CHECK_EQ(pending_request_bytes_, 0U) - << "state=" << state; - state_ = state; - switch (state) { - case kInitHeader: { - LOG(FATAL) << "cannot switch to init header"; - break; - } - case kRecvCode: { - this->RequestBytes(sizeof(RPCCode)); - break; - } - case kRecvCallHandle: { - this->RequestBytes(sizeof(call_handle_)); - break; - } - case kRecvPackedSeqNumArgs: { - this->RequestBytes(sizeof(num_packed_args_)); - break; - } - case kRecvPackedSeqTypeCode: { - this->RequestBytes(sizeof(int) * num_packed_args_); - break; - } - case kRecvPackedSeqArg: { - CHECK_LE(arg_index_, num_packed_args_); - if (arg_index_ == num_packed_args_) { - // The function can change state_ again. - HandlePackedCall(); - } else { - RequestRecvPackedSeqArg(); - } - break; - } - case kDoCopyFromRemote: { - this->RequestBytes(sizeof(uint64_t) * 3); - this->RequestBytes(sizeof(TVMContext)); - this->RequestBytes(sizeof(DLDataType)); - break; - } - case kDoCopyToRemote: { - this->RequestBytes(sizeof(uint64_t) * 3); - this->RequestBytes(sizeof(TVMContext)); - this->RequestBytes(sizeof(DLDataType)); - break; - } - case kCopyAckReceived: - case kReturnReceived: - case kShutdownReceived: { - break; - } - } - } - // Requets bytes needed for next computation. - void RequestRecvPackedSeqArg() { - CHECK_EQ(arg_recv_stage_, 0); - int tcode = arg_buf_->tcode[arg_index_]; - static_assert(sizeof(TVMValue) == sizeof(uint64_t), "invariant"); - switch (tcode) { - case kDLInt: - case kDLUInt: - case kDLFloat: - case kTVMDataType: - case kTVMOpaqueHandle: - case kTVMStr: - case kTVMBytes: - case kTVMModuleHandle: - case kTVMContext: { - this->RequestBytes(sizeof(TVMValue)); break; - } - case kTVMPackedFuncHandle: { - CHECK(client_mode_) - << "Only client can receive remote functions"; - this->RequestBytes(sizeof(TVMValue)); break; - } - case kTVMNullptr: break; - case kTVMDLTensorHandle: { - this->RequestBytes(sizeof(uint64_t)); - this->RequestBytes(sizeof(TVMContext)); - this->RequestBytes(sizeof(int)); - this->RequestBytes(sizeof(DLDataType)); - break; - } - default: { - LOG(FATAL) << "RPC cannot handle type " << TypeCode2Str(tcode); - break; - } - } - } - // Handler for packed sequence argument receive. - void HandleRecvPackedSeqArg() { - CHECK_LT(arg_index_, num_packed_args_); - int tcode = arg_buf_->tcode[arg_index_]; - TVMValue& value = arg_buf_->value[arg_index_]; - if (arg_recv_stage_ == 0) { - switch (tcode) { - case kDLInt: - case kDLUInt: - case kDLFloat: { - this->Read(&(value.v_int64)); - ++arg_index_; - this->SwitchToState(kRecvPackedSeqArg); - break; - } - case kTVMDataType: { - this->Read(&(value.v_type)); - int32_t padding = 0; - this->Read(&padding); - ++arg_index_; - this->SwitchToState(kRecvPackedSeqArg); - break; - } - case kTVMContext: { - this->Read(&(value.v_ctx)); - ++arg_index_; - this->SwitchToState(kRecvPackedSeqArg); - break; - } - case kTVMPackedFuncHandle: - case kTVMModuleHandle: - case kTVMOpaqueHandle: { - // always send handle in 64 bit. - uint64_t handle; - this->Read(&handle); - value.v_handle = reinterpret_cast(handle); - ++arg_index_; - this->SwitchToState(kRecvPackedSeqArg); - break; - } - case kTVMNullptr: { - value.v_handle = nullptr; - ++arg_index_; - this->SwitchToState(kRecvPackedSeqArg); - break; - } - case kTVMStr: - case kTVMBytes: { - uint64_t len; - this->Read(&len); - temp_bytes_.reset( new RPCByteArrayBuffer()); - temp_bytes_->data.resize(len); - arg_recv_stage_ = 1; - this->RequestBytes(len); - break; - } - case kTVMDLTensorHandle: { - temp_array_.reset(new RPCDataArrayBuffer()); - uint64_t handle; - this->Read(&handle); - DLTensor& tensor = temp_array_->tensor; - tensor.data = reinterpret_cast(handle); - this->Read(&(tensor.ctx)); - this->Read(&(tensor.ndim)); - this->Read(&(tensor.dtype)); - temp_array_->shape.resize(tensor.ndim); - tensor.shape = temp_array_->shape.data(); - arg_recv_stage_ = 1; - tensor.strides = nullptr; - tensor.byte_offset = 0; - this->RequestBytes(sizeof(int64_t) * tensor.ndim); - break; - } - default: { - LOG(FATAL) << "RPC cannot handle type " << TypeCode2Str(tcode); - break; - } - } - } else { - CHECK_EQ(arg_recv_stage_, 1); - if (tcode == kTVMStr || tcode == kTVMBytes) { - if (temp_bytes_->data.size() != 0) { - this->ReadArray(&(temp_bytes_->data[0]), temp_bytes_->data.size()); - } - if (tcode == kTVMStr) { - value.v_str = temp_bytes_->data.c_str(); - } else { - temp_bytes_->arr.size = static_cast(temp_bytes_->data.size()); - temp_bytes_->arr.data = dmlc::BeginPtr(temp_bytes_->data); - value.v_handle = &(temp_bytes_->arr); - } - arg_buf_->temp_bytes.emplace_back(std::move(temp_bytes_)); - } else { - CHECK_EQ(tcode, kTVMDLTensorHandle); - DLTensor& tensor = temp_array_->tensor; - this->ReadArray(tensor.shape, tensor.ndim); - value.v_handle = &tensor; - arg_buf_->temp_array.emplace_back(std::move(temp_array_)); - } - ++arg_index_; - arg_recv_stage_ = 0; - this->SwitchToState(kRecvPackedSeqArg); - } - } - // handler for initial header read - void HandleInitHeader() { - if (init_header_step_ == 0) { - int32_t len; - this->Read(&len); - remote_key_->resize(len); - init_header_step_ = 1; - this->RequestBytes(len); - return; - } else { - CHECK_EQ(init_header_step_, 1); - this->ReadArray(dmlc::BeginPtr(*remote_key_), remote_key_->length()); - this->SwitchToState(kRecvCode); - } - } - // Handler for read code. - void HandleRecvCode() { - this->Read(&code_); - if (code_ > RPCCode::kSystemFuncStart) { - SwitchToState(kRecvPackedSeqNumArgs); - return; - } - // invariant. - CHECK_EQ(arg_recv_stage_, 0); - switch (code_) { - case RPCCode::kCallFunc: { - SwitchToState(kRecvCallHandle); - break; - } - case RPCCode::kException: - case RPCCode::kReturn: { - SwitchToState(kRecvPackedSeqNumArgs); - break; - } - case RPCCode::kCopyFromRemote: { - SwitchToState(kDoCopyFromRemote); - break; - } - case RPCCode::kCopyToRemote: { - SwitchToState(kDoCopyToRemote); - break; - } - case RPCCode::kShutdown: { - SwitchToState(kShutdownReceived); - break; - } - case RPCCode::kCopyAck: { - SwitchToState(kCopyAckReceived); - break; - } - default: LOG(FATAL) << "Unknown event " << static_cast(code_); - } +void RPCSession::AsyncCallFunc(PackedFuncHandle func, const TVMValue* arg_values, + const int* arg_type_codes, int num_args, FAsyncCallback callback) { + try { + this->CallFunc(func, arg_values, arg_type_codes, num_args, + [&callback](TVMArgs args) { callback(RPCCode::kReturn, args); }); + } catch (const std::runtime_error& e) { + this->SendException(callback, e.what()); } +} - void HandleCopyFromRemote() { - uint64_t handle, offset, num_bytes; - TVMContext ctx; - DLDataType type_hint; - this->Read(&handle); - this->Read(&offset); - this->Read(&num_bytes); - this->Read(&ctx); - this->Read(&type_hint); - size_t elem_bytes = (type_hint.bits * type_hint.lanes + 7) / 8; +void RPCSession::AsyncCopyToRemote(void* local_from, size_t local_from_offset, void* remote_to, + size_t remote_to_offset, size_t nbytes, TVMContext remote_ctx_to, + DLDataType type_hint, RPCSession::FAsyncCallback callback) { + TVMValue value; + int32_t tcode = kTVMNullptr; + value.v_handle = nullptr; - if (ctx.device_type == kDLCPU) { - RPCCode code = RPCCode::kCopyAck; - this->Write(code); - char* dptr = reinterpret_cast(handle) + offset; - if (!DMLC_IO_NO_ENDIAN_SWAP) { - temp_data_.resize(0); - temp_data_.insert(temp_data_.end(), dptr, dptr + num_bytes); - dmlc::ByteSwap(dmlc::BeginPtr(temp_data_), elem_bytes, num_bytes / elem_bytes); - this->WriteArray(temp_data_.data(), num_bytes); - } else { - this->WriteArray(dptr, num_bytes); - } - } else { - temp_data_.resize(num_bytes + 1); - try { - TVMContext cpu_ctx; - cpu_ctx.device_type = kDLCPU; - cpu_ctx.device_id = 0; - DeviceAPI::Get(ctx)->CopyDataFromTo( - reinterpret_cast(handle), offset, - dmlc::BeginPtr(temp_data_), 0, - num_bytes, ctx, cpu_ctx, type_hint, nullptr); - RPCCode code = RPCCode::kCopyAck; - this->Write(code); - if (!DMLC_IO_NO_ENDIAN_SWAP) { - dmlc::ByteSwap(dmlc::BeginPtr(temp_data_), elem_bytes, num_bytes / elem_bytes); - } - this->WriteArray(&temp_data_[0], num_bytes); - } catch (const std::runtime_error &e) { - RPCCode code = RPCCode::kException; - this->Write(code); - TVMValue ret_value; - ret_value.v_str = e.what(); - int ret_tcode = kTVMStr; - SendPackedSeq(&ret_value, &ret_tcode, 1, false); - } - } - this->SwitchToState(kRecvCode); + try { + this->CopyToRemote(local_from, local_from_offset, remote_to, remote_to_offset, nbytes, + remote_ctx_to, type_hint); + callback(RPCCode::kReturn, TVMArgs(&value, &tcode, 1)); + } catch (const std::runtime_error& e) { + this->SendException(callback, e.what()); } +} - void HandleCopyToRemote() { - // use static variable to persist state. - // This only works if next stage is immediately after this. - if (arg_recv_stage_ == 0) { - CHECK(this->Read(©_handle_)); - CHECK(this->Read(©_offset_)); - CHECK(this->Read(©_size_)); - CHECK(this->Read(©_ctx_)); - CHECK(this->Read(©_dtype_)); - arg_recv_stage_ = 1; - CHECK_EQ(pending_request_bytes_, 0U); - this->RequestBytes(copy_size_); - } else { - CHECK_EQ(arg_recv_stage_, 1); - TVMValue ret_value; - ret_value.v_handle = nullptr; - int ret_tcode = kTVMNullptr; - RPCCode code = RPCCode::kReturn; - std::string errmsg; +void RPCSession::AsyncCopyFromRemote(void* remote_from, size_t remote_from_offset, void* local_to, + size_t local_to_offset, size_t nbytes, + TVMContext remote_ctx_from, DLDataType type_hint, + RPCSession::FAsyncCallback callback) { + TVMValue value; + int32_t tcode = kTVMNullptr; + value.v_handle = nullptr; - size_t elem_bytes = (copy_dtype_.bits * copy_dtype_.lanes + 7) / 8; - if (copy_ctx_.device_type == kDLCPU) { - char* dptr = reinterpret_cast(copy_handle_) + copy_offset_; - this->ReadArray(dptr, copy_size_); - if (!DMLC_IO_NO_ENDIAN_SWAP) { - dmlc::ByteSwap(dptr, elem_bytes, copy_size_ / elem_bytes); - } - } else { - temp_data_.resize(copy_size_ + 1); - this->ReadArray(&temp_data_[0], copy_size_); - if (!DMLC_IO_NO_ENDIAN_SWAP) { - dmlc::ByteSwap(dmlc::BeginPtr(temp_data_), elem_bytes, copy_size_ / elem_bytes); - } - try { - TVMContext cpu_ctx; - cpu_ctx.device_type = kDLCPU; - cpu_ctx.device_id = 0; - DeviceAPI::Get(copy_ctx_)->CopyDataFromTo( - temp_data_.data(), 0, - reinterpret_cast(copy_handle_), copy_offset_, - copy_size_, cpu_ctx, copy_ctx_, copy_dtype_, nullptr); - } catch (const std::runtime_error &e) { - code = RPCCode::kException; - errmsg = e.what(); - ret_value.v_str = errmsg.c_str(); - ret_tcode = kTVMStr; - } - } - this->Write(code); - SendPackedSeq(&ret_value, &ret_tcode, 1, false); - arg_recv_stage_ = 0; - this->SwitchToState(kRecvCode); - } + try { + this->CopyFromRemote(remote_from, remote_from_offset, local_to, local_to_offset, nbytes, + remote_ctx_from, type_hint); + callback(RPCCode::kReturn, TVMArgs(&value, &tcode, 1)); + } catch (const std::runtime_error& e) { + this->SendException(callback, e.what()); } - // Handle for packed call. - void HandlePackedCall(); +} - template - void CallHandler(F f) { - TVMRetValue rv; - TVMValue ret_value; - int ret_tcode; - try { - // Need to move out, in case f itself need to call RecvPackedSeq - // Which will override argbuf again. - std::unique_ptr args = std::move(arg_buf_); - f(args->AsTVMArgs(), &rv); - RPCCode code = RPCCode::kReturn; - this->Write(code); - if (rv.type_code() == kTVMStr) { - ret_value.v_str = rv.ptr()->c_str(); - ret_tcode = kTVMStr; - SendPackedSeq(&ret_value, &ret_tcode, 1, false); - } else if (rv.type_code() == kTVMBytes) { - std::string* bytes = rv.ptr(); - TVMByteArray arr; - arr.data = bytes->c_str(); - arr.size = bytes->length(); - ret_value.v_handle = &arr; - ret_tcode = kTVMBytes; - SendPackedSeq(&ret_value, &ret_tcode, 1, false); - } else if (rv.type_code() == kTVMPackedFuncHandle || - rv.type_code() == kTVMModuleHandle) { - // always send handle in 64 bit. - CHECK(!client_mode_) - << "Only server can send function and module handle back."; - rv.MoveToCHost(&ret_value, &ret_tcode); - SendPackedSeq(&ret_value, &ret_tcode, 1, false); - } else if (rv.type_code() == kTVMNDArrayHandle) { - // always send handle in 64 bit. - CHECK(!client_mode_) - << "Only server can send NDArray back"; - // We follow a special protocol to return NDArray to client side - // The first pack value is the NDArray handle as DLTensor - // The second pack value is a customized deleter that deletes the NDArray. - TVMValue ret_value_pack[2]; - int ret_tcode_pack[2]; - rv.MoveToCHost(&ret_value_pack[0], &ret_tcode_pack[0]); - ret_value_pack[1].v_handle = ret_value_pack[0].v_handle; - ret_tcode_pack[1] = kTVMOpaqueHandle; - SendPackedSeq(ret_value_pack, ret_tcode_pack, 2, false, nullptr, true); - } else { - ret_value = rv.value(); - ret_tcode = rv.type_code(); - SendPackedSeq(&ret_value, &ret_tcode, 1, false); - } - } catch (const std::runtime_error& e) { - RPCCode code = RPCCode::kException; - this->Write(code); - ret_value.v_str = e.what(); - ret_tcode = kTVMStr; - SendPackedSeq(&ret_value, &ret_tcode, 1, false); - } - } +void RPCSession::AsyncStreamWait(TVMContext ctx, TVMStreamHandle stream, + RPCSession::FAsyncCallback callback) { + TVMValue value; + int32_t tcode = kTVMNullptr; + value.v_handle = nullptr; - private: - // Utility functions - // Internal read function, update pending_request_bytes_ - size_t Read(void* data, size_t size) final { - CHECK_LE(size, pending_request_bytes_); - reader_->Read(data, size); - pending_request_bytes_ -= size; - return size; - } - void Write(const void* data, size_t size) final { - writer_->Write(data, size); + try { + this->GetDeviceAPI(ctx)->StreamSync(ctx, stream); + callback(RPCCode::kReturn, TVMArgs(&value, &tcode, 1)); + } catch (const std::runtime_error& e) { + this->SendException(callback, e.what()); } - // Number of pending bytes requests - size_t pending_request_bytes_; - // The ring buffer to read data from. - support::RingBuffer* reader_; - // The ringr buffer to write reply to. - support::RingBuffer* writer_; - // Session table index. - int rpc_sess_table_index_; - // Name of session. - std::string name_; - // remote key - std::string* remote_key_; -}; +} -struct RPCSessTable { +class RPCSessTable { public: static constexpr int kMaxRPCSession = 32; // Get global singleton @@ -848,7 +116,8 @@ struct RPCSessTable { std::lock_guard lock(mutex_); for (int i = 0; i < kMaxRPCSession; ++i) { if (tbl_[i].lock() == nullptr) { - tbl_[i] = ptr; return i; + tbl_[i] = ptr; + return i; } } LOG(FATAL) << "maximum number of RPC session reached"; @@ -863,493 +132,13 @@ struct RPCSessTable { std::array, kMaxRPCSession> tbl_; }; -RPCCode RPCSession::HandleUntilReturnEvent( - TVMRetValue* rv, bool client_mode, const PackedFunc* fwrap) { - RPCCode code = RPCCode::kCallFunc; - while (code != RPCCode::kReturn && - code != RPCCode::kShutdown && - code != RPCCode::kCopyAck) { - while (writer_.bytes_available() != 0) { - writer_.ReadWithCallback([this](const void *data, size_t size) { - return channel_->Send(data, size); - }, writer_.bytes_available()); - } - size_t bytes_needed = handler_->BytesNeeded(); - if (bytes_needed != 0) { - size_t n = reader_.WriteWithCallback([this](void* data, size_t size) { - return channel_->Recv(data, size); - }, bytes_needed); - if (n == 0) { - if (handler_->CanCleanShutdown()) { - return RPCCode::kShutdown; - } else { - LOG(FATAL) << "Channel closes before we get neded bytes"; - } - } - } - code = handler_->HandleNextEvent(rv, client_mode, fwrap); - } - return code; -} - -void RPCSession::Init() { - // Event handler - handler_ = std::make_shared( - &reader_, &writer_, table_index_, name_, &remote_key_); - // Quick function to call remote. - call_remote_ = PackedFunc([this](TVMArgs args, TVMRetValue* rv) { - handler_->SendPackedSeq(args.values, args.type_codes, args.num_args, true); - RPCCode code = HandleUntilReturnEvent(rv, true, nullptr); - CHECK(code == RPCCode::kReturn) << "code=" << static_cast(code); - }); -} - -std::shared_ptr RPCSession::Create( - std::unique_ptr channel, - std::string name, - std::string remote_key) { - std::shared_ptr sess = std::make_shared(); - sess->channel_ = std::move(channel); - sess->name_ = std::move(name); - sess->remote_key_ = std::move(remote_key); - sess->table_index_ = RPCSessTable::Global()->Insert(sess); - sess->Init(); - return sess; -} - std::shared_ptr RPCSession::Get(int table_index) { return RPCSessTable::Global()->Get(table_index); } -RPCSession::~RPCSession() { - this->Shutdown(); -} - -void RPCSession::Shutdown() { - if (channel_ != nullptr) { - RPCCode code = RPCCode::kShutdown; - handler_->Write(code); - // flush all writing buffer to output channel. - try { - while (writer_.bytes_available() != 0) { - size_t n = writer_.ReadWithCallback([this](const void *data, size_t size) { - return channel_->Send(data, size); - }, writer_.bytes_available()); - if (n == 0) break; - } - } catch (const dmlc::Error& e) { - } - channel_.reset(nullptr); - } -} - -void RPCSession::ServerLoop() { - std::lock_guard lock(mutex_); - if (const auto* f = Registry::Get("tvm.rpc.server.start")) { - (*f)(); - } - TVMRetValue rv; - CHECK(HandleUntilReturnEvent(&rv, false, nullptr) == RPCCode::kShutdown); - if (const auto* f = Registry::Get("tvm.rpc.server.shutdown")) { - (*f)(); - } - channel_.reset(nullptr); -} - -int RPCSession::ServerEventHandler(const std::string& bytes, int event_flag) { - std::lock_guard lock(mutex_); - RPCCode code = RPCCode::kNone; - if (bytes.length() != 0) { - reader_.Write(bytes.c_str(), bytes.length()); - TVMRetValue rv; - code = handler_->HandleNextEvent(&rv, false, nullptr); - } - if ((event_flag & 2) != 0 && writer_.bytes_available() != 0) { - writer_.ReadWithCallback([this](const void *data, size_t size) { - return channel_->Send(data, size); - }, writer_.bytes_available()); - } - CHECK(code != RPCCode::kReturn && code != RPCCode::kCopyAck); - if (code == RPCCode::kShutdown) return 0; - if (writer_.bytes_available() != 0) return 2; - return 1; -} - -// Get remote function with name -void RPCSession::CallFunc(void* h, - TVMArgs args, - TVMRetValue* rv, - FUnwrapRemoteObject funwrap, - const PackedFunc* fwrap) { - std::lock_guard lock(mutex_); - - RPCCode code = RPCCode::kCallFunc; - handler_->Write(code); - uint64_t handle = reinterpret_cast(h); - handler_->Write(handle); - handler_->SendPackedSeq( - args.values, args.type_codes, args.num_args, true, funwrap); - code = HandleUntilReturnEvent(rv, true, fwrap); - CHECK(code == RPCCode::kReturn) << "code=" << static_cast(code); -} - -void RPCSession::CopyToRemote(void* from, - size_t from_offset, - void* to, - size_t to_offset, - size_t data_size, - TVMContext ctx_to, - DLDataType type_hint) { - std::lock_guard lock(mutex_); - ctx_to = handler_->StripSessMask(ctx_to); - RPCCode code = RPCCode::kCopyToRemote; - handler_->Write(code); - uint64_t handle = reinterpret_cast(to); - handler_->Write(handle); - uint64_t offset = static_cast(to_offset); - handler_->Write(offset); - uint64_t size = static_cast(data_size); - handler_->Write(size); - handler_->Write(ctx_to); - handler_->Write(type_hint); - handler_->WriteArray(reinterpret_cast(from) + from_offset, data_size); - TVMRetValue rv; - CHECK(HandleUntilReturnEvent(&rv, true, nullptr) == RPCCode::kReturn); -} - -void RPCSession::CopyFromRemote(void* from, - size_t from_offset, - void* to, - size_t to_offset, - size_t data_size, - TVMContext ctx_from, - DLDataType type_hint) { - std::lock_guard lock(mutex_); - ctx_from = handler_->StripSessMask(ctx_from); - RPCCode code = RPCCode::kCopyFromRemote; - handler_->Write(code); - uint64_t handle = reinterpret_cast(from); - handler_->Write(handle); - uint64_t offset = static_cast(from_offset); - handler_->Write(offset); - uint64_t size = static_cast(data_size); - handler_->Write(size); - handler_->Write(ctx_from); - handler_->Write(type_hint); - TVMRetValue rv; - CHECK(HandleUntilReturnEvent(&rv, true, nullptr) == RPCCode::kCopyAck); - reader_.Reserve(data_size); - handler_->RequestBytes(data_size); - while (!handler_->Ready()) { - size_t bytes_needed = handler_->BytesNeeded(); - reader_.WriteWithCallback([this](void* data, size_t size) { - size_t n = channel_->Recv(data, size); - CHECK_NE(n, 0U) << "Channel closes before we get neded bytes"; - return n; - }, bytes_needed); - } - handler_->ReadArray(reinterpret_cast(to) + to_offset, data_size); - handler_->FinishCopyAck(); -} - -RPCFuncHandle RPCSession::GetTimeEvaluator( - RPCFuncHandle fhandle, TVMContext ctx, int number, int repeat, int min_repeat_ms) { - return this->CallRemote( - RPCCode::kGetTimeEvaluator, fhandle, ctx, number, repeat, min_repeat_ms); -} - -// Event handler functions -void RPCGetGlobalFunc(TVMArgs args, TVMRetValue* rv) { - std::string name = args[0]; - auto *fp = tvm::runtime::Registry::Get(name); - if (fp != nullptr) { - *rv = static_cast(new tvm::runtime::PackedFunc(*fp)); - } else { - *rv = nullptr; - } -} - -void RPCFreeFunc(TVMArgs args, TVMRetValue *rv) { - void* handle = args[0]; - delete static_cast(handle); -} - -void RPCDevSetDevice(TVMArgs args, TVMRetValue *rv) { - TVMContext ctx = args[0]; - DeviceAPI::Get(ctx)->SetDevice(ctx); -} - -void RPCDevGetAttr(TVMArgs args, TVMRetValue *rv) { - TVMContext ctx = args[0]; - DeviceAttrKind kind = static_cast(args[1].operator int()); - if (kind == kExist) { - DeviceAPI* api = DeviceAPI::Get(ctx, true); - if (api != nullptr) { - api->GetAttr(ctx, kind, rv); - } else { - *rv = 0; - } - } else { - DeviceAPI::Get(ctx)->GetAttr( - ctx, static_cast(kind), rv); - } -} - -void RPCDevAllocData(TVMArgs args, TVMRetValue *rv) { - TVMContext ctx = args[0]; - uint64_t nbytes = args[1]; - uint64_t alignment = args[2]; - DLDataType type_hint = args[3]; - void* data = DeviceAPI::Get(ctx)->AllocDataSpace( - ctx, nbytes, alignment, type_hint); - *rv = data; -} - -void RPCDevFreeData(TVMArgs args, TVMRetValue *rv) { - TVMContext ctx = args[0]; - void* ptr = args[1]; - DeviceAPI::Get(ctx)->FreeDataSpace(ctx, ptr); -} - -void RPCDevStreamSync(TVMArgs args, TVMRetValue *rv) { - TVMContext ctx = args[0]; - TVMStreamHandle handle = args[1]; - DeviceAPI::Get(ctx)->StreamSync(ctx, handle); -} - -void RPCCopyAmongRemote(TVMArgs args, TVMRetValue *rv) { - void* from = args[0]; - uint64_t from_offset = args[1]; - void* to = args[2]; - uint64_t to_offset = args[3]; - uint64_t size = args[4]; - TVMContext ctx_from = args[5]; - TVMContext ctx_to = args[6]; - DLDataType type_hint = args[7]; - TVMStreamHandle stream = args[8]; - TVMContext ctx = ctx_from; - if (ctx.device_type == kDLCPU) { - ctx = ctx_to; - } else { - CHECK(ctx_to.device_type == kDLCPU || - ctx_to.device_type == ctx_from.device_type) - << "Can not copy across different ctx types directly"; - } - DeviceAPI::Get(ctx)->CopyDataFromTo( - from, from_offset, - to, to_offset, - size, ctx_from, ctx_to, type_hint, stream); -} - -void RPCModuleLoad(TVMArgs args, TVMRetValue *rv) { - static const PackedFunc* fsys_load_ = nullptr; - if (fsys_load_ == nullptr) { - fsys_load_ = runtime::Registry::Get("tvm.rpc.server.load_module"); - CHECK(fsys_load_ != nullptr); - } - std::string file_name = args[0]; - TVMRetValue ret = (*fsys_load_)(file_name); - // pass via void* - TVMValue value; - int rcode; - ret.MoveToCHost(&value, &rcode); - CHECK_EQ(rcode, kTVMModuleHandle); - *rv = static_cast(value.v_handle); -} - -void RPCModuleImport(TVMArgs args, TVMRetValue *rv) { - void* pmod = args[0]; - void* cmod = args[1]; - ObjectInternal::GetModuleNode(pmod)->Import( - GetRef(ObjectInternal::GetModuleNode(cmod))); -} - -void RPCModuleFree(TVMArgs args, TVMRetValue *rv) { - void* mhandle = args[0]; - ObjectInternal::ObjectFree(mhandle); -} - -void RPCModuleGetFunc(TVMArgs args, TVMRetValue *rv) { - void* mhandle = args[0]; - PackedFunc pf = ObjectInternal::GetModuleNode(mhandle)->GetFunction( - args[1], false); - if (pf != nullptr) { - *rv = static_cast(new PackedFunc(pf)); - } else { - *rv = nullptr; - } -} - -void RPCModuleGetSource(TVMArgs args, TVMRetValue *rv) { - void* mhandle = args[0]; - std::string fmt = args[1]; - *rv = ObjectInternal::GetModuleNode(mhandle)->GetSource(fmt); -} - -void RPCNDArrayFree(TVMArgs args, TVMRetValue *rv) { - void* handle = args[0]; - static_cast( - reinterpret_cast(handle))->DecRef(); -} - -void RPCGetTimeEvaluator(TVMArgs args, TVMRetValue *rv) { - PackedFunc *pf = static_cast(args[0].operator void*()); - void *fhandle = new PackedFunc(WrapTimeEvaluator(*pf, args[1], args[2], args[3], args[4])); - delete pf; - *rv = fhandle; -} - -void RPCSession::EventHandler::HandlePackedCall() { - CHECK_EQ(pending_request_bytes_, 0U); - if (code_ == RPCCode::kReturn) { - state_ = kReturnReceived; return; - } - // reset state to clean init state - state_ = kRecvCode; - this->RequestBytes(sizeof(RPCCode)); - // Event handler sit at clean state at this point. - switch (code_) { - case RPCCode::kCallFunc: { - PackedFunc* pf = reinterpret_cast(call_handle_); - CallHandler([pf](TVMArgs args, TVMRetValue* rv) { - pf->CallPacked(args, rv); - }); - break; - } - case RPCCode::kException: { - CHECK_EQ(arg_buf_->value.size(), 1U); - CHECK_EQ(arg_buf_->tcode[0], kTVMStr); - std::ostringstream os; - os << "Except caught from RPC call: " << arg_buf_->value[0].v_str; - arg_buf_.reset(); - throw dmlc::Error(os.str()); - break; - } - // system functions - case RPCCode::kGetTimeEvaluator: CallHandler(RPCGetTimeEvaluator); break; - case RPCCode::kFreeFunc: CallHandler(RPCFreeFunc); break; - case RPCCode::kGetGlobalFunc: CallHandler(RPCGetGlobalFunc); break; - case RPCCode::kDevSetDevice: CallHandler(RPCDevSetDevice); break; - case RPCCode::kDevGetAttr: CallHandler(RPCDevGetAttr); break; - case RPCCode::kDevAllocData: CallHandler(RPCDevAllocData); break; - case RPCCode::kDevFreeData: CallHandler(RPCDevFreeData); break; - case RPCCode::kDevStreamSync: CallHandler(RPCDevStreamSync); break; - case RPCCode::kCopyAmongRemote: CallHandler(RPCCopyAmongRemote); break; - case RPCCode::kModuleLoad: CallHandler(RPCModuleLoad); break; - case RPCCode::kModuleImport: CallHandler(RPCModuleImport); break; - case RPCCode::kModuleFree: CallHandler(RPCModuleFree); break; - case RPCCode::kModuleGetFunc: CallHandler(RPCModuleGetFunc); break; - case RPCCode::kModuleGetSource: CallHandler(RPCModuleGetSource); break; - case RPCCode::kNDArrayFree: CallHandler(RPCNDArrayFree); break; - default: LOG(FATAL) << "Unknown event " << static_cast(code_); - } - CHECK_EQ(state_, kRecvCode); -} - -PackedFunc MicroTimeEvaluator( - PackedFunc pf, - TVMContext ctx, - int number, - int repeat) { - auto ftimer = [pf, ctx, number, repeat](TVMArgs args, TVMRetValue *rv) mutable { - TVMRetValue temp; - std::ostringstream os; - // skip first time call, to activate lazy compilation components. - pf.CallPacked(args, &temp); - DeviceAPI::Get(ctx)->StreamSync(ctx, nullptr); - for (int i = 0; i < repeat; ++i) { - double speed = 0.0; - for (int j = 0; j < number; ++j) { - pf.CallPacked(args, &temp); - DeviceAPI::Get(ctx)->StreamSync(ctx, nullptr); - speed += (temp.operator double()) / number; - } - os.write(reinterpret_cast(&speed), sizeof(speed)); - } - std::string blob = os.str(); - TVMByteArray arr; - arr.size = blob.length(); - arr.data = blob.data(); - // return the time. - *rv = arr; - }; - return PackedFunc(ftimer); -} - -PackedFunc WrapTimeEvaluator(PackedFunc pf, - TVMContext ctx, - int number, - int repeat, - int min_repeat_ms) { - if (static_cast(ctx.device_type) == static_cast(kDLMicroDev)) { - return MicroTimeEvaluator(pf, ctx, number, repeat); - } - - auto ftimer = [pf, ctx, number, repeat, min_repeat_ms](TVMArgs args, TVMRetValue *rv) mutable { - TVMRetValue temp; - std::ostringstream os; - // skip first time call, to activate lazy compilation components. - pf.CallPacked(args, &temp); - DeviceAPI::Get(ctx)->StreamSync(ctx, nullptr); - - for (int i = 0; i < repeat; ++i) { - std::chrono::time_point< - std::chrono::high_resolution_clock, std::chrono::nanoseconds> tbegin, tend; - double duration_ms = 0.0; - - do { - if (duration_ms > 0.0) { - number = static_cast( - std::max((min_repeat_ms / (duration_ms / number) + 1), - number * 1.618)); // 1.618 is chosen by random - } - - tbegin = std::chrono::high_resolution_clock::now(); - // start timing - for (int i = 0; i < number; ++i) { - pf.CallPacked(args, &temp); - } - DeviceAPI::Get(ctx)->StreamSync(ctx, nullptr); - tend = std::chrono::high_resolution_clock::now(); - - duration_ms = std::chrono::duration_cast > - (tend - tbegin).count() * 1000; - } while (duration_ms < min_repeat_ms); - - double speed = std::chrono::duration_cast >( - tend - tbegin).count() / number; - os.write(reinterpret_cast(&speed), sizeof(speed)); - } - std::string blob = os.str(); - TVMByteArray arr; - arr.size = blob.length(); - arr.data = blob.data(); - // return the time. - *rv = arr; - }; - return PackedFunc(ftimer); -} - -size_t CallbackChannel::Send(const void* data, size_t size) { - TVMByteArray bytes; - bytes.data = static_cast(data); - bytes.size = size; - int64_t n = fsend_(bytes); - if (n == -1) { - support::Socket::Error("CallbackChannel::Send"); - } - return static_cast(n); -} - -size_t CallbackChannel::Recv(void* data, size_t size) { - TVMRetValue ret = frecv_(size); - - if (ret.type_code() != kTVMBytes) { - support::Socket::Error("CallbackChannel::Recv"); - } - std::string* bytes = ret.ptr(); - memcpy(static_cast(data), bytes->c_str(), bytes->length()); - return bytes->length(); +void RPCSession::InsertToSessionTable(std::shared_ptr sess) { + CHECK_EQ(sess->table_index_, 0); + sess->table_index_ = RPCSessTable::Global()->Insert(sess); } } // namespace runtime diff --git a/src/runtime/rpc/rpc_session.h b/src/runtime/rpc/rpc_session.h index db63be4be74d..6a7e6d6e41c1 100644 --- a/src/runtime/rpc/rpc_session.h +++ b/src/runtime/rpc/rpc_session.h @@ -24,230 +24,253 @@ #ifndef TVM_RUNTIME_RPC_RPC_SESSION_H_ #define TVM_RUNTIME_RPC_RPC_SESSION_H_ -#include #include -#include -#include +#include + +#include #include -#include -#include "../../support/ring_buffer.h" +#include + +#include "rpc_protocol.h" namespace tvm { namespace runtime { -// Magic header for RPC data plane -const int kRPCMagic = 0xff271; -// magic header for RPC tracker(control plane) -const int kRPCTrackerMagic = 0x2f271; -// sucess response -const int kRPCSuccess = kRPCMagic + 0; -// cannot found matched key in server -const int kRPCMismatch = kRPCMagic + 2; - -/*! \brief Enumeration code for the RPC tracker */ -enum class TrackerCode : int { - kFail = -1, - kSuccess = 0, - kPing = 1, - kStop = 2, - kPut = 3, - kRequest = 4, - kUpdateInfo = 5, - kSummary = 6, - kGetPendingMatchKeys = 7 -}; -/*! \brief The remote functio handle */ -using RPCFuncHandle = void*; - -struct RPCArgBuffer; - -/*! \brief The RPC code */ -enum class RPCCode : int { - kNone, - kCallFunc, - kReturn, - kException, - kShutdown, - kCopyFromRemote, - kCopyToRemote, - kCopyAck, - // The following are code that can send over CallRemote - kSystemFuncStart, - kGetGlobalFunc, - kGetTimeEvaluator, - kFreeFunc, - kDevSetDevice, - kDevGetAttr, - kDevAllocData, - kDevFreeData, - kDevStreamSync, - kCopyAmongRemote, - kModuleLoad, - kModuleImport, - kModuleFree, - kModuleGetFunc, - kModuleGetSource, - kNDArrayFree -}; - -/*! - * \brief Function that unwraps a remote object to its handle. - * \param rpc_sess_table_index RPC session table index for validation. - * \param obj Handle to the object argument. - * \return The corresponding handle. - */ -typedef void* (*FUnwrapRemoteObject)( - int rpc_sess_table_index, - const TVMArgValue& obj); - /*! - * \brief Abstract channel interface used to create RPCSession. + * \brief The interface of all remote RPC sessions. + * + * It contains all the necessary interface to implement + * remote call and resource management. + * + * The interface is designed to allow easy proxy-chaining + * by forward requests to another RPCSession. */ -class RPCChannel { +class RPCSession { public: - /*! \brief virtual destructor */ - virtual ~RPCChannel() {} + /*! \brief PackedFunc Handle in the remote. */ + using PackedFuncHandle = void*; + + /*! \brief Module handle in the remote. */ + using ModuleHandle = void*; + + /*! \brief NDArray handle in the remote. */ + using NDArrayHandle = void*; + /*! - * \brief Send data over to the channel. - * \param data The data pointer. - * \param size The size fo the data. - * \return The actual bytes sent. + * \brief Callback to send an encoded return values via encode_args. + * + * \param encode_args The arguments that we can encode the return values into. + * + * Encoding convention (as list of arguments): + * - str/float/int/byte: [tcode: int, value: TVMValue] value follows PackedFunc convention. + * - PackedFunc/Module: [tcode: int, handle: void*] + * - NDArray: [tcode: int, meta: DLTensor*, nd_handle: void*] + * DLTensor* contains the meta-data as well as handle into the remote data. + * nd_handle can be used for deletion. */ - virtual size_t Send(const void* data, size_t size) = 0; + using FEncodeReturn = std::function; + /*! - * \brief Recv data from channel. + * \brief Callback to send an encoded return values via encode_args. * - * \param data The data pointer. - * \param size The size fo the data. - * \return The actual bytes received. + * \param status The return status, can be RPCCode::kReturn or RPCCode::kException. + * \param encode_args The arguments that we can encode the return values into. */ - virtual size_t Recv(void* data, size_t size) = 0; -}; + using FAsyncCallback = std::function; + + /*! \brief Destructor.*/ + virtual ~RPCSession() {} -// Bidirectional Communication Session of PackedRPC -class RPCSession { - public: - /*! \brief virtual destructor */ - ~RPCSession(); /*! - * \brief The server loop that server runs to handle RPC calls. + * \brief Get function in the session. + * \param name The name of the function. + * \return The function handle. */ - void ServerLoop(); + virtual PackedFuncHandle GetFunction(const std::string& name) = 0; + /*! - * \brief Message handling function for event driven server. - * Called when the server receives a message. - * Event driven handler will never call recv on the channel - * and always relies on the ServerEventHandler. - * to receive the data. + * \brief Call into a remote Packed function. * - * \param in_bytes The incoming bytes. - * \param event_flag 1: read_available, 2: write_avaiable. - * \return State flag. - * 1: continue running, no need to write, - * 2: need to write - * 0: shutdown - */ - int ServerEventHandler(const std::string& in_bytes, - int event_flag); - /*! - * \brief Call into remote function - * \param handle The function handle - * \param args The arguments - * \param rv The return value. - * \param funpwrap Function that takes a remote object and returns the raw handle. - * \param fwrap Wrapper function to turn Function/Module handle into real return. + * Calling convention: + * + * - type_code is follows the PackedFunc convention. + * - int/float/string/bytes follows the PackedFunc convention, all data are local. + * - PackedFunc/Module and future remote objects: pass remote handle instead. + * - NDArray/DLTensor: pass a DLTensor pointer, the data field of DLTensor + * points to a remote data handle returned by the Device API. + * The meta-data of the DLTensor sits on local. + * + * The caller populates the arguments and manages these arguments. + * + * The callee can change the content of arg_values and arg_type_codes + * if they want to do inplace modify and forward. + * + * The callee need to store the return value into ret_value. + * - PackedFunc/Module are stored as void* + * - NDArray is stored as local NDArray, whose data field is a remote handle. + * Notably the NDArray's deleter won't delete remote handle. + * It is up to the user of the RPCSession to such wrapping. + * - In short, remote handles are "moved" as return values + * and the callee needs to explicitly manage them by calling + * the deleter functions when they are no longer needed. + * + * \param func The function handle. + * \param arg_values The argument values. + * \param arg_type_codes the type codes of the argument. + * \param num_args Number of arguments. + * \param fencode_return The function to set the return value, + * if not called, return value is null. */ - void CallFunc(RPCFuncHandle handle, - TVMArgs args, - TVMRetValue* rv, - FUnwrapRemoteObject funwrap, - const PackedFunc* fwrap); + virtual void CallFunc(PackedFuncHandle func, const TVMValue* arg_values, + const int* arg_type_codes, int num_args, + const FEncodeReturn& fencode_return) = 0; + /*! * \brief Copy bytes into remote array content. - * \param from The source host data. - * \param from_offset The byte offeset in the from. - * \param to The target array. - * \param to_offset The byte offset in the to. + * \param local_from The source host data. + * \param local_from_offset The byte offeset in the from. + * \param remote_to The target array. + * \param remote_to_offset The byte offset in the to. * \param nbytes The size of the memory in bytes. - * \param ctx_to The target context. + * \param remote_ctx_to The target context. * \param type_hint Hint of content data type. */ - void CopyToRemote(void* from, - size_t from_offset, - void* to, - size_t to_offset, - size_t nbytes, - TVMContext ctx_to, - DLDataType type_hint); + virtual void CopyToRemote(void* local_from, size_t local_from_offset, void* remote_to, + size_t remote_to_offset, size_t nbytes, TVMContext remote_ctx_to, + DLDataType type_hint) = 0; /*! * \brief Copy bytes from remote array content. - * \param from The source host data. - * \param from_offset The byte offeset in the from. + * \param remote_from The source host data. + * \param remote_from_offset The byte offeset in the from. * \param to The target array. * \param to_offset The byte offset in the to. * \param nbytes The size of the memory in bytes. - * \param ctx_from The source context. + * \param remote_ctx_from The source context in the remote. * \param type_hint Hint of content data type. */ - void CopyFromRemote(void* from, - size_t from_offset, - void* to, - size_t to_offset, - size_t nbytes, - TVMContext ctx_from, - DLDataType type_hint); + virtual void CopyFromRemote(void* remote_from, size_t remote_from_offset, void* local_to, + size_t local_to_offset, size_t nbytes, TVMContext remote_ctx_from, + DLDataType type_hint) = 0; + + /*! + * \brief Free a remote function. + * \param handle The remote handle, can be NDArray/PackedFunc/Module + * \param type_code The type code of the underlying type. + */ + virtual void FreeHandle(void* handle, int type_code) = 0; + + /*! + * \brief Get device API that represents the remote + * actions that can be taken on the remote. + * + * The caller can then call into the Alloc/Free functions + * to allocate free spaces and taking the pointer as the handle. + * + * The device API is guaranteed to be alive during the + * lifetime of the Session. + * + * \param ctx The remote context. + * \param allow_missing Whether can we return nullptr if it is not available. + * + * \return The device API. + */ + virtual DeviceAPI* GetDeviceAPI(TVMContext ctx, bool allow_missing = false) = 0; + + /*! + * \brief Whether the session is a local session and we can directly + * the data handle returned by the session and treat it as pointer + * to the local memory. + * + * This information is useful for RPC server to directly copy into the + * local memory without creating a temporary buffer. + * + * \return Whether it is a local session. + */ + virtual bool IsLocalSession() const = 0; + + // Asynchrous variant of API + // These APIs are used by the RPC server to allow sessions that + // have special implementations for the async functions. + // + // In the async APIs, an exception is returned by the passing + // async_error=true, encode_args=[error_msg]. + /*! - * \brief Get a remote timer function on ctx. - * This function consumes fhandle, caller should not call Free on fhandle. + * \brief Whether the session is async. + * + * If the session is not async, its Aync implementations + * simply calls into the their synchronize counterparts, + * and the callback is guaranteed to be called before the async function finishes. + * + * \return the async state. * - * \param fhandle The function handle. - * \param ctx The ctx to run measurement on. - * \param number The number of times to run this function for taking average. - We call these runs as one `repeat` of measurement. - * \param repeat The number of times to repeat the measurement. - In total, the function will be invoked (1 + number x repeat) times, - where the first one is warm up and will be discarded. - The returned result contains `repeat` costs, - each of which is an average of `number` costs. - * \param min_repeat_ms The minimum duration of one `repeat` in milliseconds. - By default, one `repeat` contains `number` runs. If this parameter is set, - the parameters `number` will be dynamically adjusted to meet the - minimum duration requirement of one `repeat`. - i.e., When the run time of one `repeat` falls below this time, - the `number` parameter will be automatically increased. - * \return A remote timer function + * \note We can only use async session in an Event driven RPC server. */ - RPCFuncHandle GetTimeEvaluator(RPCFuncHandle fhandle, - TVMContext ctx, - int number, - int repeat, - int min_repeat_ms); + virtual bool IsAsync() const; + /*! - * \brief Call a remote defined system function with arguments. - * \param fcode The function code. - * \param args The arguments - * \return The returned remote value. + * \brief Asynchrously call func. + * \param func The function handle. + * \param arg_values The argument values. + * \param arg_type_codes the type codes of the argument. + * \param num_args Number of arguments. + * + * \param callback The callback to pass the return value or exception. */ - template - inline TVMRetValue CallRemote(RPCCode fcode, Args&& ...args); + virtual void AsyncCallFunc(PackedFuncHandle func, const TVMValue* arg_values, + const int* arg_type_codes, int num_args, FAsyncCallback callback); + /*! - * \return The session table index of the session. + * \brief Asynchrous version of CopyToRemote. + * + * \param local_from The source host data. + * \param local_from_offset The byte offeset in the from. + * \param remote_to The target array. + * \param remote_to_offset The byte offset in the to. + * \param nbytes The size of the memory in bytes. + * \param remote_ctx_to The target context. + * \param type_hint Hint of content data type. + * + * \param on_complete The callback to signal copy complete. + * \note All the allocated memory in local_from, and remote_to + * must stay alive until on_compelete is called. */ - int table_index() const { - return table_index_; - } + virtual void AsyncCopyToRemote(void* local_from, size_t local_from_offset, void* remote_to, + size_t remote_to_offset, size_t nbytes, TVMContext remote_ctx_to, + DLDataType type_hint, FAsyncCallback on_complete); + /*! - * \brief Create a RPC session with given channel. - * \param channel The communication channel. - * \param name The local name of the session, used for debug - * \param remote_key The remote key of the session - * if remote_key equals "%toinit", we need to re-intialize - * it by event handler. + * \brief Asynchrous version of CopyFromRemote. + * + * \param remote_from The source host data. + * \param remote_from_offset The byte offeset in the from. + * \param to The target array. + * \param to_offset The byte offset in the to. + * \param nbytes The size of the memory in bytes. + * \param remote_ctx_from The source context in the remote. + * \param type_hint Hint of content data type. + * + * \param on_complete The callback to signal copy complete. + * \note All the allocated memory in remote_from, and local_to + * must stay alive until on_compelete is called. + */ + virtual void AsyncCopyFromRemote(void* remote_from, size_t remote_from_offset, void* local_to, + size_t local_to_offset, size_t nbytes, + TVMContext remote_ctx_from, DLDataType type_hint, + FAsyncCallback on_complete); + /*! + * \brief Asynchrously wait for all events in ctx, stream compeletes. + * \param ctx The device context. + * \param stream The stream to wait on. + * \param on_complete The callback to signal copy complete. */ - static std::shared_ptr Create( - std::unique_ptr channel, - std::string name, - std::string remote_key); + virtual void AsyncStreamWait(TVMContext ctx, TVMStreamHandle stream, FAsyncCallback on_compelte); + + /*! + * \return The session table index of the session. + */ + int table_index() const { return table_index_; } + /*! * \brief Try get session from the global session table by table index. * \param table_index The table index of the session. @@ -255,63 +278,33 @@ class RPCSession { */ static std::shared_ptr Get(int table_index); + protected: + /*! + * \brief Send an exception to the callback. + * \param msg The exception message. + */ + void SendException(FAsyncCallback callback, const char* msg); + private: - class EventHandler; - // Handle events until receives a return - // Also flushes channels so that the function advances. - RPCCode HandleUntilReturnEvent( - TVMRetValue* rv, bool client_mode, const PackedFunc* fwrap); - // Initalization - void Init(); - // Shutdown - void Shutdown(); - // Internal channel. - std::unique_ptr channel_; - // Internal mutex - std::recursive_mutex mutex_; - // Internal ring buffer. - support::RingBuffer reader_, writer_; - // Event handler. - std::shared_ptr handler_; - // call remote with specified function code. - PackedFunc call_remote_; - // The index of this session in RPC session table. + /*! \brief index of this session in RPC session table */ int table_index_{0}; - // The name of the session. - std::string name_; - // The remote key - std::string remote_key_; + /*! \brief Insert the current session to the session table.*/ + static void InsertToSessionTable(std::shared_ptr sess); + // friend declaration + friend Module CreateRPCSessionModule(std::shared_ptr sess); }; /*! - * \brief RPC channel which callback - * frontend (Python/Java/etc.)'s send & recv function + * \brief Remote space handle cell used by the RPC runtime API. + * + * When we allocate space using a rpc context, the data pointer + * points to an allocated RemoteSpace. */ -class CallbackChannel final : public RPCChannel { - public: - explicit CallbackChannel(PackedFunc fsend, PackedFunc frecv) - : fsend_(std::move(fsend)), frecv_(std::move(frecv)) {} - - ~CallbackChannel() {} - /*! - * \brief Send data over to the channel. - * \param data The data pointer. - * \param size The size fo the data. - * \return The actual bytes sent. - */ - size_t Send(const void* data, size_t size) final; - /*! - * \brief Recv data from channel. - * - * \param data The data pointer. - * \param size The size fo the data. - * \return The actual bytes received. - */ - size_t Recv(void* data, size_t size) final; - - private: - PackedFunc fsend_; - PackedFunc frecv_; +struct RemoteSpace { + /*! \brief The remote data handle. */ + void* data; + /*! \brief Reference to the underlying RPC session. */ + std::shared_ptr sess; }; /*! @@ -319,24 +312,21 @@ class CallbackChannel final : public RPCChannel { * \param f The function argument. * \param ctx The context. * \param number The number of times to run this function for taking average. - We call these runs as one `repeat` of measurement. + * We call these runs as one `repeat` of measurement. * \param repeat The number of times to repeat the measurement. - In total, the function will be invoked (1 + number x repeat) times, - where the first one is warm up and will be discarded. - The returned result contains `repeat` costs, - each of which is an average of `number` costs. + * In total, the function will be invoked (1 + number x repeat) times, + * where the first one is warm up and will be discarded. + * The returned result contains `repeat` costs, + * each of which is an average of `number` costs. * \param min_repeat_ms The minimum duration of one `repeat` in milliseconds. - By default, one `repeat` contains `number` runs. If this parameter is set, - the parameters `number` will be dynamically adjusted to meet the - minimum duration requirement of one `repeat`. - i.e., When the run time of one `repeat` falls below this time, - the `number` parameter will be automatically increased. + * By default, one `repeat` contains `number` runs. If this parameter is set, + * the parameters `number` will be dynamically adjusted to meet the + * minimum duration requirement of one `repeat`. + * i.e., When the run time of one `repeat` falls below this time, + * the `number` parameter will be automatically increased. * \return f_timer A timer function. */ -PackedFunc WrapTimeEvaluator(PackedFunc f, - TVMContext ctx, - int number, - int repeat, +PackedFunc WrapTimeEvaluator(PackedFunc f, TVMContext ctx, int number, int repeat, int min_repeat_ms); /*! @@ -344,21 +334,15 @@ PackedFunc WrapTimeEvaluator(PackedFunc f, * \param sess The RPC session of the global module. * \return The created module. */ -Module CreateRPCModule(std::shared_ptr sess); +Module CreateRPCSessionModule(std::shared_ptr sess); -// Remote space pointer. -struct RemoteSpace { - void* data; - std::shared_ptr sess; -}; +/*! + * \brief Get the session module from a RPC session Module. + * \param mod The input module(must be an RPCModule). + * \return The internal RPCSession. + */ +std::shared_ptr RPCModuleGetSession(Module mod); -// implementation of inline functions -template -inline TVMRetValue RPCSession::CallRemote(RPCCode code, Args&& ...args) { - std::lock_guard lock(mutex_); - writer_.Write(&code, sizeof(code)); - return call_remote_(std::forward(args)...); -} } // namespace runtime } // namespace tvm #endif // TVM_RUNTIME_RPC_RPC_SESSION_H_ diff --git a/src/runtime/rpc/rpc_socket_impl.cc b/src/runtime/rpc/rpc_socket_impl.cc index 642fbb8ec7f2..77a743be0de6 100644 --- a/src/runtime/rpc/rpc_socket_impl.cc +++ b/src/runtime/rpc/rpc_socket_impl.cc @@ -21,18 +21,22 @@ * \file rpc_socket_impl.cc * \brief Socket based RPC implementation. */ +#include #include + #include -#include "rpc_session.h" + #include "../../support/socket.h" +#include "rpc_endpoint.h" +#include "rpc_local_session.h" +#include "rpc_session.h" namespace tvm { namespace runtime { class SockChannel final : public RPCChannel { public: - explicit SockChannel(support::TCPSocket sock) - : sock_(sock) {} + explicit SockChannel(support::TCPSocket sock) : sock_(sock) {} ~SockChannel() { try { // BadSocket can throw @@ -61,13 +65,12 @@ class SockChannel final : public RPCChannel { support::TCPSocket sock_; }; -std::shared_ptr -RPCConnect(std::string url, int port, std::string key) { +std::shared_ptr RPCConnect(std::string url, int port, std::string key, + TVMArgs init_seq) { support::TCPSocket sock; support::SockAddr addr(url.c_str(), port); sock.Create(addr.ss_family()); - CHECK(sock.Connect(addr)) - << "Connect to " << addr.AsString() << " failed"; + CHECK(sock.Connect(addr)) << "Connect to " << addr.AsString() << " failed"; // hand shake std::ostringstream os; int code = kRPCMagic; @@ -80,12 +83,10 @@ RPCConnect(std::string url, int port, std::string key) { CHECK_EQ(sock.RecvAll(&code, sizeof(code)), sizeof(code)); if (code == kRPCMagic + 2) { sock.Close(); - LOG(FATAL) << "URL " << url << ":" << port - << " cannot find server that matches key=" << key; + LOG(FATAL) << "URL " << url << ":" << port << " cannot find server that matches key=" << key; } else if (code == kRPCMagic + 1) { sock.Close(); - LOG(FATAL) << "URL " << url << ":" << port - << " server already have key=" << key; + LOG(FATAL) << "URL " << url << ":" << port << " server already have key=" << key; } else if (code != kRPCMagic) { sock.Close(); LOG(FATAL) << "URL " << url << ":" << port << " is not TVM RPC server"; @@ -96,42 +97,46 @@ RPCConnect(std::string url, int port, std::string key) { remote_key.resize(keylen); CHECK_EQ(sock.RecvAll(&remote_key[0], keylen), keylen); } - return RPCSession::Create( - std::unique_ptr(new SockChannel(sock)), key, remote_key); + auto endpt = + RPCEndpoint::Create(std::unique_ptr(new SockChannel(sock)), key, remote_key); + endpt->InitRemoteSession(init_seq); + return endpt; } -Module RPCClientConnect(std::string url, int port, std::string key) { - return CreateRPCModule(RPCConnect(url, port, "client:" + key)); +Module RPCClientConnect(std::string url, int port, std::string key, TVMArgs init_seq) { + auto endpt = RPCConnect(url, port, "client:" + key, init_seq); + return CreateRPCSessionModule(CreateClientSession(endpt)); } // TVM_DLL needed for MSVC TVM_DLL void RPCServerLoop(int sockfd) { - support::TCPSocket sock( - static_cast(sockfd)); - RPCSession::Create( - std::unique_ptr(new SockChannel(sock)), - "SockServerLoop", "")->ServerLoop(); + support::TCPSocket sock(static_cast(sockfd)); + RPCEndpoint::Create(std::unique_ptr(new SockChannel(sock)), "SockServerLoop", "") + ->ServerLoop(); } void RPCServerLoop(PackedFunc fsend, PackedFunc frecv) { - RPCSession::Create(std::unique_ptr( - new CallbackChannel(fsend, frecv)), - "SockServerLoop", "")->ServerLoop(); + RPCEndpoint::Create(std::unique_ptr(new CallbackChannel(fsend, frecv)), + "SockServerLoop", "") + ->ServerLoop(); } -TVM_REGISTER_GLOBAL("rpc._Connect") -.set_body_typed(RPCClientConnect); +TVM_REGISTER_GLOBAL("rpc.Connect").set_body([](TVMArgs args, TVMRetValue* rv) { + std::string url = args[0]; + int port = args[1]; + std::string key = args[2]; + *rv = RPCClientConnect(url, port, key, + TVMArgs(args.values + 3, args.type_codes + 3, args.size() - 3)); +}); + +TVM_REGISTER_GLOBAL("rpc.ServerLoop").set_body([](TVMArgs args, TVMRetValue* rv) { + if (args[0].type_code() == kDLInt) { + RPCServerLoop(args[0]); + } else { + RPCServerLoop(args[0].operator tvm::runtime::PackedFunc(), + args[1].operator tvm::runtime::PackedFunc()); + } +}); -TVM_REGISTER_GLOBAL("rpc._ServerLoop") -.set_body([](TVMArgs args, TVMRetValue* rv) { - if (args.size() == 1) { - RPCServerLoop(args[0]); - } else { - CHECK_EQ(args.size(), 2); - RPCServerLoop( - args[0].operator tvm::runtime::PackedFunc(), - args[1].operator tvm::runtime::PackedFunc()); - } - }); } // namespace runtime } // namespace tvm diff --git a/src/runtime/runtime_base.h b/src/runtime/runtime_base.h index 84fc3c462c3d..21601df1ad39 100644 --- a/src/runtime/runtime_base.h +++ b/src/runtime/runtime_base.h @@ -25,25 +25,37 @@ #define TVM_RUNTIME_RUNTIME_BASE_H_ #include + #include /*! \brief macro to guard beginning and end section of all functions */ #define API_BEGIN() try { /*! \brief every function starts with API_BEGIN(); and finishes with API_END() or API_END_HANDLE_ERROR */ -#define API_END() } catch(std::runtime_error &_except_) { return TVMAPIHandleException(_except_); } return 0; // NOLINT(*) +#define API_END() \ + } \ + catch (std::runtime_error & _except_) { \ + return TVMAPIHandleException(_except_); \ + } \ + return 0; // NOLINT(*) /*! * \brief every function starts with API_BEGIN(); * and finishes with API_END() or API_END_HANDLE_ERROR * The finally clause contains procedure to cleanup states when an error happens. */ -#define API_END_HANDLE_ERROR(Finalize) } catch(std::runtime_error &_except_) { Finalize; return TVMAPIHandleException(_except_); } return 0; // NOLINT(*) +#define API_END_HANDLE_ERROR(Finalize) \ + } \ + catch (std::runtime_error & _except_) { \ + Finalize; \ + return TVMAPIHandleException(_except_); \ + } \ + return 0; // NOLINT(*) /*! * \brief handle exception throwed out * \param e the exception * \return the return value of API after exception is handled */ -int TVMAPIHandleException(const std::runtime_error &e); +int TVMAPIHandleException(const std::runtime_error& e); #endif // TVM_RUNTIME_RUNTIME_BASE_H_ diff --git a/src/runtime/stackvm/stackvm.cc b/src/runtime/stackvm/stackvm.cc index 0f17f9e4b4a2..042815b3d68b 100644 --- a/src/runtime/stackvm/stackvm.cc +++ b/src/runtime/stackvm/stackvm.cc @@ -21,87 +21,88 @@ * Implementation stack VM. * \file stackvm.cc */ +#include "stackvm.h" + #include #include + #include -#include "stackvm.h" namespace tvm { namespace runtime { typedef dmlc::ThreadLocalStore StackVMStateStore; -StackVM::State* StackVM::ThreadLocalState() { - return StackVMStateStore::Get(); -} +StackVM::State* StackVM::ThreadLocalState() { return StackVMStateStore::Get(); } #define STACK_VM_BINOP(OP, FIELD) \ { \ stack[sp - 1].FIELD = stack[sp - 1].FIELD OP stack[sp].FIELD; \ - sp -= 1; pc += 1; \ + sp -= 1; \ + pc += 1; \ } #define STACK_VM_CMPOP(OP, FIELD) \ { \ stack[sp - 1].v_int64 = stack[sp - 1].FIELD OP stack[sp].FIELD; \ - sp -= 1; pc += 1; \ + sp -= 1; \ + pc += 1; \ } -#define STACK_VM_LOAD(FIELD, DST_TYPE, SRC_TYPE) \ - { \ - int index = code[pc + 1].v_int; \ - stack[sp]FIELD = static_cast( \ - static_cast(stack[sp].v_handle)[index]); \ - pc += 2; \ +#define STACK_VM_LOAD(FIELD, DST_TYPE, SRC_TYPE) \ + { \ + int index = code[pc + 1].v_int; \ + stack[sp] FIELD = static_cast(static_cast(stack[sp].v_handle)[index]); \ + pc += 2; \ } -#define STACK_VM_STORE(FIELD, DST_TYPE) \ - { \ - int index = code[pc + 1].v_int; \ - static_cast(stack[sp - 1].v_handle)[index] = \ - static_cast(stack[sp]FIELD); \ - sp -= 2; pc += 2; \ +#define STACK_VM_STORE(FIELD, DST_TYPE) \ + { \ + int index = code[pc + 1].v_int; \ + static_cast(stack[sp - 1].v_handle)[index] = \ + static_cast(stack[sp] FIELD); \ + sp -= 2; \ + pc += 2; \ } -#define STACK_VM_PRINT_CODE0(CODE) \ - case CODE: { \ - os << "[" << pc << "]\t" << #CODE << std::endl; return pc + 1; \ +#define STACK_VM_PRINT_CODE0(CODE) \ + case CODE: { \ + os << "[" << pc << "]\t" << #CODE << std::endl; \ + return pc + 1; \ } -#define STACK_VM_PRINT_CODE1(CODE) \ - case CODE: { \ +#define STACK_VM_PRINT_CODE1(CODE) \ + case CODE: { \ os << "[" << pc << "]\t" << #CODE << " " << code[pc + 1].v_int << "\n" \ - << "[" << pc + 1 << "]" << std::endl; \ - return pc + 2; \ + << "[" << pc + 1 << "]" << std::endl; \ + return pc + 2; \ } -#define STACK_VM_PRINT_CODE2(CODE) \ - case CODE: { \ - os << "[" << pc << "]\t" << #CODE \ - << " " << code[pc + 1].v_int \ - << " " << code[pc + 2].v_int << "\n" \ - << "[" << pc + 1 << "]" << std::endl \ - << "[" << pc + 2 << "]" << std::endl; \ - return pc + 3; \ +#define STACK_VM_PRINT_CODE2(CODE) \ + case CODE: { \ + os << "[" << pc << "]\t" << #CODE << " " << code[pc + 1].v_int << " " << code[pc + 2].v_int \ + << "\n" \ + << "[" << pc + 1 << "]" << std::endl \ + << "[" << pc + 2 << "]" << std::endl; \ + return pc + 3; \ } -#define STACK_VM_PRINT_HEAP_ACCESS(CODE) \ - case CODE: { \ - os << "[" << pc << "]\t" << #CODE << " " << code[pc + 1].v_int \ - << " " << heap_id_name[code[pc + 1].v_int] << "\n" \ - << "[" << pc + 1 << "]" << std::endl; \ - return pc + 2; \ +#define STACK_VM_PRINT_HEAP_ACCESS(CODE) \ + case CODE: { \ + os << "[" << pc << "]\t" << #CODE << " " << code[pc + 1].v_int << " " \ + << heap_id_name[code[pc + 1].v_int] << "\n" \ + << "[" << pc + 1 << "]" << std::endl; \ + return pc + 2; \ } -#define STACK_VM_PRINT_JUMP(CODE) \ - case CODE: { \ - os << "[" << pc << "]\t" << #CODE << " rel=" << code[pc + 1].v_int \ - << " to " << pc + code[pc + 1].v_int << '\n' \ - << "[" << pc + 1 << "]" << std::endl; \ - return pc + 2; \ +#define STACK_VM_PRINT_JUMP(CODE) \ + case CODE: { \ + os << "[" << pc << "]\t" << #CODE << " rel=" << code[pc + 1].v_int << " to " \ + << pc + code[pc + 1].v_int << '\n' \ + << "[" << pc + 1 << "]" << std::endl; \ + return pc + 2; \ } - int64_t StackVM::PrintCode(std::ostream& os, int64_t pc) const { switch (code[pc].op_code) { // int @@ -164,9 +165,7 @@ int64_t StackVM::PrintCode(std::ostream& os, int64_t pc) const { int begin = code[pc + 2].v_int; int end = code[pc + 3].v_int; os << "[" << pc << "]\tCALL_PACKED_FUNC " - << " fid=" << call_fid - << " begin=" << begin - << " end=" << end; + << " fid=" << call_fid << " begin=" << begin << " end=" << end; os << '\n'; for (int i = 0; i < 3; ++i) { os << "[" << pc + 1 + i << "]" << std::endl; @@ -181,8 +180,7 @@ int64_t StackVM::PrintCode(std::ostream& os, int64_t pc) const { std::ostream& operator<<(std::ostream& os, const StackVM& vm) { // NOLINT(*) int64_t pc = 0; const int64_t code_size = static_cast(vm.code.size()); - os << "Program dump: code-size=" << code_size << '\n' - << "----------begin-----------------\n"; + os << "Program dump: code-size=" << code_size << '\n' << "----------begin-----------------\n"; while (pc < code_size) { pc = vm.PrintCode(os, pc); } @@ -190,8 +188,7 @@ std::ostream& operator<<(std::ostream& os, const StackVM& vm) { // NOLINT(*) return os; } -void StackVM::Run(const runtime::TVMArgs& args, - runtime::ModuleNode* mod_ctx) const { +void StackVM::Run(const runtime::TVMArgs& args, runtime::ModuleNode* mod_ctx) const { StackVM::State* s = StackVM::ThreadLocalState(); if (s->heap.size() < heap_size) { s->heap.resize(heap_size); @@ -199,7 +196,7 @@ void StackVM::Run(const runtime::TVMArgs& args, s->sp = 0; s->pc = 0; s->mod_ctx = mod_ctx; - s->heap[0].v_handle = (void*)args.values; // NOLINT(*) + s->heap[0].v_handle = (void*)args.values; // NOLINT(*) s->heap[1].v_handle = (void*)args.type_codes; // NOLINT(*) s->heap[2].v_int64 = args.num_args; this->Run(s); @@ -207,16 +204,13 @@ void StackVM::Run(const runtime::TVMArgs& args, void StackVM::InitCache() { extern_func_cache_.clear(); - extern_func_cache_.resize( - extern_func_name.size(), PackedFunc(nullptr)); + extern_func_cache_.resize(extern_func_name.size(), PackedFunc(nullptr)); } void StackVM::Save(dmlc::Stream* strm) const { // to be endian invariant. std::vector code_copy(code.size()); - std::transform(code.begin(), code.end(), code_copy.begin(), [](Code c) { - return c.v_int; - }); + std::transform(code.begin(), code.end(), code_copy.begin(), [](Code c) { return c.v_int; }); strm->Write(code_copy); strm->Write(str_data); strm->Write(extern_func_name); @@ -225,14 +219,16 @@ void StackVM::Save(dmlc::Stream* strm) const { strm->Write(stack_size); } -bool StackVM::Load(dmlc::Stream* strm) { +bool StackVM::Load(dmlc::Stream* strm) { // to be endian invariant. std::vector code_copy; if (!strm->Read(&code_copy)) return false; code.resize(code_copy.size()); std::transform(code_copy.begin(), code_copy.end(), code.begin(), [](int v) { - Code code; code.v_int = v; return code; - }); + Code code; + code.v_int = v; + return code; + }); if (!strm->Read(&str_data)) return false; if (!strm->Read(&extern_func_name)) return false; if (!strm->Read(&heap_id_name)) return false; @@ -258,36 +254,92 @@ void StackVM::Run(State* s) const { const int64_t code_size = static_cast(code.size()); while (pc < code_size) { switch (code[pc].op_code) { - case ADD_I64: STACK_VM_BINOP(+, v_int64); break; - case SUB_I64: STACK_VM_BINOP(-, v_int64); break; - case MUL_I64: STACK_VM_BINOP(*, v_int64); break; - case DIV_I64: STACK_VM_BINOP(/, v_int64); break; - case MOD_I64: STACK_VM_BINOP(%, v_int64); break; - case EQ_I64: STACK_VM_CMPOP(==, v_int64); break; - case LT_I64: STACK_VM_CMPOP(<, v_int64); break; - case LE_I64: STACK_VM_CMPOP(<=, v_int64); break; - case ADD_F64: STACK_VM_BINOP(+, v_float64); break; - case SUB_F64: STACK_VM_BINOP(-, v_float64); break; - case MUL_F64: STACK_VM_BINOP(*, v_float64); break; - case DIV_F64: STACK_VM_BINOP(/, v_float64); break; - case EQ_F64: STACK_VM_CMPOP(==, v_float64); break; - case LT_F64: STACK_VM_CMPOP(<, v_float64); break; - case LE_F64: STACK_VM_CMPOP(<=, v_float64); break; - case EQ_HANDLE: STACK_VM_CMPOP(==, v_handle); break; + case ADD_I64: + STACK_VM_BINOP(+, v_int64); + break; + case SUB_I64: + STACK_VM_BINOP(-, v_int64); + break; + case MUL_I64: + STACK_VM_BINOP(*, v_int64); + break; + case DIV_I64: + STACK_VM_BINOP(/, v_int64); + break; + case MOD_I64: + STACK_VM_BINOP(%, v_int64); + break; + case EQ_I64: + STACK_VM_CMPOP(==, v_int64); + break; + case LT_I64: + STACK_VM_CMPOP(<, v_int64); + break; + case LE_I64: + STACK_VM_CMPOP(<=, v_int64); + break; + case ADD_F64: + STACK_VM_BINOP(+, v_float64); + break; + case SUB_F64: + STACK_VM_BINOP(-, v_float64); + break; + case MUL_F64: + STACK_VM_BINOP(*, v_float64); + break; + case DIV_F64: + STACK_VM_BINOP(/, v_float64); + break; + case EQ_F64: + STACK_VM_CMPOP(==, v_float64); + break; + case LT_F64: + STACK_VM_CMPOP(<, v_float64); + break; + case LE_F64: + STACK_VM_CMPOP(<=, v_float64); + break; + case EQ_HANDLE: + STACK_VM_CMPOP(==, v_handle); + break; // addressing - case ARRAY_LOAD_UINT32: STACK_VM_LOAD(.v_int64, int64_t, uint32_t); break; - case ARRAY_LOAD_INT32: STACK_VM_LOAD(.v_int64, int64_t, int32_t); break; - case ARRAY_LOAD_INT64: STACK_VM_LOAD(.v_int64, int64_t, int64_t); break; - case ARRAY_LOAD_FP64: STACK_VM_LOAD(.v_float64, double, double); break; - case ARRAY_LOAD_HANDLE: STACK_VM_LOAD(.v_handle, void*, void*); break; - case ARRAY_LOAD_TVMVALUE: STACK_VM_LOAD(, TVMValue, TVMValue); break; + case ARRAY_LOAD_UINT32: + STACK_VM_LOAD(.v_int64, int64_t, uint32_t); + break; + case ARRAY_LOAD_INT32: + STACK_VM_LOAD(.v_int64, int64_t, int32_t); + break; + case ARRAY_LOAD_INT64: + STACK_VM_LOAD(.v_int64, int64_t, int64_t); + break; + case ARRAY_LOAD_FP64: + STACK_VM_LOAD(.v_float64, double, double); + break; + case ARRAY_LOAD_HANDLE: + STACK_VM_LOAD(.v_handle, void*, void*); + break; + case ARRAY_LOAD_TVMVALUE: + STACK_VM_LOAD(, TVMValue, TVMValue); + break; // store - case ARRAY_STORE_UINT32: STACK_VM_STORE(.v_int64, uint32_t); break; - case ARRAY_STORE_INT32: STACK_VM_STORE(.v_int64, int32_t); break; - case ARRAY_STORE_INT64: STACK_VM_STORE(.v_int64, int64_t); break; - case ARRAY_STORE_FP64: STACK_VM_STORE(.v_float64, double); break; - case ARRAY_STORE_HANDLE: STACK_VM_STORE(.v_handle, void*); break; - case ARRAY_STORE_TVMVALUE: STACK_VM_STORE(, TVMValue); break; + case ARRAY_STORE_UINT32: + STACK_VM_STORE(.v_int64, uint32_t); + break; + case ARRAY_STORE_INT32: + STACK_VM_STORE(.v_int64, int32_t); + break; + case ARRAY_STORE_INT64: + STACK_VM_STORE(.v_int64, int64_t); + break; + case ARRAY_STORE_FP64: + STACK_VM_STORE(.v_float64, double); + break; + case ARRAY_STORE_HANDLE: + STACK_VM_STORE(.v_handle, void*); + break; + case ARRAY_STORE_TVMVALUE: + STACK_VM_STORE(, TVMValue); + break; // add case ADDR_ADD: { stack[sp - 1].v_handle = (char*)(stack[sp - 1].v_handle) + stack[sp].v_int64; // NOLINT(*) @@ -365,9 +417,8 @@ void StackVM::Run(State* s) const { } case ASSERT_SP: { int64_t expected = code[pc + 1].v_int; - CHECK_EQ(sp, expected) - << "sp assertion failed, expected=" - << expected << " now=" << sp << ", pc=" << pc; + CHECK_EQ(sp, expected) << "sp assertion failed, expected=" << expected << " now=" << sp + << ", pc=" << pc; pc += 2; break; } @@ -379,11 +430,10 @@ void StackVM::Run(State* s) const { int begin = code[pc + 2].v_int; int end = code[pc + 3].v_int; int num_args = end - begin; - static_assert(sizeof(Code) == sizeof(int) && - alignof(Code) == alignof(int), "asusmption"); + static_assert(sizeof(Code) == sizeof(int) && alignof(Code) == alignof(int), "asusmption"); runtime::TVMRetValue rv; - GetExtern(s, call_fid).CallPacked( - runtime::TVMArgs(value_stack + begin, type_stack + begin, num_args), &rv); + GetExtern(s, call_fid) + .CallPacked(runtime::TVMArgs(value_stack + begin, type_stack + begin, num_args), &rv); sp = sp - 1; stack[sp] = rv.value(); pc += 4; @@ -396,47 +446,55 @@ void StackVM::Run(State* s) const { DLTensor* arr = static_cast(stack[sp].v_handle); switch (kind) { case StackVM::kArrData: { - stack[sp].v_handle = arr[index].data; break; + stack[sp].v_handle = arr[index].data; + break; } case StackVM::kArrShape: { - stack[sp].v_handle = arr[index].shape; break; + stack[sp].v_handle = arr[index].shape; + break; } case StackVM::kArrStrides: { - stack[sp].v_handle = arr[index].strides; break; + stack[sp].v_handle = arr[index].strides; + break; } case StackVM::kArrNDim: { - stack[sp].v_int64 = arr[index].ndim; break; + stack[sp].v_int64 = arr[index].ndim; + break; } case StackVM::kArrTypeCode: { - stack[sp].v_int64 = static_cast( - arr[index].dtype.code); break; + stack[sp].v_int64 = static_cast(arr[index].dtype.code); + break; } case StackVM::kArrTypeBits: { - stack[sp].v_int64 = static_cast( - arr[index].dtype.bits); break; + stack[sp].v_int64 = static_cast(arr[index].dtype.bits); + break; } case StackVM::kArrTypeLanes: { - stack[sp].v_int64 = static_cast( - arr[index].dtype.lanes); break; + stack[sp].v_int64 = static_cast(arr[index].dtype.lanes); + break; } case StackVM::kArrByteOffset: { - stack[sp].v_int64 = static_cast( - arr[index].byte_offset); break; + stack[sp].v_int64 = static_cast(arr[index].byte_offset); + break; } case StackVM::kArrDeviceId: { - stack[sp].v_int64 = arr[index].ctx.device_id; break; + stack[sp].v_int64 = arr[index].ctx.device_id; + break; } case StackVM::kArrDeviceType: { - stack[sp].v_int64 = static_cast( - arr[index].ctx.device_type); break; + stack[sp].v_int64 = static_cast(arr[index].ctx.device_type); + break; } case StackVM::kArrAddr: { - stack[sp].v_handle = arr + index; break; + stack[sp].v_handle = arr + index; + break; } case StackVM::kTVMValueContent: { - stack[sp] = static_cast(stack[sp].v_handle)[index]; break; + stack[sp] = static_cast(stack[sp].v_handle)[index]; + break; } - default: LOG(FATAL) << "unhandled get " << kind; + default: + LOG(FATAL) << "unhandled get " << kind; } pc = pc + 3; break; @@ -447,7 +505,8 @@ void StackVM::Run(State* s) const { DLTensor* arr = static_cast(stack[sp - 1].v_handle); switch (kind) { case StackVM::kArrData: { - arr[index].data = stack[sp].v_handle; break; + arr[index].data = stack[sp].v_handle; + break; } case StackVM::kArrShape: { arr[index].shape = static_cast(stack[sp].v_handle); @@ -486,9 +545,11 @@ void StackVM::Run(State* s) const { break; } case StackVM::kTVMValueContent: { - static_cast(stack[sp - 1].v_handle)[index] = stack[sp]; break; + static_cast(stack[sp - 1].v_handle)[index] = stack[sp]; + break; } - default: LOG(FATAL) << "unhandled tvm_struct_set " << kind; + default: + LOG(FATAL) << "unhandled tvm_struct_set " << kind; } sp -= 2; pc += 3; @@ -511,8 +572,8 @@ void StackVM::Run(State* s) const { size_t nbytes = static_cast(stack[sp - 2].v_int64); int dtype_code_hint = static_cast(stack[sp - 1].v_int64); int dtype_bits_hint = static_cast(stack[sp].v_int64); - void* ptr = TVMBackendAllocWorkspace(device_type, device_id, nbytes, - dtype_code_hint, dtype_bits_hint); + void* ptr = TVMBackendAllocWorkspace(device_type, device_id, nbytes, dtype_code_hint, + dtype_bits_hint); stack[sp - 4].v_handle = ptr; sp = sp - 4; pc = pc + 1; @@ -543,8 +604,7 @@ const PackedFunc& StackVM::GetExtern(State* s, int fid) const { // allow race write in this, since write is idempotent PackedFunc& f = extern_func_cache_[fid]; if (f == nullptr) { - CHECK(s->mod_ctx != nullptr) - << "No local context is set in stackvm"; + CHECK(s->mod_ctx != nullptr) << "No local context is set in stackvm"; const PackedFunc* pf = s->mod_ctx->GetFuncFromEnv(extern_func_name[fid]); CHECK(pf != nullptr); f = *pf; diff --git a/src/runtime/stackvm/stackvm.h b/src/runtime/stackvm/stackvm.h index f36e171cdf3e..09581a6d0b62 100644 --- a/src/runtime/stackvm/stackvm.h +++ b/src/runtime/stackvm/stackvm.h @@ -29,8 +29,9 @@ #define TVM_RUNTIME_STACKVM_STACKVM_H_ #include -#include #include +#include + #include #include @@ -339,7 +340,7 @@ class StackVM { * \param pc The pc * \return the pc to next instruction. */ - int64_t PrintCode(std::ostream&os, int64_t pc) const; // NOLINT(*) + int64_t PrintCode(std::ostream& os, int64_t pc) const; // NOLINT(*) /*! \brief Get thread local state of the stack VM */ static State* ThreadLocalState(); // The code below are programs @@ -362,15 +363,26 @@ class StackVM { */ static OpCode CodeI64ToF64(OpCode code) { switch (code) { - case ADD_I64: return ADD_F64; - case SUB_I64: return SUB_F64; - case MUL_I64: return MUL_F64; - case DIV_I64: return DIV_F64; - case EQ_I64: return EQ_F64; - case LT_I64: return LT_F64; - case LE_I64: return LE_F64; - case MOD_I64: LOG(FATAL) << "cannot handle mod for float"; return ADD_F64; - default: LOG(FATAL) << "cannot handle op " << code; return ADD_F64; + case ADD_I64: + return ADD_F64; + case SUB_I64: + return SUB_F64; + case MUL_I64: + return MUL_F64; + case DIV_I64: + return DIV_F64; + case EQ_I64: + return EQ_F64; + case LT_I64: + return LT_F64; + case LE_I64: + return LE_F64; + case MOD_I64: + LOG(FATAL) << "cannot handle mod for float"; + return ADD_F64; + default: + LOG(FATAL) << "cannot handle op " << code; + return ADD_F64; } } /*! @@ -383,16 +395,20 @@ class StackVM { if (t.code == kTVMOpaqueHandle) return ARRAY_LOAD_HANDLE; if (t.code == kDLInt) { switch (t.bits) { - case 32 : return ARRAY_LOAD_INT32; - case 64 : return ARRAY_LOAD_INT64; + case 32: + return ARRAY_LOAD_INT32; + case 64: + return ARRAY_LOAD_INT64; } } else if (t.code == kDLUInt) { switch (t.bits) { - case 32 : return ARRAY_LOAD_UINT32; + case 32: + return ARRAY_LOAD_UINT32; } } else if (t.code == kDLFloat) { switch (t.bits) { - case 64 : return ARRAY_LOAD_FP64; + case 64: + return ARRAY_LOAD_FP64; } } LOG(FATAL) << "Cannot load type " << t; @@ -408,16 +424,20 @@ class StackVM { if (t.code == kTVMOpaqueHandle) return ARRAY_STORE_HANDLE; if (t.code == kDLInt) { switch (t.bits) { - case 32 : return ARRAY_STORE_INT32; - case 64 : return ARRAY_STORE_INT64; + case 32: + return ARRAY_STORE_INT32; + case 64: + return ARRAY_STORE_INT64; } } else if (t.code == kDLUInt) { switch (t.bits) { - case 32 : return ARRAY_STORE_UINT32; + case 32: + return ARRAY_STORE_UINT32; } } else if (t.code == kDLFloat) { switch (t.bits) { - case 64 : return ARRAY_STORE_FP64; + case 64: + return ARRAY_STORE_FP64; } } LOG(FATAL) << "Cannot store type " << t; diff --git a/src/runtime/stackvm/stackvm_module.cc b/src/runtime/stackvm/stackvm_module.cc index 8b30b750e714..9e1f1f515f4a 100644 --- a/src/runtime/stackvm/stackvm_module.cc +++ b/src/runtime/stackvm/stackvm_module.cc @@ -20,13 +20,16 @@ /*! * \file stackvm_module.cc */ -#include -#include +#include "stackvm_module.h" + #include +#include +#include + #include -#include #include -#include "stackvm_module.h" +#include + #include "../file_util.h" namespace tvm { @@ -34,13 +37,9 @@ namespace runtime { class StackVMModuleNode : public runtime::ModuleNode { public: - const char* type_key() const { - return "stackvm"; - } + const char* type_key() const { return "stackvm"; } - PackedFunc GetFunction( - const std::string& name, - const ObjectPtr& sptr_to_self) final { + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { if (name == runtime::symbol::tvm_module_main) { return GetFunction(entry_func_, sptr_to_self); } @@ -48,9 +47,8 @@ class StackVMModuleNode : public runtime::ModuleNode { if (it == fmap_.end()) return PackedFunc(); const StackVM& vm = it->second; // capture sptr_to_self to keep module node alive. - return PackedFunc([vm, sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - vm.Run(args, this); - }); + return PackedFunc( + [vm, sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { vm.Run(args, this); }); } std::string GetSource(const std::string& format) final { @@ -62,8 +60,7 @@ class StackVMModuleNode : public runtime::ModuleNode { return os.str(); } - void SaveToFile(const std::string& file_name, - const std::string& format) final { + void SaveToFile(const std::string& file_name, const std::string& format) final { std::string data, mblob; dmlc::MemoryStringStream writer(&data); dmlc::Stream* strm = &writer; @@ -74,8 +71,7 @@ class StackVMModuleNode : public runtime::ModuleNode { strm->Write(num_imports); for (runtime::Module im : imports_) { - CHECK_EQ(im->imports().size(), 0U) - << "Only support simply one-level hierarchy"; + CHECK_EQ(im->imports().size(), 0U) << "Only support simply one-level hierarchy"; std::string tkey = im->type_key(); strm->Write(tkey); LOG(INFO) << "save " << tkey; @@ -85,8 +81,7 @@ class StackVMModuleNode : public runtime::ModuleNode { SaveBinaryToFile(file_name, data); } - static Module Create(std::unordered_map fmap, - std::string entry_func) { + static Module Create(std::unordered_map fmap, std::string entry_func) { auto n = make_object(); n->fmap_ = std::move(fmap); n->entry_func_ = std::move(entry_func); @@ -108,17 +103,14 @@ class StackVMModuleNode : public runtime::ModuleNode { CHECK(strm->Read(&tkey)); std::string fkey = "runtime.module.loadbinary_" + tkey; const PackedFunc* f = Registry::Get(fkey); - CHECK(f != nullptr) - << "Loader of " << tkey << "(" - << fkey << ") is not presented."; + CHECK(f != nullptr) << "Loader of " << tkey << "(" << fkey << ") is not presented."; Module m = (*f)(static_cast(strm)); n->imports_.emplace_back(std::move(m)); } return Module(n); } - static Module LoadFromFile(std::string file_name, - std::string format) { + static Module LoadFromFile(std::string file_name, std::string format) { std::string data; LoadBinaryFromFile(file_name, &data); dmlc::MemoryStringStream reader(&data); @@ -132,13 +124,12 @@ class StackVMModuleNode : public runtime::ModuleNode { std::string entry_func_; }; -Module StackVMModuleCreate(std::unordered_map fmap, - std::string entry_func) { +Module StackVMModuleCreate(std::unordered_map fmap, std::string entry_func) { return StackVMModuleNode::Create(fmap, entry_func); } TVM_REGISTER_GLOBAL("runtime.module.loadfile_stackvm") -.set_body_typed(StackVMModuleNode::LoadFromFile); + .set_body_typed(StackVMModuleNode::LoadFromFile); } // namespace runtime } // namespace tvm diff --git a/src/runtime/stackvm/stackvm_module.h b/src/runtime/stackvm/stackvm_module.h index c84eb6fe4945..6ae4ae47a92c 100644 --- a/src/runtime/stackvm/stackvm_module.h +++ b/src/runtime/stackvm/stackvm_module.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -25,8 +25,10 @@ #define TVM_RUNTIME_STACKVM_STACKVM_MODULE_H_ #include + #include #include + #include "stackvm.h" namespace tvm { @@ -38,8 +40,7 @@ namespace runtime { * \param entry_func The entry function name. * \return The created module */ -Module StackVMModuleCreate(std::unordered_map fmap, - std::string entry_func); +Module StackVMModuleCreate(std::unordered_map fmap, std::string entry_func); } // namespace runtime } // namespace tvm diff --git a/src/runtime/system_library.cc b/src/runtime/system_library.cc index 3eb7b1c46b45..fe29146d8b7b 100644 --- a/src/runtime/system_library.cc +++ b/src/runtime/system_library.cc @@ -21,10 +21,12 @@ * \file system_library.cc * \brief Create library module that directly get symbol from the system lib. */ -#include -#include #include +#include +#include + #include + #include "library_module.h" namespace tvm { @@ -48,10 +50,8 @@ class SystemLibrary : public Library { std::lock_guard lock(mutex_); auto it = tbl_.find(name); if (it != tbl_.end() && ptr != it->second) { - LOG(WARNING) - << "SystemLib symbol " << name - << " get overriden to a different address " - << ptr << "->" << it->second; + LOG(WARNING) << "SystemLib symbol " << name << " get overriden to a different address " << ptr + << "->" << it->second; } tbl_[name] = ptr; } @@ -68,11 +68,9 @@ class SystemLibrary : public Library { std::unordered_map tbl_; }; -TVM_REGISTER_GLOBAL("runtime.SystemLib") -.set_body_typed([]() { - static auto mod = CreateModuleFromLibrary( - SystemLibrary::Global()); - return mod; +TVM_REGISTER_GLOBAL("runtime.SystemLib").set_body_typed([]() { + static auto mod = CreateModuleFromLibrary(SystemLibrary::Global()); + return mod; }); } // namespace runtime } // namespace tvm diff --git a/src/runtime/thread_pool.cc b/src/runtime/thread_pool.cc index 00f089b86b0f..0cc881ceb7f2 100644 --- a/src/runtime/thread_pool.cc +++ b/src/runtime/thread_pool.cc @@ -21,26 +21,26 @@ * \file thread_pool.cc * \brief Threadpool for multi-threading runtime. */ -#include +#include +#include #include -#include +#include #include +#include #include -#include -#include #if TVM_THREADPOOL_USE_OPENMP #include #endif -#include -#include -#include -#include #include -#include -#include +#include +#include #include #include +#include #include +#include +#include +#include const constexpr int kL1CacheBytes = 64; @@ -69,10 +69,7 @@ constexpr int kSyncStride = 64 / sizeof(std::atomic); class ParallelLauncher { public: // Reset the the task request. - void Init(FTVMParallelLambda flambda, - void* cdata, - int num_task, - bool need_sync) { + void Init(FTVMParallelLambda flambda, void* cdata, int num_task, bool need_sync) { num_pending_.store(num_task); this->cdata = cdata; this->flambda = flambda; @@ -88,17 +85,14 @@ class ParallelLauncher { } if (need_sync) { for (int i = 0; i < num_task; ++i) { - sync_counter_[i * kSyncStride].store( - 0, std::memory_order_relaxed); + sync_counter_[i * kSyncStride].store(0, std::memory_order_relaxed); } this->env.sync_handle = sync_counter_; } else { this->env.sync_handle = nullptr; } } - ~ParallelLauncher() { - delete[] sync_counter_; - } + ~ParallelLauncher() { delete[] sync_counter_; } // Wait n jobs to finish int WaitForJobs() { while (num_pending_.load() != 0) { @@ -122,13 +116,9 @@ class ParallelLauncher { has_error_.store(true); } // Signal that one job has finished. - void SignalJobFinish() { - num_pending_.fetch_sub(1); - } + void SignalJobFinish() { num_pending_.fetch_sub(1); } // Get thread local version of the store. - static ParallelLauncher* ThreadLocal() { - return dmlc::ThreadLocalStore::Get(); - } + static ParallelLauncher* ThreadLocal() { return dmlc::ThreadLocalStore::Get(); } // The parallel lambda FTVMParallelLambda flambda; // The closure data @@ -159,15 +149,9 @@ class SpscTaskQueue { int32_t task_id; }; - SpscTaskQueue() : - buffer_(new Task[kRingSize]), - head_(0), - tail_(0) { - } + SpscTaskQueue() : buffer_(new Task[kRingSize]), head_(0), tail_(0) {} - ~SpscTaskQueue() { - delete[] buffer_; - } + ~SpscTaskQueue() { delete[] buffer_; } /*! * \brief Push a task into the queue and notify the comsumer if it is on wait. @@ -198,9 +182,7 @@ class SpscTaskQueue { } if (pending_.fetch_sub(1) == 0) { std::unique_lock lock(mutex_); - cv_.wait(lock, [this] { - return pending_.load() >= 0 || exit_now_.load(); - }); + cv_.wait(lock, [this] { return pending_.load() >= 0 || exit_now_.load(); }); } if (exit_now_.load(std::memory_order_relaxed)) { return false; @@ -275,7 +257,7 @@ class SpscTaskQueue { // The thread pool class ThreadPool { public: - ThreadPool(): num_workers_(tvm::runtime::threading::MaxConcurrency()) { + ThreadPool() : num_workers_(tvm::runtime::threading::MaxConcurrency()) { for (int i = 0; i < num_workers_; ++i) { // The SpscTaskQueue only hosts ONE item at a time queues_.emplace_back(std::unique_ptr(new SpscTaskQueue())); @@ -286,8 +268,8 @@ class ThreadPool { } threads_ = std::unique_ptr( new tvm::runtime::threading::ThreadGroup( - num_workers_, [this](int worker_id) { this->RunWorker(worker_id); }, - exclude_worker0_ /* include_main_thread */)); + num_workers_, [this](int worker_id) { this->RunWorker(worker_id); }, + exclude_worker0_ /* include_main_thread */)); num_workers_used_ = threads_->Configure(threading::ThreadGroup::kBig, 0, exclude_worker0_); } ~ThreadPool() { @@ -296,10 +278,7 @@ class ThreadPool { } threads_.reset(); } - int Launch(FTVMParallelLambda flambda, - void* cdata, - int num_task, - int need_sync) { + int Launch(FTVMParallelLambda flambda, void* cdata, int num_task, int need_sync) { ParallelLauncher* launcher = ParallelLauncher::ThreadLocal(); CHECK(!launcher->is_worker) << "Cannot launch parallel job inside worker, consider fuse then parallel"; @@ -332,15 +311,12 @@ class ThreadPool { return res; } - static ThreadPool* ThreadLocal() { - return dmlc::ThreadLocalStore::Get(); - } + static ThreadPool* ThreadLocal() { return dmlc::ThreadLocalStore::Get(); } void UpdateWorkerConfiguration(threading::ThreadGroup::AffinityMode mode, int nthreads) { // this will also reset the affinity of the ThreadGroup // may use less than the MaxConcurrency number of workers - num_workers_used_ = threads_->Configure(mode, nthreads, - exclude_worker0_); + num_workers_used_ = threads_->Configure(mode, nthreads, exclude_worker0_); // if MaxConcurrency restricted the number of workers (e.g., due to // hyperthreading), respect the restriction num_workers_used_ = std::min(num_workers_, num_workers_used_); @@ -376,33 +352,25 @@ class ThreadPool { std::unique_ptr threads_; }; -TVM_REGISTER_GLOBAL("runtime.config_threadpool") -.set_body([](TVMArgs args, TVMRetValue* rv) { - threading::ThreadGroup::AffinityMode mode =\ - static_cast(\ - static_cast(args[0])); - int nthreads = args[1]; - ThreadPool::ThreadLocal()->UpdateWorkerConfiguration(mode, nthreads); +TVM_REGISTER_GLOBAL("runtime.config_threadpool").set_body([](TVMArgs args, TVMRetValue* rv) { + threading::ThreadGroup::AffinityMode mode = + static_cast(static_cast(args[0])); + int nthreads = args[1]; + ThreadPool::ThreadLocal()->UpdateWorkerConfiguration(mode, nthreads); }); - } // namespace runtime } // namespace tvm - -int TVMBackendParallelLaunch( - FTVMParallelLambda flambda, - void* cdata, - int num_task) { +int TVMBackendParallelLaunch(FTVMParallelLambda flambda, void* cdata, int num_task) { #if !TVM_THREADPOOL_USE_OPENMP - int res = tvm::runtime::ThreadPool::ThreadLocal()->Launch( - flambda, cdata, num_task, 1); + int res = tvm::runtime::ThreadPool::ThreadLocal()->Launch(flambda, cdata, num_task, 1); return res; #else int num_workers = tvm::runtime::threading::MaxConcurrency(); if (num_task == 0) num_task = num_workers; omp_set_num_threads(num_workers); - #pragma omp parallel num_threads(num_workers) +#pragma omp parallel num_threads(num_workers) { TVMParallelGroupEnv env; env.num_task = num_task; @@ -414,18 +382,15 @@ int TVMBackendParallelLaunch( int TVMBackendParallelBarrier(int task_id, TVMParallelGroupEnv* penv) { #if TVM_THREADPOOL_USE_OPENMP - #pragma omp barrier +#pragma omp barrier #else using tvm::runtime::kSyncStride; int num_task = penv->num_task; - std::atomic* sync_counter = - reinterpret_cast*>(penv->sync_handle); - int old_counter = sync_counter[task_id * kSyncStride].fetch_add( - 1, std::memory_order_release); + std::atomic* sync_counter = reinterpret_cast*>(penv->sync_handle); + int old_counter = sync_counter[task_id * kSyncStride].fetch_add(1, std::memory_order_release); for (int i = 0; i < num_task; ++i) { if (i != task_id) { - while (sync_counter[i * kSyncStride].load( - std::memory_order_relaxed) <= old_counter) { + while (sync_counter[i * kSyncStride].load(std::memory_order_relaxed) <= old_counter) { tvm::runtime::threading::Yield(); } } diff --git a/src/runtime/thread_storage_scope.h b/src/runtime/thread_storage_scope.h index 3e6fd781023c..1917096bb24c 100644 --- a/src/runtime/thread_storage_scope.h +++ b/src/runtime/thread_storage_scope.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -25,6 +25,7 @@ #define TVM_RUNTIME_THREAD_STORAGE_SCOPE_H_ #include + #include #include @@ -64,9 +65,12 @@ enum class StorageRank { */ inline StorageRank DefaultStorageRank(int thread_scope_rank) { switch (thread_scope_rank) { - case -1: return StorageRank::kGlobal; - case 0: return StorageRank::kShared; - case 1: return StorageRank::kLocal; + case -1: + return StorageRank::kGlobal; + case 0: + return StorageRank::kShared; + case 1: + return StorageRank::kLocal; default: { LOG(FATAL) << "unknown rank"; return StorageRank::kGlobal; @@ -84,30 +88,37 @@ struct StorageScope { inline bool operator==(const StorageScope& other) const { return rank == other.rank && tag == other.tag; } - inline bool operator!=(const StorageScope& other) const { - return !(*this == other); - } + inline bool operator!=(const StorageScope& other) const { return !(*this == other); } inline std::string to_string() const { std::string ret; switch (rank) { - case StorageRank::kGlobal: return "global" + tag; - case StorageRank::kShared: return "shared" + tag; - case StorageRank::kWarp: return "warp" + tag; - case StorageRank::kLocal: return "local" + tag; - case StorageRank::kWMMAMatrixA: return "wmma.matrix_a" + tag; - case StorageRank::kWMMAMatrixB: return "wmma.matrix_b" + tag; - case StorageRank::kWMMAAccumulator: return "wmma.accumulator" + tag; - default: LOG(FATAL) << "unknown storage scope"; return ""; + case StorageRank::kGlobal: + return "global" + tag; + case StorageRank::kShared: + return "shared" + tag; + case StorageRank::kWarp: + return "warp" + tag; + case StorageRank::kLocal: + return "local" + tag; + case StorageRank::kWMMAMatrixA: + return "wmma.matrix_a" + tag; + case StorageRank::kWMMAMatrixB: + return "wmma.matrix_b" + tag; + case StorageRank::kWMMAAccumulator: + return "wmma.accumulator" + tag; + default: + LOG(FATAL) << "unknown storage scope"; + return ""; } } /*! - * \brief make storage scope from string + * \brief Create storage scope from string * \param s The string to be parsed. * \return The storage scope. */ - static StorageScope make(const std::string& s) { + static StorageScope Create(const std::string& s) { StorageScope r; - if (s.compare(0, 6, "global") == 0) { + if (s.compare(0, 6, "global") == 0) { r.rank = StorageRank::kGlobal; r.tag = s.substr(6, std::string::npos); } else if (s.compare(0, 6, "shared") == 0) { @@ -142,11 +153,11 @@ struct ThreadScope { /*! \brief the dimension index under the rank */ int dim_index{0}; /*! - * \brief make storage scope from string + * \brief Create storage scope from string * \param s The string to be parsed. * \return The storage scope. */ - static ThreadScope make(const std::string& s) { + static ThreadScope Create(const std::string& s) { ThreadScope r; if (s == "vthread" || s == "cthread") { // virtual thread at the same level as local @@ -165,7 +176,6 @@ struct ThreadScope { } }; - /*! \brief workload specification */ struct ThreadWorkLoad { // array, first three are thread configuration. @@ -174,27 +184,22 @@ struct ThreadWorkLoad { * \param i The block dimension. * \return i-th block dim */ - inline size_t block_dim(size_t i) const { - return work_size[i + 3]; - } + inline size_t block_dim(size_t i) const { return work_size[i + 3]; } /*! * \param i The grid dimension. * \return i-th grid dim */ - inline size_t grid_dim(size_t i) const { - return work_size[i]; - } + inline size_t grid_dim(size_t i) const { return work_size[i]; } }; /*! \brief Thread axis configuration */ class ThreadAxisConfig { public: - void Init(size_t base, - const std::vector& thread_axis_tags) { + void Init(size_t base, const std::vector& thread_axis_tags) { base_ = base; std::vector filled(6, false); for (size_t i = 0; i < thread_axis_tags.size(); ++i) { const std::string& tag = thread_axis_tags[i]; - ThreadScope ts = ThreadScope::make(tag); + ThreadScope ts = ThreadScope::Create(tag); arg_index_map_.push_back(ts.rank * 3 + ts.dim_index); filled[ts.rank * 3 + ts.dim_index] = true; } @@ -210,15 +215,12 @@ class ThreadAxisConfig { ThreadWorkLoad w; std::fill(w.work_size, w.work_size + 6, 1); for (size_t i = 0; i < arg_index_map_.size(); ++i) { - w.work_size[arg_index_map_[i]] = - static_cast(x.values[base_ + i].v_int64); + w.work_size[arg_index_map_[i]] = static_cast(x.values[base_ + i].v_int64); } return w; } // return the work dim - size_t work_dim() const { - return work_dim_; - } + size_t work_dim() const { return work_dim_; } private: /*! \brief base axis */ diff --git a/src/runtime/threading_backend.cc b/src/runtime/threading_backend.cc index 9d14d3a14d03..e5520efe30a6 100644 --- a/src/runtime/threading_backend.cc +++ b/src/runtime/threading_backend.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -21,10 +21,11 @@ * \file threading_backend.cc * \brief Native threading backend */ -#include #include -#include +#include + #include +#include #if defined(__linux__) || defined(__ANDROID__) #include #include @@ -33,6 +34,9 @@ #if defined(__linux__) #include #endif +#if defined(__hexagon__) +#include +#endif namespace tvm { namespace runtime { @@ -40,12 +44,9 @@ namespace threading { class ThreadGroup::Impl { public: - Impl(int num_workers, - std::function worker_callback, - bool exclude_worker0) + Impl(int num_workers, std::function worker_callback, bool exclude_worker0) : num_workers_(num_workers) { - CHECK_GE(num_workers, 1) - << "Requested a non-positive number of worker threads."; + CHECK_GE(num_workers, 1) << "Requested a non-positive number of worker threads."; for (int i = exclude_worker0; i < num_workers_; ++i) { threads_.emplace_back([worker_callback, i] { worker_callback(i); }); } @@ -79,15 +80,14 @@ class ThreadGroup::Impl { // ones. num_workers_used = std::min(num_workers_, num_workers_used); - const char *val = getenv("TVM_BIND_THREADS"); + const char* val = getenv("TVM_BIND_THREADS"); if (val == nullptr || atoi(val) == 1) { // Do not set affinity if there are more workers than found cores if (sorted_order_.size() >= static_cast(num_workers_)) { - SetAffinity(exclude_worker0, mode == kLittle); + SetAffinity(exclude_worker0, mode == kLittle); } else { - LOG(WARNING) - << "The thread affinity cannot be set when the number of workers" - << "is larger than the number of available cores in the system."; + LOG(WARNING) << "The thread affinity cannot be set when the number of workers" + << "is larger than the number of available cores in the system."; } } return num_workers_used; @@ -101,15 +101,14 @@ class ThreadGroup::Impl { #if defined(__ANDROID__) #ifndef CPU_SET #define CPU_SETSIZE 1024 -#define __NCPUBITS (8 * sizeof (uint64_t)) +#define __NCPUBITS (8 * sizeof(uint64_t)) typedef struct { uint64_t __bits[CPU_SETSIZE / __NCPUBITS]; } cpu_set_t; #define CPU_SET(cpu, cpusetp) \ - ((cpusetp)->__bits[(cpu)/__NCPUBITS] |= (1UL << ((cpu) % __NCPUBITS))) -#define CPU_ZERO(cpusetp) \ - memset((cpusetp), 0, sizeof(cpu_set_t)) + ((cpusetp)->__bits[(cpu) / __NCPUBITS] |= (1UL << ((cpu) % __NCPUBITS))) +#define CPU_ZERO(cpusetp) memset((cpusetp), 0, sizeof(cpu_set_t)) #endif #endif #if defined(__linux__) || defined(__ANDROID__) @@ -128,8 +127,7 @@ class ThreadGroup::Impl { #if defined(__ANDROID__) sched_setaffinity(threads_[i].native_handle(), sizeof(cpu_set_t), &cpuset); #else - pthread_setaffinity_np(threads_[i].native_handle(), - sizeof(cpu_set_t), &cpuset); + pthread_setaffinity_np(threads_[i].native_handle(), sizeof(cpu_set_t), &cpuset); #endif } if (exclude_worker0) { // master thread run task @@ -182,27 +180,32 @@ class ThreadGroup::Impl { void InitSortedOrder() { unsigned int threads = std::thread::hardware_concurrency(); - std::vector > max_freqs; +#if defined(__hexagon__) + // With unsigned PDs, getting the number of available hardware threads + // is not supported in earlier versions of QuRT. In such cases assume 4. + if (threads == 0) threads = 4; +#endif + std::vector > max_freqs; for (unsigned int i = 0; i < threads; ++i) { int64_t cur_freq = 0; - #if defined(__linux__) || defined(__ANDROID__) - std::ostringstream filepath; - filepath << "/sys/devices/system/cpu/cpu" << i << "/cpufreq/cpuinfo_max_freq"; - std::ifstream ifs(filepath.str()); - if (!ifs.fail()) { - if (!(ifs >> cur_freq)) { - cur_freq = -1; - } - ifs.close(); +#if defined(__linux__) || defined(__ANDROID__) + std::ostringstream filepath; + filepath << "/sys/devices/system/cpu/cpu" << i << "/cpufreq/cpuinfo_max_freq"; + std::ifstream ifs(filepath.str()); + if (!ifs.fail()) { + if (!(ifs >> cur_freq)) { + cur_freq = -1; } - #endif + ifs.close(); + } +#endif max_freqs.push_back(std::make_pair(i, cur_freq)); } - auto fcmpbyfreq = [] (const std::pair &a, - const std::pair &b) { - return a.second == b.second ? a.first < b.first : a.second > b.second; + auto fcmpbyfreq = [](const std::pair& a, + const std::pair& b) { + return a.second == b.second ? a.first < b.first : a.second > b.second; }; std::sort(max_freqs.begin(), max_freqs.end(), fcmpbyfreq); int64_t big_freq = max_freqs.begin()->second; @@ -228,10 +231,9 @@ class ThreadGroup::Impl { int little_count_ = 0; }; -ThreadGroup::ThreadGroup(int num_workers, - std::function worker_callback, +ThreadGroup::ThreadGroup(int num_workers, std::function worker_callback, bool exclude_worker0) - : impl_(new ThreadGroup::Impl(num_workers, worker_callback, exclude_worker0)) {} + : impl_(new ThreadGroup::Impl(num_workers, worker_callback, exclude_worker0)) {} ThreadGroup::~ThreadGroup() { delete impl_; } void ThreadGroup::Join() { impl_->Join(); } @@ -239,13 +241,11 @@ int ThreadGroup::Configure(AffinityMode mode, int nthreads, bool exclude_worker0 return impl_->Configure(mode, nthreads, exclude_worker0); } -void Yield() { - std::this_thread::yield(); -} +void Yield() { std::this_thread::yield(); } int MaxConcurrency() { int max_concurrency = 1; - const char *val = getenv("TVM_NUM_THREADS"); + const char* val = getenv("TVM_NUM_THREADS"); if (val == nullptr) { val = getenv("OMP_NUM_THREADS"); } @@ -255,12 +255,22 @@ int MaxConcurrency() { max_concurrency = std::thread::hardware_concurrency(); #if defined(_M_X64) || defined(__x86_64__) max_concurrency /= 2; // ignore hyper-threading +#elif defined(__hexagon__) + // With unsigned PDs, getting the number of available hardware threads + // is not supported in earlier versions of QuRT. In such cases assume 4. + // If running on simulator, set max_concurrency to 1. + if (max_concurrency == 0) { + if (dlsym(RTLD_DEFAULT, "running_in_sim_dev_17bc90206f6cf5a7")) { + max_concurrency = 1; + } else { + max_concurrency = 4; + } + } #endif } return std::max(max_concurrency, 1); } - } // namespace threading } // namespace runtime } // namespace tvm diff --git a/src/runtime/vm/executable.cc b/src/runtime/vm/executable.cc index c2036da46e09..47bdd1c705de 100644 --- a/src/runtime/vm/executable.cc +++ b/src/runtime/vm/executable.cc @@ -28,9 +28,9 @@ #include #include -#include -#include #include +#include +#include #include #include #include @@ -50,24 +50,17 @@ VMInstructionSerializer SerializeInstruction(const Instruction& instr); // Helper to deserialize a serialized vm instruction. Instruction DeserializeInstruction(const VMInstructionSerializer& instr); -PackedFunc Executable::GetFunction(const std::string& name, - const ObjectPtr& sptr_to_self) { +PackedFunc Executable::GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) { if (name == "get_lib") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->GetLib(); - }); + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetLib(); }); } else if (name == "get_bytecode") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->GetBytecode(); - }); + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetBytecode(); }); } else if (name == "get_stats") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->Stats(); - }); + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->Stats(); }); } else if (name == "save") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->Save(); - }); + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->Save(); }); } else if (name == "get_function_arity") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { std::string func_name = args[0]; @@ -172,7 +165,8 @@ std::string Executable::Stats() const { // Get the number of globals and the name of each of them. oss << " Globals (#" << global_map.size() << "): ["; for (const auto& it : global_map) { - oss << "(\"" << it.first << "\", " << it.second << ")" << ", "; + oss << "(\"" << it.first << "\", " << it.second << ")" + << ", "; } if (!global_map.empty()) oss.seekp(-2, oss.cur); oss << "]" << std::endl; @@ -232,8 +226,7 @@ TVMByteArray Executable::Save() { void Executable::SaveGlobalSection(dmlc::Stream* strm) { std::vector > globals(this->global_map.begin(), this->global_map.end()); - auto comp = [](const std::pair& a, - const std::pair& b) { + auto comp = [](const std::pair& a, const std::pair& b) { return a.second < b.second; }; std::sort(globals.begin(), globals.end(), comp); @@ -314,9 +307,9 @@ VMInstructionSerializer SerializeInstruction(const Instruction& instr) { break; } case Opcode::AllocTensor: { - // Number of fields = 5 + instr.alloc_tensor.ndim + // Number of fields = 7 + instr.alloc_tensor.ndim fields.push_back(instr.alloc_tensor.storage); - + fields.push_back(instr.alloc_tensor.offset); // Save `DLDataType` and the dst register. const auto& dtype = instr.alloc_tensor.dtype; fields.push_back(dtype.code); @@ -337,8 +330,9 @@ VMInstructionSerializer SerializeInstruction(const Instruction& instr) { break; } case Opcode::AllocTensorReg: { - // Number of fields = 6 + // Number of fields = 7 fields.push_back(instr.alloc_tensor_reg.storage); + fields.push_back(instr.alloc_tensor_reg.offset); fields.push_back(instr.alloc_tensor_reg.shape_register); // Save `DLDataType` and the dst register. const auto& dtype = instr.alloc_tensor_reg.dtype; @@ -364,8 +358,7 @@ VMInstructionSerializer SerializeInstruction(const Instruction& instr) { fields.assign({instr.constructor_tag, instr.num_fields, instr.dst}); // Save the fields. - fields.insert(fields.end(), instr.datatype_fields, - instr.datatype_fields + instr.num_fields); + fields.insert(fields.end(), instr.datatype_fields, instr.datatype_fields + instr.num_fields); break; } case Opcode::AllocClosure: { @@ -373,15 +366,12 @@ VMInstructionSerializer SerializeInstruction(const Instruction& instr) { fields.assign({instr.clo_index, instr.num_freevar, instr.dst}); // Save the free vars. - fields.insert(fields.end(), instr.free_vars, - instr.free_vars + instr.num_freevar); + fields.insert(fields.end(), instr.free_vars, instr.free_vars + instr.num_freevar); break; } case Opcode::If: { // Number of fields = 4 - fields.assign({instr.if_op.test, - instr.if_op.target, - instr.if_op.true_offset, + fields.assign({instr.if_op.test, instr.if_op.target, instr.if_op.true_offset, instr.if_op.false_offset}); break; } @@ -399,8 +389,7 @@ VMInstructionSerializer SerializeInstruction(const Instruction& instr) { fields.assign({instr.closure, instr.num_closure_args, instr.dst}); // Save the args. - fields.insert(fields.end(), instr.closure_args, - instr.closure_args + instr.num_closure_args); + fields.insert(fields.end(), instr.closure_args, instr.closure_args + instr.num_closure_args); break; } case Opcode::LoadConst: { @@ -441,9 +430,7 @@ void Executable::SaveCodeSection(dmlc::Stream* strm) { strm->Write(static_cast(this->functions.size())); for (const auto& func : this->functions) { // Save the function info. - VMFunctionSerializer func_format(func.name, - func.register_file_size, - func.instructions.size(), + VMFunctionSerializer func_format(func.name, func.register_file_size, func.instructions.size(), func.params); func_format.Save(strm); @@ -523,8 +510,7 @@ void Executable::LoadPrimitiveOpNames(dmlc::Stream* strm) { // Extract the `cnt` number of fields started at `start` from the list // `instr_fields`. -inline std::vector ExtractFields(const std::vector& instr_fields, - Index start, +inline std::vector ExtractFields(const std::vector& instr_fields, Index start, Index cnt) { CHECK_LE(static_cast(start + cnt), instr_fields.size()); std::vector ret; @@ -564,39 +550,41 @@ Instruction DeserializeInstruction(const VMInstructionSerializer& instr) { return Instruction::InvokePacked(packed_index, arity, output_size, args); } case Opcode::AllocTensor: { - // Number of fields = 6 + instr.alloc_tensor.ndim - DCHECK_GE(instr.fields.size(), 6U); - DCHECK_EQ(instr.fields.size(), 6U + static_cast(instr.fields[4])); + // Number of fields = 7 + instr.alloc_tensor.ndim + DCHECK_GE(instr.fields.size(), 7U); + DCHECK_EQ(instr.fields.size(), 7U + static_cast(instr.fields[4])); RegName storage_reg = instr.fields[0]; + RegName offset = instr.fields[1]; DLDataType dtype; - dtype.code = instr.fields[1]; - dtype.bits = instr.fields[2]; - dtype.lanes = instr.fields[3]; + dtype.code = instr.fields[2]; + dtype.bits = instr.fields[3]; + dtype.lanes = instr.fields[4]; - Index ndim = instr.fields[4]; - RegName dst = instr.fields[5]; + Index ndim = instr.fields[5]; + RegName dst = instr.fields[6]; - std::vector shape = ExtractFields(instr.fields, 6, ndim); + std::vector shape = ExtractFields(instr.fields, 7, ndim); - return Instruction::AllocTensor(storage_reg, shape, dtype, dst); + return Instruction::AllocTensor(storage_reg, offset, shape, dtype, dst); } case Opcode::AllocTensorReg: { - // Number of fields = 5 - DCHECK_EQ(instr.fields.size(), 6U); + // Number of fields = 7 + DCHECK_EQ(instr.fields.size(), 7U); RegName storage_reg = instr.fields[0]; - Index shape_register = instr.fields[1]; + RegName offset = instr.fields[1]; + Index shape_register = instr.fields[2]; DLDataType dtype; - dtype.code = instr.fields[2]; - dtype.bits = instr.fields[3]; - dtype.lanes = instr.fields[4]; + dtype.code = instr.fields[3]; + dtype.bits = instr.fields[4]; + dtype.lanes = instr.fields[5]; - RegName dst = instr.fields[5]; + RegName dst = instr.fields[6]; - return Instruction::AllocTensorReg(storage_reg, shape_register, dtype, dst); + return Instruction::AllocTensorReg(storage_reg, offset, shape_register, dtype, dst); } case Opcode::AllocADT: { // Number of fields = 3 + instr.num_fields @@ -634,11 +622,7 @@ Instruction DeserializeInstruction(const VMInstructionSerializer& instr) { RegName dst = instr.fields[5]; - return Instruction::AllocStorage( - allocation_size, - alignment, - dtype, - dst); + return Instruction::AllocStorage(allocation_size, alignment, dtype, dst); } case Opcode::If: { // Number of fields = 4 @@ -727,9 +711,7 @@ void Executable::LoadCodeSection(dmlc::Stream* strm) { } // Create the VM function. - VMFunction vm_func = VMFunction(loaded_func.name, - loaded_func.params, - instructions, + VMFunction vm_func = VMFunction(loaded_func.name, loaded_func.params, instructions, loaded_func.register_file_size); auto it = this->global_map.find(loaded_func.name); CHECK(it != this->global_map.end()); @@ -738,24 +720,21 @@ void Executable::LoadCodeSection(dmlc::Stream* strm) { } } -TVM_REGISTER_GLOBAL("runtime.GetNumOfGlobals") -.set_body([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("runtime.GetNumOfGlobals").set_body([](TVMArgs args, TVMRetValue* rv) { runtime::Module mod = args[0]; const auto* exec = dynamic_cast(mod.operator->()); CHECK(exec); *rv = static_cast(exec->global_map.size()); }); -TVM_REGISTER_GLOBAL("runtime.GetGlobalFields") -.set_body([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("runtime.GetGlobalFields").set_body([](TVMArgs args, TVMRetValue* rv) { runtime::Module mod = args[0]; const auto* exec = dynamic_cast(mod.operator->()); CHECK(exec); int idx = args[1]; std::vector > globals(exec->global_map.begin(), exec->global_map.end()); - auto comp = [](const std::pair& a, - const std::pair& b) { + auto comp = [](const std::pair& a, const std::pair& b) { return a.second < b.second; }; std::sort(globals.begin(), globals.end(), comp); @@ -763,17 +742,14 @@ TVM_REGISTER_GLOBAL("runtime.GetGlobalFields") *rv = globals[idx].first; }); -TVM_REGISTER_GLOBAL("runtime.GetNumOfPrimitives") -.set_body([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("runtime.GetNumOfPrimitives").set_body([](TVMArgs args, TVMRetValue* rv) { runtime::Module mod = args[0]; const auto* exec = dynamic_cast(mod.operator->()); CHECK(exec); *rv = static_cast(exec->primitive_map.size()); }); - -TVM_REGISTER_GLOBAL("runtime.GetPrimitiveFields") -.set_body([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("runtime.GetPrimitiveFields").set_body([](TVMArgs args, TVMRetValue* rv) { runtime::Module mod = args[0]; const auto* exec = dynamic_cast(mod.operator->()); CHECK(exec); @@ -790,11 +766,9 @@ TVM_REGISTER_GLOBAL("runtime.GetPrimitiveFields") }); TVM_REGISTER_GLOBAL("runtime.Load_Executable") -.set_body_typed([]( - std::string code, - runtime::Module lib) { - return Executable::Load(code, lib); -}); + .set_body_typed([](std::string code, runtime::Module lib) { + return Executable::Load(code, lib); + }); } // namespace vm } // namespace runtime diff --git a/src/runtime/vm/memory_manager.cc b/src/runtime/vm/memory_manager.cc index 3e6140ed3830..4c220bbe61c8 100644 --- a/src/runtime/vm/memory_manager.cc +++ b/src/runtime/vm/memory_manager.cc @@ -21,9 +21,11 @@ * \file tvm/runtime/vm/memory_manager.cc * \brief Allocate and manage memory for the runtime. */ -#include -#include #include "memory_manager.h" + +#include +#include + #include "naive_allocator.h" #include "pooled_allocator.h" @@ -35,8 +37,7 @@ static void BufferDeleter(Object* obj) { auto* ptr = static_cast(obj); CHECK(ptr->manager_ctx != nullptr); Buffer* buffer = reinterpret_cast(ptr->manager_ctx); - MemoryManager::Global()->GetAllocator(buffer->ctx)-> - Free(*(buffer)); + MemoryManager::Global()->GetAllocator(buffer->ctx)->Free(*(buffer)); delete buffer; delete ptr; } @@ -76,8 +77,6 @@ inline size_t GetDataAlignment(const DLTensor& arr) { } NDArray StorageObj::AllocNDArray(size_t offset, std::vector shape, DLDataType dtype) { - // TODO(@jroesch): generalize later to non-overlapping allocations. - CHECK_EQ(offset, 0u); VerifyDataType(dtype); // crtical zone: allocate header, cannot throw @@ -86,14 +85,26 @@ NDArray StorageObj::AllocNDArray(size_t offset, std::vector shape, DLDa container->SetDeleter(StorageObj::Deleter); size_t needed_size = GetDataSize(container->dl_tensor); this->IncRef(); + // The manager context pointer must continue to point to the storage object + // which owns the backing memory, and keeps track of the reference count. + // + // When we free a container we extract the storage object, decrement its + // reference count, then destroy the container, but leave the underlying + // buffer intact. container->manager_ctx = reinterpret_cast(this); - container->dl_tensor.data = this->buffer.data; - NDArray ret(GetObjectPtr(container)); + // is this UB? + // The only change we make w.r.t offset is modifying the data pointer + // of the backing tensor to point into the buffer instead of its start. + auto offset_ptr = reinterpret_cast(this->buffer.data) + offset; + container->dl_tensor.data = reinterpret_cast(offset_ptr); + + NDArray ret(GetObjectPtr(container)); // RAII in effect, now run the check. - // TODO(@jroesch): generalize later to non-overlapping allocations. - CHECK(needed_size == this->buffer.size) - << "size mistmatch required " << needed_size << " found " << this->buffer.size; + + CHECK(offset + needed_size <= this->buffer.size) + << "storage allocation failure, attempted to allocate " << needed_size << " at offset " + << offset << " in region that is " << this->buffer.size << "bytes"; return ret; } @@ -106,8 +117,8 @@ MemoryManager* MemoryManager::Global() { Allocator* MemoryManager::GetAllocator(TVMContext ctx) { std::lock_guard lock(mu_); if (allocators_.find(ctx) == allocators_.end()) { - DLOG(INFO) << "New allocator for " << DeviceName(ctx.device_type) << "(" - << ctx.device_id << ")"; + DLOG(INFO) << "New allocator for " << DeviceName(ctx.device_type) << "(" << ctx.device_id + << ")"; std::unique_ptr alloc(new NaiveAllocator(ctx)); allocators_.emplace(ctx, std::move(alloc)); } @@ -120,7 +131,7 @@ NDArray Allocator::Empty(std::vector shape, DLDataType dtype, DLContext container->SetDeleter(BufferDeleter); size_t size = GetDataSize(container->dl_tensor); size_t alignment = GetDataAlignment(container->dl_tensor); - Buffer *buffer = new Buffer; + Buffer* buffer = new Buffer; *buffer = this->Alloc(size, alignment, dtype); container->manager_ctx = reinterpret_cast(buffer); container->dl_tensor.data = buffer->data; diff --git a/src/runtime/vm/memory_manager.h b/src/runtime/vm/memory_manager.h index b4453524d996..f59d584fcfba 100644 --- a/src/runtime/vm/memory_manager.h +++ b/src/runtime/vm/memory_manager.h @@ -27,6 +27,7 @@ #include #include #include + #include #include #include @@ -73,15 +74,13 @@ class Allocator { * \param ctx The context where the array is allocated. * \return The empty NDArray. */ - NDArray Empty(std::vector shape, - DLDataType dtype, - DLContext ctx); + NDArray Empty(std::vector shape, DLDataType dtype, DLContext ctx); /*! \brief Allocate a buffer given a size, alignment and type. * \param nbytes The size of the buffer. * \param alignment The alignment of the buffer. * \param type_hint A type hint to the allocator. * \return A sized allocation in the form of a buffer. - */ + */ virtual Buffer Alloc(size_t nbytes, size_t alignment, DLDataType type_hint) = 0; /*! \brief Free a buffer allocated by the allocator. * \param buffer The buffer to free. @@ -115,9 +114,7 @@ class StorageObj : public Object { Buffer buffer; /*! \brief Allocate an NDArray from a given piece of storage. */ - NDArray AllocNDArray(size_t offset, - std::vector shape, - DLDataType dtype); + NDArray AllocNDArray(size_t offset, std::vector shape, DLDataType dtype); /*! \brief The deleter for an NDArray when allocated from underlying storage. */ static void Deleter(Object* ptr); diff --git a/src/runtime/vm/naive_allocator.h b/src/runtime/vm/naive_allocator.h index db47a62a7c39..5ac2ca61817e 100644 --- a/src/runtime/vm/naive_allocator.h +++ b/src/runtime/vm/naive_allocator.h @@ -24,6 +24,7 @@ #define TVM_RUNTIME_VM_NAIVE_ALLOCATOR_H_ #include + #include #include "memory_manager.h" @@ -52,9 +53,7 @@ class NaiveAllocator final : public Allocator { DLOG(INFO) << "free " << buffer.size << " B, used memory " << used_memory_ << " B"; } - size_t UsedMemory() const override { - return used_memory_.load(std::memory_order_relaxed); - } + size_t UsedMemory() const override { return used_memory_.load(std::memory_order_relaxed); } private: std::atomic used_memory_; diff --git a/src/runtime/vm/pooled_allocator.h b/src/runtime/vm/pooled_allocator.h index 5965a4e8cf23..e09628f72e97 100644 --- a/src/runtime/vm/pooled_allocator.h +++ b/src/runtime/vm/pooled_allocator.h @@ -24,6 +24,7 @@ #define TVM_RUNTIME_VM_POOLED_ALLOCATOR_H_ #include + #include #include #include diff --git a/src/runtime/vm/profiler/vm.cc b/src/runtime/vm/profiler/vm.cc index 4dac66e50a82..6e4682d1ab96 100644 --- a/src/runtime/vm/profiler/vm.cc +++ b/src/runtime/vm/profiler/vm.cc @@ -22,6 +22,8 @@ * \brief The Relay debug virtual machine. */ +#include "vm.h" + #include #include @@ -34,27 +36,24 @@ #include #include -#include "vm.h" - namespace tvm { namespace runtime { namespace vm { -PackedFunc VirtualMachineDebug::GetFunction( - const std::string& name, const ObjectPtr& sptr_to_self) { +PackedFunc VirtualMachineDebug::GetFunction(const std::string& name, + const ObjectPtr& sptr_to_self) { if (name == "get_stat") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { CHECK_EQ(args.size(), 1U); std::vector> op_acc_time; for (auto kv : op_durations_) { - auto val = std::make_pair( - kv.first, std::accumulate(kv.second.begin(), kv.second.end(), 0.0)); + auto val = + std::make_pair(kv.first, std::accumulate(kv.second.begin(), kv.second.end(), 0.0)); op_acc_time.push_back(val); } bool sort_by_time = args[0]; if (sort_by_time) { - auto comp = [](const std::pair& lhs, - const std::pair& rhs) { + auto comp = [](const std::pair& lhs, const std::pair& rhs) { return lhs.second > rhs.second; }; std::sort(op_acc_time.begin(), op_acc_time.end(), comp); @@ -74,9 +73,9 @@ PackedFunc VirtualMachineDebug::GetFunction( auto min_value = *std::min_element(vals.begin(), vals.end()); auto max_value = *std::max_element(vals.begin(), vals.end()); - os << std::setw(30) << std::left << packed_index_map_[kv.first] << "\t" - << std::setw(10) << std::left << op_invokes_[kv.first] << "\t" - << sum << "/" << mean << "/" << min_value << "/" << max_value << std::endl; + os << std::setw(30) << std::left << packed_index_map_[kv.first] << "\t" << std::setw(10) + << std::left << op_invokes_[kv.first] << "\t" << sum << "/" << mean << "/" << min_value + << "/" << max_value << std::endl; total_duration += sum; total_packed_funcs += op_invokes_[kv.first]; @@ -104,10 +103,8 @@ void VirtualMachineDebug::LoadExecutable(const Executable* exec) { } } -void VirtualMachineDebug::InvokePacked(Index packed_index, - const PackedFunc& func, Index arg_count, - Index output_size, - const std::vector& args) { +void VirtualMachineDebug::InvokePacked(Index packed_index, const PackedFunc& func, Index arg_count, + Index output_size, const std::vector& args) { CHECK(exec_); auto ctx = this->GetParamsContext(); // warmup @@ -119,9 +116,7 @@ void VirtualMachineDebug::InvokePacked(Index packed_index, TVMSynchronize(ctx.device_type, ctx.device_id, nullptr); auto op_end = std::chrono::high_resolution_clock::now(); double op_duration = - std::chrono::duration_cast >(op_end - - op_begin) - .count(); + std::chrono::duration_cast>(op_end - op_begin).count(); op_durations_[packed_index].push_back(op_duration * 1e6); op_invokes_[packed_index] += 1; @@ -133,8 +128,7 @@ runtime::Module CreateVirtualMachineDebug(const Executable* exec) { return runtime::Module(vm); } -TVM_REGISTER_GLOBAL("runtime._VirtualMachineDebug") -.set_body([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("runtime._VirtualMachineDebug").set_body([](TVMArgs args, TVMRetValue* rv) { runtime::Module mod = args[0]; const auto* exec = dynamic_cast(mod.operator->()); CHECK(exec) << "Virtual machine has not been defined yet." diff --git a/src/runtime/vm/profiler/vm.h b/src/runtime/vm/profiler/vm.h index f0a407fd7266..c286828231b0 100644 --- a/src/runtime/vm/profiler/vm.h +++ b/src/runtime/vm/profiler/vm.h @@ -40,16 +40,15 @@ class VirtualMachineDebug : public VirtualMachine { public: VirtualMachineDebug() : VirtualMachine() {} - PackedFunc GetFunction(const std::string& name, - const ObjectPtr& sptr_to_self) final; + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final; void LoadExecutable(const Executable* exec) final; ~VirtualMachineDebug() {} private: - void InvokePacked(Index packed_index, const PackedFunc& func, Index arg_count, - Index output_size, const std::vector& args) final; + void InvokePacked(Index packed_index, const PackedFunc& func, Index arg_count, Index output_size, + const std::vector& args) final; std::unordered_map packed_index_map_; std::unordered_map> op_durations_; diff --git a/src/runtime/vm/serialize_util.h b/src/runtime/vm/serialize_util.h index 3423f7a94167..8bd1f86f8887 100644 --- a/src/runtime/vm/serialize_util.h +++ b/src/runtime/vm/serialize_util.h @@ -60,9 +60,7 @@ struct VMFunctionSerializer { VMFunctionSerializer() = default; - VMFunctionSerializer(const std::string& name, - Index register_file_size, - size_t num_instructions, + VMFunctionSerializer(const std::string& name, Index register_file_size, size_t num_instructions, const std::vector& params) : name(name), register_file_size(register_file_size), @@ -87,7 +85,7 @@ struct VMFunctionSerializer { } /*! - * \brief Save the VM function header into the serialized form. + * \brief Save the VM function header into the serialized form. * \param strm The stream used to save data. */ void Save(dmlc::Stream* strm) const { @@ -108,11 +106,11 @@ struct VMInstructionSerializer { VMInstructionSerializer() = default; - VMInstructionSerializer(Index opcode, const std::vector& fields) : - opcode(opcode), fields(fields) {} + VMInstructionSerializer(Index opcode, const std::vector& fields) + : opcode(opcode), fields(fields) {} /*! - * \brief Compute the hash of the serialized instruction. + * \brief Compute the hash of the serialized instruction. * \return The hash that combines the opcode and all fields of the VM * instruction. */ @@ -139,13 +137,12 @@ struct VMInstructionSerializer { } Index hash = Hash(); - CHECK_EQ(loaded_hash, hash) << "Found mismatch in hash for opcode: " - << opcode << "\n"; + CHECK_EQ(loaded_hash, hash) << "Found mismatch in hash for opcode: " << opcode << "\n"; return true; } /*! - * \brief Save the instruction into the serialized form. + * \brief Save the instruction into the serialized form. * \param strm The stream used to save data. */ void Save(dmlc::Stream* strm) const { diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index fedbbe9bb083..42bca37ee58b 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -23,11 +23,11 @@ */ #include -#include #include -#include #include #include +#include +#include #include #include @@ -56,8 +56,7 @@ inline Storage make_storage(size_t size, size_t alignment, DLDataType dtype_hint // We could put cache in here, from ctx to storage allocator. auto storage_obj = SimpleObjAllocator().make_object(); auto alloc = MemoryManager::Global()->GetAllocator(ctx); - DCHECK(alloc != nullptr) - << "allocator must not null"; + DCHECK(alloc != nullptr) << "allocator must not null"; storage_obj->buffer = alloc->Alloc(size, alignment, dtype_hint); return Storage(storage_obj); } @@ -86,13 +85,15 @@ Instruction::Instruction(const Instruction& instr) { return; case Opcode::AllocTensor: this->alloc_tensor.storage = instr.alloc_tensor.storage; + this->alloc_tensor.offset = instr.alloc_tensor.offset; this->alloc_tensor.ndim = instr.alloc_tensor.ndim; - this->alloc_tensor.shape = Duplicate(instr.alloc_tensor.shape, - instr.alloc_tensor.ndim); + this->alloc_tensor.shape = + Duplicate(instr.alloc_tensor.shape, instr.alloc_tensor.ndim); this->alloc_tensor.dtype = instr.alloc_tensor.dtype; return; case Opcode::AllocTensorReg: this->alloc_tensor_reg.storage = instr.alloc_tensor_reg.storage; + this->alloc_tensor_reg.offset = instr.alloc_tensor_reg.offset; this->alloc_tensor_reg.shape_register = instr.alloc_tensor_reg.shape_register; this->alloc_tensor_reg.dtype = instr.alloc_tensor_reg.dtype; return; @@ -151,7 +152,7 @@ Instruction::Instruction(const Instruction& instr) { } } -template +template static inline void FreeIf(T* t) { if (t != nullptr) { delete t; @@ -175,14 +176,16 @@ Instruction& Instruction::operator=(const Instruction& instr) { this->result = instr.result; return *this; case Opcode::AllocTensor: - this->alloc_tensor.storage = instr.alloc_tensor.storage; + this->alloc_tensor.storage = this->alloc_tensor.storage; + this->alloc_tensor.offset = instr.alloc_tensor.offset; this->alloc_tensor.ndim = instr.alloc_tensor.ndim; - this->alloc_tensor.shape = Duplicate(instr.alloc_tensor.shape, - instr.alloc_tensor.ndim); + this->alloc_tensor.shape = + Duplicate(instr.alloc_tensor.shape, instr.alloc_tensor.ndim); this->alloc_tensor.dtype = instr.alloc_tensor.dtype; return *this; case Opcode::AllocTensorReg: this->alloc_tensor_reg.storage = instr.alloc_tensor_reg.storage; + this->alloc_tensor_reg.offset = instr.alloc_tensor_reg.offset; this->alloc_tensor_reg.shape_register = instr.alloc_tensor_reg.shape_register; this->alloc_tensor_reg.dtype = instr.alloc_tensor_reg.dtype; return *this; @@ -258,22 +261,22 @@ Instruction::~Instruction() { case Opcode::Fatal: return; case Opcode::AllocTensor: - delete this->alloc_tensor.shape; + delete[] this->alloc_tensor.shape; return; case Opcode::AllocADT: - delete this->datatype_fields; + delete[] this->datatype_fields; return; case Opcode::AllocClosure: - delete this->free_vars; + delete[] this->free_vars; return; case Opcode::InvokePacked: - delete this->packed_args; + delete[] this->packed_args; return; case Opcode::InvokeClosure: - delete this->closure_args; + delete[] this->closure_args; return; case Opcode::Invoke: - delete this->invoke_args_registers; + delete[] this->invoke_args_registers; return; default: std::ostringstream out; @@ -294,9 +297,7 @@ Instruction Instruction::Fatal() { return instr; } -Instruction Instruction::InvokePacked(Index packed_index, - Index arity, - Index output_size, +Instruction Instruction::InvokePacked(Index packed_index, Index arity, Index output_size, const std::vector& args) { Instruction instr; instr.op = Opcode::InvokePacked; @@ -310,14 +311,14 @@ Instruction Instruction::InvokePacked(Index packed_index, return instr; } -Instruction Instruction::AllocTensor( - RegName storage, - const std::vector& shape, - DLDataType dtype, Index dst) { +Instruction Instruction::AllocTensor(RegName storage, RegName offset, + const std::vector& shape, DLDataType dtype, + Index dst) { Instruction instr; instr.op = Opcode::AllocTensor; instr.dst = dst; instr.alloc_tensor.storage = storage; + instr.alloc_tensor.offset = offset; instr.alloc_tensor.ndim = shape.size(); instr.alloc_tensor.shape = new int64_t[shape.size()]; for (size_t i = 0; i < shape.size(); ++i) { @@ -327,22 +328,19 @@ Instruction Instruction::AllocTensor( return instr; } -Instruction Instruction::AllocTensorReg( - RegName storage, - RegName shape_register, - DLDataType dtype, Index dst) { +Instruction Instruction::AllocTensorReg(RegName storage, RegName offset, RegName shape_register, + DLDataType dtype, Index dst) { Instruction instr; instr.op = Opcode::AllocTensorReg; instr.dst = dst; instr.alloc_tensor_reg.storage = storage; + instr.alloc_tensor_reg.offset = offset; instr.alloc_tensor_reg.shape_register = shape_register; instr.alloc_tensor_reg.dtype = dtype; return instr; } -Instruction Instruction::AllocStorage(RegName size, - Index alignment, - DLDataType dtype_hint, +Instruction Instruction::AllocStorage(RegName size, Index alignment, DLDataType dtype_hint, Index dst) { Instruction instr; instr.op = Opcode::AllocStorage; @@ -354,7 +352,7 @@ Instruction Instruction::AllocStorage(RegName size, } Instruction Instruction::AllocADT(Index tag, Index num_fields, - const std::vector& datatype_fields, Index dst) { + const std::vector& datatype_fields, Index dst) { Instruction instr; instr.op = Opcode::AllocADT; instr.dst = dst; @@ -486,7 +484,7 @@ void DLDatatypePrint(std::ostream& os, const DLDataType& dtype) { } } -template +template std::string StrJoin(T* items, int offset, int cnt, std::string delim = ", ") { if (cnt == 0) { return ""; @@ -515,26 +513,23 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) { } case Opcode::InvokePacked: { os << "invoke_packed PackedFunc[" << instr.packed_index << "] (in: $" - << StrJoin(instr.packed_args, 0, - instr.arity - instr.output_size, ", $") + << StrJoin(instr.packed_args, 0, instr.arity - instr.output_size, ", $") << ", out: $" - << StrJoin(instr.packed_args, instr.arity - instr.output_size, - instr.output_size, ", $") + << StrJoin(instr.packed_args, instr.arity - instr.output_size, instr.output_size, + ", $") << ")"; break; } case Opcode::AllocTensor: { - os << "alloc_tensor $" << instr.dst << " $" - << instr.alloc_tensor.storage << " [" - << StrJoin(instr.alloc_tensor.shape, 0, - instr.alloc_tensor.ndim) - << "] "; + os << "alloc_tensor $" << instr.dst << " $" << instr.alloc_tensor.storage << " $" + << instr.alloc_tensor.offset << " [" + << StrJoin(instr.alloc_tensor.shape, 0, instr.alloc_tensor.ndim) << "] "; DLDatatypePrint(os, instr.alloc_tensor.dtype); break; } case Opcode::AllocTensorReg: { - os << "alloc_tensor_reg $" << instr.dst << " $" - << instr.alloc_tensor_reg.storage << " $" + os << "alloc_tensor_reg $" << instr.dst << " $" << instr.alloc_tensor_reg.storage << " $" + << instr.alloc_tensor_reg.storage << " $" << instr.alloc_tensor_reg.offset << " $" << instr.alloc_tensor_reg.shape_register << " "; DLDatatypePrint(os, instr.alloc_tensor_reg.dtype); break; @@ -545,26 +540,24 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) { break; } case Opcode::AllocClosure: { - os << "alloc_closure $" << instr.dst << " VMFunc[" << instr.clo_index - << "]($" << StrJoin(instr.free_vars, 0, instr.num_freevar, ",$") - << ")"; + os << "alloc_closure $" << instr.dst << " VMFunc[" << instr.clo_index << "]($" + << StrJoin(instr.free_vars, 0, instr.num_freevar, ",$") << ")"; break; } case Opcode::If: { - os << "if " << "$" << instr.if_op.test << " $" << instr.if_op.target << " " - << instr.if_op.true_offset << " " << instr.if_op.false_offset; + os << "if " + << "$" << instr.if_op.test << " $" << instr.if_op.target << " " << instr.if_op.true_offset + << " " << instr.if_op.false_offset; break; } case Opcode::Invoke: { os << "invoke $" << instr.dst << " VMFunc[" << instr.func_index << "]($" - << StrJoin(instr.invoke_args_registers, 0, instr.num_args, ",$") - << ")"; + << StrJoin(instr.invoke_args_registers, 0, instr.num_args, ",$") << ")"; break; } case Opcode::InvokeClosure: { os << "invoke_closure $" << instr.dst << " $" << instr.closure << "($" - << StrJoin(instr.closure_args, 0, instr.num_closure_args, ",$") - << ")"; + << StrJoin(instr.closure_args, 0, instr.num_closure_args, ",$") << ")"; break; } case Opcode::LoadConst: { @@ -576,8 +569,7 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) { break; } case Opcode::GetField: { - os << "get_field $" << instr.dst << " $" << instr.object << "[" - << instr.field_index << "]"; + os << "get_field $" << instr.dst << " $" << instr.object << "[" << instr.field_index << "]"; break; } case Opcode::GetTag: { @@ -589,11 +581,9 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) { break; } case Opcode::AllocStorage: { - os << "alloc_storage $" << - instr.dst << " $" << - instr.alloc_storage.allocation_size << " $" << - instr.alloc_storage.alignment << " " << - DLDataType2String(instr.alloc_storage.dtype_hint); + os << "alloc_storage $" << instr.dst << " $" << instr.alloc_storage.allocation_size << " $" + << instr.alloc_storage.alignment << " " + << DLDataType2String(instr.alloc_storage.dtype_hint); break; } default: @@ -629,6 +619,36 @@ inline ObjectRef CopyTo(ObjectRef src, const DLContext& ctx) { return src; } +std::vector ToShape(NDArray shape_tensor) { + std::vector shape; + auto rank = shape_tensor.Shape().size(); + auto dtype = shape_tensor.DataType(); + + // For 0-rank shapes we need to allocate a single scalar. + if (rank == 0) { + return shape; + } + + // Otherwise we should be rank-1, and we will extract the number of dimensions + // for the output vector. + CHECK_EQ(rank, 1U) << "shape tensor should be a k-length vector, found " << rank; + int64_t ndim = shape_tensor.Shape().at(0); + shape.resize(ndim); + + const DLTensor* dl_tensor = shape_tensor.operator->(); + if (dtype.is_int() && dtype.bits() == 32 && dtype.lanes() == 1) { + int32_t* dims = reinterpret_cast(dl_tensor->data); + shape.assign(dims, dims + ndim); + } else if (dtype.is_int() && dtype.bits() == 64 && dtype.lanes() == 1) { + int64_t* dims = reinterpret_cast(dl_tensor->data); + shape.assign(dims, dims + ndim); + } else { + LOG(FATAL) << "invalid shape tensor datatype: " << dtype; + } + + return shape; +} + PackedFunc VirtualMachine::GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) { if (name == "invoke") { @@ -637,14 +657,14 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name, std::string func_name = args[0]; auto git = exec_->global_map.find(func_name); CHECK(git != exec_->global_map.end()) - << "Cannot find function " << func_name << " in the executable"; + << "Cannot find function " << func_name << " in the executable"; auto func = exec_->functions[git->second]; if (func.params.empty()) { *rv = Invoke(func, {}); } else { auto it = inputs_.find(func_name); CHECK(it != inputs_.end()) << "Input has not been set for function " << func_name; - const std::vector &func_args = it->second; + const std::vector& func_args = it->second; *rv = Invoke(func, func_args); } }); @@ -672,8 +692,8 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name, const auto& param_names = vm_func.params; // TODO(icemelon9): For heterogeneous execution, get input device information TVMContext ctx = ctxs_[0]; - CHECK_EQ(args.size() - 1, param_names.size()) << - "The number of provided parameters doesn't match the number of arguments"; + CHECK_EQ(args.size() - 1, param_names.size()) + << "The number of provided parameters doesn't match the number of arguments"; std::vector func_args(param_names.size()); for (int i = 1; i < args.size(); ++i) { ObjectRef obj = CopyTo(args[i], ctx); @@ -745,16 +765,14 @@ ObjectRef VirtualMachine::Invoke(const VMFunction& func, const std::vector& args) { CHECK(exec_) << "The executable has not been created yet."; auto it = exec_->global_map.find(name); - CHECK(it != exec_->global_map.end()) - << "Cannot find function " << name << " in the executable"; + CHECK(it != exec_->global_map.end()) << "Cannot find function " << name << " in the executable"; auto func_index_ = it->second; DLOG(INFO) << "Invoke Global " << name << " at index " << func_index_; return Invoke(exec_->functions[func_index_], args); } -void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func, - Index arg_count, Index output_size, - const std::vector& args) { +void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func, Index arg_count, + Index output_size, const std::vector& args) { size_t arity = 0; for (Index i = 0; i < arg_count; i++) { if (const auto* obj = args[i].as()) { @@ -806,10 +824,7 @@ void VirtualMachine::LoadExecutable(const Executable* exec) { } } - -void VirtualMachine::Init(const std::vector& ctxs) { - ctxs_ = ctxs; -} +void VirtualMachine::Init(const std::vector& ctxs) { ctxs_ = ctxs; } inline void VirtualMachine::WriteRegister(Index r, const ObjectRef& val) { frames_.back().register_file[r] = val; @@ -893,13 +908,13 @@ void VirtualMachine::RunLoop() { goto main_loop; } case Opcode::InvokePacked: { - DLOG(INFO) << "InvokedPacked " << "arity=" << instr.arity; + DLOG(INFO) << "InvokedPacked " + << "arity=" << instr.arity; const auto& func = packed_funcs_[instr.packed_index]; const auto& arity = instr.arity; std::vector args; for (Index i = 0; i < arity; ++i) { - DLOG(INFO) << - "arg" << i << " $" << instr.packed_args[i]; + DLOG(INFO) << "arg" << i << " $" << instr.packed_args[i]; auto arg = ReadRegister(instr.packed_args[i]); args.push_back(arg); } @@ -969,8 +984,9 @@ void VirtualMachine::RunLoop() { } auto storage_obj = ReadRegister(instr.alloc_tensor.storage); + auto offset = LoadScalarInt(instr.alloc_tensor.offset); auto storage = Downcast(storage_obj); - auto obj = storage->AllocNDArray(0, shape, instr.alloc_tensor.dtype); + auto obj = storage->AllocNDArray(offset, shape, instr.alloc_tensor.dtype); WriteRegister(instr.dst, obj); pc_++; @@ -983,17 +999,11 @@ void VirtualMachine::RunLoop() { auto shape_tensor_obj = ReadRegister(instr.alloc_tensor_reg.shape_register); const auto shape_arr = Downcast(shape_tensor_obj); NDArray shape_tensor = shape_arr.CopyTo(cpu_ctx); - const DLTensor* dl_tensor = shape_tensor.operator->(); - CHECK_EQ(dl_tensor->dtype.code, 0u); - CHECK_LE(dl_tensor->dtype.bits, 64); - int64_t* dims = reinterpret_cast(dl_tensor->data); - auto num_dims = shape_tensor->shape[0]; - auto shape = std::vector(num_dims); - shape.assign(dims, dims + num_dims); - + auto shape = ToShape(shape_tensor); auto storage_obj = ReadRegister(instr.alloc_tensor_reg.storage); auto storage = Downcast(storage_obj); - auto obj = storage->AllocNDArray(0, shape, instr.alloc_tensor_reg.dtype); + auto offset = LoadScalarInt(instr.alloc_tensor.offset); + auto obj = storage->AllocNDArray(offset, shape, instr.alloc_tensor_reg.dtype); WriteRegister(instr.dst, obj); pc_++; @@ -1022,10 +1032,8 @@ void VirtualMachine::RunLoop() { auto size = LoadScalarInt(instr.alloc_storage.allocation_size); auto alignment = LoadScalarInt(instr.alloc_storage.alignment); - DLOG(INFO) << - "AllocStorage: allocation_size=" << size << - "alignment=" << alignment << - "dtype_hint=" << DLDataType2String(instr.alloc_storage.dtype_hint); + DLOG(INFO) << "AllocStorage: allocation_size=" << size << "alignment=" << alignment + << "dtype_hint=" << DLDataType2String(instr.alloc_storage.dtype_hint); auto storage = make_storage(size, alignment, instr.alloc_storage.dtype_hint, ctxs_[0]); WriteRegister(instr.dst, storage); @@ -1057,8 +1065,7 @@ runtime::Module CreateVirtualMachine(const Executable* exec) { return runtime::Module(vm); } -TVM_REGISTER_GLOBAL("runtime._VirtualMachine") -.set_body([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("runtime._VirtualMachine").set_body([](TVMArgs args, TVMRetValue* rv) { runtime::Module mod = args[0]; const auto* exec = dynamic_cast(mod.operator->()); CHECK(exec) << "The virtual machine executable has not been defined yet."; diff --git a/src/runtime/vulkan/vulkan.cc b/src/runtime/vulkan/vulkan.cc index 80486406187b..44810116c3c2 100644 --- a/src/runtime/vulkan/vulkan.cc +++ b/src/runtime/vulkan/vulkan.cc @@ -17,21 +17,19 @@ * under the License. */ -#include #include #include #include #include +#include #include #include - #include "../file_util.h" #include "../pack_args.h" #include "../thread_storage_scope.h" #include "../workspace_pool.h" - #include "vulkan_common.h" #include "vulkan_module.h" #include "vulkan_shader.h" @@ -58,6 +56,8 @@ class VulkanThreadEntry { // the instance and device get destroyed. // The destruction need to be manually called // to ensure the destruction order. + + pool.reset(); streams_.clear(); for (const auto& kv : staging_buffers_) { if (!kv.second) { @@ -77,7 +77,7 @@ class VulkanThreadEntry { } TVMContext ctx; - WorkspacePool pool; + std::unique_ptr pool; VulkanStream* Stream(size_t device_id); VulkanStagingBuffer* StagingBuffer(int device_id, size_t size); @@ -117,9 +117,7 @@ class VulkanDeviceAPI final : public DeviceAPI { } void SetDevice(TVMContext ctx) final { VulkanThreadEntry::ThreadLocal()->ctx = ctx; } void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final; - void* AllocDataSpace(TVMContext ctx, - size_t nbytes, - size_t alignment, + void* AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment, DLDataType type_hint) final { const auto& vctx = context(ctx.device_id); VkBufferCreateInfo info; @@ -189,6 +187,10 @@ class VulkanDeviceAPI final : public DeviceAPI { } void FreeDataSpace(TVMContext ctx, void* ptr) final { + // Before releasing the vkBuffer, call sync to + // finish all the vulkan commands that reference the buffer. + StreamSync(ctx, nullptr); + const auto& vctx = context(ctx.device_id); auto* pbuf = static_cast(ptr); vkDestroyBuffer(vctx.device, pbuf->buffer, nullptr); @@ -331,11 +333,11 @@ class VulkanDeviceAPI final : public DeviceAPI { } void* AllocWorkspace(TVMContext ctx, size_t size, DLDataType type_hint) final { - return VulkanThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size); + return VulkanThreadEntry::ThreadLocal()->pool->AllocWorkspace(ctx, size); } void FreeWorkspace(TVMContext ctx, void* data) final { - VulkanThreadEntry::ThreadLocal()->pool.FreeWorkspace(ctx, data); + VulkanThreadEntry::ThreadLocal()->pool->FreeWorkspace(ctx, data); } static const std::shared_ptr& Global() { @@ -366,7 +368,7 @@ void VulkanDeviceAPI::GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* case kMaxThreadsPerBlock: { VkPhysicalDeviceProperties phy_prop; vkGetPhysicalDeviceProperties(vctx.phy_device, &phy_prop); - int64_t value = phy_prop.limits.maxComputeWorkGroupSize[0]; + int64_t value = phy_prop.limits.maxComputeWorkGroupInvocations; *rv = value; break; } @@ -399,8 +401,18 @@ void VulkanDeviceAPI::GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* return; case kExist: break; - case kMaxThreadDimensions: + case kMaxThreadDimensions: { + VkPhysicalDeviceProperties phy_prop; + vkGetPhysicalDeviceProperties(vctx.phy_device, &phy_prop); + int64_t dims[3]; + dims[0] = phy_prop.limits.maxComputeWorkGroupSize[0]; + dims[1] = phy_prop.limits.maxComputeWorkGroupSize[1]; + dims[2] = phy_prop.limits.maxComputeWorkGroupSize[2]; + std::stringstream ss; // use json string to return multiple int values; + ss << "[" << dims[0] << ", " << dims[1] << ", " << dims[2] << "]"; + *rv = ss.str(); break; + } case kGcnArch: return; } @@ -624,9 +636,8 @@ VulkanDeviceAPI::VulkanDeviceAPI() { #ifdef USE_VULKAN_IMMEDIATE_MODE if (has_extension("VK_KHR_push_descriptor") && has_extension("VK_KHR_descriptor_update_template")) { - ctx.descriptor_template_khr_functions = - std::unique_ptr( - new VulkanDescriptorTemplateKHRFunctions()); + ctx.descriptor_template_khr_functions = std::unique_ptr( + new VulkanDescriptorTemplateKHRFunctions()); ctx.descriptor_template_khr_functions->vkCreateDescriptorUpdateTemplateKHR = CHECK_NOTNULL((PFN_vkCreateDescriptorUpdateTemplateKHR)vkGetDeviceProcAddr( ctx.device, "vkCreateDescriptorUpdateTemplateKHR")); @@ -668,9 +679,7 @@ class VulkanModuleNode; // a wrapped function class to get packed func. class VulkanWrappedFunc { public: - void Init(VulkanModuleNode* m, - ObjectPtr sptr, - const std::string& func_name, + void Init(VulkanModuleNode* m, ObjectPtr sptr, const std::string& func_name, size_t num_buffer_args, size_t num_pack_args, const std::vector& thread_axis_tags) { m_ = m; @@ -706,13 +715,12 @@ class VulkanWrappedFunc { class VulkanModuleNode final : public runtime::ModuleNode { public: explicit VulkanModuleNode(std::unordered_map smap, - std::unordered_map fmap, std::string source) + std::unordered_map fmap, std::string source) : smap_(smap), fmap_(fmap), source_(source) {} const char* type_key() const final { return "vulkan"; } - PackedFunc GetFunction(const std::string& name, - const ObjectPtr& sptr_to_self) final { + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { CHECK_EQ(sptr_to_self.get(), this); CHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main"; auto it = fmap_.find(name); @@ -747,7 +755,7 @@ class VulkanModuleNode final : public runtime::ModuleNode { } std::shared_ptr GetPipeline(size_t device_id, const std::string& func_name, - size_t num_pack_args) { + size_t num_pack_args) { const auto& vctx = VulkanDeviceAPI::Global()->context(device_id); std::lock_guard lock(mutex_); const auto& cp = ecache_[device_id][func_name]; @@ -772,6 +780,7 @@ class VulkanModuleNode final : public runtime::ModuleNode { std::vector arg_binding; std::vector arg_template; uint32_t num_pod = 0, num_buffer = 0; + { auto fit = fmap_.find(func_name); CHECK(fit != fmap_.end()); @@ -927,8 +936,6 @@ class VulkanModuleNode final : public runtime::ModuleNode { } private: - // the binary data - std::vector data_; // function information table. std::unordered_map smap_; // function information table. @@ -1004,7 +1011,8 @@ VulkanStagingBuffer* VulkanThreadEntry::StagingBuffer(int device_id, size_t size } VulkanThreadEntry::VulkanThreadEntry() - : pool(static_cast(kDLVulkan), VulkanDeviceAPI::Global()) { + : pool(std::make_unique(static_cast(kDLVulkan), + VulkanDeviceAPI::Global())) { ctx.device_id = 0; ctx.device_type = static_cast(kDLVulkan); } @@ -1017,8 +1025,7 @@ VulkanStream* VulkanThreadEntry::Stream(size_t device_id) { return streams_[device_id].get(); } -void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, - const ArgUnion* pack_args) const { +void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion* pack_args) const { int device_id = VulkanThreadEntry::ThreadLocal()->ctx.device_id; CHECK_LT(device_id, kVulkanMaxNumDevice); const auto& vctx = VulkanDeviceAPI::Global()->context(device_id); diff --git a/src/runtime/vulkan/vulkan_common.h b/src/runtime/vulkan/vulkan_common.h index 9242d3d6d680..780b11184931 100644 --- a/src/runtime/vulkan/vulkan_common.h +++ b/src/runtime/vulkan/vulkan_common.h @@ -22,8 +22,8 @@ #include #include #include - #include + #include #include #include @@ -140,7 +140,6 @@ struct VulkanContext { bool UseImmediate() const { return descriptor_template_khr_functions.get() != nullptr; } }; - } // namespace vulkan } // namespace runtime } // namespace tvm diff --git a/src/runtime/vulkan/vulkan_shader.h b/src/runtime/vulkan/vulkan_shader.h index 1b2e45458f9c..d56ca61e91cb 100644 --- a/src/runtime/vulkan/vulkan_shader.h +++ b/src/runtime/vulkan/vulkan_shader.h @@ -18,7 +18,6 @@ */ #pragma once - #include #include #include diff --git a/src/runtime/vulkan/vulkan_stream.h b/src/runtime/vulkan/vulkan_stream.h index 1a24d2873a60..388cacc577b0 100644 --- a/src/runtime/vulkan/vulkan_stream.h +++ b/src/runtime/vulkan/vulkan_stream.h @@ -20,12 +20,11 @@ #include #include -#include #include +#include #include "vulkan_common.h" - namespace tvm { namespace runtime { namespace vulkan { @@ -44,8 +43,7 @@ struct VulkanStreamToken { class VulkanStream { public: - explicit VulkanStream(const VulkanContext* vctx) - : vctx_(vctx), state_(new VulkanStreamState()) { + explicit VulkanStream(const VulkanContext* vctx) : vctx_(vctx), state_(new VulkanStreamState()) { // create command pool VkCommandPoolCreateInfo cmd_pool_cinfo; cmd_pool_cinfo.sType = VK_STRUCTURE_TYPE_COMMAND_POOL_CREATE_INFO; diff --git a/src/runtime/workspace_pool.cc b/src/runtime/workspace_pool.cc index fc316cdeded1..8ee905e4ea84 100644 --- a/src/runtime/workspace_pool.cc +++ b/src/runtime/workspace_pool.cc @@ -21,9 +21,10 @@ * \file workspace_pool.h * \brief Workspace pool utility. */ -#include #include "workspace_pool.h" +#include + namespace tvm { namespace runtime { @@ -67,7 +68,8 @@ class WorkspacePool::Pool { if (free_list_.back().size >= nbytes) { // find smallest fit auto it = free_list_.end() - 2; - for (; it->size >= nbytes; --it) {} + for (; it->size >= nbytes; --it) { + } e = *(it + 1); free_list_.erase(it + 1); } else { @@ -91,7 +93,8 @@ class WorkspacePool::Pool { allocated_.pop_back(); } else { int index = static_cast(allocated_.size()) - 2; - for (; index > 0 && allocated_[index].data != data; --index) {} + for (; index > 0 && allocated_[index].data != data; --index) { + } CHECK_GT(index, 0) << "trying to free things that has not been allocated"; e = allocated_[index]; allocated_.erase(allocated_.begin() + index); @@ -132,8 +135,7 @@ class WorkspacePool::Pool { }; WorkspacePool::WorkspacePool(DLDeviceType device_type, std::shared_ptr device) - : device_type_(device_type), device_(device) { -} + : device_type_(device_type), device_(device) {} WorkspacePool::~WorkspacePool() { for (size_t i = 0; i < array_.size(); ++i) { @@ -158,8 +160,7 @@ void* WorkspacePool::AllocWorkspace(TVMContext ctx, size_t size) { } void WorkspacePool::FreeWorkspace(TVMContext ctx, void* ptr) { - CHECK(static_cast(ctx.device_id) < array_.size() && - array_[ctx.device_id] != nullptr); + CHECK(static_cast(ctx.device_id) < array_.size() && array_[ctx.device_id] != nullptr); array_[ctx.device_id]->Free(ptr); } diff --git a/src/runtime/workspace_pool.h b/src/runtime/workspace_pool.h index 72613caffb8e..288da7d10483 100644 --- a/src/runtime/workspace_pool.h +++ b/src/runtime/workspace_pool.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -25,8 +25,9 @@ #define TVM_RUNTIME_WORKSPACE_POOL_H_ #include -#include + #include +#include namespace tvm { namespace runtime { diff --git a/src/support/arena.h b/src/support/arena.h index 744ff4f12188..cb08db93641d 100644 --- a/src/support/arena.h +++ b/src/support/arena.h @@ -26,42 +26,107 @@ #ifndef TVM_SUPPORT_ARENA_H_ #define TVM_SUPPORT_ARENA_H_ -#include +#ifndef TVM_ARENA_HAS_DESTRUCTOR +#define TVM_ARENA_HAS_DESTRUCTOR 1 +#endif + +#include #include +#include namespace tvm { namespace support { -const constexpr int kArenaPageSize = 16 << 10; +/*! + * \brief An arena page header. + */ +struct ArenaPageHeader { + /*! \brief points to the next page. */ + ArenaPageHeader* next; + /*! + * \brief Total size of the page. + */ + size_t size; + /*! \brief memory allocator offset inside page. */ + size_t offset; +}; + +/*! + * \brief Simple page allocator that uses new and delete. + */ +class SimplePageAllocator { + public: + /*! + * \brief Allocate a new page. + * \param min_size Minimum size of the page. + * \return The allocated page. + * \note This function can return a bigger page to meet the min_size requirement. + */ + ArenaPageHeader* allocate(size_t min_size) { + size_t npages = ((min_size + kPageSize - 1) / kPageSize); + ArenaPageHeader* header = reinterpret_cast(new Page[npages]); + header->size = npages * kPageSize; + header->offset = sizeof(ArenaPageHeader); + return header; + } + /*! + * \brief De-allocate an allocate page. + * \param page The page to be de-allocated. + */ + void deallocate(ArenaPageHeader* page) { delete[] reinterpret_cast(page); } + + static const constexpr int kPageSize = 16 << 10; + static const constexpr int kPageAlign = 1024; + + private: + // page size 16 KB + // The page data type; + using Page = std::aligned_storage::type; +}; /*! * \brief Arena allocator that allocates memory from continuous * chunk and frees them all only during destruction. */ -class Arena { +template +class GenericArena { public: - Arena() { + explicit GenericArena(PageAllocator alloc = PageAllocator()) : alloc_(alloc) { // eagerly allocate the first page. - head_ = reinterpret_cast(new Page()); + head_ = tail_ = alloc_.allocate(1); head_->next = nullptr; - head_->ptr = sizeof(PageHeader); } - ~Arena() { - // delete all the allocated pages. - while (head_ != nullptr) { - Page* page = reinterpret_cast(head_); - head_ = head_->next; - delete page; - } + +#if TVM_ARENA_HAS_DESTRUCTOR + ~GenericArena() { this->FreeAll(); } +#endif + + /*! \brief Free all pages. */ + void FreeAll() { + FreePageList(&head_); + FreePageList(&free_list_); + } + /*! \brief Recycle all the pages in the arena */ + void RecycleAll() { + // put all the current list to the free list. + tail_->next = free_list_; + // allocate the first in the free list to head + free_list_ = head_->next; + head_->next = nullptr; + // Reset the head. + head_->offset = sizeof(ArenaPageHeader); + tail_ = head_; } /*! * \brief Allocate a space from Arena for type T * \param T the data type to be allocated + * \param count Numberof elements * \note The space of T is not initialized. */ - template - T* allocate_() { - return static_cast(Alloc(sizeof(T), alignof(T))); + template + T* allocate_(int count = 1) { + static_assert(PageAllocator::kPageAlign % alignof(T) == 0, "To large alignment"); + return static_cast(Alloc(sizeof(T) * count, alignof(T))); } /*! * \brief Create a new instance of type T. @@ -74,7 +139,7 @@ class Arena { * memory allocated from the same arena. * Otherwise the destructor needs to be called explicitly. */ - template + template T* make(Args&&... args) { T* ptr = allocate_(); new (ptr) T(std::forward(args)...); @@ -82,25 +147,21 @@ class Arena { } private: - // page size 16 KB - // The page data type; - using Page = std::aligned_storage::type; - /*! \brief Page header */ - struct PageHeader { - /*! \brief points to the next page */ - PageHeader* next; - /*! \brief memory allocator ptr inside page */ - size_t ptr; - }; - /* \brief The page header */ - PageHeader* head_{nullptr}; + /*! \brief internal page allocator. */ + PageAllocator alloc_; + /* \brief The the head of the allocated list. */ + ArenaPageHeader* head_{nullptr}; + /*! \brief The tail of the allocated list. */ + ArenaPageHeader* tail_{nullptr}; + /* \brief List of free pages. */ + ArenaPageHeader* free_list_{nullptr}; /*! * \brief Align ptr by upper bound. - * \param ptr The pointer value. + * \param offset The offset value. * \param align The alignment requirement. */ - size_t UpperAlign(size_t ptr, size_t align) { - return ptr + (align - (ptr % align)) % align; + size_t UpperAlign(size_t offset, size_t align) { + return offset + (align - (offset % align)) % align; } /*! * \brief Internal aligned alloc function. @@ -108,27 +169,46 @@ class Arena { * \param align The alignment requirement. */ void* Alloc(size_t size, size_t align) { - size_t ptr = UpperAlign(head_->ptr, align); - if (ptr + size <= kArenaPageSize) { - head_->ptr = ptr + size; - return reinterpret_cast(head_) + ptr; + size_t offset = UpperAlign(head_->offset, align); + if (offset + size <= head_->size) { + head_->offset = offset + size; + return reinterpret_cast(head_) + offset; } else { - PageHeader* new_head = reinterpret_cast(new Page()); + ArenaPageHeader* new_head; + offset = UpperAlign(sizeof(ArenaPageHeader), align); + if (free_list_ != nullptr && offset + size <= free_list_->size) { + new_head = free_list_; + free_list_ = free_list_->next; + } else { + new_head = alloc_.allocate(offset + size); + } new_head->next = head_; - ptr = UpperAlign(sizeof(PageHeader), align); - CHECK_LE(ptr + size, kArenaPageSize); - new_head->ptr = ptr + size; + new_head->offset = offset + size; head_ = new_head; - return reinterpret_cast(head_) + ptr; + return reinterpret_cast(head_) + offset; + } + } + /*! + * \brief Free all the pages in the list. + * \param ptr The head ptr. + */ + void FreePageList(ArenaPageHeader** ptr) { + // delete all the allocated pages. + while (ptr[0] != nullptr) { + ArenaPageHeader* temp = ptr[0]; + ptr[0] = ptr[0]->next; + alloc_.deallocate(temp); } } }; +using Arena = GenericArena; + /*! * \brief Link list node * \tparam T the content data type */ -template +template struct LinkNode { /*! \brief The content value */ T value; @@ -141,7 +221,7 @@ struct LinkNode { * \note This is a simple data structure that can be used together with the arena. * \sa LinkNode */ -template +template struct LinkedList { /*! \brief Head pointer */ LinkNode* head{nullptr}; diff --git a/src/support/base64.h b/src/support/base64.h index c85b268fd7b6..9849542471c2 100644 --- a/src/support/base64.h +++ b/src/support/base64.h @@ -27,7 +27,7 @@ #define TVM_SUPPORT_BASE64_H_ #include -#include + #include #include #include @@ -38,18 +38,16 @@ namespace support { namespace base64 { // decoding table const char DecodeTable[] = { - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 62, // '+' - 0, 0, 0, - 63, // '/' - 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, // '0'-'9' - 0, 0, 0, 0, 0, 0, 0, - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, - 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, // 'A'-'Z' - 0, 0, 0, 0, 0, 0, - 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, - 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, // 'a'-'z' + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 62, // '+' + 0, 0, 0, + 63, // '/' + 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, // '0'-'9' + 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, + 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, // 'A'-'Z' + 0, 0, 0, 0, 0, 0, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, + 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, // 'a'-'z' }; // encoding table static const char EncodeTable[] = @@ -62,14 +60,12 @@ static const char EncodeTable[] = */ class StreamBufferReader { public: - explicit StreamBufferReader(size_t buffer_size) { - buffer_.resize(buffer_size); - } + explicit StreamBufferReader(size_t buffer_size) { buffer_.resize(buffer_size); } /*! * \brief set input stream * \param stream The stream to be set */ - void set_stream(dmlc::Stream *stream) { + void set_stream(dmlc::Stream* stream) { stream_ = stream; read_len_ = read_ptr_ = 1; } @@ -88,13 +84,11 @@ class StreamBufferReader { } } /*! \return whether we are reaching the end of file */ - bool AtEnd() const { - return read_len_ == 0; - } + bool AtEnd() const { return read_len_ == 0; } private: /*! \brief the underlying stream */ - dmlc::Stream *stream_{nullptr}; + dmlc::Stream* stream_{nullptr}; /*! \brief buffer to hold data */ std::string buffer_; /*! \brief length of valid data in buffer */ @@ -106,11 +100,9 @@ class StreamBufferReader { /*! * \brief Input stream from base64 encoding */ -class Base64InStream: public dmlc::Stream { +class Base64InStream : public dmlc::Stream { public: - explicit Base64InStream(dmlc::Stream *fs) : reader_(256) { - reader_.set_stream(fs); - } + explicit Base64InStream(dmlc::Stream* fs) : reader_(256) { reader_.set_stream(fs); } /*! * \brief initialize the stream position to beginning of next base64 stream * \note call this function before actually start read @@ -122,16 +114,14 @@ class Base64InStream: public dmlc::Stream { } while (isspace(temp_ch_)); } /*! \brief whether current position is end of a base64 stream */ - bool IsEOF(void) const { - return num_prev_ == 0 && (temp_ch_ == EOF || isspace(temp_ch_)); - } + bool IsEOF(void) const { return num_prev_ == 0 && (temp_ch_ == EOF || isspace(temp_ch_)); } // override read function. - virtual size_t Read(void *ptr, size_t size) { + virtual size_t Read(void* ptr, size_t size) { using base64::DecodeTable; if (size == 0) return 0; // use tlen to record left size size_t tlen = size; - unsigned char *cptr = static_cast(ptr); + unsigned char* cptr = static_cast(ptr); // if anything left, load from previous buffered result if (num_prev_ != 0) { if (num_prev_ == 2) { @@ -142,13 +132,16 @@ class Base64InStream: public dmlc::Stream { num_prev_ = 0; } else { // assert tlen == 1 - *cptr++ = buf_prev[0]; --tlen; + *cptr++ = buf_prev[0]; + --tlen; buf_prev[0] = buf_prev[1]; num_prev_ = 1; } } else { // assert num_prev_ == 1 - *cptr++ = buf_prev[0]; --tlen; num_prev_ = 0; + *cptr++ = buf_prev[0]; + --tlen; + num_prev_ = 0; } } if (tlen == 0) return size; @@ -163,8 +156,9 @@ class Base64InStream: public dmlc::Stream { temp_ch_ = reader_.GetChar(); CHECK(temp_ch_ != EOF && !isspace(temp_ch_)) << "invalid base64 format"; nvalue |= DecodeTable[temp_ch_] << 12; - *cptr++ = (nvalue >> 16) & 0xFF; --tlen; - } + *cptr++ = (nvalue >> 16) & 0xFF; + --tlen; + } { // third byte temp_ch_ = reader_.GetChar(); @@ -174,13 +168,13 @@ class Base64InStream: public dmlc::Stream { temp_ch_ = reader_.GetChar(); CHECK(temp_ch_ == '=') << "invalid base64 format"; temp_ch_ = reader_.GetChar(); - CHECK(temp_ch_ == EOF || isspace(temp_ch_)) - << "invalid base64 format"; + CHECK(temp_ch_ == EOF || isspace(temp_ch_)) << "invalid base64 format"; break; } nvalue |= DecodeTable[temp_ch_] << 6; if (tlen) { - *cptr++ = (nvalue >> 8) & 0xFF; --tlen; + *cptr++ = (nvalue >> 8) & 0xFF; + --tlen; } else { buf_prev[num_prev_++] = (nvalue >> 8) & 0xFF; } @@ -188,19 +182,18 @@ class Base64InStream: public dmlc::Stream { { // fourth byte temp_ch_ = reader_.GetChar(); - CHECK(temp_ch_ != EOF && !isspace(temp_ch_)) - << "invalid base64 format"; + CHECK(temp_ch_ != EOF && !isspace(temp_ch_)) << "invalid base64 format"; if (temp_ch_ == '=') { temp_ch_ = reader_.GetChar(); - CHECK(temp_ch_ == EOF || isspace(temp_ch_)) - << "invalid base64 format"; + CHECK(temp_ch_ == EOF || isspace(temp_ch_)) << "invalid base64 format"; break; } nvalue |= DecodeTable[temp_ch_]; if (tlen) { - *cptr++ = nvalue & 0xFF; --tlen; + *cptr++ = nvalue & 0xFF; + --tlen; } else { - buf_prev[num_prev_ ++] = nvalue & 0xFF; + buf_prev[num_prev_++] = nvalue & 0xFF; } } // get next char @@ -211,7 +204,7 @@ class Base64InStream: public dmlc::Stream { } return size - tlen; } - virtual void Write(const void *ptr, size_t size) { + virtual void Write(const void* ptr, size_t size) { LOG(FATAL) << "Base64InStream do not support write"; } @@ -228,17 +221,17 @@ class Base64InStream: public dmlc::Stream { /*! * \brief Stream to write to base64 format. */ -class Base64OutStream: public dmlc::Stream { +class Base64OutStream : public dmlc::Stream { public: - explicit Base64OutStream(dmlc::Stream *fp) : fp_(fp) { - } - virtual void Write(const void *ptr, size_t size) { + explicit Base64OutStream(dmlc::Stream* fp) : fp_(fp) {} + virtual void Write(const void* ptr, size_t size) { using base64::EncodeTable; size_t tlen = size; - const unsigned char *cptr = static_cast(ptr); + const unsigned char* cptr = static_cast(ptr); while (tlen) { - while (buf__top_ < 3 && tlen != 0) { - buf_[++buf__top_] = *cptr++; --tlen; + while (buf__top_ < 3 && tlen != 0) { + buf_[++buf__top_] = *cptr++; + --tlen; } if (buf__top_ == 3) { // flush 4 bytes out @@ -250,7 +243,7 @@ class Base64OutStream: public dmlc::Stream { } } } - virtual size_t Read(void *ptr, size_t size) { + virtual size_t Read(void* ptr, size_t size) { LOG(FATAL) << "Base64OutStream do not support read"; return 0; } @@ -280,12 +273,11 @@ class Base64OutStream: public dmlc::Stream { private: static constexpr size_t kBufferSize = 256; - dmlc::Stream *fp_{nullptr}; + dmlc::Stream* fp_{nullptr}; int buf__top_{0}; unsigned char buf_[4]; std::string out_buf_; - void PutChar(char ch) { out_buf_ += ch; if (out_buf_.length() >= kBufferSize) Flush(); diff --git a/src/support/ffi_testing.cc b/src/support/ffi_testing.cc index 90fcfff0eef3..839f52968b82 100644 --- a/src/support/ffi_testing.cc +++ b/src/support/ffi_testing.cc @@ -17,35 +17,29 @@ * under the License. */ - /*! +/*! * FFI registration code used for frontend testing purposes. * \file ffi_testing.cc */ -#include -#include -#include #include #include +#include +#include +#include namespace tvm { // Attrs used to python API struct TestAttrs : public AttrsNode { int axis; - std::string name; + String name; Array padding; TypedEnvFunc func; TVM_DECLARE_ATTRS(TestAttrs, "attrs.TestAttrs") { - TVM_ATTR_FIELD(axis) - .set_default(10) - .set_lower_bound(1) - .set_upper_bound(10) - .describe("axis field"); - TVM_ATTR_FIELD(name) - .describe("name"); - TVM_ATTR_FIELD(padding) - .describe("padding of input") - .set_default(Array({0, 0})); + TVM_ATTR_FIELD(axis).set_default(10).set_lower_bound(1).set_upper_bound(10).describe( + "axis field"); + TVM_ATTR_FIELD(name).describe("name"); + TVM_ATTR_FIELD(padding).describe("padding of input").set_default(Array({0, 0})); TVM_ATTR_FIELD(func) .describe("some random env function") .set_default(TypedEnvFunc(nullptr)); @@ -54,44 +48,37 @@ struct TestAttrs : public AttrsNode { TVM_REGISTER_NODE_TYPE(TestAttrs); -TVM_REGISTER_GLOBAL("testing.nop") -.set_body([](TVMArgs args, TVMRetValue *ret) { - }); +TVM_REGISTER_GLOBAL("testing.nop").set_body([](TVMArgs args, TVMRetValue* ret) {}); -TVM_REGISTER_GLOBAL("testing.test_wrap_callback") -.set_body([](TVMArgs args, TVMRetValue *ret) { - PackedFunc pf = args[0]; - *ret = runtime::TypedPackedFunc([pf](){ - pf(); - }); - }); +TVM_REGISTER_GLOBAL("testing.echo").set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = args[0]; +}); -TVM_REGISTER_GLOBAL("testing.test_raise_error_callback") -.set_body([](TVMArgs args, TVMRetValue *ret) { - std::string msg = args[0]; - *ret = runtime::TypedPackedFunc([msg](){ - LOG(FATAL) << msg; - }); - }); +TVM_REGISTER_GLOBAL("testing.test_wrap_callback").set_body([](TVMArgs args, TVMRetValue* ret) { + PackedFunc pf = args[0]; + *ret = runtime::TypedPackedFunc([pf]() { pf(); }); +}); -TVM_REGISTER_GLOBAL("testing.test_check_eq_callback") -.set_body([](TVMArgs args, TVMRetValue *ret) { - std::string msg = args[0]; - *ret = runtime::TypedPackedFunc([msg](int x, int y){ - CHECK_EQ(x, y) << msg; - }); - }); +TVM_REGISTER_GLOBAL("testing.test_raise_error_callback") + .set_body([](TVMArgs args, TVMRetValue* ret) { + std::string msg = args[0]; + *ret = runtime::TypedPackedFunc([msg]() { LOG(FATAL) << msg; }); + }); -TVM_REGISTER_GLOBAL("testing.context_test") -.set_body([](TVMArgs args, TVMRetValue *ret) { - DLContext ctx = args[0]; - int dtype = args[1]; - int did = args[2]; - CHECK_EQ(static_cast(ctx.device_type), dtype); - CHECK_EQ(static_cast(ctx.device_id), did); - *ret = ctx; - }); +TVM_REGISTER_GLOBAL("testing.test_check_eq_callback").set_body([](TVMArgs args, TVMRetValue* ret) { + std::string msg = args[0]; + *ret = + runtime::TypedPackedFunc([msg](int x, int y) { CHECK_EQ(x, y) << msg; }); +}); +TVM_REGISTER_GLOBAL("testing.context_test").set_body([](TVMArgs args, TVMRetValue* ret) { + DLContext ctx = args[0]; + int dtype = args[1]; + int did = args[2]; + CHECK_EQ(static_cast(ctx.device_type), dtype); + CHECK_EQ(static_cast(ctx.device_id), did); + *ret = ctx; +}); // in src/api_test.cc void ErrorTest(int x, int y) { @@ -103,15 +90,13 @@ void ErrorTest(int x, int y) { } } -TVM_REGISTER_GLOBAL("testing.ErrorTest") -.set_body_typed(ErrorTest); +TVM_REGISTER_GLOBAL("testing.ErrorTest").set_body_typed(ErrorTest); // internal function used for debug and testing purposes -TVM_REGISTER_GLOBAL("testing.object_use_count") -.set_body([](TVMArgs args, TVMRetValue *ret) { - runtime::ObjectRef obj = args[0]; - // substract the current one because we always copy - // and get another value. - *ret = (obj.use_count() - 1); - }); +TVM_REGISTER_GLOBAL("testing.object_use_count").set_body([](TVMArgs args, TVMRetValue* ret) { + runtime::ObjectRef obj = args[0]; + // substract the current one because we always copy + // and get another value. + *ret = (obj.use_count() - 1); +}); } // namespace tvm diff --git a/src/support/pipe.h b/src/support/pipe.h index 120bbdb95e77..dcebd0ddf32f 100644 --- a/src/support/pipe.h +++ b/src/support/pipe.h @@ -24,16 +24,17 @@ #ifndef TVM_SUPPORT_PIPE_H_ #define TVM_SUPPORT_PIPE_H_ -#include #include +#include #ifdef _WIN32 #include #else -#include #include -#include +#include + #include +#include #endif namespace tvm { @@ -48,12 +49,9 @@ class Pipe : public dmlc::Stream { using PipeHandle = int; #endif /*! \brief Construct a pipe from system handle. */ - explicit Pipe(int64_t handle) - : handle_(static_cast(handle)) {} + explicit Pipe(int64_t handle) : handle_(static_cast(handle)) {} /*! \brief destructor */ - ~Pipe() { - Flush(); - } + ~Pipe() { Flush(); } using Stream::Read; using Stream::Write; /*! @@ -62,18 +60,16 @@ class Pipe : public dmlc::Stream { * \param size block size * \return the size of data read */ - size_t Read(void *ptr, size_t size) final { + size_t Read(void* ptr, size_t size) final { if (size == 0) return 0; #ifdef _WIN32 DWORD nread; - CHECK(ReadFile(handle_, static_cast(ptr), - &nread, nullptr)) + CHECK(ReadFile(handle_, static_cast(ptr), &nread, nullptr)) << "Read Error: " << GetLastError(); #else ssize_t nread; nread = read(handle_, ptr, size); - CHECK_GE(nread, 0) - << "Write Error: " << strerror(errno); + CHECK_GE(nread, 0) << "Write Error: " << strerror(errno); #endif return static_cast(nread); } @@ -83,19 +79,17 @@ class Pipe : public dmlc::Stream { * \param size block size * \return the size of data read */ - void Write(const void *ptr, size_t size) final { + void Write(const void* ptr, size_t size) final { if (size == 0) return; #ifdef _WIN32 DWORD nwrite; - CHECK(WriteFile(handle_, static_cast(ptr), - &nwrite, nullptr) && + CHECK(WriteFile(handle_, static_cast(ptr), &nwrite, nullptr) && static_cast(nwrite) == size) << "Write Error: " << GetLastError(); #else ssize_t nwrite; nwrite = write(handle_, ptr, size); - CHECK_EQ(static_cast(nwrite), size) - << "Write Error: " << strerror(errno); + CHECK_EQ(static_cast(nwrite), size) << "Write Error: " << strerror(errno); #endif } /*! diff --git a/src/support/ring_buffer.h b/src/support/ring_buffer.h index e6e3b04ec7a9..a3938491f1d1 100644 --- a/src/support/ring_buffer.h +++ b/src/support/ring_buffer.h @@ -24,9 +24,9 @@ #ifndef TVM_SUPPORT_RING_BUFFER_H_ #define TVM_SUPPORT_RING_BUFFER_H_ -#include -#include #include +#include +#include namespace tvm { namespace support { @@ -41,41 +41,48 @@ class RingBuffer { /*! \brief constructor */ RingBuffer() : ring_(kInitCapacity) {} /*! \return number of bytes available in buffer. */ - size_t bytes_available() const { - return bytes_available_; - } + size_t bytes_available() const { return bytes_available_; } /*! \return Current capacity of buffer. */ - size_t capacity() const { - return ring_.size(); - } + size_t capacity() const { return ring_.size(); } /*! - * Reserve capacity to be at least n. - * Will only increase capacity if n is bigger than current capacity. + * Reserve capacity to be at least n. + * Will only increase capacity if n is bigger than current capacity. + * + * The effect of Reserve only lasts before the next call to Reserve. + * Other functions in the ring buffer can also call into the reserve. + * * \param n The size of capacity. */ void Reserve(size_t n) { if (ring_.size() < n) { - size_t old_size = ring_.size(); - size_t new_size = static_cast(n * 1.2); - ring_.resize(new_size); - if (head_ptr_ + bytes_available_ > old_size) { - // copy the ring overflow part into the tail. - size_t ncopy = head_ptr_ + bytes_available_ - old_size; - memcpy(&ring_[0] + old_size, &ring_[0], ncopy); - } - } else if (ring_.size() > n * 8 && ring_.size() > kInitCapacity && bytes_available_ > 0) { - // shrink too large temporary buffer to avoid out of memory on some embedded devices + size_t old_size = ring_.size(); + size_t new_size = static_cast(n * 1.2); + ring_.resize(new_size); + if (head_ptr_ + bytes_available_ > old_size) { + // copy the ring overflow part into the tail. + size_t ncopy = head_ptr_ + bytes_available_ - old_size; + memcpy(&ring_[0] + old_size, &ring_[0], ncopy); + } + } else if (ring_.size() > n * 8 && ring_.size() > kInitCapacity) { + // shrink too large temporary buffer to + // avoid out of memory on some embedded devices + if (bytes_available_ != 0) { + // move existing bytes to the head. size_t old_bytes = bytes_available_; - std::vector tmp(old_bytes); - Read(&tmp[0], old_bytes); - ring_.resize(kInitCapacity); - ring_.shrink_to_fit(); memcpy(&ring_[0], &tmp[0], old_bytes); - head_ptr_ = 0; bytes_available_ = old_bytes; + } + // shrink the ring. + size_t new_size = kInitCapacity; + new_size = std::max(new_size, n); + new_size = std::max(new_size, bytes_available_); + + ring_.resize(new_size); + ring_.shrink_to_fit(); + head_ptr_ = 0; } } @@ -90,8 +97,7 @@ class RingBuffer { size_t ncopy = std::min(size, ring_.size() - head_ptr_); memcpy(data, &ring_[0] + head_ptr_, ncopy); if (ncopy < size) { - memcpy(reinterpret_cast(data) + ncopy, - &ring_[0], size - ncopy); + memcpy(reinterpret_cast(data) + ncopy, &ring_[0], size - ncopy); } head_ptr_ = (head_ptr_ + size) % ring_.size(); bytes_available_ -= size; @@ -103,7 +109,7 @@ class RingBuffer { * \param max_nbytes Maximum number of bytes can to read. * \tparam FSend A non-blocking function with signature size_t (const void* data, size_t size); */ - template + template size_t ReadWithCallback(FSend fsend, size_t max_nbytes) { size_t size = std::min(max_nbytes, bytes_available_); CHECK_NE(size, 0U); @@ -137,13 +143,13 @@ class RingBuffer { bytes_available_ += size; } /*! - * \brief Writen data into the buffer by give it a non-blocking callback function. + * \brief Written data into the buffer by give it a non-blocking callback function. * * \param frecv A receive function handle * \param max_nbytes Maximum number of bytes can write. * \tparam FRecv A non-blocking function with signature size_t (void* data, size_t size); */ - template + template size_t WriteWithCallback(FRecv frecv, size_t max_nbytes) { this->Reserve(bytes_available_ + max_nbytes); size_t nbytes = max_nbytes; @@ -168,9 +174,9 @@ class RingBuffer { private: // buffer head size_t head_ptr_{0}; - // number of bytes in the buffer. + // number of bytes occupied in the buffer. size_t bytes_available_{0}; - // The internald ata ring. + // The internal data ring. std::vector ring_; }; } // namespace support diff --git a/src/support/socket.h b/src/support/socket.h index aeb4626b5d47..3ccfaaab5ab5 100644 --- a/src/support/socket.h +++ b/src/support/socket.h @@ -35,26 +35,27 @@ using ssize_t = int; #pragma comment(lib, "Ws2_32.lib") #endif #else +#include +#include #include #include -#include -#include -#include #include -#include -#include #include +#include +#include +#include #endif #include -#include + #include -#include +#include #include +#include + #include "../support/util.h" #if defined(_WIN32) -static inline int poll(struct pollfd *pfd, int nfds, - int timeout) { +static inline int poll(struct pollfd* pfd, int nfds, int timeout) { return WSAPoll(pfd, nfds, timeout); } #else @@ -68,7 +69,8 @@ namespace support { * \return The hostname. */ inline std::string GetHostName() { - std::string buf; buf.resize(256); + std::string buf; + buf.resize(256); CHECK_NE(gethostname(&buf[0], 256), -1); return std::string(buf.c_str()); } @@ -100,16 +102,14 @@ struct SockAddr { * \param url The url of the address * \param port The port of the address. */ - SockAddr(const char *url, int port) { - this->Set(url, port); - } + SockAddr(const char* url, int port) { this->Set(url, port); } /*! - * \brief SockAddr Get the socket address from tracker. - * \param tracker The url containing the ip and port number. Format is ('192.169.1.100', 9090) - * \return SockAddr parsed from url. - */ - explicit SockAddr(const std::string &url) { + * \brief SockAddr Get the socket address from tracker. + * \param tracker The url containing the ip and port number. Format is ('192.169.1.100', 9090) + * \return SockAddr parsed from url. + */ + explicit SockAddr(const std::string& url) { size_t sep = url.find(","); std::string host = url.substr(2, sep - 3); std::string port = url.substr(sep + 1, url.length() - 1); @@ -125,31 +125,28 @@ struct SockAddr { * \param host the url of the address * \param port the port of address */ - void Set(const char *host, int port) { + void Set(const char* host, int port) { addrinfo hints; memset(&hints, 0, sizeof(hints)); hints.ai_family = PF_UNSPEC; hints.ai_flags = AI_PASSIVE; hints.ai_socktype = SOCK_STREAM; - addrinfo *res = NULL; + addrinfo* res = NULL; int sig = getaddrinfo(host, NULL, &hints, &res); - CHECK(sig == 0 && res != NULL) - << "cannot obtain address of " << host; + CHECK(sig == 0 && res != NULL) << "cannot obtain address of " << host; switch (res->ai_family) { case AF_INET: { - sockaddr_in *addr4 = reinterpret_cast(&addr); - memcpy(addr4, res->ai_addr, res->ai_addrlen); - addr4->sin_port = htons(port); - addr4->sin_family = AF_INET; - } - break; + sockaddr_in* addr4 = reinterpret_cast(&addr); + memcpy(addr4, res->ai_addr, res->ai_addrlen); + addr4->sin_port = htons(port); + addr4->sin_family = AF_INET; + } break; case AF_INET6: { - sockaddr_in6 *addr6 = reinterpret_cast(&addr); - memcpy(addr6, res->ai_addr, res->ai_addrlen); - addr6->sin6_port = htons(port); - addr6->sin6_family = AF_INET6; - } - break; + sockaddr_in6* addr6 = reinterpret_cast(&addr); + memcpy(addr6, res->ai_addr, res->ai_addrlen); + addr6->sin6_port = htons(port); + addr6->sin6_family = AF_INET6; + } break; default: CHECK(false) << "cannot decode address"; } @@ -157,35 +154,34 @@ struct SockAddr { } /*! \brief return port of the address */ int port() const { - return ntohs((addr.ss_family == AF_INET6)? \ - reinterpret_cast(&addr)->sin6_port : \ - reinterpret_cast(&addr)->sin_port); + return ntohs((addr.ss_family == AF_INET6) + ? reinterpret_cast(&addr)->sin6_port + : reinterpret_cast(&addr)->sin_port); } /*! \brief return the ip address family */ - int ss_family() const { - return addr.ss_family; - } + int ss_family() const { return addr.ss_family; } /*! \return a string representation of the address */ std::string AsString() const { - std::string buf; buf.resize(256); + std::string buf; + buf.resize(256); - const void *sinx_addr = nullptr; - if (addr.ss_family == AF_INET6) { - const in6_addr& addr6 = reinterpret_cast(&addr)->sin6_addr; - sinx_addr = reinterpret_cast(&addr6); - } else if (addr.ss_family == AF_INET) { - const in_addr& addr4 = reinterpret_cast(&addr)->sin_addr; - sinx_addr = reinterpret_cast(&addr4); - } else { - CHECK(false) << "illegal address"; - } + const void* sinx_addr = nullptr; + if (addr.ss_family == AF_INET6) { + const in6_addr& addr6 = reinterpret_cast(&addr)->sin6_addr; + sinx_addr = reinterpret_cast(&addr6); + } else if (addr.ss_family == AF_INET) { + const in_addr& addr4 = reinterpret_cast(&addr)->sin_addr; + sinx_addr = reinterpret_cast(&addr4); + } else { + CHECK(false) << "illegal address"; + } #ifdef _WIN32 - const char *s = inet_ntop(addr.ss_family, (PVOID)sinx_addr, // NOLINT(*) + const char* s = inet_ntop(addr.ss_family, (PVOID)sinx_addr, // NOLINT(*) &buf[0], buf.length()); #else - const char *s = inet_ntop(addr.ss_family, sinx_addr, - &buf[0], static_cast(buf.length())); + const char* s = + inet_ntop(addr.ss_family, sinx_addr, &buf[0], static_cast(buf.length())); #endif CHECK(s != nullptr) << "cannot decode address"; std::ostringstream os; @@ -238,10 +234,10 @@ class Socket { * \brief bind the socket to an address * \param addr The address to be binded */ - void Bind(const SockAddr &addr) { + void Bind(const SockAddr& addr) { if (bind(sockfd, reinterpret_cast(&addr.addr), - (addr.addr.ss_family == AF_INET6 ? sizeof(sockaddr_in6) : - sizeof(sockaddr_in))) == -1) { + (addr.addr.ss_family == AF_INET6 ? sizeof(sockaddr_in6) : sizeof(sockaddr_in))) == + -1) { Socket::Error("Bind"); } } @@ -256,8 +252,8 @@ class Socket { for (int port = start_port; port < end_port; ++port) { SockAddr addr(host.c_str(), port); if (bind(sockfd, reinterpret_cast(&addr.addr), - (addr.addr.ss_family == AF_INET6 ? sizeof(sockaddr_in6) : - sizeof(sockaddr_in))) == 0) { + (addr.addr.ss_family == AF_INET6 ? sizeof(sockaddr_in6) : sizeof(sockaddr_in))) == + 0) { return port; } else { LOG(WARNING) << "Bind failed to " << host << ":" << port; @@ -278,7 +274,7 @@ class Socket { int GetSockError() const { int error = 0; socklen_t len = sizeof(error); - if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR, reinterpret_cast(&error), &len) != 0) { + if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR, reinterpret_cast(&error), &len) != 0) { Error("GetSockError"); } return error; @@ -291,9 +287,7 @@ class Socket { return false; } /*! \brief check if socket is already closed */ - bool IsClosed() const { - return sockfd == INVALID_SOCKET; - } + bool IsClosed() const { return sockfd == INVALID_SOCKET; } /*! \brief close the socket */ void Close() { if (sockfd != INVALID_SOCKET) { @@ -354,7 +348,7 @@ class Socket { * \brief Report an socket error. * \param msg The error message. */ - static void Error(const char *msg) { + static void Error(const char* msg) { int errsv = GetLastError(); #ifdef _WIN32 LOG(FATAL) << "Socket " << msg << " Error:WSAError-code=" << errsv; @@ -364,8 +358,7 @@ class Socket { } protected: - explicit Socket(SockType sockfd) : sockfd(sockfd) { - } + explicit Socket(SockType sockfd) : sockfd(sockfd) {} }; /*! @@ -373,22 +366,20 @@ class Socket { */ class TCPSocket : public Socket { public: - TCPSocket() : Socket(INVALID_SOCKET) { - } + TCPSocket() : Socket(INVALID_SOCKET) {} /*! * \brief construct a TCP socket from existing descriptor * \param sockfd The descriptor */ - explicit TCPSocket(SockType sockfd) : Socket(sockfd) { - } + explicit TCPSocket(SockType sockfd) : Socket(sockfd) {} /*! * \brief enable/disable TCP keepalive * \param keepalive whether to set the keep alive option on */ void SetKeepAlive(bool keepalive) { int opt = static_cast(keepalive); - if (setsockopt(sockfd, SOL_SOCKET, SO_KEEPALIVE, - reinterpret_cast(&opt), sizeof(opt)) < 0) { + if (setsockopt(sockfd, SOL_SOCKET, SO_KEEPALIVE, reinterpret_cast(&opt), sizeof(opt)) < + 0) { Socket::Error("SetKeepAlive"); } } @@ -406,9 +397,7 @@ class TCPSocket : public Socket { * \brief perform listen of the socket * \param backlog backlog parameter */ - void Listen(int backlog = 16) { - listen(sockfd, backlog); - } + void Listen(int backlog = 16) { listen(sockfd, backlog); } /*! * \brief get a new connection * \return The accepted socket connection. @@ -421,14 +410,13 @@ class TCPSocket : public Socket { return TCPSocket(newfd); } /*! - * \brief get a new connection - * \param addr client address from which connection accepted - * \return The accepted socket connection. - */ - TCPSocket Accept(SockAddr *addr) { + * \brief get a new connection + * \param addr client address from which connection accepted + * \return The accepted socket connection. + */ + TCPSocket Accept(SockAddr* addr) { socklen_t addrlen = sizeof(addr->addr); - SockType newfd = accept(sockfd, reinterpret_cast(&addr->addr), - &addrlen); + SockType newfd = accept(sockfd, reinterpret_cast(&addr->addr), &addrlen); if (newfd == INVALID_SOCKET) { Socket::Error("Accept"); } @@ -453,10 +441,10 @@ class TCPSocket : public Socket { * \param addr the address to connect to * \return whether connect is successful */ - bool Connect(const SockAddr &addr) { - return connect(sockfd, reinterpret_cast(&addr.addr), - (addr.addr.ss_family == AF_INET6 ? sizeof(sockaddr_in6) : - sizeof(sockaddr_in))) == 0; + bool Connect(const SockAddr& addr) { + return connect( + sockfd, reinterpret_cast(&addr.addr), + (addr.addr.ss_family == AF_INET6 ? sizeof(sockaddr_in6) : sizeof(sockaddr_in))) == 0; } /*! * \brief send data using the socket @@ -466,8 +454,8 @@ class TCPSocket : public Socket { * \return size of data actually sent * return -1 if error occurs */ - ssize_t Send(const void *buf_, size_t len, int flag = 0) { - const char *buf = reinterpret_cast(buf_); + ssize_t Send(const void* buf_, size_t len, int flag = 0) { + const char* buf = reinterpret_cast(buf_); return send(sockfd, buf, static_cast(len), flag); } /*! @@ -478,8 +466,8 @@ class TCPSocket : public Socket { * \return size of data actually received * return -1 if error occurs */ - ssize_t Recv(void *buf_, size_t len, int flags = 0) { - char *buf = reinterpret_cast(buf_); + ssize_t Recv(void* buf_, size_t len, int flags = 0) { + char* buf = reinterpret_cast(buf_); return recv(sockfd, buf, static_cast(len), flags); } /*! @@ -489,10 +477,10 @@ class TCPSocket : public Socket { * \param len the size of the buffer * \return size of data actually sent */ - size_t SendAll(const void *buf_, size_t len) { - const char *buf = reinterpret_cast(buf_); + size_t SendAll(const void* buf_, size_t len) { + const char* buf = reinterpret_cast(buf_); size_t ndone = 0; - while (ndone < len) { + while (ndone < len) { ssize_t ret = send(sockfd, buf, static_cast(len - ndone), 0); if (ret == -1) { if (LastErrorWouldBlock()) return ndone; @@ -510,14 +498,13 @@ class TCPSocket : public Socket { * \param len length of data to recv * \return size of data actually sent */ - size_t RecvAll(void *buf_, size_t len) { - char *buf = reinterpret_cast(buf_); + size_t RecvAll(void* buf_, size_t len) { + char* buf = reinterpret_cast(buf_); size_t ndone = 0; - while (ndone < len) { - ssize_t ret = recv(sockfd, buf, - static_cast(len - ndone), MSG_WAITALL); + while (ndone < len) { + ssize_t ret = recv(sockfd, buf, static_cast(len - ndone), MSG_WAITALL); if (ret == -1) { - if (LastErrorWouldBlock()) { + if (LastErrorWouldBlock()) { LOG(FATAL) << "would block"; return ndone; } @@ -612,7 +599,7 @@ struct PollHelper { * \param timeout the timeout counter, can be negative, which means wait until the event happen * \return 1 if success, 0 if timeout, and -1 if error occurs */ - inline static int WaitExcept(TCPSocket::SockType fd, long timeout = -1) { // NOLINT(*) + inline static int WaitExcept(TCPSocket::SockType fd, long timeout = -1) { // NOLINT(*) pollfd pfd; pfd.fd = fd; pfd.events = POLLPRI; diff --git a/src/support/str_escape.h b/src/support/str_escape.h index fd25c019e6dc..65eec682086e 100644 --- a/src/support/str_escape.h +++ b/src/support/str_escape.h @@ -25,8 +25,8 @@ #ifndef TVM_SUPPORT_STR_ESCAPE_H_ #define TVM_SUPPORT_STR_ESCAPE_H_ -#include #include +#include namespace tvm { namespace support { @@ -76,9 +76,7 @@ inline std::string StrEscape(const char* data, size_t size) { * \param size The size of the string. * \return the Result string. */ -inline std::string StrEscape(const std::string& val) { - return StrEscape(val.data(), val.length()); -} +inline std::string StrEscape(const std::string& val) { return StrEscape(val.data(), val.length()); } } // namespace support } // namespace tvm diff --git a/src/support/util.h b/src/support/util.h index 9a477e6f81f2..859b372bd761 100644 --- a/src/support/util.h +++ b/src/support/util.h @@ -26,16 +26,16 @@ #include #ifndef _WIN32 -#include #include +#include #endif -#include -#include -#include #include #include #include #include +#include +#include +#include namespace tvm { namespace support { @@ -92,15 +92,14 @@ inline int TVMWexitstatus(int status) { #endif } - /*! * \brief IsNumber check whether string is a number. * \param str input string * \return result of operation. */ inline bool IsNumber(const std::string& str) { - return !str.empty() && std::find_if(str.begin(), - str.end(), [](char c) { return !std::isdigit(c); }) == str.end(); + return !str.empty() && + std::find_if(str.begin(), str.end(), [](char c) { return !std::isdigit(c); }) == str.end(); } /*! diff --git a/src/target/build_common.h b/src/target/build_common.h index 93687c2578ac..ec5b522397ed 100644 --- a/src/target/build_common.h +++ b/src/target/build_common.h @@ -24,27 +24,27 @@ #ifndef TVM_TARGET_BUILD_COMMON_H_ #define TVM_TARGET_BUILD_COMMON_H_ -#include -#include -#include #include -#include +#include +#include +#include #include +#include #include -#include + #include +#include + #include "../runtime/meta_data.h" namespace tvm { namespace codegen { -inline std::unordered_map -ExtractFuncInfo(const IRModule& mod) { +inline std::unordered_map ExtractFuncInfo(const IRModule& mod) { std::unordered_map fmap; - for (auto kv : mod->functions) { - CHECK(kv.second->IsInstance()) - << "Can only lower IR Module with PrimFuncs"; + for (auto kv : mod->functions) { + CHECK(kv.second->IsInstance()) << "Can only lower IR Module with PrimFuncs"; auto f = Downcast(kv.second); runtime::FunctionInfo info; diff --git a/src/target/codegen.cc b/src/target/codegen.cc index 0eceea81da17..d0a3156c9cc9 100644 --- a/src/target/codegen.cc +++ b/src/target/codegen.cc @@ -21,46 +21,43 @@ * \file codegen.cc * \brief Common utilities to generated C style code. */ +#include +#include +#include +#include +#include +#include #include #include - -#include -#include -#include #include +#include -#include -#include -#include -#include -#include -#include -#include #include -#include #include +#include +#include +#include namespace tvm { namespace codegen { runtime::Module Build(IRModule mod, const Target& target) { - if (BuildConfig::Current()->disable_assert) { + if (transform::PassContext::Current() + ->GetConfig("tir.disable_assert", Bool(false)) + .value()) { mod = tir::transform::SkipAssert()(mod); } std::string build_f_name = "target.build." + target->target_name; // the build function. const PackedFunc* bf = runtime::Registry::Get(build_f_name); - CHECK(bf != nullptr) - << "target.build." << target << " is not enabled"; + CHECK(bf != nullptr) << "target.build." << target << " is not enabled"; return (*bf)(mod, target->str()); } /*! \brief Helper class to serialize module */ class ModuleSerializer { public: - explicit ModuleSerializer(runtime::Module mod) : mod_(mod) { - Init(); - } + explicit ModuleSerializer(runtime::Module mod) : mod_(mod) { Init(); } void SerializeModule(dmlc::Stream* stream) { // Only have one DSO module and it is in the root, then @@ -109,8 +106,8 @@ class ModuleSerializer { // invariance: root module is always at location 0. // The module order is collected via DFS void CreateModuleIndex() { - std::unordered_set visited {mod_.operator->()}; - std::vector stack {mod_.operator->()}; + std::unordered_set visited{mod_.operator->()}; + std::vector stack{mod_.operator->()}; uint64_t module_index = 0; while (!stack.empty()) { @@ -139,8 +136,7 @@ class ModuleSerializer { } bool DSOExportable(const runtime::ModuleNode* mod) { - return !std::strcmp(mod->type_key(), "llvm") || - !std::strcmp(mod->type_key(), "c"); + return !std::strcmp(mod->type_key(), "llvm") || !std::strcmp(mod->type_key(), "c"); } runtime::Module mod_; @@ -148,21 +144,21 @@ class ModuleSerializer { std::unordered_map mod2index_; // index -> module std::vector mod_vec_; - std::vector import_tree_row_ptr_ {0}; + std::vector import_tree_row_ptr_{0}; std::vector import_tree_child_indices_; }; namespace { - std::string SerializeModule(const runtime::Module& mod) { - std::string bin; - dmlc::MemoryStringStream ms(&bin); - dmlc::Stream* stream = &ms; +std::string SerializeModule(const runtime::Module& mod) { + std::string bin; + dmlc::MemoryStringStream ms(&bin); + dmlc::Stream* stream = &ms; - ModuleSerializer module_serializer(mod); - module_serializer.SerializeModule(stream); + ModuleSerializer module_serializer(mod); + module_serializer.SerializeModule(stream); - return bin; - } + return bin; +} } // namespace std::string PackImportsToC(const runtime::Module& mod, bool system_lib) { @@ -180,8 +176,8 @@ std::string PackImportsToC(const runtime::Module& mod, bool system_lib) { << "#endif\n"; os << "TVM_EXPORT extern const unsigned char " << runtime::symbol::tvm_dev_mblob << "[];\n"; uint64_t nbytes = bin.length(); - os << "const unsigned char " << runtime::symbol::tvm_dev_mblob - << "[" << bin.length() + sizeof(nbytes) << "] = {\n "; + os << "const unsigned char " << runtime::symbol::tvm_dev_mblob << "[" + << bin.length() + sizeof(nbytes) << "] = {\n "; os << std::hex; size_t nunit = 80 / 4; for (size_t i = 0; i < sizeof(nbytes); ++i) { @@ -214,8 +210,7 @@ std::string PackImportsToC(const runtime::Module& mod, bool system_lib) { return os.str(); } -runtime::Module PackImportsToLLVM(const runtime::Module& mod, - bool system_lib, +runtime::Module PackImportsToLLVM(const runtime::Module& mod, bool system_lib, const std::string& target_triple) { std::string bin = SerializeModule(mod); @@ -233,19 +228,16 @@ runtime::Module PackImportsToLLVM(const runtime::Module& mod, std::string codegen_f_name = "codegen.codegen_blob"; // the codegen function. const PackedFunc* codegen_f = runtime::Registry::Get(codegen_f_name); - CHECK(codegen_f != nullptr) << "codegen.codegen_blob is not presented."; + CHECK(codegen_f != nullptr) << "codegen.codegen_blob is not presented."; return (*codegen_f)(blob_byte_array, system_lib, target_triple); } -TVM_REGISTER_GLOBAL("target.Build") -.set_body_typed(Build); +TVM_REGISTER_GLOBAL("target.Build").set_body_typed(Build); // Export two auxiliary function to the runtime namespace. -TVM_REGISTER_GLOBAL("runtime.ModulePackImportsToC") -.set_body_typed(PackImportsToC); +TVM_REGISTER_GLOBAL("runtime.ModulePackImportsToC").set_body_typed(PackImportsToC); -TVM_REGISTER_GLOBAL("runtime.ModulePackImportsToLLVM") -.set_body_typed(PackImportsToLLVM); +TVM_REGISTER_GLOBAL("runtime.ModulePackImportsToLLVM").set_body_typed(PackImportsToLLVM); } // namespace codegen } // namespace tvm diff --git a/src/target/datatype/registry.cc b/src/target/datatype/registry.cc index c16182da3674..5ed3ce4f7c03 100644 --- a/src/target/datatype/registry.cc +++ b/src/target/datatype/registry.cc @@ -16,34 +16,32 @@ * specific language governing permissions and limitations * under the License. */ -#include #include "registry.h" +#include + namespace tvm { namespace datatype { using runtime::TVMArgs; using runtime::TVMRetValue; -TVM_REGISTER_GLOBAL("runtime._datatype_register") -.set_body([](TVMArgs args, TVMRetValue* ret) { +TVM_REGISTER_GLOBAL("runtime._datatype_register").set_body([](TVMArgs args, TVMRetValue* ret) { datatype::Registry::Global()->Register(args[0], static_cast(args[1].operator int())); }); -TVM_REGISTER_GLOBAL("runtime._datatype_get_type_code") -.set_body([](TVMArgs args, TVMRetValue* ret) { +TVM_REGISTER_GLOBAL("runtime._datatype_get_type_code").set_body([](TVMArgs args, TVMRetValue* ret) { *ret = datatype::Registry::Global()->GetTypeCode(args[0]); }); -TVM_REGISTER_GLOBAL("runtime._datatype_get_type_name") -.set_body([](TVMArgs args, TVMRetValue* ret) { +TVM_REGISTER_GLOBAL("runtime._datatype_get_type_name").set_body([](TVMArgs args, TVMRetValue* ret) { *ret = Registry::Global()->GetTypeName(args[0].operator int()); }); TVM_REGISTER_GLOBAL("runtime._datatype_get_type_registered") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = Registry::Global()->GetTypeRegistered(args[0].operator int()); -}); + .set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = Registry::Global()->GetTypeRegistered(args[0].operator int()); + }); Registry* Registry::Global() { static Registry inst; @@ -51,8 +49,8 @@ Registry* Registry::Global() { } void Registry::Register(const std::string& type_name, uint8_t type_code) { - CHECK(type_code >= kTVMCustomBegin) - << "Please choose a type code >= kTVMCustomBegin for custom types"; + CHECK(type_code >= DataType::kCustomBegin) + << "Please choose a type code >= DataType::kCustomBegin for custom types"; code_to_name_[type_code] = type_name; name_to_code_[type_name] = type_code; } @@ -80,7 +78,7 @@ const runtime::PackedFunc* GetCastLowerFunc(const std::string& target, uint8_t t if (datatype::Registry::Global()->GetTypeRegistered(type_code)) { ss << datatype::Registry::Global()->GetTypeName(type_code); } else { - ss << runtime::TypeCode2Str(type_code); + ss << runtime::DLDataTypeCode2Str(static_cast(type_code)); } ss << "."; @@ -88,7 +86,7 @@ const runtime::PackedFunc* GetCastLowerFunc(const std::string& target, uint8_t t if (datatype::Registry::Global()->GetTypeRegistered(src_type_code)) { ss << datatype::Registry::Global()->GetTypeName(src_type_code); } else { - ss << runtime::TypeCode2Str(src_type_code); + ss << runtime::DLDataTypeCode2Str(static_cast(src_type_code)); } return runtime::Registry::Get(ss.str()); } diff --git a/src/target/datatype/registry.h b/src/target/datatype/registry.h index 919409f6e4f3..5df8ef8164db 100644 --- a/src/target/datatype/registry.h +++ b/src/target/datatype/registry.h @@ -22,6 +22,7 @@ #include #include + #include #include @@ -60,7 +61,7 @@ class Registry { * same code. Generally, this should be straightforward, as the user will be manually registering * all of their custom types. * \param type_name The name of the type, e.g. "bfloat" - * \param type_code The type code, which should be greater than TVMTypeCode::kTVMExtEnd + * \param type_code The type code, which should be greater than TVMArgTypeCode::kTVMExtEnd */ void Register(const std::string& type_name, uint8_t type_code); @@ -69,7 +70,7 @@ class Registry { * \param type_name The type name * \return The type code */ - uint8_t GetTypeCode(const std::string &type_name); + uint8_t GetTypeCode(const std::string& type_name); /*! * \brief Get type name from type code diff --git a/src/target/generic_func.cc b/src/target/generic_func.cc index 44d017f4ac5b..9ad9f56f7c58 100644 --- a/src/target/generic_func.cc +++ b/src/target/generic_func.cc @@ -20,14 +20,12 @@ * \file src/target/generic_func.cc */ #include - -#include -#include #include #include -#include -#include +#include #include +#include +#include #include #include @@ -43,8 +41,7 @@ struct GenericFunc::Manager { // mutex std::mutex mutex; - Manager() { - } + Manager() {} static Manager* Global() { static Manager inst; @@ -76,25 +73,23 @@ void GenericFunc::RegisterGenericFunc(GenericFunc func, const std::string& name) m->fmap[name] = func; } -GenericFunc& GenericFunc::set_default(const PackedFunc value, - bool allow_override) { +GenericFunc& GenericFunc::set_default(const PackedFunc value, bool allow_override) { auto node = static_cast(operator->()); if (!allow_override) { CHECK(node->generic_func_ == nullptr) - << "Generic function already registered for " << node->name_; + << "Generic function already registered for " << node->name_; } node->generic_func_ = value; return *this; } GenericFunc& GenericFunc::register_func(const std::vector& tags, - const PackedFunc value, - bool allow_override) { - for (auto &t : tags) { + const PackedFunc value, bool allow_override) { + for (auto& t : tags) { if (!allow_override) { auto iter = (*this)->dispatch_dict_.find(t); CHECK(iter == (*this)->dispatch_dict_.end()) - << "Tag " << t << " already registered for schedule factory " << (*this)->name_; + << "Tag " << t << " already registered for schedule factory " << (*this)->name_; } (*this)->dispatch_dict_[t] = value; } @@ -107,7 +102,7 @@ void GenericFunc::CallPacked(TVMArgs args, TVMRetValue* ret) const { PackedFunc func; if (target.defined()) { - for (auto &k : target->keys()) { + for (auto& k : target->keys()) { auto iter = node->dispatch_dict_.find(k); if (iter != node->dispatch_dict_.end()) { func = iter->second; @@ -124,30 +119,25 @@ void GenericFunc::CallPacked(TVMArgs args, TVMRetValue* ret) const { func.CallPacked(args, ret); } -TVM_REGISTER_GLOBAL("target.GenericFuncCreate") -.set_body([](TVMArgs args, TVMRetValue* ret) { +TVM_REGISTER_GLOBAL("target.GenericFuncCreate").set_body([](TVMArgs args, TVMRetValue* ret) { *ret = GenericFunc(make_object()); - }); +}); -TVM_REGISTER_GLOBAL("target.GenericFuncGetGlobal") -.set_body([](TVMArgs args, TVMRetValue* ret) { +TVM_REGISTER_GLOBAL("target.GenericFuncGetGlobal").set_body([](TVMArgs args, TVMRetValue* ret) { std::string func_name = args[0]; *ret = GenericFunc::Get(func_name); - }); +}); -TVM_REGISTER_GLOBAL("target.GenericFuncSetDefault") -.set_body([](TVMArgs args, TVMRetValue* ret) { +TVM_REGISTER_GLOBAL("target.GenericFuncSetDefault").set_body([](TVMArgs args, TVMRetValue* ret) { GenericFunc generic_func = args[0]; // Intentionally copy and not de-allocate it, to avoid free pyobject during shutdown PackedFunc* func = new PackedFunc(args[1].operator PackedFunc()); bool allow_override = args[2]; - generic_func - .set_default(*func, allow_override); - }); + generic_func.set_default(*func, allow_override); +}); -TVM_REGISTER_GLOBAL("target.GenericFuncRegisterFunc") -.set_body([](TVMArgs args, TVMRetValue* ret) { +TVM_REGISTER_GLOBAL("target.GenericFuncRegisterFunc").set_body([](TVMArgs args, TVMRetValue* ret) { GenericFunc generic_func = args[0]; // Intentionally copy and not de-allocate it, to avoid free pyobject during shutdown PackedFunc* func = new PackedFunc(args[1].operator PackedFunc()); @@ -159,17 +149,14 @@ TVM_REGISTER_GLOBAL("target.GenericFuncRegisterFunc") tags_vector.push_back(tag); } - generic_func - .register_func(tags_vector, *func, allow_override); - }); + generic_func.register_func(tags_vector, *func, allow_override); +}); -TVM_REGISTER_GLOBAL("target.GenericFuncCallFunc") -.set_body([](TVMArgs args, TVMRetValue* ret) { +TVM_REGISTER_GLOBAL("target.GenericFuncCallFunc").set_body([](TVMArgs args, TVMRetValue* ret) { GenericFunc generic_func = args[0]; TVMArgs func_args(&args.values[1], &args.type_codes[1], args.num_args - 1); - generic_func - .CallPacked(func_args, ret); - }); + generic_func.CallPacked(func_args, ret); +}); } // namespace tvm diff --git a/src/target/intrin_rule.cc b/src/target/intrin_rule.cc index 5d393ab8ebb2..37855fb39179 100644 --- a/src/target/intrin_rule.cc +++ b/src/target/intrin_rule.cc @@ -21,96 +21,99 @@ * \file intrin_rule_default.cc * \brief Default intrinsic rules. */ -#include #include "intrin_rule.h" +#include + namespace tvm { namespace codegen { namespace intrin { -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.exp") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.exp").set_body(DispatchExtern); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.erf").set_body(DispatchExtern); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log").set_body(DispatchExtern); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log2").set_body(DispatchExtern); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log10").set_body(DispatchExtern); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log1p").set_body(DispatchExtern); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.tanh").set_body(DispatchExtern); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.tan").set_body(DispatchExtern); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.atan").set_body(DispatchExtern); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.atanh").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.erf") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.atan2").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.cos").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log1p") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.acos").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.tanh") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.cosh").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.tan") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.acosh").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.cos") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sin").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sin") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.asin").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.atan") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sinh").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.atan2") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.asinh").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.hypot") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.hypot").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.nextafter") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.nextafter").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.copysign") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.copysign").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.ldexp") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.ldexp").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sqrt") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sqrt").set_body(DispatchExtern); TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.rsqrt") -.set_body([](const TVMArgs& args, TVMRetValue* rv){ - PrimExpr e = args[0]; - const CallNode* call = e.as(); - CHECK(call != nullptr); + .set_body([](const TVMArgs& args, TVMRetValue* rv) { + PrimExpr e = args[0]; + const CallNode* call = e.as(); + CHECK(call != nullptr); - auto one = make_const(call->args[0].dtype(), 1); - *rv = one / sqrt(call->args[0]); - }); + auto one = make_const(call->args[0].dtype(), 1); + *rv = one / sqrt(call->args[0]); + }); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.pow") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.pow").set_body(DispatchExtern); TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sigmoid") -.set_body([](const TVMArgs& args, TVMRetValue* rv){ - PrimExpr e = args[0]; - const CallNode* call = e.as(); - CHECK(call != nullptr); + .set_body([](const TVMArgs& args, TVMRetValue* rv) { + PrimExpr e = args[0]; + const CallNode* call = e.as(); + CHECK(call != nullptr); - auto one = make_const(call->args[0].dtype(), 1); - *rv = one / (one + exp(-call->args[0])); - }); + auto one = make_const(call->args[0].dtype(), 1); + *rv = one / (one + exp(-call->args[0])); + }); TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.isfinite") -.set_body([](const TVMArgs& args, TVMRetValue* rv){ - PrimExpr e = args[0]; - const CallNode* call = e.as(); - CHECK(call != nullptr); - *rv = isfinite(call->args[0]); - }); + .set_body([](const TVMArgs& args, TVMRetValue* rv) { + PrimExpr e = args[0]; + const CallNode* call = e.as(); + CHECK(call != nullptr); + *rv = isfinite(call->args[0]); + }); TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.isinf") -.set_body([](const TVMArgs& args, TVMRetValue* rv){ - PrimExpr e = args[0]; - const CallNode* call = e.as(); - CHECK(call != nullptr); - *rv = isinf(call->args[0]); - }); + .set_body([](const TVMArgs& args, TVMRetValue* rv) { + PrimExpr e = args[0]; + const CallNode* call = e.as(); + CHECK(call != nullptr); + *rv = isinf(call->args[0]); + }); } // namespace intrin } // namespace codegen diff --git a/src/target/intrin_rule.h b/src/target/intrin_rule.h index 091474254114..5a23e83af219 100644 --- a/src/target/intrin_rule.h +++ b/src/target/intrin_rule.h @@ -24,9 +24,9 @@ #ifndef TVM_TARGET_INTRIN_RULE_H_ #define TVM_TARGET_INTRIN_RULE_H_ -#include -#include #include +#include + #include namespace tvm { @@ -49,21 +49,18 @@ struct FloatSuffix { // Return the intrinsic name struct Direct { - std::string operator()(DataType t, std::string name) const { - return name; - } + std::string operator()(DataType t, std::string name) const { return name; } }; // Call pure extern function. -template +template inline void DispatchExtern(const TVMArgs& args, TVMRetValue* rv) { PrimExpr e = args[0]; const CallNode* call = e.as(); CHECK(call != nullptr); std::string name = T()(call->dtype, call->name); if (name.length() != 0) { - *rv = CallNode::make( - call->dtype, name, call->args, CallNode::PureExtern); + *rv = Call(call->dtype, name, call->args, CallNode::PureExtern); } else { *rv = e; } diff --git a/src/target/llvm/codegen_amdgpu.cc b/src/target/llvm/codegen_amdgpu.cc index 61121f67d111..8e6b3a2ff22c 100644 --- a/src/target/llvm/codegen_amdgpu.cc +++ b/src/target/llvm/codegen_amdgpu.cc @@ -23,12 +23,13 @@ */ #ifdef TVM_LLVM_VERSION -#include #include +#include #include -#include "codegen_llvm.h" -#include "../build_common.h" + #include "../../runtime/rocm/rocm_module.h" +#include "../build_common.h" +#include "codegen_llvm.h" namespace tvm { namespace codegen { @@ -45,8 +46,8 @@ static inline int DetectROCMmaxThreadsPerBlock() { TVMRetValue val; api->GetAttr(tvm_ctx, tvm::runtime::kExist, &val); if (val.operator int() == 1) { - tvm::runtime::DeviceAPI::Get(tvm_ctx)-> - GetAttr(tvm_ctx, tvm::runtime::kMaxThreadsPerBlock, &val); + tvm::runtime::DeviceAPI::Get(tvm_ctx)->GetAttr(tvm_ctx, tvm::runtime::kMaxThreadsPerBlock, + &val); return val.operator int(); } } @@ -73,8 +74,7 @@ class CodeGenAMDGPU : public CodeGenLLVM { llvm::Value* buf = nullptr; int32_t constant_size = op->constant_allocation_size(); - CHECK_GT(constant_size, 0) - << "Can only handle constant size stack allocation in GPU"; + CHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation in GPU"; StorageInfo& info = alloc_storage_info_[op->buffer_var.get()]; if (constant_size % 4 == 0 && info.alignment == 0) { @@ -88,9 +88,8 @@ class CodeGenAMDGPU : public CodeGenLLVM { // const int local_address_space = 5; // TODO(tqchen): for higher version of LLVM, local address space can be set. llvm::AllocaInst* alloca = WithFunctionEntry([&]() { - return builder_->CreateAlloca( - DTypeToLLVMType(op->dtype), ConstInt32(constant_size)); - }); + return builder_->CreateAlloca(DTypeToLLVMType(op->dtype), ConstInt32(constant_size)); + }); if (alloca->getAlignment() < static_cast(info.alignment)) { #if TVM_LLVM_VERSION >= 100 alloca->setAlignment(llvm::Align(info.alignment)); @@ -104,12 +103,11 @@ class CodeGenAMDGPU : public CodeGenLLVM { << "Can only allocate shared or local memory inside kernel"; // Shared memory: address space == 3 const unsigned shared_address_space = 3; - llvm::Type* type = llvm::ArrayType::get( - DTypeToLLVMType(op->dtype), constant_size); + llvm::Type* type = llvm::ArrayType::get(DTypeToLLVMType(op->dtype), constant_size); // Allocate shared memory in global, address_space = 3 - llvm::GlobalVariable *global = new llvm::GlobalVariable( - *module_, type, false, llvm::GlobalValue::PrivateLinkage, 0, ".shared", - nullptr, llvm::GlobalValue::NotThreadLocal, shared_address_space); + llvm::GlobalVariable* global = new llvm::GlobalVariable( + *module_, type, false, llvm::GlobalValue::PrivateLinkage, 0, ".shared", nullptr, + llvm::GlobalValue::NotThreadLocal, shared_address_space); #if TVM_LLVM_VERSION >= 100 global->setAlignment(llvm::Align(info.alignment)); #else @@ -119,8 +117,7 @@ class CodeGenAMDGPU : public CodeGenLLVM { } buf = builder_->CreatePointerCast( - buf, DTypeToLLVMType(op->dtype)->getPointerTo( - buf->getType()->getPointerAddressSpace())); + buf, DTypeToLLVMType(op->dtype)->getPointerTo(buf->getType()->getPointerAddressSpace())); CHECK(!var_map_.count(op->buffer_var.get())); var_map_[op->buffer_var.get()] = buf; this->VisitStmt(op->body); @@ -128,22 +125,36 @@ class CodeGenAMDGPU : public CodeGenLLVM { // Return the thread index via intrinsics. llvm::Value* GetThreadIndex(const IterVar& iv) final { - runtime::ThreadScope ts = runtime::ThreadScope::make(iv->thread_tag); + runtime::ThreadScope ts = runtime::ThreadScope::Create(iv->thread_tag); llvm::Intrinsic::ID intrin_id = ::llvm::Intrinsic::amdgcn_workitem_id_x; if (ts.rank == 1) { switch (ts.dim_index) { - case 0: intrin_id = ::llvm::Intrinsic::amdgcn_workitem_id_x; break; - case 1: intrin_id = ::llvm::Intrinsic::amdgcn_workitem_id_y; break; - case 2: intrin_id = ::llvm::Intrinsic::amdgcn_workitem_id_z; break; - default: LOG(FATAL) << "unknown workitem idx"; + case 0: + intrin_id = ::llvm::Intrinsic::amdgcn_workitem_id_x; + break; + case 1: + intrin_id = ::llvm::Intrinsic::amdgcn_workitem_id_y; + break; + case 2: + intrin_id = ::llvm::Intrinsic::amdgcn_workitem_id_z; + break; + default: + LOG(FATAL) << "unknown workitem idx"; } } else { CHECK_EQ(ts.rank, 0); switch (ts.dim_index) { - case 0: intrin_id = ::llvm::Intrinsic::amdgcn_workgroup_id_x; break; - case 1: intrin_id = ::llvm::Intrinsic::amdgcn_workgroup_id_y; break; - case 2: intrin_id = ::llvm::Intrinsic::amdgcn_workgroup_id_z; break; - default: LOG(FATAL) << "unknown workgroup idx"; + case 0: + intrin_id = ::llvm::Intrinsic::amdgcn_workgroup_id_x; + break; + case 1: + intrin_id = ::llvm::Intrinsic::amdgcn_workgroup_id_y; + break; + case 2: + intrin_id = ::llvm::Intrinsic::amdgcn_workgroup_id_z; + break; + default: + LOG(FATAL) << "unknown workgroup idx"; } } llvm::Function* f = llvm::Intrinsic::getDeclaration(module_.get(), intrin_id); @@ -155,9 +166,8 @@ class CodeGenAMDGPU : public CodeGenLLVM { if (sync == "warp") { return nullptr; } else if (sync == "shared") { - llvm::Function* f = llvm::Intrinsic::getDeclaration( - module_.get(), - ::llvm::Intrinsic::amdgcn_s_barrier); + llvm::Function* f = + llvm::Intrinsic::getDeclaration(module_.get(), ::llvm::Intrinsic::amdgcn_s_barrier); return builder_->CreateCall(f, {}); } else { LOG(FATAL) << "Do not support sync " << sync; @@ -169,9 +179,7 @@ class CodeGenAMDGPU : public CodeGenLLVM { // Additional optimization hook to tweak the builder. } - unsigned GetGlobalAddressSpace() const final { - return 1; - } + unsigned GetGlobalAddressSpace() const final { return 1; } protected: void InitTarget(llvm::TargetMachine* tm) final { @@ -211,32 +219,29 @@ runtime::Module BuildAMDGPU(IRModule mod, std::string target) { // issue #4087 for a discussion #endif InitializeLLVM(); - CHECK(target.length() >= 4 && - target.substr(0, 4) == "rocm"); + CHECK(target.length() >= 4 && target.substr(0, 4) == "rocm"); std::ostringstream config; - config << "-mtriple=amdgcn-amd-amdhsa-hcc -mcpu=gfx" - << DetectROCMComputeVersion(target) - << " -mattr=-code-object-v3 " - << target.substr(4, target.length() - 4); + config << "-mtriple=amdgcn-amd-amdhsa-hcc -mcpu=gfx" << DetectROCMComputeVersion(target) + << " -mattr=-code-object-v3 " << target.substr(4, target.length() - 4); std::unique_ptr tm = GetLLVMTargetMachine(config.str()); - std::unique_ptr cg(new CodeGenAMDGPU()); std::unique_ptr ctx(new llvm::LLVMContext()); + // careful: cg will hold a naked pointer reference to ctx, so it should + // have a shorter lifetime than the ctx. + std::unique_ptr cg(new CodeGenAMDGPU()); cg->Init("TVMAMDGPUModule", tm.get(), ctx.get(), false, false); - for (auto kv : mod->functions) { - CHECK(kv.second->IsInstance()) - << "Can only lower IR Module with PrimFuncs"; + for (auto kv : mod->functions) { + CHECK(kv.second->IsInstance()) << "Can only lower IR Module with PrimFuncs"; auto f = Downcast(kv.second); cg->AddFunction(f); } - const auto *find_rocm_bitcodes = - tvm::runtime::Registry::Get("tvm_callback_rocm_bitcode_path"); - Array bitcode_files = (*find_rocm_bitcodes)(); + const auto* find_rocm_bitcodes = tvm::runtime::Registry::Get("tvm_callback_rocm_bitcode_path"); + Array bitcode_files = (*find_rocm_bitcodes)(); - for (auto &bitcode : bitcode_files) { - std::string path = bitcode.as()->value; + for (auto& bitcode_path : bitcode_files) { + std::string path = bitcode_path; llvm::SMDiagnostic err; std::unique_ptr mlib = llvm::parseIRFile(path, err, *ctx); if (mlib.get() == nullptr) { @@ -246,7 +251,7 @@ runtime::Module BuildAMDGPU(IRModule mod, std::string target) { } mlib->setTargetTriple(tm->getTargetTriple().str()); mlib->setDataLayout(tm->createDataLayout()); - for (llvm::Function &f : mlib->functions()) { + for (llvm::Function& f : mlib->functions()) { f.addFnAttr(llvm::Attribute::AlwaysInline); } cg->AddLinkModule(std::move(mlib)); @@ -269,33 +274,28 @@ runtime::Module BuildAMDGPU(IRModule mod, std::string target) { llvm::legacy::PassManager pass; #if TVM_LLVM_VERSION <= 60 - CHECK(tm->addPassesToEmitFile( - pass, destObj, llvm::TargetMachine::CGFT_ObjectFile) == 0) - << "Cannot emit target CGFT_ObjectFile"; + CHECK(tm->addPassesToEmitFile(pass, destObj, llvm::TargetMachine::CGFT_ObjectFile) == 0) + << "Cannot emit target CGFT_ObjectFile"; #elif TVM_LLVM_VERSION <= 90 - CHECK(tm->addPassesToEmitFile( - pass, destObj, nullptr, llvm::TargetMachine::CGFT_ObjectFile) == 0) - << "Cannot emit target CGFT_ObjectFile"; + CHECK(tm->addPassesToEmitFile(pass, destObj, nullptr, llvm::TargetMachine::CGFT_ObjectFile) == 0) + << "Cannot emit target CGFT_ObjectFile"; #else - CHECK(tm->addPassesToEmitFile( - pass, destObj, nullptr, llvm::CGFT_ObjectFile) == 0) - << "Cannot emit target CGFT_ObjectFile"; + CHECK(tm->addPassesToEmitFile(pass, destObj, nullptr, llvm::CGFT_ObjectFile) == 0) + << "Cannot emit target CGFT_ObjectFile"; #endif pass.run(*mObj); std::string obj(dataObj.begin(), dataObj.end()); llvm::legacy::PassManager passAsm; #if TVM_LLVM_VERSION <= 60 - CHECK(tm->addPassesToEmitFile(passAsm, destAsm, - llvm::TargetMachine::CGFT_AssemblyFile) == 0) + CHECK(tm->addPassesToEmitFile(passAsm, destAsm, llvm::TargetMachine::CGFT_AssemblyFile) == 0) << "Cannot emit target CGFT_AssemblyFile"; #elif TVM_LLVM_VERSION <= 90 CHECK(tm->addPassesToEmitFile(passAsm, destAsm, nullptr, llvm::TargetMachine::CGFT_AssemblyFile) == 0) << "Cannot emit target CGFT_AssemblyFile"; #else - CHECK(tm->addPassesToEmitFile(passAsm, destAsm, nullptr, - llvm::CGFT_AssemblyFile) == 0) + CHECK(tm->addPassesToEmitFile(passAsm, destAsm, nullptr, llvm::CGFT_AssemblyFile) == 0) << "Cannot emit target CGFT_AssemblyFile"; #endif passAsm.run(*mAsm); @@ -313,8 +313,7 @@ runtime::Module BuildAMDGPU(IRModule mod, std::string target) { return ROCMModuleCreate(hsaco, "hsaco", ExtractFuncInfo(mod), ll, assembly); } -TVM_REGISTER_GLOBAL("target.build.rocm") -.set_body_typed(BuildAMDGPU); +TVM_REGISTER_GLOBAL("target.build.rocm").set_body_typed(BuildAMDGPU); } // namespace codegen } // namespace tvm diff --git a/src/target/llvm/codegen_arm.cc b/src/target/llvm/codegen_arm.cc index 73d849a7b3d1..991d4730a136 100644 --- a/src/target/llvm/codegen_arm.cc +++ b/src/target/llvm/codegen_arm.cc @@ -47,8 +47,7 @@ class CodeGenARM final : public CodeGenCPU { llvm::Value* CodeGenARM::CreateIntrinsic(const CallNode* op) { if (op->is_intrinsic("llvm_intrin")) { - llvm::Intrinsic::ID id = static_cast( - Downcast(op->args[0])->value); + llvm::Intrinsic::ID id = static_cast(Downcast(op->args[0])->value); if (id == ::llvm::Intrinsic::ctpop) { PrimExpr e = ARMPopcount(op); return CodeGenCPU::CreateIntrinsic(e.as()); @@ -57,21 +56,21 @@ llvm::Value* CodeGenARM::CreateIntrinsic(const CallNode* op) { return CodeGenCPU::CreateIntrinsic(op); } -PrimExpr CodeGenARM::ARMPopcount(const CallNode *call) { +PrimExpr CodeGenARM::ARMPopcount(const CallNode* call) { using namespace tir; const PrimExpr& e = call->args[2]; ::llvm::Intrinsic::ID ctpop_id = ::llvm::Intrinsic::ctpop; ::llvm::Intrinsic::ID vpaddlu_id = ::llvm::Intrinsic::arm_neon_vpaddlu; // Fallback to default llvm lowering rule if input type not a full vector or half vector length - int total_size = call->dtype.bits() * call->dtype.lanes(); + int total_size = call->dtype.bits() * call->dtype.lanes(); if (!call->dtype.is_vector() || call->dtype.bits() == 8 || - (total_size != 128 && total_size != 64)) { + (total_size != 128 && total_size != 64)) { Array vcnt_args; vcnt_args.push_back(IntImm(DataType::UInt(32), ctpop_id)); vcnt_args.push_back(IntImm(DataType::UInt(32), 1)); vcnt_args.push_back(e); - return tir::CallNode::make(call->dtype, "llvm_intrin", vcnt_args, CallNode::PureIntrinsic); + return tir::Call(call->dtype, "llvm_intrin", vcnt_args, CallNode::PureIntrinsic); } // Popcount lowering rule: @@ -80,12 +79,11 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode *call) { // to return back to original input type // Dvisions are always divisible (number of bits = 64 or 128) - DataType uint8_type = DataType( - e.dtype().code(), 8, e.dtype().bits() * e.dtype().lanes() / 8); - DataType uint16_type = DataType( - uint8_type.code(), 16, uint8_type.bits() * uint8_type.lanes() / 16); - DataType uint32_type = DataType( - uint16_type.code(), 32, uint8_type.bits() * uint8_type.lanes() / 32); + DataType uint8_type = DataType(e.dtype().code(), 8, e.dtype().bits() * e.dtype().lanes() / 8); + DataType uint16_type = + DataType(uint8_type.code(), 16, uint8_type.bits() * uint8_type.lanes() / 16); + DataType uint32_type = + DataType(uint16_type.code(), 32, uint8_type.bits() * uint8_type.lanes() / 32); // Interpret input as vector of 8bit values PrimExpr input8 = reinterpret(uint8_type, e); @@ -96,16 +94,14 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode *call) { vcnt8_args.push_back(IntImm(DataType::UInt(32), ctpop_id)); vcnt8_args.push_back(IntImm(DataType::UInt(32), 1)); vcnt8_args.push_back(input8); - PrimExpr vcnt8 = tir::CallNode::make( - uint8_type, "llvm_intrin", vcnt8_args, CallNode::PureIntrinsic); + PrimExpr vcnt8 = tir::Call(uint8_type, "llvm_intrin", vcnt8_args, CallNode::PureIntrinsic); // Accumulation 8->16bit Array vcnt16_args; vcnt16_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id)); vcnt16_args.push_back(IntImm(DataType::UInt(32), 1)); vcnt16_args.push_back(vcnt8); - PrimExpr vcnt16 = tir::CallNode::make( - uint16_type, "llvm_intrin", vcnt16_args, CallNode::PureIntrinsic); + PrimExpr vcnt16 = tir::Call(uint16_type, "llvm_intrin", vcnt16_args, CallNode::PureIntrinsic); if (call->dtype.bits() == 16) { return vcnt16; } @@ -115,8 +111,7 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode *call) { vcnt32_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id)); vcnt32_args.push_back(IntImm(DataType::UInt(32), 1)); vcnt32_args.push_back(vcnt16); - PrimExpr vcnt32 = tir::CallNode::make( - uint32_type, "llvm_intrin", vcnt32_args, CallNode::PureIntrinsic); + PrimExpr vcnt32 = tir::Call(uint32_type, "llvm_intrin", vcnt32_args, CallNode::PureIntrinsic); if (call->dtype.bits() == 32) { return vcnt32; } @@ -126,15 +121,14 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode *call) { vcnt64_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id)); vcnt64_args.push_back(IntImm(DataType::UInt(32), 1)); vcnt64_args.push_back(vcnt32); - return tir::CallNode::make( - call->dtype, "llvm_intrin", vcnt64_args, CallNode::PureIntrinsic); + return tir::Call(call->dtype, "llvm_intrin", vcnt64_args, CallNode::PureIntrinsic); } TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_arm") -.set_body([](const TVMArgs& targs, TVMRetValue* rv) { - CodeGenLLVM* cg = new CodeGenARM(); - *rv = static_cast(cg); - }); + .set_body([](const TVMArgs& targs, TVMRetValue* rv) { + CodeGenLLVM* cg = new CodeGenARM(); + *rv = static_cast(cg); + }); } // namespace codegen } // namespace tvm diff --git a/src/target/llvm/codegen_blob.cc b/src/target/llvm/codegen_blob.cc index be8ef9262765..b7c48c779073 100644 --- a/src/target/llvm/codegen_blob.cc +++ b/src/target/llvm/codegen_blob.cc @@ -21,17 +21,17 @@ * \file codegen_blob.cc */ #ifdef TVM_LLVM_VERSION +#include "codegen_blob.h" + #include + #include -#include "codegen_blob.h" namespace tvm { namespace codegen { -std::pair, - std::shared_ptr> CodeGenBlob(const std::string& data, - bool system_lib, - const std::string& target_triple) { +std::pair, std::shared_ptr> CodeGenBlob( + const std::string& data, bool system_lib, const std::string& target_triple) { InitializeLLVM(); auto tm = GetLLVMTargetMachine(std::string("-target ") + target_triple); auto triple = tm->getTargetTriple(); @@ -41,10 +41,9 @@ std::pair, module->setTargetTriple(triple.str()); module->setDataLayout(tm->createDataLayout()); auto* blob_value = llvm::ConstantDataArray::getString(*ctx, data, false); - auto* tvm_dev_mblob = new llvm::GlobalVariable(*module, blob_value->getType(), true, - llvm::GlobalValue::ExternalLinkage, blob_value, - runtime::symbol::tvm_dev_mblob, nullptr, - llvm::GlobalVariable::NotThreadLocal, 0); + auto* tvm_dev_mblob = new llvm::GlobalVariable( + *module, blob_value->getType(), true, llvm::GlobalValue::ExternalLinkage, blob_value, + runtime::symbol::tvm_dev_mblob, nullptr, llvm::GlobalVariable::NotThreadLocal, 0); #if TVM_LLVM_VERSION >= 100 tvm_dev_mblob->setAlignment(llvm::Align(1)); @@ -64,11 +63,9 @@ std::pair, auto int8_ptr_ty = int8_ty->getPointerTo(0); llvm::Constant* constant_zero = llvm::Constant::getNullValue(int32_ty); - auto* tvm_dev_mblob_reg = - new llvm::GlobalVariable(*module, int32_ty, - false, llvm::GlobalValue::InternalLinkage, - constant_zero, - std::string(runtime::symbol::tvm_dev_mblob) + "_reg_"); + auto* tvm_dev_mblob_reg = new llvm::GlobalVariable( + *module, int32_ty, false, llvm::GlobalValue::InternalLinkage, constant_zero, + std::string(runtime::symbol::tvm_dev_mblob) + "_reg_"); auto tvm_dev_mblob_reg_alignment = module->getDataLayout().getABITypeAlignment(int32_ty); #if TVM_LLVM_VERSION >= 100 tvm_dev_mblob_reg->setAlignment(llvm::Align(tvm_dev_mblob_reg_alignment)); @@ -80,11 +77,9 @@ std::pair, llvm::ArrayType::get(int8_ty, std::strlen(runtime::symbol::tvm_dev_mblob) + 1); auto* tvm_dev_mblob_string_value = llvm::ConstantDataArray::getString(*ctx, runtime::symbol::tvm_dev_mblob, true); - auto* tvm_dev_mblob_string = - new llvm::GlobalVariable(*module, tvm_dev_mblob_string_ty, - true, llvm::GlobalValue::PrivateLinkage, - tvm_dev_mblob_string_value, - std::string(runtime::symbol::tvm_dev_mblob) + ".str"); + auto* tvm_dev_mblob_string = new llvm::GlobalVariable( + *module, tvm_dev_mblob_string_ty, true, llvm::GlobalValue::PrivateLinkage, + tvm_dev_mblob_string_value, std::string(runtime::symbol::tvm_dev_mblob) + ".str"); #if TVM_LLVM_VERSION >= 100 tvm_dev_mblob_string->setAlignment(llvm::Align(1)); #else @@ -92,33 +87,30 @@ std::pair, #endif // Global init function - llvm::Function* init_fn = llvm::Function::Create(llvm::FunctionType::get(void_ty, false), - llvm::GlobalValue::InternalLinkage, - llvm::Twine("_GLOBAL__sub_I_", module_name), - module.get()); + llvm::Function* init_fn = llvm::Function::Create( + llvm::FunctionType::get(void_ty, false), llvm::GlobalValue::InternalLinkage, + llvm::Twine("_GLOBAL__sub_I_", module_name), module.get()); // Create variable initialization function. - llvm::Function* var_init_fn = llvm::Function::Create(llvm::FunctionType::get(void_ty, false), - llvm::GlobalValue::InternalLinkage, - llvm::Twine("__cxx_global_var_init"), - module.get()); + llvm::Function* var_init_fn = llvm::Function::Create( + llvm::FunctionType::get(void_ty, false), llvm::GlobalValue::InternalLinkage, + llvm::Twine("__cxx_global_var_init"), module.get()); // Create TVMBackendRegisterSystemLibSymbol function llvm::Function* tvm_backend_fn = llvm::Function::Create(llvm::FunctionType::get(int32_ty, {int8_ptr_ty, int8_ptr_ty}, false), llvm::GlobalValue::ExternalLinkage, - llvm::Twine("TVMBackendRegisterSystemLibSymbol"), - module.get()); + llvm::Twine("TVMBackendRegisterSystemLibSymbol"), module.get()); // Set necessary fn sections auto get_static_init_section_specifier = [&triple]() -> std::string { - if (triple.isOSLinux()) { - return ".text.startup"; - } else if (triple.isOSDarwin()) { - return "__TEXT,__StaticInit,regular,pure_instructions"; - } else { - return ""; - } + if (triple.isOSLinux()) { + return ".text.startup"; + } else if (triple.isOSDarwin()) { + return "__TEXT,__StaticInit,regular,pure_instructions"; + } else { + return ""; + } }; auto static_init_section_specifier = get_static_init_section_specifier(); @@ -144,11 +136,9 @@ std::pair, llvm::Constant* indices[] = {constant_zero, constant_zero}; llvm::SmallVector args; args.push_back(llvm::ConstantExpr::getGetElementPtr(tvm_dev_mblob_string_ty, - tvm_dev_mblob_string, - indices)); - args.push_back(llvm::ConstantExpr::getGetElementPtr(blob_value->getType(), - tvm_dev_mblob, - indices)); + tvm_dev_mblob_string, indices)); + args.push_back( + llvm::ConstantExpr::getGetElementPtr(blob_value->getType(), tvm_dev_mblob, indices)); auto* tvm_backend_fn_ret_value = ir_builder.CreateCall(tvm_backend_fn, args); ir_builder.CreateStore(tvm_backend_fn_ret_value, tvm_dev_mblob_reg); ir_builder.CreateRetVoid(); diff --git a/src/target/llvm/codegen_blob.h b/src/target/llvm/codegen_blob.h index a394f77a6638..2821f44ebd3c 100644 --- a/src/target/llvm/codegen_blob.h +++ b/src/target/llvm/codegen_blob.h @@ -24,9 +24,10 @@ #ifndef TVM_TARGET_LLVM_CODEGEN_BLOB_H_ #define TVM_TARGET_LLVM_CODEGEN_BLOB_H_ #ifdef TVM_LLVM_VERSION -#include #include #include +#include + #include "llvm_common.h" namespace tvm { @@ -40,10 +41,8 @@ namespace codegen { * * \return LLVM module and LLVM context */ -std::pair, - std::shared_ptr> CodeGenBlob(const std::string& data, - bool system_lib, - const std::string& target_triple); +std::pair, std::shared_ptr> CodeGenBlob( + const std::string& data, bool system_lib, const std::string& target_triple); } // namespace codegen } // namespace tvm diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index dde842765c78..6ad050ace9a3 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -22,21 +22,20 @@ */ #ifdef TVM_LLVM_VERSION +#include "codegen_cpu.h" + #include -#include #include + +#include #include #include -#include "codegen_cpu.h" namespace tvm { namespace codegen { -void CodeGenCPU::Init(const std::string& module_name, - llvm::TargetMachine* tm, - llvm::LLVMContext* ctx, - bool system_lib, - bool dynamic_lookup) { +void CodeGenCPU::Init(const std::string& module_name, llvm::TargetMachine* tm, + llvm::LLVMContext* ctx, bool system_lib, bool dynamic_lookup) { CodeGenLLVM::Init(module_name, tm, ctx, system_lib, dynamic_lookup); dbg_info_ = CreateDebugInfo(module_.get()); static_assert(sizeof(TVMValue) == sizeof(double), "invariant"); @@ -47,53 +46,34 @@ void CodeGenCPU::Init(const std::string& module_name, t_tvm_context_ = llvm::StructType::create({t_int_, t_int_}); t_tvm_type_ = llvm::StructType::create({t_int8_, t_int8_, t_int16_}); t_tvm_func_handle_ = t_void_p_; - t_tvm_array_ = llvm::StructType::create( - {t_void_p_, - t_tvm_context_, - t_int_, - t_tvm_type_, - t_tvm_shape_index_->getPointerTo(), - t_tvm_shape_index_->getPointerTo(), - t_int64_}); + t_tvm_array_ = llvm::StructType::create({t_void_p_, t_tvm_context_, t_int_, t_tvm_type_, + t_tvm_shape_index_->getPointerTo(), + t_tvm_shape_index_->getPointerTo(), t_int64_}); t_tvm_value_ = llvm::StructType::create({t_float64_}); - t_tvm_parallel_group_env_ = llvm::StructType::create({ - t_int32_->getPointerTo(), t_int32_}); + t_tvm_parallel_group_env_ = llvm::StructType::create({t_int32_->getPointerTo(), t_int32_}); ftype_tvm_parallel_lambda_ = llvm::FunctionType::get( - t_int_, - {t_int_, - t_tvm_parallel_group_env_->getPointerTo(), - t_void_p_}, false); + t_int_, {t_int_, t_tvm_parallel_group_env_->getPointerTo(), t_void_p_}, false); md_tbaa_ctx_ptr_ = md_builder_->createTBAAScalarTypeNode("ctx_ptr", md_tbaa_root_); // Runtime functions. - ftype_tvm_func_call_ = llvm::FunctionType::get(t_int_, { - t_tvm_func_handle_, - t_tvm_value_->getPointerTo(), - t_int_->getPointerTo(), + ftype_tvm_func_call_ = llvm::FunctionType::get( t_int_, - t_tvm_value_->getPointerTo(), - t_int_->getPointerTo()}, false); - ftype_tvm_get_func_from_env_ = llvm::FunctionType::get(t_int_, { - t_void_p_, - t_char_->getPointerTo(), - t_tvm_func_handle_->getPointerTo()}, false); - ftype_tvm_api_set_last_error_ = llvm::FunctionType::get( - t_void_, {t_char_->getPointerTo()}, false); - ftype_tvm_parallel_launch_ = - llvm::FunctionType::get(t_int_, { - ftype_tvm_parallel_lambda_->getPointerTo(), t_void_p_, t_int_} - , false); + {t_tvm_func_handle_, t_tvm_value_->getPointerTo(), t_int_->getPointerTo(), t_int_, + t_tvm_value_->getPointerTo(), t_int_->getPointerTo()}, + false); + ftype_tvm_get_func_from_env_ = llvm::FunctionType::get( + t_int_, {t_void_p_, t_char_->getPointerTo(), t_tvm_func_handle_->getPointerTo()}, false); + ftype_tvm_api_set_last_error_ = + llvm::FunctionType::get(t_void_, {t_char_->getPointerTo()}, false); + ftype_tvm_parallel_launch_ = llvm::FunctionType::get( + t_int_, {ftype_tvm_parallel_lambda_->getPointerTo(), t_void_p_, t_int_}, false); ftype_tvm_parallel_barrier_ = - llvm::FunctionType::get(t_int_, { - t_int_, t_tvm_parallel_group_env_->getPointerTo()} - , false); - ftype_tvm_static_init_callback_ = - llvm::FunctionType::get(t_int_, {t_void_p_}, false); + llvm::FunctionType::get(t_int_, {t_int_, t_tvm_parallel_group_env_->getPointerTo()}, false); + ftype_tvm_static_init_callback_ = llvm::FunctionType::get(t_int_, {t_void_p_}, false); ftype_tvm_static_init_ = - llvm::FunctionType::get(t_int_, { - t_void_p_->getPointerTo(), - ftype_tvm_static_init_callback_->getPointerTo(), - t_void_p_, t_int_} - , false); + llvm::FunctionType::get(t_int_, + {t_void_p_->getPointerTo(), + ftype_tvm_static_init_callback_->getPointerTo(), t_void_p_, t_int_}, + false); // initialize TVM runtime API if (system_lib) { // We will need this in environment for backward registration. @@ -104,21 +84,20 @@ void CodeGenCPU::Init(const std::string& module_name, f_tvm_register_system_symbol_ = nullptr; } if (dynamic_lookup || system_lib) { - f_tvm_func_call_ = llvm::Function::Create( - ftype_tvm_func_call_, - llvm::Function::ExternalLinkage, "TVMFuncCall", module_.get()); - f_tvm_get_func_from_env_ = llvm::Function::Create( - ftype_tvm_get_func_from_env_, - llvm::Function::ExternalLinkage, "TVMBackendGetFuncFromEnv", module_.get()); - f_tvm_api_set_last_error_ = llvm::Function::Create( - ftype_tvm_api_set_last_error_, - llvm::Function::ExternalLinkage, "TVMAPISetLastError", module_.get()); - f_tvm_parallel_launch_ = llvm::Function::Create( - ftype_tvm_parallel_launch_, - llvm::Function::ExternalLinkage, "TVMBackendParallelLaunch", module_.get()); - f_tvm_parallel_barrier_ = llvm::Function::Create( - ftype_tvm_parallel_barrier_, - llvm::Function::ExternalLinkage, "TVMBackendParallelBarrier", module_.get()); + f_tvm_func_call_ = llvm::Function::Create(ftype_tvm_func_call_, llvm::Function::ExternalLinkage, + "TVMFuncCall", module_.get()); + f_tvm_get_func_from_env_ = + llvm::Function::Create(ftype_tvm_get_func_from_env_, llvm::Function::ExternalLinkage, + "TVMBackendGetFuncFromEnv", module_.get()); + f_tvm_api_set_last_error_ = + llvm::Function::Create(ftype_tvm_api_set_last_error_, llvm::Function::ExternalLinkage, + "TVMAPISetLastError", module_.get()); + f_tvm_parallel_launch_ = + llvm::Function::Create(ftype_tvm_parallel_launch_, llvm::Function::ExternalLinkage, + "TVMBackendParallelLaunch", module_.get()); + f_tvm_parallel_barrier_ = + llvm::Function::Create(ftype_tvm_parallel_barrier_, llvm::Function::ExternalLinkage, + "TVMBackendParallelBarrier", module_.get()); } this->InitGlobalContext(dynamic_lookup); } @@ -153,22 +132,13 @@ void CodeGenCPU::AddDebugInformation(llvm::Function* function) { #if TVM_LLVM_VERSION >= 80 auto* DIFunction = dbg_info_->di_builder_->createFunction( - dbg_info_->file_, function->getName(), "", - dbg_info_->file_, - 0 /* line number */, - DIFunctionTy, - false /* internal linkage */); + dbg_info_->file_, function->getName(), "", dbg_info_->file_, 0 /* line number */, + DIFunctionTy, false /* internal linkage */); #else auto* DIFunction = dbg_info_->di_builder_->createFunction( - dbg_info_->file_, function->getName(), "", - dbg_info_->file_, - 0 /* line number */, - DIFunctionTy, - false, /* internal linkage */ - true, - 0 /* line number */, - llvm::DINode::FlagPrototyped, - true /* isOptimized */); + dbg_info_->file_, function->getName(), "", dbg_info_->file_, 0 /* line number */, + DIFunctionTy, false, /* internal linkage */ + true, 0 /* line number */, llvm::DINode::FlagPrototyped, true /* isOptimized */); #endif CHECK(DIFunction); @@ -237,9 +207,8 @@ void CodeGenCPU::AddMainFunction(const std::string& entry_func_name) { llvm::Function* f = module_->getFunction(entry_func_name); CHECK(f) << "Function " << entry_func_name << "does not in module"; llvm::Type* type = llvm::ArrayType::get(t_char_, entry_func_name.length() + 1); - llvm::GlobalVariable *global = new llvm::GlobalVariable( - *module_, type, true, llvm::GlobalValue::WeakAnyLinkage, 0, - runtime::symbol::tvm_module_main); + llvm::GlobalVariable* global = new llvm::GlobalVariable( + *module_, type, true, llvm::GlobalValue::WeakAnyLinkage, 0, runtime::symbol::tvm_module_main); #if TVM_LLVM_VERSION >= 100 global->setAlignment(llvm::Align(1)); #else @@ -255,8 +224,8 @@ std::unique_ptr CodeGenCPU::Finish() { } return CodeGenLLVM::Finish(); } -llvm::Value* CodeGenCPU::CreateStructRefPtr( - DataType t, llvm::Value* buf, llvm::Value* index, int kind) { +llvm::Value* CodeGenCPU::CreateStructRefPtr(DataType t, llvm::Value* buf, llvm::Value* index, + int kind) { if (kind < intrinsic::kArrKindBound_) { if (buf->getType() == t_void_p_) { buf = builder_->CreatePointerCast(buf, t_tvm_array_->getPointerTo()); @@ -281,27 +250,22 @@ llvm::Value* CodeGenCPU::CreateStructRefPtr( return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(2)}); } case intrinsic::kArrTypeCode: { - return builder_->CreateInBoundsGEP( - buf, {index, ConstInt32(3), ConstInt32(0)}); + return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(3), ConstInt32(0)}); } case intrinsic::kArrTypeBits: { - return builder_->CreateInBoundsGEP( - buf, {index, ConstInt32(3), ConstInt32(1)}); + return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(3), ConstInt32(1)}); } case intrinsic::kArrTypeLanes: { - return builder_->CreateInBoundsGEP( - buf, {index, ConstInt32(3), ConstInt32(2)}); + return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(3), ConstInt32(2)}); } case intrinsic::kArrByteOffset: { return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(6)}); } case intrinsic::kArrDeviceId: { - return builder_->CreateInBoundsGEP( - buf, {index, ConstInt32(1), ConstInt32(1)}); + return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(1), ConstInt32(1)}); } case intrinsic::kArrDeviceType: { - return builder_->CreateInBoundsGEP( - buf, {index, ConstInt32(1), ConstInt32(0)}); + return builder_->CreateInBoundsGEP(buf, {index, ConstInt32(1), ConstInt32(0)}); } case intrinsic::kTVMValueContent: { CHECK_EQ(t.lanes(), 1); @@ -319,7 +283,9 @@ llvm::Value* CodeGenCPU::CreateStructRefPtr( return builder_->CreatePointerCast(buf, t_void_p_->getPointerTo()); } } - default: LOG(FATAL) << "unknown field code"; return nullptr; + default: + LOG(FATAL) << "unknown field code"; + return nullptr; } } @@ -332,8 +298,8 @@ llvm::Value* CodeGenCPU::CreateCallExtern(const CallNode* op) { for (llvm::Value* v : arg_values) { arg_types.push_back(v->getType()); } - llvm::FunctionType* ftype = llvm::FunctionType::get( - GetLLVMType(GetRef(op)), arg_types, false); + llvm::FunctionType* ftype = + llvm::FunctionType::get(GetLLVMType(GetRef(op)), arg_types, false); // Check if it is available in global function table as injected function. auto it = gv_func_map_.find(op->name); if (it != gv_func_map_.end()) { @@ -350,8 +316,8 @@ llvm::Value* CodeGenCPU::CreateCallExtern(const CallNode* op) { } else { llvm::Function* f = module_->getFunction(op->name); if (f == nullptr) { - f = llvm::Function::Create( - ftype, llvm::Function::ExternalLinkage, op->name, module_.get()); + f = llvm::Function::Create(ftype, llvm::Function::ExternalLinkage, + op->name.operator llvm::StringRef(), module_.get()); } #if TVM_LLVM_VERSION >= 90 auto ext_callee = llvm::FunctionCallee(f); @@ -362,12 +328,9 @@ llvm::Value* CodeGenCPU::CreateCallExtern(const CallNode* op) { } } -llvm::GlobalVariable* CodeGenCPU::InitContextPtr( - llvm::Type* p_type, std::string name) { +llvm::GlobalVariable* CodeGenCPU::InitContextPtr(llvm::Type* p_type, std::string name) { llvm::GlobalVariable* gv = new llvm::GlobalVariable( - *module_, p_type, false, - llvm::GlobalValue::LinkOnceAnyLinkage, 0, - name); + *module_, p_type, false, llvm::GlobalValue::LinkOnceAnyLinkage, 0, name); #if TVM_LLVM_VERSION >= 100 gv->setAlignment(llvm::Align(data_layout_->getTypeAllocSize(p_type))); #else @@ -385,9 +348,8 @@ llvm::Value* CodeGenCPU::GetContextPtr(llvm::GlobalVariable* gv) { #else llvm::LoadInst* faddr = builder_->CreateAlignedLoad(gv, gv->getAlignment()); #endif - faddr->setMetadata( - "tbaa", - md_builder_->createTBAAStructTagNode(md_tbaa_ctx_ptr_, md_tbaa_ctx_ptr_, 0)); + faddr->setMetadata("tbaa", + md_builder_->createTBAAStructTagNode(md_tbaa_ctx_ptr_, md_tbaa_ctx_ptr_, 0)); return faddr; } @@ -400,16 +362,15 @@ void CodeGenCPU::InitGlobalContext(bool dynamic_lookup) { std::make_pair(tvm::runtime::symbol::tvm_module_ctx, gv_mod_ctx_)); } else { if (!dynamic_lookup) { - gv_tvm_func_call_ = InitContextPtr( - ftype_tvm_func_call_->getPointerTo(), "__TVMFuncCall"); - gv_tvm_get_func_from_env_ = InitContextPtr( - ftype_tvm_get_func_from_env_->getPointerTo(), "__TVMBackendGetFuncFromEnv"); - gv_tvm_api_set_last_error_ = InitContextPtr( - ftype_tvm_api_set_last_error_->getPointerTo(), "__TVMAPISetLastError"); - gv_tvm_parallel_launch_ = InitContextPtr( - ftype_tvm_parallel_launch_->getPointerTo(), "__TVMBackendParallelLaunch"); - gv_tvm_parallel_barrier_ = InitContextPtr( - ftype_tvm_parallel_barrier_->getPointerTo(), "__TVMBackendParallelBarrier"); + gv_tvm_func_call_ = InitContextPtr(ftype_tvm_func_call_->getPointerTo(), "__TVMFuncCall"); + gv_tvm_get_func_from_env_ = InitContextPtr(ftype_tvm_get_func_from_env_->getPointerTo(), + "__TVMBackendGetFuncFromEnv"); + gv_tvm_api_set_last_error_ = + InitContextPtr(ftype_tvm_api_set_last_error_->getPointerTo(), "__TVMAPISetLastError"); + gv_tvm_parallel_launch_ = + InitContextPtr(ftype_tvm_parallel_launch_->getPointerTo(), "__TVMBackendParallelLaunch"); + gv_tvm_parallel_barrier_ = InitContextPtr(ftype_tvm_parallel_barrier_->getPointerTo(), + "__TVMBackendParallelBarrier"); // Mark as context functions gv_func_map_["TVMBackendAllocWorkspace"] = nullptr; gv_func_map_["TVMBackendFreeWorkspace"] = nullptr; @@ -420,12 +381,9 @@ void CodeGenCPU::InitGlobalContext(bool dynamic_lookup) { llvm::BasicBlock* CodeGenCPU::CheckCallSuccess(llvm::Value* retcode) { // create emit codes that checks and load the function. using llvm::BasicBlock; - BasicBlock* fail_block = BasicBlock::Create( - *ctx_, "call_fail", function_); - BasicBlock* end_block = BasicBlock::Create( - *ctx_, "call_end", function_); - llvm::Value* succ = builder_->CreateICmpEQ( - retcode, llvm::ConstantInt::get(t_int_, 0)); + BasicBlock* fail_block = BasicBlock::Create(*ctx_, "call_fail", function_); + BasicBlock* end_block = BasicBlock::Create(*ctx_, "call_end", function_); + llvm::Value* succ = builder_->CreateICmpEQ(retcode, llvm::ConstantInt::get(t_int_, 0)); builder_->CreateCondBr(succ, end_block, fail_block, md_very_likely_branch_); builder_->SetInsertPoint(fail_block); // return the code. @@ -449,20 +407,15 @@ void CodeGenCPU::CreateComputeScope(const AttrStmtNode* op) { arg_values.push_back(value); arg_types.push_back(value->getType()); } - llvm::FunctionType* ftype = - llvm::FunctionType::get(t_int_, arg_types, false); - llvm::Function* fcompute = - llvm::Function::Create(ftype, - llvm::Function::PrivateLinkage, - op->value.as()->value, - module_.get()); - BasicBlock* compute_call_end = CheckCallSuccess( - builder_->CreateCall(fcompute, arg_values)); + llvm::FunctionType* ftype = llvm::FunctionType::get(t_int_, arg_types, false); + llvm::Function* fcompute = llvm::Function::Create( + ftype, llvm::Function::PrivateLinkage, + op->value.as()->value.operator llvm::StringRef(), module_.get()); + BasicBlock* compute_call_end = CheckCallSuccess(builder_->CreateCall(fcompute, arg_values)); // setup compute fuinction. std::unordered_map new_vmap; size_t idx = 0; - for (auto it = fcompute->arg_begin(); - it != fcompute->arg_end(); ++it, ++idx) { + for (auto it = fcompute->arg_begin(); it != fcompute->arg_end(); ++it, ++idx) { llvm::Argument* v = &(*it); const Var& var = vargs[idx]; new_vmap[var.get()] = v; @@ -476,10 +429,21 @@ void CodeGenCPU::CreateComputeScope(const AttrStmtNode* op) { #endif fcompute->addFnAttr(llvm::Attribute::NoInline); } + // Add alignment attribute if needed. +#if TVM_LLVM_VERSION >= 50 + auto f = alloc_storage_info_.find(var.get()); + if (f != alloc_storage_info_.end()) { + unsigned align = f->second.alignment; + if (align > 1) { + auto attr = llvm::Attribute::get(*ctx_, llvm::Attribute::Alignment, align); + fcompute->addParamAttr(idx, attr); + } + } +#endif } std::swap(function_, fcompute); std::swap(new_vmap, var_map_); - BasicBlock *compute_entry = BasicBlock::Create(*ctx_, "entry", function_); + BasicBlock* compute_entry = BasicBlock::Create(*ctx_, "entry", function_); builder_->SetInsertPoint(compute_entry); this->VisitStmt(op->body); builder_->CreateRet(ConstInt32(0)); @@ -504,48 +468,41 @@ llvm::Value* CodeGenCPU::PackClosureData(const Array& vfields, uint64_t* nu llvm::Value* cdata = builder_->CreateAlloca(tcdata, ConstInt32(1)); llvm::Value* zero = ConstInt32(0); for (size_t i = 0; i < vfields.size(); ++i) { - builder_->CreateStore( - var_map_.at(vfields[i].get()), - builder_->CreateInBoundsGEP(cdata, {zero, ConstInt32(i)})); + builder_->CreateStore(var_map_.at(vfields[i].get()), + builder_->CreateInBoundsGEP(cdata, {zero, ConstInt32(i)})); } *num_bytes = data_layout_->getTypeAllocSize( llvm::cast(cdata->getType())->getElementType()); return cdata; } -void CodeGenCPU::UnpackClosureData(llvm::Value* cdata, - const Array& vfields, +void CodeGenCPU::UnpackClosureData(llvm::Value* cdata, const Array& vfields, std::unordered_map* vmap) { for (size_t i = 0; i < vfields.size(); ++i) { (*vmap)[vfields[i].get()] = - builder_->CreateLoad(builder_->CreateInBoundsGEP( - cdata, {ConstInt32(0), ConstInt32(i)})); + builder_->CreateLoad(builder_->CreateInBoundsGEP(cdata, {ConstInt32(0), ConstInt32(i)})); } } void CodeGenCPU::CreateParallelLaunch(const Stmt& body, int num_task) { using llvm::BasicBlock; // closure data - llvm::Function* f = llvm::Function::Create( - ftype_tvm_parallel_lambda_, - llvm::Function::PrivateLinkage, - "__tvm_parallel_lambda", module_.get()); + llvm::Function* f = + llvm::Function::Create(ftype_tvm_parallel_lambda_, llvm::Function::PrivateLinkage, + "__tvm_parallel_lambda", module_.get()); // allocate and setup the closure, call the closure. Array vfields = tir::UndefinedVars(body, {}); uint64_t nbytes; llvm::Value* cdata = PackClosureData(vfields, &nbytes); #if TVM_LLVM_VERSION >= 90 - auto launch_callee = llvm::FunctionCallee( - ftype_tvm_parallel_launch_, RuntimeTVMParallelLaunch()); + auto launch_callee = llvm::FunctionCallee(ftype_tvm_parallel_launch_, RuntimeTVMParallelLaunch()); #else auto launch_callee = RuntimeTVMParallelLaunch(); #endif - BasicBlock* par_launch_end = CheckCallSuccess( - builder_->CreateCall( - launch_callee, - {f, builder_->CreatePointerCast(cdata, t_void_p_), ConstInt32(num_task)})); + BasicBlock* par_launch_end = CheckCallSuccess(builder_->CreateCall( + launch_callee, {f, builder_->CreatePointerCast(cdata, t_void_p_), ConstInt32(num_task)})); // Setup the closure function. - BasicBlock *lambda_entry = BasicBlock::Create(*ctx_, "entry", f); + BasicBlock* lambda_entry = BasicBlock::Create(*ctx_, "entry", f); builder_->SetInsertPoint(lambda_entry); auto it = f->arg_begin(); llvm::Value* task_id = &(*it++); @@ -559,9 +516,8 @@ void CodeGenCPU::CreateParallelLaunch(const Stmt& body, int num_task) { par_env.task_id = Var("task_id", DataType::Int(32)); par_env.num_task = Var("num_task", DataType::Int(32)); new_vmap[par_env.task_id.get()] = task_id; - new_vmap[par_env.num_task.get()] = builder_->CreateLoad( - builder_->CreateInBoundsGEP( - penv, {ConstInt32(0), ConstInt32(1)})); + new_vmap[par_env.num_task.get()] = + builder_->CreateLoad(builder_->CreateInBoundsGEP(penv, {ConstInt32(0), ConstInt32(1)})); par_env.penv = penv; std::swap(function_, f); std::swap(parallel_env_, par_env); @@ -572,16 +528,13 @@ void CodeGenCPU::CreateParallelLaunch(const Stmt& body, int num_task) { std::swap(var_map_, new_vmap); std::swap(parallel_env_, par_env); std::swap(function_, f); - CHECK_NE(par_env.parallel_loop_count, 0) - << "Cannot find parallel loop within parallel launch"; + CHECK_NE(par_env.parallel_loop_count, 0) << "Cannot find parallel loop within parallel launch"; builder_->SetInsertPoint(par_launch_end); } llvm::Value* CodeGenCPU::CreateStaticHandle() { llvm::GlobalVariable* gv = new llvm::GlobalVariable( - *module_, t_void_p_, false, - llvm::GlobalValue::PrivateLinkage, 0, - "__tvm_static_handle"); + *module_, t_void_p_, false, llvm::GlobalValue::PrivateLinkage, 0, "__tvm_static_handle"); #if TVM_LLVM_VERSION >= 100 gv->setAlignment(llvm::Align(data_layout_->getTypeAllocSize(t_void_p_))); #else @@ -594,26 +547,23 @@ llvm::Value* CodeGenCPU::CreateStaticHandle() { void CodeGenCPU::CreateStaticInit(const std::string& init_fname, const Stmt& body) { using llvm::BasicBlock; // closure data - llvm::Function* f = llvm::Function::Create( - ftype_tvm_static_init_callback_, - llvm::Function::PrivateLinkage, - "__tvm_static_init_lambda", module_.get()); + llvm::Function* f = + llvm::Function::Create(ftype_tvm_static_init_callback_, llvm::Function::PrivateLinkage, + "__tvm_static_init_lambda", module_.get()); llvm::Value* gv = CreateStaticHandle(); llvm::Function* finit = module_->getFunction(init_fname); if (finit == nullptr) { - finit = llvm::Function::Create( - ftype_tvm_static_init_, llvm::Function::ExternalLinkage, init_fname, module_.get()); + finit = llvm::Function::Create(ftype_tvm_static_init_, llvm::Function::ExternalLinkage, + init_fname, module_.get()); } // allocate and setup the closure, call the closure. uint64_t nbytes; Array vfields = tir::UndefinedVars(body, {}); llvm::Value* cdata = PackClosureData(vfields, &nbytes); - BasicBlock* init_end = CheckCallSuccess( - builder_->CreateCall( - finit, - {gv, f, builder_->CreatePointerCast(cdata, t_void_p_), ConstInt32(nbytes)})); + BasicBlock* init_end = CheckCallSuccess(builder_->CreateCall( + finit, {gv, f, builder_->CreatePointerCast(cdata, t_void_p_), ConstInt32(nbytes)})); // Setup the closure function. - BasicBlock *lambda_entry = BasicBlock::Create(*ctx_, "entry", f); + BasicBlock* lambda_entry = BasicBlock::Create(*ctx_, "entry", f); builder_->SetInsertPoint(lambda_entry); auto it = f->arg_begin(); cdata = builder_->CreatePointerCast(&(*it++), cdata->getType()); @@ -643,9 +593,9 @@ llvm::Value* CodeGenCPU::GetPackedFuncHandle(const std::string& fname) { if (it == func_handle_map_.end()) { // create global location for the handle // create the function handle - hptr = new llvm::GlobalVariable( - *module_, t_tvm_func_handle_, false, - llvm::GlobalValue::InternalLinkage, nullptr, ".tvm_func." + fname); + hptr = + new llvm::GlobalVariable(*module_, t_tvm_func_handle_, false, + llvm::GlobalValue::InternalLinkage, nullptr, ".tvm_func." + fname); #if TVM_LLVM_VERSION >= 100 hptr->setAlignment(llvm::Align(align)); #else @@ -658,42 +608,34 @@ llvm::Value* CodeGenCPU::GetPackedFuncHandle(const std::string& fname) { } // create emit codes that checks and load the function. BasicBlock* pre_block = builder_->GetInsertBlock(); - BasicBlock* init_block = BasicBlock::Create( - *ctx_, "handle_init", function_); - BasicBlock* end_block = BasicBlock::Create( - *ctx_, "handle_init_end", function_); + BasicBlock* init_block = BasicBlock::Create(*ctx_, "handle_init", function_); + BasicBlock* end_block = BasicBlock::Create(*ctx_, "handle_init_end", function_); #if TVM_LLVM_VERSION >= 110 llvm::Value* handle = builder_->CreateAlignedLoad(hptr, llvm::Align(align)); #else llvm::Value* handle = builder_->CreateAlignedLoad(hptr, align); #endif - llvm::Value* handle_not_null = builder_->CreateICmpNE( - handle, llvm::Constant::getNullValue(t_tvm_func_handle_)); - builder_->CreateCondBr( - handle_not_null, end_block, init_block, md_very_likely_branch_); + llvm::Value* handle_not_null = + builder_->CreateICmpNE(handle, llvm::Constant::getNullValue(t_tvm_func_handle_)); + builder_->CreateCondBr(handle_not_null, end_block, init_block, md_very_likely_branch_); // Initialize the handle if needed. builder_->SetInsertPoint(init_block); - llvm::Value* out = WithFunctionEntry([&]() { - return builder_->CreateAlloca(t_tvm_func_handle_); - }); + llvm::Value* out = + WithFunctionEntry([&]() { return builder_->CreateAlloca(t_tvm_func_handle_); }); #if TVM_LLVM_VERSION >= 110 - llvm::LoadInst* ctx = builder_->CreateAlignedLoad( - gv_mod_ctx_, llvm::Align(gv_mod_ctx_->getAlignment())); + llvm::LoadInst* ctx = + builder_->CreateAlignedLoad(gv_mod_ctx_, llvm::Align(gv_mod_ctx_->getAlignment())); #else - llvm::LoadInst* ctx = builder_->CreateAlignedLoad( - gv_mod_ctx_, gv_mod_ctx_->getAlignment()); + llvm::LoadInst* ctx = builder_->CreateAlignedLoad(gv_mod_ctx_, gv_mod_ctx_->getAlignment()); #endif - ctx->setMetadata( - "tbaa", - md_builder_->createTBAAStructTagNode(md_tbaa_ctx_ptr_, md_tbaa_ctx_ptr_, 0)); + ctx->setMetadata("tbaa", + md_builder_->createTBAAStructTagNode(md_tbaa_ctx_ptr_, md_tbaa_ctx_ptr_, 0)); #if TVM_LLVM_VERSION >= 90 - auto env_callee = llvm::FunctionCallee( - ftype_tvm_get_func_from_env_, RuntimeTVMGetFuncFromEnv()); + auto env_callee = llvm::FunctionCallee(ftype_tvm_get_func_from_env_, RuntimeTVMGetFuncFromEnv()); #else auto env_callee = RuntimeTVMGetFuncFromEnv(); #endif - llvm::Value* retcode = builder_->CreateCall( - env_callee, {ctx, GetConstString(fname), out}); + llvm::Value* retcode = builder_->CreateCall(env_callee, {ctx, GetConstString(fname), out}); init_block = CheckCallSuccess(retcode); #if TVM_LLVM_VERSION >= 110 llvm::Value* loaded_handle = builder_->CreateAlignedLoad(out, llvm::Align(align)); @@ -711,38 +653,33 @@ llvm::Value* CodeGenCPU::GetPackedFuncHandle(const std::string& fname) { return phi; } -llvm::BasicBlock * -CodeGenCPU::MakeCallPacked(const Array &args, llvm::Value **rvalue, - llvm::Value **ret_tcode, const DataType &r_type, - const int64_t begin, const int64_t end) { +llvm::BasicBlock* CodeGenCPU::MakeCallPacked(const Array& args, llvm::Value** rvalue, + llvm::Value** ret_tcode, const DataType& r_type, + const int64_t begin, const int64_t end) { using llvm::BasicBlock; std::string func_name = args[0].as()->value; - llvm::Value *handle = GetPackedFuncHandle(func_name); + llvm::Value* handle = GetPackedFuncHandle(func_name); // call the function int64_t nargs = end - begin; CHECK_GE(nargs, 0); - llvm::Value *stack_value = MakeValue(args[1]); - llvm::Value *stack_tcode = MakeValue(args[2]); - llvm::Value *arg_value = builder_->CreateInBoundsGEP( - builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()), - ConstInt32(begin)); - llvm::Value *arg_tcode = - CreateBufferPtr(DataType::Int(32), stack_tcode, ConstInt32(begin)); - llvm::Value *ret_value = builder_->CreateInBoundsGEP( - builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()), - ConstInt32(end)); + llvm::Value* stack_value = MakeValue(args[1]); + llvm::Value* stack_tcode = MakeValue(args[2]); + llvm::Value* arg_value = builder_->CreateInBoundsGEP( + builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()), ConstInt32(begin)); + llvm::Value* arg_tcode = CreateBufferPtr(DataType::Int(32), stack_tcode, ConstInt32(begin)); + llvm::Value* ret_value = builder_->CreateInBoundsGEP( + builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()), ConstInt32(end)); *ret_tcode = CreateBufferPtr(DataType::Int(32), stack_tcode, ConstInt32(end)); #if TVM_LLVM_VERSION >= 90 auto call_callee = llvm::FunctionCallee(ftype_tvm_func_call_, RuntimeTVMFuncCall()); #else auto call_callee = RuntimeTVMFuncCall(); #endif - BasicBlock *end_block = CheckCallSuccess(builder_->CreateCall( - call_callee, {handle, arg_value, arg_tcode, ConstInt32(nargs), - ret_value, *ret_tcode})); + BasicBlock* end_block = CheckCallSuccess(builder_->CreateCall( + call_callee, {handle, arg_value, arg_tcode, ConstInt32(nargs), ret_value, *ret_tcode})); DataType r_api_type = tir::APIType(r_type); - llvm::Value* load_ptr = builder_->CreatePointerCast( - ret_value, DTypeToLLVMType(r_api_type)->getPointerTo()); + llvm::Value* load_ptr = + builder_->CreatePointerCast(ret_value, DTypeToLLVMType(r_api_type)->getPointerTo()); #if TVM_LLVM_VERSION >= 110 *rvalue = builder_->CreateAlignedLoad(load_ptr, llvm::Align(8)); #else @@ -752,47 +689,44 @@ CodeGenCPU::MakeCallPacked(const Array &args, llvm::Value **rvalue, return end_block; } -llvm::Value *CodeGenCPU::CreateCallPacked(const CallNode *op) { +llvm::Value* CodeGenCPU::CreateCallPacked(const CallNode* op) { CHECK_EQ(op->args.size(), 5U); - llvm::Value *rvalue = nullptr; - llvm::Value *ret_tcode = nullptr; - MakeCallPacked(op->args, &rvalue, &ret_tcode, op->dtype, - op->args[3].as()->value, + llvm::Value* rvalue = nullptr; + llvm::Value* ret_tcode = nullptr; + MakeCallPacked(op->args, &rvalue, &ret_tcode, op->dtype, op->args[3].as()->value, op->args[4].as()->value); return rvalue; } -llvm::Value *CodeGenCPU::CreateCallTracePacked(const CallNode *op) { +llvm::Value* CodeGenCPU::CreateCallTracePacked(const CallNode* op) { using llvm::BasicBlock; CHECK_EQ(op->args.size(), 6U); - llvm::Value *rvalue = nullptr; - llvm::Value *ret_tcode = nullptr; - BasicBlock *end_block = MakeCallPacked( - op->args, &rvalue, &ret_tcode, op->dtype, op->args[3].as()->value, - op->args[4].as()->value); + llvm::Value* rvalue = nullptr; + llvm::Value* ret_tcode = nullptr; + BasicBlock* end_block = + MakeCallPacked(op->args, &rvalue, &ret_tcode, op->dtype, op->args[3].as()->value, + op->args[4].as()->value); // Get traced value. - llvm::Value *traced_value = MakeValue(op->args[5]); + llvm::Value* traced_value = MakeValue(op->args[5]); // The update_block handles case when we need to update the return value. - BasicBlock *update_block = - BasicBlock::Create(*ctx_, "update_block", function_); + BasicBlock* update_block = BasicBlock::Create(*ctx_, "update_block", function_); // The continue_block handles case when we need to return original // traced value. - BasicBlock *continue_block = - BasicBlock::Create(*ctx_, "continue_block", function_); + BasicBlock* continue_block = BasicBlock::Create(*ctx_, "continue_block", function_); #if TVM_LLVM_VERSION >= 110 - llvm::Value *ret_tcode_value = builder_->CreateAlignedLoad(ret_tcode, llvm::Align(8)); + llvm::Value* ret_tcode_value = builder_->CreateAlignedLoad(ret_tcode, llvm::Align(8)); #else - llvm::Value *ret_tcode_value = builder_->CreateAlignedLoad(ret_tcode, 8); + llvm::Value* ret_tcode_value = builder_->CreateAlignedLoad(ret_tcode, 8); #endif // Check the ret_type_code and create cmp instruction. - llvm::Value *cmp = builder_->CreateICmpNE( - ret_tcode_value, llvm::ConstantInt::get(t_int_, kTVMNullptr)); + llvm::Value* cmp = + builder_->CreateICmpNE(ret_tcode_value, llvm::ConstantInt::get(t_int_, kTVMNullptr)); builder_->CreateCondBr(cmp, update_block, continue_block); builder_->SetInsertPoint(update_block); builder_->CreateBr(continue_block); builder_->SetInsertPoint(continue_block); // The return value depends on from what bb we come from. - llvm::PHINode *phi_rvalue = builder_->CreatePHI(traced_value->getType(), 2); + llvm::PHINode* phi_rvalue = builder_->CreatePHI(traced_value->getType(), 2); phi_rvalue->addIncoming(rvalue, update_block); phi_rvalue->addIncoming(traced_value, end_block); return phi_rvalue; @@ -824,17 +758,14 @@ llvm::Value* CodeGenCPU::RuntimeTVMParallelBarrier() { void CodeGenCPU::AddStartupFunction() { if (export_system_symbols_.size() != 0) { llvm::FunctionType* ftype = llvm::FunctionType::get(t_void_, {}, false); - function_ = llvm::Function::Create( - ftype, - llvm::Function::InternalLinkage, - "__tvm_module_startup", module_.get()); + function_ = llvm::Function::Create(ftype, llvm::Function::InternalLinkage, + "__tvm_module_startup", module_.get()); llvm::BasicBlock* startup_entry = llvm::BasicBlock::Create(*ctx_, "entry", function_); builder_->SetInsertPoint(startup_entry); for (const auto& kv : export_system_symbols_) { llvm::Value* name = GetConstString(kv.first); - builder_->CreateCall( - f_tvm_register_system_symbol_, { - name, builder_->CreateBitCast(kv.second, t_void_p_)}); + builder_->CreateCall(f_tvm_register_system_symbol_, + {name, builder_->CreateBitCast(kv.second, t_void_p_)}); } llvm::appendToGlobalCtors(*module_, function_, 65535); builder_->CreateRet(nullptr); @@ -854,9 +785,8 @@ llvm::Value* CodeGenCPU::CreateIntrinsic(const CallNode* op) { } else if (op->is_intrinsic(intrinsic::tvm_struct_get)) { CHECK_EQ(op->args.size(), 3U); int kind = op->args[2].as()->value; - llvm::Value* ref = this->CreateStructRefPtr( - op->dtype, MakeValue(op->args[0]), - MakeValue(op->args[1]), kind); + llvm::Value* ref = + this->CreateStructRefPtr(op->dtype, MakeValue(op->args[0]), MakeValue(op->args[1]), kind); if (kind == intrinsic::kArrAddr) { return builder_->CreatePointerCast(ref, t_void_p_); } else { @@ -866,13 +796,11 @@ llvm::Value* CodeGenCPU::CreateIntrinsic(const CallNode* op) { CHECK_EQ(op->args.size(), 4U); int kind = op->args[2].as()->value; llvm::Value* value = MakeValue(op->args[3]); - llvm::Value* ref = this->CreateStructRefPtr( - op->args[3].dtype(), MakeValue(op->args[0]), - MakeValue(op->args[1]), kind); + llvm::Value* ref = this->CreateStructRefPtr(op->args[3].dtype(), MakeValue(op->args[0]), + MakeValue(op->args[1]), kind); CHECK(kind != intrinsic::kArrAddr); if (value->getType()->isPointerTy()) { - value = builder_->CreatePointerCast( - value, ref->getType()->getPointerElementType()); + value = builder_->CreatePointerCast(value, ref->getType()->getPointerElementType()); } builder_->CreateStore(value, ref); return ConstInt32(0); @@ -880,22 +808,22 @@ llvm::Value* CodeGenCPU::CreateIntrinsic(const CallNode* op) { CHECK_EQ(op->args.size(), 2U); const std::string& type = op->args[0].as()->value; return WithFunctionEntry([&]() -> llvm::AllocaInst* { - const int64_t* pval = as_const_int(op->args[1]); - CHECK(pval) << "require stack alloca to contain constant value"; - llvm::Value* num = ConstInt32(pval[0]); - if (type == "shape") { - return builder_->CreateAlloca(t_tvm_shape_index_, num); - } else if (type == "arg_value") { - return builder_->CreateAlloca(t_tvm_value_, num); - } else if (type == "arg_tcode") { - return builder_->CreateAlloca(t_int_, num); - } else if (type == "array") { - return builder_->CreateAlloca(t_tvm_array_, num); - } else { - LOG(FATAL) << "Unknown stack alloca type " << type; - return nullptr; - } - }); + const int64_t* pval = as_const_int(op->args[1]); + CHECK(pval) << "require stack alloca to contain constant value"; + llvm::Value* num = ConstInt32(pval[0]); + if (type == "shape") { + return builder_->CreateAlloca(t_tvm_shape_index_, num); + } else if (type == "arg_value") { + return builder_->CreateAlloca(t_tvm_value_, num); + } else if (type == "arg_tcode") { + return builder_->CreateAlloca(t_int_, num); + } else if (type == "array") { + return builder_->CreateAlloca(t_tvm_array_, num); + } else { + LOG(FATAL) << "Unknown stack alloca type " << type; + return nullptr; + } + }); } else { return CodeGenLLVM::CreateIntrinsic(op); } @@ -910,16 +838,14 @@ void CodeGenCPU::VisitStmt_(const AssertStmtNode* op) { os << ", " << op->message.as()->value; } llvm::Value* msg = GetConstString(os.str()); - BasicBlock* fail_block = BasicBlock::Create( - *ctx_, "assert_fail", function_); - BasicBlock* end_block = BasicBlock::Create( - *ctx_, "assert_end", function_); + BasicBlock* fail_block = BasicBlock::Create(*ctx_, "assert_fail", function_); + BasicBlock* end_block = BasicBlock::Create(*ctx_, "assert_end", function_); builder_->CreateCondBr(cond, end_block, fail_block, md_very_likely_branch_); // fail condition. builder_->SetInsertPoint(fail_block); #if TVM_LLVM_VERSION >= 90 - auto err_callee = llvm::FunctionCallee( - ftype_tvm_api_set_last_error_, RuntimeTVMAPISetLastError()); + auto err_callee = + llvm::FunctionCallee(ftype_tvm_api_set_last_error_, RuntimeTVMAPISetLastError()); #else auto err_callee = RuntimeTVMAPISetLastError(); #endif @@ -933,7 +859,7 @@ void CodeGenCPU::VisitStmt_(const AssertStmtNode* op) { void CodeGenCPU::VisitStmt_(const AttrStmtNode* op) { if (op->attr_key == tir::attr::coproc_uop_scope) { this->CreateStaticInit(op->value.as()->value, op->body); - } else if (op->attr_key == tir::attr::compute_scope) { + } else if (op->attr_key == tir::attr::compute_scope) { this->CreateComputeScope(op); } else if (tir::attr::IsPragmaKey(op->attr_key)) { if (op->attr_key == "pragma_parallel_stride_pattern") { @@ -944,20 +870,18 @@ void CodeGenCPU::VisitStmt_(const AttrStmtNode* op) { } else if (op->attr_key == "pragma_parallel_launch_point") { CreateParallelLaunch(op->body, 0); } else if (op->attr_key == "pragma_parallel_barrier_when_finish") { - CHECK(parallel_env_.penv != nullptr) - << "Cannot run barrier without parallel environment"; + CHECK(parallel_env_.penv != nullptr) << "Cannot run barrier without parallel environment"; CHECK(!parallel_env_.in_parallel_loop) << "Cannot not place within parallel loop as the workload may differ, " << " place it between parallel and parallel_launch_point"; this->VisitStmt(op->body); #if TVM_LLVM_VERSION >= 90 - auto bar_callee = llvm::FunctionCallee( - ftype_tvm_parallel_barrier_, RuntimeTVMParallelBarrier()); + auto bar_callee = + llvm::FunctionCallee(ftype_tvm_parallel_barrier_, RuntimeTVMParallelBarrier()); #else auto bar_callee = RuntimeTVMParallelBarrier(); #endif - builder_->CreateCall( - bar_callee, {MakeValue(parallel_env_.task_id), parallel_env_.penv}); + builder_->CreateCall(bar_callee, {MakeValue(parallel_env_.task_id), parallel_env_.penv}); } else if (op->attr_key == tir::attr::pragma_import_llvm) { const StringImmNode* value = op->value.as(); CHECK(value != nullptr); @@ -974,15 +898,12 @@ void CodeGenCPU::VisitStmt_(const AttrStmtNode* op) { void CodeGenCPU::VisitStmt_(const ForNode* op) { CHECK(is_zero(op->min)); - if (op->for_type == ForType::Serial || - op->for_type == ForType::Unrolled) { + if (op->for_type == ForType::Serial || op->for_type == ForType::Unrolled) { CodeGenLLVM::VisitStmt_(op); } else if (op->for_type == ForType::Parallel) { if (parallel_env_.penv == nullptr) { CreateParallelLaunch( - ForNode::make( - op->loop_var, op->min, op->extent, - op->for_type, op->device_api, op->body), 0); + For(op->loop_var, op->min, op->extent, op->for_type, op->device_api, op->body), 0); } else { // already in parallel env. CHECK(parallel_env_.task_id.defined()); @@ -995,20 +916,14 @@ void CodeGenCPU::VisitStmt_(const ForNode* op) { << "Nested parallel loop is not supported by threadpool, try fuse them instead"; parallel_env_.in_parallel_loop = true; if (parallel_env_.stride_pattern) { - CreateSerialFor(MakeValue(task_id), - MakeValue(op->extent), - MakeValue(num_task), - op->loop_var, - op->body); + CreateSerialFor(MakeValue(task_id), MakeValue(op->extent), MakeValue(num_task), + op->loop_var, op->body); } else { PrimExpr step = (op->extent + num_task - make_const(t, 1)) / num_task; - PrimExpr begin = MinNode::make(task_id * step, op->extent); - PrimExpr end = MinNode::make((task_id + make_const(t, 1)) * step, op->extent); - CreateSerialFor(MakeValue(begin), - MakeValue(end), - llvm::ConstantInt::getSigned(GetLLVMType(end), 1), - op->loop_var, - op->body); + PrimExpr begin = min(task_id * step, op->extent); + PrimExpr end = min((task_id + make_const(t, 1)) * step, op->extent); + CreateSerialFor(MakeValue(begin), MakeValue(end), + llvm::ConstantInt::getSigned(GetLLVMType(end), 1), op->loop_var, op->body); } parallel_env_.in_parallel_loop = false; ++parallel_env_.parallel_loop_count; diff --git a/src/target/llvm/codegen_cpu.h b/src/target/llvm/codegen_cpu.h index aa8371c39a5c..7a14b8fdc959 100644 --- a/src/target/llvm/codegen_cpu.h +++ b/src/target/llvm/codegen_cpu.h @@ -24,11 +24,12 @@ #ifndef TVM_TARGET_LLVM_CODEGEN_CPU_H_ #define TVM_TARGET_LLVM_CODEGEN_CPU_H_ -#include -#include #include #include #include +#include +#include + #include "codegen_llvm.h" namespace tvm { @@ -37,11 +38,8 @@ namespace codegen { // CPU host code generation class CodeGenCPU : public CodeGenLLVM { public: - void Init(const std::string& module_name, - llvm::TargetMachine* tm, - llvm::LLVMContext* ctx, - bool system_lib, - bool dynamic_lookup) override; + void Init(const std::string& module_name, llvm::TargetMachine* tm, llvm::LLVMContext* ctx, + bool system_lib, bool dynamic_lookup) override; void AddFunction(const PrimFunc& f) override; void AddMainFunction(const std::string& entry_func_name) override; std::unique_ptr Finish() override; @@ -95,20 +93,18 @@ class CodeGenCPU : public CodeGenLLVM { llvm::Value* RuntimeTVMParallelBarrier(); llvm::Value* CreateStaticHandle(); llvm::Value* GetPackedFuncHandle(const std::string& str); - llvm::Value* PackClosureData(const Array& fields, uint64_t *num_bytes); + llvm::Value* PackClosureData(const Array& fields, uint64_t* num_bytes); llvm::Value* CreateStructRefPtr(DataType t, llvm::Value* buffer, llvm::Value* index, int kind); - void UnpackClosureData(llvm::Value*cdata, - const Array& fields, + void UnpackClosureData(llvm::Value* cdata, const Array& fields, std::unordered_map* vmap); // Make packed call. - llvm::BasicBlock *MakeCallPacked(const Array &args, - llvm::Value **rvalue, - llvm::Value **ret_tcode, const DataType &r_type, + llvm::BasicBlock* MakeCallPacked(const Array& args, llvm::Value** rvalue, + llvm::Value** ret_tcode, const DataType& r_type, const int64_t begin, const int64_t end); // create call into tvm packed function. llvm::Value* CreateCallPacked(const CallNode* op); // Create trace call into tvm packed function. - llvm::Value* CreateCallTracePacked(const CallNode *op); + llvm::Value* CreateCallTracePacked(const CallNode* op); // Create static initialization void CreateStaticInit(const std::string& init_fname, const Stmt& body); // Create parallel launch diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 14302efe82fc..85e3de5844fd 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -22,19 +22,21 @@ */ #ifdef TVM_LLVM_VERSION // Part of the code are adapted from Halide's CodeGen_LLVM -#include +#include "codegen_llvm.h" + #include +#include #include #include -#include "codegen_llvm.h" -#include "codegen_cpu.h" +#include "../../arith/pattern_match.h" #include "../build_common.h" +#include "codegen_cpu.h" namespace tvm { namespace codegen { -std::unique_ptr CodeGenLLVM::Create(llvm::TargetMachine *tm) { +std::unique_ptr CodeGenLLVM::Create(llvm::TargetMachine* tm) { std::string target = tm->getTarget().getName(); std::string factory_name = "tvm.codegen.llvm.target_" + target; const PackedFunc* f = runtime::Registry::Get(factory_name); @@ -46,11 +48,8 @@ std::unique_ptr CodeGenLLVM::Create(llvm::TargetMachine *tm) { } } -void CodeGenLLVM::Init(const std::string& module_name, - llvm::TargetMachine* tm, - llvm::LLVMContext* ctx, - bool system_lib, - bool dynamic_lookup) { +void CodeGenLLVM::Init(const std::string& module_name, llvm::TargetMachine* tm, + llvm::LLVMContext* ctx, bool system_lib, bool dynamic_lookup) { InitializeLLVM(); ctx_ = ctx; builder_.reset(new IRBuilder(*ctx_)); @@ -67,7 +66,7 @@ void CodeGenLLVM::Init(const std::string& module_name, t_int64_ = llvm::Type::getInt64Ty(*ctx_); t_float64_ = llvm::Type::getDoubleTy(*ctx_); // meta data - md_very_likely_branch_ = md_builder_->createBranchWeights(1<<20, 1); + md_very_likely_branch_ = md_builder_->createBranchWeights(1 << 20, 1); md_tbaa_root_ = md_builder_->createTBAARoot("tvm-tbaa"); md_tbaa_alias_set_ = md_builder_->createTBAANode("tvm-alias", md_tbaa_root_); this->InitTarget(tm); @@ -95,9 +94,7 @@ void CodeGenLLVM::InitTarget(llvm::TargetMachine* tm) { } } -void CodeGenLLVM::AddFunction(const PrimFunc& f) { - this->AddFunctionInternal(f, false); -} +void CodeGenLLVM::AddFunction(const PrimFunc& f) { this->AddFunctionInternal(f, false); } void CodeGenLLVM::InitFuncState() { var_map_.clear(); @@ -107,7 +104,6 @@ void CodeGenLLVM::InitFuncState() { analyzer_.reset(new arith::Analyzer()); } - void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) { this->InitFuncState(); @@ -125,8 +121,8 @@ void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) { // TODO(tvm-team): // Update the function type to respect the ret_type field of f. // Once we allow more flexibility in the PrimFunc. - llvm::FunctionType* ftype = llvm::FunctionType::get( - ret_void ? t_void_ : t_int_, param_types, false); + llvm::FunctionType* ftype = + llvm::FunctionType::get(ret_void ? t_void_ : t_int_, param_types, false); auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); CHECK(global_symbol.defined()) @@ -134,9 +130,8 @@ void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) { CHECK(module_->getFunction(static_cast(global_symbol.value())) == nullptr) << "Function " << global_symbol << " already exist in module"; - function_ = llvm::Function::Create( - ftype, llvm::Function::ExternalLinkage, - global_symbol.value().operator std::string(), module_.get()); + function_ = llvm::Function::Create(ftype, llvm::Function::ExternalLinkage, + global_symbol.value().operator std::string(), module_.get()); function_->setCallingConv(llvm::CallingConv::C); function_->setDLLStorageClass(llvm::GlobalValue::DLLStorageClassTypes::DLLExportStorageClass); @@ -161,6 +156,21 @@ void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) { builder_->SetInsertPoint(entry); this->VisitStmt(f->body); + // Add alignment attribute if needed. +#if TVM_LLVM_VERSION >= 50 + for (size_t i = 0; i < f->params.size(); ++i) { + const Var& var = f->params[i]; + auto f = alloc_storage_info_.find(var.get()); + if (f != alloc_storage_info_.end()) { + unsigned align = f->second.alignment; + if (align > 1) { + auto attr = llvm::Attribute::get(*ctx_, llvm::Attribute::Alignment, align); + function_->addParamAttr(i, attr); + } + } + } +#endif + if (ret_void) { builder_->CreateRetVoid(); } else { @@ -168,7 +178,6 @@ void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) { } } - std::unique_ptr CodeGenLLVM::Finish() { this->AddStartupFunction(); for (size_t i = 0; i < link_modules_.size(); ++i) { @@ -181,13 +190,11 @@ std::unique_ptr CodeGenLLVM::Finish() { return std::move(module_); } - void CodeGenLLVM::HandleImport(const std::string& code) { std::unique_ptr mlib; llvm::SMDiagnostic err; if (code.length() >= 3 && - (code.substr(code.length() - 3) == ".ll" || - code.substr(code.length() - 3) == ".bc")) { + (code.substr(code.length() - 3) == ".ll" || code.substr(code.length() - 3) == ".bc")) { mlib = llvm::parseIRFile(code, err, *ctx_); if (mlib.get() == nullptr) { std::string msg = std::string(err.getMessage()); @@ -195,20 +202,19 @@ void CodeGenLLVM::HandleImport(const std::string& code) { << "line " << err.getLineNo() << ":" << msg; } } else { - std::unique_ptr buf = - llvm::MemoryBuffer::getMemBuffer(code); + std::unique_ptr buf = llvm::MemoryBuffer::getMemBuffer(code); mlib = llvm::parseIR(*buf, err, *ctx_); if (mlib.get() == nullptr) { std::string msg = std::string(err.getMessage()); LOG(FATAL) << "Fail to load llvm ir " - << "line " << err.getLineNo() << ":" << msg - << "\ncontent:\n" << code; + << "line " << err.getLineNo() << ":" << msg << "\ncontent:\n" + << code; } } mlib->setTargetTriple(target_machine_->getTargetTriple().str()); mlib->setDataLayout(target_machine_->createDataLayout()); // mark all the functions as force inline - for (llvm::Function &f : mlib->functions()) { + for (llvm::Function& f : mlib->functions()) { f.removeFnAttr(llvm::Attribute::NoInline); f.addFnAttr(llvm::Attribute::AlwaysInline); f.setLinkage(llvm::GlobalValue::AvailableExternallyLinkage); @@ -237,35 +243,27 @@ llvm::Value* CodeGenLLVM::CreateStorageSync(const CallNode* op) { class FPassManager : public llvm::legacy::FunctionPassManager { public: - explicit FPassManager(llvm::Module* m) - : llvm::legacy::FunctionPassManager(m) {} + explicit FPassManager(llvm::Module* m) : llvm::legacy::FunctionPassManager(m) {} // override add to allow messaging - void add(llvm::Pass* p) final { - llvm::legacy::FunctionPassManager::add(p); - } + void add(llvm::Pass* p) final { llvm::legacy::FunctionPassManager::add(p); } }; class MPassManager : public llvm::legacy::PassManager { public: // override add to allow messaging - void add(llvm::Pass* p) final { - llvm::legacy::PassManager::add(p); - } + void add(llvm::Pass* p) final { llvm::legacy::PassManager::add(p); } }; -void CodeGenLLVM::InitPassManagerBuilder(llvm::PassManagerBuilder* builder) { -} +void CodeGenLLVM::InitPassManagerBuilder(llvm::PassManagerBuilder* builder) {} void CodeGenLLVM::Optimize() { // pass manager FPassManager fpass(module_.get()); MPassManager mpass; mpass.add(llvm::createTargetTransformInfoWrapperPass( - target_machine_ ? target_machine_->getTargetIRAnalysis() : - llvm::TargetIRAnalysis())); + target_machine_ ? target_machine_->getTargetIRAnalysis() : llvm::TargetIRAnalysis())); fpass.add(llvm::createTargetTransformInfoWrapperPass( - target_machine_ ? target_machine_->getTargetIRAnalysis() : - llvm::TargetIRAnalysis())); + target_machine_ ? target_machine_->getTargetIRAnalysis() : llvm::TargetIRAnalysis())); // place optimization pass llvm::PassManagerBuilder builder; @@ -299,24 +297,32 @@ int CodeGenLLVM::NativeVectorBits(const runtime::StorageScope& storage_scope) co return native_vector_bits_; } -unsigned CodeGenLLVM::GetGlobalAddressSpace() const { - return 0; -} +unsigned CodeGenLLVM::GetGlobalAddressSpace() const { return 0; } llvm::Type* CodeGenLLVM::DTypeToLLVMType(const DataType& dtype) const { if (dtype.is_handle()) { CHECK_EQ(dtype.lanes(), 1); return t_void_p_; } + if (dtype.is_void()) { + return t_void_; + } llvm::Type* etype = nullptr; if (dtype.is_int() || dtype.is_uint()) { etype = llvm::Type::getIntNTy(*ctx_, dtype.bits()); } else if (dtype.is_float()) { switch (dtype.bits()) { - case 16: etype = llvm::Type::getHalfTy(*ctx_); break; - case 32: etype = llvm::Type::getFloatTy(*ctx_); break; - case 64: etype = llvm::Type::getDoubleTy(*ctx_); break; - default: LOG(FATAL) << "do not support " << dtype; + case 16: + etype = llvm::Type::getHalfTy(*ctx_); + break; + case 32: + etype = llvm::Type::getFloatTy(*ctx_); + break; + case 64: + etype = llvm::Type::getDoubleTy(*ctx_); + break; + default: + LOG(FATAL) << "do not support " << dtype; } } if (dtype.lanes() != 1) { @@ -351,39 +357,35 @@ llvm::Type* CodeGenLLVM::GetLLVMType(const PrimExpr& expr) const { // // This trick comes from Halide's CodeGen_LLVM // -void CodeGenLLVM::AddAliasInfo(llvm::Instruction* inst, - const VarNode* buffer, - PrimExpr index, +void CodeGenLLVM::AddAliasInfo(llvm::Instruction* inst, const VarNode* buffer, PrimExpr index, DataType type) { if (alias_var_set_.count(buffer) != 0) { // Mark all possibly aliased pointer as same type. llvm::MDNode* meta = md_tbaa_alias_set_; - inst->setMetadata( - "tbaa", - md_builder_->createTBAAStructTagNode(meta, meta, 0)); + inst->setMetadata("tbaa", md_builder_->createTBAAStructTagNode(meta, meta, 0)); return; } - int base = 0, width = 0; + + int64_t base = 0, width = 0; + arith::PVar pbase, pstride; + arith::PVar planes; // create meta-data for alias analysis // Use a group of binary tree ranges of memory banks. if (index.defined()) { - const RampNode* ramp = index.as(); - if (ramp) { - int base, stride; - if (arith::GetConstInt(ramp->base, &base) && - arith::GetConstInt(ramp->stride, &stride)) { - int xwith = ramp->lanes * stride; - width = 1; - while (width < xwith) { - width *= 2; - } - while (base % width) { - base -= base % width; - width *= 2; - } + if (arith::ramp(pbase, pstride, planes).Match(index)) { + base = pbase.Eval()->value; + int64_t xwith = planes.Eval() * pstride.Eval()->value; + width = 1; + while (width < xwith) { + width *= 2; } - } else { - if (arith::GetConstInt(index, &base)) width = 1; + while (base % width) { + base -= base % width; + width *= 2; + } + } else if (auto* ptr = index.as()) { + width = 1; + base = ptr->value; } } llvm::MDNode* meta = md_tbaa_root_; @@ -394,23 +396,18 @@ void CodeGenLLVM::AddAliasInfo(llvm::Instruction* inst, meta = md_builder_->createTBAAScalarTypeNode(buffer_type.str(), meta); // create a tree-shape access structure. if (width != 0) { - for (int w = 1024; w >= width; w /= 2) { - int b = (base / w) * w; + for (int64_t w = 1024; w >= width; w /= 2) { + int64_t b = (base / w) * w; std::stringstream os; os << buffer << ".w" << w << ".b" << b; meta = md_builder_->createTBAAScalarTypeNode(os.str(), meta); } } - inst->setMetadata( - "tbaa", - md_builder_->createTBAAStructTagNode(meta, meta, 0)); + inst->setMetadata("tbaa", md_builder_->createTBAAStructTagNode(meta, meta, 0)); } -void CodeGenLLVM::GetAlignment(DataType t, - const VarNode* buf_var, - const PrimExpr& index, - int* p_alignment, - int* p_native_bits) { +void CodeGenLLVM::GetAlignment(DataType t, const VarNode* buf_var, const PrimExpr& index, + int* p_alignment, int* p_native_bits) { int max_align_bits = t.bits(); auto it = alloc_storage_info_.find(buf_var); if (it != alloc_storage_info_.end()) { @@ -426,11 +423,9 @@ void CodeGenLLVM::GetAlignment(DataType t, int64_t coeff = me->coeff; int align_bits = t.bits(); - while (align_bits < max_align_bits && - base % 2 == 0 && - coeff % 2 == 0) { - base = base / 2; - coeff = coeff / 2; + while (align_bits < max_align_bits && base % 2 == 0 && coeff % 2 == 0) { + base = base / 2; + coeff = coeff / 2; align_bits *= 2; } if (align_bits < 8) { @@ -439,8 +434,7 @@ void CodeGenLLVM::GetAlignment(DataType t, *p_alignment = align_bits / 8; } -std::unique_ptr -CodeGenLLVM::CreateDebugInfo(llvm::Module* module) { +std::unique_ptr CodeGenLLVM::CreateDebugInfo(llvm::Module* module) { #if TVM_LLVM_VERSION >= 100 auto debug_info = std::make_unique(); debug_info->di_builder_ = std::make_unique(*module); @@ -459,8 +453,7 @@ CodeGenLLVM::CreateDebugInfo(llvm::Module* module) { } llvm::Value* CodeGenLLVM::CreateBroadcast(llvm::Value* value, int lanes) { - llvm::Constant* undef = llvm::UndefValue::get( - llvm::VectorType::get(value->getType(), lanes)); + llvm::Constant* undef = llvm::UndefValue::get(llvm::VectorType::get(value->getType(), lanes)); llvm::Constant* zero = ConstInt32(0); value = builder_->CreateInsertElement(undef, value, zero); #if TVM_LLVM_VERSION >= 110 @@ -473,7 +466,7 @@ llvm::Value* CodeGenLLVM::CreateBroadcast(llvm::Value* value, int lanes) { } llvm::Value* CodeGenLLVM::CreateVecSlice(llvm::Value* vec, int begin, int extent) { - int num_elems = static_cast(vec->getType()->getVectorNumElements()); + int num_elems = llvm::cast(vec->getType())->getNumElements(); if (extent == num_elems && begin == 0) return vec; CHECK(begin >= 0 && extent <= num_elems) << "Slicing out of bound!\n"; std::vector indices; @@ -489,8 +482,12 @@ llvm::Value* CodeGenLLVM::CreateVecSlice(llvm::Value* vec, int begin, int extent } llvm::Value* CodeGenLLVM::CreateVecFlip(llvm::Value* vec) { - int num_elems = static_cast(vec->getType()->getVectorNumElements()); + int num_elems = llvm::cast(vec->getType())->getNumElements(); +#if TVM_LLVM_VERSION >= 110 + std::vector indices; +#else std::vector indices; +#endif for (int i = 0; i < num_elems; ++i) { indices.push_back(num_elems - i - 1); } @@ -498,9 +495,8 @@ llvm::Value* CodeGenLLVM::CreateVecFlip(llvm::Value* vec) { } llvm::Value* CodeGenLLVM::CreateVecPad(llvm::Value* vec, int target_lanes) { - llvm::Value* mask = llvm::UndefValue::get( - DTypeToLLVMType(DataType::Int(32, target_lanes))); - int num_elems = static_cast(vec->getType()->getVectorNumElements()); + llvm::Value* mask = llvm::UndefValue::get(DTypeToLLVMType(DataType::Int(32, target_lanes))); + int num_elems = llvm::cast(vec->getType())->getNumElements(); if (num_elems == target_lanes) return vec; CHECK_LT(num_elems, target_lanes); for (int i = 0; i < num_elems; ++i) { @@ -514,23 +510,26 @@ llvm::Value* CodeGenLLVM::CreateVecConcat(std::vector vecs) { int total_lanes = 0; for (llvm::Value* v : vecs) { - total_lanes += static_cast( - v->getType()->getVectorNumElements()); + total_lanes += llvm::cast(v->getType())->getNumElements(); } while (vecs.size() > 1) { std::vector new_vecs; for (size_t i = 0; i < vecs.size() - 1; i += 2) { llvm::Value* lhs = vecs[i]; llvm::Value* rhs = vecs[i + 1]; - const size_t lhs_lanes = lhs->getType()->getVectorNumElements(); - const size_t rhs_lanes = rhs->getType()->getVectorNumElements(); + const size_t lhs_lanes = llvm::cast(lhs->getType())->getNumElements(); + const size_t rhs_lanes = llvm::cast(rhs->getType())->getNumElements(); if (lhs_lanes < rhs_lanes) { lhs = CreateVecPad(lhs, rhs_lanes); } else if (rhs_lanes < lhs_lanes) { rhs = CreateVecPad(rhs, lhs_lanes); } const size_t shared_lanes = std::max(lhs_lanes, rhs_lanes); +#if TVM_LLVM_VERSION >= 110 + std::vector mask; +#else std::vector mask; +#endif for (size_t i = 0; i < lhs_lanes; ++i) { mask.push_back(i); } @@ -547,28 +546,21 @@ llvm::Value* CodeGenLLVM::CreateVecConcat(std::vector vecs) { return CreateVecSlice(vecs[0], 0, total_lanes); } - -void CodeGenLLVM::CreateSerialFor(llvm::Value* begin, - llvm::Value* end, - llvm::Value* stride, - const Var& loop_var, - const Stmt& body) { +void CodeGenLLVM::CreateSerialFor(llvm::Value* begin, llvm::Value* end, llvm::Value* stride, + const Var& loop_var, const Stmt& body) { using llvm::BasicBlock; BasicBlock* pre_block = builder_->GetInsertBlock(); - BasicBlock* for_begin = BasicBlock::Create( - *ctx_, "for_begin", function_); - BasicBlock* for_body = BasicBlock::Create( - *ctx_, "for_body", function_); - BasicBlock* for_end = BasicBlock::Create( - *ctx_, "for_end", function_); + BasicBlock* for_begin = BasicBlock::Create(*ctx_, "for_begin", function_); + BasicBlock* for_body = BasicBlock::Create(*ctx_, "for_body", function_); + BasicBlock* for_end = BasicBlock::Create(*ctx_, "for_end", function_); builder_->CreateBr(for_begin); builder_->SetInsertPoint(for_begin); llvm::PHINode* loop_value = builder_->CreatePHI(begin->getType(), 2); loop_value->addIncoming(begin, pre_block); CHECK(!var_map_.count(loop_var.get())); var_map_[loop_var.get()] = loop_value; - builder_->CreateCondBr(CreateLT(loop_var.dtype(), loop_value, end), - for_body, for_end, md_very_likely_branch_); + builder_->CreateCondBr(CreateLT(loop_var.dtype(), loop_value, end), for_body, for_end, + md_very_likely_branch_); builder_->SetInsertPoint(for_body); this->VisitStmt(body); var_map_.erase(loop_var.get()); @@ -580,7 +572,7 @@ void CodeGenLLVM::CreateSerialFor(llvm::Value* begin, // cast operatpr llvm::Value* CodeGenLLVM::CreateCast(DataType from, DataType to, llvm::Value* value) { - llvm::Type * target = DTypeToLLVMType(to); + llvm::Type* target = DTypeToLLVMType(to); if (value->getType() == target) return value; if (to.is_handle()) { return builder_->CreateBitCast(value, target); @@ -617,8 +609,8 @@ llvm::Value* CodeGenLLVM::GetConstString(const std::string& str) { auto it = str_map_.find(str); if (it != str_map_.end()) return it->second; llvm::Type* type = llvm::ArrayType::get(t_char_, str.length() + 1); - llvm::GlobalVariable *global = new llvm::GlobalVariable( - *module_, type, true, llvm::GlobalValue::PrivateLinkage, 0, ".str"); + llvm::GlobalVariable* global = + new llvm::GlobalVariable(*module_, type, true, llvm::GlobalValue::PrivateLinkage, 0, ".str"); #if TVM_LLVM_VERSION >= 100 global->setAlignment(llvm::Align(1)); #else @@ -627,14 +619,12 @@ llvm::Value* CodeGenLLVM::GetConstString(const std::string& str) { global->setInitializer(llvm::ConstantDataArray::getString(*ctx_, str)); llvm::Constant* zero = ConstInt32(0); llvm::Constant* indices[] = {zero, zero}; - llvm::Constant* ptr = llvm::ConstantExpr::getGetElementPtr( - type, global, indices); + llvm::Constant* ptr = llvm::ConstantExpr::getGetElementPtr(type, global, indices); str_map_[str] = ptr; return ptr; } -llvm::Value* CodeGenLLVM::CreateBufferPtr( - DataType t, llvm::Value* buffer, llvm::Value* index) { +llvm::Value* CodeGenLLVM::CreateBufferPtr(DataType t, llvm::Value* buffer, llvm::Value* index) { CHECK_EQ(t.lanes(), 1); llvm::PointerType* btype = llvm::dyn_cast(buffer->getType()); CHECK(btype != nullptr); @@ -646,13 +636,11 @@ llvm::Value* CodeGenLLVM::CreateBufferPtr( return builder_->CreateInBoundsGEP(buffer, index); } -llvm::Value* CodeGenLLVM::CreateBufferVecPtr( - DataType t, llvm::Value* buffer, llvm::Value* index) { +llvm::Value* CodeGenLLVM::CreateBufferVecPtr(DataType t, llvm::Value* buffer, llvm::Value* index) { CHECK_GT(t.lanes(), 1); llvm::PointerType* btype = llvm::dyn_cast(buffer->getType()); CHECK(btype != nullptr); - llvm::PointerType* ptype = DTypeToLLVMType(t)->getPointerTo( - btype->getAddressSpace()); + llvm::PointerType* ptype = DTypeToLLVMType(t)->getPointerTo(btype->getAddressSpace()); if (btype != ptype) { buffer = builder_->CreatePointerCast(buffer, ptype); } @@ -672,21 +660,19 @@ llvm::Value* CodeGenLLVM::CreateCallExtern(const CallNode* op) { arg_value.push_back(MakeValue(op->args[i])); arg_type.push_back(arg_value.back()->getType()); } - llvm::FunctionType* ftype = llvm::FunctionType::get( - GetLLVMType(GetRef(op)), arg_type, false); + llvm::FunctionType* ftype = + llvm::FunctionType::get(GetLLVMType(GetRef(op)), arg_type, false); llvm::Function* f = module_->getFunction(op->name); if (f == nullptr) { - f = llvm::Function::Create( - ftype, llvm::Function::ExternalLinkage, - op->name, module_.get()); + f = llvm::Function::Create(ftype, llvm::Function::ExternalLinkage, + op->name.operator llvm::StringRef(), module_.get()); } llvm::CallInst* call = builder_->CreateCall(f, arg_value); return call; } -llvm::Function* CodeGenLLVM::GetIntrinsicDecl( - llvm::Intrinsic::ID id, llvm::Type* ret_type, - llvm::ArrayRef arg_types) { +llvm::Function* CodeGenLLVM::GetIntrinsicDecl(llvm::Intrinsic::ID id, llvm::Type* ret_type, + llvm::ArrayRef arg_types) { llvm::Module* module = module_.get(); if (!llvm::Intrinsic::isOverloaded(id)) { @@ -701,8 +687,7 @@ llvm::Function* CodeGenLLVM::GetIntrinsicDecl( auto try_match = [&](llvm::FunctionType* f_ty, bool var_arg) { overload_types.clear(); llvm::ArrayRef ref(infos); - auto match = - llvm::Intrinsic::matchIntrinsicSignature(f_ty, ref, overload_types); + auto match = llvm::Intrinsic::matchIntrinsicSignature(f_ty, ref, overload_types); if (match == llvm::Intrinsic::MatchIntrinsicTypes_Match) { bool error = llvm::Intrinsic::matchIntrinsicVarArg(var_arg, ref); if (error) { @@ -737,7 +722,7 @@ llvm::Function* CodeGenLLVM::GetIntrinsicDecl( // Failed to identify the type. return nullptr; -#else // TVM_LLVM_VERSION +#else // TVM_LLVM_VERSION llvm::ArrayRef ref(infos); // matchIntrinsicType returns true on error. if (llvm::Intrinsic::matchIntrinsicType(ret_type, ref, overload_types)) { @@ -755,9 +740,8 @@ llvm::Function* CodeGenLLVM::GetIntrinsicDecl( llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { if (op->is_intrinsic("llvm_intrin")) { CHECK_GE(op->args.size(), 2U); - llvm::Intrinsic::ID id = static_cast( - Downcast(op->args[0])->value); - int64_t num_signature = Downcast(op->args[1])->value; + llvm::Intrinsic::ID id = static_cast(Downcast(op->args[0])->value); + int64_t num_signature = Downcast(op->args[1])->value; std::vector arg_value; std::vector arg_type; for (size_t i = 2; i < op->args.size(); ++i) { @@ -773,9 +757,8 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { // mismatch will have to be treated specially here. // TODO(kparzysz-quic): fix this once TVM prefetch uses the same // type as LLVM. - llvm::Type *return_type = (id != llvm::Intrinsic::prefetch) - ? GetLLVMType(GetRef(op)) - : llvm::Type::getVoidTy(*ctx_); + llvm::Type* return_type = (id != llvm::Intrinsic::prefetch) ? GetLLVMType(GetRef(op)) + : llvm::Type::getVoidTy(*ctx_); llvm::Function* f = GetIntrinsicDecl(id, return_type, arg_type); CHECK(f) << "Cannot find intrinsic declaration, possible type mismatch: " @@ -800,22 +783,18 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { } else if (op->is_intrinsic(intrinsic::tvm_storage_sync)) { return CreateStorageSync(op); } else if (op->is_intrinsic(intrinsic::tvm_address_of)) { - const LoadNode *l = op->args[0].as(); + const LoadNode* l = op->args[0].as(); CHECK(op->args.size() == 1 && l); - const RampNode *r = l->index.as(); + const RampNode* r = l->index.as(); llvm::Value* ptr; unsigned addrspace; if (!r) { - ptr = CreateBufferPtr( - l->dtype, MakeValue(l->buffer_var), MakeValue(l->index)); - addrspace = llvm::dyn_cast( - ptr->getType())->getAddressSpace(); + ptr = CreateBufferPtr(l->dtype, MakeValue(l->buffer_var), MakeValue(l->index)); + addrspace = llvm::dyn_cast(ptr->getType())->getAddressSpace(); } else { - PrimExpr index = r->base / make_const(DataType::Int(32), r->lanes); - ptr = CreateBufferVecPtr( - l->dtype, MakeValue(l->buffer_var), MakeValue(index)); - addrspace = llvm::dyn_cast( - ptr->getType())->getAddressSpace(); + PrimExpr index = r->base / make_const(DataType::Int(32), r->lanes); + ptr = CreateBufferVecPtr(l->dtype, MakeValue(l->buffer_var), MakeValue(index)); + addrspace = llvm::dyn_cast(ptr->getType())->getAddressSpace(); } return builder_->CreatePointerCast(ptr, t_char_->getPointerTo(addrspace)); } else if (op->is_intrinsic(CallNode::reinterpret) && is_zero(op->args[0])) { @@ -829,15 +808,11 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { uint64_t val = (high << 32U) | low; return llvm::ConstantInt::get(DTypeToLLVMType(op->dtype), val); } else if (op->is_intrinsic(intrinsic::tvm_if_then_else)) { - CHECK_EQ(op->args[0].dtype().lanes(), 1) - << "if_then_else can only take scalar condition"; + CHECK_EQ(op->args[0].dtype().lanes(), 1) << "if_then_else can only take scalar condition"; using llvm::BasicBlock; - BasicBlock* then_block = BasicBlock::Create( - *ctx_, "if_then", function_); - BasicBlock* else_block = BasicBlock::Create( - *ctx_, "if_else", function_); - BasicBlock* end_block = BasicBlock::Create( - *ctx_, "if_end", function_); + BasicBlock* then_block = BasicBlock::Create(*ctx_, "if_then", function_); + BasicBlock* else_block = BasicBlock::Create(*ctx_, "if_else", function_); + BasicBlock* end_block = BasicBlock::Create(*ctx_, "if_end", function_); builder_->CreateCondBr(MakeValue(op->args[0]), then_block, else_block); builder_->SetInsertPoint(then_block); llvm::Value* then_value = MakeValue(op->args[1]); @@ -853,25 +828,29 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { value->addIncoming(else_value, else_value_block); return value; } else if (op->is_intrinsic(CallNode::reinterpret)) { - llvm::Type * target = DTypeToLLVMType(op->dtype); + llvm::Type* target = DTypeToLLVMType(op->dtype); return builder_->CreateBitCast(MakeValue(op->args[0]), target); } else if (op->is_intrinsic(CallNode::isnan)) { // TODO(hgt312): set fast math flag llvm::Value* a = MakeValue(op->args[0]); return builder_->CreateFCmpUNO(a, a); } else if (op->is_intrinsic("vectorlow")) { - llvm::Value *v = MakeValue(op->args[0]); - int l = v->getType()->getVectorNumElements(); - return CreateVecSlice(v, 0, l/2); + llvm::Value* v = MakeValue(op->args[0]); + int l = llvm::cast(v->getType())->getNumElements(); + return CreateVecSlice(v, 0, l / 2); } else if (op->is_intrinsic("vectorhigh")) { - llvm::Value *v = MakeValue(op->args[0]); - int l = v->getType()->getVectorNumElements(); - return CreateVecSlice(v, l/2, l/2); + llvm::Value* v = MakeValue(op->args[0]); + int l = llvm::cast(v->getType())->getNumElements(); + return CreateVecSlice(v, l / 2, l / 2); } else if (op->is_intrinsic("vectorcombine")) { - llvm::Value *v0 = MakeValue(op->args[0]); - llvm::Value *v1 = MakeValue(op->args[1]); - int num_elems = static_cast(v0->getType()->getVectorNumElements()) * 2; + llvm::Value* v0 = MakeValue(op->args[0]); + llvm::Value* v1 = MakeValue(op->args[1]); + int num_elems = llvm::cast(v0->getType())->getNumElements() * 2; +#if TVM_LLVM_VERSION >= 110 + std::vector indices; +#else std::vector indices; +#endif for (int i = 0; i < num_elems; ++i) { indices.push_back(i); } @@ -882,8 +861,7 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { } } -void CodeGenLLVM::Scalarize(const PrimExpr& e, - std::function f) { +void CodeGenLLVM::Scalarize(const PrimExpr& e, std::function f) { if (const RampNode* ramp = e.as()) { for (int i = 0; i < ramp->dtype.lanes(); ++i) { PrimExpr offset = ramp->base + (ramp->stride * i); @@ -897,11 +875,8 @@ void CodeGenLLVM::Scalarize(const PrimExpr& e, } } - // Visitors -llvm::Value* CodeGenLLVM::VisitExpr_(const VarNode* op) { - return GetVarValue(op); -} +llvm::Value* CodeGenLLVM::VisitExpr_(const VarNode* op) { return GetVarValue(op); } llvm::Value* CodeGenLLVM::VisitExpr_(const CastNode* op) { return CreateCast(op->value.dtype(), op->dtype, MakeValue(op->value)); @@ -914,52 +889,48 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const FloatImmNode* op) { return llvm::ConstantFP::get(DTypeToLLVMType(op->dtype), op->value); } -llvm::Value* CodeGenLLVM::VisitExpr_(const StringImmNode* op) { - return GetConstString(op->value); -} - -#define DEFINE_CODEGEN_BINARY_OP(Op) \ - llvm::Value* CodeGenLLVM::Create ## Op( \ - DataType t, llvm::Value* a, llvm::Value *b) { \ - if (t.is_int()) { \ - if (t.bits() >= 32) { \ - return builder_->CreateNSW ## Op (a, b); \ - } else { \ - return builder_->Create ## Op (a, b); \ - } \ - } else if (t.is_uint()) { \ - if (t.bits() >= 32) { \ - return builder_->CreateNUW ## Op (a, b); \ - } else { \ - return builder_->Create ## Op (a, b); \ - } \ - } else { \ - CHECK(t.is_float()); \ - return builder_->CreateF ## Op (a, b); \ - } \ - } \ - llvm::Value* CodeGenLLVM::VisitExpr_(const Op ## Node* op) { \ - return Create ## Op(op->dtype, MakeValue(op->a), MakeValue(op->b)); \ +llvm::Value* CodeGenLLVM::VisitExpr_(const StringImmNode* op) { return GetConstString(op->value); } + +#define DEFINE_CODEGEN_BINARY_OP(Op) \ + llvm::Value* CodeGenLLVM::Create##Op(DataType t, llvm::Value* a, llvm::Value* b) { \ + if (t.is_int()) { \ + if (t.bits() >= 32) { \ + return builder_->CreateNSW##Op(a, b); \ + } else { \ + return builder_->Create##Op(a, b); \ + } \ + } else if (t.is_uint()) { \ + if (t.bits() >= 32) { \ + return builder_->CreateNUW##Op(a, b); \ + } else { \ + return builder_->Create##Op(a, b); \ + } \ + } else { \ + CHECK(t.is_float()); \ + return builder_->CreateF##Op(a, b); \ + } \ + } \ + llvm::Value* CodeGenLLVM::VisitExpr_(const Op##Node* op) { \ + return Create##Op(op->dtype, MakeValue(op->a), MakeValue(op->b)); \ } DEFINE_CODEGEN_BINARY_OP(Add); DEFINE_CODEGEN_BINARY_OP(Sub); DEFINE_CODEGEN_BINARY_OP(Mul); -#define DEFINE_CODEGEN_CMP_OP(Op) \ - llvm::Value* CodeGenLLVM::Create ## Op( \ - DataType t, llvm::Value* a, llvm::Value* b) { \ - if (t.is_int()) { \ - return builder_->CreateICmpS ## Op (a, b); \ - } else if (t.is_uint()) { \ - return builder_->CreateICmpU ## Op (a, b); \ - } else { \ - CHECK(t.is_float()); \ - return builder_->CreateFCmpO ## Op (a, b); \ - } \ -} \ - llvm::Value* CodeGenLLVM::VisitExpr_(const Op ## Node* op) { \ - return Create ## Op(op->a.dtype(), MakeValue(op->a), MakeValue(op->b)); \ +#define DEFINE_CODEGEN_CMP_OP(Op) \ + llvm::Value* CodeGenLLVM::Create##Op(DataType t, llvm::Value* a, llvm::Value* b) { \ + if (t.is_int()) { \ + return builder_->CreateICmpS##Op(a, b); \ + } else if (t.is_uint()) { \ + return builder_->CreateICmpU##Op(a, b); \ + } else { \ + CHECK(t.is_float()); \ + return builder_->CreateFCmpO##Op(a, b); \ + } \ + } \ + llvm::Value* CodeGenLLVM::VisitExpr_(const Op##Node* op) { \ + return Create##Op(op->a.dtype(), MakeValue(op->a), MakeValue(op->b)); \ } DEFINE_CODEGEN_CMP_OP(LT); @@ -1038,10 +1009,8 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const NotNode* op) { } llvm::Value* CodeGenLLVM::VisitExpr_(const SelectNode* op) { - return builder_->CreateSelect( - MakeValue(op->condition), - MakeValue(op->true_value), - MakeValue(op->false_value)); + return builder_->CreateSelect(MakeValue(op->condition), MakeValue(op->true_value), + MakeValue(op->false_value)); } llvm::Value* CodeGenLLVM::VisitExpr_(const LetNode* op) { @@ -1062,8 +1031,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const LoadNode* op) { GetAlignment(t, op->buffer_var.get(), op->index, &alignment, &native_bits); llvm::Value* ptr = CreateBufferPtr(t, buffer, index); #if TVM_LLVM_VERSION >= 110 - llvm::LoadInst* load = - builder_->CreateAlignedLoad(ptr, llvm::Align(alignment), is_volatile); + llvm::LoadInst* load = builder_->CreateAlignedLoad(ptr, llvm::Align(alignment), is_volatile); #else llvm::LoadInst* load = builder_->CreateAlignedLoad(ptr, alignment, is_volatile); #endif @@ -1071,20 +1039,17 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const LoadNode* op) { return load; } else { // vector load - unsigned addrspace = llvm::dyn_cast( - buffer->getType())->getAddressSpace(); + unsigned addrspace = llvm::dyn_cast(buffer->getType())->getAddressSpace(); if (const RampNode* ramp = op->index.as()) { if (is_one(ramp->stride)) { int alignment, native_bits; GetAlignment(t, op->buffer_var.get(), ramp->base, &alignment, &native_bits); CHECK_EQ(ramp->lanes, t.lanes()); - llvm::Value* ptr = CreateBufferPtr( - t.element_of(), buffer, MakeValue(ramp->base)); - ptr = builder_->CreatePointerCast( - ptr, DTypeToLLVMType(t)->getPointerTo(addrspace)); + llvm::Value* ptr = CreateBufferPtr(t.element_of(), buffer, MakeValue(ramp->base)); + ptr = builder_->CreatePointerCast(ptr, DTypeToLLVMType(t)->getPointerTo(addrspace)); #if TVM_LLVM_VERSION >= 110 - llvm::LoadInst* load = builder_->CreateAlignedLoad( - ptr, llvm::Align(alignment), is_volatile); + llvm::LoadInst* load = + builder_->CreateAlignedLoad(ptr, llvm::Align(alignment), is_volatile); #else llvm::LoadInst* load = builder_->CreateAlignedLoad(ptr, alignment, is_volatile); #endif @@ -1099,11 +1064,9 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const LoadNode* op) { auto f = [&](int i, llvm::Value* index) { llvm::Value* ptr = CreateBufferPtr(t.element_of(), buffer, index); #if TVM_LLVM_VERSION >= 110 - llvm::LoadInst* load = builder_->CreateAlignedLoad( - ptr, llvm::Align(basic_align), is_volatile); + llvm::LoadInst* load = builder_->CreateAlignedLoad(ptr, llvm::Align(basic_align), is_volatile); #else - llvm::LoadInst* load = builder_->CreateAlignedLoad( - ptr, basic_align, is_volatile); + llvm::LoadInst* load = builder_->CreateAlignedLoad(ptr, basic_align, is_volatile); #endif ret = builder_->CreateInsertElement(ret, load, ConstInt32(i)); AddAliasInfo(load, op->buffer_var.get(), PrimExpr(), t); @@ -1113,16 +1076,13 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const LoadNode* op) { } llvm::Value* CodeGenLLVM::VisitExpr_(const CallNode* op) { - if (op->call_type == CallNode::Intrinsic || - op->call_type == CallNode::PureIntrinsic) { + if (op->call_type == CallNode::Intrinsic || op->call_type == CallNode::PureIntrinsic) { return CreateIntrinsic(op); - } else if (op->call_type == CallNode::Extern || - op->call_type == CallNode::PureExtern) { + } else if (op->call_type == CallNode::Extern || op->call_type == CallNode::PureExtern) { return CreateCallExtern(op); } else { - LOG(FATAL) << "Unknown call type " << - "name= " << op->name << - " call_type= " << op->call_type; + LOG(FATAL) << "Unknown call type " + << "name= " << op->name << " call_type= " << op->call_type; return nullptr; } } @@ -1131,14 +1091,13 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const RampNode* op) { llvm::Value* vec = llvm::UndefValue::get(DTypeToLLVMType(op->dtype)); for (int i = 0; i < op->lanes; ++i) { vec = builder_->CreateInsertElement( - vec, MakeValue(op->base + op->stride * make_const(op->stride.dtype(), i)), - ConstInt32(i)); + vec, MakeValue(op->base + op->stride * make_const(op->stride.dtype(), i)), ConstInt32(i)); } return vec; } llvm::Value* CodeGenLLVM::VisitExpr_(const ShuffleNode* op) { - std::vector vecs(op->vectors.size()); + std::vector vecs(op->vectors.size()); int total_lanes = 0; for (int i = 0, e = op->vectors.size(); i < e; ++i) { vecs[i] = VisitExpr(op->vectors[i]); @@ -1147,9 +1106,9 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const ShuffleNode* op) { llvm::Value* v0 = CreateVecConcat(vecs); std::vector idx(op->indices.size()); for (int i = 0, e = op->indices.size(); i < e; ++i) { - const int64_t *val = as_const_int(op->indices[i]); - CHECK(val && *val >= 0 && *val < total_lanes) << "Shuffled indeces are suppose to be int, " - << "but get " << op->indices[i] << "\n"; + const int64_t* val = as_const_int(op->indices[i]); + CHECK(val && *val >= 0 && *val < total_lanes) << "Shuffled indeces are suppose to be int, " + << "but get " << op->indices[i] << "\n"; idx[i] = *val; } llvm::Value* mask = llvm::ConstantDataVector::get(builder_->getContext(), idx); @@ -1183,15 +1142,13 @@ void CodeGenLLVM::VisitStmt_(const StoreNode* op) { return; } else { // vector store - unsigned addrspace = llvm::dyn_cast( - buffer->getType())->getAddressSpace(); + unsigned addrspace = llvm::dyn_cast(buffer->getType())->getAddressSpace(); if (const RampNode* ramp = op->index.as()) { if (is_one(ramp->stride)) { int alignment, native_bits; GetAlignment(t, op->buffer_var.get(), ramp->base, &alignment, &native_bits); CHECK_EQ(ramp->lanes, t.lanes()); - llvm::Value* ptr = CreateBufferPtr( - t.element_of(), buffer, MakeValue(ramp->base)); + llvm::Value* ptr = CreateBufferPtr(t.element_of(), buffer, MakeValue(ramp->base)); ptr = builder_->CreatePointerCast(ptr, DTypeToLLVMType(t)->getPointerTo(addrspace)); #if TVM_LLVM_VERSION >= 110 llvm::StoreInst* store = @@ -1211,12 +1168,10 @@ void CodeGenLLVM::VisitStmt_(const StoreNode* op) { llvm::Value* ptr = CreateBufferPtr(t.element_of(), buffer, index); #if TVM_LLVM_VERSION >= 110 llvm::StoreInst* store = builder_->CreateAlignedStore( - builder_->CreateExtractElement(value, i), - ptr, llvm::Align(basic_align), is_volatile); + builder_->CreateExtractElement(value, i), ptr, llvm::Align(basic_align), is_volatile); #else - llvm::StoreInst* store = builder_->CreateAlignedStore( - builder_->CreateExtractElement(value, i), - ptr, basic_align, is_volatile); + llvm::StoreInst* store = builder_->CreateAlignedStore(builder_->CreateExtractElement(value, i), + ptr, basic_align, is_volatile); #endif AddAliasInfo(store, op->buffer_var.get(), PrimExpr(), op->value.dtype()); }; @@ -1233,21 +1188,16 @@ void CodeGenLLVM::VisitStmt_(const ForNode* op) { CHECK(op->for_type == ForType::Serial); } CreateSerialFor(MakeValue(op->min), MakeValue(op->extent), - llvm::ConstantInt::getSigned(GetLLVMType(op->extent), 1), - op->loop_var, op->body); + llvm::ConstantInt::getSigned(GetLLVMType(op->extent), 1), op->loop_var, op->body); } - void CodeGenLLVM::VisitStmt_(const IfThenElseNode* op) { using llvm::BasicBlock; llvm::Value* cond = MakeValue(op->condition); - BasicBlock* then_block = BasicBlock::Create( - *ctx_, "if_then", function_); - BasicBlock* end_block = BasicBlock::Create( - *ctx_, "if_end", function_); + BasicBlock* then_block = BasicBlock::Create(*ctx_, "if_then", function_); + BasicBlock* end_block = BasicBlock::Create(*ctx_, "if_end", function_); if (op->else_case.defined()) { - BasicBlock* else_block = BasicBlock::Create( - *ctx_, "if_else", function_); + BasicBlock* else_block = BasicBlock::Create(*ctx_, "if_else", function_); builder_->CreateCondBr(cond, then_block, else_block); builder_->SetInsertPoint(then_block); this->VisitStmt(op->then_case); @@ -1264,39 +1214,35 @@ void CodeGenLLVM::VisitStmt_(const IfThenElseNode* op) { builder_->SetInsertPoint(end_block); } - void CodeGenLLVM::VisitStmt_(const AllocateNode* op) { CHECK(!is_zero(op->condition)); llvm::Value* buf = nullptr; - int32_t constant_size = op->constant_allocation_size(); - CHECK_GT(constant_size, 0) - << "Can only handle constant size stack allocation"; - StorageInfo& info = alloc_storage_info_[op->buffer_var.get()]; - if (constant_size % 4 == 0 && info.alignment == 0) { - info.alignment = GetTempAllocaAlignment(op->dtype, constant_size); - } - // maximum necessary alignment in the NV devices - if (info.alignment > 16) { - info.alignment = 16; - } - llvm::AllocaInst* alloca = WithFunctionEntry([&]() { - return builder_->CreateAlloca( - DTypeToLLVMType(op->dtype), ConstInt32(constant_size)); - }); - if (alloca->getAlignment() < static_cast(info.alignment)) { + int32_t constant_size = op->constant_allocation_size(); + CHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation"; + StorageInfo& info = alloc_storage_info_[op->buffer_var.get()]; + if (constant_size % 4 == 0 && info.alignment == 0) { + info.alignment = GetTempAllocaAlignment(op->dtype, constant_size); + } + // maximum necessary alignment in the NV devices + if (info.alignment > 16) { + info.alignment = 16; + } + llvm::AllocaInst* alloca = WithFunctionEntry([&]() { + return builder_->CreateAlloca(DTypeToLLVMType(op->dtype), ConstInt32(constant_size)); + }); + if (alloca->getAlignment() < static_cast(info.alignment)) { #if TVM_LLVM_VERSION >= 100 - alloca->setAlignment(llvm::Align(info.alignment)); + alloca->setAlignment(llvm::Align(info.alignment)); #else - alloca->setAlignment(info.alignment); + alloca->setAlignment(info.alignment); #endif - } - info.alignment = alloca->getAlignment(); - buf = alloca; + } + info.alignment = alloca->getAlignment(); + buf = alloca; buf = builder_->CreatePointerCast( - buf, DTypeToLLVMType(op->dtype)->getPointerTo( - buf->getType()->getPointerAddressSpace())); + buf, DTypeToLLVMType(op->dtype)->getPointerTo(buf->getType()->getPointerAddressSpace())); CHECK(!var_map_.count(op->buffer_var.get())); var_map_[op->buffer_var.get()] = buf; this->VisitStmt(op->body); @@ -1315,12 +1261,15 @@ void CodeGenLLVM::VisitStmt_(const AttrStmtNode* op) { const VarNode* v = op->node.as(); CHECK(v); alloc_storage_info_[v].scope = - runtime::StorageScope::make(op->value.as()->value); + runtime::StorageScope::Create(op->value.as()->value); } else if (op->attr_key == tir::attr::storage_alignment) { const VarNode* v = op->node.as(); CHECK(v); - alloc_storage_info_[v].alignment = - static_cast(op->value.as()->value); + alloc_storage_info_[v].alignment = static_cast(op->value.as()->value); + if (var_map_.count(v) && alloc_storage_info_[v].alignment > 1) { + builder_->CreateAlignmentAssumption(*data_layout_, GetVarValue(v), + alloc_storage_info_[v].alignment); + } } else if (op->attr_key == tir::attr::volatile_scope) { const VarNode* v = op->node.as(); CHECK(v); @@ -1335,14 +1284,19 @@ void CodeGenLLVM::VisitStmt_(const AssertStmtNode* op) { } void CodeGenLLVM::VisitStmt_(const LetStmtNode* op) { - CHECK(!var_map_.count(op->var.get())); - if (op->var.dtype().is_handle()) { + const VarNode* v = op->var.get(); + CHECK(!var_map_.count(v)); + if (v->dtype.is_handle()) { if (!is_restricted_) { - alias_var_set_.insert(op->var.get()); + alias_var_set_.insert(v); } } - var_map_[op->var.get()] = MakeValue(op->value); + var_map_[v] = MakeValue(op->value); analyzer_->Bind(op->var, op->value); + if (alloc_storage_info_.count(v) && alloc_storage_info_[v].alignment > 1) { + builder_->CreateAlignmentAssumption(*data_layout_, GetVarValue(v), + alloc_storage_info_[v].alignment); + } this->VisitStmt(op->body); } @@ -1352,9 +1306,7 @@ void CodeGenLLVM::VisitStmt_(const SeqStmtNode* op) { } } -void CodeGenLLVM::VisitStmt_(const EvaluateNode* op) { - MakeValue(op->value); -} +void CodeGenLLVM::VisitStmt_(const EvaluateNode* op) { MakeValue(op->value); } } // namespace codegen } // namespace tvm diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index 5c7ca6fb622f..0bca2a169ba4 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -25,27 +25,26 @@ #define TVM_TARGET_LLVM_CODEGEN_LLVM_H_ #ifdef TVM_LLVM_VERSION +#include #include #include -#include +#include #include -#include -#include #include +#include +#include #include -#include - #include -#include -#include #include #include #include -#include "llvm_common.h" +#include +#include + #include "../../runtime/thread_storage_scope.h" -#include "../../arith/compute_expr.h" -#include "../../tir/pass/ir_util.h" +#include "../../tir/transforms/ir_util.h" +#include "llvm_common.h" namespace tvm { namespace codegen { @@ -55,9 +54,8 @@ using namespace tir; /*! * \brief A base class to generate a LLVM. */ -class CodeGenLLVM : - public ExprFunctor, - public StmtFunctor { +class CodeGenLLVM : public ExprFunctor, + public StmtFunctor { public: /*! * \brief Create new code generator based on target machine. @@ -74,11 +72,8 @@ class CodeGenLLVM : * \param dynamic_lookup Whether dynamically lookup runtime function * or use the runtime function table passed by caller. */ - virtual void Init(const std::string& module_name, - llvm::TargetMachine* tm, - llvm::LLVMContext* ctx, - bool system_lib, - bool dynamic_lookup); + virtual void Init(const std::string& module_name, llvm::TargetMachine* tm, llvm::LLVMContext* ctx, + bool system_lib, bool dynamic_lookup); /*! * \brief Compile and add function f to the current module. * \param f The function to be added. @@ -104,9 +99,7 @@ class CodeGenLLVM : * \param e The expression to be created value for. * \return created value. */ - llvm::Value* MakeValue(const PrimExpr& e) { - return VisitExpr(e); - } + llvm::Value* MakeValue(const PrimExpr& e) { return VisitExpr(e); } // Short hande code to get a constant int 32 llvm::Constant* ConstInt32(int64_t value) const { return llvm::ConstantInt::getSigned(t_int32_, value); @@ -170,7 +163,7 @@ class CodeGenLLVM : * \tparam F The function to be executed. * \return The result. */ - template + template llvm::AllocaInst* WithFunctionEntry(F falloca) { llvm::BasicBlock* current = builder_->GetInsertBlock(); llvm::BasicBlock* entry = &(function_->getEntryBlock()); @@ -191,8 +184,7 @@ class CodeGenLLVM : virtual void InitPassManagerBuilder(llvm::PassManagerBuilder* builder); // Scalarize by iterating elements of e. // f is a callback that takes index and v. - virtual void Scalarize(const PrimExpr& e, - std::function f); + virtual void Scalarize(const PrimExpr& e, std::function f); // Initialize target virtual void InitTarget(llvm::TargetMachine* tm); // Add module startup function if needed. @@ -205,8 +197,7 @@ class CodeGenLLVM : virtual unsigned GetGlobalAddressSpace() const; void AddFunctionInternal(const PrimFunc& f, bool ret_void); // Create extern call - llvm::CallInst* CreateCallExtern(llvm::Type* ret, - const std::string& name, + llvm::CallInst* CreateCallExtern(llvm::Type* ret, const std::string& name, const std::vector& value); /*! * \brief Get the LLVM Type for a given runtime type. @@ -243,20 +234,18 @@ class CodeGenLLVM : * could not be generated (e.g. if the argument/return types do not * match). */ - llvm::Function* GetIntrinsicDecl(llvm::Intrinsic::ID id, - llvm::Type* ret_type, + llvm::Function* GetIntrinsicDecl(llvm::Intrinsic::ID id, llvm::Type* ret_type, llvm::ArrayRef arg_types); // initialize the function state. void InitFuncState(); // Get alignment given index. - void GetAlignment( - DataType t, const VarNode* buf_var, const PrimExpr& index, - int* p_alignment, int* p_native_bits); + void GetAlignment(DataType t, const VarNode* buf_var, const PrimExpr& index, int* p_alignment, + int* p_native_bits); // Get constant string llvm::Value* GetConstString(const std::string& str); // do a scalarize call with f - llvm::Value* CreateScalarizedCall( - const CallNode* op, llvm::Function* f, const std::vector& args); + llvm::Value* CreateScalarizedCall(const CallNode* op, llvm::Function* f, + const std::vector& args); // handle module import void HandleImport(const std::string& code); // cast operatpr @@ -279,9 +268,7 @@ class CodeGenLLVM : llvm::Value* CreateVecConcat(std::vector vecs); llvm::Value* CreateVecPad(llvm::Value* vec, int target_lanes); // Create serial for - void CreateSerialFor(llvm::Value* begin, - llvm::Value* end, - llvm::Value* stride, + void CreateSerialFor(llvm::Value* begin, llvm::Value* end, llvm::Value* stride, const Var& loop_var, const Stmt& body); // add alias information. void AddAliasInfo(llvm::Instruction* load, const VarNode* buffer, PrimExpr index, DataType type); diff --git a/src/target/llvm/codegen_nvptx.cc b/src/target/llvm/codegen_nvptx.cc index 40dc653f742b..bc47ce1b1014 100644 --- a/src/target/llvm/codegen_nvptx.cc +++ b/src/target/llvm/codegen_nvptx.cc @@ -24,9 +24,10 @@ #ifdef TVM_LLVM_VERSION #include -#include "codegen_llvm.h" -#include "../build_common.h" + #include "../../runtime/cuda/cuda_module.h" +#include "../build_common.h" +#include "codegen_llvm.h" namespace tvm { namespace codegen { @@ -39,10 +40,9 @@ class CodeGenNVPTX : public CodeGenLLVM { CodeGenLLVM::AddFunctionInternal(f, true); // annotate as kernel function module_->getOrInsertNamedMetadata("nvvm.annotations") - ->addOperand(llvm::MDNode::get(*ctx_, { - llvm::ValueAsMetadata::get(function_), - llvm::MDString::get(*ctx_, "kernel"), - llvm::ValueAsMetadata::get(ConstInt32(1)) })); + ->addOperand(llvm::MDNode::get( + *ctx_, {llvm::ValueAsMetadata::get(function_), llvm::MDString::get(*ctx_, "kernel"), + llvm::ValueAsMetadata::get(ConstInt32(1))})); } void VisitStmt_(const AllocateNode* op) final { @@ -50,8 +50,7 @@ class CodeGenNVPTX : public CodeGenLLVM { llvm::Value* buf = nullptr; int32_t constant_size = op->constant_allocation_size(); - CHECK_GT(constant_size, 0) - << "Can only handle constant size stack allocation in GPU"; + CHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation in GPU"; StorageInfo& info = alloc_storage_info_[op->buffer_var.get()]; if (constant_size % 4 == 0 && info.alignment == 0) { info.alignment = GetTempAllocaAlignment(op->dtype, constant_size); @@ -65,9 +64,8 @@ class CodeGenNVPTX : public CodeGenLLVM { // const int local_address_space = 5; // TODO(tqchen): for higher version of LLVM, local address space can be set. llvm::AllocaInst* alloca = WithFunctionEntry([&]() { - return builder_->CreateAlloca( - DTypeToLLVMType(op->dtype), ConstInt32(constant_size)); - }); + return builder_->CreateAlloca(DTypeToLLVMType(op->dtype), ConstInt32(constant_size)); + }); if (alloca->getAlignment() < static_cast(info.alignment)) { #if TVM_LLVM_VERSION >= 100 alloca->setAlignment(llvm::Align(info.alignment)); @@ -81,12 +79,11 @@ class CodeGenNVPTX : public CodeGenLLVM { << "Can only allocate shared or local memory inside kernel"; // Shared memory: address space == 3 const unsigned shared_address_space = 3; - llvm::Type* type = llvm::ArrayType::get( - DTypeToLLVMType(op->dtype), constant_size); + llvm::Type* type = llvm::ArrayType::get(DTypeToLLVMType(op->dtype), constant_size); // Allocate shared memory in global, address_space = 3 - llvm::GlobalVariable *global = new llvm::GlobalVariable( - *module_, type, false, llvm::GlobalValue::PrivateLinkage, 0, ".shared", - nullptr, llvm::GlobalValue::NotThreadLocal, shared_address_space); + llvm::GlobalVariable* global = new llvm::GlobalVariable( + *module_, type, false, llvm::GlobalValue::PrivateLinkage, 0, ".shared", nullptr, + llvm::GlobalValue::NotThreadLocal, shared_address_space); #if TVM_LLVM_VERSION >= 100 global->setAlignment(llvm::Align(info.alignment)); #else @@ -96,8 +93,7 @@ class CodeGenNVPTX : public CodeGenLLVM { } buf = builder_->CreatePointerCast( - buf, DTypeToLLVMType(op->dtype)->getPointerTo( - buf->getType()->getPointerAddressSpace())); + buf, DTypeToLLVMType(op->dtype)->getPointerTo(buf->getType()->getPointerAddressSpace())); CHECK(!var_map_.count(op->buffer_var.get())); var_map_[op->buffer_var.get()] = buf; this->VisitStmt(op->body); @@ -105,22 +101,36 @@ class CodeGenNVPTX : public CodeGenLLVM { // Return the thread index via intrinsics. llvm::Value* GetThreadIndex(const IterVar& iv) final { - runtime::ThreadScope ts = runtime::ThreadScope::make(iv->thread_tag); + runtime::ThreadScope ts = runtime::ThreadScope::Create(iv->thread_tag); llvm::Intrinsic::ID intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x; if (ts.rank == 1) { switch (ts.dim_index) { - case 0: intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x; break; - case 1: intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_tid_y; break; - case 2: intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_tid_z; break; - default: LOG(FATAL) << "unknown thread idx"; + case 0: + intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x; + break; + case 1: + intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_tid_y; + break; + case 2: + intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_tid_z; + break; + default: + LOG(FATAL) << "unknown thread idx"; } } else { CHECK_EQ(ts.rank, 0); switch (ts.dim_index) { - case 0: intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x; break; - case 1: intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_y; break; - case 2: intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_z; break; - default: LOG(FATAL) << "unknown thread idx"; + case 0: + intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x; + break; + case 1: + intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_y; + break; + case 2: + intrin_id = ::llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_z; + break; + default: + LOG(FATAL) << "unknown thread idx"; } } llvm::Function* f = llvm::Intrinsic::getDeclaration(module_.get(), intrin_id); @@ -133,9 +143,8 @@ class CodeGenNVPTX : public CodeGenLLVM { // TODO(tqchen) warp sync in CUDA9 return nullptr; } else if (sync == "shared") { - llvm::Function* f = llvm::Intrinsic::getDeclaration( - module_.get(), - ::llvm::Intrinsic::nvvm_barrier0); + llvm::Function* f = + llvm::Intrinsic::getDeclaration(module_.get(), ::llvm::Intrinsic::nvvm_barrier0); return builder_->CreateCall(f, {}); } else { LOG(FATAL) << "Do not support sync " << sync; @@ -161,6 +170,8 @@ class CodeGenNVPTX : public CodeGenLLVM { CodeGenLLVM::Optimize(); } + llvm::Value* CreateIntrinsic(const CallNode* op) override; + protected: void InitTarget(llvm::TargetMachine* tm) final { // Maximum vector lane = float4 @@ -169,16 +180,70 @@ class CodeGenNVPTX : public CodeGenLLVM { } }; +// Check if this is a warp shuffle intrinsic call and match its +// corresponding nvvm intrinsic. Return true if the match is successful. +static bool GetWarpShuffleIntrinsic(const CallNode* op, llvm::Intrinsic::ID* id) { + // Only 32 bit data type is supported. + if (op->dtype.is_vector() || op->dtype.bits() != 32) { + return false; + } + + // Intrinsic lookup table. + // It is difficult to emit _sync verion that works on Pascal. + // We ignore the mask and only emit the non-sync version for nvptx. + llvm::Intrinsic::ID ids[] = { + llvm::Intrinsic::nvvm_shfl_idx_i32, llvm::Intrinsic::nvvm_shfl_idx_f32, + llvm::Intrinsic::nvvm_shfl_up_i32, llvm::Intrinsic::nvvm_shfl_up_f32, + llvm::Intrinsic::nvvm_shfl_down_i32, llvm::Intrinsic::nvvm_shfl_down_f32}; + + int offset = 0; + if (op->is_intrinsic(intrinsic::tvm_warp_shuffle)) { + offset = 0; + } else if (op->is_intrinsic(intrinsic::tvm_warp_shuffle_up)) { + offset = 2; + } else if (op->is_intrinsic(intrinsic::tvm_warp_shuffle_down)) { + offset = 4; + } else { + return false; + } + + *id = ids[offset + op->dtype.is_float()]; + return true; +} + +llvm::Value* CodeGenNVPTX::CreateIntrinsic(const CallNode* op) { + llvm::Intrinsic::ID id = llvm::Intrinsic::not_intrinsic; + if (GetWarpShuffleIntrinsic(op, &id)) { + std::vector arg_value; + std::vector arg_type; + // Ignore the first mask operand and remove the last + // redundant warp_size.. + size_t n_args = op->args.size() - 1; + for (size_t i = 1; i < n_args; ++i) { + arg_value.push_back(MakeValue(op->args[i])); + arg_type.push_back(arg_value.back()->getType()); + } + llvm::Type* return_type = arg_type[0]; + llvm::Function* func = GetIntrinsicDecl(id, return_type, arg_type); + return builder_->CreateCall(func, arg_value); + } else if (op->is_intrinsic(intrinsic::tvm_warp_activemask)) { + // Only nvptx target may keep this intrinsic at this point. + // PTX assembly: asm "activemask.b32 r1;" + auto fty = llvm::FunctionType::get(t_int32_, false); + auto val = llvm::InlineAsm::get(fty, "activemask.b32 %0", "=r", true); + return builder_->CreateCall(val); + } + return CodeGenLLVM::CreateIntrinsic(op); +} + inline int DetectCUDAComputeVersion() { TVMContext tvm_ctx; tvm_ctx.device_type = kDLGPU; tvm_ctx.device_id = 0; TVMRetValue val; - tvm::runtime::DeviceAPI::Get(tvm_ctx)->GetAttr( - tvm_ctx, tvm::runtime::kExist, &val); + tvm::runtime::DeviceAPI::Get(tvm_ctx)->GetAttr(tvm_ctx, tvm::runtime::kExist, &val); if (val.operator int() == 1) { - tvm::runtime::DeviceAPI::Get(tvm_ctx)->GetAttr( - tvm_ctx, tvm::runtime::kComputeVersion, &val); + tvm::runtime::DeviceAPI::Get(tvm_ctx)->GetAttr(tvm_ctx, tvm::runtime::kComputeVersion, &val); std::string version = val; std::istringstream is(version); double ver; @@ -191,28 +256,26 @@ inline int DetectCUDAComputeVersion() { runtime::Module BuildNVPTX(IRModule mod, std::string target) { InitializeLLVM(); - CHECK(target.length() >= 5 && - target.substr(0, 5) == "nvptx"); + CHECK(target.length() >= 5 && target.substr(0, 5) == "nvptx"); int compute_ver = DetectCUDAComputeVersion(); std::ostringstream config; - config << "-mtriple=nvptx64-nvidia-cuda -mcpu=sm_" - << compute_ver + config << "-mtriple=nvptx64-nvidia-cuda -mcpu=sm_" << compute_ver << target.substr(5, target.length() - 5); std::unique_ptr tm = GetLLVMTargetMachine(config.str()); - std::unique_ptr cg(new CodeGenNVPTX()); std::unique_ptr ctx(new llvm::LLVMContext()); + // careful: cg will hold a naked pointer reference to ctx, so it should + // have a shorter lifetime than the ctx. + std::unique_ptr cg(new CodeGenNVPTX()); cg->Init("TVMPTXModule", tm.get(), ctx.get(), false, false); - for (auto kv : mod->functions) { - CHECK(kv.second->IsInstance()) - << "Can only lower IR Module with PrimFuncs"; + for (auto kv : mod->functions) { + CHECK(kv.second->IsInstance()) << "Can only lower IR Module with PrimFuncs"; auto f = Downcast(kv.second); cg->AddFunction(f); } - const auto* flibdevice_path = - tvm::runtime::Registry::Get("tvm_callback_libdevice_path"); + const auto* flibdevice_path = tvm::runtime::Registry::Get("tvm_callback_libdevice_path"); if (flibdevice_path != nullptr) { std::string path = (*flibdevice_path)(compute_ver); if (path.length() != 0) { @@ -239,16 +302,14 @@ runtime::Module BuildNVPTX(IRModule mod, std::string target) { // emit ptx llvm::legacy::PassManager pass; #if TVM_LLVM_VERSION <= 60 - CHECK(tm->addPassesToEmitFile( - pass, dest_ptx, llvm::TargetMachine::CGFT_AssemblyFile) == 0) + CHECK(tm->addPassesToEmitFile(pass, dest_ptx, llvm::TargetMachine::CGFT_AssemblyFile) == 0) << "Cannot emit target CGFT_ObjectFile"; #elif TVM_LLVM_VERSION <= 90 - CHECK(tm->addPassesToEmitFile( - pass, dest_ptx, nullptr, llvm::TargetMachine::CGFT_AssemblyFile) == 0) + CHECK(tm->addPassesToEmitFile(pass, dest_ptx, nullptr, llvm::TargetMachine::CGFT_AssemblyFile) == + 0) << "Cannot emit target CGFT_ObjectFile"; #else - CHECK(tm->addPassesToEmitFile( - pass, dest_ptx, nullptr, llvm::CGFT_AssemblyFile) == 0) + CHECK(tm->addPassesToEmitFile(pass, dest_ptx, nullptr, llvm::CGFT_AssemblyFile) == 0) << "Cannot emit target CGFT_ObjectFile"; #endif pass.run(*module); @@ -256,8 +317,7 @@ runtime::Module BuildNVPTX(IRModule mod, std::string target) { return CUDAModuleCreate(ptx, "ptx", ExtractFuncInfo(mod), ll); } -TVM_REGISTER_GLOBAL("target.build.nvptx") -.set_body_typed(BuildNVPTX); +TVM_REGISTER_GLOBAL("target.build.nvptx").set_body_typed(BuildNVPTX); } // namespace codegen } // namespace tvm diff --git a/src/target/llvm/codegen_x86_64.cc b/src/target/llvm/codegen_x86_64.cc index 477d79481111..edffda287c7b 100644 --- a/src/target/llvm/codegen_x86_64.cc +++ b/src/target/llvm/codegen_x86_64.cc @@ -24,8 +24,8 @@ #ifdef TVM_LLVM_VERSION #include -#include "codegen_cpu.h" +#include "codegen_cpu.h" #include "llvm/MC/MCSubtargetInfo.h" namespace tvm { @@ -89,14 +89,11 @@ llvm::Value* CodeGenX86_64::VisitExpr_(const CastNode* op) { ::llvm::Intrinsic::x86_avx512_mask_vcvtph2ps_512, 16, DTypeToLLVMType(DataType::Float(32, from.lanes())), { - MakeValue(tir::CallNode::make( - DataType::Int(16, from.lanes()), tir::CallNode::reinterpret, {op->value}, - tir::CallNode::PureIntrinsic)), - MakeValue( - tir::BroadcastNode::make( - FloatImm(DataType::Float(32), 0), from.lanes())), - /*mask=*/MakeValue(IntImm(DataType::Int(16), -1)), - /*rounding-mode=*/MakeValue(IntImm(DataType::Int(32), 4)), + MakeValue(tir::Call(DataType::Int(16, from.lanes()), tir::CallNode::reinterpret, + {op->value}, tir::CallNode::PureIntrinsic)), + MakeValue(tir::Broadcast(FloatImm(DataType::Float(32), 0), from.lanes())), + /*mask=*/MakeValue(IntImm(DataType::Int(16), -1)), + /*rounding-mode=*/MakeValue(IntImm(DataType::Int(32), 4)), }); } @@ -108,9 +105,8 @@ llvm::Value* CodeGenX86_64::VisitExpr_(const CastNode* op) { return CallVectorIntrin( ::llvm::Intrinsic::x86_vcvtph2ps_256, 8, DTypeToLLVMType(DataType::Float(32, from.lanes())), - {MakeValue(tir::CallNode::make( - DataType::Int(16, from.lanes()), tir::CallNode::reinterpret, {op->value}, - tir::CallNode::PureIntrinsic))}); + {MakeValue(tir::Call(DataType::Int(16, from.lanes()), tir::CallNode::reinterpret, + {op->value}, tir::CallNode::PureIntrinsic))}); } #endif } @@ -123,21 +119,20 @@ llvm::Value* CodeGenX86_64::CallVectorIntrin(llvm::Intrinsic::ID id, size_t intr const std::vector& args) { llvm::Function* f = llvm::Intrinsic::getDeclaration(module_.get(), id, {}); - if (intrin_lanes == result_ty->getVectorNumElements()) { + size_t num_elems = llvm::cast(result_ty)->getNumElements(); + if (intrin_lanes == num_elems) { return builder_->CreateCall(f, args); } // Otherwise, we split the vector into intrin_lanes sized elements (widening where necessary), // compute each result, and then concatenate the vectors (slicing the result if necessary). - CHECK_LT(intrin_lanes, result_ty->getVectorNumElements()); + CHECK_LT(intrin_lanes, num_elems); std::vector split_results; - for (size_t i = 0; - i < static_cast(result_ty->getVectorNumElements()); - i += intrin_lanes) { + for (size_t i = 0; i < num_elems; i += intrin_lanes) { std::vector split_args; for (const auto& v : args) { if (v->getType()->isVectorTy()) { - CHECK_EQ(v->getType()->getVectorNumElements(), result_ty->getVectorNumElements()); + CHECK_EQ(llvm::cast(v->getType())->getNumElements(), num_elems); split_args.push_back(CreateVecSlice(v, i, intrin_lanes)); } else { split_args.push_back(v); @@ -147,14 +142,14 @@ llvm::Value* CodeGenX86_64::CallVectorIntrin(llvm::Intrinsic::ID id, size_t intr id, intrin_lanes, llvm::VectorType::get(result_ty->getScalarType(), intrin_lanes), split_args)); } - return CreateVecSlice(CreateVecConcat(split_results), 0, result_ty->getVectorNumElements()); + return CreateVecSlice(CreateVecConcat(split_results), 0, num_elems); } TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_x86-64") -.set_body([](const TVMArgs& targs, TVMRetValue* rv) { - CodeGenLLVM* cg = new CodeGenX86_64(); - *rv = static_cast(cg); - }); + .set_body([](const TVMArgs& targs, TVMRetValue* rv) { + CodeGenLLVM* cg = new CodeGenX86_64(); + *rv = static_cast(cg); + }); } // namespace codegen } // namespace tvm diff --git a/src/target/llvm/intrin_rule_llvm.cc b/src/target/llvm/intrin_rule_llvm.cc index 58bfb371c577..8804b1e45a6f 100644 --- a/src/target/llvm/intrin_rule_llvm.cc +++ b/src/target/llvm/intrin_rule_llvm.cc @@ -22,153 +22,143 @@ */ #ifdef TVM_LLVM_VERSION -#include #include "intrin_rule_llvm.h" +#include + namespace tvm { namespace codegen { namespace llvm { TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.prefetch") -.set_body(DispatchLLVMIntrin<::llvm::Intrinsic::prefetch, 4>); + .set_body(DispatchLLVMIntrin<::llvm::Intrinsic::prefetch, 4>); TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp") -.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::exp, 1>); + .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::exp, 1>); TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp2") -.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::exp2, 1>); + .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::exp2, 1>); TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp10") -.set_body([](const TVMArgs& targs, TVMRetValue* rv) { - using tir::make_const; - using tir::make_zero; - PrimExpr e = targs[0]; - const tir::CallNode* call = e.as(); - CHECK(call != nullptr); - const PrimExpr& x = call->args[0]; - PrimExpr ln10 = make_const(x.dtype(), 2.302585093); - PrimExpr ret = tir::CallNode::make( - x.dtype(), "exp", {x * ln10}, tir::CallNode::PureIntrinsic); - *rv = ret; -}); + .set_body([](const TVMArgs& targs, TVMRetValue* rv) { + using tir::make_const; + using tir::make_zero; + PrimExpr e = targs[0]; + const tir::CallNode* call = e.as(); + CHECK(call != nullptr); + const PrimExpr& x = call->args[0]; + PrimExpr ln10 = make_const(x.dtype(), 2.302585093); + PrimExpr ret = tir::Call(x.dtype(), "exp", {x * ln10}, tir::CallNode::PureIntrinsic); + *rv = ret; + }); TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.fma") -.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::fmuladd, 3>); + .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::fmuladd, 3>); TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.log") -.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::log, 1>); + .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::log, 1>); TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.log2") -.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::log2, 1>); + .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::log2, 1>); TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.log10") -.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::log10, 1>); + .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::log10, 1>); TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.sqrt") -.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::sqrt, 1>); + .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::sqrt, 1>); TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.floor") -.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::floor, 1>); + .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::floor, 1>); TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.ceil") -.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::ceil, 1>); + .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::ceil, 1>); TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.trunc") -.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::trunc, 1>); + .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::trunc, 1>); TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.fabs") -.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::fabs, 1>); + .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::fabs, 1>); TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.round") -.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::round, 1>); + .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::round, 1>); TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.nearbyint") -.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::nearbyint, 1>); + .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::nearbyint, 1>); TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.tanh") -.set_body([](const TVMArgs& targs, TVMRetValue* rv) { - using tir::make_const; - using tir::make_zero; - PrimExpr e = targs[0]; - const tir::CallNode* call = e.as(); - CHECK(call != nullptr); - const PrimExpr& x = call->args[0]; - PrimExpr one = make_const(x.dtype(), 1); - PrimExpr two = make_const(x.dtype(), 2); - PrimExpr neg_two = make_const(x.dtype(), -2); - - PrimExpr exp_neg2x = tir::CallNode::make( - x.dtype(), "exp", {neg_two * x}, tir::CallNode::PureIntrinsic); - PrimExpr exp_pos2x = tir::CallNode::make( - x.dtype(), "exp", {two * x}, tir::CallNode::PureIntrinsic); - - PrimExpr tanh_pos = (one - exp_neg2x) / (one + exp_neg2x); - PrimExpr tanh_neg = (exp_pos2x - one) / (exp_pos2x + one); - *rv = tir::SelectNode::make( - x >= make_zero(x.dtype()), tanh_pos, tanh_neg); -}); + .set_body([](const TVMArgs& targs, TVMRetValue* rv) { + using tir::make_const; + using tir::make_zero; + PrimExpr e = targs[0]; + const tir::CallNode* call = e.as(); + CHECK(call != nullptr); + const PrimExpr& x = call->args[0]; + PrimExpr one = make_const(x.dtype(), 1); + PrimExpr two = make_const(x.dtype(), 2); + PrimExpr neg_two = make_const(x.dtype(), -2); + + PrimExpr exp_neg2x = tir::Call(x.dtype(), "exp", {neg_two * x}, tir::CallNode::PureIntrinsic); + PrimExpr exp_pos2x = tir::Call(x.dtype(), "exp", {two * x}, tir::CallNode::PureIntrinsic); + + PrimExpr tanh_pos = (one - exp_neg2x) / (one + exp_neg2x); + PrimExpr tanh_neg = (exp_pos2x - one) / (exp_pos2x + one); + *rv = tir::Select(x >= make_zero(x.dtype()), tanh_pos, tanh_neg); + }); TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.pow") -.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::pow, 2>); + .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::pow, 2>); TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.popcount") -.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::ctpop, 1>); + .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::ctpop, 1>); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.tan") -.set_body([](const TVMArgs& targs, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.tan").set_body([](const TVMArgs& targs, TVMRetValue* rv) { PrimExpr e = targs[0]; const tir::CallNode* call = e.as(); CHECK(call != nullptr); const PrimExpr& x = call->args[0]; - PrimExpr sin_x = tir::CallNode::make( - x.dtype(), "sin", {x}, tir::CallNode::PureIntrinsic); - PrimExpr cos_x = tir::CallNode::make( - x.dtype(), "cos", {x}, tir::CallNode::PureIntrinsic); + PrimExpr sin_x = tir::Call(x.dtype(), "sin", {x}, tir::CallNode::PureIntrinsic); + PrimExpr cos_x = tir::Call(x.dtype(), "cos", {x}, tir::CallNode::PureIntrinsic); PrimExpr tan_x = sin_x / cos_x; *rv = tan_x; }); TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.cos") -.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::cos, 1>); + .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::cos, 1>); TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.cosh") -.set_body([](const TVMArgs& targs, TVMRetValue* rv) { - using tir::make_const; - using tir::make_zero; - PrimExpr e = targs[0]; - const tir::CallNode* call = e.as(); - CHECK(call != nullptr); - const PrimExpr& x = call->args[0]; - PrimExpr two = make_const(x.dtype(), 2); - PrimExpr neg_one = make_const(x.dtype(), -1); - PrimExpr exp_negx = tir::CallNode::make( - x.dtype(), "exp", {neg_one * x}, tir::CallNode::PureIntrinsic); - PrimExpr exp_posx = tir::CallNode::make( - x.dtype(), "exp", {x}, tir::CallNode::PureIntrinsic); - PrimExpr ret = (exp_posx + exp_negx) / two; - *rv = ret; -}); + .set_body([](const TVMArgs& targs, TVMRetValue* rv) { + using tir::make_const; + using tir::make_zero; + PrimExpr e = targs[0]; + const tir::CallNode* call = e.as(); + CHECK(call != nullptr); + const PrimExpr& x = call->args[0]; + PrimExpr two = make_const(x.dtype(), 2); + PrimExpr neg_one = make_const(x.dtype(), -1); + PrimExpr exp_negx = tir::Call(x.dtype(), "exp", {neg_one * x}, tir::CallNode::PureIntrinsic); + PrimExpr exp_posx = tir::Call(x.dtype(), "exp", {x}, tir::CallNode::PureIntrinsic); + PrimExpr ret = (exp_posx + exp_negx) / two; + *rv = ret; + }); TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.sin") -.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::sin, 1>); + .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::sin, 1>); TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.sinh") -.set_body([](const TVMArgs& targs, TVMRetValue* rv) { - using tir::make_const; - using tir::make_zero; - PrimExpr e = targs[0]; - const tir::CallNode* call = e.as(); - CHECK(call != nullptr); - const PrimExpr& x = call->args[0]; - PrimExpr two = make_const(x.dtype(), 2); - PrimExpr neg_one = make_const(x.dtype(), -1); - PrimExpr exp_negx = tir::CallNode::make( - x.dtype(), "exp", {neg_one * x}, tir::CallNode::PureIntrinsic); - PrimExpr exp_posx = tir::CallNode::make( - x.dtype(), "exp", {x}, tir::CallNode::PureIntrinsic); - PrimExpr ret = (exp_posx - exp_negx) / two; - *rv = ret; -}); + .set_body([](const TVMArgs& targs, TVMRetValue* rv) { + using tir::make_const; + using tir::make_zero; + PrimExpr e = targs[0]; + const tir::CallNode* call = e.as(); + CHECK(call != nullptr); + const PrimExpr& x = call->args[0]; + PrimExpr two = make_const(x.dtype(), 2); + PrimExpr neg_one = make_const(x.dtype(), -1); + PrimExpr exp_negx = tir::Call(x.dtype(), "exp", {neg_one * x}, tir::CallNode::PureIntrinsic); + PrimExpr exp_posx = tir::Call(x.dtype(), "exp", {x}, tir::CallNode::PureIntrinsic); + PrimExpr ret = (exp_posx - exp_negx) / two; + *rv = ret; + }); } // namespace llvm } // namespace codegen diff --git a/src/target/llvm/intrin_rule_llvm.h b/src/target/llvm/intrin_rule_llvm.h index bb9ff66c9cb5..5613621d77fb 100644 --- a/src/target/llvm/intrin_rule_llvm.h +++ b/src/target/llvm/intrin_rule_llvm.h @@ -25,17 +25,18 @@ #define TVM_TARGET_LLVM_INTRIN_RULE_LLVM_H_ #ifdef TVM_LLVM_VERSION -#include #include - #include +#include + #include + #include "llvm_common.h" namespace tvm { namespace codegen { // num_signature means number of arguments used to query signature -template +template inline void DispatchLLVMPureIntrin(const TVMArgs& targs, TVMRetValue* rv) { PrimExpr e = targs[0]; const tir::CallNode* call = e.as(); @@ -48,11 +49,10 @@ inline void DispatchLLVMPureIntrin(const TVMArgs& targs, TVMRetValue* rv) { for (PrimExpr arg : call->args) { cargs.push_back(arg); } - *rv = tir::CallNode::make( - call->dtype, "llvm_intrin", cargs, tir::CallNode::PureIntrinsic); + *rv = tir::Call(call->dtype, "llvm_intrin", cargs, tir::CallNode::PureIntrinsic); } -template +template inline void DispatchLLVMIntrin(const TVMArgs& targs, TVMRetValue* rv) { PrimExpr e = targs[0]; const tir::CallNode* call = e.as(); @@ -64,8 +64,7 @@ inline void DispatchLLVMIntrin(const TVMArgs& targs, TVMRetValue* rv) { for (PrimExpr arg : call->args) { cargs.push_back(arg); } - *rv = tir::CallNode::make( - call->dtype, "llvm_intrin", cargs, tir::CallNode::Intrinsic); + *rv = tir::Call(call->dtype, "llvm_intrin", cargs, tir::CallNode::Intrinsic); } } // namespace codegen diff --git a/src/target/llvm/intrin_rule_nvptx.cc b/src/target/llvm/intrin_rule_nvptx.cc index 0dc1272d7d49..49c2224932a5 100644 --- a/src/target/llvm/intrin_rule_nvptx.cc +++ b/src/target/llvm/intrin_rule_nvptx.cc @@ -22,9 +22,9 @@ */ #ifdef TVM_LLVM_VERSION -#include -#include #include +#include + #include namespace tvm { @@ -39,77 +39,54 @@ inline void DispatchExternLibDevice(const TVMArgs& args, TVMRetValue* rv) { std::ostringstream intrinsic_name; intrinsic_name << "__nv_" << call->name; if (call->dtype.bits() == 32) intrinsic_name << "f"; - *rv = CallNode::make(call->dtype, intrinsic_name.str(), call->args, - CallNode::PureExtern); + *rv = Call(call->dtype, intrinsic_name.str(), call->args, CallNode::PureExtern); } namespace llvm { -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.floor") -.set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.floor").set_body(DispatchExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.ceil") -.set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.ceil").set_body(DispatchExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.round") -.set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.round").set_body(DispatchExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.trunc") -.set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.trunc").set_body(DispatchExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.fabs") -.set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.fabs").set_body(DispatchExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.exp") -.set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.exp").set_body(DispatchExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.exp2") -.set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.exp2").set_body(DispatchExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.exp10") -.set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.exp10").set_body(DispatchExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.erf") -.set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.erf").set_body(DispatchExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.fma") -.set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.fma").set_body(DispatchExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.log") -.set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.log").set_body(DispatchExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.log2") -.set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.log2").set_body(DispatchExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.log10") -.set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.log10").set_body(DispatchExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.sqrt") -.set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.sqrt").set_body(DispatchExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.pow") -.set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.pow").set_body(DispatchExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.tanh") -.set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.tanh").set_body(DispatchExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.tan") -.set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.tan").set_body(DispatchExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.cos") -.set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.cos").set_body(DispatchExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.cosh") -.set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.cosh").set_body(DispatchExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.sin") -.set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.sin").set_body(DispatchExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.sinh") -.set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.sinh").set_body(DispatchExternLibDevice); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.atan") -.set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.atan").set_body(DispatchExternLibDevice); } // namespace llvm } // namespace codegen diff --git a/src/target/llvm/intrin_rule_rocm.cc b/src/target/llvm/intrin_rule_rocm.cc index 3699c9f691b3..3a2b8ac77f82 100644 --- a/src/target/llvm/intrin_rule_rocm.cc +++ b/src/target/llvm/intrin_rule_rocm.cc @@ -22,9 +22,9 @@ */ #ifdef TVM_LLVM_VERSION -#include -#include #include +#include +#include #include @@ -38,77 +38,105 @@ inline void DispatchExternOCML(const TVMArgs& args, TVMRetValue* rv) { CHECK(call != nullptr); std::ostringstream intrinsic_name; intrinsic_name << "__ocml_" << call->name << "_f" << call->dtype.bits(); - *rv = CallNode::make(call->dtype, intrinsic_name.str(), call->args, - CallNode::PureExtern); + *rv = Call(call->dtype, intrinsic_name.str(), call->args, CallNode::PureExtern); +} + +inline void DispatchShuffle(const TVMArgs& targs, TVMRetValue* rv) { + PrimExpr e_call = targs[0]; + using namespace tir; + const CallNode* call = e_call.as(); + CHECK(call != nullptr); + CHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size + PrimExpr var = call->args[1]; + CHECK_EQ(var.dtype().bits(), 32); + + // get own lane in self (__lane_id) + PrimExpr minus_one = tir::make_const(DataType::Int(32), -1); + PrimExpr zero = tir::make_zero(DataType::Int(32)); + PrimExpr lo = + Call(DataType::Int(32), "llvm.amdgcn.mbcnt.lo", {minus_one, zero}, CallNode::PureExtern); + PrimExpr self = + Call(DataType::Int(32), "llvm.amdgcn.mbcnt.hi", {minus_one, lo}, CallNode::PureExtern); + + // compute lane to get from + PrimExpr width = call->args[3]; + PrimExpr index; + if (call->name == "tvm_warp_shuffle") { + PrimExpr src_lane = call->args[2]; + index = src_lane + (self & ~(width - 1)); + } else if (call->name == "tvm_warp_shuffle_up") { + PrimExpr delta = call->args[2]; + index = self - delta; + index = Select(index < (self & ~(width - 1)), self, index); + } else { + CHECK_EQ(call->name, "tvm_warp_shuffle_down"); + PrimExpr delta = call->args[2]; + index = self + delta; + index = Select((self & (width - 1)) + delta >= width, self, index); + } + PrimExpr res = + Call(var.dtype(), "llvm.amdgcn.ds.bpermute", {index << 2, var}, CallNode::PureExtern); + *rv = res; } namespace llvm { -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.floor") -.set_body(DispatchExternOCML); +// dummy because we don't have the activemask +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.tvm_warp_activemask") + .set_body([](const TVMArgs& targs, TVMRetValue* rv) { + PrimExpr zero = tir::make_zero(DataType::Int(32)); + *rv = zero; + }); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.tvm_warp_shuffle").set_body(DispatchShuffle); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.tvm_warp_shuffle_up").set_body(DispatchShuffle); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.tvm_warp_shuffle_down").set_body(DispatchShuffle); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.floor").set_body(DispatchExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.ceil") -.set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.ceil").set_body(DispatchExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.round") -.set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.round").set_body(DispatchExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.trunc") -.set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.trunc").set_body(DispatchExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.fabs") -.set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.fabs").set_body(DispatchExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.exp") -.set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.exp").set_body(DispatchExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.exp2") -.set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.exp2").set_body(DispatchExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.exp10") -.set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.exp10").set_body(DispatchExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.erf") -.set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.erf").set_body(DispatchExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.fma") -.set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.fma").set_body(DispatchExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.log") -.set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.log").set_body(DispatchExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.log2") -.set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.log2").set_body(DispatchExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.log10") -.set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.log10").set_body(DispatchExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.sqrt") -.set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.sqrt").set_body(DispatchExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.pow") -.set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.pow").set_body(DispatchExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.tanh") -.set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.tanh").set_body(DispatchExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.tan") -.set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.tan").set_body(DispatchExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.cos") -.set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.cos").set_body(DispatchExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.cosh") -.set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.cosh").set_body(DispatchExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.sin") -.set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.sin").set_body(DispatchExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.sinh") -.set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.sinh").set_body(DispatchExternOCML); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.atan") -.set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.atan").set_body(DispatchExternOCML); } // namespace llvm } // namespace codegen diff --git a/src/target/llvm/llvm_common.cc b/src/target/llvm/llvm_common.cc index 29e4db3c91da..5534a643676c 100644 --- a/src/target/llvm/llvm_common.cc +++ b/src/target/llvm/llvm_common.cc @@ -22,11 +22,13 @@ */ #ifdef TVM_LLVM_VERSION +#include "llvm_common.h" + #include + #include -#include #include -#include "llvm_common.h" +#include namespace tvm { namespace codegen { @@ -56,15 +58,11 @@ void InitializeLLVM() { } } -void ParseLLVMTargetOptions(const std::string& target_str, - std::string* triple, - std::string* mcpu, - std::string* mattr, - llvm::TargetOptions* options) { +void ParseLLVMTargetOptions(const std::string& target_str, std::string* triple, std::string* mcpu, + std::string* mattr, llvm::TargetOptions* options) { // setup target triple size_t start = 0; - if (target_str.length() >= 4 && - target_str.substr(0, 4) == "llvm") { + if (target_str.length() >= 4 && target_str.substr(0, 4) == "llvm") { start = 4; } // simple parser @@ -82,16 +80,13 @@ void ParseLLVMTargetOptions(const std::string& target_str, } size_t pos = key.find('='); if (pos != std::string::npos) { - CHECK_GE(key.length(), pos + 1) - << "invalid argument " << key; + CHECK_GE(key.length(), pos + 1) << "invalid argument " << key; value = key.substr(pos + 1, key.length() - 1); key = key.substr(0, pos); } else { - CHECK(is >> value) - << "Unspecified value for option " << key; + CHECK(is >> value) << "Unspecified value for option " << key; } - if (key == "-target" || - key == "-mtriple") { + if (key == "-target" || key == "-mtriple") { *triple = value; } else if (key == "-mcpu") { *mcpu = value; @@ -115,16 +110,15 @@ void ParseLLVMTargetOptions(const std::string& target_str, } } - if (triple->length() == 0 || - *triple == "default") { + if (triple->length() == 0 || *triple == "default") { *triple = llvm::sys::getDefaultTargetTriple(); } // set target option llvm::TargetOptions& opt = *options; opt = llvm::TargetOptions(); - #if TVM_LLVM_VERSION < 50 +#if TVM_LLVM_VERSION < 50 opt.LessPreciseFPMADOption = true; - #endif +#endif opt.AllowFPOpFusion = llvm::FPOpFusion::Fast; opt.UnsafeFPMath = false; opt.NoInfsFPMath = false; @@ -136,21 +130,14 @@ void ParseLLVMTargetOptions(const std::string& target_str, } } - -std::unique_ptr -GetLLVMTargetMachine(const std::string& target_str, - bool allow_null) { +std::unique_ptr GetLLVMTargetMachine(const std::string& target_str, + bool allow_null) { std::string target_triple, mcpu, mattr; llvm::TargetOptions opt; - ParseLLVMTargetOptions(target_str, - &target_triple, - &mcpu, - &mattr, - &opt); + ParseLLVMTargetOptions(target_str, &target_triple, &mcpu, &mattr, &opt); - if (target_triple.length() == 0 || - target_triple == "default") { + if (target_triple.length() == 0 || target_triple == "default") { target_triple = llvm::sys::getDefaultTargetTriple(); } if (mcpu.length() == 0) { @@ -158,14 +145,13 @@ GetLLVMTargetMachine(const std::string& target_str, } std::string err; - const llvm::Target* target = - llvm::TargetRegistry::lookupTarget(target_triple, err); + const llvm::Target* target = llvm::TargetRegistry::lookupTarget(target_triple, err); if (target == nullptr) { CHECK(allow_null) << err << " target_triple=" << target_triple; return nullptr; } - llvm::TargetMachine* tm = target->createTargetMachine( - target_triple, mcpu, mattr, opt, llvm::Reloc::PIC_); + llvm::TargetMachine* tm = + target->createTargetMachine(target_triple, mcpu, mattr, opt, llvm::Reloc::PIC_); return std::unique_ptr(tm); } diff --git a/src/target/llvm/llvm_common.h b/src/target/llvm/llvm_common.h index 85ee1ee97495..9a4ccfc9cf9c 100644 --- a/src/target/llvm/llvm_common.h +++ b/src/target/llvm/llvm_common.h @@ -25,14 +25,14 @@ #define TVM_TARGET_LLVM_LLVM_COMMON_H_ #ifdef TVM_LLVM_VERSION -#include - #include #include -#include - -#include +#include +#include #include +#include +#include +#include #if TVM_LLVM_VERSION >= 100 #include #include @@ -42,43 +42,41 @@ #include #include #include -#include #include +#include #include #include #include #include +#include +#include #include #include -#include #include - -#include +#include +#include #include #include -#include -#include #if TVM_LLVM_VERSION >= 100 #include #endif +#include +#include +#include +#include #include #include #include -#include -#include #include #include +#include #include #include -#include -#include -#include - -#include -#include #include +#include +#include namespace tvm { namespace codegen { @@ -97,11 +95,8 @@ void InitializeLLVM(); * \param options the options * \param mattr The attributes */ -void ParseLLVMTargetOptions(const std::string& target_str, - std::string* triple, - std::string* mcpu, - std::string* mattr, - llvm::TargetOptions* options); +void ParseLLVMTargetOptions(const std::string& target_str, std::string* triple, std::string* mcpu, + std::string* mattr, llvm::TargetOptions* options); /*! * \brief Get target machine from target_str string. @@ -109,10 +104,16 @@ void ParseLLVMTargetOptions(const std::string& target_str, * \param allow_null Whether allow null to be returned. * \return target machine */ -std::unique_ptr -GetLLVMTargetMachine(const std::string& target_str, bool allow_null = false); +std::unique_ptr GetLLVMTargetMachine(const std::string& target_str, + bool allow_null = false); } // namespace codegen } // namespace tvm + +namespace tvm { +namespace runtime { +inline String::operator llvm::StringRef() const { return llvm::StringRef(get()->data, size()); } +} // namespace runtime +} // namespace tvm #endif // TVM_LLVM_VERSION #endif // TVM_TARGET_LLVM_LLVM_COMMON_H_ diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index d1a244d01ff4..1151b33536b5 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -23,23 +23,25 @@ */ #ifdef TVM_LLVM_VERSION +#include #include #include -#include #include + #include -#include "llvm_common.h" -#include "codegen_llvm.h" -#include "codegen_blob.h" + #include "../../runtime/file_util.h" #include "../../runtime/library_module.h" +#include "codegen_blob.h" +#include "codegen_llvm.h" +#include "llvm_common.h" namespace tvm { namespace codegen { +using runtime::PackedFunc; using runtime::TVMArgs; using runtime::TVMRetValue; -using runtime::PackedFunc; class LLVMModuleNode final : public runtime::ModuleNode { public: @@ -51,24 +53,15 @@ class LLVMModuleNode final : public runtime::ModuleNode { } } - const char* type_key() const { - return "llvm"; - } + const char* type_key() const { return "llvm"; } - PackedFunc GetFunction( - const std::string& name, - const ObjectPtr& sptr_to_self) final { + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { if (name == "__tvm_is_system_module") { - bool flag = - (mptr_->getFunction("__tvm_module_startup") != nullptr); - return PackedFunc([flag](TVMArgs args, TVMRetValue *rv) { - * rv = flag; - }); + bool flag = (mptr_->getFunction("__tvm_module_startup") != nullptr); + return PackedFunc([flag](TVMArgs args, TVMRetValue* rv) { *rv = flag; }); } else if (name == "_get_target_triple") { std::string target_triple = tm_->getTargetTriple().str(); - return PackedFunc([target_triple](TVMArgs args, TVMRetValue *rv) { - *rv = target_triple; - }); + return PackedFunc([target_triple](TVMArgs args, TVMRetValue* rv) { *rv = target_triple; }); } if (ee_ == nullptr) LazyInitJIT(); @@ -76,8 +69,8 @@ class LLVMModuleNode final : public runtime::ModuleNode { TVMBackendPackedCFunc faddr; if (name == runtime::symbol::tvm_module_main) { - const char* entry_name = reinterpret_cast( - GetGlobalAddr(runtime::symbol::tvm_module_main)); + const char* entry_name = + reinterpret_cast(GetGlobalAddr(runtime::symbol::tvm_module_main)); CHECK(entry_name != nullptr) << "Symbol " << runtime::symbol::tvm_module_main << " is not presented"; faddr = reinterpret_cast(GetFunctionAddr(entry_name)); @@ -88,13 +81,11 @@ class LLVMModuleNode final : public runtime::ModuleNode { return WrapPackedFunc(faddr, sptr_to_self); } - void SaveToFile(const std::string& file_name, - const std::string& format) final { + void SaveToFile(const std::string& file_name, const std::string& format) final { std::string fmt = runtime::GetFileFormat(file_name, format); std::error_code ecode; llvm::raw_fd_ostream dest(file_name, ecode, llvm::sys::fs::F_None); - CHECK_EQ(ecode.value(), 0) << "Cannot open file: " << file_name - << " " << ecode.message(); + CHECK_EQ(ecode.value(), 0) << "Cannot open file: " << file_name << " " << ecode.message(); if (fmt == "o" || fmt == "obj") { #if TVM_LLVM_VERSION <= 60 std::unique_ptr m = llvm::CloneModule(mptr_); @@ -104,16 +95,14 @@ class LLVMModuleNode final : public runtime::ModuleNode { llvm::legacy::PassManager pass; CHECK(tm_); #if TVM_LLVM_VERSION <= 60 - CHECK(tm_->addPassesToEmitFile( - pass, dest, llvm::TargetMachine::CGFT_ObjectFile) == 0) + CHECK(tm_->addPassesToEmitFile(pass, dest, llvm::TargetMachine::CGFT_ObjectFile) == 0) << "Cannot emit target CGFT_ObjectFile"; #elif TVM_LLVM_VERSION <= 90 - CHECK(tm_->addPassesToEmitFile( - pass, dest, nullptr, llvm::TargetMachine::CGFT_ObjectFile) == 0) + CHECK(tm_->addPassesToEmitFile(pass, dest, nullptr, llvm::TargetMachine::CGFT_ObjectFile) == + 0) << "Cannot emit target CGFT_ObjectFile"; #else - CHECK(tm_->addPassesToEmitFile( - pass, dest, nullptr, llvm::CGFT_ObjectFile) == 0) + CHECK(tm_->addPassesToEmitFile(pass, dest, nullptr, llvm::CGFT_ObjectFile) == 0) << "Cannot emit target CGFT_ObjectFile"; #endif pass.run(*m); @@ -126,16 +115,14 @@ class LLVMModuleNode final : public runtime::ModuleNode { llvm::legacy::PassManager pass; CHECK(tm_); #if TVM_LLVM_VERSION <= 60 - CHECK(tm_->addPassesToEmitFile( - pass, dest, llvm::TargetMachine::CGFT_AssemblyFile) == 0) + CHECK(tm_->addPassesToEmitFile(pass, dest, llvm::TargetMachine::CGFT_AssemblyFile) == 0) << "Cannot emit target CGFT_AssemblyFile"; #elif TVM_LLVM_VERSION <= 90 - CHECK(tm_->addPassesToEmitFile( - pass, dest, nullptr, llvm::TargetMachine::CGFT_AssemblyFile) == 0) + CHECK(tm_->addPassesToEmitFile(pass, dest, nullptr, llvm::TargetMachine::CGFT_AssemblyFile) == + 0) << "Cannot emit target CGFT_AssemblyFile"; #else - CHECK(tm_->addPassesToEmitFile( - pass, dest, nullptr, llvm::CGFT_AssemblyFile) == 0) + CHECK(tm_->addPassesToEmitFile(pass, dest, nullptr, llvm::CGFT_AssemblyFile) == 0) << "Cannot emit target CGFT_AssemblyFile"; #endif pass.run(*m); @@ -148,8 +135,8 @@ class LLVMModuleNode final : public runtime::ModuleNode { llvm::WriteBitcodeToFile(*mptr_, dest); #endif } else { - LOG(FATAL) << "Do not know how to save file " - << file_name << " with format=\'"<< format << "\'"; + LOG(FATAL) << "Do not know how to save file " << file_name << " with format=\'" << format + << "\'"; } dest.close(); } @@ -165,28 +152,26 @@ class LLVMModuleNode final : public runtime::ModuleNode { llvm::raw_svector_ostream rso(str); if (fmt == "s" || fmt == "asm") { - #if TVM_LLVM_VERSION <= 60 - std::unique_ptr m = llvm::CloneModule(mptr_); - #else - std::unique_ptr m = llvm::CloneModule(*mptr_); - #endif - llvm::legacy::PassManager pass; - CHECK(tm_); - #if TVM_LLVM_VERSION <= 60 - CHECK(tm_->addPassesToEmitFile( - pass, rso, llvm::TargetMachine::CGFT_AssemblyFile) == 0) - << "Cannot emit target CGFT_AssemblyFile"; - #elif TVM_LLVM_VERSION <= 90 - CHECK(tm_->addPassesToEmitFile( - pass, rso, nullptr, llvm::TargetMachine::CGFT_AssemblyFile) == 0) - << "Cannot emit target CGFT_AssemblyFile"; - #else - CHECK(tm_->addPassesToEmitFile( - pass, rso, nullptr, llvm::CGFT_AssemblyFile) == 0) - << "Cannot emit target CGFT_AssemblyFile"; - #endif - pass.run(*m); - return rso.str().str(); +#if TVM_LLVM_VERSION <= 60 + std::unique_ptr m = llvm::CloneModule(mptr_); +#else + std::unique_ptr m = llvm::CloneModule(*mptr_); +#endif + llvm::legacy::PassManager pass; + CHECK(tm_); +#if TVM_LLVM_VERSION <= 60 + CHECK(tm_->addPassesToEmitFile(pass, rso, llvm::TargetMachine::CGFT_AssemblyFile) == 0) + << "Cannot emit target CGFT_AssemblyFile"; +#elif TVM_LLVM_VERSION <= 90 + CHECK(tm_->addPassesToEmitFile(pass, rso, nullptr, llvm::TargetMachine::CGFT_AssemblyFile) == + 0) + << "Cannot emit target CGFT_AssemblyFile"; +#else + CHECK(tm_->addPassesToEmitFile(pass, rso, nullptr, llvm::CGFT_AssemblyFile) == 0) + << "Cannot emit target CGFT_AssemblyFile"; +#endif + pass.run(*m); + return rso.str().str(); } else if (fmt == "" || fmt == "ll") { std::string type_str; llvm::raw_string_ostream rso(type_str); @@ -194,8 +179,7 @@ class LLVMModuleNode final : public runtime::ModuleNode { mptr_->print(rso, nullptr); return rso.str(); } else { - LOG(FATAL) << "Do not know how to get source code with format: " - << format << "\'"; + LOG(FATAL) << "Do not know how to get source code with format: " << format << "\'"; } return ""; } @@ -209,9 +193,8 @@ class LLVMModuleNode final : public runtime::ModuleNode { std::vector funcs; std::string entry_func; - for (auto kv : mod->functions) { - CHECK(kv.second->IsInstance()) - << "Can only lower IR Module with PrimFuncs"; + for (auto kv : mod->functions) { + CHECK(kv.second->IsInstance()) << "Can only lower IR Module with PrimFuncs"; auto f = Downcast(kv.second); if (f->HasNonzeroAttr(tir::attr::kIsEntryFunc)) { auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); @@ -251,8 +234,7 @@ class LLVMModuleNode final : public runtime::ModuleNode { mptr_ = module_.get(); } - void Init(std::unique_ptr module, - std::shared_ptr ctx) { + void Init(std::unique_ptr module, std::shared_ptr ctx) { InitializeLLVM(); ctx_ = ctx; llvm::SMDiagnostic err; @@ -319,20 +301,17 @@ class LLVMModuleNode final : public runtime::ModuleNode { CHECK(layout == mptr_->getDataLayout()) << "Data layout mismatch between module(" << mptr_->getDataLayout().getStringRepresentation() << ")" - << " and ExecutionEngine (" - << layout.getStringRepresentation() << ")"; + << " and ExecutionEngine (" << layout.getStringRepresentation() << ")"; ee_ = builder.create(tm.release()); - CHECK(ee_ != nullptr) - << "Failed to initialize jit engine for " << mptr_->getTargetTriple(); + CHECK(ee_ != nullptr) << "Failed to initialize jit engine for " << mptr_->getTargetTriple(); ee_->runStaticConstructorsDestructors(false); - if (void** ctx_addr = reinterpret_cast( - GetGlobalAddr(runtime::symbol::tvm_module_ctx))) { + if (void** ctx_addr = + reinterpret_cast(GetGlobalAddr(runtime::symbol::tvm_module_ctx))) { *ctx_addr = this; } - runtime::InitContextFunctions([this](const char *name) { - return reinterpret_cast(GetGlobalAddr(name)); - }); + runtime::InitContextFunctions( + [this](const char* name) { return reinterpret_cast(GetGlobalAddr(name)); }); } // Get global address from execution engine. uint64_t GetGlobalAddr(const std::string& name) const { @@ -357,7 +336,7 @@ class LLVMModuleNode final : public runtime::ModuleNode { // JIT lock std::mutex mutex_; // execution engine - llvm::ExecutionEngine *ee_{nullptr}; + llvm::ExecutionEngine* ee_{nullptr}; // The raw pointer to the module. llvm::Module* mptr_{nullptr}; // The target machine @@ -372,17 +351,13 @@ unsigned LookupLLVMIntrinsic(const std::string& name) { return llvm::Function::lookupIntrinsicID(name); } - -TVM_REGISTER_GLOBAL("target.build.llvm") -.set_body_typed([](IRModule mod, std::string target) { +TVM_REGISTER_GLOBAL("target.build.llvm").set_body_typed([](IRModule mod, std::string target) { auto n = make_object(); n->Init(mod, target); return runtime::Module(n); }); - -TVM_REGISTER_GLOBAL("codegen.LLVMModuleCreate") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("codegen.LLVMModuleCreate").set_body([](TVMArgs args, TVMRetValue* rv) { auto n = make_object(); auto target = args[0].operator std::string(); auto module_name = args[1].operator std::string(); @@ -403,35 +378,29 @@ TVM_REGISTER_GLOBAL("codegen.LLVMModuleCreate") *rv = runtime::Module(n); }); -TVM_REGISTER_GLOBAL("target.llvm_lookup_intrinsic_id") -.set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = static_cast(LookupLLVMIntrinsic(args[0])); - }); - -TVM_REGISTER_GLOBAL("target.llvm_version_major") -.set_body([](TVMArgs args, TVMRetValue* rv) { - int major = TVM_LLVM_VERSION / 10; - *rv = major; - }); - -TVM_REGISTER_GLOBAL("runtime.module.loadfile_ll") -.set_body([](TVMArgs args, TVMRetValue* rv) { - auto n = make_object(); - n->LoadIR(args[0]); - *rv = runtime::Module(n); - }); - -TVM_REGISTER_GLOBAL("codegen.llvm_target_enabled") -.set_body([](TVMArgs args, TVMRetValue* rv) { - InitializeLLVM(); - *rv = (GetLLVMTargetMachine(args[0], true) != nullptr); - }); +TVM_REGISTER_GLOBAL("target.llvm_lookup_intrinsic_id").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = static_cast(LookupLLVMIntrinsic(args[0])); +}); + +TVM_REGISTER_GLOBAL("target.llvm_version_major").set_body([](TVMArgs args, TVMRetValue* rv) { + int major = TVM_LLVM_VERSION / 10; + *rv = major; +}); + +TVM_REGISTER_GLOBAL("runtime.module.loadfile_ll").set_body([](TVMArgs args, TVMRetValue* rv) { + auto n = make_object(); + n->LoadIR(args[0]); + *rv = runtime::Module(n); +}); + +TVM_REGISTER_GLOBAL("codegen.llvm_target_enabled").set_body([](TVMArgs args, TVMRetValue* rv) { + InitializeLLVM(); + *rv = (GetLLVMTargetMachine(args[0], true) != nullptr); +}); -TVM_REGISTER_GLOBAL("codegen.codegen_blob") -.set_body([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("codegen.codegen_blob").set_body([](TVMArgs args, TVMRetValue* rv) { auto n = make_object(); - auto p = CodeGenBlob(args[0].operator std::string(), - args[1].operator bool(), + auto p = CodeGenBlob(args[0].operator std::string(), args[1].operator bool(), args[2].operator std::string()); n->Init(std::move(p.first), p.second); *rv = runtime::Module(n); diff --git a/src/target/opt/build_aocl_off.cc b/src/target/opt/build_aocl_off.cc index 2585ac23b961..9f9d098b7a97 100644 --- a/src/target/opt/build_aocl_off.cc +++ b/src/target/opt/build_aocl_off.cc @@ -20,17 +20,14 @@ /*! * Optional module when build aocl is switched to off */ -#include "../source/codegen_source_base.h" #include "../../runtime/opencl/opencl_module.h" +#include "../source/codegen_source_base.h" namespace tvm { namespace runtime { -Module AOCLModuleCreate( - std::string data, - std::string fmt, - std::unordered_map fmap, - std::string source) { +Module AOCLModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, std::string source) { LOG(WARNING) << "AOCL runtime not enabled, return a source module..."; return codegen::DeviceSourceModuleCreate(data, fmt, fmap, "aocl"); } diff --git a/src/target/opt/build_cuda_off.cc b/src/target/opt/build_cuda_off.cc index 4f941a504f93..893eb67a268f 100644 --- a/src/target/opt/build_cuda_off.cc +++ b/src/target/opt/build_cuda_off.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -24,11 +24,9 @@ namespace tvm { namespace runtime { -Module CUDAModuleCreate( - std::string data, - std::string fmt, - std::unordered_map fmap, - std::string cuda_source) { +Module CUDAModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, + std::string cuda_source) { LOG(FATAL) << "CUDA is not enabled"; return Module(); } diff --git a/src/target/opt/build_cuda_on.cc b/src/target/opt/build_cuda_on.cc index 2d659e4487e3..c9471d1bfa8d 100644 --- a/src/target/opt/build_cuda_on.cc +++ b/src/target/opt/build_cuda_on.cc @@ -27,30 +27,26 @@ #include #endif #include - #include + #include -#include "../build_common.h" -#include "../source/codegen_cuda.h" #include "../../runtime/cuda/cuda_common.h" #include "../../runtime/cuda/cuda_module.h" - +#include "../build_common.h" +#include "../source/codegen_cuda.h" namespace tvm { namespace codegen { -#define NVRTC_CALL(x) \ - { \ - nvrtcResult result = x; \ - if (result != NVRTC_SUCCESS) { \ - LOG(FATAL) \ - << "NvrtcError: " #x " failed with error: " \ - << nvrtcGetErrorString(result); \ - } \ +#define NVRTC_CALL(x) \ + { \ + nvrtcResult result = x; \ + if (result != NVRTC_SUCCESS) { \ + LOG(FATAL) << "NvrtcError: " #x " failed with error: " << nvrtcGetErrorString(result); \ + } \ } - std::string FindCUDAIncludePath() { #if defined(_WIN32) const std::string delimiter = "\\"; @@ -78,7 +74,6 @@ std::string FindCUDAIncludePath() { return cuda_include_path; } - std::string NVRTCCompile(const std::string& code, bool include_path = false) { std::vector compile_params; std::vector param_cstrings{}; @@ -104,16 +99,15 @@ std::string NVRTCCompile(const std::string& code, bool include_path = false) { } for (const auto& string : compile_params) { - param_cstrings.push_back(string.c_str()); + param_cstrings.push_back(string.c_str()); } - NVRTC_CALL(nvrtcCreateProgram( - &prog, code.c_str(), nullptr, 0, nullptr, nullptr)); - nvrtcResult compile_res = - nvrtcCompileProgram(prog, param_cstrings.size(), param_cstrings.data()); + NVRTC_CALL(nvrtcCreateProgram(&prog, code.c_str(), nullptr, 0, nullptr, nullptr)); + nvrtcResult compile_res = nvrtcCompileProgram(prog, param_cstrings.size(), param_cstrings.data()); size_t log_size; NVRTC_CALL(nvrtcGetProgramLogSize(prog, &log_size)); - std::string log; log.resize(log_size); + std::string log; + log.resize(log_size); NVRTC_CALL(nvrtcGetProgramLog(prog, &log[0])); CHECK_EQ(compile_res, NVRTC_SUCCESS) << log; size_t ptx_size; @@ -127,15 +121,14 @@ std::string NVRTCCompile(const std::string& code, bool include_path = false) { return ptx; } -runtime::Module BuildCUDA(IRModule mod) { +runtime::Module BuildCUDA(IRModule mod, std::string target) { using tvm::runtime::Registry; bool output_ssa = false; CodeGenCUDA cg; cg.Init(output_ssa); - for (auto kv : mod->functions) { - CHECK(kv.second->IsInstance()) - << "CodeGenCUDA: Can only take PrimFunc"; + for (auto kv : mod->functions) { + CHECK(kv.second->IsInstance()) << "CodeGenCUDA: Can only take PrimFunc"; auto f = Downcast(kv.second); auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); CHECK(calling_conv == CallingConv::kDeviceKernelLaunch) @@ -161,7 +154,6 @@ runtime::Module BuildCUDA(IRModule mod) { return CUDAModuleCreate(ptx, fmt, ExtractFuncInfo(mod), code); } -TVM_REGISTER_GLOBAL("target.build.cuda") -.set_body_typed(BuildCUDA); +TVM_REGISTER_GLOBAL("target.build.cuda").set_body_typed(BuildCUDA); } // namespace codegen } // namespace tvm diff --git a/src/target/opt/build_hexagon_off.cc b/src/target/opt/build_hexagon_off.cc index ce06700222ae..c734eeceed6d 100644 --- a/src/target/opt/build_hexagon_off.cc +++ b/src/target/opt/build_hexagon_off.cc @@ -23,9 +23,8 @@ namespace tvm { namespace runtime { Module HexagonModuleCreate(std::string data, std::string fmt, - std::unordered_map fmap, - std::string asm_str, std::string obj_str, - std::string ir_str, std::string bc_str, + std::unordered_map fmap, std::string asm_str, + std::string obj_str, std::string ir_str, std::string bc_str, const std::set& packed_c_abi) { LOG(WARNING) << "Hexagon runtime is not enabled, return a source module..."; return codegen::DeviceSourceModuleCreate(data, fmt, fmap, "hex"); diff --git a/src/target/opt/build_metal_off.cc b/src/target/opt/build_metal_off.cc index ff796d818b22..3cfe1316e7ce 100644 --- a/src/target/opt/build_metal_off.cc +++ b/src/target/opt/build_metal_off.cc @@ -20,16 +20,14 @@ /*! * Optional module when build metal is switched to off */ -#include "../source/codegen_source_base.h" #include "../../runtime/metal/metal_module.h" +#include "../source/codegen_source_base.h" namespace tvm { namespace runtime { -Module MetalModuleCreate(std::string data, - std::string fmt, - std::unordered_map fmap, - std::string source) { +Module MetalModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, std::string source) { LOG(WARNING) << "Metal runtime not enabled, return a source module..."; return codegen::DeviceSourceModuleCreate(data, fmt, fmap, "metal"); } diff --git a/src/target/opt/build_opencl_off.cc b/src/target/opt/build_opencl_off.cc index 6e796b1edc62..2367500eca92 100644 --- a/src/target/opt/build_opencl_off.cc +++ b/src/target/opt/build_opencl_off.cc @@ -20,17 +20,14 @@ /*! * Optional module when build opencl is switched to off */ -#include "../source/codegen_source_base.h" #include "../../runtime/opencl/opencl_module.h" +#include "../source/codegen_source_base.h" namespace tvm { namespace runtime { -Module OpenCLModuleCreate( - std::string data, - std::string fmt, - std::unordered_map fmap, - std::string source) { +Module OpenCLModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, std::string source) { return codegen::DeviceSourceModuleCreate(data, fmt, fmap, "opencl"); } diff --git a/src/target/opt/build_rocm_off.cc b/src/target/opt/build_rocm_off.cc index 64ab759a9a24..476e5a88fc6f 100644 --- a/src/target/opt/build_rocm_off.cc +++ b/src/target/opt/build_rocm_off.cc @@ -20,19 +20,15 @@ /*! * Optional module when build rocm is switched to off */ -#include "../source/codegen_source_base.h" #include "../../runtime/rocm/rocm_module.h" +#include "../source/codegen_source_base.h" namespace tvm { namespace runtime { -Module ROCMModuleCreate( - std::string data, - std::string fmt, - std::unordered_map fmap, - std::string rocm_source, - std::string assembly) { - +Module ROCMModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, std::string rocm_source, + std::string assembly) { LOG(WARNING) << "ROCM runtime is not enabled, return a source module..."; auto fget_source = [rocm_source, assembly](const std::string& format) { if (format.length() == 0) return assembly; @@ -40,8 +36,7 @@ Module ROCMModuleCreate( if (format == "asm") return assembly; return std::string(""); }; - return codegen::DeviceSourceModuleCreate( - data, fmt, fmap, "hsaco", fget_source); + return codegen::DeviceSourceModuleCreate(data, fmt, fmap, "hsaco", fget_source); } } // namespace runtime diff --git a/src/target/opt/build_sdaccel_off.cc b/src/target/opt/build_sdaccel_off.cc index 8c58c3f45b78..0de305c2a37c 100644 --- a/src/target/opt/build_sdaccel_off.cc +++ b/src/target/opt/build_sdaccel_off.cc @@ -20,17 +20,14 @@ /*! * Optional module when build opencl is switched to off */ -#include "../source/codegen_source_base.h" #include "../../runtime/opencl/opencl_module.h" +#include "../source/codegen_source_base.h" namespace tvm { namespace runtime { -Module SDAccelModuleCreate( - std::string data, - std::string fmt, - std::unordered_map fmap, - std::string source) { +Module SDAccelModuleCreate(std::string data, std::string fmt, + std::unordered_map fmap, std::string source) { LOG(WARNING) << "OpenCL runtime not enabled, return a source module..."; return codegen::DeviceSourceModuleCreate(data, fmt, fmap, "sdaccel"); } diff --git a/src/target/source/codegen_aocl.cc b/src/target/source/codegen_aocl.cc index 64674e3360dd..2b77869d4819 100644 --- a/src/target/source/codegen_aocl.cc +++ b/src/target/source/codegen_aocl.cc @@ -21,28 +21,27 @@ * \file codegen_aocl.cc */ #include -#include + #include -#include "codegen_opencl.h" -#include "../build_common.h" -#include "../../runtime/opencl/aocl/aocl_module.h" +#include + #include "../../runtime/file_util.h" +#include "../../runtime/opencl/aocl/aocl_module.h" +#include "../build_common.h" +#include "codegen_opencl.h" namespace tvm { namespace codegen { -runtime::Module BuildAOCL(IRModule mod, - std::string target_str, - bool emulation) { +runtime::Module BuildAOCL(IRModule mod, std::string target_str, bool emulation) { // Get code. using tvm::runtime::Registry; bool output_ssa = false; CodeGenOpenCL cg; cg.Init(output_ssa); - for (auto kv : mod->functions) { - CHECK(kv.second->IsInstance()) - << "CodegenOpenCL: Can only take PrimFunc"; + for (auto kv : mod->functions) { + CHECK(kv.second->IsInstance()) << "CodegenOpenCL: Can only take PrimFunc"; auto f = Downcast(kv.second); auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); CHECK(calling_conv == CallingConv::kDeviceKernelLaunch) @@ -80,15 +79,13 @@ runtime::Module BuildAOCL(IRModule mod, return AOCLModuleCreate(aocxbin, "aocx", ExtractFuncInfo(mod), code); } -TVM_REGISTER_GLOBAL("target.build.aocl") -.set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = BuildAOCL(args[0], args[1], false); - }); +TVM_REGISTER_GLOBAL("target.build.aocl").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = BuildAOCL(args[0], args[1], false); +}); -TVM_REGISTER_GLOBAL("target.build.build.aocl_sw_emu") -.set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = BuildAOCL(args[0], args[1], true); - }); +TVM_REGISTER_GLOBAL("target.build.build.aocl_sw_emu").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = BuildAOCL(args[0], args[1], true); +}); } // namespace codegen } // namespace tvm diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index 6e7784c81f85..9255d7c80c46 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -20,20 +20,19 @@ /*! * \file codegen_c.cc */ -#include -#include #include "codegen_c.h" -#include "../../arith/compute_expr.h" -#include "../../tir/pass/ir_util.h" + +#include +#include + +#include "../../arith/pattern_match.h" namespace tvm { namespace codegen { using namespace tir; -void CodeGenC::Init(bool output_ssa) { - print_ssa_form_ = output_ssa; -} +void CodeGenC::Init(bool output_ssa) { print_ssa_form_ = output_ssa; } void CodeGenC::InitFuncState(const PrimFunc& f) { alloc_storage_scope_.clear(); @@ -79,8 +78,7 @@ void CodeGenC::AddFunction(const PrimFunc& f) { ReserveKeywordsAsUnique(); auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); - CHECK(global_symbol.defined()) - << "CodeGenC: Expect PrimFunc to have the global_symbol attribute"; + CHECK(global_symbol.defined()) << "CodeGenC: Expect PrimFunc to have the global_symbol attribute"; bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias); this->PrintFuncPrefix(); @@ -94,7 +92,6 @@ void CodeGenC::AddFunction(const PrimFunc& f) { auto it = alloc_storage_scope_.find(v.get()); if (it != alloc_storage_scope_.end()) { PrintStorageScope(it->second, stream); - stream << ' '; } PrintType(GetType(v), stream); @@ -125,16 +122,11 @@ void CodeGenC::AddFunction(const PrimFunc& f) { this->stream << "}\n\n"; } -void CodeGenC::PrintFuncPrefix() { - stream << "void"; -} +void CodeGenC::PrintFuncPrefix() { stream << "void"; } -void CodeGenC::PrintFinalReturn() { -} +void CodeGenC::PrintFinalReturn() {} -std::string CodeGenC::Finish() { - return decl_stream.str() + stream.str(); -} +std::string CodeGenC::Finish() { return decl_stream.str() + stream.str(); } void CodeGenC::PrintExpr(const PrimExpr& n, std::ostream& os) { // NOLINT(*) if (print_ssa_form_) { @@ -146,12 +138,10 @@ void CodeGenC::PrintExpr(const PrimExpr& n, std::ostream& os) { // NOLINT(*) } } -void CodeGenC::PrintSSAAssign( - const std::string& target, const std::string& src, DataType t) { +void CodeGenC::PrintSSAAssign(const std::string& target, const std::string& src, DataType t) { PrintType(t, stream); stream << ' ' << target << " = "; - if (src.length() > 3 && - src[0] == '(' && src[src.length() - 1] == ')') { + if (src.length() > 3 && src[0] == '(' && src[src.length() - 1] == ')') { stream << src.substr(1, src.length() - 2); } else { stream << src; @@ -160,8 +150,7 @@ void CodeGenC::PrintSSAAssign( } // Print a reference expression to a buffer. -std::string CodeGenC::GetBufferRef( - DataType t, const VarNode* buffer, PrimExpr index) { +std::string CodeGenC::GetBufferRef(DataType t, const VarNode* buffer, PrimExpr index) { std::ostringstream os; std::string vid = GetVarID(buffer); std::string scope; @@ -179,7 +168,6 @@ std::string CodeGenC::GetBufferRef( if (!scope.empty() && IsScopePartOfType()) { PrintStorageScope(scope, os); } - os << ' '; PrintType(t, os); os << "*)" << vid << ')'; } else { @@ -188,8 +176,7 @@ std::string CodeGenC::GetBufferRef( os << "[("; PrintExpr(index, os); os << ")"; - if (t.bits() == 4 || - (t.bits() == 1 && t.is_int())) { + if (t.bits() == 4 || (t.bits() == 1 && t.is_int())) { os << " / " << (32 / t.bits()); } os << ']'; @@ -198,10 +185,9 @@ std::string CodeGenC::GetBufferRef( // optimize for case where it is in register, if (HandleTypeMatch(buffer, t) && !is_vol) { // optimize for constant access - int offset; - if (arith::GetConstInt(index, &offset)) { - CHECK_EQ(offset % t.lanes(), 0) - << "Find unaligned vector load to a vector type"; + if (auto* ptr = index.as()) { + int64_t offset = ptr->value; + CHECK_EQ(offset % t.lanes(), 0) << "Find unaligned vector load to a vector type"; os << vid << '[' << (offset / t.lanes()) << ']'; return os.str(); } @@ -213,7 +199,6 @@ std::string CodeGenC::GetBufferRef( if (!scope.empty() && IsScopePartOfType()) { PrintStorageScope(scope, os); } - os << ' '; PrintType(t, os); os << "*)("; if (!HandleTypeMatch(buffer, t.element_of())) { @@ -221,15 +206,13 @@ std::string CodeGenC::GetBufferRef( if (!scope.empty() && IsScopePartOfType()) { PrintStorageScope(scope, os); } - os << ' '; PrintType(t.element_of(), os); os << "*)"; } os << vid << " + ("; PrintExpr(index, os); os << ")"; - if (t.bits() == 4 || - (t.bits() == 1 && t.is_int())) { + if (t.bits() == 4 || (t.bits() == 1 && t.is_int())) { os << " / " << (32 / t.bits()); } os << "))[0]"; @@ -238,8 +221,8 @@ std::string CodeGenC::GetBufferRef( } // Print a reference expression to a buffer. -std::string CodeGenC::GetStructRef( - DataType t, const PrimExpr& buffer, const PrimExpr& index, int kind) { +std::string CodeGenC::GetStructRef(DataType t, const PrimExpr& buffer, const PrimExpr& index, + int kind) { if (kind < intrinsic::kArrKindBound_) { std::ostringstream os; os << "(((DLTensor*)"; @@ -256,17 +239,38 @@ std::string CodeGenC::GetStructRef( os << "]."; // other case: get fields. switch (kind) { - case intrinsic::kArrData: os << "data"; break; - case intrinsic::kArrShape: os << "shape"; break; - case intrinsic::kArrStrides: os << "strides"; break; - case intrinsic::kArrNDim: os << "ndim"; break; - case intrinsic::kArrTypeCode: os << "dtype.code"; break; - case intrinsic::kArrTypeBits: os << "dtype.bits"; break; - case intrinsic::kArrByteOffset: os << "byte_offset"; break; - case intrinsic::kArrTypeLanes: os << "dtype.lanes"; break; - case intrinsic::kArrDeviceId: os << "ctx.device_id"; break; - case intrinsic::kArrDeviceType: os << "ctx.device_type"; break; - default: LOG(FATAL) << "unknown field code"; + case intrinsic::kArrData: + os << "data"; + break; + case intrinsic::kArrShape: + os << "shape"; + break; + case intrinsic::kArrStrides: + os << "strides"; + break; + case intrinsic::kArrNDim: + os << "ndim"; + break; + case intrinsic::kArrTypeCode: + os << "dtype.code"; + break; + case intrinsic::kArrTypeBits: + os << "dtype.bits"; + break; + case intrinsic::kArrByteOffset: + os << "byte_offset"; + break; + case intrinsic::kArrTypeLanes: + os << "dtype.lanes"; + break; + case intrinsic::kArrDeviceId: + os << "ctx.device_id"; + break; + case intrinsic::kArrDeviceType: + os << "ctx.device_type"; + break; + default: + LOG(FATAL) << "unknown field code"; } os << ')'; return os.str(); @@ -301,32 +305,26 @@ void CodeGenC::RegisterHandleType(const VarNode* buf_var, DataType t) { if (it == handle_data_type_.end()) { handle_data_type_[buf_var] = t; } else { - CHECK(it->second == t) - << "conflicting buf var type"; + CHECK(it->second == t) << "conflicting buf var type"; } } -void CodeGenC::PrintVecElemLoad(const std::string& vec, - DataType t, int i, +void CodeGenC::PrintVecElemLoad(const std::string& vec, DataType t, int i, std::ostream& os) { // NOLINT(*) os << vec << ".s" << std::hex << i << std::dec; } -void CodeGenC::PrintVecElemStore(const std::string& vec, - DataType t, int i, +void CodeGenC::PrintVecElemStore(const std::string& vec, DataType t, int i, const std::string& value) { this->PrintIndent(); - stream << vec << ".s" << std::hex << i - << " = " << value << ";\n" << std::dec; + stream << vec << ".s" << std::hex << i << " = " << value << ";\n" << std::dec; } -std::string CodeGenC::GetVecLoad( - DataType t, const VarNode* buffer, PrimExpr base) { +std::string CodeGenC::GetVecLoad(DataType t, const VarNode* buffer, PrimExpr base) { return GetBufferRef(t, buffer, base); } -void CodeGenC::PrintVecStore(const VarNode* buffer, - DataType t, PrimExpr base, +void CodeGenC::PrintVecStore(const VarNode* buffer, DataType t, PrimExpr base, const std::string& value) { std::string ref = GetBufferRef(t, buffer, base); this->PrintIndent(); @@ -342,49 +340,58 @@ std::string CodeGenC::CastFromTo(std::string value, DataType from, DataType targ return os.str(); } -void CodeGenC::BindThreadIndex(const IterVar& iv) { - LOG(FATAL) << "not implemented"; -} +void CodeGenC::BindThreadIndex(const IterVar& iv) { LOG(FATAL) << "not implemented"; } -void CodeGenC::PrintStorageSync(const CallNode* op) { // NOLINT(*) +void CodeGenC::PrintStorageSync(const CallNode* op) { // NOLINT(*) } -void CodeGenC::PrintStorageScope(const std::string& scope, std::ostream& os) { // NOLINT(*) +void CodeGenC::PrintStorageScope(const std::string& scope, std::ostream& os) { // NOLINT(*) CHECK_EQ(scope, "global"); } void CodeGenC::PrintType(DataType t, std::ostream& os) { // NOLINT(*) - CHECK_EQ(t.lanes(), 1) - << "do not yet support vector types"; + CHECK_EQ(t.lanes(), 1) << "do not yet support vector types"; if (t.is_handle()) { - os << "void*"; return; + os << "void*"; + return; } if (t.is_float()) { if (t.bits() == 32) { - os << "float"; return; + os << "float"; + return; } if (t.bits() == 64) { - os << "double"; return; + os << "double"; + return; } } else if (t.is_uint()) { switch (t.bits()) { - case 8: case 16: case 32: case 64: { - os << "uint" << t.bits() << "_t"; return; + case 8: + case 16: + case 32: + case 64: { + os << "uint" << t.bits() << "_t"; + return; } - case 1: os << "int"; return; + case 1: + os << "int"; + return; } } else if (t.is_int()) { switch (t.bits()) { - case 8: case 16: case 32: case 64: { - os << "int" << t.bits() << "_t"; return; + case 8: + case 16: + case 32: + case 64: { + os << "int" << t.bits() << "_t"; + return; } } } LOG(FATAL) << "Cannot convert type " << t << " to C type"; } - -void CodeGenC::PrintType(const Type& type, std::ostream& os) { // NOLINT(*) +void CodeGenC::PrintType(const Type& type, std::ostream& os) { // NOLINT(*) if (auto* ptr = type.as()) { return PrintType(ptr->dtype, os); } else if (auto* ptr = type.as()) { @@ -397,8 +404,7 @@ void CodeGenC::PrintType(const Type& type, std::ostream& os) { // NOLINT(*) } } - -inline void PrintConst(const IntImmNode* op, std::ostream& os, CodeGenC* p) { // NOLINT(*) +inline void PrintConst(const IntImmNode* op, std::ostream& os, CodeGenC* p) { // NOLINT(*) if (op->dtype == DataType::Int(32)) { std::ostringstream temp; temp << op->value; @@ -411,8 +417,8 @@ inline void PrintConst(const IntImmNode* op, std::ostream& os, CodeGenC* p) { // } } - -inline void PrintUIntConst(DataType dtype, uint64_t val, std::ostream& os, CodeGenC* p) { // NOLINT(*) +inline void PrintUIntConst(DataType dtype, uint64_t val, std::ostream& os, + CodeGenC* p) { // NOLINT(*) if (dtype == DataType::UInt(32)) { std::ostringstream temp; temp << val << "U"; @@ -425,9 +431,10 @@ inline void PrintUIntConst(DataType dtype, uint64_t val, std::ostream& os, CodeG } } -inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenC* p) { // NOLINT(*) +inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenC* p) { // NOLINT(*) switch (op->dtype.bits()) { - case 64: case 32: { + case 64: + case 32: { std::ostringstream temp; temp << std::scientific << op->value; if (op->dtype.bits() == 32) temp << 'f'; @@ -438,10 +445,11 @@ inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenC* p) { case 16: { os << '('; p->PrintType(op->dtype, os); - os << ')' << std::scientific <value << 'f'; + os << ')' << std::scientific << op->value << 'f'; break; } - default: LOG(FATAL) << "Bad bit-width for float: " << op->dtype << "\n"; + default: + LOG(FATAL) << "Bad bit-width for float: " << op->dtype << "\n"; } } @@ -449,16 +457,15 @@ void CodeGenC::VisitExpr_(const IntImmNode* op, std::ostream& os) { // NOLINT(* PrintConst(op, os, this); } -void CodeGenC::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT(*) +void CodeGenC::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT(*) PrintConst(op, os, this); } -void CodeGenC::VisitExpr_(const StringImmNode* op, std::ostream& os) { // NOLINT(*) +void CodeGenC::VisitExpr_(const StringImmNode* op, std::ostream& os) { // NOLINT(*) os << "\"" << op->value << "\""; } -template -inline void PrintBinaryExpr(const T* op, - const char* opstr, +template +inline void PrintBinaryExpr(const T* op, const char* opstr, std::ostream& os, // NOLINT(*) CodeGenC* p) { if (op->dtype.lanes() == 1) { @@ -480,10 +487,9 @@ inline void PrintBinaryExpr(const T* op, } } -inline void PrintBinaryIntrinsic(const CallNode* op, - const char* opstr, - std::ostream& os, // NOLINT(*) - CodeGenC* p) { +inline void PrintBinaryIntrinsic(const CallNode* op, const char* opstr, + std::ostream& os, // NOLINT(*) + CodeGenC* p) { if (op->dtype.lanes() == 1) { CHECK_EQ(op->args.size(), 2U); os << '('; @@ -554,8 +560,7 @@ void CodeGenC::VisitExpr_(const NotNode* op, std::ostream& os) { // NOLINT(*) } void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) - if (op->call_type == CallNode::Extern || - op->call_type == CallNode::PureExtern) { + if (op->call_type == CallNode::Extern || op->call_type == CallNode::PureExtern) { os << op->name << "("; for (size_t i = 0; i < op->args.size(); i++) { this->PrintExpr(op->args[i], os); @@ -594,19 +599,16 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) PrintExpr(op->args[2], os); os << ")"; } else if (op->is_intrinsic(intrinsic::tvm_address_of)) { - const LoadNode *l = op->args[0].as(); + const LoadNode* l = op->args[0].as(); CHECK(op->args.size() == 1 && l); os << "(("; this->PrintType(l->dtype.element_of(), os); - os << " *)" << this->GetVarID(l->buffer_var.get()) - << " + "; + os << " *)" << this->GetVarID(l->buffer_var.get()) << " + "; this->PrintExpr(l->index, os); os << ')'; } else if (op->is_intrinsic(intrinsic::tvm_struct_get)) { CHECK_EQ(op->args.size(), 3U); - os << GetStructRef( - op->dtype, op->args[0], op->args[1], - op->args[2].as()->value); + os << GetStructRef(op->dtype, op->args[0], op->args[1], op->args[2].as()->value); } else if (op->is_intrinsic(intrinsic::tvm_handle_is_null)) { CHECK_EQ(op->args.size(), 1U); os << "("; @@ -626,19 +628,16 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) this->PrintExpr(op->args[0], os); os << ")"; } else { - if (op->call_type == CallNode::Intrinsic || - op->call_type == CallNode::PureIntrinsic) { - LOG(FATAL) << "Unresolved intrinsic " << op->name - << " with return type " << op->dtype; + if (op->call_type == CallNode::Intrinsic || op->call_type == CallNode::PureIntrinsic) { + LOG(FATAL) << "Unresolved intrinsic " << op->name << " with return type " << op->dtype; } else { LOG(FATAL) << "Unresolved call type " << op->call_type; } } } -void CodeGenC::PrintVecBinaryOp( - const std::string& op, DataType t, - PrimExpr lhs, PrimExpr rhs, std::ostream& os) { // NOLINT(*) +void CodeGenC::PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs, PrimExpr rhs, + std::ostream& os) { // NOLINT(*) if (isalpha(op[0])) { os << op << "("; this->PrintExpr(lhs, os); @@ -646,7 +645,7 @@ void CodeGenC::PrintVecBinaryOp( this->PrintExpr(rhs, os); os << ")"; } else { - os <<"("; + os << "("; this->PrintExpr(lhs, os); os << ' ' << op << ' '; this->PrintExpr(rhs, os); @@ -661,11 +660,11 @@ void CodeGenC::VisitExpr_(const LoadNode* op, std::ostream& os) { // NOLINT(*) std::string ref = GetBufferRef(op->dtype, op->buffer_var.get(), op->index); HandleVolatileLoads(ref, op, os); } else { - CHECK(is_one(op->predicate)) - << "predicated load is not supported"; - PrimExpr base; - if (GetRamp1Base(op->index, op->dtype.lanes(), &base)) { - std::string ref = GetVecLoad(op->dtype, op->buffer_var.get(), base); + CHECK(is_one(op->predicate)) << "predicated load is not supported"; + + arith::PVar base; + if (arith::ramp(base, 1, op->dtype.lanes()).Match(op->index)) { + std::string ref = GetVecLoad(op->dtype, op->buffer_var.get(), base.Eval()); HandleVolatileLoads(ref, op, os); } else { std::ostringstream svalue_expr; @@ -680,7 +679,6 @@ void CodeGenC::VisitExpr_(const LoadNode* op, std::ostream& os) { // NOLINT(*) auto it = alloc_storage_scope_.find(op->buffer_var.get()); if (it != alloc_storage_scope_.end()) { PrintStorageScope(it->second, value_temp); - value_temp << ' '; } } PrintType(elem_type, value_temp); @@ -702,16 +700,15 @@ void CodeGenC::VisitStmt_(const StoreNode* op) { DataType t = op->value.dtype(); if (t.lanes() == 1) { std::string value = this->PrintExpr(op->value); - std::string ref = this->GetBufferRef(t, op->buffer_var.get(), op->index); + std::string ref = this->GetBufferRef(t, op->buffer_var.get(), op->index); this->PrintIndent(); stream << ref << " = " << value << ";\n"; } else { - CHECK(is_one(op->predicate)) - << "Predicated store is not supported"; - PrimExpr base; - if (GetRamp1Base(op->index, t.lanes(), &base)) { + CHECK(is_one(op->predicate)) << "Predicated store is not supported"; + arith::PVar base; + if (arith::ramp(base, 1, t.lanes()).Match(op->index)) { std::string value = this->PrintExpr(op->value); - this->PrintVecStore(op->buffer_var.get(), t, base, value); + this->PrintVecStore(op->buffer_var.get(), t, base.Eval(), value); } else { // The assignment below introduces side-effect, and the resulting value cannot // be reused across multiple expression, thus a new scope is needed @@ -730,7 +727,6 @@ void CodeGenC::VisitStmt_(const StoreNode* op) { auto it = alloc_storage_scope_.find(op->buffer_var.get()); if (it != alloc_storage_scope_.end()) { PrintStorageScope(it->second, stream); - stream << ' '; } } PrintType(elem_type, stream); @@ -761,9 +757,9 @@ void CodeGenC::VisitExpr_(const RampNode* op, std::ostream& os) { // NOLINT(*) CHECK_EQ(op->base.dtype(), DataType::Int(32)); os << "((int" << op->lanes << ")("; for (int i = 0; i < op->lanes; i++) { - os << "(" << PrintExpr(op->base) << ")" << "+(" << PrintExpr(op->stride) << "*" << i <<")"; - if (i != op->lanes - 1) - os << ", "; + os << "(" << PrintExpr(op->base) << ")" + << "+(" << PrintExpr(op->stride) << "*" << i << ")"; + if (i != op->lanes - 1) os << ", "; } os << "))"; } @@ -772,7 +768,7 @@ void CodeGenC::VisitExpr_(const ShuffleNode* op, std::ostream& os) { LOG(FATAL) << "Shuffle: not supported "; } -void CodeGenC::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*) +void CodeGenC::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*) LOG(FATAL) << "Broadcast: not supported "; } @@ -793,19 +789,14 @@ void CodeGenC::VisitStmt_(const LetStmtNode* op) { var_idmap_[op->var.get()] = value; } else { PrintIndent(); - if (op->var.dtype() == DataType::Handle() && - handle_data_type_.count(op->var.get())) { + if (op->var.dtype() == DataType::Handle() && handle_data_type_.count(op->var.get())) { PrintType(handle_data_type_.at(op->var.get()), stream); - stream << "* " - << AllocVarID(op->var.get()) - << " = ("; + stream << "* " << AllocVarID(op->var.get()) << " = ("; PrintType(handle_data_type_.at(op->var.get()), stream); - stream << "*)" << value << ";\n"; + stream << "*)" << value << ";\n"; } else { PrintType(op->var.dtype(), this->stream); - this->stream << ' ' - << AllocVarID(op->var.get()) - << " = " << value << ";\n"; + this->stream << ' ' << AllocVarID(op->var.get()) << " = " << value << ";\n"; } } PrintStmt(op->body); @@ -815,17 +806,14 @@ void CodeGenC::VisitStmt_(const AllocateNode* op) { CHECK(!is_zero(op->condition)); std::string vid = AllocVarID(op->buffer_var.get()); - this->PrintIndent(); - int32_t constant_size = op->constant_allocation_size(); - CHECK_GT(constant_size, 0) - << "Can only handle constant size stack allocation for now"; - const VarNode* buffer = op->buffer_var.as(); - std::string scope = alloc_storage_scope_.at(buffer); - PrintStorageScope(scope, stream); - stream << ' '; - PrintType(op->dtype, stream); - stream << ' '<< vid << '[' - << constant_size << "];\n"; + this->PrintIndent(); + int32_t constant_size = op->constant_allocation_size(); + CHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation for now"; + const VarNode* buffer = op->buffer_var.as(); + std::string scope = alloc_storage_scope_.at(buffer); + PrintStorageScope(scope, stream); + PrintType(op->dtype, stream); + stream << ' ' << vid << '[' << constant_size << "];\n"; RegisterHandleType(op->buffer_var.get(), op->dtype); this->PrintStmt(op->body); @@ -847,6 +835,10 @@ void CodeGenC::VisitStmt_(const AttrStmtNode* op) { const VarNode* v = op->node.as(); CHECK(v); volatile_buf_.insert(v); + } else if (op->attr_key == tir::attr::pragma_import_c) { + const StringImmNode* value = op->value.as(); + CHECK(value != nullptr); + decl_stream << value->value; } this->PrintStmt(op->body); } @@ -870,9 +862,7 @@ void CodeGenC::VisitStmt_(const ForNode* op) { CHECK(is_zero(op->min)); stream << "for ("; PrintType(op->loop_var.dtype(), stream); - stream << ' ' << vid << " = 0; " - << vid << " < " << extent - << "; ++" << vid << ") {\n"; + stream << ' ' << vid << " = 0; " << vid << " < " << extent << "; ++" << vid << ") {\n"; int for_scope = BeginScope(); PrintStmt(op->body); this->EndScope(for_scope); @@ -914,15 +904,13 @@ void CodeGenC::VisitStmt_(const EvaluateNode* op) { const CallNode* call = op->value.as(); if (call) { if (call->is_intrinsic(intrinsic::tvm_storage_sync)) { - this->PrintStorageSync(call); return; + this->PrintStorageSync(call); + return; } else if (call->is_intrinsic(intrinsic::tvm_struct_set)) { CHECK_EQ(call->args.size(), 4); std::string value = PrintExpr(call->args[3]); - std::string ref = GetStructRef( - call->args[3].dtype(), - call->args[0], - call->args[1], - call->args[2].as()->value); + std::string ref = GetStructRef(call->args[3].dtype(), call->args[0], call->args[1], + call->args[2].as()->value); this->PrintIndent(); this->stream << ref << " = " << value << ";\n"; return; @@ -935,8 +923,7 @@ void CodeGenC::VisitStmt_(const EvaluateNode* op) { } } -void CodeGenC::PrintVecElemLoadExpr( - DataType t, int i, const std::string& value, std::ostream& os) { +void CodeGenC::PrintVecElemLoadExpr(DataType t, int i, const std::string& value, std::ostream& os) { CHECK_GT(t.lanes(), 1); if (t.bits() == 8 && (t.is_int() || t.is_uint())) { if (i != 0) { @@ -949,7 +936,7 @@ void CodeGenC::PrintVecElemLoadExpr( if (i == 0) { os << "(("; PrintType(t, os); - os << t.lanes() << ")("; + os << ")("; } os << value; if (i != t.lanes() - 1) { diff --git a/src/target/source/codegen_c.h b/src/target/source/codegen_c.h index db655beded02..309eb0681607 100644 --- a/src/target/source/codegen_c.h +++ b/src/target/source/codegen_c.h @@ -24,16 +24,18 @@ #ifndef TVM_TARGET_SOURCE_CODEGEN_C_H_ #define TVM_TARGET_SOURCE_CODEGEN_C_H_ +#include +#include #include -#include #include +#include #include -#include -#include + #include -#include #include #include +#include + #include "codegen_source_base.h" namespace tvm { @@ -50,10 +52,9 @@ using namespace tir; * and OpenCL-C. You might find some odd variant features, e.g., type `int3` for * a vector of 3 `int`s. For native C code generator, see `CodeGenLLVM`. */ -class CodeGenC : - public ExprFunctor, - public StmtFunctor, - public CodeGenSourceBase { +class CodeGenC : public ExprFunctor, + public StmtFunctor, + public CodeGenSourceBase { public: /*! * \brief Initialize the code generator. @@ -75,9 +76,7 @@ class CodeGenC : * \brief Print the Stmt n to CodeGenC->stream * \param n The statement to be printed. */ - void PrintStmt(const Stmt& n) { - VisitStmt(n); - } + void PrintStmt(const Stmt& n) { VisitStmt(n); } /*! * \brief Print the expression n(or its ssa id if in ssa mode) into os * \param n The expression to be printed. @@ -99,11 +98,11 @@ class CodeGenC : * * Example: stream << "void"; */ - virtual void PrintFuncPrefix(); // NOLINT(*) + virtual void PrintFuncPrefix(); // NOLINT(*) /*! * \brief Print the final return at the end the function. */ - virtual void PrintFinalReturn(); // NOLINT(*) + virtual void PrintFinalReturn(); // NOLINT(*) /*! * \brief Insert statement before function body. * \param f The function to be compiled. @@ -115,33 +114,33 @@ class CodeGenC : */ virtual void InitFuncState(const PrimFunc& f); // expression - void VisitExpr_(const VarNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const LoadNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const LetNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const CallNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const AddNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const SubNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const MulNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const DivNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const ModNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const MinNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const MaxNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const EQNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const NENode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const LTNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const LENode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const GTNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const GENode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const AndNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const OrNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const CastNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const NotNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const SelectNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const RampNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const ShuffleNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const VarNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const LoadNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const LetNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const CallNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const AddNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const SubNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const MulNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const DivNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const ModNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const MinNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const MaxNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const EQNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const NENode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const LTNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const LENode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const GTNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const GENode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const AndNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const OrNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const CastNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const NotNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const SelectNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const RampNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const ShuffleNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const BroadcastNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const IntImmNode* op, std::ostream& os) override; // NOLINT(*) - void VisitExpr_(const FloatImmNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const IntImmNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const FloatImmNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const StringImmNode* op, std::ostream& os) override; // NOLINT(*) // statment void VisitStmt_(const LetStmtNode* op) override; @@ -158,36 +157,34 @@ class CodeGenC : * \param t The type representation. * \param os The stream to print the ctype into */ - virtual void PrintType(DataType t, std::ostream& os); // NOLINT(*) + virtual void PrintType(DataType t, std::ostream& os); // NOLINT(*) /*! * Print Type represetnation of type type. * \param type The type representation. * \param os The stream to print the ctype into */ - virtual void PrintType(const Type& type, std::ostream& os); // NOLINT(*) + virtual void PrintType(const Type& type, std::ostream& os); // NOLINT(*) /*! * \brief Print expr representing the thread tag * \param IterVar iv The thread index to be binded; */ - virtual void BindThreadIndex(const IterVar& iv); // NOLINT(*) - virtual void PrintStorageScope(const std::string& scope, std::ostream& os); // NOLINT(*) - virtual void PrintStorageSync(const CallNode* op); // NOLINT(*) + virtual void BindThreadIndex(const IterVar& iv); // NOLINT(*) + virtual void PrintStorageScope(const std::string& scope, std::ostream& os); // NOLINT(*) + virtual void PrintStorageSync(const CallNode* op); // NOLINT(*) // Binary vector op. - virtual void PrintVecBinaryOp( - const std::string&op, DataType op_type, - PrimExpr lhs, PrimExpr rhs, std::ostream& os); // NOLINT(*) + virtual void PrintVecBinaryOp(const std::string& op, DataType op_type, PrimExpr lhs, PrimExpr rhs, + std::ostream& os); // NOLINT(*) // print vector load virtual std::string GetVecLoad(DataType t, const VarNode* buffer, PrimExpr base); // print vector store - virtual void PrintVecStore(const VarNode* buffer, - DataType t, PrimExpr base, + virtual void PrintVecStore(const VarNode* buffer, DataType t, PrimExpr base, const std::string& value); // NOLINT(*) // print load of single element - virtual void PrintVecElemLoad( - const std::string& vec, DataType t, int i, std::ostream& os); // NOLINT(*) + virtual void PrintVecElemLoad(const std::string& vec, DataType t, int i, + std::ostream& os); // NOLINT(*) // print store of single element. - virtual void PrintVecElemStore( - const std::string& vec, DataType t, int i, const std::string& value); + virtual void PrintVecElemStore(const std::string& vec, DataType t, int i, + const std::string& value); // Get a cast type from to virtual std::string CastFromTo(std::string value, DataType from, DataType target); // Get load of single element with expression @@ -195,11 +192,9 @@ class CodeGenC : protected: // Print reference to struct location - std::string GetStructRef( - DataType t, const PrimExpr& buffer, const PrimExpr& index, int kind); + std::string GetStructRef(DataType t, const PrimExpr& buffer, const PrimExpr& index, int kind); // Print reference to a buffer as type t in index. - virtual std::string GetBufferRef( - DataType t, const VarNode* buffer, PrimExpr index); + virtual std::string GetBufferRef(DataType t, const VarNode* buffer, PrimExpr index); /*! * \brief Handle volatile loads. @@ -209,8 +204,7 @@ class CodeGenC : * does not implement volatile member functions. CUDA codegen will cast * away volatile qualifier from CUDA __half types. */ - virtual void HandleVolatileLoads(const std::string& value, const LoadNode* op, - std::ostream& os) { + virtual void HandleVolatileLoads(const std::string& value, const LoadNode* op, std::ostream& os) { // By default, do nothing but print the loaded value. os << value; } @@ -223,9 +217,7 @@ class CodeGenC : * or "__constant__" is not part of type but a storage class (like * C/C++ static). */ - virtual bool IsScopePartOfType() const { - return true; - } + virtual bool IsScopePartOfType() const { return true; } /*! * \brief If buffer is allocated as type t. @@ -240,15 +232,12 @@ class CodeGenC : */ void RegisterHandleType(const VarNode* buf_var, DataType t); // override - void PrintSSAAssign( - const std::string& target, const std::string& src, DataType t) final; + void PrintSSAAssign(const std::string& target, const std::string& src, DataType t) final; /*! \brief reserves common C keywords */ void ReserveKeywordsAsUnique(); /*! \brief Check if buf_var is volatile or not. */ - bool IsVolatile(const VarNode *buf_var) const { - return volatile_buf_.count(buf_var) != 0; - } + bool IsVolatile(const VarNode* buf_var) const { return volatile_buf_.count(buf_var) != 0; } /*! \brief restrict keyword */ std::string restrict_keyword_{""}; @@ -257,29 +246,6 @@ class CodeGenC : /*! \brief the data type of allocated buffers */ std::unordered_map handle_data_type_; - /*! - * \brief A RAII utility class for emitting code in a scoped region. - */ - class EnterScopeRAII { - // The codegen context. - CodeGenC* cg; - - // The new scope level. - int scope; - - public: - explicit EnterScopeRAII(CodeGenC* cg) : cg(cg) { - cg->PrintIndent(); - cg->stream << "{\n"; - scope = cg->BeginScope(); - } - ~EnterScopeRAII() { - cg->EndScope(scope); - cg->PrintIndent(); - cg->stream << "}\n"; - } - }; - private: /*! \brief whether to print in SSA form */ bool print_ssa_form_{false}; diff --git a/src/target/source/codegen_c_host.cc b/src/target/source/codegen_c_host.cc index cbdec6201742..b11b3d8fc5f9 100644 --- a/src/target/source/codegen_c_host.cc +++ b/src/target/source/codegen_c_host.cc @@ -20,24 +20,26 @@ /*! * \file codegen_c_host.cc */ +#include "codegen_c_host.h" + #include -#include + #include -#include "codegen_c_host.h" +#include + #include "../build_common.h" namespace tvm { namespace codegen { -CodeGenCHost::CodeGenCHost() { - module_name_ = GetUniqueName("__tvm_module_ctx"); -} +CodeGenCHost::CodeGenCHost() { module_name_ = GetUniqueName("__tvm_module_ctx"); } void CodeGenCHost::Init(bool output_ssa, bool emit_asserts) { emit_asserts_ = emit_asserts; + declared_globals_.clear(); decl_stream << "#include \"tvm/runtime/c_runtime_api.h\"\n"; decl_stream << "#include \"tvm/runtime/c_backend_api.h\"\n"; - decl_stream << "extern void* " << module_name_ << " = NULL;\n"; + decl_stream << "void* " << module_name_ << " = NULL;\n"; CodeGenC::Init(output_ssa); } @@ -56,12 +58,13 @@ void CodeGenCHost::PrintFinalReturn() { // NOLINT(*) void CodeGenCHost::PrintType(DataType t, std::ostream& os) { // NOLINT(*) int lanes = t.lanes(); if (t.is_handle()) { - CHECK_EQ(lanes, 1) - << "does not support vector types"; - os << "void*"; return; + CHECK_EQ(lanes, 1) << "does not support vector types"; + os << "void*"; + return; } if (t == DataType::Bool()) { - os << "bool"; return; + os << "bool"; + return; } bool fail = false; if (t.is_float()) { @@ -69,37 +72,55 @@ void CodeGenCHost::PrintType(DataType t, std::ostream& os) { // NOLINT(*) case 16: os << "half"; break; - case 32: os << "float"; break; + case 32: + os << "float"; + break; case 64: os << "double"; break; - default: fail = true; break; + default: + fail = true; + break; } if (!fail && lanes == 1) return; if (!fail && (lanes >= 2 && lanes <= 16)) { - os << lanes; return; + os << lanes; + return; } } else if (t.is_uint() || t.is_int()) { if (t.is_uint()) { os << 'u'; } switch (t.bits()) { - case 8: os << "int8_t"; break; - case 16: os << "int16_t"; break; - case 32: os << "int32_t"; break; - case 64: os << "int64_t"; break; - case 1: os << "int32_t"; break; - default: fail = true; break; + case 8: + os << "int8_t"; + break; + case 16: + os << "int16_t"; + break; + case 32: + os << "int32_t"; + break; + case 64: + os << "int64_t"; + break; + case 1: + os << "int32_t"; + break; + default: + fail = true; + break; } if (!fail && lanes == 1) return; if (!fail && (lanes >= 2 && lanes <= 16)) { - os << lanes; return; + os << lanes; + return; } } LOG(FATAL) << "Cannot convert type " << t << " to C type"; } -void CodeGenCHost::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*) +void CodeGenCHost::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*) std::string v = PrintExpr(op->value); os << "(("; PrintType(op->dtype, os); @@ -117,9 +138,8 @@ void CodeGenCHost::PrintGetFuncFromBackend(const std::string& func_name, this->stream << "if (" << packed_func_name << " == NULL) {\n"; int packed_func_if_scope = this->BeginScope(); this->PrintIndent(); - this->stream << "if (TVMBackendGetFuncFromEnv(" << module_name_ - << ", \"" << func_name << "\"" - << ", &" << packed_func_name << ") != 0) {\n"; + this->stream << "if (TVMBackendGetFuncFromEnv(" << module_name_ << ", \"" << func_name << "\"" + << ", &" << packed_func_name << ") != 0) {\n"; int get_func_env_scope = this->BeginScope(); this->PrintIndent(); this->stream << "return -1;\n"; @@ -140,9 +160,12 @@ void CodeGenCHost::PrintFuncCall(const std::string& packed_func_name, int num_ar this->stream << "int " << ret_type_code << ";\n"; this->PrintIndent(); this->stream << "if (TVMFuncCall(" << packed_func_name << ", " - << "(TVMValue*) stack_value" << ", " << "(int*) stack_tcode" << ", " - << num_args << ", " << "&" << ret_val << ", " << "&" - << ret_type_code << ") != 0) {\n"; + << "(TVMValue*) stack_value" + << ", " + << "(int*) stack_tcode" + << ", " << num_args << ", " + << "&" << ret_val << ", " + << "&" << ret_type_code << ") != 0) {\n"; int func_call_scope = this->BeginScope(); this->PrintIndent(); this->stream << "return -1;\n"; @@ -151,7 +174,7 @@ void CodeGenCHost::PrintFuncCall(const std::string& packed_func_name, int num_ar this->stream << "}\n"; } -void CodeGenCHost::VisitExpr_(const CallNode *op, std::ostream& os) { // NOLINT(*) +void CodeGenCHost::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) if (op->is_intrinsic(intrinsic::tvm_stack_alloca)) { std::string stack_name = GetUniqueName("stack"); const std::string& type = op->args[0].as()->value; @@ -182,8 +205,15 @@ void CodeGenCHost::VisitExpr_(const CallNode *op, std::ostream& os) { // NOLINT( int64_t num_args = end - begin; CHECK_GE(num_args, 0); std::string func_name = s->value; - std::string packed_func_name = GetUniqueName(func_name + "_packed"); - decl_stream << "static void* " << packed_func_name << " = NULL;\n"; + // NOTE: cannot rely on GetUnique for global decl_stream declarations + // because it is reset between AddFunction(). + std::string packed_func_name = func_name + "_packed"; + if (declared_globals_.insert(packed_func_name).second) { + // Still reserve the name among unique names. + CHECK(GetUniqueName(packed_func_name) == packed_func_name) + << "Expected name " << packed_func_name << " to not be taken"; + decl_stream << "static void* " << packed_func_name << " = NULL;\n"; + } this->PrintGetFuncFromBackend(func_name, packed_func_name); this->PrintFuncCall(packed_func_name, num_args); } else if (op->is_intrinsic(intrinsic::tvm_throw_last_error)) { @@ -194,7 +224,7 @@ void CodeGenCHost::VisitExpr_(const CallNode *op, std::ostream& os) { // NOLINT( } } -void CodeGenCHost::VisitStmt_(const AssertStmtNode *op) { // NOLINT(*) +void CodeGenCHost::VisitStmt_(const AssertStmtNode* op) { // NOLINT(*) if (emit_asserts_) { std::string cond = PrintExpr(op->condition); PrintIndent(); @@ -211,18 +241,17 @@ void CodeGenCHost::VisitStmt_(const AssertStmtNode *op) { // NOLINT(*) this->PrintStmt(op->body); } -void CodeGenCHost::VisitExpr_(const MinNode *op, std::ostream& os) { // NOLINT(*) +void CodeGenCHost::VisitExpr_(const MinNode* op, std::ostream& os) { // NOLINT(*) PrintTernaryCondExpr(op, "<", os); } -void CodeGenCHost::VisitExpr_(const MaxNode *op, std::ostream& os) { // NOLINT(*) +void CodeGenCHost::VisitExpr_(const MaxNode* op, std::ostream& os) { // NOLINT(*) PrintTernaryCondExpr(op, ">", os); } template -inline void CodeGenCHost::PrintTernaryCondExpr(const T* op, - const char* compare, - std::ostream& os) { // NOLINT(*) +inline void CodeGenCHost::PrintTernaryCondExpr(const T* op, const char* compare, + std::ostream& os) { // NOLINT(*) std::ostringstream temp_a; VisitExpr(op->a, temp_a); std::string a_id = SSAGetID(temp_a.str(), op->a.dtype()); @@ -241,9 +270,8 @@ runtime::Module BuildCHost(IRModule mod) { CodeGenCHost cg; cg.Init(output_ssa, emit_asserts); - for (auto kv : mod->functions) { - CHECK(kv.second->IsInstance()) - << "CodegenCHost: Can only take PrimFunc"; + for (auto kv : mod->functions) { + CHECK(kv.second->IsInstance()) << "CodegenCHost: Can only take PrimFunc"; auto f = Downcast(kv.second); cg.AddFunction(f); } @@ -252,9 +280,8 @@ runtime::Module BuildCHost(IRModule mod) { return CSourceModuleCreate(code, "c"); } -TVM_REGISTER_GLOBAL("target.build.c") -.set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = BuildCHost(args[0]); - }); +TVM_REGISTER_GLOBAL("target.build.c").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = BuildCHost(args[0]); +}); } // namespace codegen } // namespace tvm diff --git a/src/target/source/codegen_c_host.h b/src/target/source/codegen_c_host.h index 4f9a0a74511f..94a76faabd78 100644 --- a/src/target/source/codegen_c_host.h +++ b/src/target/source/codegen_c_host.h @@ -24,10 +24,12 @@ #ifndef TVM_TARGET_SOURCE_CODEGEN_C_HOST_H_ #define TVM_TARGET_SOURCE_CODEGEN_C_HOST_H_ -#include -#include +#include #include + #include "codegen_c.h" +#include "tvm/target/codegen.h" +#include "tvm/tir/expr.h" namespace tvm { namespace codegen { @@ -37,22 +39,24 @@ class CodeGenCHost final : public CodeGenC { CodeGenCHost(); void Init(bool output_ssa, bool emit_asserts); - void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) - void PrintFuncPrefix() final; // NOLINT(*) - void PrintFinalReturn() final; // NOLINT(*) + void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) + void PrintFuncPrefix() final; // NOLINT(*) + void PrintFinalReturn() final; // NOLINT(*) // overload visitor functions - void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*) - void VisitExpr_(const CallNode *op, std::ostream& os) final; // NOLINT(*) + void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*) + void VisitExpr_(const CallNode* op, std::ostream& os) final; // NOLINT(*) // overload min and max to use the ternary operator, so we don't rely on the // standard library implementations - void VisitExpr_(const MinNode *op, std::ostream& os) final; // NOLINT(*) - void VisitExpr_(const MaxNode *op, std::ostream& os) final; // NOLINT(*) + void VisitExpr_(const MinNode* op, std::ostream& os) final; // NOLINT(*) + void VisitExpr_(const MaxNode* op, std::ostream& os) final; // NOLINT(*) - void VisitStmt_(const AssertStmtNode *op) final; // NOLINT(*) + void VisitStmt_(const AssertStmtNode* op) final; // NOLINT(*) private: std::string module_name_; + /* \brief tracks declared global variables which live despite GetUniqueName */ + std::set declared_globals_; /*! \brief whether to emit asserts in the resulting C code */ bool emit_asserts_; @@ -67,8 +71,7 @@ class CodeGenCHost final : public CodeGenC { * \param os stream reference to print into */ template - inline void PrintTernaryCondExpr(const T* op, - const char* compare, + inline void PrintTernaryCondExpr(const T* op, const char* compare, std::ostream& os); // NOLINT(*) }; diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 02b5b413562e..cf7a74f1dcc0 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -21,21 +21,21 @@ * \file codegen_cuda.cc */ +#include "codegen_cuda.h" + #include #include +#include #include #include -#include + #include "literal/cuda_half_t.h" -#include "codegen_cuda.h" namespace tvm { namespace codegen { -CodeGenCUDA::CodeGenCUDA() { - restrict_keyword_ = "__restrict__"; -} +CodeGenCUDA::CodeGenCUDA() { restrict_keyword_ = "__restrict__"; } void CodeGenCUDA::Init(bool output_ssa) { CodeGenC::Init(output_ssa); @@ -44,10 +44,7 @@ void CodeGenCUDA::Init(bool output_ssa) { CHECK_EQ(vid_global_barrier_state_, runtime::symbol::tvm_global_barrier_state); } - -void CodeGenCUDA::PrintFuncPrefix() { - stream << "extern \"C\" __global__ void"; -} +void CodeGenCUDA::PrintFuncPrefix() { stream << "extern \"C\" __global__ void"; } std::string CodeGenCUDA::Finish() { if (enable_fp16_) { @@ -64,6 +61,10 @@ std::string CodeGenCUDA::Finish() { decl_stream << _cuda_half_util; } + if (enable_warp_shuffle_) { + decl_stream << _cuda_warp_intrinsic_util; + } + if (enable_int8_) { decl_stream << "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 610)\n"; decl_stream << "#include \n"; @@ -92,16 +93,15 @@ void CodeGenCUDA::VisitStmt_(const tir::ForNode* op) { void CodeGenCUDA::BindThreadIndex(const IterVar& iv) { CHECK(!var_idmap_.count(iv->var.get())); - var_idmap_[iv->var.get()] = - CastFromTo(iv->thread_tag, DataType::UInt(32), iv->var.dtype()); + var_idmap_[iv->var.get()] = CastFromTo(iv->thread_tag, DataType::UInt(32), iv->var.dtype()); } void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) int lanes = t.lanes(); if (t.is_handle()) { - CHECK_EQ(lanes, 1) - << "do not yet support vector types"; - os << "void*"; return; + CHECK_EQ(lanes, 1) << "do not yet support vector types"; + os << "void*"; + return; } bool fail = false; if (t.is_float()) { @@ -126,22 +126,31 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) fail = true; } break; - case 32: os << "float"; break; - case 64: os << "double"; break; - default: fail = true; break; + case 32: + os << "float"; + break; + case 64: + os << "double"; + break; + default: + fail = true; + break; } if (!fail && (lanes == 1 || t.bits() == 16)) return; if (!fail && (lanes >= 2 && lanes <= 4)) { - os << lanes; return; + os << lanes; + return; } } else if (t == DataType::Bool()) { - os << "bool"; return; + os << "bool"; + return; } else if (t.is_vector_bool()) { // CUDA does not support bool vectors. // Use ushort vectors to represent instead. int n = t.lanes(); if (n <= 4) { - os << "ushort" << n; return; + os << "ushort" << n; + return; } } else if (t.is_uint() || t.is_int()) { if (t.is_uint()) { @@ -154,31 +163,41 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) switch (t.bits()) { case 1: { if (t.lanes() == 1) { - os << "int"; return; + os << "int"; + return; } else if (t.lanes() == 8) { - os << "int8_t"; return; + os << "int8_t"; + return; } else if (t.lanes() == 16) { - os << "int16_t"; return; + os << "int16_t"; + return; } else if (t.lanes() == 32) { - os << "int"; return; + os << "int"; + return; } else { LOG(FATAL) << "Cannot convert type " << t << " to CUDA type!"; } } case 4: { if (t.lanes() == 1) { - os << "int"; return; + os << "int"; + return; } else if (t.lanes() == 4) { - os << "int16_t"; return; + os << "int16_t"; + return; } else if (t.lanes() == 8) { // directly 8 4-bit int in integer. - os << "int"; return; + os << "int"; + return; } else if (t.lanes() == 16) { - os << "int2"; return; + os << "int2"; + return; } else if (t.lanes() == 32) { - os << "int4"; return; + os << "int4"; + return; } else if (t.lanes() == 64) { - os << "int8"; return; + os << "int8"; + return; } else { LOG(FATAL) << "Cannot convert type " << t << " to CUDA type!"; } @@ -191,59 +210,71 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) // We use int for int8x4 instead of char4 because using char4 is // likely to produce extra instructions to pack four int8 elements // into 32-bit data. - os << "int"; return; + os << "int"; + return; } else if (t.lanes() == 8) { enable_int8_ = true; - os << "int2"; return; + os << "int2"; + return; } else if (t.lanes() == 16) { enable_int8_ = true; - os << "int4"; return; + os << "int4"; + return; } else if (!t.is_uint() && t.lanes() == 1) { - os << "signed char"; break; + os << "signed char"; + break; } else { - os << "char"; break; + os << "char"; + break; } } - case 16: os << "short"; break; - case 32: os << "int"; break; + case 16: + os << "short"; + break; + case 32: + os << "int"; + break; case 64: { - if (sizeof(long) != 8) { // NOLINT(*) + if (sizeof(long) != 8) { // NOLINT(*) if (t.lanes() == 1) { - os << "long long"; break; + os << "long long"; + break; } else if (t.lanes() == 2) { - os << "longlong"; break; + os << "longlong"; + break; } else { // No longlong3, longlong4 LOG(FATAL) << "Cannot convert type " << t << " to CUDA type on a L32 platform"; break; } } else { - os << "long"; break; + os << "long"; + break; } } - default: fail = true; break; + default: + fail = true; + break; } if (!fail && lanes == 1) { return; } if (!fail && (lanes >= 2 && lanes <= 4)) { - os << lanes; return; + os << lanes; + return; } } LOG(FATAL) << "Cannot convert type " << t << " to CUDA type"; } -void CodeGenCUDA::PrintVecBinaryOp( - const std::string& op, DataType t, - PrimExpr lhs, PrimExpr rhs, std::ostream& os) { // NOLINT(*) +void CodeGenCUDA::PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs, PrimExpr rhs, + std::ostream& os) { // NOLINT(*) // Delcare the result. std::string sret = GetUniqueName("_"); this->PrintIndent(); this->PrintType(t, stream); stream << ' ' << sret << ";\n"; { - EnterScopeRAII scope(this); - // Unpack into individual ops. std::string vlhs = SSAGetID(PrintExpr(lhs), lhs.dtype()); std::string vrhs = SSAGetID(PrintExpr(rhs), rhs.dtype()); @@ -269,37 +300,54 @@ void CodeGenCUDA::PrintVecBinaryOp( os << sret; } -void CodeGenCUDA::PrintVecElemLoad( - const std::string& vec, DataType t, int i, std::ostream& os) { // NOLINT(*) +void CodeGenCUDA::PrintVecElemLoad(const std::string& vec, DataType t, int i, + std::ostream& os) { // NOLINT(*) + if (t.is_scalar()) { + os << vec; + return; + } + static const char access[] = {'x', 'y', 'z', 'w'}; CHECK(i >= 0 && i < (t.is_float16() ? 8 : 4)); if ((t.is_int()) && t.bits() == 8) { - os << "((char)(" << vec << " >> " << i * 8 << "))"; + if (t.lanes() == 2 || t.lanes() == 3) { + os << vec << "." << access[i % t.lanes()]; + } else { + os << "((char)(" << vec << " >> " << i * 8 << "))"; + } } else if ((t.is_uint()) && t.bits() == 8) { - os << "((unsigned char)(" << vec << " >> " << i * 8 << "))"; + if (t.lanes() == 2 || t.lanes() == 3) { + os << vec << "." << access[i % t.lanes()]; + } else { + os << "((unsigned char)(" << vec << " >> " << i * 8 << "))"; + } } else if (t.is_float16()) { - os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" - << access[i % 2]; + os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2]; } else { os << vec << "." << access[i]; } } -void CodeGenCUDA::PrintVecElemStore( - const std::string& vec, DataType t, int i, const std::string& value) { +void CodeGenCUDA::PrintVecElemStore(const std::string& vec, DataType t, int i, + const std::string& value) { this->PrintIndent(); static const char access[] = {'x', 'y', 'z', 'w'}; CHECK(i >= 0 && i < (t.is_float16() ? 8 : 4)); if (t.bits() == 8 && (t.is_int() || t.is_uint())) { - stream << vec << "="; - // Do not read the first undef lane. - if (i != 0) { - stream << vec << " & ~(0x000000ff << " << i * 8 << ") |"; + if (t.lanes() == 2 || t.lanes() == 3) { + stream << vec << '.' << access[i % t.lanes()] << "=" + << "(" << value << ");\n"; + } else { + stream << vec << "="; + // Do not read the first undef lane. + if (i != 0) { + stream << vec << " & ~(0x000000ff << " << i * 8 << ") |"; + } + stream << "(" << value << " << " << i * 8 << ");\n"; } - stream << "(" << value << " << " << i * 8 << ");\n"; } else if (t.is_float16()) { - stream << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" - << access[i % 2] << " = " << value << ";\n"; + stream << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2] << " = " + << value << ";\n"; } else { stream << vec << "." << access[i] << " = " << value << ";\n"; } @@ -315,8 +363,8 @@ void CodeGenCUDA::PrintStorageSync(const CallNode* op) { } else if (sync == "global") { if (!need_global_barrier_) { need_global_barrier_ = true; - this->decl_stream << "extern \"C\" __device__ unsigned " - << vid_global_barrier_state_ << ";\n"; + this->decl_stream << "extern \"C\" __device__ unsigned " << vid_global_barrier_state_ + << ";\n"; } // global synchronizer std::string is_load = PrintExpr(op->args[1]); @@ -324,33 +372,31 @@ void CodeGenCUDA::PrintStorageSync(const CallNode* op) { this->PrintIndent(); // In theory only threadfence is needed // but we observed problems with only threadfence - this->stream <<"__threadfence_system();\n"; + this->stream << "__threadfence_system();\n"; this->PrintIndent(); - this->stream <<"if (" << is_load << ") {\n"; + this->stream << "if (" << is_load << ") {\n"; int wb = this->BeginScope(); this->PrintIndent(); this->stream << "atomicAdd(&" << vid_global_barrier_state_ << ", 1);\n"; this->PrintIndent(); std::string ptr = GetUniqueName("pf"); - this->stream << "volatile unsigned* " - << ptr << " = &" << vid_global_barrier_state_<< ";\n"; + this->stream << "volatile unsigned* " << ptr << " = &" << vid_global_barrier_state_ << ";\n"; this->PrintIndent(); this->stream << vid_global_barrier_expect_ << " += " << num_blocks << ";\n"; this->PrintIndent(); - this->stream <<"while (" << ptr << "[0] < " << vid_global_barrier_expect_ << ");\n"; + this->stream << "while (" << ptr << "[0] < " << vid_global_barrier_expect_ << ");\n"; this->EndScope(wb); this->PrintIndent(); - this->stream <<"}\n"; + this->stream << "}\n"; this->PrintIndent(); - this->stream <<"__syncthreads();\n"; + this->stream << "__syncthreads();\n"; } } -void CodeGenCUDA::PrintStorageScope( - const std::string& scope, std::ostream& os) { // NOLINT(*) +void CodeGenCUDA::PrintStorageScope(const std::string& scope, std::ostream& os) { // NOLINT(*) CHECK_NE(scope, "global"); if (scope == "shared") { - os << "__shared__"; + os << "__shared__ "; } } @@ -360,8 +406,7 @@ void CodeGenCUDA::VisitExpr_(const CastNode* op, std::ostream& os) { CHECK_EQ(target_ty.lanes(), from_ty.lanes()); // Emit simple C-style type conversion. - if (from_ty.is_scalar()) - return CodeGenC::VisitExpr_(op, os); + if (from_ty.is_scalar()) return CodeGenC::VisitExpr_(op, os); // We could emit make_float4 like calls, but the emitted code looks // too compact to read. Emit this as vectorized unary ops. @@ -370,7 +415,6 @@ void CodeGenCUDA::VisitExpr_(const CastNode* op, std::ostream& os) { this->PrintType(target_ty, stream); stream << ' ' << sret << ";\n"; { - EnterScopeRAII scope(this); std::string src = SSAGetID(PrintExpr(op->value), from_ty); for (int i = 0, lanes = from_ty.lanes(); i < lanes; ++i) { std::ostringstream val; @@ -385,7 +429,14 @@ void CodeGenCUDA::VisitExpr_(const CastNode* op, std::ostream& os) { os << sret; } -void CodeGenCUDA::VisitExpr_(const CallNode *op, std::ostream& os) { +void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { + // This is only for backward compatibility with __shfl_{up/down}. + // A macro will be used to replace *_sync calls to legacy ones. + if (op->is_intrinsic("__shfl_sync") || op->is_intrinsic("__shfl_up_sync") || + op->is_intrinsic("__shfl_down_sync")) { + enable_warp_shuffle_ = true; + } + if (op->is_intrinsic(intrinsic::tvm_fill_fragment)) { need_mma_h_ = true; CHECK_EQ(op->args.size(), 6U); @@ -419,7 +470,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode *op, std::ostream& os) { this->PrintExpr(op->args[4], os); os << "], "; this->PrintExpr(op->args[6], os); - if (const StringImmNode *str = op->args[7].as()) { + if (const StringImmNode* str = op->args[7].as()) { os << ", nvcuda::wmma::mem_" << str->value; } else { LOG(FATAL) << "Invalid parameters"; @@ -433,7 +484,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode *op, std::ostream& os) { this->PrintExpr(op->args[i * 2], os); os << "["; this->PrintExpr(op->args[i * 2 + 1], os); - os << "]" << ((i < 3) ? ", ": ")"); + os << "]" << ((i < 3) ? ", " : ")"); } } else if (op->is_intrinsic(intrinsic::tvm_bmma_sync)) { need_mma_h_ = true; @@ -443,7 +494,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode *op, std::ostream& os) { this->PrintExpr(op->args[i * 2], os); os << "["; this->PrintExpr(op->args[i * 2 + 1], os); - os << "]" << ((i < 3) ? ", ": ")"); + os << "]" << ((i < 3) ? ", " : ")"); } } else if (op->call_type == CallNode::PureExtern && op->dtype.is_vector()) { // @@ -470,8 +521,6 @@ void CodeGenCUDA::VisitExpr_(const CallNode *op, std::ostream& os) { this->PrintType(op->dtype, stream); stream << ' ' << sret << ";\n"; { - EnterScopeRAII scope(this); - // Load arguments. std::vector sargs; for (size_t i = 0; i < op->args.size(); ++i) { @@ -484,8 +533,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode *op, std::ostream& os) { std::ostringstream scall; scall << op->name << "("; for (size_t j = 0; j < op->args.size(); ++j) { - if (j > 0) - scall << ", "; + if (j > 0) scall << ", "; PrintVecElemLoad(sargs[j], op->args[j].dtype(), i, scall); } scall << ")"; @@ -517,46 +565,39 @@ void CodeGenCUDA::VisitStmt_(const AllocateNode* op) { this->PrintIndent(); int32_t constant_size = op->constant_allocation_size(); - CHECK_GT(constant_size, 0) - << "Can only handle constant size stack allocation for now"; + CHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation for now"; const VarNode* buffer = op->buffer_var.as(); std::string scope = alloc_storage_scope_.at(buffer); if (scope.find("wmma.") == 0) { if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") { - CHECK(op->dtype == DataType::Float(16) || - op->dtype == DataType::Int(8) || - op->dtype == DataType::UInt(8) || - op->dtype == DataType::Int(4) || - op->dtype == DataType::UInt(4) || - op->dtype == DataType::Int(1)) - << "Matrix_a and matrix_b only support half or char or unsigned char " - << "or uint4 or int4 or int1 type for now"; + CHECK(op->dtype == DataType::Float(16) || op->dtype == DataType::Int(8) || + op->dtype == DataType::UInt(8) || op->dtype == DataType::Int(4) || + op->dtype == DataType::UInt(4) || op->dtype == DataType::Int(1)) + << "Matrix_a and matrix_b only support half or char or unsigned char " + << "or uint4 or int4 or int1 type for now"; } else { - CHECK(op->dtype == DataType::Float(16) || - op->dtype == DataType::Float(32) || + CHECK(op->dtype == DataType::Float(16) || op->dtype == DataType::Float(32) || op->dtype == DataType::Int(32)) - << "Accumulator only support half, float and int type for now"; + << "Accumulator only support half, float and int type for now"; } constant_size = GetWmmaFragmentSize(scope, buffer, constant_size); PrintWmmaScope(scope, op->dtype, buffer, stream); } else { PrintStorageScope(scope, stream); - stream << ' '; PrintType(op->dtype, stream); } - if ((op->dtype == DataType::Int(4) || - op->dtype == DataType::UInt(4) || - op->dtype == DataType::Int(1)) && scope == "shared") { + if ((op->dtype == DataType::Int(4) || op->dtype == DataType::UInt(4) || + op->dtype == DataType::Int(1)) && + scope == "shared") { constant_size = constant_size / (32 / op->dtype.bits()); } - stream << ' '<< vid << '[' - << constant_size << "];\n"; + stream << ' ' << vid << '[' << constant_size << "];\n"; RegisterHandleType(op->buffer_var.get(), op->dtype); this->PrintStmt(op->body); } -void CodeGenCUDA::VisitStmt_(const EvaluateNode *op) { +void CodeGenCUDA::VisitStmt_(const EvaluateNode* op) { if (is_const(op->value)) return; const CallNode* call = op->value.as(); if (call && call->is_intrinsic(intrinsic::tvm_global_barrier_kinit)) { @@ -576,17 +617,17 @@ void CodeGenCUDA::VisitStmt_(const EvaluateNode *op) { void CodeGenCUDA::VisitExpr_(const RampNode* op, std::ostream& os) { os << "((make_int" << op->lanes << ")("; for (int i = 0; i < op->lanes; i++) { - os << "(" << PrintExpr(op->base) << ")" << "+(" << PrintExpr(op->stride) << "*" << i <<")"; - if (i != op->lanes - 1) - os << ", "; + os << "(" << PrintExpr(op->base) << ")" + << "+(" << PrintExpr(op->stride) << "*" << i << ")"; + if (i != op->lanes - 1) os << ", "; } os << "))"; } -void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*) +void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*) if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 8 && op->lanes == 4) { // make_int8x4 - const int64_t *p = as_const_int(op->value); + const int64_t* p = as_const_int(op->value); CHECK(p); int64_t v = *p & 0xFF; v = (v << 24) | (v << 16) | (v << 8) | v; @@ -605,7 +646,7 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // N os << '('; for (int i = 0; i < op->lanes / 2; ++i) { if (i != 0) os << ", "; - os << "__pack_half2(" << v << ", " << v << ")"; + os << "__pack_half2(" << v << ", " << v << ")"; } os << ')'; return; @@ -622,7 +663,7 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // N os << ')'; } -void CodeGenCUDA::VisitExpr_(const ShuffleNode* op, std::ostream &os) { +void CodeGenCUDA::VisitExpr_(const ShuffleNode* op, std::ostream& os) { std::vector to_shuffle(op->vectors.size()); for (int i = 0, e = op->vectors.size(); i < e; ++i) { CHECK(op->vectors[i].dtype().lanes() == 1) << "Only scalars can be shuffled in CUDA!"; @@ -632,15 +673,15 @@ void CodeGenCUDA::VisitExpr_(const ShuffleNode* op, std::ostream &os) { PrintType(op->dtype, os); os << '('; for (int i = 0, e = op->indices.size(); i < e; ++i) { - const int64_t *val = as_const_int(op->indices[i]); - CHECK(val && *val >= 0 && (int) *val < (int) to_shuffle.size()); + const int64_t* val = as_const_int(op->indices[i]); + CHECK(val && *val >= 0 && (int)*val < (int)to_shuffle.size()); if (i != 0) os << ", "; os << to_shuffle[*val]; } os << ')'; } -void CodeGenCUDA::VisitExpr_(const SelectNode* op, std::ostream &os) { +void CodeGenCUDA::VisitExpr_(const SelectNode* op, std::ostream& os) { // Non-vector cases. if (!op->dtype.is_vector()) { CodeGenC::VisitExpr_(op, os); @@ -648,8 +689,7 @@ void CodeGenCUDA::VisitExpr_(const SelectNode* op, std::ostream &os) { } // Codegen vector condition case by serializing the select op. - CHECK(op->false_value->dtype == op->dtype && - op->true_value->dtype == op->dtype && + CHECK(op->false_value->dtype == op->dtype && op->true_value->dtype == op->dtype && op->dtype.lanes() == op->condition.dtype().lanes()); std::string r_var = GetUniqueName("_"); @@ -657,8 +697,6 @@ void CodeGenCUDA::VisitExpr_(const SelectNode* op, std::ostream &os) { this->PrintType(op->dtype, stream); stream << ' ' << r_var << ";\n"; { - EnterScopeRAII scope(this); - std::string c_var = SSAGetID(PrintExpr(op->condition), op->dtype); std::string t_var = SSAGetID(PrintExpr(op->true_value), op->dtype); std::string f_var = SSAGetID(PrintExpr(op->false_value), op->dtype); @@ -682,9 +720,10 @@ void CodeGenCUDA::VisitExpr_(const SelectNode* op, std::ostream &os) { os << r_var; } -inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenCUDA* p) { // NOLINT(*) +inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenCUDA* p) { // NOLINT(*) switch (op->dtype.bits()) { - case 64: case 32: { + case 64: + case 32: { std::ostringstream temp; if (std::isinf(op->value)) { if (op->value < 0) { @@ -708,17 +747,17 @@ inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenCUDA* p) os << '(' << std::scientific << op->value << 'f' << ')'; break; } - default: LOG(FATAL) << "Bad bit-width for float: " << op->dtype << "\n"; + default: + LOG(FATAL) << "Bad bit-width for float: " << op->dtype << "\n"; } } - -void CodeGenCUDA::VisitExpr_(const FloatImmNode *op, std::ostream& os) { // NOLINT(*) +void CodeGenCUDA::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT(*) PrintConst(op, os, this); } -void CodeGenCUDA::PrintWmmaScope(const std::string &scope, DataType t, - const VarNode* variable, std::ostream &os) { +void CodeGenCUDA::PrintWmmaScope(const std::string& scope, DataType t, const VarNode* variable, + std::ostream& os) { std::stringstream type; PrintType(t, type); std::string shape_str = fragment_shapes[variable]; @@ -743,22 +782,22 @@ void CodeGenCUDA::PrintWmmaScope(const std::string &scope, DataType t, if (scope == "wmma.matrix_a") { need_mma_h_ = true; std::string layout_str = fragment_layouts[variable]; - os << "nvcuda::wmma::fragment"; + os << "nvcuda::wmma::fragment"; } else if (scope == "wmma.matrix_b") { need_mma_h_ = true; std::string layout_str = fragment_layouts[variable]; - os << "nvcuda::wmma::fragment"; + os << "nvcuda::wmma::fragment"; } else if (scope == "wmma.accumulator") { need_mma_h_ = true; - os << "nvcuda::wmma::fragment"; + os << "nvcuda::wmma::fragment"; } } -int32_t CodeGenCUDA::GetWmmaFragmentSize(const std::string &scope, - const VarNode* variable, int32_t size) { +int32_t CodeGenCUDA::GetWmmaFragmentSize(const std::string& scope, const VarNode* variable, + int32_t size) { std::string shape_str = fragment_shapes[variable]; size_t m, n, k; size_t last_pos = 0, pos = 0; @@ -779,8 +818,8 @@ int32_t CodeGenCUDA::GetWmmaFragmentSize(const std::string &scope, return 0; } -void CodeGenCUDA::HandleVolatileLoads(const std::string& value, - const LoadNode* op, std::ostream& os) { +void CodeGenCUDA::HandleVolatileLoads(const std::string& value, const LoadNode* op, + std::ostream& os) { // Cast away volatile qualifier for fp16 types. That is, only loads and // stores are volatile. The loaded objects are not marked as volatile. // @@ -793,15 +832,17 @@ void CodeGenCUDA::HandleVolatileLoads(const std::string& value, } } -void CodeGenCUDA::PrintVecElemLoadExpr( - DataType t, int i, const std::string& value, std::ostream& os) { +void CodeGenCUDA::PrintVecElemLoadExpr(DataType t, int i, const std::string& value, + std::ostream& os) { CHECK_GT(t.lanes(), 1); if (t.bits() == 8 && (t.is_int() || t.is_uint())) { - if (i != 0) { - os << "|"; + if (!(t.lanes() == 2 || t.lanes() == 3)) { + if (i != 0) { + os << "|"; + } + os << "((0x000000ff << " << i * 8 << ") & (" << value << " << " << i * 8 << "))"; + return; } - os << "((0x000000ff << " << i * 8 << ") & (" << value << " << " << i * 8 << "))"; - return; } if (t.is_float16()) { diff --git a/src/target/source/codegen_cuda.h b/src/target/source/codegen_cuda.h index d1db7047b1b6..f9ab0ade2cf2 100644 --- a/src/target/source/codegen_cuda.h +++ b/src/target/source/codegen_cuda.h @@ -26,8 +26,10 @@ #include #include + #include #include + #include "codegen_c.h" namespace tvm { @@ -46,37 +48,32 @@ class CodeGenCUDA final : public CodeGenC { void VisitStmt_(const ForNode* op) final; void PrintStorageSync(const CallNode* op) final; void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) - void PrintVecBinaryOp( - const std::string& op, DataType t, - PrimExpr lhs, PrimExpr rhs, std::ostream& os) final; // NOLINT(*) - void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) - void PrintVecElemLoad( - const std::string& vec, DataType t, int i, std::ostream& os) final; // NOLINT(*) - void PrintVecElemStore( - const std::string& vec, DataType t, int i, const std::string& value) final; + void PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs, PrimExpr rhs, + std::ostream& os) final; // NOLINT(*) + void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) + void PrintVecElemLoad(const std::string& vec, DataType t, int i, + std::ostream& os) final; // NOLINT(*) + void PrintVecElemStore(const std::string& vec, DataType t, int i, const std::string& value) final; void BindThreadIndex(const IterVar& iv) final; // NOLINT(*) void PrintVecElemLoadExpr(DataType t, int i, const std::string& value, std::ostream& os) final; // overload visitor - void VisitExpr_(const RampNode* op, std::ostream& os) final; // NOLINT(*) - void VisitExpr_(const ShuffleNode* op, std::ostream& os) final; // NOLINT(*) - void VisitExpr_(const SelectNode* op, std::ostream& os) final; // NOLINT(*) - void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*) - void VisitExpr_(const FloatImmNode *op, std::ostream& os) final; - void VisitExpr_(const CallNode *op, std::ostream& os) final; + void VisitExpr_(const RampNode* op, std::ostream& os) final; // NOLINT(*) + void VisitExpr_(const ShuffleNode* op, std::ostream& os) final; // NOLINT(*) + void VisitExpr_(const SelectNode* op, std::ostream& os) final; // NOLINT(*) + void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*) + void VisitExpr_(const FloatImmNode* op, std::ostream& os) final; + void VisitExpr_(const CallNode* op, std::ostream& os) final; void VisitExpr_(const CastNode* op, std::ostream& os) final; - void VisitStmt_(const EvaluateNode *op) final; - void VisitStmt_(const AllocateNode *op) final; - void VisitStmt_(const AttrStmtNode *op) final; + void VisitStmt_(const EvaluateNode* op) final; + void VisitStmt_(const AllocateNode* op) final; + void VisitStmt_(const AttrStmtNode* op) final; private: // Handle volatile loads - void HandleVolatileLoads(const std::string& value, const LoadNode* op, - std::ostream& os) final; + void HandleVolatileLoads(const std::string& value, const LoadNode* op, std::ostream& os) final; // Whether scope such as "__shared__" or "__constant__" is part of type. - bool IsScopePartOfType() const final { - return false; - } + bool IsScopePartOfType() const final { return false; } // Whether global barrier is needed. bool need_global_barrier_{false}; @@ -88,6 +85,8 @@ class CodeGenCUDA final : public CodeGenC { bool enable_fp16_{false}; // whether enable int8 bool enable_int8_{false}; + // whether enable warp shuffle intrinsics + bool enable_warp_shuffle_{false}; // whether need math_constants.h bool need_math_constants_h_{false}; // whether need mma.h @@ -96,10 +95,9 @@ class CodeGenCUDA final : public CodeGenC { std::unordered_map fragment_shapes; std::unordered_map fragment_layouts; friend void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenCUDA* p); - void PrintWmmaScope( - const std::string& scope, DataType t, const VarNode* variable, std::ostream& os); - int32_t GetWmmaFragmentSize( - const std::string &scope, const VarNode* variable, int32_t size); + void PrintWmmaScope(const std::string& scope, DataType t, const VarNode* variable, + std::ostream& os); + int32_t GetWmmaFragmentSize(const std::string& scope, const VarNode* variable, int32_t size); }; } // namespace codegen diff --git a/src/target/source/codegen_metal.cc b/src/target/source/codegen_metal.cc index ea49d33351a0..2c26ee977639 100644 --- a/src/target/source/codegen_metal.cc +++ b/src/target/source/codegen_metal.cc @@ -20,13 +20,15 @@ /*! * \file codegen_metal.cc */ -#include -#include -#include #include "codegen_metal.h" -#include "../build_common.h" + +#include +#include +#include + #include "../../runtime/metal/metal_module.h" #include "../../runtime/thread_storage_scope.h" +#include "../build_common.h" namespace tvm { namespace codegen { @@ -57,8 +59,7 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) { // add to alloc buffer type. auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); - CHECK(global_symbol.defined()) - << "CodeGenC: Expect PrimFunc to have the global_symbol attribute"; + CHECK(global_symbol.defined()) << "CodeGenC: Expect PrimFunc to have the global_symbol attribute"; // Function header. this->stream << "kernel void " << static_cast(global_symbol.value()) << "("; @@ -67,14 +68,13 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) { size_t num_buffer = 0; for (size_t i = 0; i < f->params.size(); ++i, ++num_buffer) { Var v = f->params[i]; - if (!v.dtype().is_handle()) break; + if (!v.dtype().is_handle()) break; stream << " "; std::string vid = AllocVarID(v.get()); auto it = alloc_storage_scope_.find(v.get()); if (it != alloc_storage_scope_.end()) { PrintStorageScope(it->second, stream); } - stream << ' '; PrintType(GetType(v), stream); // Register handle data type // TODO(tvm-team): consider simply keep type info in the @@ -84,17 +84,15 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) { RegisterHandleType(v.get(), prim->dtype); } } - stream << ' ' << vid - << " [[ buffer(" << i << ") ]],\n"; + stream << ' ' << vid << " [[ buffer(" << i << ") ]],\n"; } // Setup normal arguments. size_t nargs = f->params.size() - num_buffer; std::string varg = GetUniqueName("arg"); if (nargs != 0) { - std::string arg_buf_type = - static_cast(global_symbol.value()) + "_args_t"; - stream << " constant " << arg_buf_type << "& " << varg - << " [[ buffer(" << num_buffer << ") ]],\n"; + std::string arg_buf_type = static_cast(global_symbol.value()) + "_args_t"; + stream << " constant " << arg_buf_type << "& " << varg << " [[ buffer(" << num_buffer + << ") ]],\n"; // declare the struct decl_stream << "struct " << arg_buf_type << " {\n"; for (size_t i = num_buffer; i < f->params.size(); ++i) { @@ -121,11 +119,10 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) { CHECK_EQ(GetUniqueName("threadIdx"), "threadIdx"); CHECK_EQ(GetUniqueName("blockIdx"), "blockIdx"); int work_dim = 0; - auto thread_axis = f->GetAttr>( - tir::attr::kDeviceThreadAxis).value(); + auto thread_axis = f->GetAttr>(tir::attr::kDeviceThreadAxis).value(); for (IterVar iv : thread_axis) { - runtime::ThreadScope scope = runtime::ThreadScope::make(iv->thread_tag); + runtime::ThreadScope scope = runtime::ThreadScope::Create(iv->thread_tag); work_dim = std::max(work_dim, scope.dim_index + 1); } if (work_dim != 0) { @@ -165,23 +162,31 @@ void CodeGenMetal::BindThreadIndex(const IterVar& iv) { void CodeGenMetal::PrintType(DataType t, std::ostream& os) { // NOLINT(*) int lanes = t.lanes(); if (t.is_handle()) { - CHECK_EQ(lanes, 1) - << "do not yet support vector types"; - os << "void*"; return; + CHECK_EQ(lanes, 1) << "do not yet support vector types"; + os << "void*"; + return; } if (t == DataType::Bool()) { - os << "bool"; return; + os << "bool"; + return; } bool fail = false; if (t.is_float()) { switch (t.bits()) { - case 16: os << "half"; break; - case 32: os << "float"; break; - default: fail = true; break; + case 16: + os << "half"; + break; + case 32: + os << "float"; + break; + default: + fail = true; + break; } if (!fail && lanes == 1) return; if (!fail && (lanes >= 2 && lanes <= 4)) { - os << lanes; return; + os << lanes; + return; } } else if (t.is_uint() || t.is_int()) { if (t.is_uint()) { @@ -189,18 +194,30 @@ void CodeGenMetal::PrintType(DataType t, std::ostream& os) { // NOLINT(*) } if (t.bits() == 8 && t.lanes() == 4) { // directly 4 8 bit int in integer. - os << "int"; return; + os << "int"; + return; } switch (t.bits()) { - case 8: os << "char"; break; - case 16: os << "short"; break; - case 32: os << "int"; break; - case 1: os << "bool"; break; - default: fail = true; break; + case 8: + os << "char"; + break; + case 16: + os << "short"; + break; + case 32: + os << "int"; + break; + case 1: + os << "bool"; + break; + default: + fail = true; + break; } if (!fail && lanes == 1) return; if (!fail && (lanes >= 2 && lanes <= 4)) { - os << lanes; return; + os << lanes; + return; } } LOG(FATAL) << "Cannot convert type " << t << " to Metal type"; @@ -219,32 +236,29 @@ void CodeGenMetal::PrintStorageSync(const CallNode* op) { } } -void CodeGenMetal::PrintVecElemLoad(const std::string& vec, - DataType t, int i, +void CodeGenMetal::PrintVecElemLoad(const std::string& vec, DataType t, int i, std::ostream& os) { // NOLINT(*) os << vec << "[" << i << "]"; } -void CodeGenMetal::PrintVecElemStore(const std::string& vec, - DataType t, int i, +void CodeGenMetal::PrintVecElemStore(const std::string& vec, DataType t, int i, const std::string& value) { this->PrintIndent(); stream << vec << "[" << i << "]" << " = " << value << ";\n"; } -void CodeGenMetal::PrintStorageScope( - const std::string& scope, std::ostream& os) { // NOLINT(*) +void CodeGenMetal::PrintStorageScope(const std::string& scope, std::ostream& os) { // NOLINT(*) if (scope == "global") { - os << "device"; + os << "device "; } else if (scope == "shared") { - os << "threadgroup"; + os << "threadgroup "; } else { - os << "thread"; + os << "thread "; } } -void CodeGenMetal::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*) +void CodeGenMetal::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*) std::string v = PrintExpr(op->value); PrintType(op->dtype, os); os << "("; @@ -274,9 +288,8 @@ runtime::Module BuildMetal(IRModule mod) { CodeGenMetal cg; cg.Init(output_ssa); - for (auto kv : mod->functions) { - CHECK(kv.second->IsInstance()) - << "CodeGenMetal: Can only take PrimFunc"; + for (auto kv : mod->functions) { + CHECK(kv.second->IsInstance()) << "CodeGenMetal: Can only take PrimFunc"; auto f = Downcast(kv.second); auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); CHECK(calling_conv == CallingConv::kDeviceKernelLaunch) @@ -295,9 +308,8 @@ runtime::Module BuildMetal(IRModule mod) { return MetalModuleCreate(code, fmt, ExtractFuncInfo(mod), source); } -TVM_REGISTER_GLOBAL("target.build.metal") -.set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = BuildMetal(args[0]); - }); +TVM_REGISTER_GLOBAL("target.build.metal").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = BuildMetal(args[0]); +}); } // namespace codegen } // namespace tvm diff --git a/src/target/source/codegen_metal.h b/src/target/source/codegen_metal.h index 644c962ab2d6..26abe34d998e 100644 --- a/src/target/source/codegen_metal.h +++ b/src/target/source/codegen_metal.h @@ -25,7 +25,9 @@ #define TVM_TARGET_SOURCE_CODEGEN_METAL_H_ #include + #include + #include "codegen_c.h" namespace tvm { @@ -36,22 +38,21 @@ class CodeGenMetal final : public CodeGenC { CodeGenMetal(); // override print thread tag. void PrintArgUnionDecl(); - void AddFunction(const PrimFunc& f); // NOLINT(*) + void AddFunction(const PrimFunc& f); // NOLINT(*) void InitFuncState(const PrimFunc& f) final; - void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) - void PrintStorageSync(const CallNode* op) final; // NOLINT(*) - void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) - void BindThreadIndex(const IterVar& iv) final; // NOLINT(*) + void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) + void PrintStorageSync(const CallNode* op) final; // NOLINT(*) + void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) + void BindThreadIndex(const IterVar& iv) final; // NOLINT(*) // print load of single element - void PrintVecElemLoad( - const std::string& vec, DataType t, int i, std::ostream& os) final; // NOLINT(*) + void PrintVecElemLoad(const std::string& vec, DataType t, int i, + std::ostream& os) final; // NOLINT(*) // print store of single element. - void PrintVecElemStore( - const std::string& vec, DataType t, int i, const std::string& value) final; + void PrintVecElemStore(const std::string& vec, DataType t, int i, const std::string& value) final; // overload visitor - void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*) + void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*) // overload visitor - void VisitExpr_(const CallNode* op, std::ostream& os) final; // NOLINT(*) + void VisitExpr_(const CallNode* op, std::ostream& os) final; // NOLINT(*) // reuse parent's function. using CodeGenC::PrintType; diff --git a/src/target/source/codegen_opencl.cc b/src/target/source/codegen_opencl.cc index d5b89609e514..8616853d8883 100644 --- a/src/target/source/codegen_opencl.cc +++ b/src/target/source/codegen_opencl.cc @@ -20,20 +20,20 @@ /*! * \file codegen_opencl.cc */ +#include "codegen_opencl.h" + #include -#include #include -#include "codegen_opencl.h" -#include "../build_common.h" -#include "../../runtime/thread_storage_scope.h" +#include + #include "../../runtime/opencl/opencl_module.h" +#include "../../runtime/thread_storage_scope.h" +#include "../build_common.h" namespace tvm { namespace codegen { -CodeGenOpenCL::CodeGenOpenCL() { - restrict_keyword_ = "restrict"; -} +CodeGenOpenCL::CodeGenOpenCL() { restrict_keyword_ = "restrict"; } void CodeGenOpenCL::InitFuncState(const PrimFunc& f) { CodeGenC::InitFuncState(f); @@ -44,34 +44,30 @@ void CodeGenOpenCL::InitFuncState(const PrimFunc& f) { } } -void CodeGenOpenCL::PrintFuncPrefix() { - stream << "__kernel void"; -} +void CodeGenOpenCL::PrintFuncPrefix() { stream << "__kernel void"; } std::string CodeGenOpenCL::Finish() { // inject extension enable pragma for fp16 and fp64 if (enable_fp16_) { - decl_stream - << "#ifdef cl_khr_fp16\n" - "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n" - "#elif defined(cl_amd_fp16)\n" - "#pragma OPENCL EXTENSION cl_amd_fp16 : enable\n" - "#else\n" - "#error \"Half precision floating point not supported" - "by OpenCL implementation on your device.\" \n" - "#endif\n\n"; + decl_stream << "#ifdef cl_khr_fp16\n" + "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n" + "#elif defined(cl_amd_fp16)\n" + "#pragma OPENCL EXTENSION cl_amd_fp16 : enable\n" + "#else\n" + "#error \"Half precision floating point not supported" + "by OpenCL implementation on your device.\" \n" + "#endif\n\n"; } if (enable_fp64_) { - decl_stream - << "#ifdef cl_khr_fp64\n" - "#pragma OPENCL EXTENSION cl_khr_fp64 : enable\n" - "#elif defined(cl_amd_fp64)\n" - "#pragma OPENCL EXTENSION cl_amd_fp64 : enable\n" - "#else\n" - "#error \"Double precision floating point not supported" - "by OpenCL implementation on your device.\" \n" - "#endif\n\n"; + decl_stream << "#ifdef cl_khr_fp64\n" + "#pragma OPENCL EXTENSION cl_khr_fp64 : enable\n" + "#elif defined(cl_amd_fp64)\n" + "#pragma OPENCL EXTENSION cl_amd_fp64 : enable\n" + "#else\n" + "#error \"Double precision floating point not supported" + "by OpenCL implementation on your device.\" \n" + "#endif\n\n"; } return CodeGenC::Finish(); @@ -79,26 +75,26 @@ std::string CodeGenOpenCL::Finish() { void CodeGenOpenCL::BindThreadIndex(const IterVar& iv) { CHECK(!var_idmap_.count(iv->var.get())); - runtime::ThreadScope ts = runtime::ThreadScope::make(iv->thread_tag); + runtime::ThreadScope ts = runtime::ThreadScope::Create(iv->thread_tag); std::ostringstream os; if (ts.rank == 1) { os << "get_local_id(" << ts.dim_index << ")"; } else { os << "get_group_id(" << ts.dim_index << ")"; } - var_idmap_[iv->var.get()] = - CastFromTo(os.str(), DataType::UInt(64), iv->var.dtype()); + var_idmap_[iv->var.get()] = CastFromTo(os.str(), DataType::UInt(64), iv->var.dtype()); } void CodeGenOpenCL::PrintType(DataType t, std::ostream& os) { // NOLINT(*) int lanes = t.lanes(); if (t.is_handle()) { - CHECK_EQ(lanes, 1) - << "do not yet support vector types"; - os << "void*"; return; + CHECK_EQ(lanes, 1) << "do not yet support vector types"; + os << "void*"; + return; } if (t == DataType::Bool()) { - os << "bool"; return; + os << "bool"; + return; } bool fail = false; if (t.is_float()) { @@ -107,16 +103,21 @@ void CodeGenOpenCL::PrintType(DataType t, std::ostream& os) { // NOLINT(*) os << "half"; enable_fp16_ = true; break; - case 32: os << "float"; break; + case 32: + os << "float"; + break; case 64: os << "double"; enable_fp64_ = true; break; - default: fail = true; break; + default: + fail = true; + break; } if (!fail && lanes == 1) return; if (!fail && (lanes >= 2 && lanes <= 16)) { - os << lanes; return; + os << lanes; + return; } } else if (t.is_uint() || t.is_int()) { if (t.is_uint()) { @@ -124,41 +125,53 @@ void CodeGenOpenCL::PrintType(DataType t, std::ostream& os) { // NOLINT(*) } if (t.bits() == 8 && t.lanes() == 4) { // directly 4 8 bit int in integer. - os << "int"; return; + os << "int"; + return; } switch (t.bits()) { - case 8: os << "char"; break; - case 16: os << "short"; break; - case 32: os << "int"; break; - case 64: os << "long"; break; - case 1: os << "int"; break; - default: fail = true; break; + case 8: + os << "char"; + break; + case 16: + os << "short"; + break; + case 32: + os << "int"; + break; + case 64: + os << "long"; + break; + case 1: + os << "int"; + break; + default: + fail = true; + break; } if (!fail && lanes == 1) return; if (!fail && (lanes >= 2 && lanes <= 16)) { - os << lanes; return; + os << lanes; + return; } } LOG(FATAL) << "Cannot convert type " << t << " to OpenCL type"; } -void CodeGenOpenCL::PrintVecAddr(const VarNode* buffer, DataType t, - PrimExpr base, std::ostream& os) { // NOLINT(*) +void CodeGenOpenCL::PrintVecAddr(const VarNode* buffer, DataType t, PrimExpr base, + std::ostream& os) { // NOLINT(*) if (!HandleTypeMatch(buffer, t.element_of())) { os << '('; auto it = alloc_storage_scope_.find(buffer); if (it != alloc_storage_scope_.end()) { PrintStorageScope(it->second, os); } - os << ' '; PrintType(t.element_of(), os); os << "*)"; } os << GetVarID(buffer) << " + "; PrintExpr(base, os); } -std::string CodeGenOpenCL::GetVecLoad( - DataType t, const VarNode* buffer, PrimExpr base) { +std::string CodeGenOpenCL::GetVecLoad(DataType t, const VarNode* buffer, PrimExpr base) { std::ostringstream os; os << "vload" << t.lanes() << "(0, "; PrintVecAddr(buffer, t, base, os); @@ -166,8 +179,7 @@ std::string CodeGenOpenCL::GetVecLoad( return os.str(); } -void CodeGenOpenCL::PrintVecStore(const VarNode* buffer, - DataType t, PrimExpr base, +void CodeGenOpenCL::PrintVecStore(const VarNode* buffer, DataType t, PrimExpr base, const std::string& value) { this->PrintIndent(); stream << "vstore" << t.lanes() << "(" << value << ", 0, "; @@ -188,12 +200,11 @@ void CodeGenOpenCL::PrintStorageSync(const CallNode* op) { } } -void CodeGenOpenCL::PrintStorageScope( - const std::string& scope, std::ostream& os) { // NOLINT(*) +void CodeGenOpenCL::PrintStorageScope(const std::string& scope, std::ostream& os) { // NOLINT(*) if (scope == "global") { - os << "__global"; + os << "__global "; } else if (scope == "shared") { - os << "__local"; + os << "__local "; } } @@ -213,7 +224,7 @@ std::string CodeGenOpenCL::CastFromTo(std::string value, DataType from, DataType return os.str(); } -void CodeGenOpenCL::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*) +void CodeGenOpenCL::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*) std::string v = PrintExpr(op->value); os << "(("; PrintType(op->dtype, os); @@ -225,7 +236,7 @@ void CodeGenOpenCL::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // os << "))"; } -void CodeGenOpenCL::VisitExpr_(const FloatImmNode *op, std::ostream& os) { // NOLINT(*) +void CodeGenOpenCL::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT(*) if (std::isinf(op->value)) { if (op->value < 0) { os << "-"; @@ -238,15 +249,14 @@ void CodeGenOpenCL::VisitExpr_(const FloatImmNode *op, std::ostream& os) { // NO } } -runtime::Module BuildOpenCL(IRModule mod) { +runtime::Module BuildOpenCL(IRModule mod, std::string target) { using tvm::runtime::Registry; bool output_ssa = false; CodeGenOpenCL cg; cg.Init(output_ssa); - for (auto kv : mod->functions) { - CHECK(kv.second->IsInstance()) - << "CodeGenOpenCL: Can only take PrimFunc"; + for (auto kv : mod->functions) { + CHECK(kv.second->IsInstance()) << "CodeGenOpenCL: Can only take PrimFunc"; auto f = Downcast(kv.second); auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); CHECK(calling_conv == CallingConv::kDeviceKernelLaunch) @@ -261,7 +271,6 @@ runtime::Module BuildOpenCL(IRModule mod) { return OpenCLModuleCreate(code, "cl", ExtractFuncInfo(mod), code); } -TVM_REGISTER_GLOBAL("target.build.opencl") -.set_body_typed(BuildOpenCL); +TVM_REGISTER_GLOBAL("target.build.opencl").set_body_typed(BuildOpenCL); } // namespace codegen } // namespace tvm diff --git a/src/target/source/codegen_opencl.h b/src/target/source/codegen_opencl.h index cc1fe994739f..32a98e4d87ea 100644 --- a/src/target/source/codegen_opencl.h +++ b/src/target/source/codegen_opencl.h @@ -25,7 +25,9 @@ #define TVM_TARGET_SOURCE_CODEGEN_OPENCL_H_ #include + #include + #include "codegen_c.h" namespace tvm { @@ -38,24 +40,22 @@ class CodeGenOpenCL final : public CodeGenC { // override print thread tag. void InitFuncState(const PrimFunc& f) final; - void PrintFuncPrefix() final; // NOLINT(*) - void BindThreadIndex(const IterVar& iv) final; // NOLINT(*) - void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) - void PrintStorageSync(const CallNode* op) final; // NOLINT(*) - void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) - std::string GetVecLoad(DataType t, const VarNode* buffer, - PrimExpr base) final; - void PrintVecStore(const VarNode* buffer, - DataType t, PrimExpr base, + void PrintFuncPrefix() final; // NOLINT(*) + void BindThreadIndex(const IterVar& iv) final; // NOLINT(*) + void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) + void PrintStorageSync(const CallNode* op) final; // NOLINT(*) + void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) + std::string GetVecLoad(DataType t, const VarNode* buffer, PrimExpr base) final; + void PrintVecStore(const VarNode* buffer, DataType t, PrimExpr base, const std::string& value) final; // NOLINT(*) // the address of load/store - void PrintVecAddr(const VarNode* buffer, DataType t, - PrimExpr base, std::ostream& os); // NOLINT(*) - std::string CastFromTo(std::string value, DataType from, DataType target); // NOLINT(*) + void PrintVecAddr(const VarNode* buffer, DataType t, PrimExpr base, + std::ostream& os); // NOLINT(*) + std::string CastFromTo(std::string value, DataType from, DataType target); // NOLINT(*) // overload visitor - void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*) - void VisitExpr_(const FloatImmNode *op, std::ostream& os) final; // NOLINT(*) + void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*) + void VisitExpr_(const FloatImmNode* op, std::ostream& os) final; // NOLINT(*) private: // whether enable fp16 and fp64 extension diff --git a/src/target/source/codegen_opengl.cc b/src/target/source/codegen_opengl.cc deleted file mode 100644 index 946b483a1dd9..000000000000 --- a/src/target/source/codegen_opengl.cc +++ /dev/null @@ -1,315 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file codegen_opengl.cc - * - * We are targeting OpenGL 3.3. The reason of not targeting a recent version - * of OpenGL is to have better compatibility of WebGL 2. - */ -#include -#include -#include -#include -#include "codegen_opengl.h" -#include "../build_common.h" -#include "../../runtime/thread_storage_scope.h" - -namespace tvm { -namespace codegen { - -CodeGenOpenGL::CodeGenOpenGL() - : output_(nullptr), output_iter_var_(nullptr) {} - -void CodeGenOpenGL::InitFuncState(const PrimFunc& f) { - CodeGenC::InitFuncState(f); - output_ = nullptr; - inputs_.clear(); - output_iter_var_ = nullptr; - thread_extent_var_ = ""; - this->decl_stream.str(""); - this->stream.str(""); -} - -void CodeGenOpenGL::AddFunction(const PrimFunc& f) { - // clear previous generated state. - this->InitFuncState(f); - - this->decl_stream << "#version 300 es\n"; - this->decl_stream << "precision highp float;\n"; - - // skip the first underscore, so SSA variable starts from _1 - GetUniqueName("_"); - - // Allocate argument names. Store in `var_idmap_`. - for (auto arg : f->params) { - auto arg_name = GetUniqueName(arg.get()->name_hint); - var_idmap_[arg.get()] = arg_name; - - if (auto* ptr = arg->type_annotation.as()) { - if (auto* prim = ptr->element_type.as()) { - RegisterHandleType(arg.get(), prim->dtype); - } - } - } - - thread_extent_var_ = GetUniqueName("thread_extent"); - this->decl_stream << "uniform int " << thread_extent_var_ << ";\n"; - - this->stream << "void main() {\n"; - - int func_scope = this->BeginScope(); - this->PrintStmt(f->body); - this->EndScope(func_scope); - - this->PrintIndent(); - this->stream << "}\n\n"; - - // Declare arguments. - for (auto arg : f->params) { - if (this->inputs_.find(arg.get()) != this->inputs_.cend()) { - // Declare input texture. - // Format: - // - Float: "uniform sampler2D {name};" - // - Int: "uniform isampler2D {name};" - // - UInt: "uniform usampler2D {name};" - - auto arg_name = GetVarID(arg.get()); - - auto type_it = this->handle_data_type_.find(arg.get()); - CHECK(type_it != this->handle_data_type_.cend()) << "Cannot find type."; - DLDataType type = type_it->second; - CHECK_EQ(type.lanes, 1) << "Vector type not supported."; - - switch (type.code) { - case kDLInt: - this->decl_stream << "uniform isampler2D " << arg_name << ";\n"; - break; - case kDLUInt: - this->decl_stream << "uniform usampler2D " << arg_name << ";\n"; - break; - case kDLFloat: - this->decl_stream << "uniform sampler2D " << arg_name << ";\n"; - break; - default: - LOG(FATAL) << "Unsupported type code."; - } - - } else if (this->output_ == arg.get()) { - // Declare output texture. - // Format: "out {type} {name};" - - auto arg_name = GetVarID(arg.get()); - - auto type_it = this->handle_data_type_.find(arg.get()); - CHECK(type_it != this->handle_data_type_.cend()) << "Cannot find type."; - auto type = type_it->second; - - this->decl_stream << "out "; - PrintType(type, this->decl_stream); - this->decl_stream << " " << arg_name << ";\n"; - - } else { - // Declare uniform value. - // Format: "uniform {type} {name};" - - auto arg_name = GetVarID(arg.get()); - auto type = arg.get()->dtype; - - this->decl_stream << "uniform "; - PrintType(type, this->decl_stream); - this->decl_stream << " " << arg_name << ";\n"; - } - } - - std::vector arg_names; - std::vector arg_kinds; - for (auto arg : f->params) { - std::string name = GetVarID(arg.get()); - - runtime::OpenGLArgKind kind; - if (inputs_.find(arg.get()) != inputs_.cend()) { - kind = runtime::OpenGLArgKind::kInputTexture; - } else if (output_ == arg.get()) { - kind = runtime::OpenGLArgKind::kOutputTexture; - } else { - kind = runtime::OpenGLArgKind::kUniform; - } - - arg_names.push_back(name); - arg_kinds.push_back(kind); - } - - auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); - CHECK(global_symbol.defined()) - << "CodeGenOpenGL: Expect PrimFunc to have the global_symbol attribute"; - - shaders_[static_cast(global_symbol.value())] = runtime::OpenGLShader( - this->decl_stream.str() + this->stream.str(), - std::move(arg_names), std::move(arg_kinds), - this->thread_extent_var_); -} - -std::unordered_map CodeGenOpenGL::Finish() { - return shaders_; -} - -void CodeGenOpenGL::BindThreadIndex(const IterVar& iv) { - CHECK_EQ(iv->thread_tag, "threadIdx.x") << "Must be threadIdx.x"; - CHECK(var_idmap_.find(iv->var.get()) == var_idmap_.end()) - << "Only support one thread iter var"; - CHECK(output_iter_var_ == nullptr) << "Only support one thread iter var"; - - var_idmap_[iv->var.get()] = iv->thread_tag; - output_iter_var_ = iv->var.get(); - - // Declare threadIdx local variable. - this->PrintIndent(); - this->stream << "ivec2 threadIdx = ivec2(" << runtime::kTextureRowSize - << " * int(gl_FragCoord.y) + int(gl_FragCoord.x), 0);\n"; - - // Return directly if threadIdx.x >= thread_extent. - this->PrintIndent(); - this->stream << "if (threadIdx.x >= " << thread_extent_var_ << ") {\n"; - this->PrintIndent(); - this->stream << " return;\n"; - this->PrintIndent(); - this->stream << "}\n"; -} - -void CodeGenOpenGL::VisitStmt_(const StoreNode* op) { - LOG(FATAL) << "Store statement not supported in OpenGL." - << " Texture store should be a Call statement."; -} - -// texelFetch(tex, ivec2(idx & kTextureRowMask, idx >> kTextureRowBits), 0).r -std::string CodeGenOpenGL::TexelFetch(const VarNode* buffer, PrimExpr index) { - std::ostringstream os; - os << "texelFetch(" << GetVarID(buffer) << ", ivec2(int("; - PrintExpr(index, os); - os << ") & " << runtime::kTextureRowMask << ", int("; - PrintExpr(index, os); - os << ") >> " << runtime::kTextureRowBits << "), 0).r"; - return os.str(); -} - -// Print a reference expression to a buffer. -// Format: texelFetch(buffer, index, 0).r -std::string CodeGenOpenGL::GetBufferRef( - DataType t, const VarNode* buffer, PrimExpr index) { - CHECK_EQ(t.lanes(), 1) << "Vector type not supported."; - CHECK(HandleTypeMatch(buffer, t)) << "Type mismatch not supported."; - - if (buffer == this->output_) { - // This is the output texture. - return GetVarID(buffer); - } else { - // This is an input texture. - this->inputs_.insert(buffer); - return TexelFetch(buffer, index); - } -} - -void CodeGenOpenGL::PrintType(DataType t, std::ostream& os) { - switch (t.code()) { - case kDLInt: - CHECK_EQ(t.bits(), 32) << "Only support 32-bit int."; - os << "int"; - break; - case kDLUInt: - CHECK_EQ(t.bits(), 32) << "Only support 32-bit uint."; - os << "uint"; - break; - case kDLFloat: - CHECK_EQ(t.bits(), 32) << "Only support 32-bit float."; - os << "float"; - break; - default: - LOG(FATAL) << "Unsupported type code."; - } -} - -// Codegen for immediate values - -void CodeGenOpenGL::VisitExpr_(const IntImmNode* op, std::ostream& os) { - CHECK_EQ(op->dtype, DataType::Int(32)) << "GLSL 3.0 only supports 32-bit ints."; - CodeGenC::VisitExpr_(op, os); -} - -void CodeGenOpenGL::VisitExpr_(const FloatImmNode* op, std::ostream& os) { - CHECK_EQ(op->dtype, DataType::Float(32)) << "GLSL 3.0 only supports 32-bit floats."; - CodeGenC::VisitExpr_(op, os); -} - -void CodeGenOpenGL::VisitExpr_(const StringImmNode*, std::ostream& os) { - LOG(FATAL) << "GLSL 3.0 doesn't support strings."; -} - -void CodeGenOpenGL::VisitStmt_(const EvaluateNode* op) { - auto call = op->value.as(); - if (call == nullptr || call->name != CallNode::glsl_texture_store) { - // Fallback to normal logic. - CodeGenC::VisitStmt_(op); - } - - CHECK_EQ(call->args.size(), 2); - auto buffer = call->args[0].as(); - auto value = call->args[1]; - - // Doesn't support store to vector. - auto type = value.dtype(); - CHECK_EQ(type.lanes(), 1) - << "Vectorized store not implemented, type = " << type; - - CHECK(inputs_.find(buffer) == inputs_.cend()) - << "Texture has been read from before. Must not store to it."; - if (output_ == nullptr) { - output_ = buffer; // Record that this texture is the output. - } else { - CHECK(output_ == buffer) << "GLSL can only write to 1 texture."; - } - - this->PrintIndent(); - this->stream << GetVarID(buffer) << " = " << PrintExpr(value) << ";\n"; -} - -runtime::Module BuildOpenGL(IRModule mod) { - bool output_ssa = false; - CodeGenOpenGL cg; - cg.Init(output_ssa); - - for (auto kv : mod->functions) { - CHECK(kv.second->IsInstance()) - << "CodeGenOpenGL: Can only take PrimFunc"; - auto f = Downcast(kv.second); - auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); - CHECK(calling_conv == CallingConv::kDeviceKernelLaunch) - << "CodeGenOpenGL: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; - cg.AddFunction(f); - } - - auto shaders = cg.Finish(); - return OpenGLModuleCreate(shaders, "gl", ExtractFuncInfo(mod)); -} - -TVM_REGISTER_GLOBAL("target.build.opengl") -.set_body_typed(BuildOpenGL); - -} // namespace codegen -} // namespace tvm diff --git a/src/target/source/codegen_opengl.h b/src/target/source/codegen_opengl.h deleted file mode 100644 index 954806bbca59..000000000000 --- a/src/target/source/codegen_opengl.h +++ /dev/null @@ -1,69 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file codegen_opengl.h - * \brief Generate OpenGL device code. - */ -#ifndef TVM_TARGET_SOURCE_CODEGEN_OPENGL_H_ -#define TVM_TARGET_SOURCE_CODEGEN_OPENGL_H_ - -#include -#include -#include -#include -#include "codegen_c.h" -#include "../../runtime/opengl/opengl_module.h" - -namespace tvm { -namespace codegen { - -class CodeGenOpenGL final : public CodeGenC { - public: - CodeGenOpenGL(); - std::unordered_map Finish(); - - void AddFunction(const PrimFunc& f); - void InitFuncState(const PrimFunc& f) final; - void BindThreadIndex(const IterVar& iv) final; - void VisitStmt_(const StoreNode* op) final; - std::string TexelFetch(const VarNode* buffer, PrimExpr index); - std::string GetBufferRef(DataType t, const VarNode* buffer, PrimExpr index) final; - void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) - - // Codegen for immediate values - void VisitExpr_(const IntImmNode* op, std::ostream& os) final; // NOLINT(*) - void VisitExpr_(const FloatImmNode* op, std::ostream& os) final; // NOLINT(*) - void VisitExpr_(const StringImmNode* op, std::ostream& os) final; // NOLINT(*) - - // Match glsl_texture_store Call. - void VisitStmt_(const EvaluateNode* op) final; // NOLINT(*) - - private: - const VarNode* output_{nullptr}; - std::unordered_set inputs_; - const VarNode* output_iter_var_{nullptr}; - std::unordered_map shaders_; - std::string thread_extent_var_; -}; - -} // namespace codegen -} // namespace tvm - -#endif // TVM_TARGET_SOURCE_CODEGEN_OPENGL_H_ diff --git a/src/target/source/codegen_source_base.cc b/src/target/source/codegen_source_base.cc index 0859428aa58b..9b2f0345864f 100644 --- a/src/target/source/codegen_source_base.cc +++ b/src/target/source/codegen_source_base.cc @@ -70,8 +70,7 @@ std::string CodeGenSourceBase::SSAGetID(std::string src, DataType t) { } std::string CodeGenSourceBase::AllocVarID(const tir::VarNode* v) { - CHECK(!var_idmap_.count(v)) - << "Need input to be in SSA form dup " << v->name_hint; + CHECK(!var_idmap_.count(v)) << "Need input to be in SSA form dup " << v->name_hint; std::string key = v->name_hint; std::string vid = GetUniqueName(key); var_idmap_[v] = vid; @@ -80,8 +79,7 @@ std::string CodeGenSourceBase::AllocVarID(const tir::VarNode* v) { std::string CodeGenSourceBase::GetVarID(const tir::VarNode* v) const { auto it = var_idmap_.find(v); - CHECK(it != var_idmap_.end()) - << "Find undefined Variable " << v->name_hint; + CHECK(it != var_idmap_.end()) << "Find undefined Variable " << v->name_hint; return it->second; } diff --git a/src/target/source/codegen_source_base.h b/src/target/source/codegen_source_base.h index 6723767b401f..39016590abdc 100644 --- a/src/target/source/codegen_source_base.h +++ b/src/target/source/codegen_source_base.h @@ -24,13 +24,15 @@ #ifndef TVM_TARGET_SOURCE_CODEGEN_SOURCE_BASE_H_ #define TVM_TARGET_SOURCE_CODEGEN_SOURCE_BASE_H_ +#include #include #include -#include -#include -#include + #include +#include #include +#include + #include "../../runtime/meta_data.h" namespace tvm { @@ -103,8 +105,7 @@ class CodeGenSourceBase { * \param src The source expression. * \param t The type of target. */ - virtual void PrintSSAAssign( - const std::string& target, const std::string& src, DataType t) = 0; + virtual void PrintSSAAssign(const std::string& target, const std::string& src, DataType t) = 0; /*! \brief the declaration stream */ std::ostringstream decl_stream; @@ -147,11 +148,8 @@ runtime::Module CSourceModuleCreate(std::string code, std::string fmt); * \param fget_source a closure to replace default get source behavior. */ runtime::Module DeviceSourceModuleCreate( - std::string data, - std::string fmt, - std::unordered_map fmap, - std::string type_key, - std::function fget_source = nullptr); + std::string data, std::string fmt, std::unordered_map fmap, + std::string type_key, std::function fget_source = nullptr); } // namespace codegen } // namespace tvm #endif // TVM_TARGET_SOURCE_CODEGEN_SOURCE_BASE_H_ diff --git a/src/target/source/codegen_vhls.cc b/src/target/source/codegen_vhls.cc index 71c36264afa4..e60e1f5027d7 100644 --- a/src/target/source/codegen_vhls.cc +++ b/src/target/source/codegen_vhls.cc @@ -20,11 +20,13 @@ /*! * \file codegen_vhls.cc */ -#include -#include #include "codegen_vhls.h" -#include "../build_common.h" + +#include +#include + #include "../../runtime/opencl/sdaccel/sdaccel_module.h" +#include "../build_common.h" namespace tvm { namespace codegen { @@ -40,37 +42,45 @@ void CodeGenVivadoHLS::PrintType(DataType t, std::ostream& os) { if (t.is_uint()) { switch (t.bits()) { case 8: - os << "unsigned char"; break; + os << "unsigned char"; + break; case 16: - os << "unsigned short"; break; + os << "unsigned short"; + break; case 32: - os << "unsigned int"; break; + os << "unsigned int"; + break; case 64: - os << "unsigned long long"; break; + os << "unsigned long long"; + break; default: - os << "ap_uint<" << t.bits() << ">"; break; + os << "ap_uint<" << t.bits() << ">"; + break; } } else if (t.is_int()) { switch (t.bits()) { case 8: - os << "char"; break; + os << "char"; + break; case 16: - os << "short"; break; + os << "short"; + break; case 32: - os << "int"; break; + os << "int"; + break; case 64: - os << "long long"; break; + os << "long long"; + break; default: - os << "ap_int<" << t.bits() << ">"; break; + os << "ap_int<" << t.bits() << ">"; + break; } } else { CodeGenC::PrintType(t, os); } } -void CodeGenVivadoHLS::PrintFuncPrefix() { - stream << "extern \"C\" void"; -} +void CodeGenVivadoHLS::PrintFuncPrefix() { stream << "extern \"C\" void"; } void CodeGenVivadoHLS::PreFunctionBody(const PrimFunc& f) { for (size_t i = 0; i < f->params.size(); ++i) { @@ -84,9 +94,8 @@ void CodeGenVivadoHLS::PreFunctionBody(const PrimFunc& f) { this->stream << "#pragma HLS INTERFACE s_axilite port=return bundle=control\n\n"; } -template -inline void PrintBinaryExpr(const T* op, - const char *opstr, +template +inline void PrintBinaryExpr(const T* op, const char* opstr, std::ostream& os, // NOLINT(*) CodeGenVivadoHLS* p) { os << opstr << '('; @@ -96,35 +105,38 @@ inline void PrintBinaryExpr(const T* op, os << ')'; } -void CodeGenVivadoHLS::VisitExpr_(const MinNode *op, std::ostream& os) { // NOLINT(*) - const char *opstr = "std::min"; +void CodeGenVivadoHLS::VisitExpr_(const MinNode* op, std::ostream& os) { // NOLINT(*) + const char* opstr = "std::min"; if (op->dtype.is_float()) { switch (op->dtype.bits()) { case 32: - opstr = "fminf"; break; + opstr = "fminf"; + break; case 64: - opstr = "fmin"; break; + opstr = "fmin"; + break; } } PrintBinaryExpr(op, opstr, os, this); } -void CodeGenVivadoHLS::VisitExpr_(const MaxNode *op, std::ostream& os) { // NOLINT(*) - const char *opstr = "std::max"; +void CodeGenVivadoHLS::VisitExpr_(const MaxNode* op, std::ostream& os) { // NOLINT(*) + const char* opstr = "std::max"; if (op->dtype.is_float()) { switch (op->dtype.bits()) { case 32: - opstr = "fmaxf"; break; + opstr = "fmaxf"; + break; case 64: - opstr = "fmax"; break; + opstr = "fmax"; + break; } } PrintBinaryExpr(op, opstr, os, this); } - runtime::Module BuildSDAccel(IRModule mod, std::string target_str) { using tvm::runtime::Registry; bool output_ssa = false; @@ -133,9 +145,8 @@ runtime::Module BuildSDAccel(IRModule mod, std::string target_str) { // Generate source code for get_source(). cg.Init(output_ssa); - for (auto kv : mod->functions) { - CHECK(kv.second->IsInstance()) - << "CodeGenVHLS: Can only take PrimFunc"; + for (auto kv : mod->functions) { + CHECK(kv.second->IsInstance()) << "CodeGenVHLS: Can only take PrimFunc"; auto f = Downcast(kv.second); auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); CHECK(calling_conv == CallingConv::kDeviceKernelLaunch) @@ -148,9 +159,8 @@ runtime::Module BuildSDAccel(IRModule mod, std::string target_str) { // Generate source code for compilation. Array > kernel_info; - for (auto kv : mod->functions) { - CHECK(kv.second->IsInstance()) - << "CodeGenOpenCL: Can only take PrimFunc"; + for (auto kv : mod->functions) { + CHECK(kv.second->IsInstance()) << "CodeGenOpenCL: Can only take PrimFunc"; auto f = Downcast(kv.second); CodeGenVivadoHLS cg; cg.Init(output_ssa); @@ -176,8 +186,7 @@ runtime::Module BuildSDAccel(IRModule mod, std::string target_str) { return SDAccelModuleCreate(xclbin, "xclbin", ExtractFuncInfo(mod), whole_code); } -TVM_REGISTER_GLOBAL("target.build.sdaccel") -.set_body_typed(BuildSDAccel); +TVM_REGISTER_GLOBAL("target.build.sdaccel").set_body_typed(BuildSDAccel); } // namespace codegen } // namespace tvm diff --git a/src/target/source/codegen_vhls.h b/src/target/source/codegen_vhls.h index 10f9ea7679b6..b9bec516bae9 100644 --- a/src/target/source/codegen_vhls.h +++ b/src/target/source/codegen_vhls.h @@ -27,7 +27,9 @@ #include #include #include + #include + #include "codegen_c.h" namespace tvm { @@ -40,8 +42,8 @@ class CodeGenVivadoHLS final : public CodeGenC { void PrintFuncPrefix() final; void PreFunctionBody(const PrimFunc& f) final; - void VisitExpr_(const MinNode *op, std::ostream& os) final; - void VisitExpr_(const MaxNode *op, std::ostream& os) final; + void VisitExpr_(const MinNode* op, std::ostream& os) final; + void VisitExpr_(const MaxNode* op, std::ostream& os) final; }; } // namespace codegen diff --git a/src/target/source/intrin_rule_aocl.cc b/src/target/source/intrin_rule_aocl.cc index 6317a2fab381..0cafd0255a86 100644 --- a/src/target/source/intrin_rule_aocl.cc +++ b/src/target/source/intrin_rule_aocl.cc @@ -27,73 +27,49 @@ namespace tvm { namespace codegen { namespace intrin { -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.floor") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.floor").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.ceil") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.ceil").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.trunc") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.trunc").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.fabs") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.fabs").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.round") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.round").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.exp") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.exp").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.log") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.log").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.tanh") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.tanh").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.sqrt") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.sqrt").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.pow") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.pow").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.popcount") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl.popcount").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.floor").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.floor") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.ceil").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.ceil") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.trunc").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.trunc") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.fabs").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.fabs") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.round").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.round") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.exp").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.exp") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.log").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.log") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.tanh").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.tanh") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.sqrt").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.sqrt") -.set_body(DispatchExtern); - -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.pow") -.set_body(DispatchExtern); - -TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.popcount") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.pow").set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.aocl_sw_emu.popcount").set_body(DispatchExtern); } // namespace intrin } // namespace codegen diff --git a/src/target/source/intrin_rule_cuda.cc b/src/target/source/intrin_rule_cuda.cc index d9441203edc0..45746b8ef721 100644 --- a/src/target/source/intrin_rule_cuda.cc +++ b/src/target/source/intrin_rule_cuda.cc @@ -31,10 +31,14 @@ struct CUDAMath { std::string operator()(DataType t, std::string name) const { if (t.is_float()) { switch (t.bits()) { - case 64: return name; - case 32: return name + 'f'; - case 16: return 'h' + name; - default: return ""; + case 64: + return name; + case 32: + return name + 'f'; + case 16: + return 'h' + name; + default: + return ""; } } return ""; @@ -55,14 +59,18 @@ struct CUDAFastMath : public CUDAMath { struct CUDAFastMathTan : public CUDAMath { std::string operator()(DataType t, std::string name) const { if (t.is_float()) { - switch (t.bits()) { - case 64: return name; - // `__tanf` seems to produce some values too deviant from numpy tan version. - // So, let's use just `tanf` instead. - case 32: return name + 'f'; - case 16: LOG(FATAL) << "cuda tan unsupported for float16"; - default: return ""; - } + switch (t.bits()) { + case 64: + return name; + // `__tanf` seems to produce some values too deviant from numpy tan version. + // So, let's use just `tanf` instead. + case 32: + return name + 'f'; + case 16: + LOG(FATAL) << "cuda tan unsupported for float16"; + default: + return ""; + } } return ""; } @@ -72,92 +80,104 @@ struct CUDAPopcount { std::string operator()(DataType t, std::string name) const { if (t.is_uint()) { switch (t.bits()) { - case 32: return "__popc"; - case 64: return "__popcll"; - default: return ""; + case 32: + return "__popc"; + case 64: + return "__popcll"; + default: + return ""; } } return ""; } }; -struct CUDAShuffle { - std::string operator()(DataType t, std::string name) const { - return "__shfl"; +struct CUDAWarpIntrinsic { + const char* operator()(DataType t, const std::string& name) const { + if (name == intrinsic::tvm_warp_shuffle) { + return "__shfl_sync"; + } + if (name == intrinsic::tvm_warp_shuffle_up) { + return "__shfl_up_sync"; + } + if (name == intrinsic::tvm_warp_shuffle_down) { + return "__shfl_down_sync"; + } + if (name == intrinsic::tvm_warp_activemask) { + return "__activemask"; + } + return ""; } }; -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.floor") -.set_body(DispatchExtern); +template +static void DispatchCUDAShuffle(const TVMArgs& args, TVMRetValue* rv) { + PrimExpr e = args[0]; + const CallNode* call = e.as(); + CHECK(call != nullptr); + CHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size + Array cuda_args{{call->args[0], call->args[1], call->args[2], call->args[3]}}; + const char* name = T()(call->dtype, call->name); + *rv = Call(call->dtype, name, cuda_args, CallNode::PureExtern); +} + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.floor").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.ceil") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.ceil").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.trunc") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.trunc").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.fabs") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.fabs").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.round") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.round").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp2") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp2").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp10") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp10").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.erf") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.erf").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log2") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log2").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log10") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log10").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tan") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tan").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.cos") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.cos").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.cosh") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.cosh").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.sin") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.sin").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.sinh") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.sinh").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.atan") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.atan").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tanh") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tanh").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.sqrt") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.sqrt").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.pow") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.pow").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.popcount") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.popcount").set_body(DispatchExtern); TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tvm_warp_shuffle") -.set_body(DispatchExtern); + .set_body(DispatchCUDAShuffle); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tvm_warp_shuffle_up") + .set_body(DispatchCUDAShuffle); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tvm_warp_shuffle_down") + .set_body(DispatchCUDAShuffle); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tvm_warp_activemask") + .set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.fmod") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.fmod").set_body(DispatchExtern); } // namespace intrin } // namespace codegen diff --git a/src/target/source/intrin_rule_metal.cc b/src/target/source/intrin_rule_metal.cc index 8bc87d2b280f..00fb9f9a95de 100644 --- a/src/target/source/intrin_rule_metal.cc +++ b/src/target/source/intrin_rule_metal.cc @@ -27,65 +27,45 @@ namespace tvm { namespace codegen { namespace intrin { -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.floor") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.floor").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.ceil") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.ceil").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.trunc") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.trunc").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.fabs") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.fabs").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.round") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.round").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.exp") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.exp").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.exp2") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.exp2").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.exp10") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.exp10").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.log") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.log").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.log2") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.log2").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.log10") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.log10").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.tanh") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.tanh").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.sqrt") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.sqrt").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.pow") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.pow").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.popcount") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.popcount").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.fmod") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.fmod").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.sin") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.sin").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.sinh") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.sinh").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.cos") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.cos").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.cosh") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.metal.cosh").set_body(DispatchExtern); } // namespace intrin } // namespace codegen diff --git a/src/target/source/intrin_rule_opencl.cc b/src/target/source/intrin_rule_opencl.cc index 1a4f52e4dfd1..8453b33f8a43 100644 --- a/src/target/source/intrin_rule_opencl.cc +++ b/src/target/source/intrin_rule_opencl.cc @@ -21,82 +21,69 @@ * \file intrin_rule_opencl.cc * \brief OpenCL intrinsic rules. */ +#include + #include "../intrin_rule.h" namespace tvm { namespace codegen { namespace intrin { -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.floor") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.floor").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.ceil") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.ceil").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.trunc") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.trunc").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.fabs") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.fabs").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.round") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.round").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.exp") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.exp").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.exp2") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.exp2").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.exp10") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.exp10").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.log") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.log").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.log2") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.log2").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.log10") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.log10").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.tanh") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.tanh").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.sqrt") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.sqrt").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.pow") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.pow").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.popcount") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.popcount").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.fmod") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.fmod").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.sin") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.sin").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.sinh") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.sinh").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.cos") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.cos").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.cosh") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.cosh").set_body(DispatchExtern); // There is no warp shuffle instruction in standard OpenCL // When shuffle is used, we assume it is intel's shuffle extension -struct IntelShuffle { - std::string operator()(DataType t, std::string name) const { - return "intel_sub_group_shuffle"; - } -}; - -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.tvm_warp_shuffle") -.set_body(DispatchExtern); +static void DispatchIntelShuffle(const TVMArgs& args, TVMRetValue* rv) { + PrimExpr e = args[0]; + const CallNode* call = e.as(); + CHECK(call != nullptr); + CHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size + arith::Analyzer analyzer; + CHECK(analyzer.CanProve(call->args[3] == call->args[4])) + << "Intel warp shuffle dose not support width != warp_size"; + Array opencl_args{{call->args[1], call->args[2]}}; + *rv = Call(call->dtype, "intel_sub_group_shuffle", opencl_args, CallNode::PureExtern); +} + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.tvm_warp_shuffle").set_body(DispatchIntelShuffle); } // namespace intrin } // namespace codegen diff --git a/src/target/source/intrin_rule_opengl.cc b/src/target/source/intrin_rule_opengl.cc deleted file mode 100644 index 1710d45d8bd6..000000000000 --- a/src/target/source/intrin_rule_opengl.cc +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file intrin_rule_opencl.cc - * \brief OpenCL intrinsic rules. - */ -#include "../intrin_rule.h" - -namespace tvm { -namespace codegen { -namespace intrin { - -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.floor") -.set_body(DispatchExtern); - -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.ceil") -.set_body(DispatchExtern); - -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.exp") -.set_body(DispatchExtern); - -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.exp2") -.set_body(DispatchExtern); - -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.exp10") -.set_body(DispatchExtern); - -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.log") -.set_body(DispatchExtern); - -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.log2") -.set_body(DispatchExtern); - -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.log10") -.set_body(DispatchExtern); - -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.tanh") -.set_body(DispatchExtern); - -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.sqrt") -.set_body(DispatchExtern); - -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.pow") -.set_body(DispatchExtern); - -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.popcount") -.set_body(DispatchExtern); - -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.sin") -.set_body(DispatchExtern); - -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.sinh") -.set_body(DispatchExtern); - -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.cos") -.set_body(DispatchExtern); - -TVM_REGISTER_GLOBAL("tvm.intrin.rule.opengl.cosh") -.set_body(DispatchExtern); - -} // namespace intrin -} // namespace codegen -} // namespace tvm diff --git a/src/target/source/intrin_rule_vhls.cc b/src/target/source/intrin_rule_vhls.cc index 41e76f260ff4..fb01d6566dab 100644 --- a/src/target/source/intrin_rule_vhls.cc +++ b/src/target/source/intrin_rule_vhls.cc @@ -27,62 +27,43 @@ namespace tvm { namespace codegen { namespace intrin { -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.floor") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.floor").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.ceil") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.ceil").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.trunc") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.trunc").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.fabs") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.fabs").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.round") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.round").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.exp") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.exp").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.exp2") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.exp2").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.exp10") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.exp10").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.log") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.log").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.log2") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.log2").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.log10") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.log10").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.tanh") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.tanh").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.sqrt") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.sqrt").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.pow") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.pow").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.popcount") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.popcount").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.sin") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.sin").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.sinh") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.sinh").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.cos") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.cos").set_body(DispatchExtern); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.cosh") -.set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.cosh").set_body(DispatchExtern); } // namespace intrin } // namespace codegen diff --git a/src/target/source/literal/cuda_half_t.h b/src/target/source/literal/cuda_half_t.h index 858ac8572a08..baf4ba733dce 100644 --- a/src/target/source/literal/cuda_half_t.h +++ b/src/target/source/literal/cuda_half_t.h @@ -295,4 +295,18 @@ __pack_half2(const half x, const half y) { } )"; +static constexpr const char* _cuda_warp_intrinsic_util = R"( +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700) +#define __shfl_sync(mask, var, lane, width) \ + __shfl((var), (lane), (width)) + +#define __shfl_down_sync(mask, var, offset, width) \ + __shfl_down((var), (offset), (width)) + +#define __shfl_up_sync(mask, var, offset, width) \ + __shfl_up((var), (offset), (width)) +#endif + +)"; + #endif // TVM_TARGET_SOURCE_LITERAL_CUDA_HALF_T_H_ diff --git a/src/target/source/source_module.cc b/src/target/source/source_module.cc index 5f133212140c..ba7f075d0045 100644 --- a/src/target/source/source_module.cc +++ b/src/target/source/source_module.cc @@ -23,43 +23,36 @@ */ #include #include -#include "codegen_source_base.h" + #include "../../runtime/file_util.h" #include "../../runtime/meta_data.h" +#include "codegen_source_base.h" namespace tvm { namespace codegen { +using runtime::PackedFunc; using runtime::TVMArgs; using runtime::TVMRetValue; -using runtime::PackedFunc; +using runtime::FunctionInfo; using runtime::GetFileFormat; using runtime::GetMetaFilePath; -using runtime::FunctionInfo; using runtime::SaveBinaryToFile; // Simulator function class SourceModuleNode : public runtime::ModuleNode { public: - SourceModuleNode(std::string code, - std::string fmt) - : code_(code), fmt_(fmt) {} - const char* type_key() const { - return "source"; - } + SourceModuleNode(std::string code, std::string fmt) : code_(code), fmt_(fmt) {} + const char* type_key() const { return "source"; } - PackedFunc GetFunction( - const std::string& name, - const ObjectPtr& sptr_to_self) final { + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { LOG(FATAL) << "Source module cannot execute, to get executable module" << " build TVM with \'" << fmt_ << "\' runtime support"; return PackedFunc(); } - std::string GetSource(const std::string& format) final { - return code_; - } + std::string GetSource(const std::string& format) final { return code_; } protected: std::string code_; @@ -74,35 +67,25 @@ runtime::Module SourceModuleCreate(std::string code, std::string fmt) { // Simulator function class CSourceModuleNode : public runtime::ModuleNode { public: - CSourceModuleNode(std::string code, - std::string fmt) - : code_(code), fmt_(fmt) {} - const char* type_key() const { - return "c"; - } + CSourceModuleNode(std::string code, std::string fmt) : code_(code), fmt_(fmt) {} + const char* type_key() const { return "c"; } - PackedFunc GetFunction( - const std::string& name, - const ObjectPtr& sptr_to_self) final { + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { LOG(FATAL) << "C Source module cannot execute, to get executable module" << " build TVM with \'" << fmt_ << "\' runtime support"; return PackedFunc(); } - std::string GetSource(const std::string& format) final { - return code_; - } + std::string GetSource(const std::string& format) final { return code_; } - void SaveToFile(const std::string& file_name, - const std::string& format) final { + void SaveToFile(const std::string& file_name, const std::string& format) final { std::string fmt = GetFileFormat(file_name, format); std::string meta_file = GetMetaFilePath(file_name); if (fmt == "cc") { CHECK_NE(code_.length(), 0); SaveBinaryToFile(file_name, code_); } else { - CHECK_EQ(fmt, fmt_) - << "Can only save to format=" << fmt_; + CHECK_EQ(fmt, fmt_) << "Can only save to format=" << fmt_; } } @@ -119,20 +102,12 @@ runtime::Module CSourceModuleCreate(std::string code, std::string fmt) { // supports limited save without cross compile class DeviceSourceModuleNode final : public runtime::ModuleNode { public: - DeviceSourceModuleNode(std::string data, - std::string fmt, - std::unordered_map fmap, - std::string type_key, + DeviceSourceModuleNode(std::string data, std::string fmt, + std::unordered_map fmap, std::string type_key, std::function fget_source) - : data_(data), - fmt_(fmt), - fmap_(fmap), - type_key_(type_key), - fget_source_(fget_source) {} - - PackedFunc GetFunction( - const std::string& name, - const ObjectPtr& sptr_to_self) final { + : data_(data), fmt_(fmt), fmap_(fmap), type_key_(type_key), fget_source_(fget_source) {} + + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { LOG(FATAL) << "Source module cannot execute, to get executable module" << " build TVM with \'" << fmt_ << "\' runtime support"; return PackedFunc(); @@ -146,15 +121,11 @@ class DeviceSourceModuleNode final : public runtime::ModuleNode { } } - const char* type_key() const { - return type_key_.c_str(); - } + const char* type_key() const { return type_key_.c_str(); } - void SaveToFile(const std::string& file_name, - const std::string& format) final { + void SaveToFile(const std::string& file_name, const std::string& format) final { std::string fmt = GetFileFormat(file_name, format); - CHECK_EQ(fmt, fmt_) - << "Can only save to format=" << fmt_; + CHECK_EQ(fmt, fmt_) << "Can only save to format=" << fmt_; std::string meta_file = GetMetaFilePath(file_name); SaveMetaDataToFile(meta_file, fmap_); SaveBinaryToFile(file_name, data_); @@ -175,19 +146,14 @@ class DeviceSourceModuleNode final : public runtime::ModuleNode { }; runtime::Module DeviceSourceModuleCreate( - std::string data, - std::string fmt, - std::unordered_map fmap, - std::string type_key, - std::function fget_source) { + std::string data, std::string fmt, std::unordered_map fmap, + std::string type_key, std::function fget_source) { auto n = make_object(data, fmt, fmap, type_key, fget_source); return runtime::Module(n); } -TVM_REGISTER_GLOBAL("runtime.SourceModuleCreate") -.set_body_typed(SourceModuleCreate); +TVM_REGISTER_GLOBAL("runtime.SourceModuleCreate").set_body_typed(SourceModuleCreate); -TVM_REGISTER_GLOBAL("runtime.CSourceModuleCreate") -.set_body_typed(CSourceModuleCreate); +TVM_REGISTER_GLOBAL("runtime.CSourceModuleCreate").set_body_typed(CSourceModuleCreate); } // namespace codegen } // namespace tvm diff --git a/src/target/spirv/build_vulkan.cc b/src/target/spirv/build_vulkan.cc index 161c1ca3bab1..86d1614dc863 100644 --- a/src/target/spirv/build_vulkan.cc +++ b/src/target/spirv/build_vulkan.cc @@ -22,44 +22,37 @@ * \brief Build SPIRV block */ // Use libspirv for parsing and validating code. -#include #include -#include - -#include "codegen_spirv.h" -#include "../build_common.h" +#include +#include -#include "../../runtime/vulkan/vulkan_shader.h" #include "../../runtime/vulkan/vulkan_module.h" +#include "../../runtime/vulkan/vulkan_shader.h" +#include "../build_common.h" +#include "codegen_spirv.h" namespace tvm { namespace codegen { class SPIRVTools { public: - SPIRVTools() { - ctx_ = spvContextCreate(SPV_ENV_VULKAN_1_0); - } - ~SPIRVTools() { - spvContextDestroy(ctx_); - } + SPIRVTools() { ctx_ = spvContextCreate(SPV_ENV_VULKAN_1_0); } + ~SPIRVTools() { spvContextDestroy(ctx_); } std::string BinaryToText(const std::vector& bin) { spv_text text = nullptr; spv_diagnostic diagnostic; spv_const_binary_t spv_bin{bin.data(), bin.size()}; spv_result_t res; - res = spvBinaryToText( - ctx_, spv_bin.code, spv_bin.wordCount, - SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES | - SPV_BINARY_TO_TEXT_OPTION_INDENT, - &text, &diagnostic); + res = + spvBinaryToText(ctx_, spv_bin.code, spv_bin.wordCount, + SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES | SPV_BINARY_TO_TEXT_OPTION_INDENT, + &text, &diagnostic); - CHECK_EQ(res, SPV_SUCCESS) - << " line=" << diagnostic->position.line - << " column=" << diagnostic->position.column - << " index=" << diagnostic->position.index - << " error:" << diagnostic->error; + CHECK_EQ(res, SPV_SUCCESS) << " line=" << diagnostic->position.line + << " column=" << diagnostic->position.column + << " index=" << diagnostic->position.index + << " error:" << diagnostic->error; std::string ret(text->str); spvTextDestroy(text); @@ -70,7 +63,7 @@ class SPIRVTools { spv_context ctx_; }; -runtime::Module BuildSPIRV(IRModule mod) { +runtime::Module BuildSPIRV(IRModule mod, std::string target, bool webgpu_restriction) { using tvm::runtime::Registry; using tvm::runtime::VulkanShader; @@ -80,11 +73,12 @@ runtime::Module BuildSPIRV(IRModule mod) { const auto* postproc = Registry::Get("tvm_callback_vulkan_postproc"); + mod = tir::transform::PointerValueTypeRewrite()(std::move(mod)); + CodeGenSPIRV cg; - for (auto kv : mod->functions) { - CHECK(kv.second->IsInstance()) - << "CodeGenSPIRV: Can only take PrimFunc"; + for (auto kv : mod->functions) { + CHECK(kv.second->IsInstance()) << "CodeGenSPIRV: Can only take PrimFunc"; auto f = Downcast(kv.second); auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); CHECK(calling_conv == CallingConv::kDeviceKernelLaunch) @@ -94,9 +88,16 @@ runtime::Module BuildSPIRV(IRModule mod) { << "CodeGenSPIRV: Expect PrimFunc to have the global_symbol attribute"; std::string f_name = global_symbol.value(); - f = PointerValueTypeRewrite(std::move(f)); + VulkanShader shader; - shader.data = cg.BuildFunction(f); + std::string entry = webgpu_restriction ? "main" : f_name; + shader.data = cg.BuildFunction(f, entry); + + if (webgpu_restriction) { + for (auto param : f->params) { + CHECK(param.dtype().is_handle()) << "WebGPU does not yet support non-buffer arguments"; + } + } if (postproc != nullptr) { TVMByteArray arr; @@ -112,12 +113,16 @@ runtime::Module BuildSPIRV(IRModule mod) { smap[f_name] = std::move(shader); } - return runtime::VulkanModuleCreate( - smap, ExtractFuncInfo(mod), code_data.str()); + return runtime::VulkanModuleCreate(smap, ExtractFuncInfo(mod), code_data.str()); } -TVM_REGISTER_GLOBAL("target.build.vulkan") -.set_body_typed(BuildSPIRV); +TVM_REGISTER_GLOBAL("target.build.vulkan").set_body_typed([](IRModule mod, std::string target) { + return BuildSPIRV(mod, target, false); +}); + +TVM_REGISTER_GLOBAL("target.build.webgpu").set_body_typed([](IRModule mod, std::string target) { + return BuildSPIRV(mod, target, true); +}); } // namespace codegen } // namespace tvm diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index 1d8004e9938f..699d3953f04c 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -21,20 +21,20 @@ * \file codegen_spirv.cc * \brief Generate SPIRV block */ -#include -#include +#include "codegen_spirv.h" + #include +#include +#include + #include -#include "codegen_spirv.h" -#include "../../arith/compute_expr.h" namespace tvm { namespace codegen { -std::vector CodeGenSPIRV::BuildFunction(const PrimFunc& f) { +std::vector CodeGenSPIRV::BuildFunction(const PrimFunc& f, const std::string& name) { this->InitFuncState(); - CHECK(f->HasNonzeroAttr(tir::attr::kNoAlias)) - << "SPIRV only takes restricted memory model"; + CHECK(f->HasNonzeroAttr(tir::attr::kNoAlias)) << "SPIRV only takes restricted memory model"; std::vector pod_args; uint32_t num_buffer = 0; @@ -45,8 +45,8 @@ std::vector CodeGenSPIRV::BuildFunction(const PrimFunc& f) { auto* prim = ptr->element_type.as(); CHECK(prim); DataType value_type = prim->dtype; - spirv::Value arg_value = builder_->BufferArgument( - builder_->GetSType(value_type), 0, num_buffer); + spirv::Value arg_value = + builder_->BufferArgument(builder_->GetSType(value_type), 0, num_buffer); storage_info_[arg.get()].UpdateContentType(value_type); var_map_[arg.get()] = arg_value; } else { @@ -68,8 +68,7 @@ std::vector CodeGenSPIRV::BuildFunction(const PrimFunc& f) { } spirv::Value ptr = builder_->DeclarePushConstant(value_types); for (size_t i = 0; i < pod_args.size(); ++i) { - spirv::Value value = builder_->GetPushConstant( - ptr, value_types[i], static_cast(i)); + spirv::Value value = builder_->GetPushConstant(ptr, value_types[i], static_cast(i)); var_map_[pod_args[i].get()] = value; } } @@ -78,12 +77,7 @@ std::vector CodeGenSPIRV::BuildFunction(const PrimFunc& f) { builder_->MakeInst(spv::OpReturn); builder_->MakeInst(spv::OpFunctionEnd); - auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); - CHECK(global_symbol.defined()) - << "CodeGenSPIRV: Expect PrimFunc to have the global_symbol attribute"; - - builder_->CommitKernelFunction( - func_ptr, static_cast(global_symbol.value())); + builder_->CommitKernelFunction(func_ptr, name); return builder_->Finalize(); } @@ -97,17 +91,16 @@ void CodeGenSPIRV::InitFuncState() { builder_->InitHeader(); } -spirv::Value CodeGenSPIRV::GetThreadIndex( - const IterVar& iv, const PrimExpr& extent) { - runtime::ThreadScope ts = runtime::ThreadScope::make(iv->thread_tag); +spirv::Value CodeGenSPIRV::GetThreadIndex(const IterVar& iv, const PrimExpr& extent) { + runtime::ThreadScope ts = runtime::ThreadScope::Create(iv->thread_tag); spirv::Value v; if (ts.rank == 1) { v = builder_->GetLocalID(ts.dim_index); - int size = 0; - CHECK(arith::GetConstInt(extent, &size)) - << "SPIRV only allows constant thread group size " << " get " << extent; + auto* sizeptr = extent.as(); + CHECK(sizeptr) << "SPIRV only allows constant thread group size " + << " get " << extent; CHECK_LT(ts.dim_index, 3); - workgroup_size_[ts.dim_index] = static_cast(size); + workgroup_size_[ts.dim_index] = static_cast(sizeptr->value); } else { v = builder_->GetWorkgroupID(ts.dim_index); } @@ -122,12 +115,12 @@ spirv::Value CodeGenSPIRV::CreateStorageSync(const CallNode* op) { } else if (sync == "shared") { auto type_int = builder_->GetSType(DataType::Int(32)); builder_->MakeInst( - spv::OpControlBarrier, - builder_->IntImm(type_int, static_cast(spv::ScopeWorkgroup)), - builder_->IntImm(type_int, static_cast(spv::ScopeWorkgroup)), - builder_->IntImm(type_int, static_cast( - spv::MemorySemanticsSequentiallyConsistentMask | - spv::MemorySemanticsWorkgroupMemoryMask))); + spv::OpControlBarrier, + builder_->IntImm(type_int, static_cast(spv::ScopeWorkgroup)), + builder_->IntImm(type_int, static_cast(spv::ScopeWorkgroup)), + builder_->IntImm(type_int, + static_cast(spv::MemorySemanticsSequentiallyConsistentMask | + spv::MemorySemanticsWorkgroupMemoryMask))); } else { LOG(FATAL) << "Do not support sync " << sync; } @@ -231,8 +224,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const NotNode* op) { } spirv::Value CodeGenSPIRV::VisitExpr_(const SelectNode* op) { - return builder_->Select(MakeValue(op->condition), - MakeValue(op->true_value), + return builder_->Select(MakeValue(op->condition), MakeValue(op->true_value), MakeValue(op->false_value)); } @@ -246,14 +238,12 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const LetNode* op) { spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { if (op->is_intrinsic("spirv_glsl450")) { CHECK_GE(op->args.size(), 2U); - uint32_t inst_id = static_cast( - op->args[0].as()->value); + uint32_t inst_id = static_cast(op->args[0].as()->value); std::vector values; for (size_t i = 1; i < op->args.size(); ++i) { values.push_back(MakeValue(op->args[i])); } - return builder_->CallGLSL450( - builder_->GetSType(op->dtype), inst_id, values); + return builder_->CallGLSL450(builder_->GetSType(op->dtype), inst_id, values); } else if (op->is_intrinsic(CallNode::bitwise_and)) { CHECK_EQ(op->args.size(), 2U); spirv::Value a = MakeValue(op->args[0]); @@ -304,10 +294,8 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { spirv::Label then_label = builder_->NewLabel(); spirv::Label else_label = builder_->NewLabel(); spirv::Label merge_label = builder_->NewLabel(); - builder_->MakeInst( - spv::OpSelectionMerge, merge_label, spv::SelectionControlMaskNone); - builder_->MakeInst( - spv::OpBranchConditional, cond, then_label, else_label); + builder_->MakeInst(spv::OpSelectionMerge, merge_label, spv::SelectionControlMaskNone); + builder_->MakeInst(spv::OpBranchConditional, cond, then_label, else_label); // then block, must get label after we see the value builder_->StartLabel(then_label); spirv::Value then_value = MakeValue(op->args[1]); @@ -325,19 +313,13 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { phi.SetIncoming(1, else_value, else_value_label); return phi; } else if (op->is_intrinsic("popcount")) { - return builder_->MakeValue( - spv::OpBitCount, - builder_->GetSType(op->dtype), - MakeValue(op->args[0])); + return builder_->MakeValue(spv::OpBitCount, builder_->GetSType(op->dtype), + MakeValue(op->args[0])); } else { - if (op->call_type == CallNode::Intrinsic || - op->call_type == CallNode::PureIntrinsic) { - LOG(FATAL) << "Unresolved intrinsic " << op->name - << " with return type " << op->dtype; - } else if (op->call_type == CallNode::Extern || - op->call_type == CallNode::PureExtern) { - LOG(FATAL) << "Unresolved extern " << op->name - << " with return type " << op->dtype; + if (op->call_type == CallNode::Intrinsic || op->call_type == CallNode::PureIntrinsic) { + LOG(FATAL) << "Unresolved intrinsic " << op->name << " with return type " << op->dtype; + } else if (op->call_type == CallNode::Extern || op->call_type == CallNode::PureExtern) { + LOG(FATAL) << "Unresolved extern " << op->name << " with return type " << op->dtype; } else { LOG(FATAL) << "Unresolved call type " << op->call_type; } @@ -351,8 +333,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const RampNode* op) { for (int i = 0; i < op->lanes; ++i) { spirv::Value v = base; if (i != 0) { - spirv::Value offset = MakeValue( - make_const(op->stride.dtype(), i) * op->stride); + spirv::Value offset = MakeValue(make_const(op->stride.dtype(), i) * op->stride); v = builder_->Add(v, offset); } values.push_back(v); @@ -380,8 +361,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const LoadNode* op) { spirv::SType content_type = builder_->GetSType(info.content_type); spirv::Value buffer = MakeValue(op->buffer_var); - spirv::SType ptr_type = builder_->GetPointerType( - content_type, buffer.stype.storage_class); + spirv::SType ptr_type = builder_->GetPointerType(content_type, buffer.stype.storage_class); uint32_t mask = spv::MemoryAccessMaskNone; if (info.is_volatile) { @@ -391,18 +371,15 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const LoadNode* op) { CHECK_EQ(info.content_type, op->dtype) << "Vulkan only allow one type access to the same buffer"; spirv::Value index = MakeValue(op->index); - spirv::Value ptr = builder_->StructArrayAccess( - ptr_type, buffer, index); + spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, index); return builder_->MakeValue(spv::OpLoad, content_type, ptr, mask); } else { if (op->dtype.element_of() == info.content_type) { // because content type is element type, we can only do scalarize load. std::vector values; auto f = [&](int i, spirv::Value index) { - spirv::Value ptr = builder_->StructArrayAccess( - ptr_type, buffer, index); - values.emplace_back( - builder_->MakeValue(spv::OpLoad, content_type, ptr, mask)); + spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, index); + values.emplace_back(builder_->MakeValue(spv::OpLoad, content_type, ptr, mask)); }; this->Scalarize(op->index, f); return builder_->Concat(values); @@ -411,13 +388,11 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const LoadNode* op) { if (is_one(ramp->stride)) { CHECK_EQ(ramp->lanes, op->dtype.lanes()); arith::ModularSet me = analyzer_->modular_set(ramp->base); - CHECK((me->coeff % ramp->lanes) == 0 && - (me->base % ramp->lanes) == 0) + CHECK((me->coeff % ramp->lanes) == 0 && (me->base % ramp->lanes) == 0) << "Only aligned vector access is allowed in SPIRV"; - PrimExpr vec_index = tir::Simplify( - ramp->base / make_const(ramp->base.dtype(), ramp->lanes)); - spirv::Value ptr = builder_->StructArrayAccess( - ptr_type, buffer, MakeValue(vec_index)); + PrimExpr vec_index = + analyzer_->Simplify(ramp->base / make_const(ramp->base.dtype(), ramp->lanes)); + spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, MakeValue(vec_index)); return builder_->MakeValue(spv::OpLoad, content_type, ptr, mask); } } @@ -428,8 +403,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const LoadNode* op) { return spirv::Value(); } -void CodeGenSPIRV::Scalarize(const PrimExpr& e, - std::function f) { +void CodeGenSPIRV::Scalarize(const PrimExpr& e, std::function f) { if (const RampNode* ramp = e.as()) { for (int i = 0; i < ramp->dtype.lanes(); ++i) { PrimExpr offset = ramp->base + ramp->stride * i; @@ -439,8 +413,7 @@ void CodeGenSPIRV::Scalarize(const PrimExpr& e, spirv::SType etype = builder_->GetSType(e.dtype().element_of()); spirv::Value value = MakeValue(e); for (int i = 0; i < e.dtype().lanes(); ++i) { - f(i, builder_->MakeValue( - spv::OpCompositeExtract, etype, value, i)); + f(i, builder_->MakeValue(spv::OpCompositeExtract, etype, value, i)); } } } @@ -458,8 +431,7 @@ void CodeGenSPIRV::VisitStmt_(const StoreNode* op) { spirv::SType content_type = builder_->GetSType(info.content_type); spirv::Value buffer = MakeValue(op->buffer_var); spirv::Value value = MakeValue(op->value); - spirv::SType ptr_type = builder_->GetPointerType( - content_type, buffer.stype.storage_class); + spirv::SType ptr_type = builder_->GetPointerType(content_type, buffer.stype.storage_class); uint32_t mask = spv::MemoryAccessMaskNone; if (info.is_volatile) { @@ -470,17 +442,14 @@ void CodeGenSPIRV::VisitStmt_(const StoreNode* op) { CHECK_EQ(info.content_type, op->value.dtype()) << "Vulkan only allow one type access to the same buffer"; spirv::Value index = MakeValue(op->index); - spirv::Value ptr = builder_->StructArrayAccess( - ptr_type, buffer, index); + spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, index); builder_->MakeInst(spv::OpStore, ptr, value, mask); } else { if (op->value.dtype().element_of() == info.content_type) { // because content type is element type, we can only do scalarize load. auto f = [&](int i, spirv::Value index) { - spirv::Value elem = builder_->MakeValue( - spv::OpCompositeExtract, content_type, value, i); - spirv::Value ptr = builder_->StructArrayAccess( - ptr_type, buffer, index); + spirv::Value elem = builder_->MakeValue(spv::OpCompositeExtract, content_type, value, i); + spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, index); builder_->MakeInst(spv::OpStore, ptr, elem, mask); }; this->Scalarize(op->index, f); @@ -489,13 +458,11 @@ void CodeGenSPIRV::VisitStmt_(const StoreNode* op) { if (is_one(ramp->stride)) { CHECK_EQ(ramp->lanes, op->value.dtype().lanes()); arith::ModularSet me = analyzer_->modular_set(ramp->base); - CHECK((me->coeff % ramp->lanes) == 0 && - (me->base % ramp->lanes) == 0) + CHECK((me->coeff % ramp->lanes) == 0 && (me->base % ramp->lanes) == 0) << "Only aligned vector access is allowed in SPIRV"; - PrimExpr vec_index = tir::Simplify( - ramp->base / make_const(ramp->base.dtype(), ramp->lanes)); - spirv::Value ptr = builder_->StructArrayAccess( - ptr_type, buffer, MakeValue(vec_index)); + PrimExpr vec_index = + analyzer_->Simplify(ramp->base / make_const(ramp->base.dtype(), ramp->lanes)); + spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, MakeValue(vec_index)); builder_->MakeInst(spv::OpStore, ptr, value, mask); return; } @@ -523,14 +490,11 @@ void CodeGenSPIRV::VisitStmt_(const ForNode* op) { spirv::PhiValue loop_var = builder_->MakePhi(init_value.stype, 2); loop_var.SetIncoming(0, init_value, init_label); spirv::Value loop_cond = builder_->LT(loop_var, extent_value); - uint32_t control = ( - op->for_type == ForType::Unrolled ? - spv::LoopControlUnrollMask : spv::LoopControlMaskNone); - builder_->MakeInst( - spv::OpLoopMerge, merge_label, continue_label, control); - builder_->MakeInst( - spv::OpBranchConditional, loop_cond, body_label, merge_label, - weight_likely_branch_, 1); + uint32_t control = + (op->for_type == ForType::Unrolled ? spv::LoopControlUnrollMask : spv::LoopControlMaskNone); + builder_->MakeInst(spv::OpLoopMerge, merge_label, continue_label, control); + builder_->MakeInst(spv::OpBranchConditional, loop_cond, body_label, merge_label, + weight_likely_branch_, 1); // loop body builder_->StartLabel(body_label); @@ -540,10 +504,8 @@ void CodeGenSPIRV::VisitStmt_(const ForNode* op) { // loop continue builder_->StartLabel(continue_label); - spirv::Value one = - op->loop_var.dtype().is_int() ? - builder_->IntImm(loop_var.stype, 1) : - builder_->UIntImm(loop_var.stype, 1); + spirv::Value one = op->loop_var.dtype().is_int() ? builder_->IntImm(loop_var.stype, 1) + : builder_->UIntImm(loop_var.stype, 1); spirv::Value next_value = builder_->Add(loop_var, one); loop_var.SetIncoming(1, next_value, builder_->CurrentLabel()); builder_->MakeInst(spv::OpBranch, head_label); @@ -557,10 +519,8 @@ void CodeGenSPIRV::VisitStmt_(const IfThenElseNode* op) { spirv::Label merge_label = builder_->NewLabel(); if (op->else_case.defined()) { spirv::Label else_label = builder_->NewLabel(); - builder_->MakeInst( - spv::OpSelectionMerge, merge_label, spv::SelectionControlMaskNone); - builder_->MakeInst( - spv::OpBranchConditional, cond, then_label, else_label); + builder_->MakeInst(spv::OpSelectionMerge, merge_label, spv::SelectionControlMaskNone); + builder_->MakeInst(spv::OpBranchConditional, cond, then_label, else_label); // then block builder_->StartLabel(then_label); this->VisitStmt(op->then_case); @@ -570,11 +530,9 @@ void CodeGenSPIRV::VisitStmt_(const IfThenElseNode* op) { this->VisitStmt(op->else_case); builder_->MakeInst(spv::OpBranch, merge_label); } else { - builder_->MakeInst( - spv::OpSelectionMerge, merge_label, spv::SelectionControlMaskNone); - builder_->MakeInst( - spv::OpBranchConditional, cond, then_label, merge_label, - weight_likely_branch_, 1); + builder_->MakeInst(spv::OpSelectionMerge, merge_label, spv::SelectionControlMaskNone); + builder_->MakeInst(spv::OpBranchConditional, cond, then_label, merge_label, + weight_likely_branch_, 1); // then block builder_->StartLabel(then_label); this->VisitStmt(op->then_case); @@ -588,23 +546,20 @@ void CodeGenSPIRV::VisitStmt_(const AllocateNode* op) { CHECK(!is_zero(op->condition)); CHECK(!op->dtype.is_handle()); int32_t constant_size = op->constant_allocation_size(); - CHECK_GT(constant_size, 0) - << "Can only handle constant size stack allocation in GPU"; + CHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation in GPU"; spirv::Value buf; StorageInfo& info = storage_info_[op->buffer_var.get()]; spirv::SType etype = builder_->GetSType(op->dtype); if (info.scope.rank == runtime::StorageRank::kLocal) { - buf = builder_->Allocate( - etype, static_cast(constant_size), - spv::StorageClassFunction); + buf = + builder_->Allocate(etype, static_cast(constant_size), spv::StorageClassFunction); } else { // shared memory CHECK(info.scope.rank == runtime::StorageRank::kShared) << "Can only allocate shared or local memory inside kernel"; // Shared memory - buf = builder_->Allocate( - etype, static_cast(constant_size), - spv::StorageClassWorkgroup); + buf = + builder_->Allocate(etype, static_cast(constant_size), spv::StorageClassWorkgroup); } CHECK(!info.content_fixed); info.UpdateContentType(op->dtype); @@ -625,8 +580,7 @@ void CodeGenSPIRV::VisitStmt_(const AttrStmtNode* op) { } else if (op->attr_key == tir::attr::storage_scope) { const VarNode* v = op->node.as(); CHECK(v); - storage_info_[v].scope = - runtime::StorageScope::make(op->value.as()->value); + storage_info_[v].scope = runtime::StorageScope::Create(op->value.as()->value); } else if (op->attr_key == tir::attr::volatile_scope) { const VarNode* v = op->node.as(); CHECK(v); @@ -654,9 +608,7 @@ void CodeGenSPIRV::VisitStmt_(const SeqStmtNode* op) { } } -void CodeGenSPIRV::VisitStmt_(const EvaluateNode* op) { - MakeValue(op->value); -} +void CodeGenSPIRV::VisitStmt_(const EvaluateNode* op) { MakeValue(op->value); } } // namespace codegen } // namespace tvm diff --git a/src/target/spirv/codegen_spirv.h b/src/target/spirv/codegen_spirv.h index f50760711dec..a8af29a194d5 100644 --- a/src/target/spirv/codegen_spirv.h +++ b/src/target/spirv/codegen_spirv.h @@ -26,14 +26,16 @@ #include #include +#include #include -#include #include +#include #include +#include -#include "ir_builder.h" #include "../../runtime/thread_storage_scope.h" +#include "ir_builder.h" namespace tvm { namespace codegen { @@ -43,24 +45,22 @@ using namespace tir; /*! * \brief Code generator into SPIRV */ -class CodeGenSPIRV: - public ExprFunctor, - public StmtFunctor { +class CodeGenSPIRV : public ExprFunctor, + public StmtFunctor { public: /*! * \brief Compile and add function f to the current module. * \param f The function to be added. + * \param name The name of the target function. * \return The final spirv module. */ - virtual std::vector BuildFunction(const PrimFunc& f); + virtual std::vector BuildFunction(const PrimFunc& f, const std::string& name); /*! * \brief Create Value for expression e * \param e The expression to be created value for. * \return created value. */ - spirv::Value MakeValue(const PrimExpr& e) { - return VisitExpr(e); - } + spirv::Value MakeValue(const PrimExpr& e) { return VisitExpr(e); } // override codegen spirv::Value VisitExpr_(const VarNode* op) override; spirv::Value VisitExpr_(const CastNode* op) override; @@ -115,8 +115,7 @@ class CodeGenSPIRV: // Update content type if it hasn't beenupdated. void UpdateContentType(DataType type) { if (content_fixed) { - CHECK_EQ(type, content_type) - << "Cannot use two different content type in GLSL model"; + CHECK_EQ(type, content_type) << "Cannot use two different content type in GLSL model"; } else { this->content_type = type; content_fixed = true; @@ -128,8 +127,7 @@ class CodeGenSPIRV: // Get the thread index spirv::Value GetThreadIndex(const IterVar& iv, const PrimExpr& extent); spirv::Value CreateStorageSync(const CallNode* op); - void Scalarize(const PrimExpr& e, - std::function f); + void Scalarize(const PrimExpr& e, std::function f); // The builder std::unique_ptr builder_; // Work group size of three @@ -147,5 +145,4 @@ class CodeGenSPIRV: } // namespace codegen } // namespace tvm - #endif // TVM_TARGET_SPIRV_CODEGEN_SPIRV_H_ diff --git a/src/target/spirv/intrin_rule_spirv.cc b/src/target/spirv/intrin_rule_spirv.cc index ead6952b434e..a6b254770daa 100644 --- a/src/target/spirv/intrin_rule_spirv.cc +++ b/src/target/spirv/intrin_rule_spirv.cc @@ -20,9 +20,9 @@ /*! * \file intrin_rule_spirv.cc */ +#include #include #include -#include namespace tvm { namespace codegen { @@ -31,7 +31,7 @@ namespace spirv { using namespace runtime; // num_signature means number of arguments used to query signature -template +template inline void DispatchGLSLPureIntrin(const TVMArgs& targs, TVMRetValue* rv) { PrimExpr e = targs[0]; const tir::CallNode* call = e.as(); @@ -43,39 +43,55 @@ inline void DispatchGLSLPureIntrin(const TVMArgs& targs, TVMRetValue* rv) { for (PrimExpr arg : call->args) { cargs.push_back(arg); } - *rv = tir::CallNode::make( - call->dtype, "spirv_glsl450", cargs, tir::CallNode::PureIntrinsic); + *rv = tir::Call(call->dtype, "spirv_glsl450", cargs, tir::CallNode::PureIntrinsic); } TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.floor") -.set_body(DispatchGLSLPureIntrin); + .set_body(DispatchGLSLPureIntrin); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.ceil") -.set_body(DispatchGLSLPureIntrin); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.ceil").set_body(DispatchGLSLPureIntrin); TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.round") -.set_body(DispatchGLSLPureIntrin); + .set_body(DispatchGLSLPureIntrin); TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.trunc") -.set_body(DispatchGLSLPureIntrin); + .set_body(DispatchGLSLPureIntrin); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.fabs").set_body(DispatchGLSLPureIntrin); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.exp").set_body(DispatchGLSLPureIntrin); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.log").set_body(DispatchGLSLPureIntrin); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.sqrt").set_body(DispatchGLSLPureIntrin); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.pow").set_body(DispatchGLSLPureIntrin); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.tanh").set_body(DispatchGLSLPureIntrin); + +// WebGPU rules. +TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.floor") + .set_body(DispatchGLSLPureIntrin); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.ceil").set_body(DispatchGLSLPureIntrin); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.round") + .set_body(DispatchGLSLPureIntrin); + +TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.trunc") + .set_body(DispatchGLSLPureIntrin); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.fabs") -.set_body(DispatchGLSLPureIntrin); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.fabs").set_body(DispatchGLSLPureIntrin); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.exp") -.set_body(DispatchGLSLPureIntrin); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.exp").set_body(DispatchGLSLPureIntrin); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.log") -.set_body(DispatchGLSLPureIntrin); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.log").set_body(DispatchGLSLPureIntrin); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.sqrt") -.set_body(DispatchGLSLPureIntrin); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.sqrt").set_body(DispatchGLSLPureIntrin); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.pow") -.set_body(DispatchGLSLPureIntrin); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.pow").set_body(DispatchGLSLPureIntrin); -TVM_REGISTER_GLOBAL("tvm.intrin.rule.vulkan.tanh") -.set_body(DispatchGLSLPureIntrin); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.webgpu.tanh").set_body(DispatchGLSLPureIntrin); } // namespace spirv } // namespace codegen diff --git a/src/target/spirv/ir_builder.cc b/src/target/spirv/ir_builder.cc index bf43f11cce02..305464ac398b 100644 --- a/src/target/spirv/ir_builder.cc +++ b/src/target/spirv/ir_builder.cc @@ -32,10 +32,14 @@ namespace spirv { void IRBuilder::InitHeader() { CHECK_EQ(header_.size(), 0U); header_.push_back(spv::MagicNumber); - // Use SPIR-V v1.0. This needs to be kept in sync (or at least behind) - // `VkApplicationInfo.apiVersion` in `vulkan.cc` to ensure Vulkan API - // validation passes. + + // Use the spirv version as indicated in the SDK. +#if SPV_VERSION >= 0x10300 + header_.push_back(0x10300); +#else header_.push_back(0x10000); +#endif + // generator: set to 0, unknown header_.push_back(0U); // Bound: set during Finalize @@ -45,9 +49,9 @@ void IRBuilder::InitHeader() { // shader ib_.Begin(spv::OpCapability).Add(spv::CapabilityShader).Commit(&header_); // memory model - ib_.Begin(spv::OpMemoryModel).AddSeq( - spv::AddressingModelLogical, - spv::MemoryModelGLSL450).Commit(&entry_); + ib_.Begin(spv::OpMemoryModel) + .AddSeq(spv::AddressingModelLogical, spv::MemoryModelGLSL450) + .Commit(&entry_); this->InitPreDefs(); } @@ -62,8 +66,7 @@ void IRBuilder::InitPreDefs() { t_void_.id = id_counter_++; ib_.Begin(spv::OpTypeVoid).Add(t_void_).Commit(&global_); t_void_func_.id = id_counter_++; - ib_.Begin(spv::OpTypeFunction) - .AddSeq(t_void_func_, t_void_).Commit(&global_); + ib_.Begin(spv::OpTypeFunction).AddSeq(t_void_func_, t_void_).Commit(&global_); } SType IRBuilder::GetSType(const DataType& dtype) { @@ -89,8 +92,7 @@ SType IRBuilder::GetSType(const DataType& dtype) { return t; } -SType IRBuilder::GetPointerType(const SType& value_type, - spv::StorageClass storage_class) { +SType IRBuilder::GetPointerType(const SType& value_type, spv::StorageClass storage_class) { CHECK_NE(storage_class, spv::StorageClassMax); auto key = std::make_pair(value_type.id, storage_class); auto it = pointer_type_tbl_.find(key); @@ -102,14 +104,12 @@ SType IRBuilder::GetPointerType(const SType& value_type, t.type = DataType::Handle(); t.element_type_id = value_type.id; t.storage_class = storage_class; - ib_.Begin(spv::OpTypePointer) - .AddSeq(t, storage_class, value_type).Commit(&global_); + ib_.Begin(spv::OpTypePointer).AddSeq(t, storage_class, value_type).Commit(&global_); pointer_type_tbl_[key] = t; return t; } -SType IRBuilder::GetStructArrayType(const SType& value_type, - uint32_t num_elems) { +SType IRBuilder::GetStructArrayType(const SType& value_type, uint32_t num_elems) { auto key = std::make_pair(value_type.id, num_elems); auto it = struct_array_type_tbl_.find(key); if (it != struct_array_type_tbl_.end()) { @@ -123,54 +123,50 @@ SType IRBuilder::GetStructArrayType(const SType& value_type, if (num_elems != 0) { Value length = UIntImm(GetSType(DataType::UInt(32)), num_elems); - ib_.Begin(spv::OpTypeArray) - .AddSeq(arr_type, value_type, length).Commit(&global_); + ib_.Begin(spv::OpTypeArray).AddSeq(arr_type, value_type, length).Commit(&global_); } else { - ib_.Begin(spv::OpTypeRuntimeArray) - .AddSeq(arr_type, value_type).Commit(&global_); + ib_.Begin(spv::OpTypeRuntimeArray).AddSeq(arr_type, value_type).Commit(&global_); } int nbits = value_type.type.bits() * value_type.type.lanes(); CHECK_EQ(nbits % 8, 0); uint32_t nbytes = static_cast(nbits) / 8; // decorate the array type. - this->Decorate(spv::OpDecorate, - arr_type, spv::DecorationArrayStride, nbytes); + this->Decorate(spv::OpDecorate, arr_type, spv::DecorationArrayStride, nbytes); // declare struct of array SType struct_type; struct_type.id = id_counter_++; struct_type.type = DataType::Handle(); struct_type.element_type_id = value_type.id; - ib_.Begin(spv::OpTypeStruct) - .AddSeq(struct_type, arr_type).Commit(&global_); + ib_.Begin(spv::OpTypeStruct).AddSeq(struct_type, arr_type).Commit(&global_); // decorate the array type. ib_.Begin(spv::OpMemberDecorate) .AddSeq(struct_type, 0, spv::DecorationOffset, 0) .Commit(&decorate_); + +#if SPV_VERSION < 0x10300 + // NOTE: BufferBlock was deprecated in SPIRV 1.3 + // use StorageClassStorageBuffer instead. // runtime array are always decorated as BufferBlock(shader storage buffer) if (num_elems == 0) { - this->Decorate(spv::OpDecorate, - struct_type, spv::DecorationBufferBlock); + this->Decorate(spv::OpDecorate, struct_type, spv::DecorationBufferBlock); } +#else + this->Decorate(spv::OpDecorate, struct_type, spv::DecorationBlock); +#endif struct_array_type_tbl_[key] = struct_type; return struct_type; } -Value IRBuilder::StructArrayAccess(const SType& res_type, - Value buffer, - Value index) { +Value IRBuilder::StructArrayAccess(const SType& res_type, Value buffer, Value index) { CHECK(buffer.flag == kStructArrayPtr); - return MakeValue(spv::OpInBoundsAccessChain, - res_type, buffer, - const_i32_zero_, index); + return MakeValue(spv::OpInBoundsAccessChain, res_type, buffer, const_i32_zero_, index); } Value IRBuilder::IntImm(const SType& dtype, int64_t value) { return GetConst_(dtype, reinterpret_cast(&value)); } -Value IRBuilder::UIntImm(const SType& dtype, uint64_t value) { - return GetConst_(dtype, &value); -} +Value IRBuilder::UIntImm(const SType& dtype, uint64_t value) { return GetConst_(dtype, &value); } Value IRBuilder::FloatImm(const SType& dtype, double value) { if (dtype.type.bits() == 64) { @@ -182,23 +178,28 @@ Value IRBuilder::FloatImm(const SType& dtype, double value) { return GetConst_(dtype, &data); } else { CHECK_EQ(dtype.type.bits(), 16); - return Cast(dtype, - FloatImm(GetSType(DataType::Float(32)), value)); + return Cast(dtype, FloatImm(GetSType(DataType::Float(32)), value)); } } -Value IRBuilder::BufferArgument(const SType& value_type, - uint32_t descriptor_set, +Value IRBuilder::BufferArgument(const SType& value_type, uint32_t descriptor_set, uint32_t binding) { + // NOTE: BufferBlock was deprecated in SPIRV 1.3 + // use StorageClassStorageBuffer instead. +#if SPV_VERSION >= 0x10300 + spv::StorageClass storage_class = spv::StorageClassStorageBuffer; +#else + spv::StorageClass storage_class = spv::StorageClassUniform; +#endif + SType sarr_type = GetStructArrayType(value_type, 0); - SType ptr_type = GetPointerType(sarr_type, spv::StorageClassUniform); + SType ptr_type = GetPointerType(sarr_type, storage_class); Value val = NewValue(ptr_type, kStructArrayPtr); - ib_.Begin(spv::OpVariable) - .AddSeq(ptr_type, val, spv::StorageClassUniform).Commit(&global_); - this->Decorate(spv::OpDecorate, - val, spv::DecorationDescriptorSet, descriptor_set); - this->Decorate(spv::OpDecorate, - val, spv::DecorationBinding, binding); + + ib_.Begin(spv::OpVariable).AddSeq(ptr_type, val, storage_class).Commit(&global_); + + this->Decorate(spv::OpDecorate, val, spv::DecorationDescriptorSet, descriptor_set); + this->Decorate(spv::OpDecorate, val, spv::DecorationBinding, binding); return val; } @@ -220,37 +221,30 @@ Value IRBuilder::DeclarePushConstant(const std::vector& value_types) { .Commit(&decorate_); DataType t = value_types[i].type; uint32_t nbits = t.bits() * t.lanes(); - CHECK_EQ(nbits % 8 , 0); + CHECK_EQ(nbits % 8, 0); offset += nbits / 8; } // Decorate push constants as UBO this->Decorate(spv::OpDecorate, struct_type, spv::DecorationBlock); - SType ptr_type = GetPointerType( - struct_type, spv::StorageClassPushConstant); + SType ptr_type = GetPointerType(struct_type, spv::StorageClassPushConstant); Value val = NewValue(ptr_type, kPushConstantPtr); - ib_.Begin(spv::OpVariable) - .AddSeq(ptr_type, val, spv::StorageClassPushConstant).Commit(&global_); + ib_.Begin(spv::OpVariable).AddSeq(ptr_type, val, spv::StorageClassPushConstant).Commit(&global_); return val; } -Value IRBuilder::GetPushConstant( - Value ptr_push_const, const SType& v_type, uint32_t index) { +Value IRBuilder::GetPushConstant(Value ptr_push_const, const SType& v_type, uint32_t index) { SType ptr_vtype = this->GetPointerType(v_type, spv::StorageClassPushConstant); - Value ptr = this->MakeValue( - spv::OpAccessChain, ptr_vtype, ptr_push_const, - IntImm(t_int32_, static_cast(index))); + Value ptr = this->MakeValue(spv::OpAccessChain, ptr_vtype, ptr_push_const, + IntImm(t_int32_, static_cast(index))); return this->MakeValue(spv::OpLoad, v_type, ptr); } -Value IRBuilder::NewFunction() { - return NewValue(t_void_func_, kFunction); -} +Value IRBuilder::NewFunction() { return NewValue(t_void_func_, kFunction); } void IRBuilder::CommitKernelFunction(const Value& func, const std::string& name) { CHECK_EQ(func.flag, kFunction); - ib_.Begin(spv::OpEntryPoint) - .AddSeq(spv::ExecutionModelGLCompute, func, name); + ib_.Begin(spv::OpEntryPoint).AddSeq(spv::ExecutionModelGLCompute, func, name); if (workgroup_id_.id != 0) { ib_.Add(workgroup_id_); } @@ -262,34 +256,31 @@ void IRBuilder::CommitKernelFunction(const Value& func, const std::string& name) void IRBuilder::StartFunction(const Value& func) { CHECK_EQ(func.flag, kFunction); - this->MakeInst( - spv::OpFunction, t_void_, func, 0, t_void_func_); + // add function declaration to the header. + ib_.Begin(spv::OpFunction).AddSeq(t_void_, func, 0, t_void_func_).Commit(&func_header_); + spirv::Label start_label = this->NewLabel(); - this->StartLabel(start_label); + ib_.Begin(spv::OpLabel).AddSeq(start_label).Commit(&func_header_); + curr_label_ = start_label; } -void IRBuilder::SetLocalSize(const Value& func, - uint32_t local_size[3]) { +void IRBuilder::SetLocalSize(const Value& func, uint32_t local_size[3]) { CHECK_EQ(func.flag, kFunction); ib_.Begin(spv::OpExecutionMode) - .AddSeq(func, spv::ExecutionModeLocalSize, - local_size[0], local_size[1], local_size[2]) + .AddSeq(func, spv::ExecutionModeLocalSize, local_size[0], local_size[1], local_size[2]) .Commit(&exec_mode_); } -Value IRBuilder::Allocate(const SType& value_type, - uint32_t num_elems, +Value IRBuilder::Allocate(const SType& value_type, uint32_t num_elems, spv::StorageClass storage_class) { CHECK_NE(num_elems, 0U); SType sarr_type = GetStructArrayType(value_type, num_elems); SType ptr_type = GetPointerType(sarr_type, storage_class); Value val = NewValue(ptr_type, kStructArrayPtr); if (storage_class == spv::StorageClassFunction) { - ib_.Begin(spv::OpVariable) - .AddSeq(ptr_type, val, storage_class).Commit(&function_); + ib_.Begin(spv::OpVariable).AddSeq(ptr_type, val, storage_class).Commit(&func_header_); } else { - ib_.Begin(spv::OpVariable) - .AddSeq(ptr_type, val, storage_class).Commit(&global_); + ib_.Begin(spv::OpVariable).AddSeq(ptr_type, val, storage_class).Commit(&global_); } return val; } @@ -297,19 +288,16 @@ Value IRBuilder::Allocate(const SType& value_type, Value IRBuilder::GetWorkgroupID(uint32_t dim_index) { if (workgroup_id_.id == 0) { SType vec3_type = this->GetSType(DataType::Int(32).with_lanes(3)); - SType ptr_type = this->GetPointerType( - vec3_type, spv::StorageClassInput); + SType ptr_type = this->GetPointerType(vec3_type, spv::StorageClassInput); workgroup_id_ = NewValue(ptr_type, kVectorPtr); ib_.Begin(spv::OpVariable) .AddSeq(ptr_type, workgroup_id_, spv::StorageClassInput) .Commit(&global_); - this->Decorate(spv::OpDecorate, workgroup_id_, - spv::DecorationBuiltIn, spv::BuiltInWorkgroupId); + this->Decorate(spv::OpDecorate, workgroup_id_, spv::DecorationBuiltIn, spv::BuiltInWorkgroupId); } SType pint_type = this->GetPointerType(t_int32_, spv::StorageClassInput); - Value ptr = this->MakeValue( - spv::OpAccessChain, pint_type, workgroup_id_, - IntImm(t_int32_, static_cast(dim_index))); + Value ptr = this->MakeValue(spv::OpAccessChain, pint_type, workgroup_id_, + IntImm(t_int32_, static_cast(dim_index))); return this->MakeValue(spv::OpLoad, t_int32_, ptr); } @@ -318,16 +306,13 @@ Value IRBuilder::GetLocalID(uint32_t dim_index) { SType vec3_type = this->GetSType(DataType::Int(32).with_lanes(3)); SType ptr_type = this->GetPointerType(vec3_type, spv::StorageClassInput); local_id_ = NewValue(ptr_type, kVectorPtr); - ib_.Begin(spv::OpVariable) - .AddSeq(ptr_type, local_id_, spv::StorageClassInput) - .Commit(&global_); - this->Decorate(spv::OpDecorate, local_id_, - spv::DecorationBuiltIn, spv::BuiltInLocalInvocationId); + ib_.Begin(spv::OpVariable).AddSeq(ptr_type, local_id_, spv::StorageClassInput).Commit(&global_); + this->Decorate(spv::OpDecorate, local_id_, spv::DecorationBuiltIn, + spv::BuiltInLocalInvocationId); } SType pint_type = this->GetPointerType(t_int32_, spv::StorageClassInput); - Value ptr = this->MakeValue( - spv::OpAccessChain, pint_type, local_id_, - UIntImm(t_int32_, static_cast(dim_index))); + Value ptr = this->MakeValue(spv::OpAccessChain, pint_type, local_id_, + UIntImm(t_int32_, static_cast(dim_index))); return this->MakeValue(spv::OpLoad, t_int32_, ptr); } @@ -354,9 +339,8 @@ Value IRBuilder::GetConst_(const SType& dtype, const uint64_t* pvalue) { if (dtype.type.bits() > 32) { if (dtype.type.is_int()) { int64_t sign_mask = 0xFFFFFFFFL; - const int64_t* sign_ptr = - reinterpret_cast(pvalue); - ib_.Add(static_cast((sign_ptr[0] >> 32L) & sign_mask)); + const int64_t* sign_ptr = reinterpret_cast(pvalue); + ib_.Add(static_cast((sign_ptr[0] >> 32L) & sign_mask)); } else { ib_.Add(static_cast((pvalue[0] >> 32UL) & mask)); } @@ -390,8 +374,7 @@ SType IRBuilder::DeclareType(const DataType& dtype) { t.id = id_counter_++; t.type = dtype; SType base_type = GetSType(dtype.element_of()); - ib_.Begin(spv::OpTypeVector).AddSeq( - t, base_type, dtype.lanes()).Commit(&global_); + ib_.Begin(spv::OpTypeVector).AddSeq(t, base_type, dtype.lanes()).Commit(&global_); return t; } } @@ -411,12 +394,10 @@ PhiValue IRBuilder::MakePhi(const SType& out_type, uint32_t num_incoming) { return phi; } -Value IRBuilder::CallGLSL450(const SType& ret_type, - uint32_t inst_id, +Value IRBuilder::CallGLSL450(const SType& ret_type, uint32_t inst_id, const std::vector& args) { Value val = NewValue(ret_type, kNormal); - ib_.Begin(spv::OpExtInst) - .AddSeq(ret_type, val, ext_glsl450_, inst_id); + ib_.Begin(spv::OpExtInst).AddSeq(ret_type, val, ext_glsl450_, inst_id); for (const Value& v : args) { ib_.Add(v); } @@ -486,14 +467,12 @@ Value IRBuilder::Cast(const SType& dst_type, spirv::Value value) { return MakeValue(spv::OpUConvert, dst_type, value); } else if (from.is_uint() && to.is_int()) { if (from.bits() != to.bits()) { - value = MakeValue( - spv::OpUConvert, GetSType(from.with_bits(to.bits())), value); + value = MakeValue(spv::OpUConvert, GetSType(from.with_bits(to.bits())), value); } return MakeValue(spv::OpBitcast, dst_type, value); } else if (from.is_int() && to.is_uint()) { if (from.bits() != to.bits()) { - value = MakeValue( - spv::OpSConvert, GetSType(from.with_bits(to.bits())), value); + value = MakeValue(spv::OpSConvert, GetSType(from.with_bits(to.bits())), value); } return MakeValue(spv::OpBitcast, dst_type, value); } else if (from.is_float() && to.is_int()) { @@ -507,21 +486,20 @@ Value IRBuilder::Cast(const SType& dst_type, spirv::Value value) { } else if (from.is_float() && to.is_float()) { return MakeValue(spv::OpFConvert, dst_type, value); } else { - LOG(FATAL) << "do not support type cast from " - << from << " to " << to; + LOG(FATAL) << "do not support type cast from " << from << " to " << to; return Value(); } } -#define DEFINE_BUILDER_BINARY_USIGN_OP(_OpName, _Op) \ - Value IRBuilder::_OpName(Value a, Value b) { \ - CHECK_EQ(a.stype.id, b.stype.id); \ - if (a.stype.type.is_int() || a.stype.type.is_uint()) { \ - return MakeValue(spv::OpI ## _Op, a.stype, a, b); \ - } else { \ - CHECK(a.stype.type.is_float()); \ - return MakeValue(spv::OpF ## _Op, a.stype, a, b); \ - } \ +#define DEFINE_BUILDER_BINARY_USIGN_OP(_OpName, _Op) \ + Value IRBuilder::_OpName(Value a, Value b) { \ + CHECK_EQ(a.stype.id, b.stype.id); \ + if (a.stype.type.is_int() || a.stype.type.is_uint()) { \ + return MakeValue(spv::OpI##_Op, a.stype, a, b); \ + } else { \ + CHECK(a.stype.type.is_float()); \ + return MakeValue(spv::OpF##_Op, a.stype, a, b); \ + } \ } #define DEFINE_BUILDER_BINARY_SIGN_OP(_OpName, _Op) \ @@ -554,19 +532,19 @@ Value IRBuilder::Mod(Value a, Value b) { } } -#define DEFINE_BUILDER_CMP_OP(_OpName, _Op) \ - Value IRBuilder::_OpName(Value a, Value b) { \ - CHECK_EQ(a.stype.id, b.stype.id); \ - CHECK_EQ(a.stype.type.lanes(), b.stype.type.lanes()); \ +#define DEFINE_BUILDER_CMP_OP(_OpName, _Op) \ + Value IRBuilder::_OpName(Value a, Value b) { \ + CHECK_EQ(a.stype.id, b.stype.id); \ + CHECK_EQ(a.stype.type.lanes(), b.stype.type.lanes()); \ const auto& bool_type = this->GetSType(DataType::UInt(1).with_lanes(a.stype.type.lanes())); \ - if (a.stype.type.is_int()) { \ - return MakeValue(spv::OpS##_Op, bool_type, a, b); \ - } else if (a.stype.type.is_uint()) { \ - return MakeValue(spv::OpU##_Op, bool_type, a, b); \ - } else { \ - CHECK(a.stype.type.is_float()); \ - return MakeValue(spv::OpFOrd##_Op, bool_type, a, b); \ - } \ + if (a.stype.type.is_int()) { \ + return MakeValue(spv::OpS##_Op, bool_type, a, b); \ + } else if (a.stype.type.is_uint()) { \ + return MakeValue(spv::OpU##_Op, bool_type, a, b); \ + } else { \ + CHECK(a.stype.type.is_float()); \ + return MakeValue(spv::OpFOrd##_Op, bool_type, a, b); \ + } \ } DEFINE_BUILDER_CMP_OP(LT, LessThan); @@ -574,17 +552,17 @@ DEFINE_BUILDER_CMP_OP(LE, LessThanEqual); DEFINE_BUILDER_CMP_OP(GT, GreaterThan); DEFINE_BUILDER_CMP_OP(GE, GreaterThanEqual); -#define DEFINE_BUILDER_CMP_UOP(_OpName, _Op) \ - Value IRBuilder::_OpName(Value a, Value b) { \ - CHECK_EQ(a.stype.id, b.stype.id); \ - CHECK_EQ(a.stype.type.lanes(), b.stype.type.lanes()); \ +#define DEFINE_BUILDER_CMP_UOP(_OpName, _Op) \ + Value IRBuilder::_OpName(Value a, Value b) { \ + CHECK_EQ(a.stype.id, b.stype.id); \ + CHECK_EQ(a.stype.type.lanes(), b.stype.type.lanes()); \ const auto& bool_type = this->GetSType(DataType::UInt(1).with_lanes(a.stype.type.lanes())); \ - if (a.stype.type.is_int() || a.stype.type.is_uint()) { \ - return MakeValue(spv::OpI##_Op, bool_type, a, b); \ - } else { \ - CHECK(a.stype.type.is_float()); \ - return MakeValue(spv::OpFOrd##_Op, bool_type, a, b); \ - } \ + if (a.stype.type.is_int() || a.stype.type.is_uint()) { \ + return MakeValue(spv::OpI##_Op, bool_type, a, b); \ + } else { \ + CHECK(a.stype.type.is_float()); \ + return MakeValue(spv::OpFOrd##_Op, bool_type, a, b); \ + } \ } DEFINE_BUILDER_CMP_UOP(EQ, Equal); diff --git a/src/target/spirv/ir_builder.h b/src/target/spirv/ir_builder.h index bdfea4ff7f1c..c52f92fd7c20 100644 --- a/src/target/spirv/ir_builder.h +++ b/src/target/spirv/ir_builder.h @@ -27,14 +27,15 @@ #include #include +// clang-format off #include -#include -#include -#include #include +#include #include - +#include +#include #include +// clang-format on namespace tvm { namespace codegen { @@ -85,9 +86,7 @@ struct Label { class Instr { public: /*! \return the word count */ - uint32_t WordCount() const { - return word_count_; - } + uint32_t WordCount() const { return word_count_; } /*! * \brief Access idx-th word of instruction * \param idx The index @@ -122,9 +121,7 @@ struct PhiValue : public Value { * \param value The value to come * \param parent The parent label. */ - void SetIncoming(uint32_t index, - const Value& value, - const Label& parent) { + void SetIncoming(uint32_t index, const Value& value, const Label& parent) { CHECK_EQ(this->stype.id, value.stype.id); instr[3 + index * 2] = value.id; instr[3 + index * 2 + 1] = parent.id; @@ -203,12 +200,10 @@ class InstrBuilder { */ InstrBuilder& Add(const std::string& v) { const uint32_t kWordSize = sizeof(uint32_t); - uint32_t nwords = - (static_cast(v.length()) + kWordSize) / kWordSize; + uint32_t nwords = (static_cast(v.length()) + kWordSize) / kWordSize; size_t begin = data_.size(); data_.resize(begin + nwords, 0U); - std::copy(v.begin(), v.end(), - reinterpret_cast(&data_[begin])); + std::copy(v.begin(), v.end(), reinterpret_cast(&data_[begin])); return *this; } /*! @@ -217,8 +212,8 @@ class InstrBuilder { * \return reference to self. * \tparams Args The positional arguments */ - template - InstrBuilder& AddSeq(Args&& ...args) { + template + InstrBuilder& AddSeq(Args&&... args) { AddSeqHelper helper; helper.builder = this; runtime::detail::for_each(helper, std::forward(args)...); @@ -252,7 +247,7 @@ class InstrBuilder { // The reference to builder InstrBuilder* builder; // invoke function - template + template void operator()(size_t, const T& v) const { builder->Add(v); } @@ -301,6 +296,7 @@ class IRBuilder { data.insert(data.end(), debug_.begin(), debug_.end()); data.insert(data.end(), decorate_.begin(), decorate_.end()); data.insert(data.end(), global_.begin(), global_.end()); + data.insert(data.end(), func_header_.begin(), func_header_.end()); data.insert(data.end(), function_.begin(), function_.end()); return data; } @@ -322,17 +318,15 @@ class IRBuilder { curr_label_ = label; } /*! \return The current label */ - Label CurrentLabel() const { - return curr_label_; - } + Label CurrentLabel() const { return curr_label_; } /*! * \brief Add code to debug segment. * \param op The operator * \param args The instruction sequence * \tparams Args The positional arguments */ - template - void Debug(spv::Op op, Args&& ...args) { + template + void Debug(spv::Op op, Args&&... args) { ib_.Begin(op).AddSeq(std::forward(args)...).Commit(&debug_); } /*! @@ -341,10 +335,9 @@ class IRBuilder { * \param args The instruction sequence * \tparams Args The positional arguments */ - template - void ExecutionMode(Value func, Args&& ...args) { - ib_.Begin(spv::OpExecutionMode).AddSeq( - func, std::forward(args)...).Commit(&exec_mode_); + template + void ExecutionMode(Value func, Args&&... args) { + ib_.Begin(spv::OpExecutionMode).AddSeq(func, std::forward(args)...).Commit(&exec_mode_); } /*! * \brief Add code to decorate segment. @@ -352,8 +345,8 @@ class IRBuilder { * \param args The instruction sequence * \tparams Args The positional arguments */ - template - void Decorate(spv::Op op, Args&& ...args) { + template + void Decorate(spv::Op op, Args&&... args) { ib_.Begin(op).AddSeq(std::forward(args)...).Commit(&decorate_); } /*! @@ -362,8 +355,8 @@ class IRBuilder { * \param args The instruction sequence * \tparams Args The positional arguments */ - template - void DeclareGlobal(spv::Op op, Args&& ...args) { + template + void DeclareGlobal(spv::Op op, Args&&... args) { ib_.Begin(op).AddSeq(std::forward(args)...).Commit(&decorate_); } /*! @@ -374,8 +367,8 @@ class IRBuilder { * \return The result SSA value. * \tparams Args The positional arguments */ - template - Instr MakeInst(spv::Op op, Args&& ...args) { + template + Instr MakeInst(spv::Op op, Args&&... args) { return ib_.Begin(op).AddSeq(std::forward(args)...).Commit(&function_); } /*! @@ -387,8 +380,8 @@ class IRBuilder { * \return The result SSA value. * \tparams Args The positional arguments */ - template - Value MakeValue(spv::Op op, const SType& out_type, Args&& ...args) { + template + Value MakeValue(spv::Op op, const SType& out_type, Args&&... args) { Value val = NewValue(out_type, kNormal); MakeInst(op, out_type, val, std::forward(args)...); return val; @@ -409,9 +402,7 @@ class IRBuilder { * \param args The arguments * \return The result value. */ - Value CallGLSL450(const SType& ret_type, - uint32_t inst_id, - const std::vector& args); + Value CallGLSL450(const SType& ret_type, uint32_t inst_id, const std::vector& args); /*! * \brief Build vector by concatenating components * @@ -431,8 +422,7 @@ class IRBuilder { * \param storage_class The storage class * \return The corresponding spirv type. */ - SType GetPointerType(const SType& value_type, - spv::StorageClass storage_class); + SType GetPointerType(const SType& value_type, spv::StorageClass storage_class); /*! * \brief Get a struct{ value_type[num_elems] } type. * \param value_type the content value type. @@ -441,17 +431,14 @@ class IRBuilder { * * \return The corresponding spirv type. */ - SType GetStructArrayType(const SType& value_type, - uint32_t num_elems); + SType GetStructArrayType(const SType& value_type, uint32_t num_elems); /*! * \brief Get a struct array access with a given index. * \param ptr_type The pointer type. * \param buffer The buffer ptr to struct array * \param index The array index. */ - Value StructArrayAccess(const SType& ptr_type, - Value buffer, - Value index); + Value StructArrayAccess(const SType& ptr_type, Value buffer, Value index); /*! * \brief Create a cast that cast value to dst_type * \param dst_type The target type. @@ -485,9 +472,7 @@ class IRBuilder { * \param binding The binding locaiton in descriptor set. * \param The argument type. */ - Value BufferArgument(const SType& value_type, - uint32_t descriptor_set, - uint32_t binding); + Value BufferArgument(const SType& value_type, uint32_t descriptor_set, uint32_t binding); /*! * \brief Declare POD arguments through push constants. * @@ -533,9 +518,7 @@ class IRBuilder { * \param num_elems Number of elements to allocate. * \param storage_class The storage class we want to store to. */ - Value Allocate(const SType& value_type, - uint32_t num_elems, - spv::StorageClass storage_class); + Value Allocate(const SType& value_type, uint32_t num_elems, spv::StorageClass storage_class); /* * \brief Get the i-th workgroup id. * \return The value representing the workgroup id. @@ -610,8 +593,10 @@ class IRBuilder { std::vector debug_; /*! \brief Annotation segment */ std::vector decorate_; - /*! \brief Global segment: types, variables, types */ + /*! \brief Global segment: types, variables, types */ std::vector global_; + /*! \brief Function header segment */ + std::vector func_header_; /*! \brief Function segment */ std::vector function_; }; diff --git a/src/target/stackvm/codegen_stackvm.cc b/src/target/stackvm/codegen_stackvm.cc index b28b6a1f5fb4..6dd2ca0ecb6c 100644 --- a/src/target/stackvm/codegen_stackvm.cc +++ b/src/target/stackvm/codegen_stackvm.cc @@ -20,14 +20,17 @@ /*! * \file codegen_stackvm.cc */ -#include -#include +#include "codegen_stackvm.h" + #include -#include +#include +#include #include +#include + #include #include -#include "codegen_stackvm.h" + #include "../../runtime/stackvm/stackvm_module.h" namespace tvm { @@ -40,19 +43,32 @@ using namespace tir; StackVM::StructFieldKind MapFieldKind(int64_t kind) { auto val = static_cast(kind); switch (val) { - case intrinsic::kArrData: return StackVM::kArrData; - case intrinsic::kArrShape: return StackVM::kArrShape; - case intrinsic::kArrAddr: return StackVM::kArrAddr; - case intrinsic::kArrStrides: return StackVM::kArrStrides; - case intrinsic::kArrNDim: return StackVM::kArrNDim; - case intrinsic::kArrTypeCode: return StackVM::kArrTypeCode; - case intrinsic::kArrTypeBits: return StackVM::kArrTypeBits; - case intrinsic::kArrTypeLanes: return StackVM::kArrTypeLanes; - case intrinsic::kArrByteOffset: return StackVM::kArrByteOffset; - case intrinsic::kArrDeviceId: return StackVM::kArrDeviceId; - case intrinsic::kArrDeviceType: return StackVM::kArrDeviceType; - case intrinsic::kTVMValueContent: return StackVM::kTVMValueContent; - default: LOG(FATAL) << "Do not know how to map field " << kind; + case intrinsic::kArrData: + return StackVM::kArrData; + case intrinsic::kArrShape: + return StackVM::kArrShape; + case intrinsic::kArrAddr: + return StackVM::kArrAddr; + case intrinsic::kArrStrides: + return StackVM::kArrStrides; + case intrinsic::kArrNDim: + return StackVM::kArrNDim; + case intrinsic::kArrTypeCode: + return StackVM::kArrTypeCode; + case intrinsic::kArrTypeBits: + return StackVM::kArrTypeBits; + case intrinsic::kArrTypeLanes: + return StackVM::kArrTypeLanes; + case intrinsic::kArrByteOffset: + return StackVM::kArrByteOffset; + case intrinsic::kArrDeviceId: + return StackVM::kArrDeviceId; + case intrinsic::kArrDeviceType: + return StackVM::kArrDeviceType; + case intrinsic::kTVMValueContent: + return StackVM::kTVMValueContent; + default: + LOG(FATAL) << "Do not know how to map field " << kind; } return StackVM::kArrData; } @@ -84,8 +100,7 @@ void CodeGenStackVM::PushOp(StackVM::OpCode opcode) { } void CodeGenStackVM::SetOperand(int64_t operand_index, int64_t operand) { - CHECK(operand >= std::numeric_limits::min() && - operand <= std::numeric_limits::max()); + CHECK(operand >= std::numeric_limits::min() && operand <= std::numeric_limits::max()); vm_.code.at(operand_index).v_int = static_cast(operand); } @@ -120,8 +135,7 @@ int CodeGenStackVM::AllocVarID(const VarNode* v) { int CodeGenStackVM::GetVarID(const VarNode* v) const { auto it = var_idmap_.find(v); - CHECK(it != var_idmap_.end()) - << "Find undefined Variable " << v->name_hint; + CHECK(it != var_idmap_.end()) << "Find undefined Variable " << v->name_hint; return it->second; } @@ -161,7 +175,7 @@ void CodeGenStackVM::VisitStmt_(const AllocateNode* op) { void CodeGenStackVM::VisitExpr_(const CallNode* op) { if (op->is_intrinsic(intrinsic::tvm_address_of)) { - const LoadNode *l = op->args[0].as(); + const LoadNode* l = op->args[0].as(); CHECK(op->args.size() == 1 && l); this->PushOp(StackVM::LOAD_HEAP, GetVarID(l->buffer_var.get())); this->Push(l->index); @@ -261,9 +275,7 @@ void CodeGenStackVM::VisitExpr_(const CallNode* op) { } } -void CodeGenStackVM::PushBinary(StackVM::OpCode op_int64, - const PrimExpr& a, - const PrimExpr& b) { +void CodeGenStackVM::PushBinary(StackVM::OpCode op_int64, const PrimExpr& a, const PrimExpr& b) { this->Push(a); this->Push(b); DataType t = a.dtype(); @@ -295,7 +307,7 @@ void CodeGenStackVM::VisitExpr_(const IntImmNode* op) { CHECK(op->value >= std::numeric_limits::min() && op->value <= std::numeric_limits::max()) << "Int constant exceed bound"; - this->PushOp(StackVM::PUSH_I64, static_cast(op->value)); + this->PushOp(StackVM::PUSH_I64, static_cast(op->value)); } void CodeGenStackVM::VisitExpr_(const FloatImmNode* op) { @@ -312,25 +324,15 @@ void CodeGenStackVM::VisitExpr_(const CastNode* op) { PushCast(op->dtype, op->value.dtype()); } -void CodeGenStackVM::VisitExpr_(const AddNode* op) { - PushBinary(StackVM::ADD_I64, op->a, op->b); -} +void CodeGenStackVM::VisitExpr_(const AddNode* op) { PushBinary(StackVM::ADD_I64, op->a, op->b); } -void CodeGenStackVM::VisitExpr_(const SubNode* op) { - PushBinary(StackVM::SUB_I64, op->a, op->b); -} +void CodeGenStackVM::VisitExpr_(const SubNode* op) { PushBinary(StackVM::SUB_I64, op->a, op->b); } -void CodeGenStackVM::VisitExpr_(const MulNode* op) { - PushBinary(StackVM::MUL_I64, op->a, op->b); -} +void CodeGenStackVM::VisitExpr_(const MulNode* op) { PushBinary(StackVM::MUL_I64, op->a, op->b); } -void CodeGenStackVM::VisitExpr_(const DivNode* op) { - PushBinary(StackVM::DIV_I64, op->a, op->b); -} +void CodeGenStackVM::VisitExpr_(const DivNode* op) { PushBinary(StackVM::DIV_I64, op->a, op->b); } -void CodeGenStackVM::VisitExpr_(const ModNode* op) { - PushBinary(StackVM::MOD_I64, op->a, op->b); -} +void CodeGenStackVM::VisitExpr_(const ModNode* op) { PushBinary(StackVM::MOD_I64, op->a, op->b); } void CodeGenStackVM::VisitExpr_(const MinNode* op) { this->Push(op->a); @@ -350,22 +352,16 @@ void CodeGenStackVM::VisitExpr_(const MaxNode* op) { this->PushOp(StackVM::SELECT); } -void CodeGenStackVM::VisitExpr_(const EQNode* op) { - PushBinary(StackVM::EQ_I64, op->a, op->b); -} +void CodeGenStackVM::VisitExpr_(const EQNode* op) { PushBinary(StackVM::EQ_I64, op->a, op->b); } -void CodeGenStackVM::VisitExpr_(const LENode* op) { - PushBinary(StackVM::LE_I64, op->a, op->b); -} +void CodeGenStackVM::VisitExpr_(const LENode* op) { PushBinary(StackVM::LE_I64, op->a, op->b); } void CodeGenStackVM::VisitExpr_(const NENode* op) { PushBinary(StackVM::EQ_I64, op->a, op->b); this->PushOp(StackVM::NOT); } -void CodeGenStackVM::VisitExpr_(const LTNode* op) { - PushBinary(StackVM::LT_I64, op->a, op->b); -} +void CodeGenStackVM::VisitExpr_(const LTNode* op) { PushBinary(StackVM::LT_I64, op->a, op->b); } void CodeGenStackVM::VisitExpr_(const GENode* op) { PushBinary(StackVM::LT_I64, op->a, op->b); @@ -431,7 +427,7 @@ void CodeGenStackVM::VisitStmt_(const SeqStmtNode* op) { } } -void CodeGenStackVM::VisitStmt_(const EvaluateNode *ev) { +void CodeGenStackVM::VisitStmt_(const EvaluateNode* ev) { if (is_const(ev->value)) return; const CallNode* op = ev->value.as(); if (op && op->is_intrinsic(intrinsic::tvm_struct_set)) { @@ -482,9 +478,7 @@ void CodeGenStackVM::VisitStmt_(const LetStmtNode* op) { this->Push(op->body); } -void CodeGenStackVM::VisitExpr_(const RampNode* op) { - LOG(FATAL) << "Ramp is not supported"; -} +void CodeGenStackVM::VisitExpr_(const RampNode* op) { LOG(FATAL) << "Ramp is not supported"; } void CodeGenStackVM::VisitExpr_(const BroadcastNode* op) { LOG(FATAL) << "Broadcast is not supported"; @@ -506,9 +500,7 @@ void CodeGenStackVM::VisitStmt_(const AssertStmtNode* op) { this->Push(op->body); } -void CodeGenStackVM::VisitStmt_(const AttrStmtNode* op) { - this->Push(op->body); -} +void CodeGenStackVM::VisitStmt_(const AttrStmtNode* op) { this->Push(op->body); } void CodeGenStackVM::VisitExpr_(const LetNode* op) { this->Push(op->value); @@ -517,21 +509,19 @@ void CodeGenStackVM::VisitExpr_(const LetNode* op) { this->Push(op->body); } -runtime::Module BuildStackVM(const IRModule& mod) { +runtime::Module BuildStackVM(const IRModule& mod, const std::string& target) { std::unordered_map fmap; std::string entry_func; - for (auto kv : mod->functions) { - CHECK(kv.second->IsInstance()) - << "CodeGenStackVM: Can only take PrimFunc"; + for (auto kv : mod->functions) { + CHECK(kv.second->IsInstance()) << "CodeGenStackVM: Can only take PrimFunc"; auto f = Downcast(kv.second); auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); CHECK(global_symbol.defined()) << "CodeGenStackVM: Expect PrimFunc to have the global_symbol attribute"; std::string f_name = global_symbol.value(); StackVM vm = codegen::CodeGenStackVM().Compile(f); - CHECK(!fmap.count(f_name)) - << "Function name " << f_name << "already exist in list"; + CHECK(!fmap.count(f_name)) << "Function name " << f_name << "already exist in list"; fmap[f_name] = std::move(vm); if (f->HasNonzeroAttr(tir::attr::kIsEntryFunc)) { @@ -542,7 +532,6 @@ runtime::Module BuildStackVM(const IRModule& mod) { return runtime::StackVMModuleCreate(fmap, entry_func); } -TVM_REGISTER_GLOBAL("target.build.stackvm") -.set_body_typed(BuildStackVM); +TVM_REGISTER_GLOBAL("target.build.stackvm").set_body_typed(BuildStackVM); } // namespace codegen } // namespace tvm diff --git a/src/target/stackvm/codegen_stackvm.h b/src/target/stackvm/codegen_stackvm.h index 31036822649d..b77c40696de6 100644 --- a/src/target/stackvm/codegen_stackvm.h +++ b/src/target/stackvm/codegen_stackvm.h @@ -24,12 +24,14 @@ #ifndef TVM_TARGET_STACKVM_CODEGEN_STACKVM_H_ #define TVM_TARGET_STACKVM_CODEGEN_STACKVM_H_ +#include #include +#include #include -#include + #include -#include #include +#include #include "../../runtime/stackvm/stackvm.h" @@ -44,11 +46,10 @@ using runtime::StackVM; * This module is used to generate host wrapper * into device function when only device JIT is available. */ -class CodeGenStackVM - : public ExprFunctor, - public StmtFunctor { +class CodeGenStackVM : public ExprFunctor, + public StmtFunctor { public: - /*! + /*! * \brief Generate a stack VM representing * \param f The function to be compiled * \param device_funcs The extern device functions to be linked. @@ -59,9 +60,7 @@ class CodeGenStackVM /*! \brief Push stmt to generate new code */ void Push(const Stmt& n); /*! \brief Push expr to generate new code */ - void Push(const PrimExpr& n) { - VisitExpr(n); - } + void Push(const PrimExpr& n) { VisitExpr(n); } /*! * \brief Push the opcode to the code. * \param opcode The code to be pushed. @@ -81,9 +80,7 @@ class CodeGenStackVM */ void SetOperand(int64_t operand_index, int64_t operand); /*! \return The current program pointer */ - int64_t GetPC() const { - return static_cast(vm_.code.size()); - } + int64_t GetPC() const { return static_cast(vm_.code.size()); } /*! * \brief Get string id in vm * \param key The string to get id. @@ -103,9 +100,7 @@ class CodeGenStackVM */ int GetVarID(const VarNode* v) const; // Push binary operator - void PushBinary(StackVM::OpCode op_int64, - const PrimExpr& a, - const PrimExpr& b); + void PushBinary(StackVM::OpCode op_int64, const PrimExpr& a, const PrimExpr& b); // push cast; void PushCast(DataType dst, DataType src); // overloadable functions diff --git a/src/target/target.cc b/src/target/target.cc index 50856d62af30..2104c2e4bdc1 100644 --- a/src/target/target.cc +++ b/src/target/target.cc @@ -21,11 +21,9 @@ * \file src/target/target.cc */ #include - -#include #include +#include #include - #include #include @@ -33,28 +31,27 @@ namespace tvm { +using runtime::PackedFunc; using runtime::TVMArgs; using runtime::TVMRetValue; -using runtime::PackedFunc; TVM_REGISTER_NODE_TYPE(TargetNode); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << op->str(); - }); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << op->str(); + }); /*! -* \brief Construct a Target node from the given name and options. -* \param target_name The major target name. Should be one of -* {"aocl", "aocl_sw_emu", "c", "cuda", "ext_dev", "hexagon", "hybrid", "llvm", -* "metal", "nvptx", "opencl", "opengl", "rocm", "sdaccel", "stackvm", "vulkan"} -* \param options Additional options appended to the target -* \return The constructed Target -*/ -Target CreateTarget(const std::string& target_name, - const std::vector& options) { + * \brief Construct a Target node from the given name and options. + * \param target_name The major target name. Should be one of + * {"aocl", "aocl_sw_emu", "c", "cuda", "ext_dev", "hexagon", "hybrid", "llvm", + * "metal", "nvptx", "opencl", "rocm", "sdaccel", "stackvm", "vulkan"} + * \param options Additional options appended to the target + * \return The constructed Target + */ +Target CreateTarget(const std::string& target_name, const std::vector& options) { auto t = make_object(); t->target_name = target_name; @@ -101,8 +98,9 @@ Target CreateTarget(const std::string& target_name, // For now assume rocm schedule for opencl if (target_name == "opencl") { t->device_type = kDLOpenCL; - } else { + } else { // rocm t->device_type = kDLROCM; + t->thread_warp_size = 64; } t->keys_array.push_back(target_name); t->keys_array.push_back("gpu"); @@ -110,11 +108,13 @@ Target CreateTarget(const std::string& target_name, if (t->device_name == "intel_graphics") { t->thread_warp_size = 16; } - } else if (target_name == "metal" || target_name == "vulkan") { + } else if (target_name == "metal" || target_name == "vulkan" || target_name == "webgpu") { if (target_name == "metal") { t->device_type = kDLMetal; - } else { + } else if (target_name == "vulkan") { t->device_type = kDLVulkan; + } else { + t->device_type = kDLWebGPU; } t->keys_array.push_back(target_name); t->keys_array.push_back("gpu"); @@ -127,9 +127,6 @@ Target CreateTarget(const std::string& target_name, t->device_type = kDLAOCL; t->keys_array.push_back("aocl"); t->keys_array.push_back("hls"); - } else if (target_name == "opengl") { - t->device_type = kOpenGL; - t->keys_array.push_back("opengl"); } else if (target_name == "stackvm") { t->device_type = kDLCPU; } else if (target_name == "ext_dev") { @@ -139,16 +136,18 @@ Target CreateTarget(const std::string& target_name, } else if (target_name == "hexagon") { t->keys_array.push_back("hexagon"); t->device_type = kDLHexagon; + } else if (target_name == "webgpu") { + t->keys_array.push_back("webgpu"); + t->device_type = kDLWebGPU; } else { - LOG(ERROR) << "Unknown target name " << target_name; + LOG(ERROR) << "Unknown target name " << target_name << "; falling back to stackvm"; return target::stackvm(); } return Target(t); } -TVM_REGISTER_GLOBAL("target.TargetCreate") -.set_body([](TVMArgs args, TVMRetValue* ret) { +TVM_REGISTER_GLOBAL("target.TargetCreate").set_body([](TVMArgs args, TVMRetValue* ret) { std::string target_name = args[0]; std::vector options; for (int i = 1; i < args.num_args; ++i) { @@ -157,13 +156,12 @@ TVM_REGISTER_GLOBAL("target.TargetCreate") } *ret = CreateTarget(target_name, options); - }); +}); -TVM_REGISTER_GLOBAL("target.TargetFromString") -.set_body([](TVMArgs args, TVMRetValue* ret) { +TVM_REGISTER_GLOBAL("target.TargetFromString").set_body([](TVMArgs args, TVMRetValue* ret) { std::string target_str = args[0]; *ret = Target::Create(target_str); - }); +}); std::vector TargetNode::keys() const { std::vector result; @@ -193,14 +191,13 @@ const std::string& TargetNode::str() const { if (str_repr_.length() != 0) return str_repr_; std::ostringstream result; result << target_name; - for (const auto &x : options()) { + for (const auto& x : options()) { result << " " << x; } str_repr_ = result.str(); return str_repr_; } - bool StartsWith(const std::string& str, const std::string& pattern) { return str.compare(0, pattern.length(), pattern) == 0; } @@ -250,219 +247,71 @@ struct TVMTargetThreadLocalEntry { typedef dmlc::ThreadLocalStore TVMTargetThreadLocalStore; void Target::EnterWithScope() { - TVMTargetThreadLocalEntry *entry = TVMTargetThreadLocalStore::Get(); + TVMTargetThreadLocalEntry* entry = TVMTargetThreadLocalStore::Get(); entry->context_stack.push(*this); } void Target::ExitWithScope() { - TVMTargetThreadLocalEntry *entry = TVMTargetThreadLocalStore::Get(); + TVMTargetThreadLocalEntry* entry = TVMTargetThreadLocalStore::Get(); CHECK(!entry->context_stack.empty()); CHECK(entry->context_stack.top().same_as(*this)); entry->context_stack.pop(); } tvm::Target Target::Current(bool allow_not_defined) { - TVMTargetThreadLocalEntry *entry = TVMTargetThreadLocalStore::Get(); + TVMTargetThreadLocalEntry* entry = TVMTargetThreadLocalStore::Get(); if (entry->context_stack.size() > 0) { return entry->context_stack.top(); } CHECK(allow_not_defined) - << "Target context required. Please set it by constructing a TargetContext"; + << "Target context required. Please set it by constructing a TargetContext"; return Target(); } -TVM_REGISTER_GLOBAL("target.GetCurrentTarget") -.set_body([](TVMArgs args, TVMRetValue* ret) { +TVM_REGISTER_GLOBAL("target.GetCurrentTarget").set_body([](TVMArgs args, TVMRetValue* ret) { bool allow_not_defined = args[0]; *ret = Target::Current(allow_not_defined); - }); +}); class Target::Internal { public: - static void EnterScope(Target target) { - target.EnterWithScope(); - } - static void ExitScope(Target target) { - target.ExitWithScope(); - } + static void EnterScope(Target target) { target.EnterWithScope(); } + static void ExitScope(Target target) { target.ExitWithScope(); } }; -TVM_REGISTER_GLOBAL("target.EnterTargetScope") -.set_body_typed(Target::Internal::EnterScope); +TVM_REGISTER_GLOBAL("target.EnterTargetScope").set_body_typed(Target::Internal::EnterScope); -TVM_REGISTER_GLOBAL("target.ExitTargetScope") -.set_body_typed(Target::Internal::ExitScope); +TVM_REGISTER_GLOBAL("target.ExitTargetScope").set_body_typed(Target::Internal::ExitScope); namespace target { std::vector MergeOptions(std::vector opts, - const std::vector& new_opts) { + const std::vector& new_opts) { opts.insert(opts.end(), new_opts.begin(), new_opts.end()); return opts; } -Target llvm(const std::vector& options) { - return CreateTarget("llvm", options); -} +Target llvm(const std::vector& options) { return CreateTarget("llvm", options); } -Target cuda(const std::vector& options) { - return CreateTarget("cuda", options); -} +Target cuda(const std::vector& options) { return CreateTarget("cuda", options); } -Target rocm(const std::vector& options) { - return CreateTarget("rocm", options); -} +Target rocm(const std::vector& options) { return CreateTarget("rocm", options); } -Target opencl(const std::vector& options) { - return CreateTarget("opencl", options); -} +Target opencl(const std::vector& options) { return CreateTarget("opencl", options); } -Target metal(const std::vector& options) { - return CreateTarget("metal", options); -} +Target metal(const std::vector& options) { return CreateTarget("metal", options); } Target mali(const std::vector& options) { - return CreateTarget("opencl", MergeOptions(options, { - "-device=mali" - })); + return CreateTarget("opencl", MergeOptions(options, {"-device=mali"})); } Target intel_graphics(const std::vector& options) { - return CreateTarget("opencl", MergeOptions(options, { - "-device=intel_graphics" - })); + return CreateTarget("opencl", MergeOptions(options, {"-device=intel_graphics"})); } -Target stackvm(const std::vector& options) { - return CreateTarget("stackvm", options); -} +Target stackvm(const std::vector& options) { return CreateTarget("stackvm", options); } -Target ext_dev(const std::vector& options) { - return CreateTarget("ext_dev", options); -} +Target ext_dev(const std::vector& options) { return CreateTarget("ext_dev", options); } -Target hexagon(const std::vector& options) { - return CreateTarget("hexagon", options); -} +Target hexagon(const std::vector& options) { return CreateTarget("hexagon", options); } } // namespace target - -BuildConfig BuildConfig::Create() { - return BuildConfig(make_object()); -} - -/*! \brief Entry to hold the BuildConfig context stack. */ -struct TVMBuildConfigThreadLocalEntry { - /*! \brief The default build config if the stack is empty */ - BuildConfig default_config; - - /*! \brief The current build config context */ - std::stack context_stack; - - TVMBuildConfigThreadLocalEntry() : - default_config(BuildConfig::Create()) { - } -}; - -/*! \brief Thread local store to hold the BuildConfig context stack. */ -typedef dmlc::ThreadLocalStore TVMBuildConfigThreadLocalStore; - -void BuildConfig::EnterWithScope() { - TVMBuildConfigThreadLocalEntry *entry = TVMBuildConfigThreadLocalStore::Get(); - entry->context_stack.push(*this); -} - -void BuildConfig::ExitWithScope() { - TVMBuildConfigThreadLocalEntry *entry = TVMBuildConfigThreadLocalStore::Get(); - CHECK(!entry->context_stack.empty()); - CHECK(entry->context_stack.top().same_as(*this)); - entry->context_stack.pop(); -} - -tvm::BuildConfig BuildConfig::Current() { - TVMBuildConfigThreadLocalEntry *entry = TVMBuildConfigThreadLocalStore::Get(); - if (entry->context_stack.size() > 0) { - return entry->context_stack.top(); - } - - return entry->default_config; -} - -TVM_REGISTER_NODE_TYPE(BuildConfigNode); - -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "build_config("; - p->stream << "data_alignment=" << op->data_alignment << ", "; - p->stream << "offset_factor=" << op->offset_factor << ", "; - p->stream << "double_buffer_split_loop=" << op->double_buffer_split_loop << ", "; - p->stream << "auto_unroll_max_step=" << op->auto_unroll_max_step << ", "; - p->stream << "auto_unroll_max_depth=" << op->auto_unroll_max_depth << ", "; - p->stream << "auto_unroll_max_extent=" << op->auto_unroll_max_extent << ", "; - p->stream << "unroll_explicit=" << op->unroll_explicit << ", "; - p->stream << "restricted_func=" << op->restricted_func << ", "; - p->stream << "detect_global_barrier=" << op->detect_global_barrier << ", "; - p->stream << "partition_const_loop=" << op->partition_const_loop << ", "; - p->stream << "dump_pass_ir=" << op->dump_pass_ir << ", "; - p->stream << "instrument_bound_checkers=" << op->instrument_bound_checkers << ", "; - p->stream << "disable_select_rewriting=" << op->disable_select_rewriting; - p->stream << "disable_vectorize=" << op->disable_vectorize; - p->stream << "disable_assert=" << op->disable_assert; - p->stream << ")"; -}); - -TVM_REGISTER_GLOBAL("target.GetCurrentBuildConfig") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = BuildConfig::Current(); - }); - -class BuildConfig::Internal { - public: - static void EnterScope(BuildConfig target) { - target.EnterWithScope(); - } - static void ExitScope(BuildConfig target) { - target.ExitWithScope(); - } -}; - -TVM_REGISTER_GLOBAL("target.EnterBuildConfigScope") -.set_body_typed(BuildConfig::Internal::EnterScope); - -TVM_REGISTER_GLOBAL("target.ExitBuildConfigScope") -.set_body_typed(BuildConfig::Internal::ExitScope); - -TVM_REGISTER_GLOBAL("target.BuildConfigSetAddLowerPass") -.set_body([](TVMArgs args, TVMRetValue* ret) { - BuildConfig cfg = args[0]; - std::vector< std::pair > add_lower_pass; - CHECK_EQ(args.size() % 2, 1); - for (int i = 1; i < args.size(); i += 2) { - add_lower_pass.push_back(std::make_pair( - args[i].operator int(), - args[i + 1].operator tvm::runtime::PackedFunc())); - } - cfg->add_lower_pass = add_lower_pass; - }); - -TVM_REGISTER_GLOBAL("target.BuildConfigGetAddLowerPassInfo") -.set_body([](TVMArgs args, TVMRetValue* ret) { - // Return one of the following: - // * Size of add_lower_pass if num_args == 1 - // * Phase index of pass if args are (config, index, true) - // * Function of pass if args are (config, index, false) - BuildConfig cfg = args[0]; - if (args.num_args == 1) { - *ret = static_cast(cfg->add_lower_pass.size()); - } else { - int index = args[1]; - bool get_phase = args[2]; - auto item = cfg->add_lower_pass[index]; - if (get_phase) { - *ret = item.first; - } else { - *ret = item.second; - } - } -}); - } // namespace tvm diff --git a/src/target/target_info.cc b/src/target/target_info.cc index 73fe011cc936..5ebb7edc80dc 100644 --- a/src/target/target_info.cc +++ b/src/target/target_info.cc @@ -20,21 +20,21 @@ /*! * \file target/target_info.cc */ -#include #include +#include #include namespace tvm { TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "mem-info(" - << "unit_bits=" << op->unit_bits << ", " - << "max_num_bits=" << op->max_num_bits << ", " - << "max_simd_bits=" << op->max_simd_bits << ", " - << "head_address=" << op->head_address << ")"; -}); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "mem-info(" + << "unit_bits=" << op->unit_bits << ", " + << "max_num_bits=" << op->max_num_bits << ", " + << "max_simd_bits=" << op->max_simd_bits << ", " + << "head_address=" << op->head_address << ")"; + }); TVM_REGISTER_NODE_TYPE(MemoryInfoNode); diff --git a/src/te/autodiff/ad_util.cc b/src/te/autodiff/ad_util.cc index 3a90beff4822..89ff96d4724b 100644 --- a/src/te/autodiff/ad_util.cc +++ b/src/te/autodiff/ad_util.cc @@ -21,10 +21,12 @@ * \file ad_util.cc * \brief Utility for tensor-level auto-differentiation. */ +#include "ad_util.h" + #include -#include +#include + #include -#include "ad_util.h" namespace tvm { namespace te { @@ -33,9 +35,7 @@ std::pair, Map> CloneIterVars(const Array Array new_vars; Map vmap; for (const IterVar& iv : vars) { - IterVar new_v = - IterVarNode::make(iv->dom, iv->var.copy_with_suffix(""), - iv->iter_type, iv->thread_tag); + IterVar new_v = IterVar(iv->dom, iv->var.copy_with_suffix(""), iv->iter_type, iv->thread_tag); new_vars.push_back(new_v); vmap.Set(iv->var, new_v->var); } @@ -53,8 +53,8 @@ PrimExpr CloneReduction(const PrimExpr& expr) { src_with_newaxis.push_back(tir::Substitute(src, vmap)); } - return ReduceNode::make(red->combiner, src_with_newaxis, - new_axis, tir::Substitute(red->condition, vmap), red->value_index); + return Reduce(red->combiner, src_with_newaxis, new_axis, tir::Substitute(red->condition, vmap), + red->value_index); } else { return expr; } diff --git a/src/te/autodiff/ad_util.h b/src/te/autodiff/ad_util.h index 7e511b1c5a22..56ab6c18b929 100644 --- a/src/te/autodiff/ad_util.h +++ b/src/te/autodiff/ad_util.h @@ -24,11 +24,12 @@ #ifndef TVM_TE_AUTODIFF_AD_UTIL_H_ #define TVM_TE_AUTODIFF_AD_UTIL_H_ -#include #include -#include +#include + #include #include +#include namespace tvm { namespace te { diff --git a/src/te/autodiff/adjoint.cc b/src/te/autodiff/adjoint.cc index 0c54764e601a..772213da5cca 100644 --- a/src/te/autodiff/adjoint.cc +++ b/src/te/autodiff/adjoint.cc @@ -30,11 +30,12 @@ * (3) and sum them together to get the adjoint of the input itself. * The three steps are computed recursively. */ +#include +#include #include #include #include -#include -#include + #include #include @@ -47,27 +48,25 @@ Tensor Identity(const Tensor& output) { // add extra dimension for Jacobian shape.push_back(e); } - auto func = - [&output](const Array& input_indices) { - PrimExpr res = const_true(); - for (size_t i = 0; i < output->shape.size(); ++i) { - res = res && (PrimExpr(input_indices[i]) == - PrimExpr(input_indices[output->shape.size() + i])); - } - return CastNode::make(output->dtype, res); - }; + auto func = [&output](const Array& input_indices) { + PrimExpr res = const_true(); + for (size_t i = 0; i < output->shape.size(); ++i) { + res = + res && (PrimExpr(input_indices[i]) == PrimExpr(input_indices[output->shape.size() + i])); + } + return Cast(output->dtype, res); + }; return te::compute(shape, func, "identity"); } -Tensor VectorJacobianProduct(const Tensor &output, const Tensor &input, const Tensor &head) { +Tensor VectorJacobianProduct(const Tensor& output, const Tensor& input, const Tensor& head) { Tensor jac = Jacobian(output, input); Tensor result = topi::tensordot(head, jac, /*axes=*/output->shape.size(), output->op->name + "." + input->op->name + ".grad"); return result; } -Array Gradient(const Tensor& output, - const Array& inputs, +Array Gradient(const Tensor& output, const Array& inputs, const Tensor& head_or_null) { // Diagonal identity tensor Tensor head = head_or_null.get() ? head_or_null : Identity(output); @@ -95,41 +94,40 @@ Array Gradient(const Tensor& output, // This is a recursive function that does all the work. It computes the adjoint for a given // tensor, adds it to the map, and returns it std::function compute_adjoint; - compute_adjoint = - [&compute_adjoint, &adjoints, &reverse_dependencies, &head, &output] - (const Tensor& tensor) { - if (!adjoints.count(tensor)) { - // Here the adjoint hasn't been computed yet - Tensor res_adjoint; - std::vector direct_consumers = reverse_dependencies[tensor]; - if (direct_consumers.empty()) { - // No reverse dependencies means that the output does not depend on this tensor, - // return a zero tensor of the appropriate shape - // (i.e., output shape + tensor shape, aka shape of Jacobian) - Array result_shape(head->shape.begin(), - head->shape.end() + (-output->shape.size())); - for (auto e : tensor->shape) { - result_shape.push_back(e); - } - res_adjoint = topi::full(result_shape, output->dtype, make_zero(output->dtype)); - } else { - // The new adjoint is computed as a sum of the reverse dependencies' adjoints multiplied - // by the corresponding "local" jacobians (dDep/dTensor). The computation of the jacobian - // and the multiplication is done in the function VectorJacobianProduct - for (const Tensor& direct_consumer : direct_consumers) { - // part = (adjoint of direct_consumer) * Jacobian(direct_consumer, tensor) - Tensor part = VectorJacobianProduct( - direct_consumer, tensor, compute_adjoint(direct_consumer)); - res_adjoint = res_adjoint.get() ? topi::add(res_adjoint, part) : part; - } + compute_adjoint = [&compute_adjoint, &adjoints, &reverse_dependencies, &head, + &output](const Tensor& tensor) { + if (!adjoints.count(tensor)) { + // Here the adjoint hasn't been computed yet + Tensor res_adjoint; + std::vector direct_consumers = reverse_dependencies[tensor]; + if (direct_consumers.empty()) { + // No reverse dependencies means that the output does not depend on this tensor, + // return a zero tensor of the appropriate shape + // (i.e., output shape + tensor shape, aka shape of Jacobian) + Array result_shape(head->shape.begin(), + head->shape.end() + (-output->shape.size())); + for (auto e : tensor->shape) { + result_shape.push_back(e); } - - adjoints[tensor] = res_adjoint; - return res_adjoint; + res_adjoint = topi::full(result_shape, output->dtype, make_zero(output->dtype)); } else { - return adjoints[tensor]; + // The new adjoint is computed as a sum of the reverse dependencies' adjoints multiplied + // by the corresponding "local" jacobians (dDep/dTensor). The computation of the jacobian + // and the multiplication is done in the function VectorJacobianProduct + for (const Tensor& direct_consumer : direct_consumers) { + // part = (adjoint of direct_consumer) * Jacobian(direct_consumer, tensor) + Tensor part = + VectorJacobianProduct(direct_consumer, tensor, compute_adjoint(direct_consumer)); + res_adjoint = res_adjoint.get() ? topi::add(res_adjoint, part) : part; + } } - }; + + adjoints[tensor] = res_adjoint; + return res_adjoint; + } else { + return adjoints[tensor]; + } + }; // Adjoints corresponding to inputs Array result; @@ -141,15 +139,14 @@ Array Gradient(const Tensor& output, return result; } -TVM_REGISTER_GLOBAL("te.Gradient") -.set_body([](TVMArgs args, TVMRetValue *ret) { - LOG(WARNING) << "te.Gradient is an experimental feature."; - if (args.size() == 2) { - *ret = Gradient(args[0], args[1]); - } else if (args.size() == 3) { - *ret = Gradient(args[0], args[1], args[2]); - } - }); +TVM_REGISTER_GLOBAL("te.Gradient").set_body([](TVMArgs args, TVMRetValue* ret) { + LOG(WARNING) << "te.Gradient is an experimental feature."; + if (args.size() == 2) { + *ret = Gradient(args[0], args[1]); + } else if (args.size() == 3) { + *ret = Gradient(args[0], args[1], args[2]); + } +}); } // namespace te } // namespace tvm diff --git a/src/te/autodiff/jacobian.cc b/src/te/autodiff/jacobian.cc index 1a324588537f..1834aa3decf7 100644 --- a/src/te/autodiff/jacobian.cc +++ b/src/te/autodiff/jacobian.cc @@ -23,18 +23,23 @@ * X must be direct input tensor of Y. * The result Jacobian shape will be (Y.shape, X.shape) */ -#include +#include #include +#include #include -#include + #include + #include "ad_util.h" namespace tvm { namespace te { -#define NOT_IMPLEMENTED \ - { LOG(FATAL) << "Derivative of this expr is not implemented: " << GetRef(op); throw; } +#define NOT_IMPLEMENTED \ + { \ + LOG(FATAL) << "Derivative of this expr is not implemented: " << GetRef(op); \ + throw; \ + } /*! \brief Differentiate an expression wrt a variable or a tensor element */ class JacobianMutator : public ExprMutator { @@ -45,7 +50,7 @@ class JacobianMutator : public ExprMutator { * \param indices The indices of the element with respect to which to differentiate. */ explicit JacobianMutator(Tensor input, Array indices) - : input_(input), indices_(indices) {} + : input_(input), indices_(indices) {} /*! * \brief Differentiate wrt the input variable. * \param input The input variable. @@ -70,114 +75,93 @@ class JacobianMutator : public ExprMutator { } } - PrimExpr VisitExpr_(const LoadNode* op) NOT_IMPLEMENTED - PrimExpr VisitExpr_(const LetNode* op) NOT_IMPLEMENTED + PrimExpr VisitExpr_(const LoadNode* op) NOT_IMPLEMENTED; + PrimExpr VisitExpr_(const LetNode* op) NOT_IMPLEMENTED; + + PrimExpr VisitExpr_(const ProducerLoadNode* op) final { + auto tensor = Downcast(op->producer); + if (input_.get() && tensor == input_) { + // Tensor(indices) + CHECK_EQ(indices_.size(), op->indices.size()); + PrimExpr condition = const_true(); + for (size_t i = 0; i < input_.ndim(); ++i) { + condition = And(condition, EQ(indices_[i], op->indices[i])); + } + return Cast(op->dtype, condition); + } else { + return make_zero(op->dtype); + } + } PrimExpr VisitExpr_(const CallNode* op) { PrimExpr expr = GetRef(op); - if (op->call_type == CallNode::CallType::Halide) { - if (input_.get() && op->func.same_as(input_->op) && - op->value_index == input_->value_index) { - // Tensor(indices) - CHECK_EQ(indices_.size(), op->args.size()); - PrimExpr condition = const_true(); - for (size_t i = 0; i < input_.ndim(); ++i) { - condition = AndNode::make(condition, EQNode::make(indices_[i], op->args[i])); - } - return CastNode::make(op->dtype, condition); - } else { - return make_zero(op->dtype); - } - } else if (op->call_type == CallNode::CallType::PureIntrinsic) { + if (op->call_type == CallNode::CallType::PureIntrinsic) { static std::unordered_set piecewise_const = {"floor", "ceil", "trunc", "round"}; if (op->name == "exp") { - return MulNode::make(Mutate(op->args[0]), expr); + return Mul(Mutate(op->args[0]), expr); } else if (op->name == "log") { - return DivNode::make(Mutate(op->args[0]), op->args[0]); + return Div(Mutate(op->args[0]), op->args[0]); } else if (op->name == "sigmoid") { - return MulNode::make(Mutate(op->args[0]), - MulNode::make(expr, SubNode::make(FloatImm(expr.dtype(), 1.0), expr))); + return Mul(Mutate(op->args[0]), Mul(expr, Sub(FloatImm(expr.dtype(), 1.0), expr))); } else if (op->name == "sqrt") { - return DivNode::make(Mutate(op->args[0]), - MulNode::make(expr, FloatImm(expr.dtype(), 2.0))); + return Div(Mutate(op->args[0]), Mul(expr, FloatImm(expr.dtype(), 2.0))); } else if (op->name == "tanh") { - return MulNode::make(Mutate(op->args[0]), - SubNode::make(FloatImm(expr.dtype(), 1.0), MulNode::make(expr, expr))); + return Mul(Mutate(op->args[0]), Sub(FloatImm(expr.dtype(), 1.0), Mul(expr, expr))); } else if (op->name == "pow") { auto x = op->args[0], y = op->args[1]; - return expr * (Mutate(y)*log(x) + Mutate(x)*y/x); + return expr * (Mutate(y) * log(x) + Mutate(x) * y / x); } else if (op->name == "fabs") { auto type = op->args[0].dtype(); - return MulNode::make(Mutate(op->args[0]), - SelectNode::make(GENode::make(op->args[0], make_zero(type)), - FloatImm(type, 1.0), FloatImm(type, -1.0))); + return Mul(Mutate(op->args[0]), Select(GE(op->args[0], make_zero(type)), + FloatImm(type, 1.0), FloatImm(type, -1.0))); } else if (op->name == intrinsic::tvm_if_then_else) { - Array new_args = {op->args[0], - Mutate(op->args[1]), - Mutate(op->args[2])}; - return CallNode::make(op->dtype, op->name, new_args, - op->call_type, op->func, op->value_index); + Array new_args = {op->args[0], Mutate(op->args[1]), Mutate(op->args[2])}; + return Call(op->dtype, op->name, new_args, op->call_type); } else if (piecewise_const.count(op->name)) { return FloatImm(expr.dtype(), 0.0); } else { throw dmlc::Error("Derivative of this intrinsic is not implemented: " + op->name); } } - NOT_IMPLEMENTED + NOT_IMPLEMENTED; } - PrimExpr VisitExpr_(const AddNode* op) { - return AddNode::make(Mutate(op->a), Mutate(op->b)); - } + PrimExpr VisitExpr_(const AddNode* op) { return Add(Mutate(op->a), Mutate(op->b)); } - PrimExpr VisitExpr_(const SubNode* op) { - return SubNode::make(Mutate(op->a), Mutate(op->b)); - } + PrimExpr VisitExpr_(const SubNode* op) { return Sub(Mutate(op->a), Mutate(op->b)); } PrimExpr VisitExpr_(const MulNode* op) { - return AddNode::make( - MulNode::make(Mutate(op->a), op->b), - MulNode::make(op->a, Mutate(op->b))); + return Add(Mul(Mutate(op->a), op->b), Mul(op->a, Mutate(op->b))); } PrimExpr VisitExpr_(const DivNode* op) { - return DivNode::make( - SubNode::make( - MulNode::make(Mutate(op->a), op->b), - MulNode::make(op->a, Mutate(op->b))), - MulNode::make(op->b, op->b)); + return Div(Sub(Mul(Mutate(op->a), op->b), Mul(op->a, Mutate(op->b))), Mul(op->b, op->b)); } - PrimExpr VisitExpr_(const ModNode* op) NOT_IMPLEMENTED + PrimExpr VisitExpr_(const ModNode* op) NOT_IMPLEMENTED; PrimExpr VisitExpr_(const FloorDivNode* op) { - return FloorDivNode::make( - SubNode::make( - MulNode::make(Mutate(op->a), op->b), - MulNode::make(op->a, Mutate(op->b))), - MulNode::make(op->b, op->b)); + return FloorDiv(Sub(Mul(Mutate(op->a), op->b), Mul(op->a, Mutate(op->b))), Mul(op->b, op->b)); } - PrimExpr VisitExpr_(const FloorModNode* op) NOT_IMPLEMENTED + PrimExpr VisitExpr_(const FloorModNode* op) NOT_IMPLEMENTED; PrimExpr VisitExpr_(const MinNode* op) { - return SelectNode::make(LENode::make(op->a, op->b), - Mutate(op->a), Mutate(op->b)); + return Select(LE(op->a, op->b), Mutate(op->a), Mutate(op->b)); } PrimExpr VisitExpr_(const MaxNode* op) { - return SelectNode::make(GENode::make(op->a, op->b), - Mutate(op->a), Mutate(op->b)); + return Select(GE(op->a, op->b), Mutate(op->a), Mutate(op->b)); } - PrimExpr VisitExpr_(const EQNode* op) NOT_IMPLEMENTED - PrimExpr VisitExpr_(const NENode* op) NOT_IMPLEMENTED - PrimExpr VisitExpr_(const LTNode* op) NOT_IMPLEMENTED - PrimExpr VisitExpr_(const LENode* op) NOT_IMPLEMENTED - PrimExpr VisitExpr_(const GTNode* op) NOT_IMPLEMENTED - PrimExpr VisitExpr_(const GENode* op) NOT_IMPLEMENTED - PrimExpr VisitExpr_(const AndNode* op) NOT_IMPLEMENTED - PrimExpr VisitExpr_(const OrNode* op) NOT_IMPLEMENTED + PrimExpr VisitExpr_(const EQNode* op) NOT_IMPLEMENTED; + PrimExpr VisitExpr_(const NENode* op) NOT_IMPLEMENTED; + PrimExpr VisitExpr_(const LTNode* op) NOT_IMPLEMENTED; + PrimExpr VisitExpr_(const LENode* op) NOT_IMPLEMENTED; + PrimExpr VisitExpr_(const GTNode* op) NOT_IMPLEMENTED; + PrimExpr VisitExpr_(const GENode* op) NOT_IMPLEMENTED; + PrimExpr VisitExpr_(const AndNode* op) NOT_IMPLEMENTED; + PrimExpr VisitExpr_(const OrNode* op) NOT_IMPLEMENTED; PrimExpr VisitExpr_(const ReduceNode* op) { // This case is relatively difficult because a reduction expression @@ -229,12 +213,12 @@ class JacobianMutator : public ExprMutator { for (size_t i = 0; i < new_op->combiner->lhs.size(); ++i) { PrimExpr res_di = Derivative(res, new_op->combiner->lhs[i]); // new_lhs[i] is the derivative of lhs[i] (wrt our input tensor) - new_res = AddNode::make(new_res, MulNode::make(new_lhs[i], res_di)); + new_res = Add(new_res, Mul(new_lhs[i], res_di)); } for (size_t i = 0; i < new_op->combiner->rhs.size(); ++i) { PrimExpr res_di = Derivative(res, new_op->combiner->rhs[i]); // new_rhs[i] is the derivative of rhs[i] (wrt our input tensor) - new_res = AddNode::make(new_res, MulNode::make(new_rhs[i], res_di)); + new_res = Add(new_res, Mul(new_rhs[i], res_di)); } new_result.push_back(new_res); } @@ -261,47 +245,42 @@ class JacobianMutator : public ExprMutator { new_source.push_back(src); } - CommReducer new_combiner = CommReducerNode::make(new_lhs, new_rhs, new_result, new_identity); + CommReducer new_combiner = CommReducer(new_lhs, new_rhs, new_result, new_identity); // Also simplify the resulting combiner // (mostly to get rid of unused components, e.g., the original expressions) - return Simplify( - ReduceNode::make(new_combiner, new_source, new_op->axis, - new_op->condition, new_op->value_index)); + return analyzer_.Simplify( + Reduce(new_combiner, new_source, new_op->axis, new_op->condition, new_op->value_index)); } PrimExpr VisitExpr_(const CastNode* op) { if (op->dtype.is_float()) { - return CastNode::make(op->dtype, Mutate(op->value)); + return Cast(op->dtype, Mutate(op->value)); } else { return make_zero(op->dtype); } } - PrimExpr VisitExpr_(const NotNode* op) NOT_IMPLEMENTED + PrimExpr VisitExpr_(const NotNode* op) NOT_IMPLEMENTED; PrimExpr VisitExpr_(const SelectNode* op) { - return SelectNode::make(op->condition, - Mutate(op->true_value), Mutate(op->false_value)); + return Select(op->condition, Mutate(op->true_value), Mutate(op->false_value)); } - PrimExpr VisitExpr_(const RampNode* op) NOT_IMPLEMENTED - PrimExpr VisitExpr_(const BroadcastNode* op) NOT_IMPLEMENTED - PrimExpr VisitExpr_(const ShuffleNode* op) NOT_IMPLEMENTED + PrimExpr VisitExpr_(const RampNode* op) NOT_IMPLEMENTED; + PrimExpr VisitExpr_(const BroadcastNode* op) NOT_IMPLEMENTED; + PrimExpr VisitExpr_(const ShuffleNode* op) NOT_IMPLEMENTED; - PrimExpr VisitExpr_(const IntImmNode* op) { - return IntImm(op->dtype, 0); - } + PrimExpr VisitExpr_(const IntImmNode* op) { return IntImm(op->dtype, 0); } - PrimExpr VisitExpr_(const FloatImmNode* op) { - return FloatImm(op->dtype, 0); - } + PrimExpr VisitExpr_(const FloatImmNode* op) { return FloatImm(op->dtype, 0); } - PrimExpr VisitExpr_(const StringImmNode* op) NOT_IMPLEMENTED + PrimExpr VisitExpr_(const StringImmNode* op) NOT_IMPLEMENTED; private: Tensor input_; Array indices_; Var input_var_; + arith::Analyzer analyzer_; }; PrimExpr Derivative(const PrimExpr& expr, const Var& var) { @@ -334,18 +313,18 @@ Tensor Jacobian(const Tensor& output, const Tensor& input) { Array input_indices; size_t i = 0; for (PrimExpr ext : input->shape) { - IterVar new_v = IterVarNode::make(Range(0, ext), Var("jac_i" + std::to_string(i++)), - IterVarType::kDataPar); + IterVar new_v = + IterVar(Range(0, ext), Var("jac_i" + std::to_string(i++)), IterVarType::kDataPar); // Append jacobian iter to new_axis new_axis.push_back(new_v); // Differentiate wrt input[input_indices] input_indices.push_back(new_v); } - + arith::Analyzer analzyer; // Compute Jacobian - PrimExpr new_body = Jacobian( - Substitute(op->body[output->value_index], vmap), input, input_indices); - new_body = Simplify(new_body); + PrimExpr new_body = + Jacobian(Substitute(op->body[output->value_index], vmap), input, input_indices); + new_body = analzyer.Simplify(new_body); int value_index = 0; Array new_bodies; @@ -355,15 +334,13 @@ Tensor Jacobian(const Tensor& output, const Tensor& input) { if (const ReduceNode* red = new_body.as()) { value_index = red->value_index; for (size_t idx = 0; idx < red->source.size(); ++idx) { - new_bodies.push_back( - ReduceNode::make(red->combiner, red->source, red->axis, red->condition, idx)); + new_bodies.push_back(Reduce(red->combiner, red->source, red->axis, red->condition, idx)); } } else { new_bodies.push_back(new_body); } - auto new_op = ComputeOpNode::make( - op->name + ".jacobian", op->tag, op->attrs, new_axis, new_bodies); + auto new_op = ComputeOp(op->name + ".jacobian", op->tag, op->attrs, new_axis, new_bodies); // Jacobian shape = output.shape + input.shape Array new_shape = output->shape; @@ -371,7 +348,7 @@ Tensor Jacobian(const Tensor& output, const Tensor& input) { new_shape.push_back(e); } - return TensorNode::make(new_shape, output->dtype, new_op, value_index); + return Tensor(new_shape, output->dtype, new_op, value_index); } } // namespace te diff --git a/src/te/operation/compute_op.cc b/src/te/operation/compute_op.cc index 6f703c9ec4e3..1fc0520143fb 100644 --- a/src/te/operation/compute_op.cc +++ b/src/te/operation/compute_op.cc @@ -4,7 +4,7 @@ * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance +5B * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 @@ -21,46 +21,44 @@ * \brief Compute Op. * \file compute_op.cc */ +#include "compute_op.h" + +#include #include #include -#include +#include #include -#include #include -#include + #include +#include #include -#include "compute_op.h" -#include "op_util.h" -#include "../schedule/message_passing.h" -#include "../../arith/compute_expr.h" + #include "../../arith/interval_set.h" +#include "../schedule/message_passing.h" +#include "op_util.h" namespace tvm { namespace te { using namespace tir; TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "compute(" << op->name << ", " << op << ")"; -}); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "compute(" << op->name << ", " << op << ")"; + }); TVM_REGISTER_NODE_TYPE(ComputeOpNode); /// Verify if ComputeOp is valid with respect to Reduce operations. -static void VerifyComputeOp(const ComputeOpNode *op); +static void VerifyComputeOp(const ComputeOpNode* op); inline bool ReduceEqual(const tir::ReduceNode* a, const tir::ReduceNode* b) { - return (a->combiner.same_as(b->combiner)) && - (a->source.same_as(b->source)) && - (a->axis.same_as(b->axis)) && - (a->condition.same_as(b->condition)); + return (a->combiner.same_as(b->combiner)) && (a->source.same_as(b->source)) && + (a->axis.same_as(b->axis)) && (a->condition.same_as(b->condition)); } -int ComputeOpNode::num_outputs() const { - return body.size(); -} +int ComputeOpNode::num_outputs() const { return body.size(); } Array BaseComputeOpNode::root_iter_vars() const { if (reduce_axis.size() == 0) return axis; @@ -87,11 +85,8 @@ Array BaseComputeOpNode::output_shape(size_t idx) const { return shape; } -Tensor compute(Array shape, - FCompute fcompute, - std::string name, - std::string tag, - Map attrs) { +Tensor compute(Array shape, FCompute fcompute, std::string name, std::string tag, + Map attrs) { auto op_node = make_object(); // compute dimension. size_t ndim = shape.size(); @@ -100,20 +95,15 @@ Tensor compute(Array shape, for (size_t i = 0; i < ndim; ++i) { std::ostringstream os; os << "ax" << i; - axis.emplace_back(IterVarNode::make( - Range(0, shape[i]), Var(os.str(), shape[i].dtype()), kDataPar)); + axis.emplace_back(IterVar(Range(0, shape[i]), Var(os.str(), shape[i].dtype()), kDataPar)); args.push_back(axis.back()->var); } - return ComputeOpNode::make( - name, tag, attrs, axis, {fcompute(args)}).output(0); + return ComputeOp(name, tag, attrs, axis, {fcompute(args)}).output(0); } -Array compute(Array shape, - FBatchCompute fcompute, - std::string name, - std::string tag, - Map attrs) { +Array compute(Array shape, FBatchCompute fcompute, std::string name, + std::string tag, Map attrs) { auto op_node = make_object(); // compute dimension. size_t ndim = shape.size(); @@ -122,12 +112,11 @@ Array compute(Array shape, for (size_t i = 0; i < ndim; ++i) { std::ostringstream os; os << "ax" << i; - axis.emplace_back(IterVarNode::make( - Range(0, shape[i]), Var(os.str(), shape[i].dtype()), kDataPar)); + axis.emplace_back(IterVar(Range(0, shape[i]), Var(os.str(), shape[i].dtype()), kDataPar)); args.push_back(axis.back()->var); } - Operation op = ComputeOpNode::make(name, tag, attrs, axis, fcompute(args)); + Operation op = ComputeOp(name, tag, attrs, axis, fcompute(args)); Array outputs; for (int idx = 0; idx < op->num_outputs(); ++idx) { outputs.push_back(op.output(idx)); @@ -135,13 +124,10 @@ Array compute(Array shape, return outputs; } -Operation ComputeOpNode::make(std::string name, - std::string tag, - Map attrs, - Array axis, - Array body) { +ComputeOp::ComputeOp(std::string name, std::string tag, Map attrs, + Array axis, Array body) { if (!attrs.defined()) { - attrs = Map(); + attrs = Map(); } auto n = make_object(); n->name = std::move(name); @@ -154,12 +140,13 @@ Operation ComputeOpNode::make(std::string name, n->reduce_axis = reduce->axis; } VerifyComputeOp(n.get()); - return Operation(n); + data_ = std::move(n); } TVM_REGISTER_GLOBAL("te.ComputeOp") -.set_body_typed(ComputeOpNode::make); - + .set_body_typed([](std::string name, std::string tag, Map attrs, + Array axis, + Array body) { return ComputeOp(name, tag, attrs, axis, body); }); // The schedule related logics Array ComputeOpNode::InputTensors() const { @@ -167,22 +154,20 @@ Array ComputeOpNode::InputTensors() const { std::unordered_set visited; for (auto& e : body) { tir::PostOrderVisit(e, [&ret, &visited](const ObjectRef& n) { - const tir::CallNode *call = n.as(); - if (call != nullptr && call->func.defined()) { - Tensor t = Downcast(call->func).output(call->value_index); - if (!visited.count(t)) { - ret.push_back(t); - visited.insert(t); - } + if (auto* pload = n.as()) { + Tensor t = Downcast(pload->producer); + if (!visited.count(t)) { + ret.push_back(t); + visited.insert(t); } - }); + } + }); } return ret; } -Operation ComputeOpNode::ReplaceInputs( - const Operation& self, - const std::unordered_map& rmap) const { +Operation ComputeOpNode::ReplaceInputs(const Operation& self, + const std::unordered_map& rmap) const { CHECK_EQ(self.operator->(), this); VerifyComputeOp(this); Array arr; @@ -202,28 +187,23 @@ Operation ComputeOpNode::ReplaceInputs( arr = this->body; } } else { - arr = UpdateArray(this->body, [&rmap] (const PrimExpr& e) { - return te::ReplaceTensor(e, rmap); - }); + arr = + UpdateArray(this->body, [&rmap](const PrimExpr& e) { return te::ReplaceTensor(e, rmap); }); } if (!arr.same_as(this->body)) { - return ComputeOpNode::make( - this->name, this->tag, this->attrs, this->axis, arr); + return ComputeOp(this->name, this->tag, this->attrs, this->axis, arr); } else { return self; } } -void ComputeOpNode::PropBoundToInputs( - const Operation& self, - arith::Analyzer* analyzer, - const std::unordered_map& dom_map, - std::unordered_map* out_dom_map) const { +void ComputeOpNode::PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer, + const std::unordered_map& dom_map, + std::unordered_map* out_dom_map) const { CHECK_EQ(self.operator->(), this); auto fvisit = [&dom_map, out_dom_map, analyzer](const ObjectRef& n) { - auto *call = n.as(); - if (call != nullptr && call->func.defined()) { - Tensor t = Downcast(call->func).output(call->value_index); + if (auto* pload = n.as()) { + Tensor t = Downcast(pload->producer); if (t->op.defined() && out_dom_map->count(t)) { TensorDom& dom = out_dom_map->at(t); for (size_t i = 0; i < t.ndim(); ++i) { @@ -231,7 +211,7 @@ void ComputeOpNode::PropBoundToInputs( // undefined behaviour), so we can intersect the estimated set of the argument with the // range expected by the tensor. However, intersection may result in overly complex // expressions, so we perform a more relaxed form of intersection. - IntSet arg_intset = EvalSet(call->args[i], dom_map); + IntSet arg_intset = analyzer->int_set(pload->indices[i], ConvertDomMap(dom_map)); const arith::IntervalSetNode* arg_interval = arg_intset.as(); if (arg_interval) { PrimExpr shape_i_min_value = make_zero(t->shape[i].dtype()); @@ -239,12 +219,14 @@ void ComputeOpNode::PropBoundToInputs( PrimExpr min_value = arg_interval->min_value; PrimExpr max_value = arg_interval->max_value; // Prefer the shape bounds only when we can prove they are tighter. - if (arith::is_neg_inf(min_value) || - analyzer->CanProve(shape_i_min_value >= min_value)) { + // We must update bound's ends in pairs. Here is an counter example: shape_i is + // [0, 0] and arg_interval is [threadIdx.y, threadIdx.y], where threadIdx.y's range is + // [0, 7]. If we allowed updating one end, the bound would become [threadIdx.y, 0], + // awkward for further analysis. + if ((arith::is_pos_inf(max_value) && arith::is_neg_inf(min_value)) || + (analyzer->CanProve(shape_i_min_value >= min_value) && + analyzer->CanProve(shape_i_max_value <= max_value))) { min_value = shape_i_min_value; - } - if (arith::is_pos_inf(max_value) || - analyzer->CanProve(shape_i_max_value <= max_value)) { max_value = shape_i_max_value; } dom.data[i].push_back(IntSet::interval(min_value, max_value)); @@ -258,10 +240,9 @@ void ComputeOpNode::PropBoundToInputs( for (auto& e : body) tir::PostOrderVisit(e, fvisit); } -void BaseComputeOpNode::GatherBound( - const Operation& self, - const std::unordered_map& tensor_dom, - std::unordered_map* out_dom_map) const { +void BaseComputeOpNode::GatherBound(const Operation& self, + const std::unordered_map& tensor_dom, + std::unordered_map* out_dom_map) const { CHECK_EQ(self.operator->(), this); const TensorDom& tdom = tensor_dom.at(self.output(0)); for (size_t i = 0; i < this->axis.size(); ++i) { @@ -275,10 +256,9 @@ void BaseComputeOpNode::GatherBound( } } -Stmt BaseComputeOpNode::BuildRealize( - const Stage& stage, - const std::unordered_map& realize_map, - const Stmt& body) const { +Stmt BaseComputeOpNode::BuildRealize(const Stage& stage, + const std::unordered_map& realize_map, + const Stmt& body) const { CHECK_EQ(stage->op.get(), this); Region bounds; for (IterVar iv : this->axis) { @@ -286,23 +266,19 @@ Stmt BaseComputeOpNode::BuildRealize( } Stmt realize = body; for (int i = this->num_outputs(); i > 0; --i) { - Tensor t = stage->op.output(i-1); - realize = tir::RealizeNode::make(t->op, t->value_index, - t->dtype, bounds, const_true(), realize); + Tensor t = stage->op.output(i - 1); + realize = tir::ProducerRealize(t, bounds, const_true(), realize); // alignment requirement, only useful for compute for (size_t i = 0; i < num_schedulable_dims(); ++i) { auto it = stage->iter_var_attrs.find(this->axis[i]); if (it != stage->iter_var_attrs.end()) { IterVarAttr attr = (*it).second; if (attr->dim_align_factor != 0) { - Array tuple = {static_cast(i), - attr->dim_align_factor, - attr->dim_align_offset}; - realize = tir::AttrStmtNode::make( + Array tuple = {static_cast(i), attr->dim_align_factor, + attr->dim_align_offset}; + realize = tir::AttrStmt( t, tir::attr::buffer_dim_align, - CallNode::make(DataType::Handle(), - tir::intrinsic::tvm_tuple, - tuple, CallNode::Intrinsic), + Call(DataType::Handle(), tir::intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), realize); } } @@ -311,16 +287,12 @@ Stmt BaseComputeOpNode::BuildRealize( return realize; } -size_t ComputeOpNode::num_schedulable_dims() const { - return axis.size(); -} +size_t ComputeOpNode::num_schedulable_dims() const { return axis.size(); } // Build a reduction body. -void MakeReduction(const ComputeOpNode* op, - const Array& tensors, - Stmt* init, +void MakeReduction(const ComputeOpNode* op, const Array& tensors, Stmt* init, Stmt* provide) { - Array args; + Array args; for (IterVar iv : op->axis) { args.push_back(iv->var); } @@ -339,34 +311,30 @@ void MakeReduction(const ComputeOpNode* op, Array update_value = (*combiner)(lhs, reduce->source); for (size_t i = 0; i < size; ++i) { Tensor t = tensors[i]; - inits.emplace_back(ProvideNode::make( - t->op, t->value_index, init_value[i], args)); - provides.emplace_back(ProvideNode::make( - t->op, t->value_index, update_value[i], args)); + inits.emplace_back(ProducerStore(t, init_value[i], args)); + provides.emplace_back(ProducerStore(t, update_value[i], args)); } *init = SeqStmt::Flatten(inits); *provide = SeqStmt::Flatten(provides); if (!is_one(reduce->condition)) { - *provide = IfThenElseNode::make(reduce->condition, *provide); + *provide = IfThenElse(reduce->condition, *provide); } } // Normal computation. -Stmt MakeProvide(const ComputeOpNode* op, - const Tensor& t) { +Stmt MakeProvide(const ComputeOpNode* op, const Tensor& t) { Array args; for (IterVar iv : op->axis) { args.push_back(iv->var); } - return ProvideNode::make(t->op, t->value_index, op->body[t->value_index], args); + return ProducerStore(t, op->body[t->value_index], args); } -Stmt MakeComputeStmt(const ComputeOpNode* self, - const Stage& stage, +Stmt MakeComputeStmt(const ComputeOpNode* self, const Stage& stage, const std::unordered_map& dom_map, bool debug_keep_trivial_loop) { // grab the nest structure - ComputeLoopNest n = ComputeLoopNest::make(self, stage, dom_map, debug_keep_trivial_loop); + ComputeLoopNest n = ComputeLoopNest::Create(self, stage, dom_map, debug_keep_trivial_loop); // Normal loop structure n.init_nest.emplace_back(MakeIfNest(n.init_predicates)); n.main_nest.emplace_back(MakeIfNest(n.main_predicates)); @@ -381,10 +349,10 @@ Stmt MakeComputeStmt(const ComputeOpNode* self, init = MergeNest(n.init_nest, init); init = Substitute(init, n.init_vmap); // common nest - std::vector > common( - n.main_nest.begin(), n.main_nest.begin() + n.num_common_loop + 1); - std::vector > reduce( - n.main_nest.begin() + n.num_common_loop + 1, n.main_nest.end()); + std::vector > common(n.main_nest.begin(), + n.main_nest.begin() + n.num_common_loop + 1); + std::vector > reduce(n.main_nest.begin() + n.num_common_loop + 1, + n.main_nest.end()); provide = MergeNest(reduce, provide); if (debug_keep_trivial_loop) { provide = MergeNest(common, provide); @@ -407,14 +375,9 @@ Stmt MakeComputeStmt(const ComputeOpNode* self, } } -enum class ComputeType { - kNormal, - kCrossThreadReduction, - kTensorize -}; +enum class ComputeType { kNormal, kCrossThreadReduction, kTensorize }; -ComputeType DetectComputeType(const ComputeOpNode* self, - const Stage& stage) { +ComputeType DetectComputeType(const ComputeOpNode* self, const Stage& stage) { // Verify correctness of leaf nest. int normal_red = 0, thread_red = 0, tensorize = 0; @@ -434,13 +397,11 @@ ComputeType DetectComputeType(const ComputeOpNode* self, ++normal_red; } } else { - CHECK_EQ(thread_red, 0) - << "Cross thread reduce cannot swap with normal data axis"; + CHECK_EQ(thread_red, 0) << "Cross thread reduce cannot swap with normal data axis"; } } if (tensorize != 0) { - CHECK(thread_red == 0) - << "Cannot mix cross thread reduction with Tensorize"; + CHECK(thread_red == 0) << "Cannot mix cross thread reduction with Tensorize"; return ComputeType::kTensorize; } if (thread_red != 0) { @@ -451,10 +412,9 @@ ComputeType DetectComputeType(const ComputeOpNode* self, } // implement the provide utility. -Stmt ComputeOpNode::BuildProvide( - const Stage& stage, - const std::unordered_map& dom_map, - bool debug_keep_trivial_loop) const { +Stmt ComputeOpNode::BuildProvide(const Stage& stage, + const std::unordered_map& dom_map, + bool debug_keep_trivial_loop) const { CHECK_EQ(stage->op.operator->(), this); ComputeType ctype = DetectComputeType(this, stage); if (ctype == ComputeType::kCrossThreadReduction) { @@ -467,20 +427,16 @@ Stmt ComputeOpNode::BuildProvide( } } -ComputeLoopNest ComputeLoopNest::make( - const BaseComputeOpNode* self, - const Stage& stage, - const std::unordered_map& dom_map, - bool debug_keep_trivial_loop) { +ComputeLoopNest ComputeLoopNest::Create(const BaseComputeOpNode* self, const Stage& stage, + const std::unordered_map& dom_map, + bool debug_keep_trivial_loop) { CHECK_EQ(stage->op.operator->(), self); ComputeLoopNest ret; // make main loop nest - ret.main_nest = MakeLoopNest( - stage, dom_map, 0, false, std::unordered_set(), &ret.main_vmap, - debug_keep_trivial_loop); - ret.main_predicates = MakeBoundCheck( - stage, dom_map, ret.main_vmap, false, - std::unordered_set()); + ret.main_nest = MakeLoopNest(stage, dom_map, 0, false, std::unordered_set(), + &ret.main_vmap, debug_keep_trivial_loop); + ret.main_predicates = + MakeBoundCheck(stage, dom_map, ret.main_vmap, false, std::unordered_set()); for (auto& e : ret.main_predicates) { e = likely(e); } @@ -506,7 +462,8 @@ ComputeLoopNest ComputeLoopNest::make( auto iv = leaf_iter_vars[i]; int flag = update_state.at(iv); if ((flag & 2) != 0) { - begin_loop = i; break; + begin_loop = i; + break; } ret.init_vmap[iv] = ret.main_vmap.at(iv); } @@ -517,11 +474,9 @@ ComputeLoopNest ComputeLoopNest::make( int flag = kv.second; if (flag == 2) skip_iter.insert(kv.first); } - ret.init_nest = MakeLoopNest( - stage, dom_map, begin_loop, true, - skip_iter, &(ret.init_vmap), debug_keep_trivial_loop); - ret.init_predicates = MakeBoundCheck( - stage, dom_map, ret.init_vmap, true, skip_iter); + ret.init_nest = MakeLoopNest(stage, dom_map, begin_loop, true, skip_iter, &(ret.init_vmap), + debug_keep_trivial_loop); + ret.init_predicates = MakeBoundCheck(stage, dom_map, ret.init_vmap, true, skip_iter); for (auto& e : ret.init_predicates) { e = likely(e); } @@ -561,14 +516,12 @@ class ComputeVerifier final : protected tir::ExprVisitor { for (const PrimExpr e : compute_->body) { // Check for consistency of top level reductions const tir::ReduceNode* reduce = e.as(); - CHECK((reduce && reduce_) || (!reduce && !reduce_)) - << "All ComputeOp should be consistent " - << "with being Reduce operation or not."; + CHECK((reduce && reduce_) || (!reduce && !reduce_)) << "All ComputeOp should be consistent " + << "with being Reduce operation or not."; if (reduce && reduce_) { - CHECK(ReduceEqual(reduce, reduce_)) - << "The Reduce inputs of ComputeOp should " - << "have the same attribute except value_index"; + CHECK(ReduceEqual(reduce, reduce_)) << "The Reduce inputs of ComputeOp should " + << "have the same attribute except value_index"; } level_ = 0; @@ -587,16 +540,15 @@ class ComputeVerifier final : protected tir::ExprVisitor { void VisitExpr_(const tir::ReduceNode* op) final { // Check for non top level reductions - CHECK(0 == level_) - << "Reductions are only allowed at the top level of compute. " - << "Please create another tensor for further composition."; + CHECK(0 == level_) << "Reductions are only allowed at the top level of compute. " + << "Please create another tensor for further composition."; } //@} private: - const ComputeOpNode* compute_{nullptr}; ///< ComputeOpNode to verify - const tir::ReduceNode* reduce_{nullptr}; ///< Top level Reduce operation - int level_{0}; ///< Level of op being processed + const ComputeOpNode* compute_{nullptr}; ///< ComputeOpNode to verify + const tir::ReduceNode* reduce_{nullptr}; ///< Top level Reduce operation + int level_{0}; ///< Level of op being processed }; } // namespace @@ -606,11 +558,8 @@ static void VerifyComputeOp(const ComputeOpNode* op) { v.Run(); } -Stmt TransformUpdate(const Stage& stage, - const std::unordered_map& dom_map, - const ComputeLoopNest& n, - Stmt body, - Stmt update) { +Stmt TransformUpdate(const Stage& stage, const std::unordered_map& dom_map, + const ComputeLoopNest& n, Stmt body, Stmt update) { Array conds; std::unordered_set banned; for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) { @@ -630,15 +579,18 @@ Stmt TransformUpdate(const Stage& stage, banned.insert(iv->var.get()); } } + + auto fbanned = [&](const VarNode* node) { return banned.count(node); }; + for (const PrimExpr& pred : n.main_predicates) { - if (tir::ExprUseVar(pred, banned)) { - LOG(FATAL) << "Tensorize update transform failed, the condition " - << pred << " has a conflict with the reset condition"; + if (tir::ExprUseVar(pred, fbanned)) { + LOG(FATAL) << "Tensorize update transform failed, the condition " << pred + << " has a conflict with the reset condition"; } } - return IfThenElseNode::make(arith::ComputeReduce(conds, const_true(1)), - update, body); + auto cond = foldl([](PrimExpr a, PrimExpr b) { return a || b; }, const_false(1), conds); + return IfThenElse(cond, update, body); } } // namespace te diff --git a/src/te/operation/compute_op.h b/src/te/operation/compute_op.h index 08db74f0d9a5..2661eb976f2e 100644 --- a/src/te/operation/compute_op.h +++ b/src/te/operation/compute_op.h @@ -24,10 +24,11 @@ #ifndef TVM_TE_OPERATION_COMPUTE_OP_H_ #define TVM_TE_OPERATION_COMPUTE_OP_H_ -#include #include -#include +#include + #include +#include namespace tvm { namespace te { @@ -58,11 +59,9 @@ struct ComputeLoopNest { * \param debug_keep_trivial_loop Whether keep trivial loops with extent of 1 * \return The constructed loop nest */ - static ComputeLoopNest make( - const BaseComputeOpNode* self, - const Stage& stage, - const std::unordered_map& dom_map, - bool debug_keep_trivial_loop); + static ComputeLoopNest Create(const BaseComputeOpNode* self, const Stage& stage, + const std::unordered_map& dom_map, + bool debug_keep_trivial_loop); }; /*! @@ -73,11 +72,9 @@ struct ComputeLoopNest { * \param debug_keep_trivial_loop Whether keep trivial loops with extent of 1 * \return The created statement. */ -Stmt MakeCrossThreadReduction( - const ComputeOpNode* self, - const Stage& stage, - const std::unordered_map& dom_map, - bool debug_keep_trivial_loop); +Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage, + const std::unordered_map& dom_map, + bool debug_keep_trivial_loop); /*! * \brief Build body of compute for tensorization. @@ -87,10 +84,8 @@ Stmt MakeCrossThreadReduction( * \param debug_keep_trivial_loop Whether keep trivial loops with extent of 1 * \return The created statement. */ -Stmt MakeTensorize(const ComputeOpNode* self, - const Stage& stage, - const std::unordered_map& dom_map, - bool debug_keep_trivial_loop); +Stmt MakeTensorize(const ComputeOpNode* self, const Stage& stage, + const std::unordered_map& dom_map, bool debug_keep_trivial_loop); /*! * \brief Transform the update part when there is no init func in tensorizing @@ -101,11 +96,8 @@ Stmt MakeTensorize(const ComputeOpNode* self, * \param update The update func in tensorize intrin * \return Transformed result. */ -Stmt TransformUpdate(const Stage& stage, - const std::unordered_map& dom_map, - const ComputeLoopNest& n, - Stmt body, - Stmt update); +Stmt TransformUpdate(const Stage& stage, const std::unordered_map& dom_map, + const ComputeLoopNest& n, Stmt body, Stmt update); } // namespace te } // namespace tvm diff --git a/src/te/operation/cross_thread_reduction.cc b/src/te/operation/cross_thread_reduction.cc index 1b3d87d57006..e834ff279d05 100644 --- a/src/te/operation/cross_thread_reduction.cc +++ b/src/te/operation/cross_thread_reduction.cc @@ -21,7 +21,6 @@ * \brief Logics related to cross thread reduction, used by ComputeOpNode. * \file cross_thread_reduction.cc */ -#include #include "compute_op.h" #include "op_util.h" @@ -29,21 +28,66 @@ namespace tvm { namespace te { using namespace tir; -Stmt MakeCrossThreadReduction( - const ComputeOpNode* self, - const Stage& stage, - const std::unordered_map& dom_map, - bool debug_keep_trivial_loop) { - Array args; +// +// Cross thread reduction transformation. +// +// The input loop nest in generic form (single reduction/thread case) +// +// let m be the reduction extent +// let N be the thread extent +// let input_pred be the predicate on the reduction +// +// B[..] = 0 +// for (tid, 0, N) +// for (i, 0, floordiv(m+N-1, N)) +// if (i + tid * floordiv(m+N-1, N) < m) +// if (input_pred) +// B[..] = op(B[..], A[i + tid * floordiv(m+N-1,N)]) +// +// The threaded reduction looks like +// +// (1) normal reductions (leaves) +// for (i, 0, floordiv(m+N-1, N)) +// if (i + tid * floordiv(m+N-1, N) < m) +// if (input_pred) +// B_temp[0] = op(B_temp[0], A[i + tid * floordiv(m+N-1,N)]) +// +// (2) threaded reduction does not require predicates as an identity +// element will be filled if out of bounds. +// +// tvm_thread_allreduce(size, B_temp, (bool)1, tid) +// +// The last step is to write the final reduction variable, +// which should be predicated by the existing input_pred if any +// The consequence is that input_pred should be independent of +// the reduction axis. Otherwise, we need to seperate it into +// dependent part and independent one. +// +// (3) write back +// if (input_pred) +// B[..] = B_temp[0] +// +// In summary, we are going to need two predicates +// +// * the original input_pred from reduction itself +// +// * the normal reduction axis predicate +// normal_pred = (i + tid * floordiv(m+N-1,N)) < m +// this predicate depends on the normal reduction variable. +// +// input_pred will be applied to both normal reduction and +// the writeback step. +// +Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage, + const std::unordered_map& dom_map, + bool debug_keep_trivial_loop) { + Array args; for (IterVar iv : self->axis) { args.push_back(iv->var); } std::unordered_map value_map; - auto nest = MakeLoopNest( - stage, dom_map, 0, false, std::unordered_set(), &value_map, debug_keep_trivial_loop); - auto conds = MakeBoundCheck( - stage, dom_map, value_map, false, - std::unordered_set()); + auto nest = MakeLoopNest(stage, dom_map, 0, false, std::unordered_set(), &value_map, + debug_keep_trivial_loop); size_t size = self->body.size(); CHECK_GT(size, 0); @@ -53,10 +97,17 @@ Stmt MakeCrossThreadReduction( CHECK(reduce); reduces[i] = reduce; } - PrimExpr cond = reduces[0]->condition; - for (PrimExpr v : conds) { - cond = cond && v; - } + + // This computes the bound checking predicates in normal reduction. + auto normal_preds = + MakeBoundCheck(stage, dom_map, value_map, false, std::unordered_set()); + + // normal_pred = input_pred && normal_pred + PrimExpr input_pred = reduces[0]->condition; + normal_preds.push_back(input_pred); + normal_preds.erase(std::remove_if(normal_preds.begin(), normal_preds.end(), + [](const PrimExpr& e) { return !e.defined(); }), + normal_preds.end()); std::vector> common, normal_red; for (size_t i = 0, n = stage->leaf_iter_vars.size(); i < n; ++i) { @@ -91,16 +142,16 @@ Stmt MakeCrossThreadReduction( for (size_t i = 0; i < size; ++i) { DataType t = reduces[i]->dtype; normal_res_handles.emplace_back("normal_reduce_temp" + std::to_string(i), DataType::Handle()); - lhs.push_back(LoadNode::make(t, normal_res_handles[i], 0, const_true(t.lanes()))); + lhs.push_back(Load(t, normal_res_handles[i], 0, const_true(t.lanes()))); } Array init_value = combiner->identity_element; Array update_value = (*combiner)(lhs, reduces[0]->source); for (size_t i = 0; i < size; ++i) { DataType t = reduces[i]->dtype; - normal_init.emplace_back(StoreNode::make( - normal_res_handles[i], init_value[i], 0, const_true(t.lanes()))); - normal_update.emplace_back(StoreNode::make( - normal_res_handles[i], update_value[i], 0, const_true(t.lanes()))); + normal_init.emplace_back( + Store(normal_res_handles[i], init_value[i], 0, const_true(t.lanes()))); + normal_update.emplace_back( + Store(normal_res_handles[i], update_value[i], 0, const_true(t.lanes()))); } } @@ -109,13 +160,15 @@ Stmt MakeCrossThreadReduction( for (size_t i = 0; i < size; ++i) { if (!normal_red.empty()) { DataType t = reduces[i]->dtype; - freduce_args.push_back(LoadNode::make( - t, normal_res_handles[i], 0, const_true(t.lanes()))); + freduce_args.push_back(Load(t, normal_res_handles[i], 0, const_true(t.lanes()))); } else { freduce_args.push_back(reduces[0]->source[i]); } } - freduce_args.push_back(cond); + + // No constraints on the thread reduction step. It may have redundent + // computation for rare cases. TODO(tvm-team): revisit this. + freduce_args.push_back(const_true(1)); std::vector res_handles(size); for (size_t idx = 0; idx < size; ++idx) { res_handles[idx] = Var("reduce_temp" + std::to_string(idx), DataType::Handle()); @@ -125,58 +178,52 @@ Stmt MakeCrossThreadReduction( for (IterVar iv : stage->leaf_iter_vars) { if (iv->iter_type == kCommReduce) { auto it = stage->iter_var_attrs.find(iv); - if (it != stage->iter_var_attrs.end() && - (*it).second->bind_thread.defined()) { + if (it != stage->iter_var_attrs.end() && (*it).second->bind_thread.defined()) { IterVar tv = (*it).second->bind_thread; freduce_args.push_back(tv->var); } } } + // Checks for the thread. - std::vector thread_head_check; + std::vector output_preds; if (stage->store_predicate.defined()) { - thread_head_check.emplace_back(stage->store_predicate); + output_preds.emplace_back(stage->store_predicate); } - Stmt reduce_body = EvaluateNode::make(CallNode::make( - DataType::Handle(), - tir::intrinsic::tvm_thread_allreduce, - freduce_args, CallNode::Intrinsic)); - reduce_body = AttrStmtNode::make( - reduces[0]->combiner, - tir::attr::reduce_scope, - make_zero(DataType::Handle()), - reduce_body); + // Apply the existing input predicate if any. + output_preds.push_back(input_pred); + + Stmt reduce_body = Evaluate(Call(DataType::Handle(), tir::intrinsic::tvm_thread_allreduce, + freduce_args, CallNode::Intrinsic)); + reduce_body = AttrStmt(reduces[0]->combiner, tir::attr::reduce_scope, + make_zero(DataType::Handle()), reduce_body); if (!normal_red.empty()) { Stmt init_body = SeqStmt::Flatten(normal_init); Stmt update_body = SeqStmt::Flatten(normal_update); + update_body = MergeNest(MakeIfNest(normal_preds), update_body); update_body = MergeNest(normal_red, update_body); reduce_body = SeqStmt::Flatten(init_body, update_body, reduce_body); - reduce_body = MergeNest(MakeIfNest(conds), reduce_body); } std::vector assigns(size); for (size_t idx = 0; idx < size; ++idx) { DataType t = reduces[idx]->dtype; - assigns[idx] = ProvideNode::make( - stage->op, idx, - LoadNode::make(t, res_handles[idx], 0, const_true(t.lanes())), args); + assigns[idx] = ProducerStore(stage->op.output(idx), + Load(t, res_handles[idx], 0, const_true(t.lanes())), args); } Stmt assign_body = SeqStmt::Flatten(assigns); - assign_body = MergeNest(MakeIfNest(thread_head_check), assign_body); - assign_body = MergeNest(MakeIfNest(conds), assign_body); + assign_body = MergeNest(MakeIfNest(output_preds), assign_body); Stmt body = SeqStmt::Flatten(reduce_body, assign_body); for (size_t idx = size; idx != 0; --idx) { - body = AllocateNode::make( - res_handles[idx - 1], reduces[idx - 1]->dtype, {1}, const_true(), body); - body = AttrStmtNode::make( - res_handles[idx - 1], tir::attr::storage_scope, StringImmNode::make("local"), body); + body = Allocate(res_handles[idx - 1], reduces[idx - 1]->dtype, {1}, const_true(), body); + body = AttrStmt(res_handles[idx - 1], tir::attr::storage_scope, StringImm("local"), body); if (!normal_red.empty()) { - body = AllocateNode::make( - normal_res_handles[idx - 1], reduces[idx - 1]->dtype, {1}, const_true(), body); - body = AttrStmtNode::make( - normal_res_handles[idx - 1], tir::attr::storage_scope, StringImmNode::make("local"), body); + body = + Allocate(normal_res_handles[idx - 1], reduces[idx - 1]->dtype, {1}, const_true(), body); + body = + AttrStmt(normal_res_handles[idx - 1], tir::attr::storage_scope, StringImm("local"), body); } } body = Substitute(body, value_map); diff --git a/src/te/operation/extern_op.cc b/src/te/operation/extern_op.cc index 9d95e329c8f2..ef55c44241b0 100644 --- a/src/te/operation/extern_op.cc +++ b/src/te/operation/extern_op.cc @@ -21,11 +21,13 @@ * \brief External computation rule. * \file extern_op.cc */ +#include #include #include -#include #include + #include + #include "op_util.h" namespace tvm { @@ -33,39 +35,26 @@ namespace te { using namespace tir; // ExternOpNode TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "extern(" << op->name << ", " << op << ")"; - }); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "extern(" << op->name << ", " << op << ")"; + }); TVM_REGISTER_NODE_TYPE(ExternOpNode); -int ExternOpNode::num_outputs() const { - return static_cast(output_placeholders.size()); -} - -Array ExternOpNode::root_iter_vars() const { - return {}; -} +int ExternOpNode::num_outputs() const { return static_cast(output_placeholders.size()); } -DataType ExternOpNode::output_dtype(size_t i) const { - return output_placeholders[i]->dtype; -} +Array ExternOpNode::root_iter_vars() const { return {}; } -Array ExternOpNode::output_shape(size_t i) const { - return output_placeholders[i]->shape; -} +DataType ExternOpNode::output_dtype(size_t i) const { return output_placeholders[i]->dtype; } +Array ExternOpNode::output_shape(size_t i) const { return output_placeholders[i]->shape; } -Operation ExternOpNode::make(std::string name, - std::string tag, - Map attrs, - Array inputs, - Array input_placeholders, - Array output_placeholders, - Stmt body) { +ExternOp::ExternOp(std::string name, std::string tag, Map attrs, + Array inputs, Array input_placeholders, + Array output_placeholders, Stmt body) { if (!attrs.defined()) { - attrs = Map(); + attrs = Map(); } auto n = make_object(); n->name = std::move(name); @@ -76,7 +65,7 @@ Operation ExternOpNode::make(std::string name, CHECK_EQ(inputs[i]->dtype, input_placeholders[i]->dtype); CHECK_EQ(inputs[i]->shape.size(), input_placeholders[i]->shape.size()); for (size_t dim = 0; dim < inputs[i]->shape.size(); ++dim) { - CHECK(inputs[i]->shape[dim].same_as(input_placeholders[i]->shape[dim])); + CHECK(inputs[i]->shape[dim].same_as(input_placeholders[i]->shape[dim])); } CHECK_EQ(input_placeholders[i]->strides.size(), 0U); } @@ -84,20 +73,20 @@ Operation ExternOpNode::make(std::string name, n->input_placeholders = std::move(input_placeholders); n->output_placeholders = std::move(output_placeholders); n->body = std::move(body); - return Operation(n); + data_ = std::move(n); } TVM_REGISTER_GLOBAL("te.ExternOp") -.set_body_typed(ExternOpNode::make); + .set_body_typed([](std::string name, std::string tag, Map attrs, + Array inputs, Array input_placeholders, + Array output_placeholders, Stmt body) { + return ExternOp(name, tag, attrs, inputs, input_placeholders, output_placeholders, body); + }); +Array ExternOpNode::InputTensors() const { return inputs; } -Array ExternOpNode::InputTensors() const { - return inputs; -} - -Operation ExternOpNode::ReplaceInputs( - const Operation& self, - const std::unordered_map& rmap) const { +Operation ExternOpNode::ReplaceInputs(const Operation& self, + const std::unordered_map& rmap) const { CHECK_EQ(self.operator->(), this); auto n = make_object(*this); n->body = ReplaceTensor(this->body, rmap); @@ -108,65 +97,52 @@ Operation ExternOpNode::ReplaceInputs( } } - if (body.same_as(n->body) && - inputs.same_as(n->inputs)) { + if (body.same_as(n->body) && inputs.same_as(n->inputs)) { return self; } else { return Operation(n); } } -void ExternOpNode::PropBoundToInputs( - const Operation& self, - arith::Analyzer* analyzer, - const std::unordered_map& dom_map, - std::unordered_map* out_dom_map) const { +void ExternOpNode::PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer, + const std::unordered_map& dom_map, + std::unordered_map* out_dom_map) const { for (Tensor t : this->inputs) { auto it = out_dom_map->find(t); if (it == out_dom_map->end()) continue; TensorDom& dom = it->second; for (size_t i = 0; i < t->shape.size(); ++i) { dom.data[i].emplace_back(IntSet::range( - Range::make_by_min_extent( - make_const(t->shape[i].dtype(), 0), t->shape[i]))); + Range::make_by_min_extent(make_const(t->shape[i].dtype(), 0), t->shape[i]))); } } } -void ExternOpNode::GatherBound( - const Operation& self, - const std::unordered_map& tensor_dom, - std::unordered_map* out_dom_map) const { -} +void ExternOpNode::GatherBound(const Operation& self, + const std::unordered_map& tensor_dom, + std::unordered_map* out_dom_map) const {} -Stmt ExternOpNode::BuildRealize( - const Stage& stage, - const std::unordered_map& realize_map, - const Stmt& body) const { +Stmt ExternOpNode::BuildRealize(const Stage& stage, + const std::unordered_map& realize_map, + const Stmt& body) const { CHECK_EQ(stage->op.get(), this); Stmt realize_body = body; for (int k = 0; k < num_outputs(); ++k) { Tensor t = stage->op.output(k); Region bounds; for (size_t i = 0; i < t->shape.size(); ++i) { - bounds.push_back( - Range::make_by_min_extent( - make_const(t->shape[i].dtype(), 0), t->shape[i])); + bounds.push_back(Range::make_by_min_extent(make_const(t->shape[i].dtype(), 0), t->shape[i])); } - realize_body = tir::RealizeNode::make( - t->op, t->value_index, t->dtype, - bounds, const_true(), realize_body); + realize_body = tir::ProducerRealize(t, bounds, const_true(), realize_body); } return realize_body; } -Stmt ExternOpNode::BuildProvide( - const Stage& stage, - const std::unordered_map& dom_map, - bool debug_keep_trivial_loop) const { +Stmt ExternOpNode::BuildProvide(const Stage& stage, + const std::unordered_map& dom_map, + bool debug_keep_trivial_loop) const { CHECK_EQ(stage->op.operator->(), this); - Stmt ret = AttrStmtNode::make( - make_zero(DataType::Int(32)), tir::attr::extern_scope, 0, this->body); + Stmt ret = AttrStmt(make_zero(DataType::Int(32)), tir::attr::extern_scope, 0, this->body); auto f_push_bind = [&ret](Buffer buffer, Tensor tensor) { Array bind_spec; Array tuple; @@ -176,9 +152,8 @@ Stmt ExternOpNode::BuildProvide( tuple.push_back(make_const(buffer->shape[k].dtype(), 0)); tuple.push_back(buffer->shape[k]); } - ret = AttrStmtNode::make( - bind_spec, tir::attr::buffer_bind_scope, - CallNode::make(DataType::Handle(), intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), ret); + ret = AttrStmt(bind_spec, tir::attr::buffer_bind_scope, + Call(DataType::Handle(), intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), ret); }; for (size_t i = output_placeholders.size(); i != 0; --i) { f_push_bind(output_placeholders[i - 1], stage->op.output(i - 1)); diff --git a/src/te/operation/hybrid_op.cc b/src/te/operation/hybrid_op.cc index 4da127ea0a85..9be474d7d941 100644 --- a/src/te/operation/hybrid_op.cc +++ b/src/te/operation/hybrid_op.cc @@ -21,57 +21,46 @@ * \brief Hybrid computation rule. * \file hybrid_op.cc */ +#include "hybrid_op.h" + +#include #include #include -#include -#include -#include -#include #include +#include #include -#include +#include + #include +#include #include + #include "op_util.h" -#include "hybrid_op.h" namespace tvm { namespace te { using namespace tir; // HybridOpNode TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "hybrid(" << op->name << ", " << op << ")"; - }); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "hybrid(" << op->name << ", " << op << ")"; + }); TVM_REGISTER_NODE_TYPE(HybridOpNode); -int HybridOpNode::num_outputs() const { - return static_cast(outputs.size()); -} - -Array HybridOpNode::root_iter_vars() const { - return this->axis; -} +int HybridOpNode::num_outputs() const { return static_cast(outputs.size()); } -DataType HybridOpNode::output_dtype(size_t i) const { - return outputs[i]->dtype; -} +Array HybridOpNode::root_iter_vars() const { return this->axis; } -Array HybridOpNode::output_shape(size_t i) const { - return outputs[i]->shape; -} +DataType HybridOpNode::output_dtype(size_t i) const { return outputs[i]->dtype; } +Array HybridOpNode::output_shape(size_t i) const { return outputs[i]->shape; } -Operation HybridOpNode::make(std::string name, - std::string tag, - Map attrs, - Array inputs, - Array outputs, - Stmt body) { +HybridOp::HybridOp(std::string name, std::string tag, Map attrs, + Array inputs, Array outputs, Stmt body) { if (!attrs.defined()) { - attrs = Map(); + attrs = Map(); } auto n = make_object(); n->name = std::move(name); @@ -81,13 +70,13 @@ Operation HybridOpNode::make(std::string name, n->outputs = std::move(outputs); n->axis = te::GatherLoopVars(body); n->body = std::move(body); - Operation res = Operation(n); - return res; + data_ = std::move(n); } TVM_REGISTER_GLOBAL("te.HybridOp") -.set_body_typed(HybridOpNode::make); - + .set_body_typed([](std::string name, std::string tag, Map attrs, + Array inputs, Array outputs, + Stmt body) { return HybridOp(name, tag, attrs, inputs, outputs, body); }); Array HybridOpNode::InputTensors() const { // Because input tensors could be potentially inlined into hybrid scripts, @@ -99,21 +88,19 @@ Array HybridOpNode::InputTensors() const { std::unordered_set visited; Array curr_inputs; tir::PostOrderVisit(body, [&curr_inputs, &orig_inputs, &visited](const ObjectRef& n) { - const tir::CallNode *call = n.as(); - if (call != nullptr && call->func.defined()) { - Tensor t = Downcast(call->func).output(call->value_index); - if (orig_inputs.count(t) && !visited.count(t)) { - curr_inputs.push_back(t); - visited.insert(t); - } + if (auto* pload = n.as()) { + Tensor t = Downcast(pload->producer); + if (orig_inputs.count(t) && !visited.count(t)) { + curr_inputs.push_back(t); + visited.insert(t); } + } }); return curr_inputs; } -Operation HybridOpNode::ReplaceInputs( - const Operation &self, - const std::unordered_map &rmap) const { +Operation HybridOpNode::ReplaceInputs(const Operation& self, + const std::unordered_map& rmap) const { CHECK_EQ(self.operator->(), this); auto n = make_object(*this); n->body = te::ReplaceTensor(this->body, rmap); @@ -124,46 +111,40 @@ Operation HybridOpNode::ReplaceInputs( } } - if (body.same_as(n->body) && - inputs.same_as(n->inputs)) { + if (body.same_as(n->body) && inputs.same_as(n->inputs)) { return self; } else { return Operation(n); } } -void HybridOpNode::PropBoundToInputs( - const Operation &self, - arith::Analyzer* analyzer, - const std::unordered_map &dom_map, - std::unordered_map* out_dom_map) const { +void HybridOpNode::PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer, + const std::unordered_map& dom_map, + std::unordered_map* out_dom_map) const { auto curr_inputs = InputTensors(); for (Tensor t : curr_inputs) { auto it = out_dom_map->find(t); if (it == out_dom_map->end()) continue; - TensorDom &dom = it->second; + TensorDom& dom = it->second; for (size_t i = 0; i < t->shape.size(); ++i) { dom.data[i].emplace_back(IntSet::range( - Range::make_by_min_extent( - make_const(t->shape[i].dtype(), 0), t->shape[i]))); + Range::make_by_min_extent(make_const(t->shape[i].dtype(), 0), t->shape[i]))); } } } -void HybridOpNode::GatherBound( - const Operation &self, - const std::unordered_map &tensor_dom, - std::unordered_map* out_dom_map) const { +void HybridOpNode::GatherBound(const Operation& self, + const std::unordered_map& tensor_dom, + std::unordered_map* out_dom_map) const { for (auto iter_var : axis) { CHECK(!out_dom_map->count(iter_var)); out_dom_map->operator[](iter_var) = iter_var->dom; } } -Stmt HybridOpNode::BuildRealize( - const Stage &stage, - const std::unordered_map &realize_map, - const Stmt &body) const { +Stmt HybridOpNode::BuildRealize(const Stage& stage, + const std::unordered_map& realize_map, + const Stmt& body) const { // TODO(@were): Add attribute inject here and remove it from hybrid parser. CHECK_EQ(stage->op.get(), this); Stmt realize_body = body; @@ -171,24 +152,18 @@ Stmt HybridOpNode::BuildRealize( Tensor t = stage->op.output(k); Region bounds; for (size_t i = 0; i < t->shape.size(); ++i) { - bounds.push_back( - Range::make_by_min_extent( - make_const(t->shape[i].dtype(), 0), t->shape[i])); + bounds.push_back(Range::make_by_min_extent(make_const(t->shape[i].dtype(), 0), t->shape[i])); } - realize_body = tir::RealizeNode::make( - t->op, t->value_index, t->dtype, - bounds, const_true(), realize_body); + realize_body = tir::ProducerRealize(t, bounds, const_true(), realize_body); } return realize_body; } -Stmt HybridOpNode::BuildProvide( - const Stage &stage, - const std::unordered_map &dom_map, - bool debug_keep_trivial_loop) const { +Stmt HybridOpNode::BuildProvide(const Stage& stage, + const std::unordered_map& dom_map, + bool debug_keep_trivial_loop) const { CHECK_EQ(stage->op.operator->(), this); - Stmt ret = AttrStmtNode::make( - make_zero(DataType::Int(32)), tir::attr::extern_scope, 0, this->body); + Stmt ret = AttrStmt(make_zero(DataType::Int(32)), tir::attr::extern_scope, 0, this->body); std::unordered_map rmap; for (int i = 0; i < this->num_outputs(); ++i) { rmap[outputs[i]] = stage->op.output(i); @@ -224,45 +199,44 @@ Stmt HybridOpNode::BuildProvide( return ret; } -Stmt ApplyLoopShapes(const Stage &stage, - const std::unordered_map &dom_map, Stmt stmt) { +Stmt ApplyLoopShapes(const Stage& stage, const std::unordered_map& dom_map, + Stmt stmt) { class LoopSpliter : public StmtExprMutator { PrimExpr factor; - const VarNode *parent; + const VarNode* parent; IterVar inner, outer; public: bool splitted; - LoopSpliter(const SplitNode *split, - const std::unordered_map &dom_map) : - factor(split->factor), splitted(false) { + LoopSpliter(const SplitNode* split, const std::unordered_map& dom_map) + : factor(split->factor), splitted(false) { parent = split->parent->var.get(); - auto &inner_ = split->inner; + auto& inner_ = split->inner; CHECK(dom_map.count(inner_)); - auto &inner_dom = dom_map.find(inner_)->second; + auto& inner_dom = dom_map.find(inner_)->second; CHECK(is_const_int(inner_dom->min, 0)); - auto &outer_ = split->outer; + auto& outer_ = split->outer; CHECK(dom_map.count(outer_)); - auto &outer_dom = dom_map.find(outer_)->second; + auto& outer_dom = dom_map.find(outer_)->second; CHECK(is_const_int(outer_dom->min, 0)); - inner = IterVarNode::make(inner_dom, inner_->var, inner_->iter_type); - outer = IterVarNode::make(outer_dom, outer_->var, outer_->iter_type); + inner = IterVar(inner_dom, inner_->var, inner_->iter_type); + outer = IterVar(outer_dom, outer_->var, outer_->iter_type); } - Stmt VisitStmt_(const ForNode *op) final { + Stmt VisitStmt_(const ForNode* op) final { if (op->loop_var.get() == parent) { - std::unordered_map rmap; + std::unordered_map rmap; rmap[op->loop_var.get()] = inner + outer * factor; Stmt ret = tir::Substitute(op->body, rmap); PrimExpr cond = likely(outer * factor < (op->extent - inner)); - ret = IfThenElseNode::make(cond, ret); - ret = ForNode::make(inner->var, PrimExpr(0), inner->dom->extent, - IterVarTypeToForType(inner->iter_type), op->device_api, ret); - ret = ForNode::make(outer->var, PrimExpr(0), outer->dom->extent, - IterVarTypeToForType(outer->iter_type), op->device_api, ret); + ret = IfThenElse(cond, ret); + ret = For(inner->var, PrimExpr(0), inner->dom->extent, + IterVarTypeToForType(inner->iter_type), op->device_api, ret); + ret = For(outer->var, PrimExpr(0), outer->dom->extent, + IterVarTypeToForType(outer->iter_type), op->device_api, ret); splitted = true; return ret; } @@ -271,24 +245,27 @@ Stmt ApplyLoopShapes(const Stage &stage, }; class LoopFuser : public StmtExprMutator { - const IterVar &parent; - const VarNode *inner; - const VarNode *outer; + const IterVar& parent; + const VarNode* inner; + const VarNode* outer; bool under_outer; PrimExpr extent; public: bool fused; - explicit LoopFuser(const FuseNode *fuse_) - : parent(fuse_->fused), inner(fuse_->inner->var.get()), - outer(fuse_->outer->var.get()), under_outer(false), - extent(0), fused(false) {} + explicit LoopFuser(const FuseNode* fuse_) + : parent(fuse_->fused), + inner(fuse_->inner->var.get()), + outer(fuse_->outer->var.get()), + under_outer(false), + extent(0), + fused(false) {} // TODO(@were): Handle imperfect loops Stmt VisitStmt_(const ForNode* op) final { if (op->loop_var.get() == inner) { CHECK(under_outer); - std::unordered_map rmap; + std::unordered_map rmap; rmap[op->loop_var.get()] = indexmod(parent, op->extent); extent = op->extent; fused = true; @@ -296,15 +273,15 @@ Stmt ApplyLoopShapes(const Stage &stage, } else if (op->loop_var.get() == outer) { under_outer = true; Stmt body = this->VisitStmt(op->body); - std::unordered_map rmap; + std::unordered_map rmap; rmap[op->loop_var.get()] = indexdiv(parent, extent); body = tir::Substitute(body, rmap); under_outer = false; - return ForNode::make(parent->var, PrimExpr(0), extent * op->extent, - op->for_type, op->device_api, body); + return For(parent->var, PrimExpr(0), extent * op->extent, op->for_type, op->device_api, + body); } else if (under_outer) { Stmt body = this->VisitStmt(op->body); - std::unordered_map rmap; + std::unordered_map rmap; rmap[op->loop_var.get()] = indexmod(indexdiv(parent, extent), op->extent); body = tir::Substitute(body, rmap); extent = extent * op->extent; @@ -314,12 +291,12 @@ Stmt ApplyLoopShapes(const Stage &stage, } }; - for (auto &rel : stage->relations) { - if (const SplitNode *split = rel.as()) { + for (auto& rel : stage->relations) { + if (const SplitNode* split = rel.as()) { LoopSpliter Spliter(split, dom_map); stmt = Spliter(stmt); CHECK(Spliter.splitted); - } else if (const FuseNode *fuse = rel.as()) { + } else if (const FuseNode* fuse = rel.as()) { LoopFuser Fuser(fuse); stmt = Fuser(stmt); CHECK(Fuser.fused); @@ -329,45 +306,45 @@ Stmt ApplyLoopShapes(const Stage &stage, return stmt; } -Stmt ApplyLoopAnnotations(const Stage &stage, - const std::unordered_map &rebased, Stmt stmt) { +Stmt ApplyLoopAnnotations(const Stage& stage, const std::unordered_map& rebased, + Stmt stmt) { class LoopAnnotator : public StmtMutator { - const VarNode *var; - const IterVarAttr &attr; + const VarNode* var; + const IterVarAttr& attr; public: - LoopAnnotator(const VarNode *var_, const IterVarAttr &attr_) : var(var_), attr(attr_) {} + LoopAnnotator(const VarNode* var_, const IterVarAttr& attr_) : var(var_), attr(attr_) {} - Stmt VisitStmt_(const ForNode *op) final { + Stmt VisitStmt_(const ForNode* op) final { tir::ExprDeepEqual expr_equal; if (op->loop_var.get() == var) { if (attr->bind_thread.defined()) { - const auto &iter_var = attr->bind_thread; + const auto& iter_var = attr->bind_thread; if (iter_var->dom.defined()) { CHECK(is_const_int(iter_var->dom->min, 0)); CHECK(expr_equal(iter_var->dom->extent, op->extent)) - << "Thread extent and loop extent mismatch!\n"; + << "Thread extent and loop extent mismatch!\n"; } - std::unordered_map rmap; + std::unordered_map rmap; rmap[op->loop_var.get()] = iter_var; Stmt body = tir::Substitute(op->body, rmap); - return AttrStmtNode::make(iter_var, "thread_extent", op->extent, body); + return AttrStmt(iter_var, "thread_extent", op->extent, body); } else { - return ForNode::make(op->loop_var, op->min, op->extent, - IterVarTypeToForType(attr->iter_type), op->device_api, op->body); + return For(op->loop_var, op->min, op->extent, IterVarTypeToForType(attr->iter_type), + op->device_api, op->body); } } return StmtMutator::VisitStmt_(op); } }; - for (auto &iter_var : stage->leaf_iter_vars) { + for (auto& iter_var : stage->leaf_iter_vars) { bool need_change = false; int found = 0; - const IterVar &actual = rebased.count(iter_var) ? rebased.find(iter_var)->second : iter_var; - const VarNode *var = actual->var.get(); + const IterVar& actual = rebased.count(iter_var) ? rebased.find(iter_var)->second : iter_var; + const VarNode* var = actual->var.get(); ForType expected = IterVarTypeToForType(iter_var->iter_type); IterVarAttr attr; if (stage->iter_var_attrs.count(iter_var)) { @@ -375,9 +352,8 @@ Stmt ApplyLoopAnnotations(const Stage &stage, expected = IterVarTypeToForType(attr->iter_type); } - PostOrderVisit(stmt, - [&found, &var, &attr, &expected, &need_change](const ObjectRef& node) { - if (const ForNode *op = node.as()) { + PostOrderVisit(stmt, [&found, &var, &attr, &expected, &need_change](const ObjectRef& node) { + if (const ForNode* op = node.as()) { if (op->loop_var.get() == var) { ++found; need_change = expected != op->for_type || (attr.defined() && attr->bind_thread.defined()); @@ -393,23 +369,21 @@ Stmt ApplyLoopAnnotations(const Stage &stage, return stmt; } -Stmt ApplyLoopOrder(const Stage &stage, - const std::unordered_map &dom_map, - const std::unordered_map &rebased, Stmt stmt) { +Stmt ApplyLoopOrder(const Stage& stage, const std::unordered_map& dom_map, + const std::unordered_map& rebased, Stmt stmt) { std::vector current_order; PostOrderVisit(stmt, [¤t_order](const ObjectRef& node) { - if (const ForNode *op = node.as()) - current_order.push_back(op->loop_var.get()); + if (const ForNode* op = node.as()) current_order.push_back(op->loop_var.get()); }); std::reverse(current_order.begin(), current_order.end()); - auto &required_ord = stage->leaf_iter_vars; + auto& required_ord = stage->leaf_iter_vars; CHECK_EQ(current_order.size(), required_ord.size()) << "Cannot reorder the loops!"; - std::unordered_map reorder; + std::unordered_map reorder; bool need_reorder = false; for (size_t i = 0; i < current_order.size(); ++i) { - auto ¤t = current_order[i]; - const IterVar &iter_var = required_ord[i]; - const IterVar &required = rebased.count(iter_var) ? rebased.find(iter_var)->second : iter_var; + auto& current = current_order[i]; + const IterVar& iter_var = required_ord[i]; + const IterVar& required = rebased.count(iter_var) ? rebased.find(iter_var)->second : iter_var; CHECK(required->dom.defined() || dom_map.count(required)) << required << "\n"; reorder[current] = required; if (current != required->var.get()) { @@ -418,15 +392,14 @@ Stmt ApplyLoopOrder(const Stage &stage, } class LoopReorder : public StmtMutator { - const Stage &stage; - const std::unordered_map &dom_map; - const std::unordered_map &reorder; + const Stage& stage; + const std::unordered_map& dom_map; + const std::unordered_map& reorder; public: - LoopReorder(const Stage &stage, - const std::unordered_map &dom_map, - const std::unordered_map &reorder) - : stage(stage), dom_map(dom_map), reorder(reorder) {} + LoopReorder(const Stage& stage, const std::unordered_map& dom_map, + const std::unordered_map& reorder) + : stage(stage), dom_map(dom_map), reorder(reorder) {} Stmt VisitStmt_(const ForNode* op) final { // Reorder from in to out @@ -435,25 +408,23 @@ Stmt ApplyLoopOrder(const Stage &stage, auto target = reorder.find(op->loop_var.get())->second; if (body_.same_as(op->body) && op->loop_var.get() == target->var.get()) return GetRef(op); - const Stmt &body = op->body.same_as(body_) ? op->body : body_; + const Stmt& body = op->body.same_as(body_) ? op->body : body_; ForType for_type = IterVarTypeToForType(target->iter_type); if (stage->iter_var_attrs.count(target)) { for_type = IterVarTypeToForType(stage->iter_var_attrs[target]->iter_type); } - const Range &range = target->dom.defined() ? target->dom : dom_map.find(target)->second; - return ForNode::make(target->var, range->min, range->extent, - for_type, DeviceAPI::None, body); + const Range& range = target->dom.defined() ? target->dom : dom_map.find(target)->second; + return For(target->var, range->min, range->extent, for_type, DeviceAPI::None, body); } }; - if (need_reorder) - return LoopReorder(stage, dom_map, reorder)(stmt); + if (need_reorder) return LoopReorder(stage, dom_map, reorder)(stmt); return stmt; } -Stmt ApplySchedule(const Stage &stage, - const std::unordered_map &dom_map, Stmt stmt) { +Stmt ApplySchedule(const Stage& stage, const std::unordered_map& dom_map, + Stmt stmt) { // TODO(@were): Eliminate loop rebase in script parser and move the burden here // Gather rebased variables std::unordered_map rebased; @@ -474,10 +445,10 @@ std::vector GatherLoopVars(Stmt stmt) { // TODO(@were): Write a comprehensive pass to analyze iter var types std::vector res_; PostOrderVisit(stmt, [&res_](const ObjectRef& node) { - if (const ForNode *op = node.as()) { + if (const ForNode* op = node.as()) { Var loop_var(op->loop_var); Range dom = Range::make_by_min_extent(op->min, op->extent); - res_.push_back(IterVarNode::make(dom, loop_var, ForTypeToIterVarType(op->for_type))); + res_.push_back(IterVar(dom, loop_var, ForTypeToIterVarType(op->for_type))); } }); std::reverse(res_.begin(), res_.end()); @@ -487,15 +458,13 @@ std::vector GatherLoopVars(Stmt stmt) { // replacer to replace tensors' usage in Provide class ProviderReplacer : public tir::StmtMutator { public: - explicit ProviderReplacer(const std::unordered_map &vmap) - : vmap_(vmap) {} + explicit ProviderReplacer(const std::unordered_map& vmap) : vmap_(vmap) {} - Stmt VisitStmt_(const tir::ProvideNode* op) final { - Tensor t = Downcast(op->func).output(op->value_index); + Stmt VisitStmt_(const tir::ProducerStoreNode* op) final { + Tensor t = Downcast(op->producer); auto it = vmap_.find(t); if (it != vmap_.end()) { - Stmt ret = tir::ProvideNode::make( - it->second->op, it->second->value_index, op->value, op->args); + Stmt ret = tir::ProducerStore(it->second, op->value, op->indices); found = true; return this->VisitStmt(ret); } @@ -506,11 +475,10 @@ class ProviderReplacer : public tir::StmtMutator { bool found{false}; private: - const std::unordered_map &vmap_; + const std::unordered_map& vmap_; }; -Stmt ReplaceProvideTensor(Stmt stmt, - const std::unordered_map &replace) { +Stmt ReplaceProvideTensor(Stmt stmt, const std::unordered_map& replace) { ProviderReplacer repl(replace); Stmt ret = repl(stmt); return repl.found ? ret : stmt; diff --git a/src/te/operation/hybrid_op.h b/src/te/operation/hybrid_op.h index a7b2cb16c080..a11ae89e23f7 100644 --- a/src/te/operation/hybrid_op.h +++ b/src/te/operation/hybrid_op.h @@ -24,16 +24,16 @@ #ifndef TVM_TE_OPERATION_HYBRID_OP_H_ #define TVM_TE_OPERATION_HYBRID_OP_H_ -#include #include +#include #include #include #include +#include "../../tir/transforms/arg_binder.h" +#include "../../tir/transforms/ir_util.h" #include "../schedule/message_passing.h" -#include "../../tir/pass/ir_util.h" -#include "../../tir/pass/arg_binder.h" namespace tvm { namespace te { @@ -49,8 +49,7 @@ std::vector GatherLoopVars(Stmt stmt); * \param stmt The statement to be processed. * \param replace The replacement rule. */ -Stmt ReplaceProvideTensor(Stmt stmt, - const std::unordered_map& replace); +Stmt ReplaceProvideTensor(Stmt stmt, const std::unordered_map& replace); /*! * \brief Apply the schedule manipulation on the function body. @@ -58,8 +57,8 @@ Stmt ReplaceProvideTensor(Stmt stmt, * \param dom_map The extents of the iterative variables may be used. * \param stage The schedule information to be applied. */ -Stmt ApplySchedule(const Stage& stage, - const std::unordered_map& dom_map, Stmt stmt); +Stmt ApplySchedule(const Stage& stage, const std::unordered_map& dom_map, + Stmt stmt); /*! * \brief Apply loop splits and fuses in the schedule on the function body. @@ -67,9 +66,8 @@ Stmt ApplySchedule(const Stage& stage, * \param dom_map The extents of the iterative variables may be used. * \param stmt The statement to be processed. */ -Stmt ApplyLoopShapes(const Stage &stage, - const std::unordered_map& dom_map, Stmt stmt); - +Stmt ApplyLoopShapes(const Stage& stage, const std::unordered_map& dom_map, + Stmt stmt); /*! * \brief Apply loop annotation in the schedule on the function body. @@ -77,8 +75,8 @@ Stmt ApplyLoopShapes(const Stage &stage, * \param rebased The map specifies the rebase, a.k.a rename, relationship of these variables. * \param stmt The statement to be processed. */ -Stmt ApplyLoopAnnotations(const Stage &stage, - const std::unordered_map& rebased, Stmt stmt); +Stmt ApplyLoopAnnotations(const Stage& stage, const std::unordered_map& rebased, + Stmt stmt); /*! * \brief Apply loop order in the schedule on the function body. @@ -87,9 +85,8 @@ Stmt ApplyLoopAnnotations(const Stage &stage, * \param rebased The map specifies the rebase, a.k.a rename, relationship of these variables. * \param stmt The statement to be processed. */ -Stmt ApplyLoopOrder(const Stage &stage, - const std::unordered_map &dom_map, - const std::unordered_map &rebased, Stmt stmt); +Stmt ApplyLoopOrder(const Stage& stage, const std::unordered_map& dom_map, + const std::unordered_map& rebased, Stmt stmt); } // namespace te } // namespace tvm diff --git a/src/te/operation/op_util.cc b/src/te/operation/op_util.cc index 4ecfe9472901..61b782629d19 100644 --- a/src/te/operation/op_util.cc +++ b/src/te/operation/op_util.cc @@ -21,15 +21,16 @@ * \brief Utility to make loop nest. * \file op_util.cc */ +#include "op_util.h" + +#include #include -#include #include -#include + #include -#include "op_util.h" -#include "../schedule/message_passing.h" -#include "../../arith/compute_expr.h" + #include "../../runtime/thread_storage_scope.h" +#include "../schedule/message_passing.h" namespace tvm { namespace te { @@ -37,16 +38,14 @@ namespace te { using namespace arith; using namespace tir; -std::vector > -MakeLoopNest(const Stage& stage, - const std::unordered_map& dom_map, - size_t begin_iter_pos, - bool new_loop_var, - const std::unordered_set& skip_iter, - std::unordered_map* p_value_map, - bool debug_keep_trivial_loop) { +std::vector > MakeLoopNest(const Stage& stage, + const std::unordered_map& dom_map, + size_t begin_iter_pos, bool new_loop_var, + const std::unordered_set& skip_iter, + std::unordered_map* p_value_map, + bool debug_keep_trivial_loop) { auto leaf_iter_vars = stage->leaf_iter_vars; - Stmt no_op = EvaluateNode::make(0); + Stmt no_op = Evaluate(0); // create the loop nest std::vector > nest; nest.resize(leaf_iter_vars.size() + 1); @@ -85,14 +84,21 @@ MakeLoopNest(const Stage& stage, } if (it_attr.defined()) { switch (it_attr->iter_type) { - case kUnrolled: for_type = ForType::Unrolled; break; - case kVectorized: for_type = ForType::Vectorized; break; - case kParallelized: for_type = ForType::Parallel; break; - case kDataPar: break; - case kTensorized: break; - default: LOG(FATAL) << "Unknown iter type" - << it_attr->iter_type - << " in the iter_var_attrs"; + case kUnrolled: + for_type = ForType::Unrolled; + break; + case kVectorized: + for_type = ForType::Vectorized; + break; + case kParallelized: + for_type = ForType::Parallel; + break; + case kDataPar: + break; + case kTensorized: + break; + default: + LOG(FATAL) << "Unknown iter type" << it_attr->iter_type << " in the iter_var_attrs"; } CHECK_EQ(it_attr->pragma_keys.size(), it_attr->pragma_values.size()); for (size_t k = 0; k < it_attr->pragma_keys.size(); ++k) { @@ -102,49 +108,37 @@ MakeLoopNest(const Stage& stage, pvalue = make_const(DataType::Int(32), 1); } nest[i + 1].emplace_back( - AttrStmtNode::make(iv, tir::attr::pragma_scope_prefix + pkey, pvalue, no_op)); + AttrStmt(iv, tir::attr::pragma_scope_prefix + pkey, pvalue, no_op)); } } if (!debug_keep_trivial_loop && is_one(dom->extent)) { - nest[i + 1].emplace_back( - LetStmtNode::make(var, dom->min, no_op)); + nest[i + 1].emplace_back(LetStmt(var, dom->min, no_op)); value_map[iv] = dom->min; } else if (is_zero(dom->min)) { - nest[i + 1].emplace_back( - ForNode::make(var, 0, dom->extent, - for_type, DeviceAPI::None, no_op)); + nest[i + 1].emplace_back(For(var, 0, dom->extent, for_type, DeviceAPI::None, no_op)); value_map[iv] = var; } else { Var idx(bind_iv->var->name_hint + ".idx", bind_iv->var.dtype()); - nest[i + 1].emplace_back( - ForNode::make(idx, 0, dom->extent, - for_type, DeviceAPI::None, no_op)); + nest[i + 1].emplace_back(For(idx, 0, dom->extent, for_type, DeviceAPI::None, no_op)); PrimExpr new_value = dom->min + idx; value_map[iv] = new_value; - nest[i + 1].emplace_back( - LetStmtNode::make(var, new_value, no_op)); + nest[i + 1].emplace_back(LetStmt(var, new_value, no_op)); } if (it_attr.defined() && it_attr->prefetch_data.size() != 0) { - CHECK(!is_one(dom->extent)) - << "Cannot prefetch on trivial loop with extent=1"; - CHECK_EQ(it_attr->prefetch_data.size(), - it_attr->prefetch_offset.size()); + CHECK(!is_one(dom->extent)) << "Cannot prefetch on trivial loop with extent=1"; + CHECK_EQ(it_attr->prefetch_data.size(), it_attr->prefetch_offset.size()); for (size_t j = 0; j < it_attr->prefetch_data.size(); ++j) { - nest[i + 1].emplace_back( - AttrStmtNode::make(it_attr->prefetch_data[j], - tir::attr::prefetch_scope, - it_attr->prefetch_offset[j], no_op)); + nest[i + 1].emplace_back(AttrStmt(it_attr->prefetch_data[j], tir::attr::prefetch_scope, + it_attr->prefetch_offset[j], no_op)); } } - } else if (bind_iv->thread_tag == "vthread" || - bind_iv->thread_tag == "cthread") { + } else if (bind_iv->thread_tag == "vthread" || bind_iv->thread_tag == "cthread") { // virtual thread // Always restrict threaded IterVar to starts from 0. CHECK(is_zero(dom->min)); CHECK(is_positive_const(dom->extent)); // annotate the extent of the IterVar - nest[i + 1].emplace_back( - AttrStmtNode::make(bind_iv, tir::attr::virtual_thread, dom->extent, no_op)); + nest[i + 1].emplace_back(AttrStmt(bind_iv, tir::attr::virtual_thread, dom->extent, no_op)); value_map[iv] = var; } else if (bind_iv->thread_tag == "pipeline") { // pipeline marker. @@ -152,21 +146,32 @@ MakeLoopNest(const Stage& stage, CHECK(is_one(dom->extent)); // annotate the extent of the IterVar nest[i + 1].emplace_back( - AttrStmtNode::make(bind_iv, tir::attr::pipeline_exec_scope, dom->extent, no_op)); + AttrStmt(bind_iv, tir::attr::pipeline_exec_scope, dom->extent, no_op)); value_map[iv] = dom->min; } else { // Always restrict threaded IterVar to starts from 0. CHECK(is_zero(dom->min)); // annotate the extent of the IterVar - nest[i + 1].emplace_back( - AttrStmtNode::make(bind_iv, tir::attr::thread_extent, dom->extent, no_op)); + nest[i + 1].emplace_back(AttrStmt(bind_iv, tir::attr::thread_extent, dom->extent, no_op)); if (!debug_keep_trivial_loop && is_one(dom->extent)) { value_map[iv] = dom->min; } else { - runtime::ThreadScope ts = runtime::ThreadScope::make(bind_iv->thread_tag); - if (stage->scope == "" || stage->scope == "warp" || - static_cast(runtime::StorageScope::make(stage->scope).rank) <= ts.rank) { + runtime::ThreadScope ts = runtime::ThreadScope::Create(bind_iv->thread_tag); + if (stage->scope == "" || + static_cast(runtime::StorageScope::Create(stage->scope).rank) <= ts.rank) { value_map[iv] = var; + } else if (stage->scope == "warp" && ts.rank == 1) { + // To determine whether a thread index is inside or outside a warp, we need + // to know the thread extent. We leave a warning for now. + if (ts.dim_index == 0) { + value_map[iv] = var; + } else { + LOG(WARNING) + << "WARNING: threadIdx.y or threadIdx.z accessing warp-scope memory detected. " + << "TVM assumes only threadIdx.x indicates threads inside a warp, " + << "while threadIdx.y and threadIdx.z indicates different warps."; + value_map[iv] = dom->min; + } } else { value_map[iv] = dom->min; } @@ -174,8 +179,7 @@ MakeLoopNest(const Stage& stage, } // annotate the extent of the IterVar if (!new_loop_var) { - nest[i + 1].emplace_back( - AttrStmtNode::make(iv, tir::attr::loop_scope, iv->var, no_op)); + nest[i + 1].emplace_back(AttrStmt(iv, tir::attr::loop_scope, iv->var, no_op)); } } // message passing to get offset of root iter vars. @@ -184,10 +188,10 @@ MakeLoopNest(const Stage& stage, } std::vector MakeIfNest(const std::vector& predicates) { - Stmt no_op = EvaluateNode::make(0); + Stmt no_op = Evaluate(0); std::vector nest; for (const PrimExpr& cond : predicates) { - nest.emplace_back(IfThenElseNode::make(cond, no_op)); + nest.emplace_back(IfThenElse(cond, no_op)); } return nest; } @@ -195,22 +199,21 @@ std::vector MakeIfNest(const std::vector& predicates) { // replacer to replace tensors class TensorReplacer : public tir::StmtExprMutator { public: - explicit TensorReplacer(const std::unordered_map& vmap) - : vmap_(vmap) {} + explicit TensorReplacer(const std::unordered_map& vmap) : vmap_(vmap) {} - PrimExpr VisitExpr_(const tir::CallNode* op) final { - if (op->call_type == tir::CallNode::Halide) { - Tensor t = Downcast(op->func).output(op->value_index); - auto it = vmap_.find(t); - if (it != vmap_.end()) { - PrimExpr ret = tir::CallNode::make( - op->dtype, it->second->op->name, op->args, - op->call_type, it->second->op, it->second->value_index); - found = true; - return this->VisitExpr(ret); - } + PrimExpr VisitExpr_(const tir::ProducerLoadNode* op) final { + PrimExpr expr = StmtExprMutator::VisitExpr_(op); + op = expr.as(); + CHECK(op != nullptr); + + Tensor t = Downcast(op->producer); + auto it = vmap_.find(t); + if (it != vmap_.end()) { + found = true; + return tir::ProducerLoad(it->second, op->indices); + } else { + return expr; } - return StmtExprMutator::VisitExpr_(op); } // whether it is found. @@ -220,22 +223,18 @@ class TensorReplacer : public tir::StmtExprMutator { const std::unordered_map& vmap_; }; -Stmt ReplaceTensor(Stmt stmt, - const std::unordered_map& replace) { +Stmt ReplaceTensor(Stmt stmt, const std::unordered_map& replace) { TensorReplacer repl(replace); Stmt ret = repl(stmt); return repl.found ? ret : stmt; } -PrimExpr ReplaceTensor(PrimExpr expr, - const std::unordered_map& replace) { +PrimExpr ReplaceTensor(PrimExpr expr, const std::unordered_map& replace) { TensorReplacer repl(replace); PrimExpr ret = repl(expr); return repl.found ? ret : expr; } - -Stmt Substitute(Stmt s, - const std::unordered_map& value_map) { +Stmt Substitute(Stmt s, const std::unordered_map& value_map) { std::unordered_map init; for (const auto& kv : value_map) { init[kv.first->var.get()] = kv.second; @@ -245,31 +244,31 @@ Stmt Substitute(Stmt s, IterVarType ForTypeToIterVarType(tir::ForType for_type) { switch (for_type) { - case ForType::Serial: - return kDataPar; - case ForType::Parallel: - return kParallelized; - case ForType::Vectorized: - return kVectorized; - case ForType::Unrolled: - return kUnrolled; - default: - return kDataPar; + case ForType::Serial: + return kDataPar; + case ForType::Parallel: + return kParallelized; + case ForType::Vectorized: + return kVectorized; + case ForType::Unrolled: + return kUnrolled; + default: + return kDataPar; } } tir::ForType IterVarTypeToForType(IterVarType iter_type) { switch (iter_type) { - case kDataPar: - return ForType::Serial; - case kParallelized: - return ForType::Parallel; - case kVectorized: - return ForType::Vectorized; - case kUnrolled: - return ForType::Unrolled; - default: - return ForType::Serial; + case kDataPar: + return ForType::Serial; + case kParallelized: + return ForType::Parallel; + case kVectorized: + return ForType::Vectorized; + case kUnrolled: + return ForType::Unrolled; + default: + return ForType::Serial; } } diff --git a/src/te/operation/op_util.h b/src/te/operation/op_util.h index 5e16b8e4a879..6c864fca67d5 100644 --- a/src/te/operation/op_util.h +++ b/src/te/operation/op_util.h @@ -24,13 +24,15 @@ #ifndef TVM_TE_OPERATION_OP_UTIL_H_ #define TVM_TE_OPERATION_OP_UTIL_H_ -#include #include +#include + #include #include #include -#include "../../tir/pass/ir_util.h" -#include "../../tir/pass/arg_binder.h" + +#include "../../tir/transforms/arg_binder.h" +#include "../../tir/transforms/ir_util.h" #include "../schedule/message_passing.h" namespace tvm { @@ -49,14 +51,12 @@ using tir::MergeNest; * \param p_value_map The result value of each IterVar. * \param debug_keep_trivial_loop Whether keep trivial loops with extent of 1 */ -std::vector > -MakeLoopNest(const Stage& stage, - const std::unordered_map& dom_map, - size_t begin_iter_pos, - bool new_loop_var, - const std::unordered_set& skip_iter, - std::unordered_map* p_value_map, - bool debug_keep_trivial_loop); +std::vector > MakeLoopNest(const Stage& stage, + const std::unordered_map& dom_map, + size_t begin_iter_pos, bool new_loop_var, + const std::unordered_set& skip_iter, + std::unordered_map* p_value_map, + bool debug_keep_trivial_loop); /*! * \brief Create a nest of if checking the predicates. @@ -71,15 +71,13 @@ std::vector MakeIfNest(const std::vector& predicates); * \param stmt The statement to be processed. * \param replace The replacement rule. */ -Stmt ReplaceTensor(Stmt stmt, - const std::unordered_map& replace); +Stmt ReplaceTensor(Stmt stmt, const std::unordered_map& replace); /*! * \brief Replace the tensor reference (especially in Call's) in stmt by the replace map. * \param expr The expression to be processed. * \param replace The replacement rule. */ -PrimExpr ReplaceTensor(PrimExpr expr, - const std::unordered_map& replace); +PrimExpr ReplaceTensor(PrimExpr expr, const std::unordered_map& replace); /*! * \brief Substitute the variables of stmt by value map. @@ -87,8 +85,7 @@ PrimExpr ReplaceTensor(PrimExpr expr, * \param value_map The value map. * \return Substituted result. */ -Stmt Substitute(Stmt stmt, - const std::unordered_map& value_map); +Stmt Substitute(Stmt stmt, const std::unordered_map& value_map); /*! * \brief Converts Halide ForType to its corresponding IterVarType diff --git a/src/te/operation/placeholder_op.cc b/src/te/operation/placeholder_op.cc index d48be4c53668..5b7ede314e49 100644 --- a/src/te/operation/placeholder_op.cc +++ b/src/te/operation/placeholder_op.cc @@ -29,20 +29,16 @@ namespace te { // PlaceholderOpNode TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "placeholder(" << op->name << ", " << op << ")"; -}); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "placeholder(" << op->name << ", " << op << ")"; + }); TVM_REGISTER_NODE_TYPE(PlaceholderOpNode); -int PlaceholderOpNode::num_outputs() const { - return 1; -} +int PlaceholderOpNode::num_outputs() const { return 1; } -Array PlaceholderOpNode::root_iter_vars() const { - return {}; -} +Array PlaceholderOpNode::root_iter_vars() const { return {}; } DataType PlaceholderOpNode::output_dtype(size_t i) const { CHECK_EQ(i, 0U); @@ -54,59 +50,48 @@ Array PlaceholderOpNode::output_shape(size_t i) const { return shape; } -Operation PlaceholderOpNode::make(std::string name, - Array shape, - DataType dtype) { +PlaceholderOp::PlaceholderOp(std::string name, Array shape, DataType dtype) { auto n = make_object(); n->name = name; n->shape = shape; n->dtype = dtype; - return Operation(n); + data_ = std::move(n); } Tensor placeholder(Array shape, DataType dtype, std::string name) { - return PlaceholderOpNode::make(name, shape, dtype).output(0); + return PlaceholderOp(name, shape, dtype).output(0); } TVM_REGISTER_GLOBAL("te.Placeholder") -.set_body_typed([](Array shape, DataType dtype, std::string name) { - return placeholder(shape, dtype, name); -}); + .set_body_typed([](Array shape, DataType dtype, std::string name) { + return placeholder(shape, dtype, name); + }); -Array PlaceholderOpNode::InputTensors() const { - return {}; -} +Array PlaceholderOpNode::InputTensors() const { return {}; } -Operation PlaceholderOpNode::ReplaceInputs( - const Operation& self, - const std::unordered_map& rmap) const { +Operation PlaceholderOpNode::ReplaceInputs(const Operation& self, + const std::unordered_map& rmap) const { return self; } void PlaceholderOpNode::PropBoundToInputs( - const Operation& self, - arith::Analyzer* analyzer, + const Operation& self, arith::Analyzer* analyzer, const std::unordered_map& dom_map, - std::unordered_map* out_dom_map) const { -} + std::unordered_map* out_dom_map) const {} -void PlaceholderOpNode::GatherBound( - const Operation& self, - const std::unordered_map& tensor_dom, - std::unordered_map* out_dom_map) const { -} +void PlaceholderOpNode::GatherBound(const Operation& self, + const std::unordered_map& tensor_dom, + std::unordered_map* out_dom_map) const {} -Stmt PlaceholderOpNode::BuildRealize( - const Stage& stage, - const std::unordered_map& realize_map, - const Stmt& body) const { +Stmt PlaceholderOpNode::BuildRealize(const Stage& stage, + const std::unordered_map& realize_map, + const Stmt& body) const { return body; } -Stmt PlaceholderOpNode::BuildProvide( - const Stage& stage, - const std::unordered_map& dom_map, - bool debug_keep_trivial_loop) const { +Stmt PlaceholderOpNode::BuildProvide(const Stage& stage, + const std::unordered_map& dom_map, + bool debug_keep_trivial_loop) const { return Stmt(); } } // namespace te diff --git a/src/te/operation/scan_op.cc b/src/te/operation/scan_op.cc index 2ee5b273d4f6..cc86d0f46e3b 100644 --- a/src/te/operation/scan_op.cc +++ b/src/te/operation/scan_op.cc @@ -24,28 +24,22 @@ #include #include #include -#include -#include "op_util.h" + #include "../schedule/graph.h" +#include "op_util.h" namespace tvm { namespace te { using namespace tir; TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "scan(" << op->name << ", " << op << ")"; -}); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "scan(" << op->name << ", " << op << ")"; + }); TVM_REGISTER_NODE_TYPE(ScanOpNode); -inline bool prove_equal(PrimExpr lhs, PrimExpr rhs) { - return is_zero(tir::Simplify(lhs - rhs)); -} - -int ScanOpNode::num_outputs() const { - return static_cast(update.size()); -} +int ScanOpNode::num_outputs() const { return static_cast(update.size()); } Array ScanOpNode::root_iter_vars() const { Array ret{scan_axis}; for (IterVar iv : spatial_axis_) { @@ -54,60 +48,52 @@ Array ScanOpNode::root_iter_vars() const { return ret; } -DataType ScanOpNode::output_dtype(size_t i) const { - return update[i]->dtype; -} +DataType ScanOpNode::output_dtype(size_t i) const { return update[i]->dtype; } Array ScanOpNode::output_shape(size_t i) const { CHECK_LT(i, state_placeholder.size()); return state_placeholder[i]->shape; } -Operation ScanOpNode::make(std::string name, - std::string tag, - Map attrs, - IterVar axis, - Array init, - Array update, - Array state_placeholder, - Array inputs) { +ScanOp::ScanOp(std::string name, std::string tag, Map attrs, IterVar axis, + Array init, Array update, Array state_placeholder, + Array inputs) { if (!attrs.defined()) { - attrs = Map(); + attrs = Map(); } auto n = make_object(); CHECK_EQ(init.size(), update.size()); CHECK_EQ(init.size(), state_placeholder.size()); + arith::Analyzer analyzer; + auto prove_equal = [&](PrimExpr lhs, PrimExpr rhs) { + return is_zero(analyzer.Simplify(lhs - rhs)); + }; for (size_t i = 0; i < init.size(); ++i) { CHECK_EQ(init[i]->dtype, state_placeholder[i]->dtype); CHECK_EQ(init[i]->dtype, update[i]->dtype); CHECK(prove_equal(init[i]->shape[0], axis->dom->min)) << "init.shape[0] need to match scan_axis.dom.min"; - CHECK(prove_equal( - state_placeholder[i]->shape[0], axis->dom->min + axis->dom->extent)) + CHECK(prove_equal(state_placeholder[i]->shape[0], axis->dom->min + axis->dom->extent)) << "state_placeholder.shape[0] need to match" << " scan_axis.dom.min + scan_axis.dom.extent"; CHECK_EQ(state_placeholder[i].ndim(), init[i].ndim()) << "The dimension of init need to match state_placeholder"; CHECK_EQ(update[i].ndim(), state_placeholder[i].ndim()) << "The update.ndim need to be state_placeholder.ndim - 1"; - for (size_t k = 0; k < update[i].ndim(); ++k) { - CHECK(prove_equal( - update[i]->shape[k], state_placeholder[i]->shape[k])); + for (size_t k = 0; k < update[i].ndim(); ++k) { + CHECK(prove_equal(update[i]->shape[k], state_placeholder[i]->shape[k])); if (k != 0) { // setup spatial axis std::ostringstream spatial_name; spatial_name << name << ".out" << i << ".i" << k; - n->spatial_axis_.push_back( - IterVarNode::make( - Range::make_by_min_extent(0, update[i]->shape[k]), - Var(spatial_name.str()), kOpaque)); + n->spatial_axis_.push_back(IterVar(Range::make_by_min_extent(0, update[i]->shape[k]), + Var(spatial_name.str()), kOpaque)); } } - for (size_t k = 1; k < init[i].ndim(); ++k) { - CHECK(prove_equal( - init[i]->shape[k], state_placeholder[i]->shape[k])); + for (size_t k = 1; k < init[i].ndim(); ++k) { + CHECK(prove_equal(init[i]->shape[k], state_placeholder[i]->shape[k])); } } n->name = std::move(name); @@ -118,28 +104,23 @@ Operation ScanOpNode::make(std::string name, n->update = std::move(update); n->state_placeholder = std::move(state_placeholder); n->inputs = std::move(inputs); - return Operation(n); + data_ = std::move(n); } TVM_REGISTER_GLOBAL("te.ScanOp") -.set_body_typed(ScanOpNode::make); - + .set_body_typed([](std::string name, std::string tag, Map attrs, + IterVar axis, Array init, Array update, + Array state_placeholder, Array inputs) { + return ScanOp(name, tag, attrs, axis, init, update, state_placeholder, inputs); + }); -Array scan(Array init, - Array update, - Array state_placeholder, - Array inputs, - std::string name, - std::string tag, - Map attrs) { +Array scan(Array init, Array update, Array state_placeholder, + Array inputs, std::string name, std::string tag, + Map attrs) { IterVar scan_axis = - IterVarNode::make( - Range::make_by_min_extent( - init[0]->shape[0], update[0]->shape[0] - init[0]->shape[0]), - Var(name + ".idx"), kOrdered); - Operation op = ScanOpNode::make( - name, tag, attrs, scan_axis, - init, update, state_placeholder, inputs); + IterVar(Range::make_by_min_extent(init[0]->shape[0], update[0]->shape[0] - init[0]->shape[0]), + Var(name + ".idx"), kOrdered); + Operation op = ScanOp(name, tag, attrs, scan_axis, init, update, state_placeholder, inputs); Array res; for (int i = 0; i < op->num_outputs(); ++i) { res.push_back(op.output(i)); @@ -158,9 +139,8 @@ Array ScanOpNode::InputTensors() const { return ret; } -Operation ScanOpNode::ReplaceInputs( - const Operation& self, - const std::unordered_map& rmap) const { +Operation ScanOpNode::ReplaceInputs(const Operation& self, + const std::unordered_map& rmap) const { CHECK_EQ(self.operator->(), this); auto n = make_object(*this); for (size_t i = 0; i < n->init.size(); ++i) { @@ -171,19 +151,16 @@ Operation ScanOpNode::ReplaceInputs( n->update.Set(i, rmap.at(n->update[i])); } } - if (!n->init.same_as(init) || - !n->update.same_as(update)) { + if (!n->init.same_as(init) || !n->update.same_as(update)) { return Operation(n); } else { return self; } } -void ScanOpNode::PropBoundToInputs( - const Operation& self, - arith::Analyzer* analyzer, - const std::unordered_map& dom_map, - std::unordered_map* out_dom_map) const { +void ScanOpNode::PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer, + const std::unordered_map& dom_map, + std::unordered_map* out_dom_map) const { CHECK_EQ(self.operator->(), this); for (size_t i = 0, sp_idx = 0; i < this->init.size(); ++i) { TensorDom* init_dom = nullptr; @@ -196,8 +173,8 @@ void ScanOpNode::PropBoundToInputs( } // first dimension, always needed. if (init_dom) { - init_dom->data[0].push_back(IntSet::range( - Range::make_by_min_extent(0, this->init[i]->shape[0]))); + init_dom->data[0].push_back( + IntSet::range(Range::make_by_min_extent(0, this->init[i]->shape[0]))); } if (update_dom) { update_dom->data[0].push_back(dom_map.at(this->scan_axis->var.get())); @@ -215,10 +192,9 @@ void ScanOpNode::PropBoundToInputs( } } -void ScanOpNode::GatherBound( - const Operation& self, - const std::unordered_map& tensor_dom, - std::unordered_map* out_dom_map) const { +void ScanOpNode::GatherBound(const Operation& self, + const std::unordered_map& tensor_dom, + std::unordered_map* out_dom_map) const { CHECK_EQ(self.operator->(), this); CHECK(!out_dom_map->count(this->scan_axis)); std::vector output(this->num_outputs()); @@ -232,10 +208,11 @@ void ScanOpNode::GatherBound( time_dom.insert(time_dom.end(), d.data[0].begin(), d.data[0].end()); } CHECK(!out_dom_map->count(this->scan_axis)); + arith::Analyzer analyzer; Range sdom = this->scan_axis->dom; Range r = arith::Union(time_dom).cover_range(sdom); - (*out_dom_map)[this->scan_axis] = Range::make_by_min_extent( - sdom->min, tir::Simplify(r->extent + r->min - sdom->min)); + (*out_dom_map)[this->scan_axis] = + Range::make_by_min_extent(sdom->min, analyzer.Simplify(r->extent + r->min - sdom->min)); Map fix_pt = ScanFixPointAnalysis(self); // Update for spatial axis. size_t sp_idx = 0; @@ -256,14 +233,12 @@ void ScanOpNode::GatherBound( } } -Stmt ScanOpNode::BuildRealize( - const Stage& stage, - const std::unordered_map& dom_map, - const Stmt& body) const { +Stmt ScanOpNode::BuildRealize(const Stage& stage, const std::unordered_map& dom_map, + const Stmt& body) const { + arith::Analyzer analyzer; CHECK_EQ(stage->op.get(), this); Range sdom = dom_map.at(this->scan_axis); - Range tdom = Range::make_by_min_extent( - 0, tir::Simplify(sdom->extent + sdom->min)); + Range tdom = Range::make_by_min_extent(0, analyzer.Simplify(sdom->extent + sdom->min)); Stmt ret = body; size_t sp_idx = 0; for (size_t i = 0; i < update.size(); ++i) { @@ -275,25 +250,19 @@ Stmt ScanOpNode::BuildRealize( IterVar sp_ax = this->spatial_axis_[sp_idx]; bounds.push_back(dom_map.at(sp_ax)); } - ret = tir::RealizeNode::make(t->op, t->value_index, t->dtype, - bounds, const_true(), ret); + ret = tir::ProducerRealize(t, bounds, const_true(), ret); } return ret; } -Stmt ScanOpNode::BuildProvide( - const Stage& stage, - const std::unordered_map& dom_map, - bool debug_keep_trivial_loop) const { +Stmt ScanOpNode::BuildProvide(const Stage& stage, const std::unordered_map& dom_map, + bool debug_keep_trivial_loop) const { CHECK_EQ(stage->op.operator->(), this); - Stmt provide = AttrStmtNode::make( - stage->op, tir::attr::scan_update_scope, this->scan_axis->var, - EvaluateNode::make(0)); - Stmt init = AttrStmtNode::make( - stage->op, tir::attr::scan_init_scope, 0, - EvaluateNode::make(0)); + Stmt provide = + AttrStmt(stage->op, tir::attr::scan_update_scope, this->scan_axis->var, Evaluate(0)); + Stmt init = AttrStmt(stage->op, tir::attr::scan_init_scope, 0, Evaluate(0)); size_t begin_scan = 0; - for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) { + for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) { if (stage->leaf_iter_vars[i]->iter_type == kThreadIndex) { CHECK_EQ(begin_scan, i); begin_scan = i + 1; @@ -301,12 +270,9 @@ Stmt ScanOpNode::BuildProvide( } std::unordered_map vmap; std::unordered_set empty; - auto nest = MakeLoopNest( - stage, dom_map, 0, false, empty, &vmap, debug_keep_trivial_loop); + auto nest = MakeLoopNest(stage, dom_map, 0, false, empty, &vmap, debug_keep_trivial_loop); nest[begin_scan].push_back(init); - nest.push_back( - MakeIfNest( - MakeBoundCheck(stage, dom_map, vmap, false, empty))); + nest.push_back(MakeIfNest(MakeBoundCheck(stage, dom_map, vmap, false, empty))); return MergeNest(nest, provide); } } // namespace te diff --git a/src/te/operation/tensor_compute_op.cc b/src/te/operation/tensor_compute_op.cc index 4cdc9e1f8d32..8d5265bcb14f 100644 --- a/src/te/operation/tensor_compute_op.cc +++ b/src/te/operation/tensor_compute_op.cc @@ -21,25 +21,26 @@ * \brief Tensor Compute Op. * \file tensor_compute_op.cc */ +#include #include #include -#include #include -#include +#include + #include -#include "./op_util.h" + #include "./compute_op.h" -#include "../../arith/compute_expr.h" +#include "./op_util.h" namespace tvm { namespace te { using namespace tir; // TensorComputeOpNode TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "tensor_compute_op(" << op->name << ", " << op << ")"; - }); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "tensor_compute_op(" << op->name << ", " << op << ")"; + }); TVM_REGISTER_NODE_TYPE(TensorComputeOpNode); @@ -51,15 +52,10 @@ DataType TensorComputeOpNode::output_dtype(size_t i) const { return this->intrin->buffers[this->inputs.size() + i]->dtype; } -Operation TensorComputeOpNode::make(std::string name, - std::string tag, - Array axis, - Array reduce_axis, - int schedulable_ndim, - TensorIntrin intrin, - Array tensors, - Array regions, - Array scalar_inputs) { +TensorComputeOp::TensorComputeOp(std::string name, std::string tag, Array axis, + Array reduce_axis, int schedulable_ndim, + TensorIntrin intrin, Array tensors, Array regions, + Array scalar_inputs) { auto n = make_object(); n->name = std::move(name); n->tag = std::move(tag); @@ -70,20 +66,22 @@ Operation TensorComputeOpNode::make(std::string name, n->inputs = std::move(tensors); n->input_regions = std::move(regions); n->scalar_inputs = std::move(scalar_inputs); - return Operation(n); + data_ = std::move(n); } TVM_REGISTER_GLOBAL("te.TensorComputeOp") -.set_body_typed(TensorComputeOpNode::make); - + .set_body_typed([](std::string name, std::string tag, Array axis, + Array reduce_axis, int schedulable_ndim, TensorIntrin intrin, + Array tensors, Array regions, + Array scalar_inputs) { + return TensorComputeOp(name, tag, axis, reduce_axis, schedulable_ndim, intrin, tensors, + regions, scalar_inputs); + }); -Array TensorComputeOpNode::InputTensors() const { - return inputs; -} +Array TensorComputeOpNode::InputTensors() const { return inputs; } -Operation TensorComputeOpNode::ReplaceInputs( - const Operation& self, - const std::unordered_map& rmap) const { +Operation TensorComputeOpNode::ReplaceInputs(const Operation& self, + const std::unordered_map& rmap) const { CHECK_EQ(self.operator->(), this); auto n = make_object(*this); auto intrin = make_object(*(this->intrin.operator->())); @@ -103,8 +101,7 @@ Operation TensorComputeOpNode::ReplaceInputs( if (intrin->body.same_as(n->intrin->body) && intrin->reduce_init.same_as(n->intrin->reduce_init) && - intrin->reduce_update.same_as(n->intrin->reduce_update) && - inputs.same_as(n->inputs)) { + intrin->reduce_update.same_as(n->intrin->reduce_update) && inputs.same_as(n->inputs)) { return self; } else { n->intrin = TensorIntrin(intrin); @@ -113,8 +110,7 @@ Operation TensorComputeOpNode::ReplaceInputs( } void TensorComputeOpNode::PropBoundToInputs( - const Operation& self, - arith::Analyzer* analyzer, + const Operation& self, arith::Analyzer* analyzer, const std::unordered_map& dom_map, std::unordered_map* out_dom_map) const { for (size_t i = 0; i < this->inputs.size(); ++i) { @@ -130,18 +126,15 @@ void TensorComputeOpNode::PropBoundToInputs( } } -size_t TensorComputeOpNode::num_schedulable_dims() const { - return schedulable_ndim; -} +size_t TensorComputeOpNode::num_schedulable_dims() const { return schedulable_ndim; } -Stmt TensorComputeOpNode::BuildProvide( - const Stage& stage, - const std::unordered_map& dom_map, - bool debug_keep_trivial_loop) const { +Stmt TensorComputeOpNode::BuildProvide(const Stage& stage, + const std::unordered_map& dom_map, + bool debug_keep_trivial_loop) const { CHECK_EQ(stage->op.operator->(), this); // Start bind data. - Stmt nop = EvaluateNode::make(0); + Stmt nop = Evaluate(0); std::vector input_bind_nest, output_bind_nest; Array inputs = this->InputTensors(); @@ -158,11 +151,9 @@ Stmt TensorComputeOpNode::BuildProvide( tuple.push_back(region[i]->min); tuple.push_back(region[i]->extent); } - input_bind_nest.emplace_back(AttrStmtNode::make( + input_bind_nest.emplace_back(AttrStmt( bind_spec, tir::attr::buffer_bind_scope, - CallNode::make(DataType::Handle(), - tir::intrinsic::tvm_tuple, - tuple, CallNode::Intrinsic), nop)); + Call(DataType::Handle(), tir::intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), nop)); } // output binding @@ -184,11 +175,9 @@ Stmt TensorComputeOpNode::BuildProvide( } } - output_bind_nest.emplace_back(AttrStmtNode::make( + output_bind_nest.emplace_back(AttrStmt( bind_spec, tir::attr::buffer_bind_scope, - CallNode::make(DataType::Handle(), - tir::intrinsic::tvm_tuple, - tuple, CallNode::Intrinsic), nop)); + Call(DataType::Handle(), tir::intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), nop)); } // Check variable remap @@ -209,11 +198,10 @@ Stmt TensorComputeOpNode::BuildProvide( binder.BindArray(sp_expr, user_expr, this->name); size_t tloc = stage->leaf_iter_vars.size(); - ComputeLoopNest n = ComputeLoopNest::make(this, stage, dom_map, debug_keep_trivial_loop); + ComputeLoopNest n = ComputeLoopNest::Create(this, stage, dom_map, debug_keep_trivial_loop); if (this->reduce_axis.size() == 0) { - std::vector > nest( - n.main_nest.begin(), n.main_nest.begin() + tloc + 1); + std::vector > nest(n.main_nest.begin(), n.main_nest.begin() + tloc + 1); nest.emplace_back(MakeIfNest(n.main_predicates)); CHECK_EQ(n.init_predicates.size(), 0U); CHECK(this->intrin->body.defined()) @@ -223,24 +211,23 @@ Stmt TensorComputeOpNode::BuildProvide( body = tir::Substitute(body, vmap); body = MergeNest(binder.asserts(), body); body = te::Substitute(body, n.main_vmap); - Stmt ret = MergeNest(nest, body); + Stmt ret = MergeNest(nest, body); return ret; } else { // Need to split reduction - CHECK(this->intrin->reduce_update.defined()) - << "Reduction update op is not defined"; + CHECK(this->intrin->reduce_update.defined()) << "Reduction update op is not defined"; // Need init and update steps CHECK_NE(this->reduce_axis.size(), 0U); - std::vector > common( - n.main_nest.begin(), n.main_nest.begin() + n.num_common_loop + 1); - std::vector > update_nest( - n.main_nest.begin() + n.num_common_loop + 1, n.main_nest.begin() + tloc + 1); + std::vector > common(n.main_nest.begin(), + n.main_nest.begin() + n.num_common_loop + 1); + std::vector > update_nest(n.main_nest.begin() + n.num_common_loop + 1, + n.main_nest.begin() + tloc + 1); update_nest.emplace_back(MakeIfNest(n.main_predicates)); if (this->intrin->reduce_init.defined()) { // init nest - std::vector > init_nest( - n.init_nest.begin(), n.init_nest.begin() + tloc + 1); + std::vector > init_nest(n.init_nest.begin(), + n.init_nest.begin() + tloc + 1); init_nest.emplace_back(MakeIfNest(n.init_predicates)); Stmt init = MergeNest(output_bind_nest, this->intrin->reduce_init); init = te::Substitute(init, n.init_vmap); @@ -255,11 +242,9 @@ Stmt TensorComputeOpNode::BuildProvide( return MergeNest(common, SeqStmt::Flatten(init, update)); } else { // When init op is not available, use body op for reset in the first iter. - CHECK(this->intrin->body.defined()) - << "Normal body op is not defined"; - Stmt update = TransformUpdate(stage, dom_map, n, - this->intrin->body, - this->intrin->reduce_update); + CHECK(this->intrin->body.defined()) << "Normal body op is not defined"; + Stmt update = + TransformUpdate(stage, dom_map, n, this->intrin->body, this->intrin->reduce_update); update = MergeNest(output_bind_nest, update); update = MergeNest(input_bind_nest, update); update = tir::Substitute(update, vmap); diff --git a/src/te/operation/tensorize.cc b/src/te/operation/tensorize.cc index 6064f5c4e008..82832c927785 100644 --- a/src/te/operation/tensorize.cc +++ b/src/te/operation/tensorize.cc @@ -21,15 +21,14 @@ * \brief Logics related to tensorize, used by ComputeOpNode. * \file tensorize.cc */ +#include +#include #include #include -#include -#include -#include -#include "op_util.h" -#include "compute_op.h" #include "../schedule/message_passing.h" +#include "compute_op.h" +#include "op_util.h" namespace tvm { namespace te { @@ -40,12 +39,10 @@ using namespace tir; // out_dom: the domain of root iter vars in output op // in_region: region of each input tensor. // return The location of the tensorized scope start. -size_t InferTensorizeRegion( - const ComputeOpNode* self, - const Stage& stage, - const std::unordered_map& dom_map, - std::unordered_map* out_dom, - std::unordered_map >* in_region) { +size_t InferTensorizeRegion(const ComputeOpNode* self, const Stage& stage, + const std::unordered_map& dom_map, + std::unordered_map* out_dom, + std::unordered_map >* in_region) { // Get the bound of the tensorized scope. bool found_point = false; size_t loc_scope = 0; @@ -53,8 +50,7 @@ size_t InferTensorizeRegion( // Loop over the leafs for (size_t i = stage->leaf_iter_vars.size(); i != 0; --i) { IterVar iv = stage->leaf_iter_vars[i - 1]; - CHECK(iv->iter_type == kDataPar || - iv->iter_type == kCommReduce); + CHECK(iv->iter_type == kDataPar || iv->iter_type == kCommReduce); auto vit = dom_map.find(iv); CHECK(vit != dom_map.end()); const Range& vrange = vit->second; @@ -70,8 +66,7 @@ size_t InferTensorizeRegion( if (iit != stage->iter_var_attrs.end()) { const IterVarAttr& attr = (*iit).second; if (!found_point) { - CHECK(!attr->bind_thread.defined()) - << "Do not allow thread in tensorize scope"; + CHECK(!attr->bind_thread.defined()) << "Do not allow thread in tensorize scope"; } if (attr->iter_type == kTensorized) { CHECK(!found_point) << "Do not allow two tensorized point"; @@ -114,18 +109,15 @@ size_t InferTensorizeRegion( return loc_scope; } -void VerifyTensorizeLoopNest(const ComputeOpNode* self, - const Stage& stage, - const ComputeLoopNest& n, - size_t tloc) { +void VerifyTensorizeLoopNest(const ComputeOpNode* self, const Stage& stage, + const ComputeLoopNest& n, size_t tloc) { // Veirfication step. std::unordered_set banned; CHECK_EQ(n.main_nest.size(), stage->leaf_iter_vars.size() + 1); - CHECK(n.init_nest.size() == stage->leaf_iter_vars.size() + 1 || - n.init_nest.size() == 0); + CHECK(n.init_nest.size() == stage->leaf_iter_vars.size() + 1 || n.init_nest.size() == 0); auto f_push_banned = [&banned](const Stmt& s) { if (const ForNode* op = s.as()) { - banned.insert(op->loop_var.get()); + banned.insert(op->loop_var.get()); } else if (const AttrStmtNode* op = s.as()) { if (const IterVarNode* iv = op->node.as()) { banned.insert(iv->var.get()); @@ -144,16 +136,19 @@ void VerifyTensorizeLoopNest(const ComputeOpNode* self, } } } + + auto fbanned = [&](const VarNode* node) { return banned.count(node); }; + for (const PrimExpr& pred : n.main_predicates) { - if (tir::ExprUseVar(pred, banned)) { - LOG(FATAL) << "Tensorize failed, split condition " - << pred << " relies on var defined inside tensorize scope"; + if (tir::ExprUseVar(pred, fbanned)) { + LOG(FATAL) << "Tensorize failed, split condition " << pred + << " relies on var defined inside tensorize scope"; } } for (const PrimExpr& pred : n.init_predicates) { - if (tir::ExprUseVar(pred, banned)) { - LOG(FATAL) << "Tensorize failed, split condition " - << pred << " relies on var defined inside tensorize scope"; + if (tir::ExprUseVar(pred, fbanned)) { + LOG(FATAL) << "Tensorize failed, split condition " << pred + << " relies on var defined inside tensorize scope"; } } } @@ -161,23 +156,19 @@ void VerifyTensorizeLoopNest(const ComputeOpNode* self, // Remap the tensor placeholder, index and inline things. class TensorIntrinMatcher final : public StmtExprMutator { public: - PrimExpr VisitExpr_(const CallNode* op) final { + PrimExpr VisitExpr_(const ProducerLoadNode* op) final { PrimExpr expr = StmtExprMutator::VisitExpr_(op); - op = expr.as(); - if (op->call_type == CallNode::Halide) { - Tensor t = Downcast(op->func).output(op->value_index); - auto it = in_remap_.find(t); - if (it != in_remap_.end()) { - const InputEntry& e = it->second; - CHECK_EQ(op->args.size(), e.region.size()); - Array args; - for (size_t i = e.start; i < e.region.size(); ++i) { - args.push_back(op->args[i] - e.region[i]->min); - } - return CallNode::make( - op->dtype, e.tensor->op->name, args, - op->call_type, e.tensor->op, e.tensor->value_index); + op = expr.as(); + auto t = Downcast(op->producer); + auto it = in_remap_.find(t); + if (it != in_remap_.end()) { + const InputEntry& e = it->second; + CHECK_EQ(op->indices.size(), e.region.size()); + Array indices; + for (size_t i = e.start; i < e.region.size(); ++i) { + indices.push_back(op->indices[i] - e.region[i]->min); } + return ProducerLoad(e.tensor, indices); } return expr; } @@ -201,16 +192,13 @@ class TensorIntrinMatcher final : public StmtExprMutator { axis.push_back(it->second); } } - return ReduceNode::make( - op->combiner, op->source, axis, op->condition, op->value_index); + return Reduce(op->combiner, op->source, axis, op->condition, op->value_index); } - void Init(const ComputeOpNode* self, - const Stage& stage, + void Init(const ComputeOpNode* self, const Stage& stage, const std::unordered_map& dom_map, const std::unordered_map& out_dom, - const std::unordered_map >& in_region, - const TensorIntrin& intrin, + const std::unordered_map >& in_region, const TensorIntrin& intrin, Map* compute_intrin_iter_space) { CHECK(self == stage->op.get()); @@ -222,6 +210,7 @@ class TensorIntrinMatcher final : public StmtExprMutator { compute_intrin_iter_space->Set(iv->var, vrange); } } + analyzer_.Bind(*compute_intrin_iter_space); // input remap. Array inputs = self->InputTensors(); @@ -234,12 +223,11 @@ class TensorIntrinMatcher final : public StmtExprMutator { // Enable fuzzy matching, to match [1, n, m] to [n, m] e.start = e.region.size() - e.tensor.ndim(); for (size_t j = 0; j < e.start; ++j) { - auto canonical_extent = Simplify(e.region[j]->extent, *compute_intrin_iter_space); + auto canonical_extent = analyzer_.Simplify(e.region[j]->extent); CHECK(is_one(canonical_extent)) << "Tensorize " << intrin->name << ":" << " Input dimension mismatch with tensor intrin " - << " expected shape=" << e.tensor->shape - << ", given region=" << e.region; + << " expected shape=" << e.tensor->shape << ", given region=" << e.region; } in_remap_[inputs[i]] = e; } @@ -252,10 +240,9 @@ class TensorIntrinMatcher final : public StmtExprMutator { size_t axis_start = self->axis.size() - intrin_compute->axis.size(); for (size_t i = 0; i < axis_start; ++i) { Range r = out_dom.at(self->axis[i]); - CHECK(is_one(r->extent)) - << "Tensorize: Output mismatch with tensor intrin " - << " intrin-dim=" << intrin_compute->axis.size() - << ", tensorize-dim=" << self->axis.size(); + CHECK(is_one(r->extent)) << "Tensorize: Output mismatch with tensor intrin " + << " intrin-dim=" << intrin_compute->axis.size() + << ", tensorize-dim=" << self->axis.size(); var_remap_[self->axis[i]->var.get()] = r->min; } // Assume we tensorize at regin axis i [min, min + extent) @@ -275,10 +262,9 @@ class TensorIntrinMatcher final : public StmtExprMutator { axis_start = self->reduce_axis.size() - intrin_compute->reduce_axis.size(); for (size_t i = 0; i < axis_start; ++i) { Range r = out_dom.at(self->reduce_axis[i]); - CHECK(is_one(r->extent)) - << "Tensorize: Reduction mismatch with tensor intrin " - << " intrin-dim=" << intrin_compute->reduce_axis.size() - << ", tensorize-dim=" << self->reduce_axis.size(); + CHECK(is_one(r->extent)) << "Tensorize: Reduction mismatch with tensor intrin " + << " intrin-dim=" << intrin_compute->reduce_axis.size() + << ", tensorize-dim=" << self->reduce_axis.size(); var_remap_[self->reduce_axis[i]->var.get()] = r->min; } for (size_t i = axis_start; i < self->reduce_axis.size(); ++i) { @@ -304,17 +290,17 @@ class TensorIntrinMatcher final : public StmtExprMutator { std::unordered_map var_remap_; // IterVar remap. std::unordered_map axis_remap_; + // arith analyzer + arith::Analyzer analyzer_; }; // Try to match tensor dataflow of the stage with the intrinsic -Array MatchTensorizeBody( - const ComputeOpNode* self, - const Stage& stage, - const std::unordered_map& dom_map, - const std::unordered_map& out_dom, - const std::unordered_map >& in_region, - const TensorIntrin& intrin, - Map* compute_intrin_iter_space) { +Array MatchTensorizeBody(const ComputeOpNode* self, const Stage& stage, + const std::unordered_map& dom_map, + const std::unordered_map& out_dom, + const std::unordered_map >& in_region, + const TensorIntrin& intrin, + Map* compute_intrin_iter_space) { TensorIntrinMatcher matcher; matcher.Init(self, stage, dom_map, out_dom, in_region, intrin, compute_intrin_iter_space); Array ret; @@ -324,60 +310,51 @@ Array MatchTensorizeBody( return ret; } -void VerifyTensorizeBody( - const ComputeOpNode* self, - const Stage& stage, - const std::unordered_map& dom_map, - const std::unordered_map& out_dom, - const std::unordered_map >& in_region, - const TensorIntrin& intrin) { +void VerifyTensorizeBody(const ComputeOpNode* self, const Stage& stage, + const std::unordered_map& dom_map, + const std::unordered_map& out_dom, + const std::unordered_map >& in_region, + const TensorIntrin& intrin) { StructuralEqual expr_equal; Map compute_intrin_iter_space; Array body = MatchTensorizeBody(self, stage, dom_map, out_dom, in_region, intrin, - &compute_intrin_iter_space); + &compute_intrin_iter_space); const ComputeOpNode* intrin_compute = intrin->op.as(); CHECK(intrin_compute) << "Only support compute intrinsic for now"; - CHECK_EQ(body.size(), intrin_compute->body.size()) - << "Tensorize failed: body size mismatch"; + CHECK_EQ(body.size(), intrin_compute->body.size()) << "Tensorize failed: body size mismatch"; + arith::Analyzer ana; + ana.Bind(compute_intrin_iter_space); + for (size_t i = 0; i < body.size(); ++i) { - PrimExpr lhs = Simplify(body[i], compute_intrin_iter_space); - lhs = CanonicalSimplify(lhs, compute_intrin_iter_space); - PrimExpr rhs = Simplify(intrin_compute->body[i], compute_intrin_iter_space); - rhs = CanonicalSimplify(rhs, compute_intrin_iter_space); + PrimExpr lhs = ana.Simplify(body[i]); + PrimExpr rhs = ana.Simplify(intrin_compute->body[i]); if (lhs.dtype() != rhs.dtype()) { - LOG(FATAL) - << "Failed to match the data type with TensorIntrin " - << intrin->name << "'s declaration " - << " provided=" << lhs.dtype() - << ", intrin=" << rhs.dtype(); + LOG(FATAL) << "Failed to match the data type with TensorIntrin " << intrin->name + << "'s declaration " + << " provided=" << lhs.dtype() << ", intrin=" << rhs.dtype(); } - CHECK(expr_equal(lhs, rhs)) - << "Failed to match the compute with TensorIntrin " - << intrin->name << "'s declaration " - << " provided= " << lhs - << ", intrin= " << rhs; + CHECK(expr_equal(lhs, rhs)) << "Failed to match the compute with TensorIntrin " << intrin->name + << "'s declaration " + << " provided= " << lhs << ", intrin= " << rhs; } } -Stmt MakeTensorize(const ComputeOpNode* self, - const Stage& stage, +Stmt MakeTensorize(const ComputeOpNode* self, const Stage& stage, const std::unordered_map& dom_map, bool debug_keep_trivial_loop) { std::unordered_map out_dom; std::unordered_map > in_region; size_t tloc = InferTensorizeRegion(self, stage, dom_map, &out_dom, &in_region); - TensorIntrin intrin = stage->iter_var_attrs.at( - stage->leaf_iter_vars[tloc])->tensor_intrin; + TensorIntrin intrin = stage->iter_var_attrs.at(stage->leaf_iter_vars[tloc])->tensor_intrin; CHECK(intrin.defined()); - ComputeLoopNest n = ComputeLoopNest::make(self, stage, dom_map, debug_keep_trivial_loop); + ComputeLoopNest n = ComputeLoopNest::Create(self, stage, dom_map, debug_keep_trivial_loop); VerifyTensorizeLoopNest(self, stage, n, tloc); VerifyTensorizeBody(self, stage, dom_map, out_dom, in_region, intrin); // Start bind data. - Stmt nop = EvaluateNode::make(0); + Stmt nop = Evaluate(0); std::vector input_bind_nest, output_bind_nest; Array inputs = self->InputTensors(); - CHECK_EQ(inputs.size(), intrin->inputs.size()) - << "Tensorize failed: input size mismatch "; + CHECK_EQ(inputs.size(), intrin->inputs.size()) << "Tensorize failed: input size mismatch "; // input binding for (size_t i = 0; i < intrin->inputs.size(); ++i) { Tensor tensor = inputs[i]; @@ -391,11 +368,9 @@ Stmt MakeTensorize(const ComputeOpNode* self, tuple.push_back(r->min); tuple.push_back(r->extent); } - input_bind_nest.emplace_back(AttrStmtNode::make( + input_bind_nest.emplace_back(AttrStmt( bind_spec, tir::attr::buffer_bind_scope, - CallNode::make(DataType::Handle(), - tir::intrinsic::tvm_tuple, - tuple, CallNode::Intrinsic), nop)); + Call(DataType::Handle(), tir::intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), nop)); } // output binding const ComputeOpNode* intrin_compute = intrin->op.as(); @@ -413,11 +388,9 @@ Stmt MakeTensorize(const ComputeOpNode* self, Tensor tensor = stage->op.output(i - intrin->inputs.size()); Buffer buffer = intrin->buffers[i]; Array bind_spec{buffer, tensor}; - output_bind_nest.emplace_back(AttrStmtNode::make( + output_bind_nest.emplace_back(AttrStmt( bind_spec, tir::attr::buffer_bind_scope, - CallNode::make(DataType::Handle(), - tir::intrinsic::tvm_tuple, - tuple, CallNode::Intrinsic), nop)); + Call(DataType::Handle(), tir::intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), nop)); } // Check variable remap std::unordered_map vmap; @@ -429,8 +402,7 @@ Stmt MakeTensorize(const ComputeOpNode* self, IterVar iv = self->reduce_axis[i]; auto it = out_dom.find(iv); CHECK(it != out_dom.end()); - CHECK(is_one(it->second->extent)) - << "Tensorization fail: reduction axis size do not match"; + CHECK(is_one(it->second->extent)) << "Tensorization fail: reduction axis size do not match"; } for (size_t i = start; i < self->reduce_axis.size(); ++i) { IterVar iv = self->reduce_axis[i]; @@ -439,17 +411,14 @@ Stmt MakeTensorize(const ComputeOpNode* self, CHECK(it != out_dom.end()); binder.Bind(target->dom->min, make_const(iv->dom->min.dtype(), 0), "tensir_intrin.reduction.min"); - binder.Bind(target->dom->extent, it->second->extent, - "tensir_intrin.reduction.extent"); + binder.Bind(target->dom->extent, it->second->extent, "tensir_intrin.reduction.extent"); } if (tloc <= n.num_common_loop) { // Do no need to split reduction - std::vector > nest( - n.main_nest.begin(), n.main_nest.begin() + tloc + 1); + std::vector > nest(n.main_nest.begin(), n.main_nest.begin() + tloc + 1); nest.emplace_back(MakeIfNest(n.main_predicates)); CHECK_EQ(n.init_predicates.size(), 0U); - CHECK(intrin->body.defined()) - << "Normal store op for intrin " << intrin << " is not defined"; + CHECK(intrin->body.defined()) << "Normal store op for intrin " << intrin << " is not defined"; Stmt body = MergeNest(output_bind_nest, intrin->body); body = MergeNest(input_bind_nest, body); body = tir::Substitute(body, vmap); @@ -462,16 +431,16 @@ Stmt MakeTensorize(const ComputeOpNode* self, << "Reduction update op for intrin " << intrin << " is not defined"; // Need init and update steps CHECK_NE(self->reduce_axis.size(), 0U); - std::vector > common( - n.main_nest.begin(), n.main_nest.begin() + n.num_common_loop + 1); - std::vector > update_nest( - n.main_nest.begin() + n.num_common_loop + 1, n.main_nest.begin() + tloc + 1); + std::vector > common(n.main_nest.begin(), + n.main_nest.begin() + n.num_common_loop + 1); + std::vector > update_nest(n.main_nest.begin() + n.num_common_loop + 1, + n.main_nest.begin() + tloc + 1); update_nest.emplace_back(MakeIfNest(n.main_predicates)); if (intrin->reduce_init.defined()) { // init nest - std::vector > init_nest( - n.init_nest.begin(), n.init_nest.begin() + tloc + 1); + std::vector > init_nest(n.init_nest.begin(), + n.init_nest.begin() + tloc + 1); init_nest.emplace_back(MakeIfNest(n.init_predicates)); Stmt init = MergeNest(output_bind_nest, intrin->reduce_init); init = te::Substitute(init, n.init_vmap); @@ -486,11 +455,8 @@ Stmt MakeTensorize(const ComputeOpNode* self, return MergeNest(common, SeqStmt::Flatten(init, update)); } else { // When init op is not available, use body op for reset in the first iter. - CHECK(intrin->body.defined()) - << "Normal body op for intrin " << intrin << " is not defined"; - Stmt update = TransformUpdate(stage, dom_map, n, - intrin->body, - intrin->reduce_update); + CHECK(intrin->body.defined()) << "Normal body op for intrin " << intrin << " is not defined"; + Stmt update = TransformUpdate(stage, dom_map, n, intrin->body, intrin->reduce_update); update = MergeNest(output_bind_nest, update); update = MergeNest(input_bind_nest, update); update = tir::Substitute(update, vmap); @@ -503,36 +469,26 @@ Stmt MakeTensorize(const ComputeOpNode* self, } // Register functions for unittests -TVM_REGISTER_GLOBAL("test.op.InferTensorizeRegion") -.set_body([](TVMArgs args, TVMRetValue* ret) { - Stage stage = args[0]; - Map dmap = args[1]; - std::unordered_map out_dom; - std::unordered_map > in_region; - CHECK(stage->op.as()); - InferTensorizeRegion(stage->op.as(), - stage, - as_unordered_map(dmap), - &out_dom, &in_region); - *ret = Array{Map(out_dom), - Map >(in_region)}; - }); +TVM_REGISTER_GLOBAL("test.op.InferTensorizeRegion").set_body([](TVMArgs args, TVMRetValue* ret) { + Stage stage = args[0]; + Map dmap = args[1]; + std::unordered_map out_dom; + std::unordered_map > in_region; + CHECK(stage->op.as()); + InferTensorizeRegion(stage->op.as(), stage, as_unordered_map(dmap), &out_dom, + &in_region); + *ret = Array{Map(out_dom), Map >(in_region)}; +}); -TVM_REGISTER_GLOBAL("test.op.MatchTensorizeBody") -.set_body([](TVMArgs args, TVMRetValue* ret) { - Stage stage = args[0]; - Map out_dom = args[1]; - Map > in_region = args[2]; - TensorIntrin intrin = args[3]; - Map vrange; - CHECK(stage->op.as()); - *ret = MatchTensorizeBody(stage->op.as(), - stage, - {{}}, - as_unordered_map(out_dom), - as_unordered_map(in_region), - intrin, - &vrange); - }); +TVM_REGISTER_GLOBAL("test.op.MatchTensorizeBody").set_body([](TVMArgs args, TVMRetValue* ret) { + Stage stage = args[0]; + Map out_dom = args[1]; + Map > in_region = args[2]; + TensorIntrin intrin = args[3]; + Map vrange; + CHECK(stage->op.as()); + *ret = MatchTensorizeBody(stage->op.as(), stage, {{}}, as_unordered_map(out_dom), + as_unordered_map(in_region), intrin, &vrange); +}); } // namespace te } // namespace tvm diff --git a/src/te/schedule/auto_inline_elem_wise.cc b/src/te/schedule/auto_inline_elem_wise.cc index 6d79f4a8d1d6..e2b7215158b2 100644 --- a/src/te/schedule/auto_inline_elem_wise.cc +++ b/src/te/schedule/auto_inline_elem_wise.cc @@ -21,8 +21,8 @@ * \file auto_inline_elem_wise.cc */ #include -#include #include +#include #include namespace tvm { @@ -61,7 +61,6 @@ class ElemWiseDetector : public tir::ExprVisitor { Array axis_; }; - bool IsElemWise(const Operation& op) { if (const ComputeOpNode* compute = op.as()) { ElemWiseDetector v = ElemWiseDetector(compute->axis); @@ -112,12 +111,9 @@ void AutoInlineInjective(Schedule sch) { } } -TVM_REGISTER_GLOBAL("schedule.AutoInlineElemWise") -.set_body_typed(AutoInlineElemWise); - +TVM_REGISTER_GLOBAL("schedule.AutoInlineElemWise").set_body_typed(AutoInlineElemWise); -TVM_REGISTER_GLOBAL("schedule.AutoInlineInjective") -.set_body_typed(AutoInlineInjective); +TVM_REGISTER_GLOBAL("schedule.AutoInlineInjective").set_body_typed(AutoInlineInjective); } // namespace te } // namespace tvm diff --git a/src/te/schedule/bound.cc b/src/te/schedule/bound.cc index 50cbafd2b654..099f4882f16c 100644 --- a/src/te/schedule/bound.cc +++ b/src/te/schedule/bound.cc @@ -22,14 +22,15 @@ * \brief The bound inference logic. */ #include -#include #include -#include +#include + #include #include + +#include "../../runtime/thread_storage_scope.h" #include "graph.h" #include "message_passing.h" -#include "../../runtime/thread_storage_scope.h" namespace tvm { namespace te { @@ -50,41 +51,35 @@ struct GraphContext { std::unordered_map op2stage_; }; -bool NeedRelax(const IterVar& iv, - bool found_attach, +bool NeedRelax(const IterVar& iv, bool found_attach, const std::unordered_map& bind_map, const runtime::StorageScope& scope) { auto it = bind_map.find(iv); - const std::string& tag = ( - it != bind_map.end() ? it->second->thread_tag : iv->thread_tag); + const std::string& tag = (it != bind_map.end() ? it->second->thread_tag : iv->thread_tag); if (tag.length() == 0 || tag == "pipeline") { return !found_attach; } - ThreadScope ts = ThreadScope::make(tag); + ThreadScope ts = ThreadScope::Create(tag); // When there is warp memory // threadIdx.x must be set to be warp index. - if (scope.rank == StorageRank::kWarp && - ts.rank == 1 && - ts.dim_index == 0) { + if (scope.rank == StorageRank::kWarp && ts.rank == 1 && ts.dim_index == 0) { return true; } return static_cast(scope.rank) <= ts.rank; } // infer storage scope, if not given -StorageScope InferStorageScope( - const Stage& stage, const GraphContext& ctx) { +StorageScope InferStorageScope(const Stage& stage, const GraphContext& ctx) { if (stage->scope.length() != 0) { - return StorageScope::make(stage->scope); + return StorageScope::Create(stage->scope); } int max_rank = -1; for (IterVar iv : ctx.attach_path.at(stage->op)) { auto it = ctx.bind_map.find(iv); - const std::string& tag = ( - it != ctx.bind_map.end() ? it->second->thread_tag : iv->thread_tag); + const std::string& tag = (it != ctx.bind_map.end() ? it->second->thread_tag : iv->thread_tag); if (tag != "pipeline" && tag.length() != 0) { - max_rank = std::max(max_rank, ThreadScope::make(tag).rank); + max_rank = std::max(max_rank, ThreadScope::Create(tag).rank); } } StorageScope s; @@ -92,20 +87,16 @@ StorageScope InferStorageScope( return s; } - -void InferRootBound(const Stage& stage, - const GraphContext& ctx, +void InferRootBound(const Stage& stage, const GraphContext& ctx, std::unordered_map* rmap) { - CHECK_NE(stage->attach_type, kInline) - << "call schedule.normalize before scheduleops"; + CHECK_NE(stage->attach_type, kInline) << "call schedule.normalize before scheduleops"; if (stage->attach_type == kInlinedAlready) return; if (stage->is_output) { // verify correctness. - CHECK_EQ(stage.GetAttachSpec()->attach_type, kGroupRoot) - << "Output must be attached at root"; + CHECK_EQ(stage.GetAttachSpec()->attach_type, kGroupRoot) << "Output must be attached at root"; } if (stage->is_output || stage->op.as()) { - for (auto iv : stage->op->root_iter_vars()) { + for (auto iv : stage->op->root_iter_vars()) { CHECK(iv->dom.defined()); CHECK(!rmap->count(iv)); (*rmap)[iv] = iv->dom; @@ -138,7 +129,7 @@ void InferRootBound(const Stage& stage, Array stage_attach = ctx.attach_path.at(stage->op); // The parent set. for (const Operation& op : consumers) { - std::unordered_map relax_set; + Map relax_set; std::unordered_map up_state; bool found_attach = false; CHECK(ctx.op2stage_.count(op.get())); @@ -155,9 +146,8 @@ void InferRootBound(const Stage& stage, if (is_one(vrange->extent)) { up_state[iv] = IntSet::single_point(vrange->min); } else if (!NeedRelax(iv, found_attach, ctx.bind_map, scope)) { - CHECK(is_zero(vrange->min)) - << "InferBound requires every leaf iter var's min equals 0, " - << " call schedule.normalize to achieve this. "; + CHECK(is_zero(vrange->min)) << "InferBound requires every leaf iter var's min equals 0, " + << " call schedule.normalize to achieve this. "; if (ctx.bind_map.count(iv)) { up_state[iv] = IntSet::single_point(ctx.bind_map.at(iv)->var); } else { @@ -173,13 +163,12 @@ void InferRootBound(const Stage& stage, found_attach = true; } Range vrange = rmap->at(iv); - CHECK(is_zero(vrange->min)) - << "InferBound requires every leaf iter var's min equals 0, " - << "call schedule.normalize to achieve this."; + CHECK(is_zero(vrange->min)) << "InferBound requires every leaf iter var's min equals 0, " + << "call schedule.normalize to achieve this."; if (NeedRelax(iv, found_attach, ctx.bind_map, scope)) { - relax_set[iv->var.get()] = IntSet::range(vrange); + relax_set.Set(iv->var, IntSet::range(vrange)); if (ctx.bind_map.count(iv)) { - relax_set[ctx.bind_map.at(iv)->var.get()] = IntSet::range(vrange); + relax_set.Set(ctx.bind_map.at(iv)->var, IntSet::range(vrange)); } } } @@ -191,6 +180,9 @@ void InferRootBound(const Stage& stage, // Relax if needed. std::unordered_map dom_map; arith::Analyzer analyzer; + for (auto entry : *rmap) { + analyzer.Bind(entry.first->var, entry.second); + } for (auto iv : op->root_iter_vars()) { Range r; if (up_state.count(iv)) { @@ -199,11 +191,13 @@ void InferRootBound(const Stage& stage, r = iv->dom; } if (relax_set.size() != 0) { - dom_map[iv->var.get()] = EvalSet(r, relax_set); + dom_map[iv->var.get()] = + IntSet::interval(analyzer.int_set(r->min, relax_set).min(), + analyzer.int_set(r->min + r->extent - 1, relax_set).max()); } else { dom_map[iv->var.get()] = IntSet::range(r); } - analyzer.Bind(iv->var, r); + analyzer.Bind(iv->var, r, true); } op->PropBoundToInputs(op, &analyzer, dom_map, &tmap); } @@ -253,15 +247,13 @@ Map InferBound(const Schedule& sch) { } } for (auto& p : ret) { - ret[p.first] = Range::make_by_min_extent( - analyzer.Simplify(p.second->min), - analyzer.Simplify(p.second->extent)); + ret[p.first] = Range::make_by_min_extent(analyzer.Simplify(p.second->min), + analyzer.Simplify(p.second->extent)); } return Map(ret.begin(), ret.end()); } -TVM_REGISTER_GLOBAL("schedule.InferBound") -.set_body_typed(InferBound); +TVM_REGISTER_GLOBAL("schedule.InferBound").set_body_typed(InferBound); } // namespace te } // namespace tvm diff --git a/src/te/schedule/graph.cc b/src/te/schedule/graph.cc index 9dce36f220ef..09e899581d14 100644 --- a/src/te/schedule/graph.cc +++ b/src/te/schedule/graph.cc @@ -21,40 +21,32 @@ * \file graph.cc * \brief Utilities to get information about schedule graph. */ +#include "graph.h" + #include +#include #include #include -#include -#include -#include + #include -#include "graph.h" +#include +#include namespace tvm { namespace te { // key to specific tensor dimension. struct TensorDimKey { - tir::FunctionRef f; + Operation op; int value_index; int dim; TensorDimKey() {} - TensorDimKey(const tir::CallNode* op, int dim) - : f(op->func), value_index(op->value_index), dim(dim) { - } - TensorDimKey(const Tensor& t, int dim) - : f(t->op), value_index(t->value_index), dim(dim) { - } + TensorDimKey(const Tensor& t, int dim) : op(t->op), value_index(t->value_index), dim(dim) {} TensorDimKey(const Tensor& t, size_t dim) - : f(t->op), value_index(t->value_index), dim(static_cast(dim)) { - } + : op(t->op), value_index(t->value_index), dim(static_cast(dim)) {} inline bool operator==(const TensorDimKey& other) const { - return f == other.f && - value_index == other.value_index && - dim == other.dim; - } - inline bool operator!=(const TensorDimKey& other) const { - return !operator==(other); + return op == other.op && value_index == other.value_index && dim == other.dim; } + inline bool operator!=(const TensorDimKey& other) const { return !operator==(other); } }; } // namespace te } // namespace tvm @@ -63,16 +55,14 @@ namespace std { template <> struct hash<::tvm::te::TensorDimKey> { std::size_t operator()(const ::tvm::te::TensorDimKey& k) const { - size_t lhs = ::tvm::ObjectHash()(k.f); - size_t rhs = static_cast(k.value_index) << 16UL | - static_cast(k.dim); + size_t lhs = ::tvm::ObjectPtrHash()(k.op); + size_t rhs = static_cast(k.value_index) << 16UL | static_cast(k.dim); lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2); return lhs; } }; } // namespace std - namespace tvm { namespace te { @@ -105,12 +95,9 @@ ReadGraph CreateReadGraph(const Array& roots) { // Do DFS visit to get the subgraph. // Return if op is inside the subgraph. -bool GetSubGraphByPostDFS_( - const Operation& op, - const std::unordered_set& boundary, - bool include_bounary, - std::unordered_map* visited, - Array* result) { +bool GetSubGraphByPostDFS_(const Operation& op, const std::unordered_set& boundary, + bool include_bounary, std::unordered_map* visited, + Array* result) { if (visited->count(op.get())) { return visited->at(op.get()); } @@ -127,9 +114,7 @@ bool GetSubGraphByPostDFS_( // check if we can reach boundary. bool reach_boundary = false; for (Tensor t : op->InputTensors()) { - if (GetSubGraphByPostDFS_(t->op, boundary, - include_bounary, - visited, result)) { + if (GetSubGraphByPostDFS_(t->op, boundary, include_bounary, visited, result)) { reach_boundary = true; } } @@ -140,8 +125,7 @@ bool GetSubGraphByPostDFS_( return reach_boundary; } -Array GetSubGraph(const Array& outputs, - const Array& inputs, +Array GetSubGraph(const Array& outputs, const Array& inputs, bool include_inputs) { Array result; std::unordered_set boundary; @@ -150,16 +134,12 @@ Array GetSubGraph(const Array& outputs, } std::unordered_map visited; for (Tensor t : outputs) { - GetSubGraphByPostDFS_(t->op, boundary, include_inputs, - &visited, &result); + GetSubGraphByPostDFS_(t->op, boundary, include_inputs, &visited, &result); } return result; } - -void PostDFSOrder(const Operation& op, - const ReadGraph& g, - std::unordered_set* visited, +void PostDFSOrder(const Operation& op, const ReadGraph& g, std::unordered_set* visited, Array* post_order) { if (visited->count(op)) return; visited->insert(op); @@ -169,9 +149,7 @@ void PostDFSOrder(const Operation& op, post_order->push_back(op); } -Array PostDFSOrder( - const Array& roots, - const ReadGraph& g) { +Array PostDFSOrder(const Array& roots, const ReadGraph& g) { std::unordered_set visited; Array post_order; for (Operation op : roots) { @@ -196,8 +174,7 @@ AttachPath CreateAttachPath(Schedule sch) { std::unordered_set visited; Array path; for (Stage s = stage; s.defined();) { - CHECK(!visited.count(s.get())) - << "Find loop in compute_at attach group"; + CHECK(!visited.count(s.get())) << "Find loop in compute_at attach group"; visited.insert(s.get()); Stage spec = s.GetAttachSpec(); bool start_attach; @@ -221,9 +198,8 @@ AttachPath CreateAttachPath(Schedule sch) { } if (start_attach) path.push_back(iv); } - CHECK(start_attach) - << "Invalid Schedule: cannot find attach point " << attach_ivar - << " in the schedule of " << s->op; + CHECK(start_attach) << "Invalid Schedule: cannot find attach point " << attach_ivar + << " in the schedule of " << s->op; } if (!ret.count(stage->op)) { ret.Set(stage->op, path); @@ -233,7 +209,7 @@ AttachPath CreateAttachPath(Schedule sch) { } // graph of push reach relation of tensor dimensions -using ReachGraph = std::unordered_map >; +using ReachGraph = std::unordered_map>; ReachGraph GetReachGraph(const Array& ops) { ReachGraph reach; @@ -249,10 +225,8 @@ ReachGraph GetReachGraph(const Array& ops) { for (size_t i = 0; i < update.size(); ++i) { Tensor t = op.output(i); for (int k = 1; k < static_cast(update[i]->shape.size()); ++k) { - reach[TensorDimKey(t, k)].emplace_back( - TensorDimKey(update[i], k)); - reach[TensorDimKey(t, k)].emplace_back( - TensorDimKey(init[i], k)); + reach[TensorDimKey(t, k)].emplace_back(TensorDimKey(update[i], k)); + reach[TensorDimKey(t, k)].emplace_back(TensorDimKey(init[i], k)); } } } else if (const auto* compute_op = op.as()) { @@ -264,19 +238,19 @@ ReachGraph GetReachGraph(const Array& ops) { reach[TensorDimKey(t, i)] = {}; } auto fvisit = [&vmap, &reach, &bset](const ObjectRef& n) { - const tir::CallNode *call = n.as(); - if (call != nullptr && call->func.defined()) { - if (!bset.count(call->func.get())) return; - for (size_t i = 0; i < call->args.size(); ++i) { - TensorDimKey dkey(call, static_cast(i)); + if (auto* pload = n.as()) { + Tensor t = Downcast(pload->producer); + if (!bset.count(t->op.get())) return; + for (size_t i = 0; i < pload->indices.size(); ++i) { + TensorDimKey dkey(t, static_cast(i)); auto fpush = [&dkey, &vmap, &reach](const ObjectRef& node) { - const VarNode *v = node.as(); + const VarNode* v = node.as(); auto it = vmap.find(v); if (it != vmap.end()) { reach[it->second].push_back(dkey); } }; - tir::PostOrderVisit(call->args[i], fpush); + tir::PostOrderVisit(pload->indices[i], fpush); } } }; @@ -315,8 +289,7 @@ Map ScanFixPointAnalysis(const Operation& scan_op) { } } // merge exact reach - auto f_merge_key = [&exact_reach, &fail_set]( - const TensorDimKey& dst, const TensorDimKey& src) { + auto f_merge_key = [&exact_reach, &fail_set](const TensorDimKey& dst, const TensorDimKey& src) { auto sit = exact_reach.find(src); if (sit == exact_reach.end()) return; auto dit = exact_reach.find(dst); @@ -343,7 +316,7 @@ Map ScanFixPointAnalysis(const Operation& scan_op) { } } } else if (const auto* compute_op = op.as()) { - std::unordered_map > vmap; + std::unordered_map> vmap; const auto& axis = compute_op->axis; for (size_t i = 0; i < axis.size(); ++i) { std::vector keys; @@ -352,13 +325,12 @@ Map ScanFixPointAnalysis(const Operation& scan_op) { } vmap[axis[i]->var.get()] = std::move(keys); } - auto fvisit = [&vmap, &f_merge_key, &exact_reach, &fail_set]( - const ObjectRef& n) { - const tir::CallNode *call = n.as(); - if (call != nullptr && call->func.defined()) { - for (size_t i = 0; i < call->args.size(); ++i) { - auto it = vmap.find(call->args[i].get()); - TensorDimKey src(call, static_cast(i)); + auto fvisit = [&vmap, &f_merge_key, &exact_reach, &fail_set](const ObjectRef& n) { + if (auto* pload = n.as()) { + Tensor t = Downcast(pload->producer); + for (size_t i = 0; i < pload->indices.size(); ++i) { + auto it = vmap.find(pload->indices[i].get()); + TensorDimKey src(t, static_cast(i)); if (it != vmap.end()) { const std::vector& keys = it->second; for (const auto& key : keys) { @@ -391,8 +363,7 @@ Map ScanFixPointAnalysis(const Operation& scan_op) { TensorDimKey key(scan->update[i], k); TensorDimKey target(scan->state_placeholder[i], k); IterVar sp_iv = scan->spatial_axis_[sp_idx]; - if (fail_set.count(sp_iv.get()) || - !exact_reach.count(key) || + if (fail_set.count(sp_iv.get()) || !exact_reach.count(key) || exact_reach.at(key) != sp_iv.get()) { ret.Set(sp_iv, make_const(DataType::Int(32), 0)); } else { @@ -407,7 +378,7 @@ Map ScanFixPointAnalysis(const Operation& scan_op) { if (k != target && place_holder_ref.count(k)) break; stack.pop_back(); if (!reach.count(k)) { - LOG(FATAL) << "cannot find reach of " << k.f << "-" << k.dim; + LOG(FATAL) << "cannot find reach of " << k.op << "-" << k.dim; } for (TensorDimKey kk : reach.at(k)) { @@ -430,24 +401,18 @@ Map ScanFixPointAnalysis(const Operation& scan_op) { return ret; } - -TVM_REGISTER_GLOBAL("schedule.CreateReadGraph") -.set_body_typed(CreateReadGraph); +TVM_REGISTER_GLOBAL("schedule.CreateReadGraph").set_body_typed(CreateReadGraph); TVM_REGISTER_GLOBAL("schedule.PostDFSOrder") -.set_body_typed([](const Array& roots, - const ReadGraph& g) { - return PostDFSOrder(roots, g); -}); + .set_body_typed([](const Array& roots, const ReadGraph& g) { + return PostDFSOrder(roots, g); + }); -TVM_REGISTER_GLOBAL("schedule.CreateAttachPath") -.set_body_typed(CreateAttachPath); +TVM_REGISTER_GLOBAL("schedule.CreateAttachPath").set_body_typed(CreateAttachPath); -TVM_REGISTER_GLOBAL("schedule.ScanGetBody") -.set_body_typed(ScanGetBody); +TVM_REGISTER_GLOBAL("schedule.ScanGetBody").set_body_typed(ScanGetBody); -TVM_REGISTER_GLOBAL("schedule.ScanFixPointAnalysis") -.set_body_typed(ScanFixPointAnalysis); +TVM_REGISTER_GLOBAL("schedule.ScanFixPointAnalysis").set_body_typed(ScanFixPointAnalysis); } // namespace te } // namespace tvm diff --git a/src/te/schedule/graph.h b/src/te/schedule/graph.h index c3478c705145..bb98ff4b706d 100644 --- a/src/te/schedule/graph.h +++ b/src/te/schedule/graph.h @@ -24,9 +24,10 @@ #ifndef TVM_TE_SCHEDULE_GRAPH_H_ #define TVM_TE_SCHEDULE_GRAPH_H_ -#include -#include #include +#include +#include + #include #include #include @@ -72,8 +73,7 @@ ReadGraph CreateReadGraph(const Array& roots); * * \return The subgraph. */ -Array GetSubGraph(const Array& outputs, - const Array& inputs, +Array GetSubGraph(const Array& outputs, const Array& inputs, bool include_inputs); /*! @@ -85,8 +85,7 @@ Array GetSubGraph(const Array& outputs, * \note PostDFSOrder is a special case of Topoligical order, * and can be used when topoligical order is needed. */ -Array PostDFSOrder( - const Array& roots, const ReadGraph& g); +Array PostDFSOrder(const Array& roots, const ReadGraph& g); /*! * \brief Create feedgraph for given Schedule diff --git a/src/te/schedule/message_passing.cc b/src/te/schedule/message_passing.cc index 4ff8586bccdf..55593be34212 100644 --- a/src/te/schedule/message_passing.cc +++ b/src/te/schedule/message_passing.cc @@ -21,33 +21,28 @@ * \file message_passing.cc * \brief The message passing domain. */ +#include "message_passing.h" + #include #include -#include -#include "message_passing.h" -#include "../../arith/compute_expr.h" namespace tvm { namespace te { using namespace tir; -void Update(std::unordered_map* p_state, - const IterVar& iv, - Range r, +void Update(std::unordered_map* p_state, const IterVar& iv, Range r, arith::Analyzer* analyzer) { auto it = p_state->find(iv); if (it == p_state->end()) { (*p_state)[iv] = r; analyzer->Bind(iv->var, r); } else { - bool match = is_zero(it->second->min) && - analyzer->CanProve(r->extent - it->second->extent == 0); - CHECK(match) - << iv - << " domain already inferred," - << " cannot prove their extents are the same " - << it->second->extent << " vs " << r->extent; + bool match = + is_zero(it->second->min) && analyzer->CanProve(r->extent - it->second->extent == 0); + CHECK(match) << iv << " domain already inferred," + << " cannot prove their extents are the same " << it->second->extent << " vs " + << r->extent; } } @@ -90,10 +85,8 @@ void PassUpThreadBinding(const Stage& stage, std::unordered_map* } } -void PassDownDomain(const Stage& stage, - std::unordered_map* p_state, - arith::Analyzer* actx, - bool allow_missing) { +void PassDownDomain(const Stage& stage, std::unordered_map* p_state, + arith::Analyzer* actx, bool allow_missing) { auto ceil_div = [actx](const PrimExpr& a, const PrimExpr& b) { if (actx->CanProve(indexmod(a, b) == 0)) { return actx->Simplify(indexdiv(a, b)); @@ -101,7 +94,7 @@ void PassDownDomain(const Stage& stage, return actx->Simplify(indexdiv(a + (b - 1), b)); }; - auto minimum_or_later = [actx](const PrimExpr& a, const PrimExpr& b) { + auto minimum_or_later = [actx](const PrimExpr& a, const PrimExpr& b) { if (actx->CanProve(a < b)) { return actx->Simplify(a); } @@ -139,20 +132,16 @@ void PassDownDomain(const Stage& stage, }; if (r->factor.defined()) { Update(p_state, r->inner, - Range::make_by_min_extent( - 0, resolve_min_extent_for_split(r->inner, r->factor)), + Range::make_by_min_extent(0, resolve_min_extent_for_split(r->inner, r->factor)), actx); Update(p_state, r->outer, - Range::make_by_min_extent( - 0, ceil_div(range_parent->extent, r->factor)), actx); + Range::make_by_min_extent(0, ceil_div(range_parent->extent, r->factor)), actx); } else { Update(p_state, r->outer, - Range::make_by_min_extent( - 0, resolve_min_extent_for_split(r->outer, r->nparts)), + Range::make_by_min_extent(0, resolve_min_extent_for_split(r->outer, r->nparts)), actx); Update(p_state, r->inner, - Range::make_by_min_extent( - 0, ceil_div(range_parent->extent, r->nparts)), actx); + Range::make_by_min_extent(0, ceil_div(range_parent->extent, r->nparts)), actx); } } else if (const FuseNode* r = rel.as()) { if (!state.count(r->outer) || !state.count(r->inner)) { @@ -161,16 +150,13 @@ void PassDownDomain(const Stage& stage, } const Range& range_outer = state.at(r->outer); const Range& range_inner = state.at(r->inner); - state[r->fused] = Range::make_by_min_extent( - 0, range_outer->extent * range_inner->extent); + state[r->fused] = Range::make_by_min_extent(0, range_outer->extent * range_inner->extent); } else if (const RebaseNode* r = rel.as()) { if (!state.count(r->parent)) { CHECK(allow_missing); continue; } - Update(p_state, r->rebased, - Range::make_by_min_extent( - 0, state.at(r->parent)->extent), actx); + Update(p_state, r->rebased, Range::make_by_min_extent(0, state.at(r->parent)->extent), actx); } else if (const SingletonNode* s = rel.as()) { Update(p_state, s->iter, Range::make_by_min_extent(0, 1), actx); } else { @@ -186,10 +172,8 @@ void PassDownDomain(const Stage& stage, } } -void PassUpIndex(const Stage& stage, - const Map& dom_map, - std::unordered_map* p_state, - bool allow_missing) { +void PassUpIndex(const Stage& stage, const Map& dom_map, + std::unordered_map* p_state, bool allow_missing) { auto& state = *p_state; for (size_t i = stage->relations.size(); i != 0; --i) { IterVarRelation rel = stage->relations[i - 1]; @@ -245,10 +229,8 @@ void PassUpIndex(const Stage& stage, } } -void PassDownIndex(const Stage& stage, - const Map& dom_map, - std::unordered_map* p_state, - bool allow_missing) { +void PassDownIndex(const Stage& stage, const Map& dom_map, + std::unordered_map* p_state, bool allow_missing) { auto& state = *p_state; for (IterVarRelation rel : stage->relations) { if (const SplitNode* s = rel.as()) { @@ -293,16 +275,10 @@ void PassDownIndex(const Stage& stage, } // Domain message passing. -void PassUpDomain(const SplitNode* s, - const std::unordered_map& dom_map, - const IntSet& outer, - const IntSet& inner, - IntSet* parent) { - if (dom_map.count(s->outer) && - dom_map.count(s->inner) && - dom_map.count(s->parent) && - outer.match_range(dom_map.at(s->outer)) && - inner.match_range(dom_map.at(s->inner))) { +void PassUpDomain(const SplitNode* s, const std::unordered_map& dom_map, + const IntSet& outer, const IntSet& inner, IntSet* parent) { + if (dom_map.count(s->outer) && dom_map.count(s->inner) && dom_map.count(s->parent) && + outer.match_range(dom_map.at(s->outer)) && inner.match_range(dom_map.at(s->inner))) { *parent = IntSet::range(dom_map.at(s->parent)); return; } @@ -311,19 +287,16 @@ void PassUpDomain(const SplitNode* s, CHECK(outer.defined()); CHECK(inner.defined()); CHECK(factor.defined()); - *parent = arith::EvalSet( - s->outer->var * factor + s->inner->var + parent_min, - {{s->outer, outer}, {s->inner, inner}}); + *parent = arith::EvalSet(s->outer->var * factor + s->inner->var + parent_min, + {{s->outer, outer}, {s->inner, inner}}); } -void PassUpDomain(const FuseNode* s, - const std::unordered_map& dom_map, - const IntSet& fused, - IntSet* outer, - IntSet* inner) { +void PassUpDomain(const FuseNode* s, const std::unordered_map& dom_map, + const IntSet& fused, IntSet* outer, IntSet* inner) { CHECK(dom_map.count(s->outer)); CHECK(dom_map.count(s->inner)); CHECK(dom_map.count(s->fused)); + arith::Analyzer ana; if (fused.match_range(dom_map.at(s->fused))) { *outer = IntSet::range(dom_map.at(s->outer)); @@ -336,8 +309,8 @@ void PassUpDomain(const FuseNode* s, if (fused.is_single_point()) { PrimExpr value = fused.point_value(); PrimExpr factor = dom_map.at(s->inner)->extent; - PrimExpr v_outer = indexdiv(value, factor); - PrimExpr v_inner = indexmod(value, factor); + PrimExpr v_outer = indexdiv(value, factor); + PrimExpr v_inner = indexmod(value, factor); if (!is_zero(outer_min)) v_outer = v_outer + outer_min; if (!is_zero(inner_min)) v_inner = v_inner + inner_min; *outer = IntSet::single_point(v_outer); @@ -345,20 +318,19 @@ void PassUpDomain(const FuseNode* s, } else { PrimExpr fused_extent = (fused.max() - fused.min() + 1); PrimExpr inner_extent = dom_map.at(s->inner)->extent; - *outer = IntSet::interval( - outer_min + indexdiv(fused.min(), inner_extent), - outer_min + indexdiv(fused.max(), inner_extent)); - if (is_zero(Simplify(indexmod(inner_extent, fused_extent))) && - is_zero(Simplify(indexmod(fused.min(), fused_extent)))) { + *outer = IntSet::interval(outer_min + indexdiv(fused.min(), inner_extent), + outer_min + indexdiv(fused.max(), inner_extent)); + if (is_zero(ana.Simplify(indexmod(inner_extent, fused_extent))) && + is_zero(ana.Simplify(indexmod(fused.min(), fused_extent)))) { // fused never spans multiple rows, make a tight bounding box // there may be other cases when bounding box could be tightened *inner = IntSet::interval(inner_min + indexmod(fused.min(), inner_extent), inner_min + indexmod(fused.max(), inner_extent)); } else { // fused may span multiple rows, use full row widths - if (!is_zero(Simplify(indexmod(fused_extent, inner_extent))) || - !is_zero(Simplify(indexmod(fused.min(), inner_extent)))) { - LOG(WARNING) << - "fused and original axes are not aligned, this may cause redundant computations"; + if (!is_zero(ana.Simplify(indexmod(fused_extent, inner_extent))) || + !is_zero(ana.Simplify(indexmod(fused.min(), inner_extent)))) { + LOG(WARNING) + << "fused and original axes are not aligned, this may cause redundant computations"; } *inner = IntSet::range(dom_map.at(s->inner)); } @@ -366,44 +338,34 @@ void PassUpDomain(const FuseNode* s, } } -void PassUpDomain(const RebaseNode* s, - const std::unordered_map& dom_map, - const IntSet& rebased, - IntSet* parent) { +void PassUpDomain(const RebaseNode* s, const std::unordered_map& dom_map, + const IntSet& rebased, IntSet* parent) { CHECK(dom_map.count(s->parent)); if (rebased.match_range(dom_map.at(s->rebased))) { *parent = IntSet::range(dom_map.at(s->parent)); return; } PrimExpr parent_min = dom_map.at(s->parent)->min; - *parent = arith::EvalSet(s->rebased->var + parent_min, - {{s->rebased, rebased}}); + *parent = arith::EvalSet(s->rebased->var + parent_min, {{s->rebased, rebased}}); } -void PassUpDomain(const Stage& stage, - const std::unordered_map& dom_map, +void PassUpDomain(const Stage& stage, const std::unordered_map& dom_map, std::unordered_map* p_state) { auto& state = *p_state; for (size_t i = stage->relations.size(); i != 0; --i) { IterVarRelation rel = stage->relations[i - 1]; if (const SplitNode* r = rel.as()) { IntSet parent; - PassUpDomain(r, dom_map, - state.at(r->outer), state.at(r->inner), - &parent); + PassUpDomain(r, dom_map, state.at(r->outer), state.at(r->inner), &parent); state[r->parent] = parent; } else if (const FuseNode* r = rel.as()) { IntSet outer, inner; - PassUpDomain(r, dom_map, - state.at(r->fused), - &outer, &inner); + PassUpDomain(r, dom_map, state.at(r->fused), &outer, &inner); state[r->outer] = outer; state[r->inner] = inner; } else if (const RebaseNode* r = rel.as()) { IntSet parent; - PassUpDomain(r, dom_map, - state.at(r->rebased), - &parent); + PassUpDomain(r, dom_map, state.at(r->rebased), &parent); state[r->parent] = parent; } else if (rel.as()) { } else { @@ -413,8 +375,7 @@ void PassUpDomain(const Stage& stage, } // Pass up bit mask with or relation. -void PassUpBitMaskOr(const Stage& stage, - std::unordered_map* p_state, +void PassUpBitMaskOr(const Stage& stage, std::unordered_map* p_state, bool allow_missing) { auto& state = *p_state; for (size_t i = stage->relations.size(); i != 0; --i) { @@ -461,8 +422,7 @@ void PassUpBitMaskOr(const Stage& stage, } } -void PassDownBitMaskOr(const Stage& stage, - std::unordered_map* p_state, +void PassDownBitMaskOr(const Stage& stage, std::unordered_map* p_state, bool allow_missing) { auto& state = *p_state; for (IterVarRelation rel : stage->relations) { @@ -509,17 +469,14 @@ void PassDownBitMaskOr(const Stage& stage, } } - /*! * \brief message passing to find if boundary checking on IterVar is needed. * \param s The stage to be used. * \param p_state The message passing state * IterVar->flag */ -void PassUpBoundCheck(const Stage& s, - const Map& dom_map, - std::unordered_map* p_state, - arith::Analyzer* analyzer) { +void PassUpBoundCheck(const Stage& s, const Map& dom_map, + std::unordered_map* p_state, arith::Analyzer* analyzer) { auto& state = *p_state; for (size_t i = s->relations.size(); i != 0; --i) { IterVarRelation rel = s->relations[i - 1]; @@ -560,16 +517,14 @@ bool IsRangeSame(const Range input_1, const Range input_2) { arith::Analyzer analyzer; if (input_1.same_as(input_2)) return true; - return (analyzer.CanProve(input_1->min == input_2->min) - && analyzer.CanProve(input_1->extent == input_2->extent)); + return (analyzer.CanProve(input_1->min == input_2->min) && + analyzer.CanProve(input_1->extent == input_2->extent)); } -std::vector MakeBoundCheck( - const Stage& stage, - const Map& dom_map, - const std::unordered_map& value_map, - bool skip_ivar_domain, - const std::unordered_set& skip_iter) { +std::vector MakeBoundCheck(const Stage& stage, const Map& dom_map, + const std::unordered_map& value_map, + bool skip_ivar_domain, + const std::unordered_set& skip_iter) { arith::Analyzer analyzer; std::unordered_map bound_state; @@ -579,11 +534,15 @@ std::vector MakeBoundCheck( PassUpBoundCheck(stage, dom_map, &bound_state, &analyzer); std::vector preds; - std::unordered_map iset_dmap; + Map iset_dmap; // setup domain map for set analysis for (const auto& kv : dom_map) { - iset_dmap[kv.first->var.get()] = IntSet::range(kv.second); + iset_dmap.Set(kv.first->var, IntSet::range(kv.second)); + } + + for (auto entry : dom_map) { + analyzer.Bind(entry.first->var, entry.second); } for (const IterVar& iv : stage->all_iter_vars) { @@ -591,7 +550,7 @@ std::vector MakeBoundCheck( if (bound_state.at(iv)) { Range dom = dom_map.at(iv); PrimExpr value = value_map.at(iv) - dom->min; - PrimExpr vmax = EvalSet(value, iset_dmap).max(); + PrimExpr vmax = analyzer.int_set(value, iset_dmap).max(); if (vmax.dtype() != value.dtype() || !analyzer.CanProve(vmax < dom->extent)) { preds.emplace_back(value < dom->extent); } @@ -603,7 +562,7 @@ std::vector MakeBoundCheck( CHECK(iv->dom.defined()); if (!skip_ivar_domain && !IsRangeSame(iv->dom, dom)) { PrimExpr value = value_map.at(iv) - iv->dom->min; - IntSet s = EvalSet(value, iset_dmap); + IntSet s = analyzer.int_set(value, iset_dmap); PrimExpr vmin = s.min(); PrimExpr vmax = s.max(); // The range of `value` resides in [vmin, vmax] diff --git a/src/te/schedule/message_passing.h b/src/te/schedule/message_passing.h index 187723516f97..c382b90d630c 100644 --- a/src/te/schedule/message_passing.h +++ b/src/te/schedule/message_passing.h @@ -25,10 +25,11 @@ #ifndef TVM_TE_SCHEDULE_MESSAGE_PASSING_H_ #define TVM_TE_SCHEDULE_MESSAGE_PASSING_H_ -#include -#include -#include #include +#include +#include +#include + #include #include #include @@ -45,11 +46,8 @@ namespace te { * \param analyzer Analyzer context, storing information about bounds in p_state. * \param allow_missing Whether allow missing value. */ -void PassDownDomain( - const Stage& stage, - std::unordered_map* p_state, - arith::Analyzer* analyzer, - bool allow_missing = false); +void PassDownDomain(const Stage& stage, std::unordered_map* p_state, + arith::Analyzer* analyzer, bool allow_missing = false); /*! * \param Upward inference of index of each IterVar. @@ -60,10 +58,8 @@ void PassDownDomain( * \param p_state The index state of each IterVar. * \param allow_missing Whether allow missing value. */ -void PassUpIndex(const Stage& stage, - const Map& dom_map, - std::unordered_map* p_state, - bool allow_missing = false); +void PassUpIndex(const Stage& stage, const Map& dom_map, + std::unordered_map* p_state, bool allow_missing = false); /*! * \param Downward inference of index of each IterVar. @@ -74,10 +70,8 @@ void PassUpIndex(const Stage& stage, * \param p_state The index state of each IterVar. * \param allow_missing Whether allow missing value. */ -void PassDownIndex(const Stage& stage, - const Map& dom_map, - std::unordered_map* p_state, - bool allow_missing = false); +void PassDownIndex(const Stage& stage, const Map& dom_map, + std::unordered_map* p_state, bool allow_missing = false); /*! * \param Upward inference of domain set of each IterVar. @@ -87,8 +81,7 @@ void PassDownIndex(const Stage& stage, * \param dom_map The domain map of each iteration variable's maximum domain. * \param p_state The index state of each IterVar. */ -void PassUpDomain(const Stage& stage, - const std::unordered_map& dom_map, +void PassUpDomain(const Stage& stage, const std::unordered_map& dom_map, std::unordered_map* p_state); /*! @@ -97,8 +90,7 @@ void PassUpDomain(const Stage& stage, * \param p_state The index state of each IterVar. * \param allow_missing Whether allow missing value. */ -void PassUpBitMaskOr(const Stage& stage, - std::unordered_map* p_state, +void PassUpBitMaskOr(const Stage& stage, std::unordered_map* p_state, bool allow_missing = false); /*! @@ -107,8 +99,7 @@ void PassUpBitMaskOr(const Stage& stage, * \param p_state The index state of each IterVar. * \param allow_missing Whether allow missing value. */ -void PassDownBitMaskOr(const Stage& stage, - std::unordered_map* p_state, +void PassDownBitMaskOr(const Stage& stage, std::unordered_map* p_state, bool allow_missing = false); /*! @@ -120,13 +111,10 @@ void PassDownBitMaskOr(const Stage& stage, * \param skip_iter The set of variables to skip bound condition. * \return List of predicates that we need to check. */ -std::vector -MakeBoundCheck( - const Stage& stage, - const Map& dom_map, - const std::unordered_map& value_map, - bool skip_ivar_domain, - const std::unordered_set& skip_iter); +std::vector MakeBoundCheck(const Stage& stage, const Map& dom_map, + const std::unordered_map& value_map, + bool skip_ivar_domain, + const std::unordered_set& skip_iter); } // namespace te } // namespace tvm diff --git a/src/tir/pass/inline.cc b/src/te/schedule/operation_inline.cc similarity index 57% rename from src/tir/pass/inline.cc rename to src/te/schedule/operation_inline.cc index 1b322964b873..fd613f47107a 100644 --- a/src/tir/pass/inline.cc +++ b/src/te/schedule/operation_inline.cc @@ -18,48 +18,54 @@ */ /*! - * \file inline.cc + * \file operation_inline.cc */ +#include "operation_inline.h" + +#include #include #include -#include #include +#include + +#include "../../tir/transforms/ir_util.h" + namespace tvm { -namespace tir { +namespace te { // inliner to inline a function // the result may not be SSA, // ConvertSSA need to be applied after this pass -class IRInline final : public StmtExprMutator { +class OperationInliner final : public StmtExprMutator { public: - IRInline(FunctionRef f, Array args, PrimExpr body) - : f_(f), args_(args), body_(body) {} + OperationInliner(Operation op, Array args, PrimExpr body) + : operation_(op), args_(args), body_(body) {} - PrimExpr VisitExpr_(const CallNode* op) final { + PrimExpr VisitExpr_(const ProducerLoadNode* op) final { PrimExpr expr = StmtExprMutator::VisitExpr_(op); - op = expr.as(); + op = expr.as(); + auto tensor = Downcast(op->producer); - if (op->func == f_) { - CHECK_EQ(op->value_index, 0); + if (tensor->op.same_as(operation_)) { + CHECK_EQ(tensor->value_index, 0); expr = body_; - CHECK_EQ(args_.size(), op->args.size()); + CHECK_EQ(args_.size(), op->indices.size()); bool has_side_effect = false; - for (size_t i = 0; i < op->args.size(); ++i) { - if (HasSideEffect(op->args[i])) has_side_effect = true; + for (size_t i = 0; i < op->indices.size(); ++i) { + if (HasSideEffect(op->indices[i])) has_side_effect = true; } if (has_side_effect) { for (size_t i = 0; i < args_.size(); ++i) { - expr = LetNode::make(args_[i], op->args[i], expr); + expr = Let(args_[i], op->indices[i], expr); } } else { Map vmap; for (size_t i = 0; i < args_.size(); ++i) { - vmap.Set(args_[i], op->args[i]); + vmap.Set(args_[i], op->indices[i]); } - expr = Substitute( - EvaluateNode::make(expr), vmap).as()->value; + expr = Substitute(Evaluate(expr), vmap).as()->value; } return expr; } else { @@ -68,20 +74,16 @@ class IRInline final : public StmtExprMutator { } private: - FunctionRef f_; + Operation operation_; Array args_; PrimExpr body_; }; -Stmt Inline(Stmt stmt, - FunctionRef f, - Array args, - PrimExpr body) { - CHECK_EQ(f->num_outputs(), 1) - << "can only inline output single value operation"; - Stmt ret = IRInline(f, args, body)(std::move(stmt)); +Stmt Inline(Stmt stmt, Operation f, Array args, PrimExpr body) { + CHECK_EQ(f->num_outputs(), 1) << "can only inline output single value operation"; + Stmt ret = OperationInliner(f, args, body)(std::move(stmt)); if (ret.same_as(stmt)) return ret; return ConvertSSA(ret); } -} // namespace tir +} // namespace te } // namespace tvm diff --git a/src/te/schedule/operation_inline.h b/src/te/schedule/operation_inline.h new file mode 100644 index 000000000000..d475fbe3787e --- /dev/null +++ b/src/te/schedule/operation_inline.h @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file operation_inline.h + */ +#ifndef TVM_TE_SCHEDULE_OPERATION_INLINE_H_ +#define TVM_TE_SCHEDULE_OPERATION_INLINE_H_ + +#include +#include +#include +#include + +namespace tvm { +namespace te { + +/*! + * \brief inline all calls of f in stmt. + * + * \param stmt The statement to apply inline optimization. + * \param op The op to be inlined. + * \param args The arguments variable of the function. + * \param body The definition body of the function. + * \return The result stmt + * + * \note All the passes in this file uses SSA form and outputs SSA form. + */ +Stmt Inline(Stmt stmt, Operation op, Array args, PrimExpr body); + +} // namespace te +} // namespace tvm +#endif // TVM_TE_SCHEDULE_OPERATION_INLINE_H_ diff --git a/src/te/schedule/schedule_dataflow_rewrite.cc b/src/te/schedule/schedule_dataflow_rewrite.cc index 99f2fb9efd87..af72d3b1a1df 100644 --- a/src/te/schedule/schedule_dataflow_rewrite.cc +++ b/src/te/schedule/schedule_dataflow_rewrite.cc @@ -20,33 +20,33 @@ /*! * \file schedule_dataflow_rewrite.cc */ -#include #include +#include +#include #include -#include + #include + +#include "../../tir/transforms/ir_util.h" #include "message_passing.h" -#include "../../tir/pass/ir_util.h" -#include "../../arith/compute_expr.h" +#include "operation_inline.h" namespace tvm { namespace te { // find first occurance location in leaf -template +template size_t FindNodeRef(ArrayNode* array_node, const T& v) { const Object* n = v.get(); - for (size_t i = 0; i < array_node->data.size(); ++i) { - if (array_node->data[i].get() == n) return i; + for (size_t i = 0; i < array_node->size(); ++i) { + if (array_node->at(i).get() == n) return i; } - return array_node->data.size(); + return array_node->size(); } // The replacer of cache. class VarReplacer : public tir::StmtExprMutator { public: - explicit VarReplacer( - const std::unordered_map& vsub) - : vsub_(vsub) {} + explicit VarReplacer(const std::unordered_map& vsub) : vsub_(vsub) {} PrimExpr VisitExpr_(const VarNode* op) final { auto it = vsub_.find(op); if (it != vsub_.end()) return it->second; @@ -55,19 +55,16 @@ class VarReplacer : public tir::StmtExprMutator { tir::CommReducer MutateCommReducer(tir::CommReducer combiner) { // Replace free variables in combiner - auto new_identity = tir::UpdateArray(combiner->identity_element, [this] (const PrimExpr& e) { - return this->VisitExpr(e); - }); - auto new_result = tir::UpdateArray(combiner->result, [this] (const PrimExpr& e) { - return this->VisitExpr(e); - }); + auto new_identity = tir::UpdateArray(combiner->identity_element, + [this](const PrimExpr& e) { return this->VisitExpr(e); }); + auto new_result = tir::UpdateArray(combiner->result, + [this](const PrimExpr& e) { return this->VisitExpr(e); }); if (combiner->identity_element.same_as(new_identity) && combiner->identity_element.same_as(new_result)) { return combiner; } else { - return tir::CommReducerNode::make( - combiner->lhs, combiner->rhs, new_result, new_identity); + return tir::CommReducer(combiner->lhs, combiner->rhs, new_result, new_identity); } } @@ -78,12 +75,8 @@ class VarReplacer : public tir::StmtExprMutator { if (op->combiner.same_as(new_combiner)) { return new_e; } else { - return tir::ReduceNode::make( - new_combiner, - new_reduce->source, - new_reduce->axis, - new_reduce->condition, - new_reduce->value_index); + return tir::Reduce(new_combiner, new_reduce->source, new_reduce->axis, new_reduce->condition, + new_reduce->value_index); } } @@ -91,27 +84,25 @@ class VarReplacer : public tir::StmtExprMutator { const std::unordered_map& vsub_; }; -PrimExpr InjectPredicate(const Array& predicates, - PrimExpr body) { +PrimExpr InjectPredicate(const Array& predicates, PrimExpr body) { using tir::ReduceNode; using tir::SelectNode; if (predicates.size() == 0) return body; const ReduceNode* reduce = body.as(); + auto fand = [](PrimExpr a, PrimExpr b) { return a && b; }; + if (reduce) { auto n = make_object(*reduce); - n->condition = n->condition && arith::ComputeReduce(predicates, PrimExpr()); + n->condition = foldl(fand, n->condition, predicates); return PrimExpr(n); } - return SelectNode::make(arith::ComputeReduce(predicates, PrimExpr()), - body, - make_zero(body.dtype())); + return Select(foldl(fand, const_true(1), predicates), body, make_zero(body.dtype())); } // Replace data flow appears in all stages given the tensor change. // Also update vmap if subsequent dataflow need to be replaced. // Need to keep an update to the date transitive closure property on the vmap by a reverse map. -void ReplaceDataFlow(const Array& stages, - std::unordered_map* vmap, +void ReplaceDataFlow(const Array& stages, std::unordered_map* vmap, std::unordered_map* rvmap) { for (Stage s : stages) { Operation op = s->op->ReplaceInputs(s->op, *vmap); @@ -131,14 +122,11 @@ void ReplaceDataFlow(const Array& stages, } inline bool ReduceEqual(const tir::ReduceNode* a, const tir::ReduceNode* b) { - return (a->combiner.same_as(b->combiner)) && - (a->source.same_as(b->source)) && - (a->axis.same_as(b->axis)) && - (a->condition.same_as(b->condition)); + return (a->combiner.same_as(b->combiner)) && (a->source.same_as(b->source)) && + (a->axis.same_as(b->axis)) && (a->condition.same_as(b->condition)); } -Tensor Schedule::cache_read(const Tensor& tensor, - const std::string& scope, +Tensor Schedule::cache_read(const Tensor& tensor, const std::string& scope, const Array& readers) { (*this)->InvalidateCache(); // create identity mapping. @@ -152,9 +140,12 @@ Tensor Schedule::cache_read(const Tensor& tensor, std::unordered_map vsub; Stage s = operator[](tensor->op); Tensor sugar_tensor = s->op.output(tensor->value_index); - Tensor cache = compute(sugar_tensor->shape, [&sugar_tensor](const Array& i) { - return sugar_tensor(Array(i.begin(), i.end())); - }, os.str()); + Tensor cache = compute( + sugar_tensor->shape, + [&sugar_tensor](const Array& i) { + return sugar_tensor(Array(i.begin(), i.end())); + }, + os.str()); vsub[sugar_tensor] = cache; std::unordered_map vmap; @@ -162,22 +153,19 @@ Tensor Schedule::cache_read(const Tensor& tensor, for (Operation op : readers) { Stage s = operator[](op); Operation repl_op = s->op->ReplaceInputs(s->op, vsub); - CHECK(!repl_op.same_as(s->op)) - << "Cannot find " << tensor - << " in the inputs of " << s->op; + CHECK(!repl_op.same_as(s->op)) << "Cannot find " << tensor << " in the inputs of " << s->op; vmap[s->op.output(0)] = repl_op.output(0); rvmap[repl_op.output(0)] = s->op.output(0); s->op = repl_op; } ReplaceDataFlow((*this)->stages, &vmap, &rvmap); - ArrayNode* stages = (*this)->stages.CopyOnWrite(); + Array& stages = (*this)->stages; Stage op_stage = operator[](tensor->op); - size_t pos = FindNodeRef(stages, op_stage); + size_t pos = FindNodeRef(stages.GetArrayNode(), op_stage); Stage cache_stage = Stage(cache->op); cache_stage.set_scope(scope); - CHECK_LT(pos, stages->data.size()); - stages->data.insert(stages->data.begin() + pos + 1, - cache_stage); + CHECK_LT(pos, stages.size()); + stages.insert(stages.begin() + pos + 1, cache_stage); (*this)->stage_map.Set(cache->op, cache_stage); // Update group cache_stage->group = op_stage->group; @@ -187,12 +175,9 @@ Tensor Schedule::cache_read(const Tensor& tensor, return cache; } -template -void PrepareAxisMapping(Stage orig_stage, - OpType* op, - std::unordered_set* p_red_axis, - Array* p_new_axis, - std::unordered_map* p_dom_map, +template +void PrepareAxisMapping(Stage orig_stage, OpType* op, std::unordered_set* p_red_axis, + Array* p_new_axis, std::unordered_map* p_dom_map, std::unordered_map* p_vsub, std::unordered_map* p_vsub2newvar, std::vector* p_predicates) { @@ -217,11 +202,9 @@ void PrepareAxisMapping(Stage orig_stage, std::unordered_map value_map; for (IterVar iv : orig_stage->leaf_iter_vars) { if (red_axis.count(iv)) continue; - CHECK_EQ(iv->iter_type, kDataPar) - << "Can only relayout with in data parallel dimensions"; + CHECK_EQ(iv->iter_type, kDataPar) << "Can only relayout with in data parallel dimensions"; Range dom = dom_map.at(iv); - IterVar new_iv = IterVarNode::make( - dom, iv->var.copy_with_suffix(".c"), iv->iter_type); + IterVar new_iv = IterVar(dom, iv->var.copy_with_suffix(".c"), iv->iter_type); new_axis.push_back(new_iv); if (is_one(dom->min)) { value_map[iv] = dom->min; @@ -236,8 +219,7 @@ void PrepareAxisMapping(Stage orig_stage, skip_bound_check.insert(iv); } PassUpIndex(orig_stage, dom_map, &value_map, true); - predicates = MakeBoundCheck( - orig_stage, dom_map, value_map, true, skip_bound_check); + predicates = MakeBoundCheck(orig_stage, dom_map, value_map, true, skip_bound_check); // The root axis for (IterVar iv : op->axis) { if (value_map.count(iv)) { @@ -247,12 +229,8 @@ void PrepareAxisMapping(Stage orig_stage, } } -Array ReplaceOriginalOp(Schedule sch, - Stage orig_stage, - const std::string& scope, - Operation cache_op, - Operation orig_new_op, - size_t tensor_size) { +Array ReplaceOriginalOp(Schedule sch, Stage orig_stage, const std::string& scope, + Operation cache_op, Operation orig_new_op, size_t tensor_size) { Array cache_tensor_list; for (size_t i = 0; i < tensor_size; i++) { Tensor cache_tensor = cache_op.output(i); @@ -274,13 +252,12 @@ Array ReplaceOriginalOp(Schedule sch, orig_stage->leaf_iter_vars = orig_stage->all_iter_vars; orig_stage->relations = Array(); // create schedule for new cached stage. - ArrayNode* stages = sch->stages.CopyOnWrite(); - size_t pos = FindNodeRef(stages, orig_stage); + Array& stages = sch->stages; + size_t pos = FindNodeRef(stages.GetArrayNode(), orig_stage); Stage cache_stage = Stage(cache_op); cache_stage.set_scope(scope); - CHECK_LT(pos, stages->data.size()); - stages->data.insert(stages->data.begin() + pos, - cache_stage); + CHECK_LT(pos, stages.size()); + stages.insert(stages.begin() + pos, cache_stage); sch->stage_map.Set(cache_op, cache_stage); // Update group cache_stage->group = orig_stage->group; @@ -290,10 +267,8 @@ Array ReplaceOriginalOp(Schedule sch, return cache_tensor_list; } - // Cache write and relayout the data according to loop pattern -Array CacheWriteWithReLayout(Schedule sch, - const Array& tensor_array, +Array CacheWriteWithReLayout(Schedule sch, const Array& tensor_array, const std::string& scope) { size_t tensor_size = tensor_array.size(); sch->InvalidateCache(); @@ -309,8 +284,8 @@ Array CacheWriteWithReLayout(Schedule sch, std::unordered_map vsub2newvar; std::vector predicates; - PrepareAxisMapping(orig_stage, compute, - &red_axis, &new_axis, &dom_map, &vsub, &vsub2newvar, &predicates); + PrepareAxisMapping(orig_stage, compute, &red_axis, &new_axis, &dom_map, &vsub, &vsub2newvar, + &predicates); PrimExpr body; Array body_list; @@ -325,17 +300,13 @@ Array CacheWriteWithReLayout(Schedule sch, const tir::ReduceNode* reduce_body = body.as(); if (first_reduce != nullptr) { CHECK(ReduceEqual(reduce_body, first_reduce)); - body = tir::ReduceNode::make(first_reduce->combiner, - first_reduce->source, - first_reduce->axis, - first_reduce->condition, - reduce_body->value_index); + body = tir::Reduce(first_reduce->combiner, first_reduce->source, first_reduce->axis, + first_reduce->condition, reduce_body->value_index); } else { first_reduce = reduce_body; } } else { - CHECK(first_reduce == nullptr) - << "cannot mix reduce and other node in ONE compute bodys"; + CHECK(first_reduce == nullptr) << "cannot mix reduce and other node in ONE compute bodys"; } body_list.push_back(body); } @@ -353,26 +324,21 @@ Array CacheWriteWithReLayout(Schedule sch, args.push_back(value_map.at(iv)); } } - Operation cache_op = ComputeOpNode::make( - compute->name + "." + scope, compute->tag, compute->attrs, - new_axis, body_list); + Operation cache_op = + ComputeOp(compute->name + "." + scope, compute->tag, compute->attrs, new_axis, body_list); Array cache_expr_list; for (size_t i = 0; i < tensor_size; i++) { Tensor cache_tensor = cache_op.output(i); cache_expr_list.push_back(cache_tensor(args)); } - Operation orig_new_op = ComputeOpNode::make( - compute->name, compute->tag, compute->attrs, - compute->axis, cache_expr_list); - return ReplaceOriginalOp(sch, orig_stage, scope, - cache_op, orig_new_op, tensor_size); + Operation orig_new_op = + ComputeOp(compute->name, compute->tag, compute->attrs, compute->axis, cache_expr_list); + return ReplaceOriginalOp(sch, orig_stage, scope, cache_op, orig_new_op, tensor_size); } - // for tensor compute op -Array CacheWriteWithReLayoutTensor(Schedule sch, - const Array& tensor_array, +Array CacheWriteWithReLayoutTensor(Schedule sch, const Array& tensor_array, const std::string& scope) { size_t tensor_size = tensor_array.size(); sch->InvalidateCache(); @@ -390,14 +356,12 @@ Array CacheWriteWithReLayoutTensor(Schedule sch, std::unordered_map vsub2newvar; std::vector predicates; - PrepareAxisMapping(orig_stage, tensor_op, - &red_axis, &new_axis, &dom_map, &vsub, &vsub2newvar, &predicates); - + PrepareAxisMapping(orig_stage, tensor_op, &red_axis, &new_axis, &dom_map, &vsub, &vsub2newvar, + &predicates); for (int i = tensor_op->schedulable_ndim; i < static_cast(tensor_op->axis.size()); ++i) { IterVar iv = tensor_op->axis[i]; - IterVar new_iv = IterVarNode::make( - iv->dom, iv->var.copy_with_suffix(".c"), iv->iter_type); + IterVar new_iv = IterVar(iv->dom, iv->var.copy_with_suffix(".c"), iv->iter_type); new_axis.push_back(new_iv); } Array new_regions; @@ -416,16 +380,16 @@ Array CacheWriteWithReLayoutTensor(Schedule sch, new_scalar_inputs.push_back(VarReplacer(vsub2newvar)(old_input)); } - Operation cache_op = TensorComputeOpNode::make( - tensor_op->name + "." + scope, tensor_op->tag, new_axis, - tensor_op->reduce_axis, tensor_op->schedulable_ndim, - tensor_op->intrin, tensor_op->inputs, new_regions, new_scalar_inputs); + Operation cache_op = + TensorComputeOp(tensor_op->name + "." + scope, tensor_op->tag, new_axis, + tensor_op->reduce_axis, tensor_op->schedulable_ndim, tensor_op->intrin, + tensor_op->inputs, new_regions, new_scalar_inputs); // axis will be used in generating compute op Array compute_axis = tensor_op->axis; for (size_t i = tensor_op->schedulable_ndim; i < tensor_op->axis.size(); ++i) { IterVar iv = tensor_op->axis[i]; - IterVar aiv = IterVarNode::make(iv->dom, iv->var, kDataPar); + IterVar aiv = IterVar(iv->dom, iv->var, kDataPar); compute_axis.Set(i, aiv); } @@ -454,19 +418,14 @@ Array CacheWriteWithReLayoutTensor(Schedule sch, Tensor cache_tensor = cache_op.output(i); cache_expr_list.push_back(cache_tensor(args)); } - Operation orig_new_op = ComputeOpNode::make( - tensor_op->name, tensor_op->tag, {}, - compute_axis, cache_expr_list); - return ReplaceOriginalOp(sch, orig_stage, scope, - cache_op, orig_new_op, tensor_size); + Operation orig_new_op = + ComputeOp(tensor_op->name, tensor_op->tag, {}, compute_axis, cache_expr_list); + return ReplaceOriginalOp(sch, orig_stage, scope, cache_op, orig_new_op, tensor_size); } - -Array Schedule::cache_write(const Array& tensor_array, - const std::string& scope) { +Array Schedule::cache_write(const Array& tensor_array, const std::string& scope) { (*this)->InvalidateCache(); - CHECK(tensor_array.size() > 0) - << "size of tensor_array must be greater than 0"; + CHECK(tensor_array.size() > 0) << "size of tensor_array must be greater than 0"; Tensor tensor = tensor_array[0]; Stage orig_stage = operator[](tensor->op); const ComputeOpNode* compute = tensor->op.as(); @@ -474,15 +433,12 @@ Array Schedule::cache_write(const Array& tensor_array, << "size of input tensor list must be same as number of stage outputs"; for (size_t i = 1; i < tensor_array.size(); i++) { Stage tmp_stage = operator[](tensor_array[i]->op); - CHECK(orig_stage.same_as(tmp_stage)) - << "Input tensor list must be generated by ONE computeOp"; + CHECK(orig_stage.same_as(tmp_stage)) << "Input tensor list must be generated by ONE computeOp"; } return CacheWriteWithReLayout(*this, tensor_array, scope); } - -Tensor Schedule::cache_write(const Tensor& tensor, - const std::string& scope) { +Tensor Schedule::cache_write(const Tensor& tensor, const std::string& scope) { // support original compute and tensor compute both (*this)->InvalidateCache(); if (tensor->op.as()) { @@ -495,7 +451,6 @@ Tensor Schedule::cache_write(const Tensor& tensor, } } - void RebaseNonZeroMinLoop(const Schedule& sch) { std::unordered_map rebase_map; for (Stage s : sch->stages) { @@ -505,21 +460,19 @@ void RebaseNonZeroMinLoop(const Schedule& sch) { ArrayNode* leaf_vars = s->leaf_iter_vars.CopyOnWrite(); for (IterVar iv : root_iter_vars) { size_t idx = FindNodeRef(leaf_vars, iv); - auto it = s->iter_var_attrs.find(iv); + auto it = s->iter_var_attrs.find(iv); // don;t need to rebase path that are binded. - if (it != s->iter_var_attrs.end() && - (*it).second->bind_thread.defined()) { + if (it != s->iter_var_attrs.end() && (*it).second->bind_thread.defined()) { continue; } - if (idx < leaf_vars->data.size()) { + if (idx < leaf_vars->size()) { // insert rebase - IterVar rebased = IterVarNode::make( - Range(), iv->var.copy_with_suffix(""), iv->iter_type); - s->relations.push_back(RebaseNode::make(iv, rebased)); + IterVar rebased = IterVar(Range(), iv->var.copy_with_suffix(""), iv->iter_type); + s->relations.push_back(te::Rebase(iv, rebased)); if (s->iter_var_attrs.count(iv)) { s->iter_var_attrs.Set(rebased, s->iter_var_attrs.at(iv)); } - leaf_vars->data[idx] = rebased; + leaf_vars->SetItem(idx, rebased); rebase_map[iv] = rebased; } } @@ -556,13 +509,11 @@ void InjectInline(ScheduleNode* sch) { { // setup args const ComputeOpNode* compute = stage->op.as(); - CHECK(compute) - << "can only inline compute op"; + CHECK(compute) << "can only inline compute op"; for (auto iv : compute->axis) { args.push_back(iv->var); } - CHECK_EQ(compute->body.size(), 1U) - << "can only inline compute op with 1 output"; + CHECK_EQ(compute->body.size(), 1U) << "can only inline compute op with 1 output"; body = compute->body[0]; } for (size_t j = i; j < sch->stages.size(); ++j) { @@ -579,12 +530,12 @@ void InjectInline(ScheduleNode* sch) { for (size_t k = 1; k < new_body[j].size(); ++k) { const tir::ReduceNode* reduce_ = new_body[j][k].as(); CHECK(reduce_); - CHECK(ReduceEqual(reduce_, reduce)) - << "The Reduce inputs of ComputeOp should " - << "have the same attribute except value_index"; + CHECK(ReduceEqual(reduce_, reduce)) << "The Reduce inputs of ComputeOp should " + << "have the same attribute except value_index"; } - PrimExpr new_value = tir::Inline(tir::EvaluateNode::make(new_body[j][0]), - stage->op, args, body).as()->value; + PrimExpr new_value = Inline(tir::Evaluate(new_body[j][0]), stage->op, args, body) + .as() + ->value; if (!new_value.same_as(new_body[j][0])) { changed[j] = true; const tir::ReduceNode* r = new_value.as(); @@ -599,8 +550,9 @@ void InjectInline(ScheduleNode* sch) { } } else { for (size_t k = 0; k < new_body[j].size(); ++k) { - PrimExpr new_value = tir::Inline(tir::EvaluateNode::make(new_body[j][k]), - stage->op, args, body).as()->value; + PrimExpr new_value = Inline(tir::Evaluate(new_body[j][k]), stage->op, args, body) + .as() + ->value; if (!new_value.same_as(new_body[j][k])) { new_body[j].Set(k, new_value); changed[j] = true; @@ -611,7 +563,7 @@ void InjectInline(ScheduleNode* sch) { if (!new_hybrid_body[j].defined()) { new_hybrid_body[j] = hybrid->body; } - Stmt new_stmt = tir::Inline(new_hybrid_body[j], stage->op, args, body); + Stmt new_stmt = Inline(new_hybrid_body[j], stage->op, args, body); if (!new_stmt.same_as(new_hybrid_body[j])) { new_hybrid_body[j] = new_stmt; hybrid_changed[j] = true; @@ -631,9 +583,7 @@ void InjectInline(ScheduleNode* sch) { CHECK(compute); Operation op = s->op; if (changed[i]) { - op = ComputeOpNode::make( - compute->name, compute->tag, compute->attrs, - compute->axis, new_body[i]); + op = ComputeOp(compute->name, compute->tag, compute->attrs, compute->axis, new_body[i]); } op = op->ReplaceInputs(op, repl); if (!op.same_as(s->op)) { @@ -645,9 +595,8 @@ void InjectInline(ScheduleNode* sch) { } else if (hybrid_changed[i]) { const HybridOpNode* hybrid = sch->stages[i]->op.as(); CHECK(hybrid); - Operation op = HybridOpNode::make( - hybrid->name, hybrid->tag, hybrid->attrs, hybrid->inputs, - hybrid->outputs, new_hybrid_body[i]); + Operation op = HybridOp(hybrid->name, hybrid->tag, hybrid->attrs, hybrid->inputs, + hybrid->outputs, new_hybrid_body[i]); op = op->ReplaceInputs(op, repl); for (int idx = 0; idx < s->op->num_outputs(); ++idx) { repl[s->op.output(idx)] = op.output(idx); @@ -673,21 +622,17 @@ Schedule Schedule::normalize() { } // Handle reduction factor. -Array Schedule::rfactor(const Tensor& tensor, - const IterVar& axis, - int factor_axis) { +Array Schedule::rfactor(const Tensor& tensor, const IterVar& axis, int factor_axis) { (*this)->InvalidateCache(); using tir::ReduceNode; - CHECK_EQ(axis->iter_type, kCommReduce) - << "Can only factor reduction axis"; + CHECK_EQ(axis->iter_type, kCommReduce) << "Can only factor reduction axis"; Stage reduce_stage = operator[](tensor->op); const ComputeOpNode* compute_op = reduce_stage->op.as(); CHECK(compute_op) << "Can only factor ComputeOp"; ArrayNode* leaf_vars = reduce_stage->leaf_iter_vars.CopyOnWrite(); { size_t axis_pos = FindNodeRef(leaf_vars, axis); - CHECK_NE(axis_pos, leaf_vars->data.size()) - << "Cannot find IterVar " << axis << " in leaf iter vars"; + CHECK_NE(axis_pos, leaf_vars->size()) << "Cannot find IterVar " << axis << " in leaf iter vars"; } // Find touched reduction axis. std::unordered_map touch_map; @@ -698,8 +643,7 @@ Array Schedule::rfactor(const Tensor& tensor, std::unordered_set skip_bound_check; // Verify normal axis are not touched. for (IterVar iv : compute_op->axis) { - CHECK(!touch_map.count(iv)) - << "Factor axis touches normal axis."; + CHECK(!touch_map.count(iv)) << "Factor axis touches normal axis."; skip_bound_check.insert(iv); } // get analyzer. @@ -727,11 +671,11 @@ Array Schedule::rfactor(const Tensor& tensor, } } te::PassUpIndex(reduce_stage, dom_map, &value_map, true); - std::vector predicates = MakeBoundCheck( - reduce_stage, dom_map, value_map, true, skip_bound_check); + std::vector predicates = + MakeBoundCheck(reduce_stage, dom_map, value_map, true, skip_bound_check); // Get the factored op node. - const int factor_axis_pos = \ + const int factor_axis_pos = factor_axis >= 0 ? factor_axis : static_cast(compute_op->axis.size() + 1) + factor_axis; CHECK_LE(factor_axis_pos, compute_op->axis.size()); auto n = make_object(); @@ -740,8 +684,7 @@ Array Schedule::rfactor(const Tensor& tensor, // axis relacement. auto iv_node = make_object(); iv_node->dom = dom_map.at(axis); - CHECK(is_zero(iv_node->dom->min)) - << "Can only factor reduction domain starting from 0"; + CHECK(is_zero(iv_node->dom->min)) << "Can only factor reduction domain starting from 0"; iv_node->var = axis->var; iv_node->iter_type = kDataPar; @@ -761,7 +704,9 @@ Array Schedule::rfactor(const Tensor& tensor, const ReduceNode* reduce = compute_op->body[idx].as(); CHECK(reduce) << "Can only rfactor non-inline reductions"; predicates.push_back(reduce->condition); - PrimExpr predicate = likely(arith::ComputeReduce(predicates, PrimExpr())); + auto fand = [](PrimExpr a, PrimExpr b) { return a && b; }; + + PrimExpr predicate = likely(foldl(fand, const_true(1), predicates)); std::unordered_map vsub; @@ -785,18 +730,14 @@ Array Schedule::rfactor(const Tensor& tensor, } } VarReplacer replacer(vsub); - Array new_source = tir::UpdateArray(reduce->source, - [&replacer] (const PrimExpr& e) { return replacer(e); }); + Array new_source = + tir::UpdateArray(reduce->source, [&replacer](const PrimExpr& e) { return replacer(e); }); PrimExpr new_pred = replacer(predicate); std::vector body; for (size_t idx = 0; idx < reduce->source.size(); ++idx) { - body.emplace_back(ReduceNode::make(reduce->combiner, - new_source, - n->reduce_axis, - new_pred, - idx)); + body.emplace_back(Reduce(reduce->combiner, new_source, n->reduce_axis, new_pred, idx)); } n->body = Array(body); // refresh relations, keep the un-touched relations. @@ -818,21 +759,19 @@ Array Schedule::rfactor(const Tensor& tensor, } // initialize the factored stage. Operation factor_op(n); - ArrayNode* stages = (*this)->stages.CopyOnWrite(); - size_t stage_pos = FindNodeRef(stages, reduce_stage); + Array& stages = (*this)->stages; + size_t stage_pos = FindNodeRef(stages.GetArrayNode(), reduce_stage); Stage factor_stage = Stage(factor_op); factor_stage->relations = rels; - CHECK_LT(stage_pos, stages->data.size()); - stages->data.insert(stages->data.begin() + stage_pos, - factor_stage); + CHECK_LT(stage_pos, stages.size()); + stages.insert(stages.begin() + stage_pos, factor_stage); (*this)->stage_map.Set(factor_op, factor_stage); factor_stage->group = reduce_stage->group; if (factor_stage->group.defined()) { ++factor_stage->group->num_child_stages; } // Replace the old reduction. - IterVar repl_red_axis = reduce_axis( - dom_map.at(axis), axis->var->name_hint + ".v"); + IterVar repl_red_axis = reduce_axis(dom_map.at(axis), axis->var->name_hint + ".v"); Array factor_tensors; Array old_tensors; int size = factor_op->num_outputs(); @@ -840,32 +779,33 @@ Array Schedule::rfactor(const Tensor& tensor, factor_tensors.push_back(factor_op.output(idx)); old_tensors.push_back(reduce_stage->op.output(idx)); } - Array repl_tensors = compute(old_tensors[0]->shape, - [&](const Array& i) { - Array indices; - const int idx_size = static_cast(i.size()); - for (int idx = 0; idx < idx_size; ++idx) { - if (factor_axis_pos == idx) { - indices.push_back(repl_red_axis->var); + Array repl_tensors = compute( + old_tensors[0]->shape, + [&](const Array& i) { + Array indices; + const int idx_size = static_cast(i.size()); + for (int idx = 0; idx < idx_size; ++idx) { + if (factor_axis_pos == idx) { + indices.push_back(repl_red_axis->var); + } + indices.push_back(i[idx]); } - indices.push_back(i[idx]); - } - if (factor_axis_pos == idx_size) { + if (factor_axis_pos == idx_size) { indices.push_back(repl_red_axis->var); - } - Array factor_exprs; - for (int idx = 0; idx < size; ++idx) { - factor_exprs.push_back(factor_tensors[idx](indices)); - } - Array reductions; - Array axis = {repl_red_axis}; - PrimExpr cond = const_true(); - for (int idx = 0; idx < size; ++idx) { - reductions.push_back(ReduceNode::make(reduce->combiner, - factor_exprs, axis, cond, idx)); - } - return reductions; - }, reduce_stage->op->name + ".repl"); + } + Array factor_exprs; + for (int idx = 0; idx < size; ++idx) { + factor_exprs.push_back(factor_tensors[idx](indices)); + } + Array reductions; + Array axis = {repl_red_axis}; + PrimExpr cond = const_true(); + for (int idx = 0; idx < size; ++idx) { + reductions.push_back(Reduce(reduce->combiner, factor_exprs, axis, cond, idx)); + } + return reductions; + }, + reduce_stage->op->name + ".repl"); std::unordered_map vmap; std::unordered_map rvmap; diff --git a/src/te/schedule/schedule_lang.cc b/src/te/schedule/schedule_lang.cc index bfee0d5a0a6b..707d52fb186a 100644 --- a/src/te/schedule/schedule_lang.cc +++ b/src/te/schedule/schedule_lang.cc @@ -22,68 +22,61 @@ */ #include #include -#include #include +#include + #include #include + #include "graph.h" namespace tvm { namespace te { // find first occurance location in leaf -template +template size_t FindNodeRef(ArrayNode* array_node, const T& v) { const Object* n = v.get(); - for (size_t i = 0; i < array_node->data.size(); ++i) { - if (array_node->data[i].get() == n) return i; + for (size_t i = 0; i < array_node->size(); ++i) { + if (array_node->at(i).get() == n) return i; } - return array_node->data.size(); + return array_node->size(); } size_t FindLeafVar(ArrayNode* all_vars, ArrayNode* leaf_vars, const IterVar& v) { size_t pos = FindNodeRef(leaf_vars, v); - if (pos < leaf_vars->data.size()) return pos; + if (pos < leaf_vars->size()) return pos; - if (FindNodeRef(all_vars, v) < all_vars->data.size()) { - LOG(FATAL) << "Operate on iter var " << v - << "that has already been split"; + if (FindNodeRef(all_vars, v) < all_vars->size()) { + LOG(FATAL) << "Operate on iter var " << v << "that has already been split"; } else { - LOG(FATAL) << "Operate on iter var " << v - << "that is not part of the schedule"; + LOG(FATAL) << "Operate on iter var " << v << "that is not part of the schedule"; } return 0; } -void Split(StageNode* self, - IterVar parent, - PrimExpr factor, - PrimExpr nparts, - IterVar* p_outer, - IterVar* p_inner) { +void SplitHelper(StageNode* self, IterVar parent, PrimExpr factor, PrimExpr nparts, + IterVar* p_outer, IterVar* p_inner) { // Check if split is valid. - CHECK(parent->iter_type == kDataPar || - parent->iter_type == kCommReduce || + CHECK(parent->iter_type == kDataPar || parent->iter_type == kCommReduce || parent->iter_type == kOrdered) << "Cannot split on " << IterVarType2String(parent->iter_type); - IterVar outer = IterVarNode::make( - Range(), parent->var.copy_with_suffix(".outer"), parent->iter_type); - IterVar inner = IterVarNode::make( - Range(), parent->var.copy_with_suffix(".inner"), parent->iter_type); + IterVar outer = IterVar(Range(), parent->var.copy_with_suffix(".outer"), parent->iter_type); + IterVar inner = IterVar(Range(), parent->var.copy_with_suffix(".inner"), parent->iter_type); *p_outer = outer; *p_inner = inner; // The splits - ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite(); - ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite(); - size_t pos = FindLeafVar(all_vars, leaf_vars, parent); - self->relations.push_back(SplitNode::make(parent, outer, inner, factor, nparts)); + Array& all_vars = self->all_iter_vars; + Array& leaf_vars = self->leaf_iter_vars; + size_t pos = FindLeafVar(all_vars.GetArrayNode(), leaf_vars.GetArrayNode(), parent); + self->relations.push_back(Split(parent, outer, inner, factor, nparts)); // add vars to all vars - all_vars->data.push_back(outer); - all_vars->data.push_back(inner); + all_vars.push_back(outer); + all_vars.push_back(inner); // replace the position. - leaf_vars->data.erase(leaf_vars->data.begin() + pos); - leaf_vars->data.insert(leaf_vars->data.begin() + pos, inner); - leaf_vars->data.insert(leaf_vars->data.begin() + pos, outer); + leaf_vars.erase(leaf_vars.begin() + pos); + leaf_vars.insert(leaf_vars.begin() + pos, inner); + leaf_vars.insert(leaf_vars.begin() + pos, outer); } Stage::Stage(Operation op) { @@ -112,8 +105,7 @@ bool Stage::is_scheduled() const { Stage Stage::GetAttachSpec() const { Stage attach_spec = *this; - while (attach_spec->attach_type == kGroupRoot && - attach_spec->group.defined()) { + while (attach_spec->attach_type == kGroupRoot && attach_spec->group.defined()) { attach_spec = attach_spec->group; } return attach_spec; @@ -124,9 +116,8 @@ Stage& Stage::set_scope(std::string scope) { // NOLINT(*) return *this; } -Stage& Stage::compute_at(Stage parent, IterVar scope) { // NOLINT(*) - CHECK_NE((*this)->attach_type, kScanUpdate) - << "Cannot specify compute_at for scan updates"; +Stage& Stage::compute_at(Stage parent, IterVar scope) { // NOLINT(*) + CHECK_NE((*this)->attach_type, kScanUpdate) << "Cannot specify compute_at for scan updates"; // Group constraint checking. Stage group = (*this)->group; if (group.defined()) { @@ -134,8 +125,7 @@ Stage& Stage::compute_at(Stage parent, IterVar scope) { // NOLINT(*) while (pg.defined() && !pg.same_as(group)) { pg = pg->group; } - CHECK(pg.same_as(group)) - << "Can only assign compute_at to stages within the same group"; + CHECK(pg.same_as(group)) << "Can only assign compute_at to stages within the same group"; } (*this)->attach_type = kScope; @@ -144,34 +134,30 @@ Stage& Stage::compute_at(Stage parent, IterVar scope) { // NOLINT(*) bool found = false; for (size_t i = 0; i < parent->leaf_iter_vars.size(); ++i) { if (scope == parent->leaf_iter_vars[i]) { - found = true; break; + found = true; + break; } } - CHECK(found) - << "Cannot find the axis " << scope - << " in parent's leaf_iter_vars" - << " parent=" << parent; + CHECK(found) << "Cannot find the axis " << scope << " in parent's leaf_iter_vars" + << " parent=" << parent; return *this; } -Stage& Stage::compute_inline() { // NOLINT(*) - CHECK_NE((*this)->attach_type, kScanUpdate) - << "Cannot specify compute_at for scan updates"; +Stage& Stage::compute_inline() { // NOLINT(*) + CHECK_NE((*this)->attach_type, kScanUpdate) << "Cannot specify compute_at for scan updates"; (*this)->attach_type = kInline; return *this; } -Stage& Stage::compute_root() { // NOLINT(*) - CHECK_NE((*this)->attach_type, kScanUpdate) - << "Cannot specify compute_at for scan updates"; +Stage& Stage::compute_root() { // NOLINT(*) + CHECK_NE((*this)->attach_type, kScanUpdate) << "Cannot specify compute_at for scan updates"; (*this)->attach_type = kGroupRoot; return *this; } -Stage& Stage::bind(IterVar ivar, IterVar thread_ivar) { // NOLINT(*) +Stage& Stage::bind(IterVar ivar, IterVar thread_ivar) { // NOLINT(*) StageNode* self = operator->(); - CHECK(ivar->iter_type == kDataPar || - ivar->iter_type == kCommReduce) + CHECK(ivar->iter_type == kDataPar || ivar->iter_type == kCommReduce) << "Cannot bind " << IterVarType2String(ivar->iter_type) << " to thread"; CHECK(thread_ivar->iter_type == kThreadIndex) << "Cannot rebase by " << IterVarType2String(ivar->iter_type) @@ -184,10 +170,8 @@ Stage& Stage::bind(IterVar ivar, IterVar thread_ivar) { // NOLINT(*) ObjectPtr n; if (it != self->iter_var_attrs.end()) { n = make_object(*(*it).second.operator->()); - if (n->bind_thread.defined() && - !n->bind_thread.same_as(thread_ivar)) { - LOG(WARNING) << "Axis " << ivar - << " is already bind to another thread " << n->bind_thread; + if (n->bind_thread.defined() && !n->bind_thread.same_as(thread_ivar)) { + LOG(WARNING) << "Axis " << ivar << " is already bind to another thread " << n->bind_thread; } } else { n = make_object(); @@ -201,18 +185,15 @@ Stage& Stage::env_threads(Array threads) { StageNode* self = operator->(); CHECK(self->op.defined() && self->op.as()) << "env_threads is only valid for composite ops such as ScanOp"; - CHECK_EQ(self->env_threads.size(), 0U) - << "Already set env_threads"; - ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite(); - ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite(); + CHECK_EQ(self->env_threads.size(), 0U) << "Already set env_threads"; + Array& leaf_vars = self->leaf_iter_vars; + Array& all_vars = self->all_iter_vars; std::vector temp; for (IterVar iv : threads) { temp.push_back(iv); } - leaf_vars->data.insert( - leaf_vars->data.begin(), temp.begin(), temp.end()); - all_vars->data.insert( - all_vars->data.end(), temp.begin(), temp.end()); + leaf_vars.insert(leaf_vars.begin(), temp.begin(), temp.end()); + all_vars.insert(all_vars.end(), temp.begin(), temp.end()); self->env_threads = threads; return *this; } @@ -223,54 +204,48 @@ Stage& Stage::set_store_predicate(PrimExpr predicate) { return *this; } -Stage& Stage::split( - IterVar parent, PrimExpr factor, IterVar* p_outer, IterVar* p_inner) { // NOLINT(*) - Split(operator->(), parent, factor, PrimExpr(), p_outer, p_inner); +Stage& Stage::split(IterVar parent, PrimExpr factor, IterVar* p_outer, + IterVar* p_inner) { // NOLINT(*) + SplitHelper(operator->(), parent, factor, PrimExpr(), p_outer, p_inner); return *this; } -Stage& Stage::split_by_nparts( - IterVar parent, PrimExpr nparts, IterVar* p_outer, IterVar* p_inner) { // NOLINT(*) - Split(operator->(), parent, PrimExpr(), nparts, p_outer, p_inner); +Stage& Stage::split_by_nparts(IterVar parent, PrimExpr nparts, IterVar* p_outer, + IterVar* p_inner) { // NOLINT(*) + SplitHelper(operator->(), parent, PrimExpr(), nparts, p_outer, p_inner); return *this; } Stage& Stage::fuse(IterVar outer, IterVar inner, IterVar* p_target) { // NOLINT(*) StageNode* self = operator->(); - CHECK(outer->iter_type == kDataPar || - outer->iter_type == kCommReduce || + CHECK(outer->iter_type == kDataPar || outer->iter_type == kCommReduce || outer->iter_type == kOrdered) << "Cannot fuse " << IterVarType2String(outer->iter_type); - CHECK(inner->iter_type == kDataPar || - inner->iter_type == kCommReduce || + CHECK(inner->iter_type == kDataPar || inner->iter_type == kCommReduce || inner->iter_type == kOrdered) << "Cannot fuse " << IterVarType2String(inner->iter_type); IterVarType iter_type = outer->iter_type; if (inner->iter_type > iter_type) iter_type = inner->iter_type; - std::string fused_name = - outer->var->name_hint + "." + inner->var->name_hint + ".fused"; + std::string fused_name = outer->var->name_hint + "." + inner->var->name_hint + ".fused"; - IterVar fused = IterVarNode::make( - Range(), Var(fused_name, outer->var.dtype()), iter_type); + IterVar fused = IterVar(Range(), Var(fused_name, outer->var.dtype()), iter_type); - ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite(); - ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite(); + Array& all_vars = self->all_iter_vars; + Array& leaf_vars = self->leaf_iter_vars; - size_t pos_inner = FindLeafVar(all_vars, leaf_vars, inner); - size_t pos_outer = FindLeafVar(all_vars, leaf_vars, outer); + size_t pos_inner = FindLeafVar(all_vars.GetArrayNode(), leaf_vars.GetArrayNode(), inner); + size_t pos_outer = FindLeafVar(all_vars.GetArrayNode(), leaf_vars.GetArrayNode(), outer); if (pos_inner + 1 == pos_outer) { std::swap(outer, inner); std::swap(pos_inner, pos_outer); } CHECK_EQ(pos_inner, pos_outer + 1) << "Can only fuse iterations that are consecutive between each other"; - self->relations.push_back(FuseNode::make(outer, inner, fused)); - all_vars->data.push_back(fused); - leaf_vars->data.erase(leaf_vars->data.begin() + pos_outer, - leaf_vars->data.begin() + pos_inner + 1); - leaf_vars->data.insert(leaf_vars->data.begin() + pos_outer, - fused); + self->relations.push_back(Fuse(outer, inner, fused)); + all_vars.push_back(fused); + leaf_vars.erase(leaf_vars.begin() + pos_outer, leaf_vars.begin() + pos_inner + 1); + leaf_vars.insert(leaf_vars.begin() + pos_outer, fused); *p_target = fused; return *this; } @@ -286,14 +261,13 @@ Stage& Stage::fuse(const Array& axes, IterVar* p_target) { // NOLINT(* StageNode* self = operator->(); // special handle fuse empty array. // insert at the outer most loop - IterVar singleton = IterVarNode::make( - Range::make_by_min_extent(0, 1), - Var("singleton", DataType::Int(32)), kDataPar); - self->relations.push_back(SingletonNode::make(singleton)); - ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite(); - ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite(); - all_vars->data.push_back(singleton); - leaf_vars->data.insert(leaf_vars->data.begin(), singleton); + IterVar singleton = + IterVar(Range::make_by_min_extent(0, 1), Var("singleton", DataType::Int(32)), kDataPar); + self->relations.push_back(Singleton(singleton)); + Array& all_vars = self->all_iter_vars; + Array& leaf_vars = self->leaf_iter_vars; + all_vars.push_back(singleton); + leaf_vars.insert(leaf_vars.begin(), singleton); *p_target = singleton; } return *this; @@ -303,14 +277,11 @@ Stage& Stage::reorder(const Array& order) { // NOLINT(*) std::unordered_set seen_var; StageNode* self = operator->(); for (IterVar iv : order) { - CHECK(iv->iter_type == kDataPar || - iv->iter_type == kCommReduce || + CHECK(iv->iter_type == kDataPar || iv->iter_type == kCommReduce || iv->iter_type == kThreadIndex) - << "Cannot reorder IterVar(" - << IterVarType2String(iv->iter_type) << ")"; + << "Cannot reorder IterVar(" << IterVarType2String(iv->iter_type) << ")"; - CHECK_EQ(seen_var.count(iv), 0) - << "Same axis can not appear more than once " << iv; + CHECK_EQ(seen_var.count(iv), 0) << "Same axis can not appear more than once " << iv; seen_var.insert(iv); } ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite(); @@ -322,29 +293,25 @@ Stage& Stage::reorder(const Array& order) { // NOLINT(*) } std::vector temp; for (size_t i = 0; i < pos.size(); ++i) { - temp.emplace_back(leaf_vars->data[pos[i]]); + temp.emplace_back(leaf_vars->at(pos[i])); } std::sort(pos.begin(), pos.end()); for (size_t i = 0; i < pos.size(); ++i) { - leaf_vars->data[pos[i]] = temp[i]; + leaf_vars->SetItem(pos[i], temp[i]); } return *this; } -Stage& Stage::tile(IterVar x_parent, IterVar y_parent, - PrimExpr x_factor, PrimExpr y_factor, - IterVar* p_x_outer, IterVar* p_y_outer, - IterVar* p_x_inner, IterVar* p_y_inner) { +Stage& Stage::tile(IterVar x_parent, IterVar y_parent, PrimExpr x_factor, PrimExpr y_factor, + IterVar* p_x_outer, IterVar* p_y_outer, IterVar* p_x_inner, IterVar* p_y_inner) { split(x_parent, x_factor, p_x_outer, p_x_inner); split(y_parent, y_factor, p_y_outer, p_y_inner); reorder(Array({*p_x_outer, *p_y_outer, *p_x_inner, *p_y_inner})); return *this; } -template -inline void UpdateIterVarAttr(StageNode* self, - IterVar var, - FUpdate fupdate, +template +inline void UpdateIterVarAttr(StageNode* self, IterVar var, FUpdate fupdate, bool need_leaf = true) { if (need_leaf) { ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite(); @@ -363,60 +330,53 @@ inline void UpdateIterVarAttr(StageNode* self, } inline void SetAttrIterType(StageNode* self, IterVar var, IterVarType iter_type) { - UpdateIterVarAttr(self, var, [iter_type](IterVarAttrNode* n) { - n->iter_type = iter_type; - }); + UpdateIterVarAttr(self, var, [iter_type](IterVarAttrNode* n) { n->iter_type = iter_type; }); } -Stage& Stage::vectorize(IterVar var) { // NOLINT(*) - CHECK(var->iter_type == kDataPar || - var->iter_type == kOpaque || - var->iter_type == kUnrolled || - var->iter_type == kVectorized || - var->iter_type == kTensorized || +Stage& Stage::vectorize(IterVar var) { // NOLINT(*) + CHECK(var->iter_type == kDataPar || var->iter_type == kOpaque || var->iter_type == kUnrolled || + var->iter_type == kVectorized || var->iter_type == kTensorized || var->iter_type == kParallelized) << "Cannot vectorize on " << IterVarType2String(var->iter_type); SetAttrIterType(operator->(), var, kVectorized); return *this; } -Stage& Stage::tensorize(IterVar var, TensorIntrin f) { // NOLINT(*) +Stage& Stage::tensorize(IterVar var, TensorIntrin f) { // NOLINT(*) UpdateIterVarAttr(operator->(), var, [f](IterVarAttrNode* n) { - n->iter_type = kTensorized; - n->tensor_intrin = f; - }); + n->iter_type = kTensorized; + n->tensor_intrin = f; + }); return *this; } -Stage& Stage::unroll(IterVar var) { // NOLINT(*) +Stage& Stage::unroll(IterVar var) { // NOLINT(*) SetAttrIterType(operator->(), var, kUnrolled); return *this; } -Stage& Stage::parallel(IterVar var) { // NOLINT(*) +Stage& Stage::parallel(IterVar var) { // NOLINT(*) SetAttrIterType(operator->(), var, kParallelized); return *this; } -Stage& Stage::pragma(IterVar var, - const std::string& pragma_type, - const PrimExpr& pragma_value) { // NOLINT(*) +Stage& Stage::pragma(IterVar var, const std::string& pragma_type, + const PrimExpr& pragma_value) { // NOLINT(*) if (pragma_type == "unroll") { this->unroll(var); } else if (pragma_type == "vectorize") { this->vectorize(var); } else { - UpdateIterVarAttr( - operator->(), var, [pragma_type, pragma_value](IterVarAttrNode* n) { - n->pragma_keys.push_back(tir::StringImmNode::make(pragma_type)); - n->pragma_values.push_back(pragma_value); - }); + UpdateIterVarAttr(operator->(), var, [pragma_type, pragma_value](IterVarAttrNode* n) { + n->pragma_keys.push_back(tir::StringImm(pragma_type)); + n->pragma_values.push_back(pragma_value); + }); } return *this; } -Stage& Stage::prefetch(const Tensor &tensor, IterVar var, PrimExpr offset) { - StageNode *self = operator->(); +Stage& Stage::prefetch(const Tensor& tensor, IterVar var, PrimExpr offset) { + StageNode* self = operator->(); ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite(); ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite(); FindLeafVar(all_vars, leaf_vars, var); @@ -434,73 +394,33 @@ Stage& Stage::prefetch(const Tensor &tensor, IterVar var, PrimExpr offset) { } Stage& Stage::storage_align(IterVar axis, int factor, int offset) { - StageNode *self = operator->(); - UpdateIterVarAttr(self, axis, [factor, offset](IterVarAttrNode* n) { - n->dim_align_factor = factor; - n->dim_align_offset = offset; - }, false); + StageNode* self = operator->(); + UpdateIterVarAttr( + self, axis, + [factor, offset](IterVarAttrNode* n) { + n->dim_align_factor = factor; + n->dim_align_offset = offset; + }, + false); return *this; } Stage& Stage::double_buffer() { - StageNode *self = operator->(); + StageNode* self = operator->(); CHECK(!self->is_output) << "Cannot apply double buffer on output"; self->double_buffer = true; return *this; } -Stage& Stage::opengl() { - CHECK(!is_scheduled()) << "Must be a fresh schedule"; - StageNode *self = operator->(); - - auto all_iter_vars = self->all_iter_vars; // curr version of all_iter_vars - CHECK(!all_iter_vars.empty()) << "At least one iter var"; - - // Fuse all data parallel dimensions to 1. - IterVar fused = all_iter_vars[0]; - for (size_t i = 1; i != all_iter_vars.size(); ++i) { - auto iter_var = all_iter_vars[i]; - switch (iter_var->iter_type) { - case IterVarType::kDataPar: { - fuse(fused, all_iter_vars[i], &fused); - break; - } - case IterVarType::kThreadIndex: { - LOG(ERROR) << "A fresh schedule shouldn't have thread index iter var"; - break; - } - case IterVarType::kCommReduce: - case IterVarType::kOrdered: - case IterVarType::kOpaque: { - break; - } - default: { - LOG(ERROR) << "Invalid iter var type " - << IterVarType2String(iter_var->iter_type); - break; - } - } - } - - // Bind the only dimension to threadIdx.x. - bind(fused, thread_axis(Range(nullptr), "threadIdx.x")); - - // Mark this stage as OpenGL. - (*this)->is_opengl = true; - - return *this; -} - Stage CopyStage(const Stage& s) { - ObjectPtr n = - make_object(*s.operator->()); + ObjectPtr n = make_object(*s.operator->()); return Stage(n); } Schedule Schedule::copy() const { // map of stages. const ScheduleNode* self = operator->(); - std::unordered_map smap; + std::unordered_map smap; ObjectPtr n = make_object(); n->outputs = self->outputs; // Copy the stages. @@ -521,24 +441,22 @@ Schedule Schedule::copy() const { for (Stage s : n->stages) { if (s->attach_stage.defined()) { CHECK(smap.find(s->attach_stage) != smap.end()) - << s->attach_stage << " not found in " << (*this); + << s->attach_stage << " not found in " << (*this); s->attach_stage = smap.at(s->attach_stage); } if (s->group.defined()) { - CHECK(smap.find(s->group) != smap.end()) - << s->group << " not found in " << (*this); + CHECK(smap.find(s->group) != smap.end()) << s->group << " not found in " << (*this); s->group = smap.at(s->group); } } for (Stage s : n->groups) { if (s->attach_stage.defined()) { CHECK(smap.find(s->attach_stage) != smap.end()) - << s->attach_stage << " not found in " << (*this); + << s->attach_stage << " not found in " << (*this); s->attach_stage = smap.at(s->attach_stage); } if (s->group.defined()) { - CHECK(smap.find(s->group) != smap.end()) - << s->group << " not found in " << (*this); + CHECK(smap.find(s->group) != smap.end()) << s->group << " not found in " << (*this); s->group = smap.at(s->group); } } @@ -548,8 +466,7 @@ Schedule Schedule::copy() const { Stage Schedule::operator[](const Operation& op) { auto it = (*this)->stage_map.find(op); CHECK(it != (*this)->stage_map.end()) - << "Cannot find Stage for operator " << op - << " in the schedule"; + << "Cannot find Stage for operator " << op << " in the schedule"; return (*it).second; } @@ -570,15 +487,13 @@ Stage LeastCommonAncestor(Stage g1, Stage g2) { return g; } -Array RemapTensor(ScheduleNode* self, - const Array& arr) { +Array RemapTensor(ScheduleNode* self, const Array& arr) { self->InitCache(); const auto& op2stage_cache = self->op2stage_cache_; Array ret; for (Tensor t : arr) { if (!op2stage_cache.count(t->op.get())) { - CHECK(self->stage_map.count(t->op)) - << "Given tensor is not in the schedule plan"; + CHECK(self->stage_map.count(t->op)) << "Given tensor is not in the schedule plan"; t = self->stage_map[t->op]->op.output(t->value_index); } ret.push_back(t); @@ -587,24 +502,21 @@ Array RemapTensor(ScheduleNode* self, } // Group the schedule stages. -Stage Schedule::create_group(const Array& outputs, - const Array& inputs, +Stage Schedule::create_group(const Array& outputs, const Array& inputs, bool include_inputs) { ScheduleNode* self = operator->(); self->InitCache(); const auto& op2stage_cache = self->op2stage_cache_; // Get the ops. - Array ops = te::GetSubGraph( - RemapTensor(self, outputs), - RemapTensor(self, inputs), - include_inputs); + Array ops = + te::GetSubGraph(RemapTensor(self, outputs), RemapTensor(self, inputs), include_inputs); // local counter entry // Automatically initialize to 0 during creation. struct Entry { int count{0}; }; // Map of group->touched counter - std::unordered_map counter; + std::unordered_map counter; // The parent group; Stage parent_group; // Detect common parent and child. @@ -631,7 +543,7 @@ Stage Schedule::create_group(const Array& outputs, // Propagate the counter statistics from by checking if subgroup // Is full and propagate. std::vector stack; - for (auto &kv : counter) { + for (auto& kv : counter) { if (!kv.first.same_as(parent_group)) { if (kv.first->num_child_stages == kv.second.count) { stack.push_back(kv.first); @@ -650,7 +562,7 @@ Stage Schedule::create_group(const Array& outputs, } } // Verification and remappig the subgroups. - for (auto &kv : counter) { + for (auto& kv : counter) { if (kv.first.same_as(parent_group)) continue; CHECK_EQ(kv.first->num_child_stages, kv.second.count) << "Trying to group region that intersect with an already existed group"; @@ -695,9 +607,7 @@ Stage Schedule::create_group(const Array& outputs, return gstage; } -void ScheduleNode::InvalidateCache() { - op2stage_cache_.clear(); -} +void ScheduleNode::InvalidateCache() { op2stage_cache_.clear(); } void ScheduleNode::InitCache() { if (op2stage_cache_.size() == stages.size()) return; @@ -714,9 +624,9 @@ bool ScheduleNode::Contain(const Operation& op) const { return stage_map.find(op) != stage_map.end(); } -Schedule ScheduleNode::make(Array ops) { +Schedule::Schedule(Array ops) { auto n = make_object(); - Schedule sch(n); + data_ = n; n->outputs = ops; auto g = te::CreateReadGraph(n->outputs); Array post_order = te::PostDFSOrder(n->outputs, g); @@ -740,7 +650,7 @@ Schedule ScheduleNode::make(Array ops) { inputs.push_back(t); } // Create the scan group. - Stage scan_group = sch.create_group(scan->update, inputs, false); + Stage scan_group = this->create_group(scan->update, inputs, false); scan_group->attach_type = kScanUpdate; scan_group->attach_stage = stage; @@ -750,43 +660,37 @@ Schedule ScheduleNode::make(Array ops) { } } } - return sch; } -IterVarRelation SplitNode::make(IterVar parent, - IterVar outer, - IterVar inner, - PrimExpr factor, - PrimExpr nparts) { +Split::Split(IterVar parent, IterVar outer, IterVar inner, PrimExpr factor, PrimExpr nparts) { auto n = make_object(); n->parent = parent; n->outer = outer; n->inner = inner; n->factor = factor; n->nparts = nparts; - return IterVarRelation(n); + data_ = std::move(n); } -IterVarRelation FuseNode::make( - IterVar outer, IterVar inner, IterVar fused) { +Fuse::Fuse(IterVar outer, IterVar inner, IterVar fused) { auto n = make_object(); n->outer = outer; n->inner = inner; n->fused = fused; - return IterVarRelation(n); + data_ = std::move(n); } -IterVarRelation RebaseNode::make(IterVar parent, IterVar rebased) { +Rebase::Rebase(IterVar parent, IterVar rebased) { auto n = make_object(); n->parent = parent; n->rebased = rebased; - return IterVarRelation(n); + data_ = std::move(n); } -IterVarRelation SingletonNode::make(IterVar iter) { +Singleton::Singleton(IterVar iter) { auto n = make_object(); n->iter = iter; - return IterVarRelation(n); + data_ = std::move(n); } SpecializedCondition::SpecializedCondition(Array conditions) { @@ -805,19 +709,19 @@ struct TVMSpecializationThreadLocalEntry { typedef dmlc::ThreadLocalStore TVMSpecializationThreadLocalStore; void SpecializedCondition::EnterWithScope() { - TVMSpecializationThreadLocalEntry *entry = TVMSpecializationThreadLocalStore::Get(); + TVMSpecializationThreadLocalEntry* entry = TVMSpecializationThreadLocalStore::Get(); entry->condition_stack.push(*this); } void SpecializedCondition::ExitWithScope() { - TVMSpecializationThreadLocalEntry *entry = TVMSpecializationThreadLocalStore::Get(); + TVMSpecializationThreadLocalEntry* entry = TVMSpecializationThreadLocalStore::Get(); CHECK(!entry->condition_stack.empty()); CHECK(entry->condition_stack.top().same_as(*this)); entry->condition_stack.pop(); } SpecializedCondition SpecializedCondition::Current() { - TVMSpecializationThreadLocalEntry *entry = TVMSpecializationThreadLocalStore::Get(); + TVMSpecializationThreadLocalEntry* entry = TVMSpecializationThreadLocalStore::Get(); SpecializedCondition cond; if (entry->condition_stack.size() > 0) { cond = entry->condition_stack.top(); @@ -827,13 +731,9 @@ SpecializedCondition SpecializedCondition::Current() { class SpecializedCondition::Internal { public: - static void EnterScope(SpecializedCondition cond) { - cond.EnterWithScope(); - } + static void EnterScope(SpecializedCondition cond) { cond.EnterWithScope(); } - static void ExitScope(SpecializedCondition cond) { - cond.ExitWithScope(); - } + static void ExitScope(SpecializedCondition cond) { cond.ExitWithScope(); } }; TVM_REGISTER_NODE_TYPE(StageNode); @@ -847,193 +747,156 @@ TVM_REGISTER_NODE_TYPE(SpecializedConditionNode); // Printer TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - if (op->op.defined()) { - p->stream << "stage(" << op->origin_op->name << ", " << op << ")"; - } else { - p->stream << "group-stage(" << op << ")"; - } -}) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << IterVarType2String(op->iter_type); -}) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "split(parent="; - p->Print(op->parent); - p->stream << ", outer="; - p->Print(op->outer); - p->stream << ", inner="; - p->Print(op->inner); - p->stream << ')'; -}) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "split("; - p->stream << "outer="; - p->Print(op->outer); - p->stream << ", inner="; - p->Print(op->inner); - p->stream << ", fused="; - p->Print(op->fused); - p->stream << ')'; -}) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "rebase("; - p->stream << "parent="; - p->Print(op->parent); - p->stream << ", rebased="; - p->Print(op->rebased); - p->stream << ')'; -}) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "singleton("; - p->Print(op->iter); - p->stream << ')'; -}) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "schedule(" << op << ")"; -}) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "specialized_condition("; - p->Print(op->clauses); - p->stream << ')'; -}); - + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + if (op->op.defined()) { + p->stream << "stage(" << op->origin_op->name << ", " << op << ")"; + } else { + p->stream << "group-stage(" << op << ")"; + } + }) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << IterVarType2String(op->iter_type); + }) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "split(parent="; + p->Print(op->parent); + p->stream << ", outer="; + p->Print(op->outer); + p->stream << ", inner="; + p->Print(op->inner); + p->stream << ')'; + }) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "split("; + p->stream << "outer="; + p->Print(op->outer); + p->stream << ", inner="; + p->Print(op->inner); + p->stream << ", fused="; + p->Print(op->fused); + p->stream << ')'; + }) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "rebase("; + p->stream << "parent="; + p->Print(op->parent); + p->stream << ", rebased="; + p->Print(op->rebased); + p->stream << ')'; + }) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "singleton("; + p->Print(op->iter); + p->stream << ')'; + }) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "schedule(" << op << ")"; + }) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "specialized_condition("; + p->Print(op->clauses); + p->stream << ')'; + }); -TVM_REGISTER_GLOBAL("te.CreateSchedule") -.set_body_typed(create_schedule); +TVM_REGISTER_GLOBAL("te.CreateSchedule").set_body_typed(create_schedule); -TVM_REGISTER_GLOBAL("te.StageSetScope") -.set_body_method(&Stage::set_scope); +TVM_REGISTER_GLOBAL("te.StageSetScope").set_body_method(&Stage::set_scope); -TVM_REGISTER_GLOBAL("te.StageBind") -.set_body_method(&Stage::bind); +TVM_REGISTER_GLOBAL("te.StageBind").set_body_method(&Stage::bind); TVM_REGISTER_GLOBAL("te.StageSplitByFactor") -.set_body_typed([](Stage stage, IterVar parent, PrimExpr factor) { - IterVar outer, inner; - stage.split(parent, factor, &outer, &inner); - return Array({outer, inner}); -}); + .set_body_typed([](Stage stage, IterVar parent, PrimExpr factor) { + IterVar outer, inner; + stage.split(parent, factor, &outer, &inner); + return Array({outer, inner}); + }); TVM_REGISTER_GLOBAL("te.StageSplitByNParts") -.set_body_typed([](Stage stage, IterVar parent, PrimExpr nparts) { - IterVar outer, inner; - stage.split_by_nparts(parent, nparts, &outer, &inner); - return Array({outer, inner}); -}); + .set_body_typed([](Stage stage, IterVar parent, PrimExpr nparts) { + IterVar outer, inner; + stage.split_by_nparts(parent, nparts, &outer, &inner); + return Array({outer, inner}); + }); -TVM_REGISTER_GLOBAL("te.StageFuse") -.set_body_typed([](Stage stage, Array axes) { - IterVar fused; - stage.fuse(axes, &fused); - return fused; - }); +TVM_REGISTER_GLOBAL("te.StageFuse").set_body_typed([](Stage stage, Array axes) { + IterVar fused; + stage.fuse(axes, &fused); + return fused; +}); -TVM_REGISTER_GLOBAL("te.StageComputeAt") -.set_body_method(&Stage::compute_at); +TVM_REGISTER_GLOBAL("te.StageComputeAt").set_body_method(&Stage::compute_at); -TVM_REGISTER_GLOBAL("te.StageComputeInline") -.set_body_method(&Stage::compute_inline); +TVM_REGISTER_GLOBAL("te.StageComputeInline").set_body_method(&Stage::compute_inline); -TVM_REGISTER_GLOBAL("te.StageComputeRoot") -.set_body_method(&Stage::compute_root); +TVM_REGISTER_GLOBAL("te.StageComputeRoot").set_body_method(&Stage::compute_root); -TVM_REGISTER_GLOBAL("te.StageReorder") -.set_body_method(&Stage::reorder); +TVM_REGISTER_GLOBAL("te.StageReorder").set_body_method(&Stage::reorder); TVM_REGISTER_GLOBAL("te.StageTile") -.set_body_typed([]( - Stage stage, - IterVar x_parent, IterVar y_parent, - PrimExpr x_factor, PrimExpr y_factor -) { - IterVar x_outer, y_outer, x_inner, y_inner; - stage.tile(x_parent, y_parent, - x_factor, y_factor, - &x_outer, &y_outer, - &x_inner, &y_inner); - return Array({x_outer, y_outer, x_inner, y_inner}); - }); - -TVM_REGISTER_GLOBAL("te.StageEnvThreads") -.set_body_method(&Stage::env_threads); + .set_body_typed([](Stage stage, IterVar x_parent, IterVar y_parent, PrimExpr x_factor, + PrimExpr y_factor) { + IterVar x_outer, y_outer, x_inner, y_inner; + stage.tile(x_parent, y_parent, x_factor, y_factor, &x_outer, &y_outer, &x_inner, &y_inner); + return Array({x_outer, y_outer, x_inner, y_inner}); + }); -TVM_REGISTER_GLOBAL("te.StageSetStorePredicate") -.set_body_method(&Stage::set_store_predicate); +TVM_REGISTER_GLOBAL("te.StageEnvThreads").set_body_method(&Stage::env_threads); -TVM_REGISTER_GLOBAL("te.StageUnroll") -.set_body_method(&Stage::unroll); +TVM_REGISTER_GLOBAL("te.StageSetStorePredicate").set_body_method(&Stage::set_store_predicate); -TVM_REGISTER_GLOBAL("te.StageVectorize") -.set_body_method(&Stage::vectorize); +TVM_REGISTER_GLOBAL("te.StageUnroll").set_body_method(&Stage::unroll); -TVM_REGISTER_GLOBAL("te.StageTensorize") -.set_body_method(&Stage::tensorize); +TVM_REGISTER_GLOBAL("te.StageVectorize").set_body_method(&Stage::vectorize); -TVM_REGISTER_GLOBAL("te.StageParallel") -.set_body_method(&Stage::parallel); +TVM_REGISTER_GLOBAL("te.StageTensorize").set_body_method(&Stage::tensorize); -TVM_REGISTER_GLOBAL("te.StagePragma") -.set_body_method(&Stage::pragma); +TVM_REGISTER_GLOBAL("te.StageParallel").set_body_method(&Stage::parallel); -TVM_REGISTER_GLOBAL("te.StagePrefetch") -.set_body_method(&Stage::prefetch); +TVM_REGISTER_GLOBAL("te.StagePragma").set_body_method(&Stage::pragma); -TVM_REGISTER_GLOBAL("te.StageStorageAlign") -.set_body_method(&Stage::storage_align); +TVM_REGISTER_GLOBAL("te.StagePrefetch").set_body_method(&Stage::prefetch); -TVM_REGISTER_GLOBAL("te.StageDoubleBuffer") -.set_body_method(&Stage::double_buffer); +TVM_REGISTER_GLOBAL("te.StageStorageAlign").set_body_method(&Stage::storage_align); -TVM_REGISTER_GLOBAL("te.StageOpenGL") -.set_body_method(&Stage::opengl); +TVM_REGISTER_GLOBAL("te.StageDoubleBuffer").set_body_method(&Stage::double_buffer); -TVM_REGISTER_GLOBAL("te.ScheduleNormalize") -.set_body_method(&Schedule::normalize); +TVM_REGISTER_GLOBAL("te.ScheduleNormalize").set_body_method(&Schedule::normalize); -TVM_REGISTER_GLOBAL("te.ScheduleCreateGroup") -.set_body_method(&Schedule::create_group); +TVM_REGISTER_GLOBAL("te.ScheduleCreateGroup").set_body_method(&Schedule::create_group); -TVM_REGISTER_GLOBAL("te.ScheduleCacheRead") -.set_body_method(&Schedule::cache_read); +TVM_REGISTER_GLOBAL("te.ScheduleCacheRead").set_body_method(&Schedule::cache_read); -TVM_REGISTER_GLOBAL("te.ScheduleCacheWrite") -.set_body([](TVMArgs args, TVMRetValue* ret) { - if (args[1].IsObjectRef()) { - *ret = args[0].operator Schedule() - .cache_write(args[1].operator Tensor(), args[2]); - } else { - *ret = args[0].operator Schedule() - .cache_write(args[1].operator Array(), args[2]); - } - }); +TVM_REGISTER_GLOBAL("te.ScheduleCacheWrite").set_body([](TVMArgs args, TVMRetValue* ret) { + if (args[1].IsObjectRef()) { + *ret = args[0].operator Schedule().cache_write(args[1].operator Tensor(), args[2]); + } else { + *ret = args[0].operator Schedule().cache_write(args[1].operator Array(), args[2]); + } +}); -TVM_REGISTER_GLOBAL("te.ScheduleRFactor") -.set_body_method(&Schedule::rfactor); +TVM_REGISTER_GLOBAL("te.ScheduleRFactor").set_body_method(&Schedule::rfactor); -TVM_REGISTER_GLOBAL("te.CreateSpecializedCondition") -.set_body_typed([](Array condition) { - return SpecializedCondition(condition); +TVM_REGISTER_GLOBAL("te.CreateSpecializedCondition").set_body_typed([](Array condition) { + return SpecializedCondition(condition); }); -TVM_REGISTER_GLOBAL("te.GetCurrentSpecialization") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = SpecializedCondition::Current(); +TVM_REGISTER_GLOBAL("te.GetCurrentSpecialization").set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = SpecializedCondition::Current(); }); TVM_REGISTER_GLOBAL("te.EnterSpecializationScope") -.set_body_typed(SpecializedCondition::Internal::EnterScope); + .set_body_typed(SpecializedCondition::Internal::EnterScope); TVM_REGISTER_GLOBAL("te.ExitSpecializationScope") -.set_body_typed(SpecializedCondition::Internal::ExitScope); + .set_body_typed(SpecializedCondition::Internal::ExitScope); } // namespace te } // namespace tvm diff --git a/src/te/schedule/schedule_ops.cc b/src/te/schedule/schedule_ops.cc index 57b637df0570..f2955f33e225 100644 --- a/src/te/schedule/schedule_ops.cc +++ b/src/te/schedule/schedule_ops.cc @@ -21,31 +21,30 @@ * \file schedule_ops.cc */ #include -#include -#include -#include #include #include -#include +#include +#include +#include + #include #include -#include "graph.h" +#include + +#include "../../tir/transforms/ir_util.h" #include "../operation/op_util.h" -#include "../../tir/pass/ir_util.h" +#include "graph.h" namespace tvm { namespace te { using namespace tir; -Stmt MakePipeline(const Stage& s, - const std::unordered_map& dom_map, - Stmt consumer, +Stmt MakePipeline(const Stage& s, const std::unordered_map& dom_map, Stmt consumer, bool debug_keep_trivial_loop) { Stmt producer = s->op->BuildProvide(s, dom_map, debug_keep_trivial_loop); if (s->double_buffer) { - producer = AttrStmtNode::make( - s->op, tir::attr::double_buffer_scope, 1, producer); + producer = AttrStmt(s->op, tir::attr::double_buffer_scope, 1, producer); } Stmt pipeline = producer; @@ -54,43 +53,32 @@ Stmt MakePipeline(const Stage& s, } pipeline = s->op->BuildRealize(s, dom_map, pipeline); // use attribute to mark scope of the operation. - pipeline = AttrStmtNode::make( - s->op, tir::attr::realize_scope, - StringImmNode::make(s->scope), - pipeline); + pipeline = AttrStmt(s->op, tir::attr::realize_scope, StringImm(s->scope), pipeline); - if (s->is_opengl) { - pipeline = AttrStmtNode::make( - s->op, tir::attr::opengl_stage_scope, StringImmNode::make(""), pipeline); - } return pipeline; } // inject the operator's realization on the stmt. class InjectAttach : public StmtMutator { public: - InjectAttach(const Stage& stage, - const Stage& attach_spec, - const std::unordered_map& dom_map, - bool debug_keep_trivial_loop) - : stage_(stage), attach_spec_(attach_spec), dom_map_(dom_map), + InjectAttach(const Stage& stage, const Stage& attach_spec, + const std::unordered_map& dom_map, bool debug_keep_trivial_loop) + : stage_(stage), + attach_spec_(attach_spec), + dom_map_(dom_map), debug_keep_trivial_loop_(debug_keep_trivial_loop) {} Stmt VisitStmt(const Stmt& input_stmt) final { CHECK(input_stmt.defined()); auto stmt = StmtMutator::VisitStmt(input_stmt); const AttrStmtNode* op = stmt.as(); - if (op != nullptr && - op->attr_key == tir::attr::loop_scope) { - if (attach_spec_->attach_type == kScope && - op->node == attach_spec_->attach_ivar) { - CHECK(!found_attach) - << "Find IterVar" << attach_spec_->attach_ivar - << " in multiple places in the IR"; + if (op != nullptr && op->attr_key == tir::attr::loop_scope) { + if (attach_spec_->attach_type == kScope && op->node == attach_spec_->attach_ivar) { + CHECK(!found_attach) << "Find IterVar" << attach_spec_->attach_ivar + << " in multiple places in the IR"; found_attach = true; - stmt = AttrStmtNode::make( - op->node, op->attr_key, op->value, - MakePipeline(stage_, dom_map_, op->body, debug_keep_trivial_loop_)); + stmt = AttrStmt(op->node, op->attr_key, op->value, + MakePipeline(stage_, dom_map_, op->body, debug_keep_trivial_loop_)); } } return stmt; @@ -113,27 +101,26 @@ class InjectAttach : public StmtMutator { // inject the operator's realization on the stmt. class InjectScanStep : public StmtMutator { public: - InjectScanStep(const Stage& stage, - const Operation& scan_op, - const std::unordered_map& dom_map, - bool is_init, + InjectScanStep(const Stage& stage, const Operation& scan_op, + const std::unordered_map& dom_map, bool is_init, bool debug_keep_trivial_loop) - : stage_(stage), scan_op_(scan_op), - dom_map_(dom_map), is_init_(is_init), debug_keep_trivial_loop_(debug_keep_trivial_loop) {} + : stage_(stage), + scan_op_(scan_op), + dom_map_(dom_map), + is_init_(is_init), + debug_keep_trivial_loop_(debug_keep_trivial_loop) {} Stmt VisitStmt(const Stmt& input_stmt) final { CHECK(input_stmt.defined()); auto stmt = StmtMutator::VisitStmt(input_stmt); // update const AttrStmtNode* op = stmt.as(); - if (op != nullptr && - ((op->attr_key == tir::attr::scan_update_scope && !is_init_) || - (op->attr_key == tir::attr::scan_init_scope && is_init_))) { + if (op != nullptr && ((op->attr_key == tir::attr::scan_update_scope && !is_init_) || + (op->attr_key == tir::attr::scan_init_scope && is_init_))) { if (op->node.same_as(scan_op_)) { found_attach = true; - stmt = AttrStmtNode::make( - op->node, op->attr_key, op->value, - MakePipeline(stage_, dom_map_, op->body, debug_keep_trivial_loop_)); + stmt = AttrStmt(op->node, op->attr_key, op->value, + MakePipeline(stage_, dom_map_, op->body, debug_keep_trivial_loop_)); } } return stmt; @@ -169,8 +156,7 @@ class SchedulePostProc : public StmtExprMutator { } Stmt VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == tir::attr::loop_scope || - op->attr_key == tir::attr::scan_init_scope) { + if (op->attr_key == tir::attr::loop_scope || op->attr_key == tir::attr::scan_init_scope) { return this->VisitStmt(op->body); } else if (op->attr_key == tir::attr::scan_update_scope) { const ScanOpNode* scan = op->node.as(); @@ -181,7 +167,7 @@ class SchedulePostProc : public StmtExprMutator { // delete duplicated thread extent attr auto it = thread_extent_scope_.find(op->node.get()); if (it != thread_extent_scope_.end()) { - CHECK(is_zero(tir::Simplify(it->second - op->value))); + CHECK(is_zero(analyzer_.Simplify(it->second - op->value))); return this->VisitStmt(op->body); } else { thread_extent_scope_[op->node.get()] = op->value; @@ -194,8 +180,7 @@ class SchedulePostProc : public StmtExprMutator { auto it = replace_op_.find(op->node.get()); if (it != replace_op_.end()) { if (it->second.defined()) { - Stmt ret = AttrStmtNode::make( - it->second, op->attr_key, op->value, op->body); + Stmt ret = AttrStmt(it->second, op->attr_key, op->value, op->body); return this->VisitStmt(ret); } else { return this->VisitStmt(op->body); @@ -207,9 +192,8 @@ class SchedulePostProc : public StmtExprMutator { auto it = replace_op_.find(tensor->op.get()); if (it != replace_op_.end()) { if (it->second.defined()) { - return AttrStmtNode::make( - Array{tuple[0], it->second.output(tensor->value_index)}, - op->attr_key, op->value, this->VisitStmt(op->body)); + return AttrStmt(Array{tuple[0], it->second.output(tensor->value_index)}, + op->attr_key, op->value, this->VisitStmt(op->body)); } else { return this->VisitStmt(op->body); } @@ -219,9 +203,8 @@ class SchedulePostProc : public StmtExprMutator { auto it = replace_op_.find(tensor->op.get()); if (it != replace_op_.end()) { if (it->second.defined()) { - return AttrStmtNode::make( - it->second.output(tensor->value_index), - op->attr_key, op->value, this->VisitStmt(op->body)); + return AttrStmt(it->second.output(tensor->value_index), op->attr_key, op->value, + this->VisitStmt(op->body)); } else { return this->VisitStmt(op->body); } @@ -230,14 +213,12 @@ class SchedulePostProc : public StmtExprMutator { return StmtExprMutator::VisitStmt_(op); } - Stmt VisitStmt_(const RealizeNode* op) final { - TensorKey key{op->func, op->value_index}; + Stmt VisitStmt_(const ProducerRealizeNode* op) final { + auto key = Downcast(op->producer); auto it = replace_realize_.find(key); if (it != replace_realize_.end()) { if (it->second.defined()) { - Stmt ret = RealizeNode::make( - it->second->op, it->second->value_index, - op->dtype, op->bounds, op->condition, op->body); + Stmt ret = ProducerRealize(it->second, op->bounds, op->condition, op->body); return this->VisitStmt(ret); } else { return this->VisitStmt(op->body); @@ -247,32 +228,31 @@ class SchedulePostProc : public StmtExprMutator { } } - Stmt VisitStmt_(const ProvideNode* op) final { - TensorKey key{op->func, op->value_index}; + Stmt VisitStmt_(const ProducerStoreNode* op) final { + auto key = Downcast(op->producer); auto it = replace_buffer_.find(key); if (it != replace_buffer_.end()) { const Tensor& dst = it->second; - Stmt ret = ProvideNode::make( - dst->op, dst->value_index, op->value, op->args); + Stmt ret = ProducerStore(dst, op->value, op->indices); return this->VisitStmt(ret); } else { return StmtExprMutator::VisitStmt_(op); } } - PrimExpr VisitExpr_(const CallNode* op) final { - if (op->call_type == CallNode::Halide) { - TensorKey key{op->func, op->value_index}; - auto it = replace_buffer_.find(key); - if (it != replace_buffer_.end()) { - const Tensor& dst = it->second; - PrimExpr ret = CallNode::make( - op->dtype, dst->op->name, op->args, - op->call_type, dst->op, dst->value_index); - return this->VisitExpr(ret); - } + PrimExpr VisitExpr_(const ProducerLoadNode* op) final { + PrimExpr expr = StmtExprMutator::VisitExpr_(op); + op = expr.as(); + CHECK(op != nullptr); + + auto key = Downcast(op->producer); + auto it = replace_buffer_.find(key); + if (it != replace_buffer_.end()) { + const Tensor& dst = it->second; + return ProducerLoad(dst, op->indices); + } else { + return expr; } - return StmtExprMutator::VisitExpr_(op); } PrimExpr VisitExpr_(const VarNode* op) final { @@ -299,8 +279,7 @@ class SchedulePostProc : public StmtExprMutator { if (!s->op.same_as(s->origin_op)) { for (int i = 0; i < s->op->num_outputs(); ++i) { Tensor target = s->origin_op.output(i); - AddReplace(s->op.output(i), target, - target, s->origin_op); + AddReplace(s->op.output(i), target, target, s->origin_op); } } // Specially add replacements for scan op. @@ -316,13 +295,10 @@ class SchedulePostProc : public StmtExprMutator { } private: - void AddReplace(Tensor src, - Tensor dst, - Tensor repl_realize = Tensor(), + void AddReplace(Tensor src, Tensor dst, Tensor repl_realize = Tensor(), Operation repl_op = Operation()) { - TensorKey key{src->op, src->value_index}; - replace_buffer_[key] = dst; - replace_realize_[key] = repl_realize; + replace_buffer_[src] = dst; + replace_realize_[src] = repl_realize; replace_op_[src->op.get()] = repl_op; } // The thread extent scope. @@ -330,15 +306,16 @@ class SchedulePostProc : public StmtExprMutator { // The scan value std::unordered_map var_value_; // buffer replacement - std::unordered_map replace_buffer_; + std::unordered_map replace_buffer_; // buffere realization to be replaced - std::unordered_map replace_realize_; + std::unordered_map replace_realize_; // replace producer consumer. std::unordered_map replace_op_; + // integer analyzer + arith::Analyzer analyzer_; }; -Stmt ScheduleOps( - Schedule sch, Map dom_map_, bool debug_keep_trivial_loop) { +Stmt ScheduleOps(Schedule sch, Map dom_map_, bool debug_keep_trivial_loop) { Stmt body = Stmt(); std::unordered_map dom_map = as_unordered_map(dom_map_); // scan init and scan updates @@ -348,8 +325,7 @@ Stmt ScheduleOps( if (!scan) continue; for (Tensor t : scan->init) { if (scan_init.count(t->op)) { - CHECK(scan_init.at(t->op).same_as(s->op)) - << "Scan init tensor can only belong to one scan"; + CHECK(scan_init.at(t->op).same_as(s->op)) << "Scan init tensor can only belong to one scan"; } else { scan_init[t->op] = s->op; } @@ -363,8 +339,7 @@ Stmt ScheduleOps( // reverse the post DFS order. for (size_t i = sch->stages.size(); i != 0; --i) { Stage s = sch->stages[i - 1]; - CHECK_NE(s->attach_type, kInline) - << "call schedule.normalize before scheduleops"; + CHECK_NE(s->attach_type, kInline) << "call schedule.normalize before scheduleops"; CHECK(s->op.defined()); // no need to specify place holder op. if (s->op.as()) continue; @@ -375,15 +350,13 @@ Stmt ScheduleOps( CHECK(body.defined()); InjectScanStep mu(s, scan_init.at(s->op), dom_map, true, debug_keep_trivial_loop); body = mu(std::move(body)); - CHECK(mu.found_attach) - << "did not find attachment point for scan.init"; + CHECK(mu.found_attach) << "did not find attachment point for scan.init"; } else if (attach_spec->attach_type == kScanUpdate) { // Handle scan update CHECK(body.defined()); InjectScanStep mu(s, attach_spec->attach_stage->op, dom_map, false, debug_keep_trivial_loop); body = mu(std::move(body)); - CHECK(mu.found_attach) - << "did not find attachment point for scan.update"; + CHECK(mu.found_attach) << "did not find attachment point for scan.update"; } else if (attach_spec->attach_type == kInlinedAlready) { // do nothing } else if (attach_spec->attach_type == kGroupRoot) { @@ -394,11 +367,10 @@ Stmt ScheduleOps( CHECK(body.defined()); InjectAttach mutator(s, attach_spec, dom_map, debug_keep_trivial_loop); body = mutator(std::move(body)); - CHECK(mutator.found_attach) - << "did not find attachment point for " << s << " in " - << attach_spec->attach_stage->op << " x " << attach_spec->attach_ivar - << ", body:\n" - << body; + CHECK(mutator.found_attach) << "did not find attachment point for " << s << " in " + << attach_spec->attach_stage->op << " x " + << attach_spec->attach_ivar << ", body:\n" + << body; } } SchedulePostProc post_proc; @@ -406,8 +378,7 @@ Stmt ScheduleOps( return post_proc(std::move(body)); } -TVM_REGISTER_GLOBAL("schedule.ScheduleOps") -.set_body([](TVMArgs args, TVMRetValue* ret) { +TVM_REGISTER_GLOBAL("schedule.ScheduleOps").set_body([](TVMArgs args, TVMRetValue* ret) { if (args.size() == 2) *ret = ScheduleOps(args[0], args[1], false); else diff --git a/src/tir/pass/tensor_core.cc b/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc similarity index 55% rename from src/tir/pass/tensor_core.cc rename to src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc index dc2df985a8ee..1ff569f29f1f 100644 --- a/src/tir/pass/tensor_core.cc +++ b/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc @@ -18,32 +18,34 @@ */ /*! - * \file tensor_core.cc + * \file schedule_postproc_rewrite_for_tensor_core.cc + * + * \brief Rewrite the Stmt generated by ScheduleOps + * to accomondate tensorcore. */ -// IR Passes for TensorCore CodeGen +#include +#include +#include +#include +#include +#include #include +#include #include -#include #include -#include -#include -#include -#include -#include -#include + #include -#include "ir_util.h" -#include "../../arith/compute_expr.h" + #include "../../runtime/thread_storage_scope.h" namespace tvm { -namespace tir { +namespace te { using namespace te; +using intrinsic::tvm_address_of; using runtime::StorageRank; using runtime::StorageScope; using runtime::ThreadScope; -using intrinsic::tvm_address_of; struct Tile { int m{-1}; @@ -60,7 +62,7 @@ std::string simplify_name(std::string input) { } } -PrimExpr unpack_type_cast(const PrimExpr &input, const DataType &target_type) { +PrimExpr unpack_type_cast(const PrimExpr& input, const DataType& target_type) { auto cast = input.as(); if (cast == nullptr) { return input; @@ -73,7 +75,7 @@ PrimExpr unpack_type_cast(const PrimExpr &input, const DataType &target_type) { // MMAMatcher matches C = Cast(A)*Cast(B)+C, // where A & B are fp16/int8 local buffers, // and C is fp32/int32 local buffer. -class MMAMatcher: public StmtVisitor { +class MMAMatcher : public StmtVisitor { public: explicit MMAMatcher(Map extern_buffer) { for (auto kv : extern_buffer) { @@ -81,15 +83,15 @@ class MMAMatcher: public StmtVisitor { bi.name = kv.second->name; bi.dtype = kv.second->dtype; bi.external = true; - buf_map_[TensorKey{kv.first->op, kv.first->value_index}] = bi; + buf_map_[kv.first] = bi; } } void VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == attr::pragma_tensor_core) { + if (op->attr_key == tir::attr::pragma_tensor_core) { tensor_core_on_ = true; StmtVisitor::VisitStmt_(op); - } else if (op->attr_key == attr::realize_scope) { + } else if (op->attr_key == tir::attr::realize_scope) { storage_scope_[op->node.get()] = op->value.as()->value; this->VisitStmt(op->body); } else { @@ -97,9 +99,9 @@ class MMAMatcher: public StmtVisitor { } } - void VisitStmt_(const ProvideNode* op) final { + void VisitStmt_(const ProducerStoreNode* op) final { StmtVisitor::VisitStmt_(op); - auto it = buf_map_.find(TensorKey{op->func, op->value_index}); + auto it = buf_map_.find(Downcast(op->producer)); if (it == buf_map_.end()) { return; } @@ -112,8 +114,8 @@ class MMAMatcher: public StmtVisitor { } } - void VisitStmt_(const RealizeNode* op) final { - TensorKey key{op->func, op->value_index}; + void VisitStmt_(const ProducerRealizeNode* op) final { + auto key = Downcast(op->producer); if (buf_map_.count(key)) { if (!buf_map_.at(key).external) { return; @@ -121,15 +123,15 @@ class MMAMatcher: public StmtVisitor { this->VisitStmt(op->body); } else { BufferInfo bi; - bi.name = key.GetName(); - bi.dtype = op->dtype; + bi.name = key->GetNameHint(); + bi.dtype = key->dtype; buf_map_[key] = bi; this->VisitStmt(op->body); buf_map_[key].released = true; } } - inline bool Matched() const {return matched_;} + inline bool Matched() const { return matched_; } friend class ScheduleAnalyser; friend class BufferAnalyser; @@ -140,7 +142,7 @@ class MMAMatcher: public StmtVisitor { DataType dtype; bool external{false}; bool released{false}; - bool same_as(const BufferInfo &bi) { + bool same_as(const BufferInfo& bi) { if (this->dtype != bi.dtype) return false; if (this->name != bi.name) return false; if (this->external != bi.external) return false; @@ -150,42 +152,38 @@ class MMAMatcher: public StmtVisitor { }; // Check whether the storage scope is local - bool check_local_buffer_(const CallNode* op, BufferInfo* bi) { - if (op->call_type == CallNode::Halide) { - auto it = storage_scope_.find(op->func.get()); - if (it == storage_scope_.end()) { - return false; - } - const std::string& strkey = it->second; - if (strkey != "local") { - return false; - } - auto it1 = buf_map_.find(TensorKey{op->func, op->value_index}); - if (it1 == buf_map_.end()) { - return false; - } - *bi = it1->second; - if (bi->released) { - return false; - } - return true; + bool check_local_buffer_(const ProducerLoadNode* op, BufferInfo* bi) { + auto tensor = Downcast(op->producer); + auto it = storage_scope_.find(tensor.get()); + if (it == storage_scope_.end()) { + return false; } - return false; + const std::string& strkey = it->second; + if (strkey != "local") { + return false; + } + auto it1 = buf_map_.find(tensor); + if (it1 == buf_map_.end()) { + return false; + } + *bi = it1->second; + if (bi->released) { + return false; + } + return true; } // Do the pattern matching - bool mma_sync_match_(const ProvideNode* op, BufferInfo store_buffer) { + bool mma_sync_match_(const ProducerStoreNode* op, BufferInfo store_buffer) { auto* add = op->value.as(); if (add == nullptr) { return false; } - auto* load_c = add->a.as(); + auto* load_c = add->a.as(); BufferInfo buffer_c; - if (!check_local_buffer_(load_c, &buffer_c) - || !buffer_c.same_as(store_buffer) - || !(buffer_c.dtype == DataType::Float(32) || - buffer_c.dtype == DataType::Int(32))) { + if (!check_local_buffer_(load_c, &buffer_c) || !buffer_c.same_as(store_buffer) || + !(buffer_c.dtype == DataType::Float(32) || buffer_c.dtype == DataType::Int(32))) { return false; } @@ -195,28 +193,22 @@ class MMAMatcher: public StmtVisitor { } auto load_a_expr = unpack_type_cast(mul->a, buffer_c.dtype); - auto load_a = load_a_expr.as(); + auto load_a = load_a_expr.as(); BufferInfo buffer_a; - if (!check_local_buffer_(load_a, &buffer_a) - || !(buffer_a.dtype == DataType::Float(16) || - buffer_a.dtype == DataType::Int(8) || - buffer_a.dtype == DataType::UInt(8) || - buffer_a.dtype == DataType::Int(4) || - buffer_a.dtype == DataType::UInt(4) || - buffer_a.dtype == DataType::Int(1))) { + if (!check_local_buffer_(load_a, &buffer_a) || + !(buffer_a.dtype == DataType::Float(16) || buffer_a.dtype == DataType::Int(8) || + buffer_a.dtype == DataType::UInt(8) || buffer_a.dtype == DataType::Int(4) || + buffer_a.dtype == DataType::UInt(4) || buffer_a.dtype == DataType::Int(1))) { return false; } auto load_b_expr = unpack_type_cast(mul->b, buffer_c.dtype); - auto load_b = load_b_expr.as(); + auto load_b = load_b_expr.as(); BufferInfo buffer_b; - if (!check_local_buffer_(load_b, &buffer_b) - || !(buffer_b.dtype == DataType::Float(16) || - buffer_b.dtype == DataType::Int(8) || - buffer_b.dtype == DataType::UInt(8) || - buffer_b.dtype == DataType::Int(4) || - buffer_a.dtype == DataType::UInt(4) || - buffer_a.dtype == DataType::Int(1))) { + if (!check_local_buffer_(load_b, &buffer_b) || + !(buffer_b.dtype == DataType::Float(16) || buffer_b.dtype == DataType::Int(8) || + buffer_b.dtype == DataType::UInt(8) || buffer_b.dtype == DataType::Int(4) || + buffer_a.dtype == DataType::UInt(4) || buffer_a.dtype == DataType::Int(1))) { return false; } @@ -225,15 +217,14 @@ class MMAMatcher: public StmtVisitor { frag_reg_.insert(buffer_b.name); buf_name_.insert(std::make_pair(load_a, buffer_a.name)); buf_name_.insert(std::make_pair(load_b, buffer_b.name)); - mma_sync_.insert(std::make_pair(op, - Array{load_a_expr, load_b_expr, add->a})); + mma_sync_.insert(std::make_pair(op, Array{load_a_expr, load_b_expr, add->a})); return true; } - std::unordered_map buf_map_; + std::unordered_map buf_map_; std::unordered_map storage_scope_; - std::unordered_map> mma_sync_; + std::unordered_map> mma_sync_; std::unordered_map buf_name_; std::unordered_set frag_reg_; bool matched_{false}; @@ -279,9 +270,8 @@ class BodyVisitor : public StmtExprVisitor { // ScheduleAnalyser figures out matrix_a/matrix_b and row_major/col_major class ScheduleAnalyser { public: - explicit ScheduleAnalyser(const MMAMatcher &mma_matcher) - : mma_sync_(mma_matcher.mma_sync_), - buf_name_(mma_matcher.buf_name_) {} + explicit ScheduleAnalyser(const MMAMatcher& mma_matcher) + : mma_sync_(mma_matcher.mma_sync_), buf_name_(mma_matcher.buf_name_) {} bool MatrixIdentify(Schedule schedule) { // TODO(minmin): handle the case where MatMul is not the output stage @@ -298,8 +288,8 @@ class ScheduleAnalyser { } const VarNode* axis_var[2]; const VarNode* reduce_axis_var; - axis_var[0] = axis[axis.size()-2]->var.as(); - axis_var[1] = axis[axis.size()-1]->var.as(); + axis_var[0] = axis[axis.size() - 2]->var.as(); + axis_var[1] = axis[axis.size() - 1]->var.as(); reduce_axis_var = reduce_axis[0]->var.as(); BodyVisitor body_visitor; @@ -341,8 +331,8 @@ class ScheduleAnalyser { matrix_major_.insert(std::make_pair(compute->name, "col_major")); } - for (auto &mma_sync : mma_sync_) { - auto &operands = mma_sync.second; + for (auto& mma_sync : mma_sync_) { + auto& operands = mma_sync.second; auto* load_a = operands[0].as(); auto* load_b = operands[1].as(); auto input0 = simplify_name(buf_name_.find(load_a)->second); @@ -370,7 +360,7 @@ class ScheduleAnalyser { private: std::unordered_map matrix_abc_; std::unordered_map matrix_major_; - std::unordered_map> mma_sync_; + std::unordered_map> mma_sync_; std::unordered_map buf_name_; }; @@ -397,8 +387,7 @@ class IndexVisitor : public StmtExprVisitor { class BufferAnalyser : public StmtExprVisitor { public: explicit BufferAnalyser(Map extern_buffer, - const ScheduleAnalyser &schedule_analyser, - const MMAMatcher &mma_matcher) + const ScheduleAnalyser& schedule_analyser, const MMAMatcher& mma_matcher) : matrix_abc_(schedule_analyser.matrix_abc_), matrix_major_(schedule_analyser.matrix_major_), frag_reg_(mma_matcher.frag_reg_) { @@ -409,27 +398,25 @@ class BufferAnalyser : public StmtExprVisitor { bi.strides = kv.second->strides; bi.shape = kv.second->shape; bi.external = true; - buf_map_[TensorKey{kv.first->op, kv.first->value_index}] = bi; + buf_map_[kv.first] = bi; } } void VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == attr::thread_extent) { + if (op->attr_key == tir::attr::thread_extent) { if (const IntImmNode* value = op->value.as()) { thread_extent_.insert( - std::make_pair( - op->node.as()->var->name_hint, - value->value)); + std::make_pair(op->node.as()->var->name_hint, value->value)); } StmtExprVisitor::VisitStmt_(op); - } else if (op->attr_key == attr::realize_scope) { + } else if (op->attr_key == tir::attr::realize_scope) { storage_scope_[op->node.get()] = op->value.as()->value; this->VisitStmt(op->body); - } else if (op->attr_key == attr::buffer_dim_align) { + } else if (op->attr_key == tir::attr::buffer_dim_align) { te::Tensor tensor = Downcast(op->node); const CallNode* tuple = op->value.as(); CHECK(tuple && tuple->is_intrinsic(intrinsic::tvm_tuple)); - auto& vinfo = dim_align_[TensorKey{tensor->op, tensor->value_index}]; + auto& vinfo = dim_align_[tensor]; size_t dim = tuple->args[0].as()->value; if (dim >= vinfo.size()) { vinfo.resize(dim + 1); @@ -442,17 +429,15 @@ class BufferAnalyser : public StmtExprVisitor { } } - void VisitStmt_(const ProvideNode* op) final { + void VisitStmt_(const ProducerStoreNode* op) final { StmtExprVisitor::VisitStmt_(op); - TensorKey key{op->func, op->value_index}; + auto key = Downcast(op->producer); auto it = buf_map_.find(key); - CHECK(it != buf_map_.end()) - << "Cannot find allocated buffer for " << key.f; + CHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << key->GetNameHint(); const BufferInfo& bi = it->second; - CHECK(!bi.released) - << "Read a buffer that is already out of scope"; + CHECK(!bi.released) << "Read a buffer that is already out of scope"; - if (matrix_abc_.count(key.GetName())) { + if (matrix_abc_.count(key->GetNameHint())) { if (bi.shape.size() < 2) { invalid_ = true; return; @@ -473,30 +458,25 @@ class BufferAnalyser : public StmtExprVisitor { for (size_t i = 1; i < bi.shape.size(); ++i) { PrimExpr stride = IntImm(DataType::Int(32), 1); for (size_t j = bi.shape.size() - 1; j >= i; --j) { - stride = MulNode::make(stride, bi.shape[j]); + stride = Mul(stride, bi.shape[j]); } strides.push_back(stride); } strides.push_back(make_const(DataType::Int(32), 1)); } - strides_.insert(std::make_pair(key.GetName(), strides)); + strides_.insert(std::make_pair(key->GetNameHint(), strides)); if (frag_reg_.count(bi.name)) { - PrimExpr dst = CallNode::make(bi.dtype, - bi.name, - op->args, - CallNode::Halide, - op->func, - 0); + PrimExpr dst = ProducerLoad(op->producer, op->indices); frag_load_.insert(std::make_pair(op, dst)); - auto rel_index = bi.RelIndex(op->args); - if (op->args.size() < 2) { + auto rel_index = bi.RelIndex(op->indices); + if (op->indices.size() < 2) { invalid_ = true; return; } std::vector tile_size; - for (auto i = op->args.size() - 1; i + 2 >= op->args.size(); --i) { + for (auto i = op->indices.size() - 1; i + 2 >= op->indices.size(); --i) { index_visitor.scaling_factor_ = 16; if (const IntImmNode* shape = bi.shape[i].as()) { tile_size.push_back(shape->value); @@ -506,7 +486,7 @@ class BufferAnalyser : public StmtExprVisitor { return; } auto index = rel_index[i]; - auto simplified_index = tir::Simplify(index); + auto simplified_index = analyzer_.Simplify(index); index_visitor(simplified_index); } @@ -542,81 +522,74 @@ class BufferAnalyser : public StmtExprVisitor { } } - const CallNode* value = op->value.as(); - if (value != nullptr && frag_reg_.count(value->name)) { - PrimExpr dst = CallNode::make(bi.dtype, - bi.name, - op->args, - CallNode::Halide, - op->func, - 0); + const ProducerLoadNode* value = op->value.as(); + // TODO(tvm-team): string matching is dangerous, consider other means. + if (value != nullptr && frag_reg_.count(value->producer->GetNameHint())) { + PrimExpr dst = ProducerLoad(op->producer, op->indices); frag_store_.insert(std::make_pair(op, dst)); } } - void VisitExpr_(const CallNode* op) final { + void VisitExpr_(const ProducerLoadNode* op) final { StmtExprVisitor::VisitExpr_(op); - if (op->call_type == CallNode::Halide) { - TensorKey key{op->func, op->value_index}; - auto it = buf_map_.find(key); - CHECK(it != buf_map_.end()) - << "Cannot find allocated buffer for " << key.f; - const BufferInfo& bi = it->second; - CHECK(!bi.released) - << "Read a buffer that is already out of scope"; - - if (matrix_abc_.count(op->name)) { - if (bi.shape.size() < 2) { + + auto tensor = Downcast(op->producer); + auto it = buf_map_.find(tensor); + CHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << tensor->GetNameHint(); + const BufferInfo& bi = it->second; + CHECK(!bi.released) << "Read a buffer that is already out of scope"; + + if (matrix_abc_.count(tensor->op->name)) { + if (bi.shape.size() < 2) { + invalid_ = true; + return; + } + for (auto i = bi.shape.size() - 1; i + 2 >= bi.shape.size(); --i) { + const IntImmNode* shape = bi.shape[i].as(); + if (shape == nullptr || shape->value % 16 != 0) { invalid_ = true; return; } - for (auto i = bi.shape.size() - 1; i + 2 >= bi.shape.size(); --i) { - const IntImmNode* shape = bi.shape[i].as(); - if (shape == nullptr || shape->value % 16 != 0) { - invalid_ = true; - return; - } - } } + } - Array strides; - if (bi.strides.size() > 0) { - strides = bi.strides; - } else { - for (size_t i = 1; i < bi.shape.size(); ++i) { - PrimExpr stride = IntImm(DataType::Int(32), 1); - for (size_t j = bi.shape.size() - 1; j >= i; --j) { - stride = MulNode::make(stride, bi.shape[j]); - } - strides.push_back(stride); + Array strides; + if (bi.strides.size() > 0) { + strides = bi.strides; + } else { + for (size_t i = 1; i < bi.shape.size(); ++i) { + PrimExpr stride = IntImm(DataType::Int(32), 1); + for (size_t j = bi.shape.size() - 1; j >= i; --j) { + stride = Mul(stride, bi.shape[j]); } - strides.push_back(make_const(DataType::Int(32), 1)); + strides.push_back(stride); } - strides_.insert(std::make_pair(key.GetName(), strides)); + strides.push_back(make_const(DataType::Int(32), 1)); + } + strides_.insert(std::make_pair(tensor->GetNameHint(), strides)); - if (!frag_reg_.count(bi.name)) { - return; - } + if (!frag_reg_.count(bi.name)) { + return; + } - auto rel_index = bi.RelIndex(op->args); - if (op->args.size() < 2) { - invalid_ = true; - return; - } - for (auto i = op->args.size() - 1; i + 2 >= op->args.size(); --i) { - index_visitor.scaling_factor_ = 16; - if (const IntImmNode* shape = bi.shape[i].as()) { - index_visitor.scaling_factor_ = shape->value; - } - auto index = rel_index[i]; - auto simplified_index = tir::Simplify(index); - index_visitor(simplified_index); + auto rel_index = bi.RelIndex(op->indices); + if (op->indices.size() < 2) { + invalid_ = true; + return; + } + for (auto i = op->indices.size() - 1; i + 2 >= op->indices.size(); --i) { + index_visitor.scaling_factor_ = 16; + if (const IntImmNode* shape = bi.shape[i].as()) { + index_visitor.scaling_factor_ = shape->value; } + auto index = rel_index[i]; + auto simplified_index = analyzer_.Simplify(index); + index_visitor(simplified_index); } } - void VisitStmt_(const RealizeNode* op) final { - TensorKey key{op->func, op->value_index}; + void VisitStmt_(const ProducerRealizeNode* op) final { + auto key = Downcast(op->producer); if (buf_map_.count(key)) { CHECK(buf_map_.at(key).external); this->VisitStmt(op->body); @@ -641,9 +614,8 @@ class BufferAnalyser : public StmtExprVisitor { if (dim < avec.size() && avec[dim].align_factor != 0) { PrimExpr factor = make_const(stride.dtype(), avec[dim].align_factor); PrimExpr offset = make_const(stride.dtype(), avec[dim].align_offset); - stride = stride + \ - indexmod(factor + offset - indexmod(stride, factor), factor); - stride = tir::Simplify(stride); + stride = stride + indexmod(factor + offset - indexmod(stride, factor), factor); + stride = analyzer_.Simplify(stride); } rstrides.push_back(stride); stride = stride * shape[dim]; @@ -651,8 +623,8 @@ class BufferAnalyser : public StmtExprVisitor { strides = Array(rstrides.rbegin(), rstrides.rend()); } - bi.name = key.GetName(); - bi.dtype = op->dtype; + bi.name = key->GetNameHint(); + bi.dtype = key->dtype; bi.strides = strides; bi.shape = shape; @@ -729,48 +701,39 @@ class BufferAnalyser : public StmtExprVisitor { } bool supported_warp_tile_() { - if (warp_tile_.m == 16 && - warp_tile_.n == 16 && - warp_tile_.k == 16) { + if (warp_tile_.m == 16 && warp_tile_.n == 16 && warp_tile_.k == 16) { return true; } - if (warp_tile_.m == 8 && - warp_tile_.n == 32 && - warp_tile_.k == 16) { + if (warp_tile_.m == 8 && warp_tile_.n == 32 && warp_tile_.k == 16) { return true; } - if (warp_tile_.m == 32 && - warp_tile_.n == 8 && - warp_tile_.k == 16) { + if (warp_tile_.m == 32 && warp_tile_.n == 8 && warp_tile_.k == 16) { return true; } - if (warp_tile_.m == 8 && - warp_tile_.n == 8 && - warp_tile_.k == 32) { + if (warp_tile_.m == 8 && warp_tile_.n == 8 && warp_tile_.k == 32) { return true; } - if (warp_tile_.m == 8 && - warp_tile_.n == 8 && - warp_tile_.k == 128) { + if (warp_tile_.m == 8 && warp_tile_.n == 8 && warp_tile_.k == 128) { return true; } return false; } - std::unordered_map buf_map_; - std::unordered_map > dim_align_; + std::unordered_map buf_map_; + std::unordered_map> dim_align_; std::unordered_map storage_scope_; std::unordered_map matrix_abc_; std::unordered_map matrix_major_; std::unordered_set frag_reg_; std::unordered_map> strides_; - std::unordered_map frag_load_; - std::unordered_map frag_store_; + std::unordered_map frag_load_; + std::unordered_map frag_store_; std::unordered_map thread_extent_; IndexVisitor index_visitor; Tile warp_tile_; Tile thread_tile_; + arith::Analyzer analyzer_; int warp_threads_y_{-1}; bool invalid_{false}; }; @@ -778,7 +741,7 @@ class BufferAnalyser : public StmtExprVisitor { // ThreadIdxMutator does the thread index unification inside a warp class ThreadIdxMutator : public StmtExprMutator { public: - explicit ThreadIdxMutator(PrimExpr warp_y): warp_y_(warp_y) {} + explicit ThreadIdxMutator(PrimExpr warp_y) : warp_y_(warp_y) {} PrimExpr VisitExpr_(const VarNode* op) final { PrimExpr expr = StmtExprMutator::VisitExpr_(op); @@ -789,8 +752,8 @@ class ThreadIdxMutator : public StmtExprMutator { return zero; } if (op->name_hint == "threadIdx.y") { - PrimExpr div = DivNode::make(expr, warp_y_); - PrimExpr mul = MulNode::make(div, warp_y_); + PrimExpr div = Div(expr, warp_y_); + PrimExpr mul = Mul(div, warp_y_); return mul; } } @@ -805,52 +768,49 @@ class ThreadIdxMutator : public StmtExprMutator { // based on tensor core intrinsics class TensorCoreIRMutator : public StmtExprMutator { public: - explicit TensorCoreIRMutator(const ScheduleAnalyser &schedule_analyser, - const BufferAnalyser &buffer_analyser) + explicit TensorCoreIRMutator(const ScheduleAnalyser& schedule_analyser, + const BufferAnalyser& buffer_analyser) : matrix_abc_(schedule_analyser.matrix_abc_), - matrix_major_(schedule_analyser.matrix_major_), - mma_sync_(schedule_analyser.mma_sync_), - strides_(buffer_analyser.strides_), - frag_reg_(buffer_analyser.frag_reg_), - loop_scaling_(buffer_analyser.index_visitor.loop_scaling_), - frag_load_(buffer_analyser.frag_load_), - frag_store_(buffer_analyser.frag_store_), - warp_tile_(buffer_analyser.warp_tile_), - warp_threads_y_(buffer_analyser.warp_threads_y_) {} - - Stmt VisitStmt_(const RealizeNode* op) final { - TensorKey key{op->func, op->value_index}; + matrix_major_(schedule_analyser.matrix_major_), + mma_sync_(schedule_analyser.mma_sync_), + strides_(buffer_analyser.strides_), + frag_reg_(buffer_analyser.frag_reg_), + loop_scaling_(buffer_analyser.index_visitor.loop_scaling_), + frag_load_(buffer_analyser.frag_load_), + frag_store_(buffer_analyser.frag_store_), + warp_tile_(buffer_analyser.warp_tile_), + warp_threads_y_(buffer_analyser.warp_threads_y_) {} + + Stmt VisitStmt_(const ProducerRealizeNode* op) final { + auto key = Downcast(op->producer); bounds_[key] = op->bounds; Stmt stmt = StmtExprMutator::VisitStmt_(op); - op = stmt.as(); + op = stmt.as(); if (op != nullptr) { - if (!frag_reg_.count(key.GetName())) { + if (!frag_reg_.count(key->GetNameHint())) { return stmt; } - auto new_extents = get_tile_size_(simplify_name(key.GetName())); + auto new_extents = get_tile_size_(simplify_name(key->GetNameHint())); Region new_bounds; for (size_t i = 0; i < op->bounds.size() - 2; ++i) { new_bounds.push_back(op->bounds[i]); } - CHECK_GE(op->bounds.size(), 2) - << "Less than 2 dimensions for matrix " << key.GetName(); - new_bounds.push_back(Range::make_by_min_extent( - op->bounds[op->bounds.size() - 2]->min, new_extents[0])); - new_bounds.push_back(Range::make_by_min_extent( - op->bounds[op->bounds.size() - 1]->min, new_extents[1])); - - return RealizeNode::make(op->func, op->value_index, - op->dtype, new_bounds, - op->condition, op->body); + CHECK_GE(op->bounds.size(), 2) << "Less than 2 dimensions for matrix " << key->GetNameHint(); + new_bounds.push_back( + Range::make_by_min_extent(op->bounds[op->bounds.size() - 2]->min, new_extents[0])); + new_bounds.push_back( + Range::make_by_min_extent(op->bounds[op->bounds.size() - 1]->min, new_extents[1])); + + return ProducerRealize(op->producer, new_bounds, op->condition, op->body); } return stmt; } Stmt VisitStmt_(const AttrStmtNode* op) final { Stmt stmt = StmtExprMutator::VisitStmt_(op); - if (op->attr_key == attr::realize_scope) { + if (op->attr_key == tir::attr::realize_scope) { auto node = op->node.as(); if (node != nullptr) { if (!frag_reg_.count(node->name)) { @@ -858,189 +818,142 @@ class TensorCoreIRMutator : public StmtExprMutator { } auto it = matrix_abc_.find(simplify_name(node->name)); - CHECK(it != matrix_abc_.end()) - << "Cannot find matrix info for " << node->name; - auto matrix_abc = tvm::tir::StringImmNode::make("wmma." + it->second); + CHECK(it != matrix_abc_.end()) << "Cannot find matrix info for " << node->name; + auto matrix_abc = tvm::tir::StringImm("wmma." + it->second); Stmt body = this->VisitStmt(op->body); - return AttrStmtNode::make(op->node, - op->attr_key, - matrix_abc, - body); + return AttrStmt(op->node, op->attr_key, matrix_abc, body); } } return stmt; } - Stmt VisitStmt_(const ProvideNode* op) final { + Stmt VisitStmt_(const ProducerStoreNode* op) final { Stmt stmt = StmtExprMutator::VisitStmt_(op); auto it = mma_sync_.find(op); if (it != mma_sync_.end()) { - const auto &operands = it->second; + const auto& operands = it->second; PrimExpr a = operands[0]; - auto ca = a.as(); + auto ca = a.as(); PrimExpr b = operands[1]; - auto cb = b.as(); + auto cb = b.as(); PrimExpr c = operands[2]; - auto cc = c.as(); + auto cc = c.as(); ObjectPtr buffer_node_a = make_object(); ObjectPtr buffer_node_b = make_object(); ObjectPtr buffer_node_c = make_object(); - auto mma_sync_call = - [&buffer_node_a, &buffer_node_b, &ca, &cb] - (const Buffer &buffer) { - Buffer buffer_a(buffer_node_a); - Buffer buffer_b(buffer_node_b); - if (ca->dtype == DataType::Int(1) && cb->dtype == DataType::Int(1)) { - return EvaluateNode::make( - CallNode::make(DataType::Handle(), - intrinsic::tvm_bmma_sync, - {buffer->data, buffer->elem_offset, - buffer_a->data, buffer_a->elem_offset, - buffer_b->data, buffer_b->elem_offset, - buffer->data, buffer->elem_offset}, - CallNode::Intrinsic)); - } else { - return EvaluateNode::make( - CallNode::make(DataType::Handle(), - intrinsic::tvm_mma_sync, - {buffer->data, buffer->elem_offset, - buffer_a->data, buffer_a->elem_offset, - buffer_b->data, buffer_b->elem_offset, - buffer->data, buffer->elem_offset}, - CallNode::Intrinsic)); - } - }; + auto mma_sync_call = [&buffer_node_a, &buffer_node_b, &ca, &cb](const Buffer& buffer) { + Buffer buffer_a(buffer_node_a); + Buffer buffer_b(buffer_node_b); + if (ca->dtype == DataType::Int(1) && cb->dtype == DataType::Int(1)) { + return Evaluate( + Call(DataType::Handle(), intrinsic::tvm_bmma_sync, + {buffer->data, buffer->elem_offset, buffer_a->data, buffer_a->elem_offset, + buffer_b->data, buffer_b->elem_offset, buffer->data, buffer->elem_offset}, + CallNode::Intrinsic)); + } else { + return Evaluate( + Call(DataType::Handle(), intrinsic::tvm_mma_sync, + {buffer->data, buffer->elem_offset, buffer_a->data, buffer_a->elem_offset, + buffer_b->data, buffer_b->elem_offset, buffer->data, buffer->elem_offset}, + CallNode::Intrinsic)); + } + }; - auto call_add_c = - [this, &cc, &buffer_node_c, &mma_sync_call](const Buffer &buffer) { - return add_buffer_bind_scope_(cc, buffer_node_c, - TensorKey{cc->func, cc->value_index}, mma_sync_call, cc->dtype); - }; + auto call_add_c = [this, &cc, &buffer_node_c, &mma_sync_call](const Buffer& buffer) { + return add_buffer_bind_scope_(cc, buffer_node_c, mma_sync_call); + }; - auto call_add_b = - [this, &cb, &buffer_node_b, &call_add_c](const Buffer &buffer) { - return add_buffer_bind_scope_(cb, buffer_node_b, - TensorKey{cb->func, cb->value_index}, call_add_c, cb->dtype); - }; + auto call_add_b = [this, &cb, &buffer_node_b, &call_add_c](const Buffer& buffer) { + return add_buffer_bind_scope_(cb, buffer_node_b, call_add_c); + }; - return add_buffer_bind_scope_(ca, buffer_node_a, - TensorKey{ca->func, ca->value_index}, call_add_b, ca->dtype); + return add_buffer_bind_scope_(ca, buffer_node_a, call_add_b); } auto it2 = frag_load_.find(op); if (it2 != frag_load_.end()) { PrimExpr dst = it2->second; - if (op->value.as() != nullptr || - op->value.as() != nullptr) { - auto call = dst.as(); - - auto fill_fragment_call = - [this, &op](const Buffer &buffer) { - return EvaluateNode::make( - CallNode::make(DataType::Handle(), - intrinsic::tvm_fill_fragment, - {buffer->data, - warp_tile_.m, warp_tile_.n, warp_tile_.k, - buffer->elem_offset, op->value}, - CallNode::Intrinsic)); - }; + if (op->value.as() != nullptr || op->value.as() != nullptr) { + auto pload = dst.as(); + + auto fill_fragment_call = [this, &op](const Buffer& buffer) { + return Evaluate(Call(DataType::Handle(), intrinsic::tvm_fill_fragment, + {buffer->data, warp_tile_.m, warp_tile_.n, warp_tile_.k, + buffer->elem_offset, op->value}, + CallNode::Intrinsic)); + }; ObjectPtr buffer_node = make_object(); - return add_buffer_bind_scope_(call, buffer_node, - TensorKey{call->func, call->value_index}, - fill_fragment_call, call->dtype); + return add_buffer_bind_scope_(pload, buffer_node, fill_fragment_call); } const CallNode* value = op->value.as(); - CHECK(value != nullptr) - << "Can only load fragment from a buffer"; + CHECK(value != nullptr) << "Can only load fragment from a buffer"; auto it = strides_.find(value->name); - CHECK(it != strides_.end()) - << "Cannot find stride for " << value->name; + CHECK(it != strides_.end()) << "Cannot find stride for " << value->name; auto strides = it->second; CHECK_GE(strides.size(), 2); - PrimExpr stride = strides[strides.size()-2]; + PrimExpr stride = strides[strides.size() - 2]; // thread index unification inside a warp PrimExpr warp_y = IntImm(DataType::Int(32), warp_threads_y_); ThreadIdxMutator thread_idx_mutator(warp_y); PrimExpr mutated_value = thread_idx_mutator(op->value); - PrimExpr src = CallNode::make(value->dtype, - "&", - {mutated_value}, - CallNode::Extern); + PrimExpr src = Call(value->dtype, "&", {mutated_value}, CallNode::Extern); - auto call = dst.as(); + auto pload = dst.as(); PrimExpr matrix_major; - auto iter2 = matrix_major_.find(simplify_name(call->name)); + auto iter2 = matrix_major_.find(simplify_name(pload->producer->GetNameHint())); CHECK(iter2 != matrix_major_.end()) - << "Can not determine matrix major for " << call->name; + << "Can not determine matrix major for " << pload->producer->GetNameHint(); if (iter2->second == "col_major") { - matrix_major = StringImmNode::make("col_major"); + matrix_major = StringImm("col_major"); } else if (iter2->second == "row_major") { - matrix_major = StringImmNode::make("row_major"); + matrix_major = StringImm("row_major"); } else { - LOG(FATAL) << "invalid matrix major for " << call->name; + LOG(FATAL) << "invalid matrix major for " << pload->producer->GetNameHint(); } - auto load_matrix_call = - [this, &src, &stride, &matrix_major](const Buffer &buffer) { - return EvaluateNode::make( - CallNode::make(DataType::Handle(), - intrinsic::tvm_load_matrix_sync, - {buffer->data, - warp_tile_.m, warp_tile_.n, warp_tile_.k, - buffer->elem_offset, src, stride, matrix_major}, - CallNode::Intrinsic)); + auto load_matrix_call = [this, &src, &stride, &matrix_major](const Buffer& buffer) { + return Evaluate(Call(DataType::Handle(), intrinsic::tvm_load_matrix_sync, + {buffer->data, warp_tile_.m, warp_tile_.n, warp_tile_.k, + buffer->elem_offset, src, stride, matrix_major}, + CallNode::Intrinsic)); }; ObjectPtr buffer_node = make_object(); - return add_buffer_bind_scope_(call, buffer_node, - TensorKey{op->func, op->value_index}, - load_matrix_call, call->dtype); + return add_buffer_bind_scope_(pload, buffer_node, load_matrix_call); } auto it3 = frag_store_.find(op); if (it3 != frag_store_.end()) { - TensorKey key{op->func, op->value_index}; - auto it = strides_.find(key.GetName()); - CHECK(it != strides_.end()) - << "Cannot find stride for " << key.GetName(); + auto it = strides_.find(op->producer->GetNameHint()); + CHECK(it != strides_.end()) << "Cannot find stride for " << op->producer->GetNameHint(); auto strides = it->second; CHECK_GE(strides.size(), 2); - PrimExpr stride = strides[strides.size()-2]; + PrimExpr stride = strides[strides.size() - 2]; PrimExpr dst = it3->second; // thread index unification inside a warp PrimExpr warp_y = IntImm(DataType::Int(32), warp_threads_y_); ThreadIdxMutator thread_idx_mutator(warp_y); dst = thread_idx_mutator(dst); - dst = CallNode::make(DataType::Handle(), - "&", - {dst}, - CallNode::Extern); - - auto call = op->value.as(); - - auto store_matrix_call = - [this, &dst, &stride](const Buffer &buffer) { - return EvaluateNode::make( - CallNode::make(DataType::Handle(), - intrinsic::tvm_store_matrix_sync, - {buffer->data, - warp_tile_.m, warp_tile_.n, warp_tile_.k, - buffer->elem_offset, dst, stride, - StringImmNode::make("col_major")}, - CallNode::Intrinsic)); - }; + dst = Call(DataType::Handle(), "&", {dst}, CallNode::Extern); + + auto pload = op->value.as(); + + auto store_matrix_call = [this, &dst, &stride](const Buffer& buffer) { + return Evaluate(Call(DataType::Handle(), intrinsic::tvm_store_matrix_sync, + {buffer->data, warp_tile_.m, warp_tile_.n, warp_tile_.k, + buffer->elem_offset, dst, stride, StringImm("col_major")}, + CallNode::Intrinsic)); + }; ObjectPtr buffer_node = make_object(); - return add_buffer_bind_scope_(call, buffer_node, - TensorKey{call->func, call->value_index}, - store_matrix_call, call->dtype); + return add_buffer_bind_scope_(pload, buffer_node, store_matrix_call); } return stmt; @@ -1054,55 +967,54 @@ class TensorCoreIRMutator : public StmtExprMutator { if (it != loop_scaling_.end()) { int scale_factor = it->second; int scaled_extent_value = 1; - if (const IntImmNode *ori_extent = op->extent.as()) { + if (const IntImmNode* ori_extent = op->extent.as()) { int ori_extent_value = ori_extent->value; scaled_extent_value = ori_extent_value / scale_factor; } PrimExpr scaled_extent = make_const(op->extent.dtype(), scaled_extent_value); - stmt = ForNode::make(op->loop_var, op->min, scaled_extent, op->for_type, - op->device_api, op->body); + stmt = For(op->loop_var, op->min, scaled_extent, op->for_type, op->device_api, op->body); } } return stmt; } private: - Array get_tile_size_(const std::string &name) { - auto it = matrix_abc_.find(name); - auto it2 = matrix_major_.find(name); - CHECK(it != matrix_abc_.end() && it2 != matrix_major_.end()) - << "Cannot find matrix info for " << name; - PrimExpr size0 = make_const(DataType::Int(32), 16); - PrimExpr size1 = make_const(DataType::Int(32), 16); - if (it->second == "matrix_a" && it2->second == "col_major") { - size0 = make_const(DataType::Int(32), warp_tile_.k); - size1 = make_const(DataType::Int(32), warp_tile_.m); - } - if (it->second == "matrix_a" && it2->second == "row_major") { - size0 = make_const(DataType::Int(32), warp_tile_.m); - size1 = make_const(DataType::Int(32), warp_tile_.k); - } - if (it->second == "matrix_b" && it2->second == "row_major") { - size0 = make_const(DataType::Int(32), warp_tile_.k); - size1 = make_const(DataType::Int(32), warp_tile_.n); - } - if (it->second == "matrix_b" && it2->second == "col_major") { - size0 = make_const(DataType::Int(32), warp_tile_.n); - size1 = make_const(DataType::Int(32), warp_tile_.k); - } - if (it->second == "matrix_c") { - size0 = make_const(DataType::Int(32), warp_tile_.n); - size1 = make_const(DataType::Int(32), warp_tile_.m); - } - Array tile_size = {size0, size1}; - return tile_size; + Array get_tile_size_(const std::string& name) { + auto it = matrix_abc_.find(name); + auto it2 = matrix_major_.find(name); + CHECK(it != matrix_abc_.end() && it2 != matrix_major_.end()) + << "Cannot find matrix info for " << name; + PrimExpr size0 = make_const(DataType::Int(32), 16); + PrimExpr size1 = make_const(DataType::Int(32), 16); + if (it->second == "matrix_a" && it2->second == "col_major") { + size0 = make_const(DataType::Int(32), warp_tile_.k); + size1 = make_const(DataType::Int(32), warp_tile_.m); + } + if (it->second == "matrix_a" && it2->second == "row_major") { + size0 = make_const(DataType::Int(32), warp_tile_.m); + size1 = make_const(DataType::Int(32), warp_tile_.k); + } + if (it->second == "matrix_b" && it2->second == "row_major") { + size0 = make_const(DataType::Int(32), warp_tile_.k); + size1 = make_const(DataType::Int(32), warp_tile_.n); + } + if (it->second == "matrix_b" && it2->second == "col_major") { + size0 = make_const(DataType::Int(32), warp_tile_.n); + size1 = make_const(DataType::Int(32), warp_tile_.k); + } + if (it->second == "matrix_c") { + size0 = make_const(DataType::Int(32), warp_tile_.n); + size1 = make_const(DataType::Int(32), warp_tile_.m); + } + Array tile_size = {size0, size1}; + return tile_size; } - Stmt add_buffer_bind_scope_(const CallNode* call, - const ObjectPtr &buffer_node, const TensorKey &key, - const std::function &call_back, - DataType datatype) { - auto it = bounds_.find(key); + Stmt add_buffer_bind_scope_(const ProducerLoadNode* pload, + const ObjectPtr& buffer_node, + const std::function& call_back) { + auto tensor = Downcast(pload->producer); + auto it = bounds_.find(tensor); CHECK(it != bounds_.end()); Array min_bound; for (auto i : it->second) { @@ -1114,7 +1026,7 @@ class TensorCoreIRMutator : public StmtExprMutator { for (size_t i = 0; i < it->second.size() - 2; ++i) { shape.push_back(it->second[i]->extent); } - auto tile_size = get_tile_size_(simplify_name(call->name)); + auto tile_size = get_tile_size_(simplify_name(tensor->op->name)); shape.push_back(tile_size[0]); shape.push_back(tile_size[1]); @@ -1122,73 +1034,57 @@ class TensorCoreIRMutator : public StmtExprMutator { for (size_t i = 1; i < shape.size(); ++i) { PrimExpr stride = IntImm(DataType::Int(32), 1); for (size_t j = shape.size() - 1; j >= i; --j) { - stride = MulNode::make(stride, shape[j]); + stride = Mul(stride, shape[j]); } strides.push_back(stride); } strides.push_back(make_const(DataType::Int(32), 1)); PrimExpr elem_offset = IntImm(DataType::Int(32), 0); - CHECK_EQ(call->args.size(), min_bound.size()); + CHECK_EQ(pload->indices.size(), min_bound.size()); for (size_t i = 0; i < min_bound.size(); i++) { - elem_offset = AddNode::make( - elem_offset, MulNode::make( - strides[i], SubNode::make(call->args[i], min_bound[i]))); + elem_offset = Add(elem_offset, Mul(strides[i], Sub(pload->indices[i], min_bound[i]))); } - auto it2 = matrix_abc_.find(simplify_name(call->name)); - CHECK(it2 != matrix_abc_.end()) - << "Cannot find matrix info for " << call->name; - buffer_node->data = Var(call->name, DataType::Handle()); - buffer_node->name = call->name; + auto it2 = matrix_abc_.find(simplify_name(tensor->op->name)); + CHECK(it2 != matrix_abc_.end()) << "Cannot find matrix info for " << tensor->op->name; + buffer_node->data = Var(tensor->op->name, DataType::Handle()); + buffer_node->name = tensor->op->name; buffer_node->scope = "wmma." + it2->second; - buffer_node->dtype = datatype; + buffer_node->dtype = tensor->dtype; buffer_node->strides = strides; buffer_node->shape = shape; buffer_node->data_alignment = 1; - buffer_node->elem_offset = Simplify(elem_offset); + buffer_node->elem_offset = analyzer_.Simplify(elem_offset); buffer_node->offset_factor = 1; Buffer buffer(buffer_node); - ObjectPtr tensor_node = make_object(); - tensor_node->value_index = key.value_index; - tensor_node->op = Downcast(key.f); - tensor_node->shape = shape; - tensor_node->dtype = datatype; - Tensor tensor(tensor_node); - Array args; - for (size_t i = 0; i < call->args.size(); ++i) { - args.push_back(call->args[i]); + for (size_t i = 0; i < pload->indices.size(); ++i) { + args.push_back(pload->indices[i]); args.push_back(shape[i]); } - auto tuple = CallNode::make(DataType::Handle(), - intrinsic::tvm_tuple, - args, - CallNode::Intrinsic); + auto tuple = Call(DataType::Handle(), intrinsic::tvm_tuple, args, CallNode::Intrinsic); Array node = {buffer, tensor}; - return AttrStmtNode::make(node, - "buffer_bind_scope", - tuple, - call_back(buffer)); + return AttrStmt(node, "buffer_bind_scope", tuple, call_back(buffer)); } std::unordered_map matrix_abc_; std::unordered_map matrix_major_; - std::unordered_map> mma_sync_; + std::unordered_map> mma_sync_; std::unordered_map> strides_; std::unordered_set frag_reg_; std::unordered_map loop_scaling_; - std::unordered_map frag_load_; - std::unordered_map frag_store_; - std::unordered_map bounds_; + std::unordered_map frag_load_; + std::unordered_map frag_store_; + std::unordered_map bounds_; + arith::Analyzer analyzer_; Tile warp_tile_; int warp_threads_y_{-1}; }; -Stmt RewriteForTensorCore(Stmt stmt, - Schedule schedule, - Map extern_buffer) { +Stmt SchedulePostProcRewriteForTensorCore(Stmt stmt, Schedule schedule, + Map extern_buffer) { // Check if current lower target is CUDA auto target = tvm::Target::Current(true); if (target.defined() && target->target_name != "cuda") { @@ -1213,8 +1109,7 @@ Stmt RewriteForTensorCore(Stmt stmt, return stmt; } - BufferAnalyser buffer_analyser(extern_buffer, - schedule_analyser, mma_matcher); + BufferAnalyser buffer_analyser(extern_buffer, schedule_analyser, mma_matcher); buffer_analyser(stmt); if (!buffer_analyser.QualifiedForTensorCore()) { return stmt; @@ -1223,5 +1118,10 @@ Stmt RewriteForTensorCore(Stmt stmt, return TensorCoreIRMutator(schedule_analyser, buffer_analyser)(std::move(stmt)); } -} // namespace tir +TVM_REGISTER_GLOBAL("schedule.SchedulePostProcRewriteForTensorCore") + .set_body_typed([](Stmt stmt, Schedule schedule, Map extern_buffer) { + return SchedulePostProcRewriteForTensorCore(stmt, schedule, extern_buffer); + }); + +} // namespace te } // namespace tvm diff --git a/src/te/schedule/schedule_postproc_to_primfunc.cc b/src/te/schedule/schedule_postproc_to_primfunc.cc new file mode 100644 index 000000000000..a86ad76b0eb9 --- /dev/null +++ b/src/te/schedule/schedule_postproc_to_primfunc.cc @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file schedule_postproc_to_primfunc.cc + * + * \brief Translate the function body generated by ScheduleOps + * with te related dialects that incorporates Tensor + * into the Stmts to a PrimFunc. + * + * Perform this translation before running any TIR optimizations. + * + * Rationale: The body generated by ScheduleOps is not + * a formal PrimFunc and cannot be used for further optimization. + * This function canonicalize that body and creates a formal PrimFunc. + * + * List of actions taken by the function: + * - Remove occurences of te::Tensor, te::Operation in the IR + * and replace them by corresponding IR nodes via tir::Buffer. + * - Add annotation of extern buffers using the buffer_map field + * in the PrimFunc type. + */ +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace tvm { +namespace te { + +// create a buffer for tensor. +Buffer CreateBufferFor(const Tensor& tensor) { + std::string name = tensor->op->name; + if (tensor->op->num_outputs() != 1) { + name += ".v" + std::to_string(tensor->value_index); + } + Buffer buffer = decl_buffer(tensor->shape, tensor->dtype, name); + return buffer; +} + +// A remapper that maps tensor to buffer +class TensorToBufferMapper : public StmtExprMutator { + public: + explicit TensorToBufferMapper(std::unordered_map buffer_map) + : buffer_map_(buffer_map) {} + + Stmt VisitStmt_(const AttrStmtNode* op) final { + auto ret = StmtExprMutator::VisitStmt_(op); + op = ret.as(); + // TODO(tvm-team): remove realize_scope, turn the info into + // Buffer's scope field in this pass. + if (op->attr_key == tir::attr::realize_scope || + op->attr_key == tir::attr::double_buffer_scope) { + Stmt body = op->body; + Operation operation = Downcast(op->node); + for (int i = operation->num_outputs(); i != 0; --i) { + Buffer buffer = GetOrAllocBuffer(operation.output(i - 1)); + body = AttrStmt(buffer, op->attr_key, op->value, body); + } + return body; + } else if (op->attr_key == tir::attr::buffer_bind_scope) { + Array tuple = Downcast>(op->node); + Tensor tensor = Downcast(tuple[1]); + return AttrStmt(Array{tuple[0], GetOrAllocBuffer(tensor)}, op->attr_key, op->value, + op->body); + } else if (op->attr_key == tir::attr::buffer_dim_align || + op->attr_key == tir::attr::prefetch_scope) { + Tensor tensor = Downcast(op->node); + Buffer buffer = GetOrAllocBuffer(tensor); + return AttrStmt(buffer, op->attr_key, op->value, op->body); + } else { + return ret; + } + } + + Stmt VisitStmt_(const ProducerRealizeNode* op) final { + Tensor tensor = Downcast(op->producer); + Buffer buffer = GetOrAllocBuffer(tensor); + + auto ret = StmtExprMutator::VisitStmt_(op); + op = ret.as(); + + return BufferRealize(buffer, op->bounds, op->condition, op->body); + } + + Stmt VisitStmt_(const ProducerStoreNode* op) final { + Tensor tensor = Downcast(op->producer); + Buffer buffer = GetBuffer(tensor); + + auto ret = StmtExprMutator::VisitStmt_(op); + op = ret.as(); + + return BufferStore(buffer, op->value, op->indices); + } + + PrimExpr VisitExpr_(const ProducerLoadNode* op) final { + auto ret = StmtExprMutator::VisitExpr_(op); + op = ret.as(); + Tensor tensor = Downcast(op->producer); + Buffer buffer = GetBuffer(tensor); + return tir::BufferLoad(buffer, op->indices); + } + + private: + Buffer GetOrAllocBuffer(const Tensor& tensor) { return GetBuffer(tensor, true); } + + Buffer GetBuffer(const Tensor& tensor, bool allow_alloc = false) { + auto it = buffer_map_.find(tensor); + if (it != buffer_map_.end()) return it->second; + CHECK(allow_alloc) << "Cannot find the Realization point of tensor " << tensor; + + auto buffer = CreateBufferFor(tensor); + buffer_map_[tensor] = buffer; + return buffer; + } + + // maps tensor to buffer. + std::unordered_map buffer_map_; +}; + +PrimFunc SchedulePostProcToPrimFunc(Array arg_list, Stmt body, + Optional> extern_buffer_opt) { + std::unordered_map extern_buffer; + + if (extern_buffer_opt.defined()) { + auto v = extern_buffer_opt.value(); + extern_buffer = std::unordered_map(v.begin(), v.end()); + } + + Array params; + Map buffer_map; + + for (auto var : arg_list) { + if (auto* n = var.as()) { + params.push_back(GetRef(n)); + } else if (auto* n = var.as()) { + te::Tensor tensor = GetRef(n); + CHECK(!extern_buffer.count(tensor)); + + tir::Buffer buffer = CreateBufferFor(tensor); + tir::Var bptr(buffer->name, DataType::Handle()); + params.push_back(bptr); + buffer_map.Set(bptr, buffer); + extern_buffer[tensor] = buffer; + } else { + tir::Buffer buffer = Downcast(var); + tir::Var bptr(buffer->name, DataType::Handle()); + params.push_back(bptr); + buffer_map.Set(bptr, buffer); + } + } + + body = TensorToBufferMapper(std::move(extern_buffer))(std::move(body)); + return tir::PrimFunc(params, body, VoidType(), buffer_map); +} + +TVM_REGISTER_GLOBAL("schedule.SchedulePostProcToPrimFunc") + .set_body_typed(SchedulePostProcToPrimFunc); + +} // namespace te +} // namespace tvm diff --git a/src/tir/pass/verify_compact_buffer.cc b/src/te/schedule/verify_compact_buffer.cc similarity index 83% rename from src/tir/pass/verify_compact_buffer.cc rename to src/te/schedule/verify_compact_buffer.cc index 5328165ffb91..0089c36dc607 100644 --- a/src/tir/pass/verify_compact_buffer.cc +++ b/src/te/schedule/verify_compact_buffer.cc @@ -21,16 +21,18 @@ * \file verify_compact_buffer.cc * \brief Verify if there was any compact buffer bound to a statement. */ +#include +#include +#include #include #include #include #include -#include #include namespace tvm { -namespace tir { +namespace te { class VerifyBuffer : public StmtVisitor { public: @@ -41,7 +43,7 @@ class VerifyBuffer : public StmtVisitor { void VisitStmt_(const AttrStmtNode* op) final { StmtVisitor::VisitStmt_(op); - if (op->attr_key == attr::buffer_bind_scope) { + if (op->attr_key == tir::attr::buffer_bind_scope) { is_compact_ = true; } } @@ -50,10 +52,12 @@ class VerifyBuffer : public StmtVisitor { bool is_compact_{false}; }; -bool VerifyCompactBuffer(Stmt stmt) { +bool VerifyCompactBuffer(const Stmt& stmt) { VerifyBuffer verifier; return verifier.Verify(stmt); } -} // namespace tir +TVM_REGISTER_GLOBAL("schedule.VerifyCompactBuffer").set_body_typed(VerifyCompactBuffer); + +} // namespace te } // namespace tvm diff --git a/src/te/tensor.cc b/src/te/tensor.cc index cb14f6a35270..e66b9632d8a2 100644 --- a/src/te/tensor.cc +++ b/src/te/tensor.cc @@ -21,27 +21,22 @@ * \file tensor.cc */ #include -#include #include +#include #include + #include namespace tvm { namespace te { IterVar thread_axis(Range dom, std::string tag) { - return IterVarNode::make( - dom, Var(tag), kThreadIndex, tag); + return IterVar(dom, Var(tag), kThreadIndex, tag); } -IterVar reduce_axis(Range dom, std::string name) { - return IterVarNode::make( - dom, Var(name), kCommReduce); -} +IterVar reduce_axis(Range dom, std::string name) { return IterVar(dom, Var(name), kCommReduce); } -Var var(std::string name_hint, DataType t) { - return Var(name_hint, t); -} +Var var(std::string name_hint, DataType t) { return Var(name_hint, t); } // Tensor PrimExpr Tensor::operator()(Array indices) const { @@ -50,16 +45,16 @@ PrimExpr Tensor::operator()(Array indices) const { } PrimExpr Tensor::operator()(Array indices) const { - using tir::CallNode; if (ndim() != 0) { - CHECK_EQ(ndim(), indices.size()) - << "Tensor dimension mismatch in read" - << "ndim = " << ndim() << ", indices.size=" << indices.size(); + CHECK_EQ(ndim(), indices.size()) << "Tensor dimension mismatch in read" + << "ndim = " << ndim() << ", indices.size=" << indices.size(); } - auto n = CallNode::make( - (*this)->dtype, (*this)->op->name, indices, CallNode::Halide, - (*this)->op, (*this)->value_index); - return n; + + return ProducerLoad((*this), indices); +} + +String TensorNode::GetNameHint() const { + return op->num_outputs() == 1 ? op->name : (op->name + ".v" + std::to_string(value_index)); } Tensor Operation::output(size_t i) const { @@ -71,38 +66,32 @@ Tensor Operation::output(size_t i) const { return Tensor(node); } -Tensor TensorNode::make(Array shape, - DataType dtype, - Operation op, - int value_index) { +Tensor::Tensor(Array shape, DataType dtype, Operation op, int value_index) { auto n = make_object(); n->shape = std::move(shape); n->dtype = dtype; n->op = op; n->value_index = value_index; - return Tensor(n); + data_ = std::move(n); } -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* t = static_cast(node.get()); - p->stream << "Tensor(shape=" << t->shape - << ", op.name=" << t->op->name << ')'; - }); +TVM_REGISTER_GLOBAL("te.Tensor") + .set_body_typed([](Array shape, DataType dtype, Operation op, int value_index) { + return Tensor(shape, dtype, op, value_index); + }); TVM_REGISTER_NODE_TYPE(TensorNode); +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* t = static_cast(node.get()); + p->stream << "Tensor(shape=" << t->shape << ", op.name=" << t->op->name << ')'; + }); // TensorIntrin - -TensorIntrin TensorIntrinNode::make(std::string name, - Operation op, - Array inputs, - Array buffers, - Array scalar_params, - Stmt body, - Stmt reduce_init, - Stmt reduce_update) { +TensorIntrin::TensorIntrin(std::string name, Operation op, Array inputs, + Array buffers, Array scalar_params, Stmt body, + Stmt reduce_init, Stmt reduce_update) { auto n = make_object(); n->name = std::move(name); n->op = std::move(op); @@ -112,69 +101,65 @@ TensorIntrin TensorIntrinNode::make(std::string name, n->body = std::move(body); n->reduce_init = std::move(reduce_init); n->reduce_update = std::move(reduce_update); - return TensorIntrin(n); + data_ = std::move(n); } -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "TensorIntrin(name=" << op->name << ", " << op << ")"; - }); +TVM_REGISTER_GLOBAL("te.TensorIntrin") + .set_body_typed([](std::string name, Operation op, Array inputs, Array buffers, + Array scalar_params, Stmt body, Stmt reduce_init, Stmt reduce_update) { + return TensorIntrin(name, op, inputs, buffers, scalar_params, body, reduce_init, + reduce_update); + }); TVM_REGISTER_NODE_TYPE(TensorIntrinNode); +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "TensorIntrin(name=" << op->name << ", " << op << ")"; + }); // TensorIntrinCall - -TensorIntrinCall TensorIntrinCallNode::make(TensorIntrin intrin, - Array tensors, - Array regions, - Array reduce_axis, - Array scalar_inputs) { +TensorIntrinCall::TensorIntrinCall(TensorIntrin intrin, Array tensors, + Array regions, Array reduce_axis, + Array scalar_inputs) { auto n = make_object(); n->intrin = std::move(intrin); n->tensors = std::move(tensors); n->regions = std::move(regions); n->reduce_axis = std::move(reduce_axis); n->scalar_inputs = std::move(scalar_inputs); - return TensorIntrinCall(n); + data_ = std::move(n); } +TVM_REGISTER_GLOBAL("te.TensorIntrinCall") + .set_body_typed([](TensorIntrin intrin, Array tensors, Array regions, + Array reduce_axis, Array scalar_inputs) { + return TensorIntrinCall(intrin, tensors, regions, reduce_axis, scalar_inputs); + }); + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* n = static_cast(node.get()); - p->stream << "TensorIntrinCall(intrin=" << n->intrin << ", " << n << ")"; - }); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* n = static_cast(node.get()); + p->stream << "TensorIntrinCall(intrin=" << n->intrin << ", " << n << ")"; + }); TVM_REGISTER_NODE_TYPE(TensorIntrinCallNode); -TVM_REGISTER_GLOBAL("te.Tensor") -.set_body_typed(TensorNode::make); - -TVM_REGISTER_GLOBAL("te.TensorIntrin") -.set_body_typed(TensorIntrinNode::make); - -TVM_REGISTER_GLOBAL("te.TensorIntrinCall") -.set_body_typed(TensorIntrinCallNode::make); - -TVM_REGISTER_GLOBAL("te.TensorEqual") -.set_body_method(&Tensor::operator==); +// Other tensor ops. +TVM_REGISTER_GLOBAL("te.TensorEqual").set_body_method(&Tensor::operator==); -TVM_REGISTER_GLOBAL("te.TensorHash") -.set_body_typed([](Tensor tensor) -> int64_t { - return static_cast(std::hash()(tensor)); - }); +TVM_REGISTER_GLOBAL("te.TensorHash").set_body_typed([](Tensor tensor) -> int64_t { + return static_cast(std::hash()(tensor)); +}); -TVM_REGISTER_GLOBAL("te.OpGetOutput") -.set_body_typed([](Operation op, int64_t output) { +TVM_REGISTER_GLOBAL("te.OpGetOutput").set_body_typed([](Operation op, int64_t output) { return op.output(static_cast(output)); }); -TVM_REGISTER_GLOBAL("te.OpNumOutputs") -.set_body_method(&OperationNode::num_outputs); +TVM_REGISTER_GLOBAL("te.OpNumOutputs").set_body_method(&OperationNode::num_outputs); -TVM_REGISTER_GLOBAL("te.OpInputTensors") -.set_body_method(&OperationNode::InputTensors); +TVM_REGISTER_GLOBAL("te.OpInputTensors").set_body_method(&OperationNode::InputTensors); } // namespace te } // namespace tvm diff --git a/src/tir/analysis/deep_equal.cc b/src/tir/analysis/deep_equal.cc index 763e3eb7cdae..7eb8013f2a85 100644 --- a/src/tir/analysis/deep_equal.cc +++ b/src/tir/analysis/deep_equal.cc @@ -21,16 +21,15 @@ * \file tir/analysis/deep_equal.cc * \brief Deep equality checking. */ -#include #include +#include #include #include namespace tvm { namespace tir { -class DeepCmpSEqualHandler : - public SEqualReducer::Handler { +class DeepCmpSEqualHandler : public SEqualReducer::Handler { public: // use direct recursion. bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars) final { @@ -41,12 +40,9 @@ class DeepCmpSEqualHandler : return vtable_->SEqualReduce(lhs.get(), rhs.get(), SEqualReducer(this, false)); } - ObjectRef MapLhsToRhs(const ObjectRef& lhs) final { - return ObjectRef(nullptr); - } + ObjectRef MapLhsToRhs(const ObjectRef& lhs) final { return ObjectRef(nullptr); } - void MarkGraphNode() final { - } + void MarkGraphNode() final {} private: // reflection vtable @@ -67,9 +63,9 @@ bool ExprDeepEqual::operator()(const PrimExpr& lhs, const PrimExpr& rhs) const { } TVM_REGISTER_GLOBAL("tir.analysis.expr_deep_equal") -.set_body_typed([](const PrimExpr& lhs, const PrimExpr& rhs) { - return ExprDeepEqual()(lhs, rhs); -}); + .set_body_typed([](const PrimExpr& lhs, const PrimExpr& rhs) { + return ExprDeepEqual()(lhs, rhs); + }); } // namespace tir } // namespace tvm diff --git a/src/arith/util.cc b/src/tir/analysis/side_effect.cc similarity index 57% rename from src/arith/util.cc rename to src/tir/analysis/side_effect.cc index 058c3e959528..b5fb328bf2b9 100644 --- a/src/arith/util.cc +++ b/src/tir/analysis/side_effect.cc @@ -18,36 +18,40 @@ */ /*! - * \file util.cc - * \brief The utils for arithmetic analysis. + * \file side_effect.cc + * \brief side effect analysis */ -#include -#include +#include +#include +#include namespace tvm { -namespace arith { +namespace tir { -std::tuple xgcd(int64_t a, int64_t b) { - int64_t s = 0, old_s = 1; - int64_t t = 1, old_t = 0; - int64_t r = b, old_r = a; +class ExprSideEffect : public ExprVisitor { + public: + void VisitExpr(const PrimExpr& e) final { + if (has_side_effect_) return; + ExprVisitor::VisitExpr(e); + } - while (r != 0) { - int64_t q = old_r / r; - std::swap(r, old_r); - r -= q * old_r; - std::swap(s, old_s); - s -= q * old_s; - std::swap(t, old_t); - t -= q * old_t; + void VisitExpr_(const CallNode* op) final { + if (!op->is_pure()) { + has_side_effect_ = true; + return; + } else { + ExprVisitor::VisitExpr_(op); + } } - CHECK_EQ(a % old_r, 0); - CHECK_EQ(b % old_r, 0); - CHECK(old_r == old_s*a + old_t*b); + bool has_side_effect_{false}; +}; - return std::make_tuple(old_r, old_s, old_t); +bool HasSideEffect(const PrimExpr& e) { + ExprSideEffect v; + v(e); + return v.has_side_effect_; } -} // namespace arith +} // namespace tir } // namespace tvm diff --git a/src/tir/analysis/var_touch.cc b/src/tir/analysis/var_touch.cc new file mode 100644 index 000000000000..2a2332955582 --- /dev/null +++ b/src/tir/analysis/var_touch.cc @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file simple_analysis.cc + * \brief Implementation of simple passes + */ +#include +#include +#include + +namespace tvm { +namespace tir { + +class VarTouchVisitor : public ExprVisitor { + public: + explicit VarTouchVisitor(std::function var_set) : var_set_(var_set) {} + + void VisitExpr(const PrimExpr& e) final { + if (use_var_) return; + ExprVisitor::VisitExpr(e); + } + + void VisitExpr_(const VarNode* op) final { Handle(op); } + + void VisitExpr_(const LoadNode* op) final { + Handle(op->buffer_var.get()); + ExprVisitor::VisitExpr_(op); + } + + void Handle(const VarNode* var) { + if (var_set_(var)) use_var_ = true; + } + + bool use_var_{false}; + + private: + std::function var_set_; +}; + +bool ExprUseVar(const PrimExpr& e, std::function var_set) { + VarTouchVisitor visitor(var_set); + visitor(e); + return visitor.use_var_; +} + +} // namespace tir +} // namespace tvm diff --git a/src/tir/pass/verify_gpu_code.cc b/src/tir/analysis/verify_gpu_code.cc similarity index 82% rename from src/tir/pass/verify_gpu_code.cc rename to src/tir/analysis/verify_gpu_code.cc index 70d909a859cc..1fbae0fd2dcd 100644 --- a/src/tir/pass/verify_gpu_code.cc +++ b/src/tir/analysis/verify_gpu_code.cc @@ -25,7 +25,7 @@ */ #include - +#include #include #include #include @@ -35,12 +35,8 @@ namespace tir { class GPUCodeVerifier : public StmtVisitor { public: - bool Verify(Stmt stmt, - int64_t max_local_memory_per_block, - int64_t max_shared_memory_per_block, - int64_t max_threads_per_block, - int64_t max_thread_x, - int64_t max_thread_y, + bool Verify(Stmt stmt, int64_t max_local_memory_per_block, int64_t max_shared_memory_per_block, + int64_t max_threads_per_block, int64_t max_thread_x, int64_t max_thread_y, int64_t max_thread_z) { max_local_memory_per_block_ = static_cast(max_local_memory_per_block); max_shared_memory_per_block_ = static_cast(max_shared_memory_per_block); @@ -84,7 +80,7 @@ class GPUCodeVerifier : public StmtVisitor { } Var var = op->node.as()->var; - const auto *extent = op->value.as(); + const auto* extent = op->value.as(); CHECK(extent); // record the number of threads in a block @@ -136,8 +132,8 @@ class GPUCodeVerifier : public StmtVisitor { private: int nest_level_{0}; - std::unordered_set visited_local_buffers_; - std::unordered_set visited_shared_buffers_; + std::unordered_set visited_local_buffers_; + std::unordered_set visited_shared_buffers_; std::unordered_set visited_threads_; size_t thread_x_extent_, thread_y_extent_, thread_z_extent_; @@ -164,8 +160,7 @@ class GPUCodeVerifier : public StmtVisitor { } }; -bool VerifyGPUCode(Stmt stmt, - Map constraints) { +bool VerifyGPUCode(const PrimFunc& func, Map constraints) { GPUCodeVerifier verifier; int64_t max_local_memory_per_block = INT64_MAX; @@ -193,14 +188,29 @@ bool VerifyGPUCode(Stmt stmt, LOG(FATAL) << "Invalid check item: " << iter.first; } - return verifier.Verify(stmt, - max_local_memory_per_block, - max_shared_memory_per_block, - max_threads_per_block, - max_thread_x, - max_thread_y, - max_thread_z); + return verifier.Verify(func->body, max_local_memory_per_block, max_shared_memory_per_block, + max_threads_per_block, max_thread_x, max_thread_y, max_thread_z); +} + +TVM_REGISTER_GLOBAL("tir.analysis.verify_gpu_code").set_body_typed(VerifyGPUCode); + +namespace transform { + +Pass VerifyGPUCode(Map constraints) { + auto pass_func = [=](IRModule mod, PassContext ctx) { + for (auto kv : mod->functions) { + if (auto* n = kv.second.as()) { + auto func = GetRef(n); + CHECK(VerifyGPUCode(func, constraints)) << "RuntimeError: GPU constraint violated" << func; + } + } + return mod; + }; + return tvm::transform::CreateModulePass(pass_func, 0, "tir.VerifyGPUCode", {}); } +TVM_REGISTER_GLOBAL("tir.transform.VerifyGPUCode").set_body_typed(VerifyGPUCode); + +} // namespace transform } // namespace tir } // namespace tvm diff --git a/src/tir/analysis/verify_memory.cc b/src/tir/analysis/verify_memory.cc index 8e684e966770..8eb846b7d618 100644 --- a/src/tir/analysis/verify_memory.cc +++ b/src/tir/analysis/verify_memory.cc @@ -21,12 +21,12 @@ * \file verify_memory.cc * \brief Pass to check if memory accesses are legal. */ -#include +#include +#include +#include #include +#include #include -#include -#include - namespace tvm { namespace tir { @@ -46,13 +46,12 @@ class MemoryAccessVerifier final : protected StmtExprVisitor { public: /// Special member functions //@{ - explicit MemoryAccessVerifier(PrimFunc f, int device_type) - : func_(f), dev_type_(device_type) {} + explicit MemoryAccessVerifier(PrimFunc f, int device_type) : func_(f), dev_type_(device_type) {} virtual ~MemoryAccessVerifier() = default; - MemoryAccessVerifier(const MemoryAccessVerifier &) = delete; - MemoryAccessVerifier(MemoryAccessVerifier &&) = delete; - MemoryAccessVerifier &operator=(const MemoryAccessVerifier &) = delete; - MemoryAccessVerifier &operator=(MemoryAccessVerifier &&) = delete; + MemoryAccessVerifier(const MemoryAccessVerifier&) = delete; + MemoryAccessVerifier(MemoryAccessVerifier&&) = delete; + MemoryAccessVerifier& operator=(const MemoryAccessVerifier&) = delete; + MemoryAccessVerifier& operator=(MemoryAccessVerifier&&) = delete; //@} /// Interface to perform memory access verification @@ -67,12 +66,12 @@ class MemoryAccessVerifier final : protected StmtExprVisitor { protected: /// Visitor implementation //@{ - void VisitExpr(const PrimExpr &n) final { + void VisitExpr(const PrimExpr& n) final { if (Failed()) return; StmtExprVisitor::VisitExpr(n); } - void VisitStmt(const Stmt &n) final { + void VisitStmt(const Stmt& n) final { if (Failed()) return; StmtExprVisitor::VisitStmt(n); } @@ -84,8 +83,8 @@ class MemoryAccessVerifier final : protected StmtExprVisitor { } void VisitStmt_(const AttrStmtNode* op) final { - if (!InThreadEnv() && (op->attr_key == attr::thread_extent || - op->attr_key == attr::pipeline_exec_scope)) { + if (!InThreadEnv() && + (op->attr_key == attr::thread_extent || op->attr_key == attr::pipeline_exec_scope)) { EnterThreadEnv(); StmtExprVisitor::VisitStmt_(op); ExitThreadEnv(); @@ -106,8 +105,8 @@ class MemoryAccessVerifier final : protected StmtExprVisitor { //@} /// Check if the value of a Variable comes from function argument. - bool IsFromFunctionArgs(const VarNode *var) const { - const VarNode *V = var; + bool IsFromFunctionArgs(const VarNode* var) const { + const VarNode* V = var; for (auto kv : func_->buffer_map) { if (V == kv.second->data.get()) return true; } @@ -118,9 +117,9 @@ class MemoryAccessVerifier final : protected StmtExprVisitor { // The value is expected to come from a tvm_struct_get Call. // Get the first argument of tvm_struct_get, and continue. - const auto &iter = defs_.find(V); + const auto& iter = defs_.find(V); if (iter == defs_.end()) return false; - const CallNode *C = iter->second.as(); + const CallNode* C = iter->second.as(); if (!C || C->name != intrinsic::tvm_struct_get) return false; V = C->args[0].as(); } @@ -128,7 +127,7 @@ class MemoryAccessVerifier final : protected StmtExprVisitor { } /// Handle memory access to a Variable - void HandleLoadStoreToVariable(const Var &var) { + void HandleLoadStoreToVariable(const Var& var) { // We skip the access within thread env. if (InThreadEnv()) return; @@ -152,14 +151,11 @@ class MemoryAccessVerifier final : protected StmtExprVisitor { /// Check if a given DLDeviceType/TVMDeviceExtType value denotes GPU device. static bool IsGPUDevice(int dev_type) { - return kDLGPU == dev_type || kDLOpenCL == dev_type || - kDLVulkan == dev_type || kDLMetal == dev_type || - kDLROCM == dev_type || kOpenGL == dev_type; + return kDLGPU == dev_type || kDLOpenCL == dev_type || kDLVulkan == dev_type || + kDLMetal == dev_type || kDLROCM == dev_type || kOpenGL == dev_type; } /// Check if a given DLDeviceType/TVMDeviceExtType value denotes FPGA device. - static bool IsFPGADevice(int dev_type) { - return kDLSDAccel == dev_type || kDLAOCL == dev_type; - } + static bool IsFPGADevice(int dev_type) { return kDLSDAccel == dev_type || kDLAOCL == dev_type; } private: /// Status of visitor @@ -167,38 +163,50 @@ class MemoryAccessVerifier final : protected StmtExprVisitor { bool in_thread_env_{false}; bool failure_{false}; ///< If the verification fails (i.e. has illegal access) //@} - tir::PrimFunc func_{nullptr}; ///< Function to be verified. - int dev_type_{kDLCPU}; ///< Device type - std::unordered_map defs_; ///< Variable definitions + tir::PrimFunc func_{nullptr}; ///< Function to be verified. + int dev_type_{kDLCPU}; ///< Device type + std::unordered_map defs_; ///< Variable definitions }; } // namespace /// Interface of VerifyMemory pass -void VerifyMemory(const IRModule& mod) { - for (auto kv : mod->functions) { - if (auto* n = kv.second.as()) { - PrimFunc func = GetRef(n); - auto target = func->GetAttr(tvm::attr::kTarget); - CHECK(target.defined()) - << "LowerWarpMemory: Require the target attribute"; - - if (func->GetAttr( - tvm::attr::kCallingConv, - Integer(CallingConv::kDefault)) == CallingConv::kDefault) { - MemoryAccessVerifier v(func, target.value()->device_type); - v.Run(); - if (v.Failed()) { - LOG(FATAL) - << "ValueError: Direct host side access to device memory is detected." - << " Did you forget to bind?\n" - << func; - } - } - } +bool VerifyMemory(const PrimFunc& func) { + auto target = func->GetAttr(tvm::attr::kTarget); + CHECK(target.defined()) << "LowerWarpMemory: Require the target attribute"; + + if (func->GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) == + CallingConv::kDefault) { + MemoryAccessVerifier v(func, target.value()->device_type); + v.Run(); + return !v.Failed(); + } else { + return true; } } -TVM_REGISTER_GLOBAL("tir.analysis.verify_memory") -.set_body_typed(VerifyMemory); +TVM_REGISTER_GLOBAL("tir.analysis.verify_memory").set_body_typed(VerifyMemory); + +namespace transform { + +Pass VerifyMemory() { + auto pass_func = + [=](IRModule mod, PassContext ctx) { + for (auto kv : mod->functions) { + if (auto* n = kv.second.as()) { + auto func = GetRef(n); + CHECK(VerifyMemory(func)) + << "RuntimeError: Direct host side access to device memory is detected." + << " Did you forget to bind?\n" + << func; + } + } + return mod; + }; + return tvm::transform::CreateModulePass(pass_func, 0, "tir.VerifyMemory", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.VerifyMemory").set_body_typed(VerifyMemory); + +} // namespace transform } // namespace tir } // namespace tvm diff --git a/src/tir/analysis/verify_ssa.cc b/src/tir/analysis/verify_ssa.cc new file mode 100644 index 000000000000..c57cbf7d0703 --- /dev/null +++ b/src/tir/analysis/verify_ssa.cc @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * SSA related checks and pass. + * + * SSA requires each varaible to be only defined once. + * \file verify_ssa.cc + */ +#include +#include +#include +#include + +#include +#include +#include + +namespace tvm { +namespace tir { + +class IRVerifySSA final : public StmtExprVisitor { + public: + bool is_ssa{true}; + + void VisitExpr(const PrimExpr& n) final { + if (!is_ssa) return; + StmtExprVisitor::VisitExpr(n); + } + void VisitStmt(const Stmt& n) final { + if (!is_ssa) return; + StmtExprVisitor::VisitStmt(n); + } + void VisitExpr_(const LetNode* op) final { + MarkDef(op->var.get()); + StmtExprVisitor::VisitExpr_(op); + } + void VisitStmt_(const LetStmtNode* op) final { + MarkDef(op->var.get()); + StmtExprVisitor::VisitStmt_(op); + } + void VisitStmt_(const ForNode* op) final { + MarkDef(op->loop_var.get()); + StmtExprVisitor::VisitStmt_(op); + } + void VisitStmt_(const AllocateNode* op) final { + MarkDef(op->buffer_var.get()); + StmtExprVisitor::VisitStmt_(op); + } + + void VisitExpr_(const VarNode* node) final { + if (match_scope_) { + MarkDef(node, true); + } + } + + void Run(const PrimFunc& func) { + for (auto param : func->params) { + MarkDef(param.get()); + } + + for (auto kv : func->buffer_map) { + this->DefineBuffer(kv.second); + } + this->VisitStmt(func->body); + } + + void DefineBuffer(const Buffer& buffer) { + match_scope_ = true; + this->VisitExpr(buffer->data); + for (size_t i = 0; i < buffer->shape.size(); ++i) { + this->VisitExpr(buffer->shape[i]); + } + + if (buffer->strides.defined()) { + for (size_t i = 0; i < buffer->strides.size(); ++i) { + this->VisitExpr(buffer->strides[i]); + } + } + this->VisitExpr(buffer->elem_offset); + + match_scope_ = false; + } + + private: + void MarkDef(const VarNode* v, bool allow_dup = false) { + if (defined_.count(v) != 0) { + if (!allow_dup) { + is_ssa = false; + return; + } + } else { + defined_[v] = 1; + } + } + // whether we are in match scope, where a var can occur multiple times. + bool match_scope_{false}; + std::unordered_map defined_; +}; + +bool VerifySSA(const PrimFunc& func) { + IRVerifySSA visitor; + visitor.Run(func); + return visitor.is_ssa; +} + +TVM_REGISTER_GLOBAL("tir.analysis.verify_ssa").set_body_typed(VerifySSA); + +namespace transform { + +Pass VerifySSA() { + auto pass_func = [=](IRModule mod, PassContext ctx) { + for (auto kv : mod->functions) { + if (auto* n = kv.second.as()) { + auto func = GetRef(n); + CHECK(VerifySSA(func)) << "RuntimeError: IR is not in SSA form" << func; + } + } + return mod; + }; + return tvm::transform::CreateModulePass(pass_func, 0, "tir.VerifySSA", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.VerifySSA").set_body_typed(VerifySSA); + +} // namespace transform + +} // namespace tir +} // namespace tvm diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index 6bbf6451b7ac..4e433fc718b1 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -20,47 +20,37 @@ /*! * \file buffer.cc */ +#include +#include #include +#include #include -#include #include -#include -#include +#include #include #include -#include "../../arith/compute_expr.h" namespace tvm { namespace tir { -// TODO(tqchen): change to floormod/div + using IndexMod = tir::FloorModNode; using IndexDiv = tir::FloorDivNode; -Array SimplifyArray(Array array) { +Array SimplifyArray(arith::Analyzer* ana, Array array) { for (size_t i = 0; i < array.size(); ++i) { - array.Set(i, tir::Simplify(array[i])); + array.Set(i, ana->Simplify(array[i])); } return array; } -Buffer decl_buffer(Array shape, - DataType dtype, - std::string name) { - return BufferNode::make( - Var(name, PointerType(PrimType(dtype))), - dtype, - shape, - Array(), - PrimExpr(), - name, - "", - 0, 0, - kDefault); +Buffer decl_buffer(Array shape, DataType dtype, String name) { + return Buffer(Var(name, PointerType(PrimType(dtype))), dtype, shape, Array(), + PrimExpr(), name, "", 0, 0, kDefault); } // Split the given expression w.r.t the add operator -inline std::vector ExprSplitAddition(const PrimExpr &expr) { +inline std::vector ExprSplitAddition(const PrimExpr& expr) { using namespace tir; std::vector ret; std::stack split_buffer; @@ -79,7 +69,6 @@ inline std::vector ExprSplitAddition(const PrimExpr &expr) { return ret; } - // Searches for the following types of expr: // mult_expr = (a1 + a2 + ... + aj + c / (k1 * k2 * ... * ki) * k1 * ... * kt-1 ) * kt * ... * ki // mod_l_expr = c @@ -87,9 +76,9 @@ inline std::vector ExprSplitAddition(const PrimExpr &expr) { // If it can be optimized, returns (true, (a1 + a2 + ... + aj) * kt * ... * ki + c) // Currently the we will not search the add/mult combinations exhaustively // as it will take too much computation. -inline std::pair MergeMulModInner(const PrimExpr &mult_expr, - const PrimExpr &mod_l_expr, - const PrimExpr &mod_r_expr) { +inline std::pair MergeMulModInner(const PrimExpr& mult_expr, + const PrimExpr& mod_l_expr, + const PrimExpr& mod_r_expr) { using namespace tir; const MulNode* mult_ptr = mult_expr.as(); if (!mult_ptr) return std::make_pair(false, PrimExpr()); @@ -124,9 +113,8 @@ inline std::pair MergeMulModInner(const PrimExpr &mult_expr, return std::make_pair(false, PrimExpr()); } else if (inner_div_ptr) { PrimExpr overall_mult = mult_inner.get() ? mult_inner * mult_outer : mult_outer; - if (expr_equal(overall_mult, inner_div_ptr->b) - && expr_equal(overall_mult, mod_r_expr) - && expr_equal(inner_div_ptr->a, mod_l_expr)) { + if (expr_equal(overall_mult, inner_div_ptr->b) && expr_equal(overall_mult, mod_r_expr) && + expr_equal(inner_div_ptr->a, mod_l_expr)) { // Found! PrimExpr ret = no_opt_sum.get() ? no_opt_sum * mult_outer + mod_l_expr : mod_l_expr; return std::make_pair(true, ret); @@ -157,9 +145,7 @@ inline std::pair MergeMulModInner(const PrimExpr &mult_expr, inline void MergeMulModInsertElements(const std::vector& eles, std::list* mult_exprs, std::list >* mod_exprs, - PrimExpr* no_opt_sum, - bool* has_mult, - bool* has_mod) { + PrimExpr* no_opt_sum, bool* has_mult, bool* has_mod) { using namespace tir; *has_mult = false; *has_mod = false; @@ -185,22 +171,21 @@ inline void MergeMulModInsertElements(const std::vector& eles, // The search will be performed repeatively until no pattern is found. // Return: a pair with (false, Expr()) if cannot be optimized. // a pair with (true, optimized_expr) if can be optimized -inline PrimExpr MergeMulMod(const PrimExpr &base) { +inline PrimExpr MergeMulMod(arith::Analyzer* analyzer, const PrimExpr& base) { using namespace tir; // 1. Prepare the lists. // We store two lists, a list that contain all the elements that match Mul and // a list that contain all the elements that match Mod. // The elements in the Mod will be used to match against the elements in Mul. // The result will then be split and pushed back to these two lists. - PrimExpr simplified_base = Simplify(base); + PrimExpr simplified_base = analyzer->Simplify(base); std::vector eles = ExprSplitAddition(simplified_base); std::list mult_exprs; std::list > mod_exprs; PrimExpr no_opt_sum; bool has_mult; bool has_mod; - MergeMulModInsertElements(eles, &mult_exprs, &mod_exprs, - &no_opt_sum, &has_mult, &has_mod); + MergeMulModInsertElements(eles, &mult_exprs, &mod_exprs, &no_opt_sum, &has_mult, &has_mod); bool find_opt = false; std::list >::iterator search_mod_it = mod_exprs.begin(); // 2. Exhaustive Search @@ -208,9 +193,8 @@ inline PrimExpr MergeMulMod(const PrimExpr &base) { std::list::iterator mult_it = mult_exprs.begin(); bool inner_find_opt = false; while (mult_it != mult_exprs.end()) { - std::pair ret = MergeMulModInner(*mult_it, - search_mod_it->first, - search_mod_it->second); + std::pair ret = + MergeMulModInner(*mult_it, search_mod_it->first, search_mod_it->second); if (ret.first) { inner_find_opt = true; auto temp_mod_it = search_mod_it; @@ -218,8 +202,8 @@ inline PrimExpr MergeMulMod(const PrimExpr &base) { mod_exprs.erase(temp_mod_it); mult_exprs.erase(mult_it); std::vector ret_eles = ExprSplitAddition(ret.second); - MergeMulModInsertElements(ret_eles, &mult_exprs, &mod_exprs, - &no_opt_sum, &has_mult, &has_mod); + MergeMulModInsertElements(ret_eles, &mult_exprs, &mod_exprs, &no_opt_sum, &has_mult, + &has_mod); if (has_mult) { search_mod_it = mod_exprs.begin(); } else if (has_mod && search_mod_it == mod_exprs.end()) { @@ -242,9 +226,9 @@ inline PrimExpr MergeMulMod(const PrimExpr &base) { no_opt_sum = no_opt_sum.get() ? no_opt_sum + *it : *it; } for (std::list >::iterator it = mod_exprs.begin(); - it != mod_exprs.end(); ++it) { - no_opt_sum = no_opt_sum.get() ? - no_opt_sum + indexmod(it->first, it->second) : indexmod(it->first, it->second); + it != mod_exprs.end(); ++it) { + no_opt_sum = no_opt_sum.get() ? no_opt_sum + indexmod(it->first, it->second) + : indexmod(it->first, it->second); } return no_opt_sum; } @@ -254,6 +238,7 @@ inline PrimExpr MergeMulMod(const PrimExpr &base) { // We also perform optimization to simplify the indexing expression. inline PrimExpr ElemOffset(const BufferNode* n, Array index) { PrimExpr base = n->elem_offset; + arith::Analyzer ana; if (n->strides.size() == 0) { // Scalar case if (n->shape.size() == 0 && index.size() == 1) { @@ -265,7 +250,7 @@ inline PrimExpr ElemOffset(const BufferNode* n, Array index) { if (index.size() > 0) { PrimExpr offset = index[0]; for (size_t i = 1; i < index.size(); ++i) { - offset = MergeMulMod(offset * n->shape[i] + index[i]); + offset = MergeMulMod(&ana, offset * n->shape[i] + index[i]); } base = base + offset; } @@ -273,12 +258,12 @@ inline PrimExpr ElemOffset(const BufferNode* n, Array index) { } else { CHECK_EQ(n->strides.size(), index.size()); if (is_zero(base)) { - base = MergeMulMod(index[0] * n->strides[0]); + base = MergeMulMod(&ana, index[0] * n->strides[0]); } else { - base = MergeMulMod(base + index[0] * n->strides[0]); + base = MergeMulMod(&ana, base + index[0] * n->strides[0]); } for (size_t i = 1; i < index.size(); ++i) { - base = MergeMulMod(base + index[i] * n->strides[i]); + base = MergeMulMod(&ana, base + index[i] * n->strides[i]); } } return base; @@ -290,7 +275,7 @@ inline PrimExpr BufferOffset(const BufferNode* n, Array index, DataTyp offset = offset * make_const(offset.dtype(), dtype.lanes()); } if (dtype.lanes() != 1) { - return tir::RampNode::make(offset, make_const(offset.dtype(), 1), dtype.lanes()); + return tir::Ramp(offset, make_const(offset.dtype(), 1), dtype.lanes()); } else { return offset; } @@ -299,20 +284,14 @@ inline PrimExpr BufferOffset(const BufferNode* n, Array index, DataTyp PrimExpr Buffer::vload(Array begin, DataType dtype) const { // specially handle bool, stored asDataType::Int(8) const BufferNode* n = operator->(); - CHECK(dtype.element_of() == n->dtype.element_of() && - dtype.lanes() % n->dtype.lanes() == 0) - << "Cannot load " << dtype - << " from buffer of " << n->dtype; + CHECK(dtype.element_of() == n->dtype.element_of() && dtype.lanes() % n->dtype.lanes() == 0) + << "Cannot load " << dtype << " from buffer of " << n->dtype; if (dtype == DataType::Bool()) { - return tir::CastNode::make( - DataType::Bool(), - tir::LoadNode::make( - DataType::Int(8), n->data, BufferOffset(n, begin, DataType::Int(8)), - const_true())); + return tir::Cast(DataType::Bool(), + tir::Load(DataType::Int(8), n->data, BufferOffset(n, begin, DataType::Int(8)), + const_true())); } else { - return tir::LoadNode::make( - dtype, n->data, BufferOffset(n, begin, dtype), - const_true(dtype.lanes())); + return tir::Load(dtype, n->data, BufferOffset(n, begin, dtype), const_true(dtype.lanes())); } } @@ -320,18 +299,13 @@ Stmt Buffer::vstore(Array begin, PrimExpr value) const { // specially handle bool, stored asDataType::Int(8) const BufferNode* n = operator->(); DataType dtype = value.dtype(); - CHECK(dtype.element_of() == n->dtype.element_of() && - dtype.lanes() % n->dtype.lanes() == 0) - << "Cannot load " << dtype - << " from buffer of " << n->dtype; + CHECK(dtype.element_of() == n->dtype.element_of() && dtype.lanes() % n->dtype.lanes() == 0) + << "Cannot load " << dtype << " from buffer of " << n->dtype; if (value.dtype() == DataType::Bool()) { - return tir::StoreNode::make(n->data, - tir::CastNode::make(DataType::Int(8), value), - BufferOffset(n, begin, DataType::Int(8)), - const_true()); + return tir::Store(n->data, tir::Cast(DataType::Int(8), value), + BufferOffset(n, begin, DataType::Int(8)), const_true()); } else { - return tir::StoreNode::make(n->data, value, BufferOffset(n, begin, dtype), - const_true(dtype.lanes())); + return tir::Store(n->data, value, BufferOffset(n, begin, dtype), const_true(dtype.lanes())); } } @@ -341,7 +315,7 @@ Buffer Buffer::MakeStrideView() const { std::vector temp; auto n = make_object(*operator->()); PrimExpr acc = make_const(n->DefaultIndexType(), 1); - for (size_t i = n->shape.size(); i != 0 ; --i) { + for (size_t i = n->shape.size(); i != 0; --i) { temp.push_back(acc); acc = acc * n->shape[i - 1]; } @@ -353,8 +327,9 @@ Buffer Buffer::MakeStrideView() const { Buffer Buffer::MakeSlice(Array begins, Array extents) const { const BufferNode* n = operator->(); - begins = SimplifyArray(begins); - PrimExpr elem_offset = tir::Simplify(ElemOffset(n, begins)); + arith::Analyzer ana; + begins = SimplifyArray(&ana, begins); + PrimExpr elem_offset = ana.Simplify(ElemOffset(n, begins)); Array strides = n->strides; if (strides.size() == 0) { bool can_relax = true; @@ -362,8 +337,7 @@ Buffer Buffer::MakeSlice(Array begins, Array extents) const // check if stride is needed. for (size_t i = 0; i < extents.size(); ++i) { if (!can_relax) { - if (!is_zero(begins[i]) || - !is_zero(tir::Simplify(extents[i] - n->shape[i]))) { + if (!is_zero(begins[i]) || !is_zero(ana.Simplify(extents[i] - n->shape[i]))) { need_stride = true; } } @@ -374,21 +348,11 @@ Buffer Buffer::MakeSlice(Array begins, Array extents) const return MakeStrideView().MakeSlice(begins, extents); } } - return BufferNode::make(n->data, - n->dtype, - extents, - strides, - elem_offset, - n->name + "_slice", - n->scope, - n->data_alignment, - 0, - n->buffer_type); + return Buffer(n->data, n->dtype, extents, strides, elem_offset, n->name + "_slice", n->scope, + n->data_alignment, 0, n->buffer_type); } -PrimExpr Buffer::access_ptr(int access_mask, - DataType ptr_type, - int content_lanes, +PrimExpr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lanes, PrimExpr offset) const { const BufferNode* self = operator->(); PrimExpr e_dtype; @@ -399,34 +363,25 @@ PrimExpr Buffer::access_ptr(int access_mask, int highest_dim = 0; extent = self->strides[highest_dim] * self->shape[highest_dim] - offset; } else { - extent = arith::ComputeReduce(self->shape, PrimExpr()) - offset; + auto fmul = [](PrimExpr a, PrimExpr b) { return a * b; }; + extent = foldl(fmul, make_const(DataType::Int(32), 1), self->shape) - offset; } PrimExpr elem_offset = self->elem_offset + offset; if (content_lanes > 1) { e_dtype = tir::TypeAnnotation(self->dtype.with_lanes(content_lanes)); extent = extent / make_const(self->elem_offset.dtype(), content_lanes); - elem_offset = self->elem_offset / make_const(self->elem_offset.dtype(), - content_lanes); + elem_offset = self->elem_offset / make_const(self->elem_offset.dtype(), content_lanes); } else { e_dtype = tir::TypeAnnotation(self->dtype); } - Array acc_args{ - e_dtype, self->data, elem_offset, - extent, make_const(DataType::Int(32), access_mask)}; - return tir::CallNode::make( - ptr_type, tir::intrinsic::tvm_access_ptr, acc_args, tir::CallNode::Intrinsic); + Array acc_args{e_dtype, self->data, elem_offset, extent, + make_const(DataType::Int(32), access_mask)}; + return tir::Call(ptr_type, tir::intrinsic::tvm_access_ptr, acc_args, tir::CallNode::Intrinsic); } -Buffer BufferNode::make(Var data, - DataType dtype, - Array shape, - Array strides, - PrimExpr elem_offset, - std::string name, - std::string scope, - int data_alignment, - int offset_factor, - BufferType buffer_type) { +Buffer::Buffer(Var data, DataType dtype, Array shape, Array strides, + PrimExpr elem_offset, String name, String scope, int data_alignment, + int offset_factor, BufferType buffer_type) { auto n = make_object(); n->data = std::move(data); n->dtype = dtype; @@ -455,35 +410,30 @@ Buffer BufferNode::make(Var data, n->strides.push_back(Var("stride", n->shape[i].dtype())); } } - return Buffer(n); + data_ = std::move(n); } TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "buffer(" << op->name << ", " << op << ")"; -}); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "buffer(" << op->name << ", " << op << ")"; + }); TVM_REGISTER_NODE_TYPE(BufferNode); +TVM_REGISTER_GLOBAL("tir.Buffer").set_body([](TVMArgs args, TVMRetValue* ret) { + CHECK_EQ(args.size(), 10); + auto buffer_type = args[9].operator String(); + BufferType type = (buffer_type == "auto_broadcast") ? kAutoBroadcast : kDefault; + *ret = + Buffer(args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7], args[8], type); +}); -TVM_REGISTER_GLOBAL("tir.Buffer") -.set_body([](TVMArgs args, TVMRetValue* ret) { - CHECK_EQ(args.size(), 10); - auto buffer_type = args[9].operator std::string(); - BufferType type = (buffer_type == "auto_broadcast") ? kAutoBroadcast : kDefault; - *ret = BufferNode::make(args[0], args[1], args[2], args[3], args[4], - args[5], args[6], args[7], args[8], type); - }); - -TVM_REGISTER_GLOBAL("tir.BufferAccessPtr") -.set_body_method(&Buffer::access_ptr); +TVM_REGISTER_GLOBAL("tir.BufferAccessPtr").set_body_method(&Buffer::access_ptr); -TVM_REGISTER_GLOBAL("tir.BufferVLoad") -.set_body_method(&Buffer::vload); +TVM_REGISTER_GLOBAL("tir.BufferVLoad").set_body_method(&Buffer::vload); -TVM_REGISTER_GLOBAL("tir.BufferVStore") -.set_body_method(&Buffer::vstore); +TVM_REGISTER_GLOBAL("tir.BufferVStore").set_body_method(&Buffer::vstore); } // namespace tir } // namespace tvm diff --git a/src/tir/ir/data_layout.cc b/src/tir/ir/data_layout.cc index 85842a0b9dcf..bc777db55dbe 100644 --- a/src/tir/ir/data_layout.cc +++ b/src/tir/ir/data_layout.cc @@ -21,44 +21,43 @@ * \file src/lang/data_layout.cc * \brief Data Layout expression. */ +#include #include #include -#include +#include + #include namespace tvm { namespace tir { -using tir::Var; using tir::IterVar; using tir::IterVarNode; +using tir::Var; TVM_REGISTER_NODE_TYPE(LayoutNode); TVM_REGISTER_NODE_TYPE(BijectiveLayoutNode); const LayoutAxis LayoutAxis::UPPER_CASE[] = { - LayoutAxis('A'), LayoutAxis('B'), LayoutAxis('C'), LayoutAxis('D'), LayoutAxis('E'), - LayoutAxis('F'), LayoutAxis('G'), LayoutAxis('H'), LayoutAxis('I'), LayoutAxis('J'), - LayoutAxis('K'), LayoutAxis('L'), LayoutAxis('M'), LayoutAxis('N'), LayoutAxis('O'), - LayoutAxis('P'), LayoutAxis('Q'), LayoutAxis('R'), LayoutAxis('S'), LayoutAxis('T'), - LayoutAxis('U'), LayoutAxis('V'), LayoutAxis('W'), LayoutAxis('X'), LayoutAxis('Y'), - LayoutAxis('Z') -}; + LayoutAxis('A'), LayoutAxis('B'), LayoutAxis('C'), LayoutAxis('D'), LayoutAxis('E'), + LayoutAxis('F'), LayoutAxis('G'), LayoutAxis('H'), LayoutAxis('I'), LayoutAxis('J'), + LayoutAxis('K'), LayoutAxis('L'), LayoutAxis('M'), LayoutAxis('N'), LayoutAxis('O'), + LayoutAxis('P'), LayoutAxis('Q'), LayoutAxis('R'), LayoutAxis('S'), LayoutAxis('T'), + LayoutAxis('U'), LayoutAxis('V'), LayoutAxis('W'), LayoutAxis('X'), LayoutAxis('Y'), + LayoutAxis('Z')}; const LayoutAxis LayoutAxis::LOWER_CASE[] = { - LayoutAxis('a'), LayoutAxis('b'), LayoutAxis('c'), LayoutAxis('d'), LayoutAxis('e'), - LayoutAxis('f'), LayoutAxis('g'), LayoutAxis('h'), LayoutAxis('i'), LayoutAxis('j'), - LayoutAxis('k'), LayoutAxis('l'), LayoutAxis('m'), LayoutAxis('n'), LayoutAxis('o'), - LayoutAxis('p'), LayoutAxis('q'), LayoutAxis('r'), LayoutAxis('s'), LayoutAxis('t'), - LayoutAxis('u'), LayoutAxis('v'), LayoutAxis('w'), LayoutAxis('x'), LayoutAxis('y'), - LayoutAxis('z') -}; + LayoutAxis('a'), LayoutAxis('b'), LayoutAxis('c'), LayoutAxis('d'), LayoutAxis('e'), + LayoutAxis('f'), LayoutAxis('g'), LayoutAxis('h'), LayoutAxis('i'), LayoutAxis('j'), + LayoutAxis('k'), LayoutAxis('l'), LayoutAxis('m'), LayoutAxis('n'), LayoutAxis('o'), + LayoutAxis('p'), LayoutAxis('q'), LayoutAxis('r'), LayoutAxis('s'), LayoutAxis('t'), + LayoutAxis('u'), LayoutAxis('v'), LayoutAxis('w'), LayoutAxis('x'), LayoutAxis('y'), + LayoutAxis('z')}; const LayoutAxis& LayoutAxis::Get(const char name) { CHECK((name >= 'A' && name <= 'Z') || (name >= 'a' && name <= 'z')) - << "Invalid layout axis name: " << name << ". Has to be A-Z or a-z."; - return (name >= 'A' && name <= 'Z') ? - LayoutAxis::UPPER_CASE[name-'A'] : - LayoutAxis::LOWER_CASE[name-'a']; + << "Invalid layout axis name: " << name << ". Has to be A-Z or a-z."; + return (name >= 'A' && name <= 'Z') ? LayoutAxis::UPPER_CASE[name - 'A'] + : LayoutAxis::LOWER_CASE[name - 'a']; } const LayoutAxis& LayoutAxis::Get(const IterVar& itvar) { @@ -67,7 +66,7 @@ const LayoutAxis& LayoutAxis::Get(const IterVar& itvar) { return LayoutAxis::Get(axis[0]); } -const LayoutAxis& LayoutAxis::make(const std::string& name) { +const LayoutAxis& LayoutAxis::Get(const std::string& name) { CHECK_EQ(name.length(), 1) << "Invalid axis " << name; return LayoutAxis::Get(name[0]); } @@ -81,9 +80,9 @@ Layout::Layout(const Array& axes) { CHECK_GT(factor->value, 0); repr << factor->value; } - CHECK_EQ(axis->var.get()->name_hint.size(), 1) << "Invalid layout axis " - << axis->var.get()->name_hint; - char c = axis->var.get()->name_hint[0]; + CHECK_EQ(axis->var.get()->name_hint.size(), 1) + << "Invalid layout axis " << axis->var.get()->name_hint; + char c = axis->var.get()->name_hint.operator std::string()[0]; CHECK((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z')) << "Invalid layout axis " << c; repr << axis->var.get()->name_hint; } @@ -91,7 +90,7 @@ Layout::Layout(const Array& axes) { data_ = std::move(node); } -Layout::Layout(const std::string& name) { // NOLINT(*) +Layout::Layout(const std::string& name) { // NOLINT(*) if (name == "__undef__") return; auto node = make_object(); @@ -103,19 +102,18 @@ Layout::Layout(const std::string& name) { // NOLINT(*) int32_t factor = 0; for (char c : name) { if (c >= 'A' && c <= 'Z') { - CHECK_EQ(factor, 0) << "Invalid layout " << name - << ": invalid factor size " << factor + CHECK_EQ(factor, 0) << "Invalid layout " << name << ": invalid factor size " << factor << " before dimension " << c; std::string shape_name("_shape"); shape_name.insert(0, 1, c); - IterVar axis = IterVarNode::make(Range(PrimExpr(0), Var(shape_name)), - Var(std::string(1, c)), tir::kDataPar); + IterVar axis = + IterVar(Range(PrimExpr(0), Var(shape_name)), Var(std::string(1, c)), tir::kDataPar); node->axes.push_back(axis); } else if (c >= 'a' && c <= 'z') { - CHECK_GT(factor, 0) << "Invalid layout " << name << ": invalid factor size " - << factor << " for dimension " << c; - IterVar axis = IterVarNode::make(Range(PrimExpr(0), PrimExpr(factor)), - Var(std::string(1, c)), tir::kDataPar); + CHECK_GT(factor, 0) << "Invalid layout " << name << ": invalid factor size " << factor + << " for dimension " << c; + IterVar axis = + IterVar(Range(PrimExpr(0), PrimExpr(factor)), Var(std::string(1, c)), tir::kDataPar); node->axes.push_back(axis); factor = 0; } else if (c >= '0' && c <= '9') { @@ -129,7 +127,7 @@ Layout::Layout(const std::string& name) { // NOLINT(*) // validate layout std::vector exist_axis(256, false); for (const IterVar& v : node->axes) { - auto axis_str = v->var.get()->name_hint; + auto axis_str = v->var.get()->name_hint.operator std::string(); CHECK_EQ(axis_str.size(), 1); char axis = axis_str[0]; CHECK((axis >= 'a' && axis <= 'z') || (axis >= 'A' && axis <= 'Z')); @@ -137,19 +135,15 @@ Layout::Layout(const std::string& name) { // NOLINT(*) exist_axis[axis] = true; } for (const IterVar& v : node->axes) { - char axis = v->var.get()->name_hint[0]; + char axis = v->var.get()->name_hint.operator std::string()[0]; if (axis >= 'a' && axis <= 'z') { - CHECK(exist_axis[axis-'a'+'A']) << "Invalid layout " << name << ": missing axis " - << std::toupper(axis); + CHECK(exist_axis[axis - 'a' + 'A']) + << "Invalid layout " << name << ": missing axis " << std::toupper(axis); } } data_ = std::move(node); } -Layout LayoutNode::make(const std::string& layout) { - return Layout(layout); -} - Layout Layout::SubLayout(size_t pos, size_t len) const { if (!defined() || pos > ndim()) return Layout::Undef(); if (len == 0) return Layout(Array()); @@ -162,22 +156,22 @@ Layout Layout::SubLayout(size_t pos, size_t len) const { return Layout(new_layout); } -Layout Layout::Split(const LayoutAxis &axis, size_t target_pos, int32_t factor) const { +Layout Layout::Split(const LayoutAxis& axis, size_t target_pos, int32_t factor) const { if (!defined()) return Layout::Undef(); const std::string& name = operator->()->name; const auto axes = operator->()->axes; - CHECK(target_pos <= this->ndim()) << "Invalid split position " - << target_pos << " for layout " << name; + CHECK(target_pos <= this->ndim()) + << "Invalid split position " << target_pos << " for layout " << name; CHECK(axis.IsPrimal()) << "Cannot split a subordinate axis " << axis; CHECK(this->Contains(axis)) << "Axis " << axis << " does not exist in " << name; - CHECK(!this->Contains(axis.ToSubordinate())) << "Axis " << axis - << " has already been split in " << name; + CHECK(!this->Contains(axis.ToSubordinate())) + << "Axis " << axis << " has already been split in " << name; CHECK(factor > 0) << "Invalid split size " << factor; Array new_layout; for (size_t i = 0; i <= this->ndim(); ++i) { if (i == target_pos) { - new_layout.push_back(IterVarNode::make(Range(PrimExpr(0), PrimExpr(factor)), - Var(axis.ToSubordinate().name()), tir::kDataPar)); + new_layout.push_back(IterVar(Range(PrimExpr(0), PrimExpr(factor)), + Var(axis.ToSubordinate().name()), tir::kDataPar)); } if (i == this->ndim()) break; new_layout.push_back(axes[i]); @@ -200,16 +194,15 @@ int32_t Layout::FactorOf(const LayoutAxis& axis) const { } TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* l = static_cast(node.get()); - p->stream << "Layout(" << l->name << ")"; - }); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* l = static_cast(node.get()); + p->stream << "Layout(" << l->name << ")"; + }); -inline bool GetStoreRule(Array* rule, - const Layout& src_layout, +inline bool GetStoreRule(Array* rule, const Layout& src_layout, const Layout& dst_layout) { - if (!src_layout.defined() || src_layout.name().empty() || - !dst_layout.defined() || dst_layout.name().empty()) { + if (!src_layout.defined() || src_layout.name().empty() || !dst_layout.defined() || + dst_layout.name().empty()) { return false; } for (size_t i = 0; i < dst_layout.ndim(); ++i) { @@ -253,15 +246,16 @@ inline bool GetStoreRule(Array* rule, } inline Array TransformIndex(const Array& src_index, - const Array& src_axis, - const Array& transform_rule) { + const Array& src_axis, + const Array& transform_rule) { + arith::Analyzer ana; Array result; std::unordered_map bind_map; for (size_t i = 0; i < src_index.size(); ++i) { bind_map[src_axis[i]->var.get()] = src_index[i]; } for (PrimExpr rule : transform_rule) { - result.push_back(tir::Simplify(tir::Substitute(rule, bind_map))); + result.push_back(ana.Simplify(tir::Substitute(rule, bind_map))); } return result; } @@ -270,23 +264,23 @@ Array BijectiveLayout::ForwardIndex(const Array& src_index) CHECK(defined()) << "Cannot operate on an undefined bijective layout."; const BijectiveLayoutNode* self = operator->(); CHECK_EQ(src_index.size(), self->src_layout->axes.size()) - << "Input mismatch with layout " << self->src_layout; + << "Input mismatch with layout " << self->src_layout; return TransformIndex(src_index, self->src_layout->axes, self->forward_rule); } - Array BijectiveLayout::BackwardIndex(const Array& dst_index) const { CHECK(defined()) << "Cannot operate on an undefined bijective layout."; const BijectiveLayoutNode* self = operator->(); CHECK_EQ(dst_index.size(), self->dst_layout->axes.size()) - << "Output mismatch with layout " << self->dst_layout; + << "Output mismatch with layout " << self->dst_layout; return TransformIndex(dst_index, self->dst_layout->axes, self->backward_rule); } inline Array TransformShape(const Array& src_shape, - const Array& src_axis, - const Array& target_axis, - const Array& transform_rule) { + const Array& src_axis, + const Array& target_axis, + const Array& transform_rule) { + arith::Analyzer ana; CHECK_EQ(src_shape.size(), src_axis.size()); // bind variables for original axes // for major-axis, bind the corresponding size @@ -306,8 +300,8 @@ inline Array TransformShape(const Array& src_shape, const auto* orig_axis_extent = orig_axis->dom->extent.as(); if (orig_shape_const) { CHECK_EQ(orig_shape_const->value, orig_axis_extent->value) - << "Input shape mismatch at index " << i << ". Expected " - << orig_axis->dom->extent << ", get " << orig_shape; + << "Input shape mismatch at index " << i << ". Expected " << orig_axis->dom->extent + << ", get " << orig_shape; } } bind_map[orig_axis->var.get()] = PrimExpr(0); @@ -327,9 +321,9 @@ inline Array TransformShape(const Array& src_shape, result.push_back(axis->dom->extent); } else { if (symbolic_var_set.count(i)) { - result.push_back(tir::AnyNode::make()); + result.push_back(tir::Any()); } else { - result.push_back(tir::Simplify(tir::Substitute(rule, bind_map))); + result.push_back(ana.Simplify(tir::Substitute(rule, bind_map))); } } } @@ -339,15 +333,13 @@ inline Array TransformShape(const Array& src_shape, Array BijectiveLayout::ForwardShape(const Array& shape) const { CHECK(defined()) << "Cannot operate on an undefined bijective layout."; const BijectiveLayoutNode* self = operator->(); - return TransformShape(shape, self->src_layout->axes, - self->dst_layout->axes, self->forward_rule); + return TransformShape(shape, self->src_layout->axes, self->dst_layout->axes, self->forward_rule); } Array BijectiveLayout::BackwardShape(const Array& shape) const { CHECK(defined()) << "Cannot operate on an undefined bijective layout."; const BijectiveLayoutNode* self = operator->(); - return TransformShape(shape, self->dst_layout->axes, - self->src_layout->axes, self->backward_rule); + return TransformShape(shape, self->dst_layout->axes, self->src_layout->axes, self->backward_rule); } BijectiveLayout::BijectiveLayout(Layout src_layout, Layout dst_layout) { @@ -365,51 +357,47 @@ BijectiveLayout::BijectiveLayout(Layout src_layout, Layout dst_layout) { } TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* b = static_cast(node.get()); - p->stream << "BijectiveLayout(" << b->src_layout.name() - << "->" << b->dst_layout.name() << ")"; - }); - -TVM_REGISTER_GLOBAL("tir.Layout") -.set_body_typed(LayoutNode::make); - -TVM_REGISTER_GLOBAL("tir.LayoutIndexOf") -.set_body_typed([](Layout layout, std::string axis) -> int { - return layout.IndexOf(LayoutAxis::make(axis)); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* b = static_cast(node.get()); + p->stream << "BijectiveLayout(" << b->src_layout.name() << "->" << b->dst_layout.name() + << ")"; + }); + +TVM_REGISTER_GLOBAL("tir.Layout").set_body_typed([](std::string name) { return Layout(name); }); + +TVM_REGISTER_GLOBAL("tir.LayoutIndexOf").set_body_typed([](Layout layout, std::string axis) -> int { + return layout.IndexOf(LayoutAxis::Get(axis)); }); TVM_REGISTER_GLOBAL("tir.LayoutFactorOf") -.set_body_typed([](Layout layout, std::string axis) -> int { - return layout.FactorOf(LayoutAxis::make(axis)); -}); + .set_body_typed([](Layout layout, std::string axis) -> int { + return layout.FactorOf(LayoutAxis::Get(axis)); + }); -TVM_REGISTER_GLOBAL("tir.LayoutNdim") -.set_body_typed([](Layout layout) -> int { +TVM_REGISTER_GLOBAL("tir.LayoutNdim").set_body_typed([](Layout layout) -> int { return layout.ndim(); }); -TVM_REGISTER_GLOBAL("tir.LayoutGetItem") -.set_body_typed([](Layout layout, int idx) -> std::string { +TVM_REGISTER_GLOBAL("tir.LayoutGetItem").set_body_typed([](Layout layout, int idx) -> std::string { const LayoutAxis& axis = layout[idx]; return axis.name(); }); TVM_REGISTER_GLOBAL("tir.BijectiveLayout") -.set_body_typed([](Layout src_layout, Layout dst_layout) -> BijectiveLayout { - return BijectiveLayout(src_layout, dst_layout); -}); + .set_body_typed([](Layout src_layout, Layout dst_layout) -> BijectiveLayout { + return BijectiveLayout(src_layout, dst_layout); + }); TVM_REGISTER_GLOBAL("tir.BijectiveLayoutForwardIndex") -.set_body_method(&BijectiveLayout::ForwardIndex); + .set_body_method(&BijectiveLayout::ForwardIndex); TVM_REGISTER_GLOBAL("tir.BijectiveLayoutBackwardIndex") -.set_body_method(&BijectiveLayout::BackwardIndex); + .set_body_method(&BijectiveLayout::BackwardIndex); TVM_REGISTER_GLOBAL("tir.BijectiveLayoutForwardShape") -.set_body_method(&BijectiveLayout::ForwardShape); + .set_body_method(&BijectiveLayout::ForwardShape); TVM_REGISTER_GLOBAL("tir.BijectiveLayoutBackwardShape") -.set_body_method(&BijectiveLayout::BackwardShape); + .set_body_method(&BijectiveLayout::BackwardShape); } // namespace tir } // namespace tvm diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index 65d424e31212..94efff5180fb 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -22,25 +22,52 @@ */ #include #include -#include #include -#include -#include +#include + #include -#include "../pass/ir_util.h" +#include + #include "../../support/str_escape.h" namespace tvm { namespace tir { -Var::Var(std::string name_hint, DataType dtype) { +#define TVM_DEFINE_BINOP_CONSTRUCTOR(Name) \ + Name::Name(PrimExpr a, PrimExpr b) { \ + using T = Name::ContainerType; \ + CHECK(a.defined()) << "ValueError: a is undefined\n"; \ + CHECK(b.defined()) << "ValueError: b is undefined\n"; \ + CHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types\n"; \ + ObjectPtr node = make_object(); \ + node->dtype = a.dtype(); \ + node->a = std::move(a); \ + node->b = std::move(b); \ + data_ = std::move(node); \ + } + +#define TVM_DEFINE_CMPOP_CONSTRUCTOR(Name) \ + Name::Name(PrimExpr a, PrimExpr b) { \ + using T = Name::ContainerType; \ + CHECK(a.defined()) << "ValueError: a is undefined\n"; \ + CHECK(b.defined()) << "ValueError: b is undefined\n"; \ + CHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types\n"; \ + ObjectPtr node = make_object(); \ + node->dtype = DataType::Bool(a.dtype().lanes()); \ + node->a = std::move(a); \ + node->b = std::move(b); \ + data_ = std::move(node); \ + } + +// Var +Var::Var(String name_hint, DataType dtype) { auto n = make_object(); n->name_hint = std::move(name_hint); n->dtype = std::move(dtype); data_ = std::move(n); } -Var::Var(std::string name_hint, Type type_annotation) { +Var::Var(String name_hint, Type type_annotation) { auto n = make_object(); n->name_hint = std::move(name_hint); n->dtype = GetRuntimeDataType(type_annotation); @@ -48,7 +75,7 @@ Var::Var(std::string name_hint, Type type_annotation) { data_ = std::move(n); } -Var Var::copy_with_suffix(const std::string& suffix) const { +Var Var::copy_with_suffix(const String& suffix) const { const VarNode* node = get(); ObjectPtr new_ptr; if (auto* ptr = this->as()) { @@ -56,21 +83,11 @@ Var Var::copy_with_suffix(const std::string& suffix) const { } else { new_ptr = make_object(*node); } - new_ptr->name_hint += suffix; - + new_ptr->name_hint = new_ptr->name_hint.operator std::string() + suffix.operator std::string(); return Var(new_ptr); } -SizeVar::SizeVar(std::string name_hint, DataType dtype) { - auto n = make_object(); - n->name_hint = std::move(name_hint); - n->dtype = std::move(dtype); - data_ = std::move(n); -} - - -TVM_REGISTER_GLOBAL("tir.Var") -.set_body_typed([](std::string name_hint, runtime::TVMArgValue type) { +TVM_REGISTER_GLOBAL("tir.Var").set_body_typed([](String name_hint, runtime::TVMArgValue type) { if (type.IsObjectRef()) { return Var(name_hint, type.operator Type()); } else { @@ -78,73 +95,364 @@ TVM_REGISTER_GLOBAL("tir.Var") } }); -TVM_REGISTER_GLOBAL("tir.SizeVar") -.set_body_typed([](std::string s, DataType t) { - return SizeVar(s, t); +TVM_REGISTER_NODE_TYPE(VarNode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + // omit the type + // stream << op->name << "." << op->type; + p->stream << op->name_hint; + }); + +// SizeVar +SizeVar::SizeVar(String name_hint, DataType dtype) { + auto n = make_object(); + n->name_hint = std::move(name_hint); + n->dtype = std::move(dtype); + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("tir.SizeVar").set_body_typed([](String s, DataType t) { + return SizeVar(s, t); }); +TVM_REGISTER_NODE_TYPE(SizeVarNode); -IterVar IterVarNode::make(Range dom, - Var var, - IterVarType t, - std::string thread_tag) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "{" << op->name_hint << "|" << op->name_hint << ">=0}"; + }); + +// IterVar +IterVar::IterVar(Range dom, Var var, IterVarType t, String thread_tag) { ObjectPtr n = make_object(); n->dom = dom; n->var = var; n->iter_type = t; n->thread_tag = thread_tag; - return IterVar(n); + data_ = std::move(n); } TVM_REGISTER_GLOBAL("tir.IterVar") -.set_body_typed([](Range dom, Var var, int iter_type, std::string thread_tag) { - return IterVarNode::make( - dom, var, - static_cast(iter_type), - thread_tag); -}); + .set_body_typed([](Range dom, Var var, int iter_type, String thread_tag) { + return IterVar(dom, var, static_cast(iter_type), thread_tag); + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "iter_var("; - if (op->var->name_hint.length() != 0) { - p->stream << op->var->name_hint << ", "; - } - if (op->dom.defined()) { - p->stream << op->dom; - } - if (op->thread_tag.length() != 0) { - p->stream << ", " << op->thread_tag; - } - p->stream << ")"; - }); - + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "iter_var("; + if (op->var->name_hint.length() != 0) { + p->stream << op->var->name_hint << ", "; + } + if (op->dom.defined()) { + p->stream << op->dom; + } + if (op->thread_tag.length() != 0) { + p->stream << ", " << op->thread_tag; + } + p->stream << ")"; + }); TVM_REGISTER_NODE_TYPE(IterVarNode); -PrimExpr StringImmNode::make(std::string value) { +// StringImm +StringImm::StringImm(String value) { ObjectPtr node = make_object(); node->dtype = DataType::Handle(); node->value = std::move(value); - return PrimExpr(node); + data_ = std::move(node); } -TVM_REGISTER_GLOBAL("tir.StringImm") -.set_body_typed(StringImmNode::make); +TVM_REGISTER_GLOBAL("tir.StringImm").set_body_typed([](String value) { return StringImm(value); }); + +TVM_REGISTER_NODE_TYPE(StringImmNode); +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << '\"' << support::StrEscape(op->value) << '\"'; + }); -PrimExpr CastNode::make(DataType t, PrimExpr value) { +// Cast +Cast::Cast(DataType t, PrimExpr value) { CHECK(value.defined()); CHECK_EQ(t.lanes(), value.dtype().lanes()); ObjectPtr node = make_object(); node->dtype = t; node->value = std::move(value); - return PrimExpr(node); + data_ = std::move(node); } +TVM_REGISTER_GLOBAL("tir.Cast").set_body_typed([](DataType dtype, PrimExpr value) { + return Cast(dtype, value); +}); + +TVM_REGISTER_NODE_TYPE(CastNode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << op->dtype << '('; + p->Print(op->value); + p->stream << ')'; + }); + +// Add +TVM_DEFINE_BINOP_CONSTRUCTOR(Add); + +TVM_REGISTER_GLOBAL("tir.Add").set_body_typed([](PrimExpr a, PrimExpr b) { return Add(a, b); }); + +TVM_REGISTER_NODE_TYPE(AddNode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << '('; + p->Print(op->a); + p->stream << " + "; + p->Print(op->b); + p->stream << ')'; + }); + +// Sub +TVM_DEFINE_BINOP_CONSTRUCTOR(Sub); + +TVM_REGISTER_GLOBAL("tir.Sub").set_body_typed([](PrimExpr a, PrimExpr b) { return Sub(a, b); }); + +TVM_REGISTER_NODE_TYPE(SubNode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << '('; + p->Print(op->a); + p->stream << " - "; + p->Print(op->b); + p->stream << ')'; + }); + +// Mul +TVM_DEFINE_BINOP_CONSTRUCTOR(Mul); + +TVM_REGISTER_GLOBAL("tir.Mul").set_body_typed([](PrimExpr a, PrimExpr b) { return Mul(a, b); }); + +TVM_REGISTER_NODE_TYPE(MulNode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << '('; + p->Print(op->a); + p->stream << "*"; + p->Print(op->b); + p->stream << ')'; + }); + +// Div +TVM_DEFINE_BINOP_CONSTRUCTOR(Div); + +TVM_REGISTER_GLOBAL("tir.Div").set_body_typed([](PrimExpr a, PrimExpr b) { return Div(a, b); }); + +TVM_REGISTER_NODE_TYPE(DivNode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << '('; + p->Print(op->a); + p->stream << "/"; + p->Print(op->b); + p->stream << ')'; + }); + +// Mod +TVM_DEFINE_BINOP_CONSTRUCTOR(Mod); + +TVM_REGISTER_GLOBAL("tir.Mod").set_body_typed([](PrimExpr a, PrimExpr b) { return Mod(a, b); }); + +TVM_REGISTER_NODE_TYPE(ModNode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << '('; + p->Print(op->a); + p->stream << " % "; + p->Print(op->b); + p->stream << ')'; + }); + +// FloorDiv +TVM_DEFINE_BINOP_CONSTRUCTOR(FloorDiv); + +TVM_REGISTER_GLOBAL("tir.FloorDiv").set_body_typed([](PrimExpr a, PrimExpr b) { + return FloorDiv(a, b); +}); + +TVM_REGISTER_NODE_TYPE(FloorDivNode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "floordiv(" << op->a << ", " << op->b << ")"; + }); + +// FloorMod +TVM_DEFINE_BINOP_CONSTRUCTOR(FloorMod); + +TVM_REGISTER_GLOBAL("tir.FloorMod").set_body_typed([](PrimExpr a, PrimExpr b) { + return FloorMod(a, b); +}); -PrimExpr AndNode::make(PrimExpr a, PrimExpr b) { +TVM_REGISTER_NODE_TYPE(FloorModNode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "floormod(" << op->a << ", " << op->b << ")"; + }); + +// Min +TVM_DEFINE_BINOP_CONSTRUCTOR(Min); + +TVM_REGISTER_GLOBAL("tir.Min").set_body_typed([](PrimExpr a, PrimExpr b) { return Min(a, b); }); + +TVM_REGISTER_NODE_TYPE(MinNode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "min("; + p->Print(op->a); + p->stream << ", "; + p->Print(op->b); + p->stream << ")"; + }); + +// Max +TVM_DEFINE_BINOP_CONSTRUCTOR(Max); + +TVM_REGISTER_GLOBAL("tir.Max").set_body_typed([](PrimExpr a, PrimExpr b) { return Max(a, b); }); + +TVM_REGISTER_NODE_TYPE(MaxNode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "max("; + p->Print(op->a); + p->stream << ", "; + p->Print(op->b); + p->stream << ")"; + }); + +// EQ +TVM_DEFINE_CMPOP_CONSTRUCTOR(EQ); + +TVM_REGISTER_GLOBAL("tir.EQ").set_body_typed([](PrimExpr a, PrimExpr b) { return EQ(a, b); }); + +TVM_REGISTER_NODE_TYPE(EQNode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << '('; + p->Print(op->a); + p->stream << " == "; + p->Print(op->b); + p->stream << ')'; + }); + +// NE +TVM_DEFINE_CMPOP_CONSTRUCTOR(NE); + +TVM_REGISTER_GLOBAL("tir.NE").set_body_typed([](PrimExpr a, PrimExpr b) { return NE(a, b); }); + +TVM_REGISTER_NODE_TYPE(NENode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << '('; + p->Print(op->a); + p->stream << " != "; + p->Print(op->b); + p->stream << ')'; + }); + +// LT +TVM_DEFINE_CMPOP_CONSTRUCTOR(LT); + +TVM_REGISTER_GLOBAL("tir.LT").set_body_typed([](PrimExpr a, PrimExpr b) { return LT(a, b); }); + +TVM_REGISTER_NODE_TYPE(LTNode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << '('; + p->Print(op->a); + p->stream << " < "; + p->Print(op->b); + p->stream << ')'; + }); + +// LE +TVM_DEFINE_CMPOP_CONSTRUCTOR(LE); + +TVM_REGISTER_GLOBAL("tir.LE").set_body_typed([](PrimExpr a, PrimExpr b) { return LE(a, b); }); + +TVM_REGISTER_NODE_TYPE(LENode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << '('; + p->Print(op->a); + p->stream << " <= "; + p->Print(op->b); + p->stream << ')'; + }); + +// GT +TVM_DEFINE_CMPOP_CONSTRUCTOR(GT); + +TVM_REGISTER_GLOBAL("tir.GT").set_body_typed([](PrimExpr a, PrimExpr b) { return GT(a, b); }); + +TVM_REGISTER_NODE_TYPE(GTNode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << '('; + p->Print(op->a); + p->stream << " > "; + p->Print(op->b); + p->stream << ')'; + }); + +// GE +TVM_DEFINE_CMPOP_CONSTRUCTOR(GE); + +TVM_REGISTER_GLOBAL("tir.GE").set_body_typed([](PrimExpr a, PrimExpr b) { return GE(a, b); }); + +TVM_REGISTER_NODE_TYPE(GENode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << '('; + p->Print(op->a); + p->stream << " >= "; + p->Print(op->b); + p->stream << ')'; + }); + +// And +And::And(PrimExpr a, PrimExpr b) { CHECK(a.defined()) << "ValueError: a is undefined"; CHECK(b.defined()) << "ValueError: b is undefined"; CHECK(a.dtype().is_bool()); @@ -155,10 +463,25 @@ PrimExpr AndNode::make(PrimExpr a, PrimExpr b) { node->dtype = DataType::Bool(a.dtype().lanes()); node->a = std::move(a); node->b = std::move(b); - return PrimExpr(node); + data_ = std::move(node); } -PrimExpr OrNode::make(PrimExpr a, PrimExpr b) { +TVM_REGISTER_GLOBAL("tir.And").set_body_typed([](PrimExpr a, PrimExpr b) { return And(a, b); }); + +TVM_REGISTER_NODE_TYPE(AndNode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << '('; + p->Print(op->a); + p->stream << " && "; + p->Print(op->b); + p->stream << ')'; + }); + +// Or +Or::Or(PrimExpr a, PrimExpr b) { CHECK(a.defined()) << "ValueError: a is undefined"; CHECK(b.defined()) << "ValueError: b is undefined"; CHECK(a.dtype().is_bool()); @@ -169,29 +492,52 @@ PrimExpr OrNode::make(PrimExpr a, PrimExpr b) { node->dtype = DataType::Bool(a.dtype().lanes()); node->a = std::move(a); node->b = std::move(b); - return PrimExpr(node); + data_ = std::move(node); } +TVM_REGISTER_GLOBAL("tir.Or").set_body_typed([](PrimExpr a, PrimExpr b) { return Or(a, b); }); -PrimExpr NotNode::make(PrimExpr a) { +TVM_REGISTER_NODE_TYPE(OrNode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << '('; + p->Print(op->a); + p->stream << " || "; + p->Print(op->b); + p->stream << ')'; + }); + +// Not +Not::Not(PrimExpr a) { CHECK(a.defined()) << "ValueError: a is undefined"; CHECK(a.dtype().is_bool()); ObjectPtr node = make_object(); node->dtype = DataType::Bool(a.dtype().lanes()); node->a = std::move(a); - return PrimExpr(node); + data_ = std::move(node); } +TVM_REGISTER_GLOBAL("tir.Not").set_body_typed([](PrimExpr a) { return Not(a); }); +TVM_REGISTER_NODE_TYPE(NotNode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << '!'; + p->Print(op->a); + }); -PrimExpr SelectNode::make(PrimExpr condition, PrimExpr true_value, PrimExpr false_value) { +// Select +Select::Select(PrimExpr condition, PrimExpr true_value, PrimExpr false_value) { CHECK(condition.defined()) << "ValueError: condition is undefined"; CHECK(true_value.defined()) << "ValueError: true_value is undefined"; CHECK(false_value.defined()) << "ValueError: true_value is undefined"; CHECK(condition.dtype().is_bool()); - CHECK(condition.dtype().lanes() == true_value.dtype().lanes() || - condition.dtype().lanes() == 1); + CHECK(condition.dtype().lanes() == true_value.dtype().lanes() || condition.dtype().lanes() == 1); CHECK(false_value.dtype() == true_value.dtype()) << "TypeError: mismatched types"; ObjectPtr node = make_object(); @@ -199,10 +545,30 @@ PrimExpr SelectNode::make(PrimExpr condition, PrimExpr true_value, PrimExpr fals node->condition = std::move(condition); node->true_value = std::move(true_value); node->false_value = std::move(false_value); - return PrimExpr(node); + data_ = std::move(node); } -PrimExpr LoadNode::make(DataType dtype, Var buffer_var, PrimExpr index, PrimExpr predicate) { +TVM_REGISTER_GLOBAL("tir.Select") + .set_body_typed([](PrimExpr condition, PrimExpr true_value, PrimExpr false_value) { + return Select(condition, true_value, false_value); + }); + +TVM_REGISTER_NODE_TYPE(SelectNode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "select("; + p->Print(op->condition); + p->stream << ", "; + p->Print(op->true_value); + p->stream << ", "; + p->Print(op->false_value); + p->stream << ")"; + }); + +// Load +Load::Load(DataType dtype, Var buffer_var, PrimExpr index, PrimExpr predicate) { CHECK(buffer_var.defined()); CHECK(predicate.defined()); CHECK(index.defined()); @@ -215,10 +581,34 @@ PrimExpr LoadNode::make(DataType dtype, Var buffer_var, PrimExpr index, PrimExpr node->index = std::move(index); node->predicate = std::move(predicate); - return PrimExpr(node); + data_ = std::move(node); } -PrimExpr RampNode::make(PrimExpr base, PrimExpr stride, int lanes) { +TVM_REGISTER_GLOBAL("tir.Load").set_body([](TVMArgs args, TVMRetValue* ret) { + DataType t = args[0]; + if (args.size() == 3) { + *ret = Load(t, args[1], args[2], const_true(t.lanes())); + } else { + *ret = Load(t, args[1], args[2], args[3]); + } +}); + +TVM_REGISTER_NODE_TYPE(LoadNode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << op->buffer_var << "["; + p->Print(op->index); + p->stream << "]"; + if (!is_one(op->predicate)) { + p->stream << " if "; + p->Print(op->predicate); + } + }); + +// Ramp +Ramp::Ramp(PrimExpr base, PrimExpr stride, int lanes) { CHECK(base.defined()); CHECK(stride.defined()); CHECK(base.dtype().is_scalar()); @@ -231,10 +621,27 @@ PrimExpr RampNode::make(PrimExpr base, PrimExpr stride, int lanes) { node->base = base; node->stride = stride; node->lanes = lanes; - return PrimExpr(node); + data_ = std::move(node); } -PrimExpr BroadcastNode::make(PrimExpr value, int lanes) { +TVM_REGISTER_GLOBAL("tir.Ramp").set_body_typed([](PrimExpr base, PrimExpr stride, int lanes) { + return Ramp(base, stride, lanes); +}); + +TVM_REGISTER_NODE_TYPE(RampNode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "ramp("; + p->Print(op->base); + p->stream << ", "; + p->Print(op->stride); + p->stream << ", " << op->lanes << ")"; + }); + +// Broadcast +Broadcast::Broadcast(PrimExpr value, int lanes) { CHECK(value.defined()); CHECK(value.dtype().is_scalar()); CHECK_GT(lanes, 1); @@ -243,10 +650,25 @@ PrimExpr BroadcastNode::make(PrimExpr value, int lanes) { node->dtype = value.dtype().with_lanes(lanes); node->value = std::move(value); node->lanes = lanes; - return PrimExpr(node); + data_ = node; } -PrimExpr LetNode::make(Var var, PrimExpr value, PrimExpr body) { +TVM_REGISTER_GLOBAL("tir.Broadcast").set_body_typed([](PrimExpr value, int lanes) { + return Broadcast(value, lanes); +}); + +TVM_REGISTER_NODE_TYPE(BroadcastNode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "x" << op->lanes << "("; + p->Print(op->value); + p->stream << ")"; + }); + +// Let +Let::Let(Var var, PrimExpr value, PrimExpr body) { CHECK(value.defined()); CHECK(body.defined()); CHECK_EQ(value.dtype(), var.dtype()); @@ -256,14 +678,57 @@ PrimExpr LetNode::make(Var var, PrimExpr value, PrimExpr body) { node->var = std::move(var); node->value = std::move(value); node->body = std::move(body); - return PrimExpr(node); + data_ = std::move(node); +} + +TVM_REGISTER_GLOBAL("tir.Let").set_body_typed([](Var var, PrimExpr value, PrimExpr body) { + return Let(var, value, body); +}); + +TVM_REGISTER_NODE_TYPE(LetNode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "(let " << op->var << " = "; + p->Print(op->value); + p->stream << " in "; + p->Print(op->body); + p->stream << ")"; + }); + +// Call +Call::Call(DataType dtype, String name, Array args, CallType call_type) { + for (size_t i = 0; i < args.size(); ++i) { + CHECK(args[i].defined()); + } + + ObjectPtr node = make_object(); + node->dtype = dtype; + node->name = std::move(name); + node->args = std::move(args); + node->call_type = call_type; + data_ = std::move(node); } -const char* CallNode::vectorizable_intrinsics[] = { - "floor", "ceil", "sign", "trunc", "fabs", "round", "exp", "tanh", "sqrt", - "log", "sin", "cos", "pow", "tan", tir::CallNode::shift_left, tir::CallNode::shift_right, - tir::CallNode::likely, tir::CallNode::popcount -}; +const char* CallNode::vectorizable_intrinsics[] = {"floor", + "ceil", + "sign", + "trunc", + "fabs", + "round", + "exp", + "tanh", + "sqrt", + "log", + "sin", + "cos", + "pow", + "tan", + tir::CallNode::shift_left, + tir::CallNode::shift_right, + tir::CallNode::likely, + tir::CallNode::popcount}; bool CallNode::is_vectorizable() const { size_t cnt = sizeof(CallNode::vectorizable_intrinsics) / sizeof(char*); @@ -275,34 +740,37 @@ bool CallNode::is_vectorizable() const { return false; } -PrimExpr CallNode::make(DataType dtype, - std::string name, - Array args, - CallType call_type, - FunctionRef func, - int value_index) { - for (size_t i = 0; i < args.size(); ++i) { - CHECK(args[i].defined()); - } +TVM_REGISTER_GLOBAL("tir.Call") + .set_body_typed([](DataType type, String name, Array args, int call_type) { + Array prim_expr_args; + for (const auto& it : args) { + CHECK(it->IsInstance() || it->IsInstance()); + if (const auto* str = it.as()) { + prim_expr_args.push_back(StringImm(str->data)); + } else { + prim_expr_args.push_back(Downcast(it)); + } + } + return Call(type, name, prim_expr_args, static_cast(call_type)); + }); - if (call_type == Halide) { - for (size_t i = 0; i < args.size(); ++i) { - CHECK(args[i].dtype().is_int()); - } - } +TVM_REGISTER_NODE_TYPE(CallNode); - ObjectPtr node = make_object(); - node->dtype = dtype; - node->name = std::move(name); - node->args = std::move(args); - node->call_type = call_type; - node->func = std::move(func); - node->value_index = value_index; - return PrimExpr(node); -} +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << op->name << "("; + for (size_t i = 0; i < op->args.size(); ++i) { + p->Print(op->args[i]); + if (i < op->args.size() - 1) { + p->stream << ", "; + } + } + p->stream << ")"; + }); -PrimExpr ShuffleNode::make(Array vectors, - Array indices) { +// Shuffle +Shuffle::Shuffle(Array vectors, Array indices) { CHECK_NE(vectors.size(), 0U); CHECK_NE(indices.size(), 0U); @@ -319,10 +787,10 @@ PrimExpr ShuffleNode::make(Array vectors, node->dtype = base_type.with_lanes(static_cast(indices.size())); node->vectors = std::move(vectors); node->indices = std::move(indices); - return PrimExpr(node); + data_ = node; } -PrimExpr ShuffleNode::make_concat(Array vectors) { +PrimExpr Shuffle::Concat(Array vectors) { CHECK_NE(vectors.size(), 0); if (vectors.size() == 1) { return vectors[0]; @@ -334,23 +802,49 @@ PrimExpr ShuffleNode::make_concat(Array vectors) { indices.push_back(IntImm(DataType::Int(32), index++)); } } - return make(vectors, indices); + return Shuffle(vectors, indices); } -PrimExpr ShuffleNode::make_extract_element(PrimExpr vector, int index) { - return make({vector}, {Integer(index)}); +PrimExpr Shuffle::ExtractElement(PrimExpr vector, int index) { + return Shuffle({vector}, {Integer(index)}); } -CommReducer CommReducerNode::make(Array lhs, - Array rhs, - Array result, - Array identity_element) { +TVM_REGISTER_GLOBAL("tir.Shuffle") + .set_body_typed([](Array vectors, Array indices) { + return Shuffle(vectors, indices); + }); + +TVM_REGISTER_NODE_TYPE(ShuffleNode); + +template +void PrintList(const Array& exprs, ReprPrinter* p) { + for (size_t i = 0; i < exprs.size(); ++i) { + p->Print(exprs[i]); + if (i < exprs.size() - 1) { + p->stream << ", "; + } + } +} + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "shuffle("; + PrintList(op->vectors, p); + p->stream << ", "; + PrintList(op->indices, p); + p->stream << ")"; + }); + +// CommReducer +CommReducer::CommReducer(Array lhs, Array rhs, Array result, + Array identity_element) { auto node = make_object(); node->lhs = lhs; node->rhs = rhs; node->result = result; node->identity_element = identity_element; - return CommReducer(node); + data_ = std::move(node); } Array CommReducerNode::operator()(Array a, Array b) const { @@ -362,23 +856,34 @@ Array CommReducerNode::operator()(Array a, Array b value_map.Set(lhs[i], a[i]); value_map.Set(rhs[i], b[i]); } - return UpdateArray(result, [&value_map] (const PrimExpr& e) { - return Substitute(e, value_map); - }); + auto ret = this->result; + ret.MutateByApply([&value_map](const PrimExpr& e) { return Substitute(e, value_map); }); + return ret; } TVM_REGISTER_GLOBAL("tir.CommReducer") -.set_body_typed(CommReducerNode::make); + .set_body_typed([](Array lhs, Array rhs, Array result, + Array identity_element) { + return CommReducer(lhs, rhs, result, identity_element); + }); TVM_REGISTER_GLOBAL("tir.CommReducerCombine") -.set_body_method(&tir::CommReducerNode::operator()); + .set_body_method(&tir::CommReducerNode::operator()); + +TVM_REGISTER_NODE_TYPE(CommReducerNode); +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "comm_reducer(result=" << op->result << ", lhs=" << op->lhs + << ", rhs=" << op->rhs << ", identity_element=" << op->identity_element << ")"; + }); -PrimExpr ReduceNode::make(CommReducer combiner, Array source, - Array axis, PrimExpr condition, int value_index) { +// Reduce +Reduce::Reduce(CommReducer combiner, Array source, Array axis, + PrimExpr condition, int value_index) { for (size_t i = 0; i < axis.size(); ++i) { - CHECK_EQ(axis[i]->iter_type, kCommReduce) - << "Can only take axis created by reduce_axis"; + CHECK_EQ(axis[i]->iter_type, kCommReduce) << "Can only take axis created by reduce_axis"; } if (!condition.defined()) { condition = const_true(); @@ -394,19 +899,39 @@ PrimExpr ReduceNode::make(CommReducer combiner, Array source, n->axis = std::move(axis); n->condition = condition; n->value_index = value_index; - return PrimExpr(n); + data_ = std::move(n); } - TVM_REGISTER_GLOBAL("tir.Reduce") -.set_body_typed(ReduceNode::make); + .set_body_typed([](CommReducer combiner, Array source, Array axis, + PrimExpr condition, int value_index) { + return Reduce(combiner, source, axis, condition, value_index); + }); +TVM_REGISTER_NODE_TYPE(ReduceNode); -PrimExpr AnyNode::make() { - auto n = make_object(); - return PrimExpr(n); -} +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "reduce(combiner=" << op->combiner; + p->stream << ", source=" << op->source; + p->stream << ", axis=" << op->axis; + p->stream << ", where=" << op->condition; + p->stream << ", value_index=" << op->value_index; + p->stream << ")"; + }); +// Any +Any::Any() { data_ = make_object(); } + +TVM_REGISTER_GLOBAL("tir.Any").set_body_typed([]() { return Any(); }); + +TVM_REGISTER_NODE_TYPE(AnyNode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { p->stream << "?"; }); + +// BufferLoad BufferLoad::BufferLoad(Buffer buffer, Array indices) { ObjectPtr node = make_object(); node->dtype = buffer->dtype; @@ -415,410 +940,52 @@ BufferLoad::BufferLoad(Buffer buffer, Array indices) { data_ = std::move(node); } -TVM_REGISTER_GLOBAL("tir.BufferLoad") -.set_body_typed([](Buffer buffer, Array indices) { +TVM_REGISTER_GLOBAL("tir.BufferLoad").set_body_typed([](Buffer buffer, Array indices) { return BufferLoad(buffer, indices); }); TVM_REGISTER_NODE_TYPE(BufferLoadNode); - TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << '\"' << support::StrEscape(op->value) << '\"'; -}); - -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << op->dtype << '('; - p->Print(op->value); - p->stream << ')'; - }) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - // omit the type - // stream << op->name << "." << op->type; - p->stream << op->name_hint; - }) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "{" << op->name_hint << "|" << op->name_hint << ">=0}"; - }) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << '('; - p->Print(op->a); - p->stream << " + "; - p->Print(op->b); - p->stream << ')'; - }) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << '('; - p->Print(op->a); - p->stream << " - "; - p->Print(op->b); - p->stream << ')'; - }) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << '('; - p->Print(op->a); - p->stream << "*"; - p->Print(op->b); - p->stream << ')'; - }) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << '('; - p->Print(op->a); - p->stream << "/"; - p->Print(op->b); - p->stream << ')'; - }) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << '('; - p->Print(op->a); - p->stream << " % "; - p->Print(op->b); - p->stream << ')'; -}) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "min("; - p->Print(op->a); - p->stream << ", "; - p->Print(op->b); - p->stream << ")"; -}) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "max("; - p->Print(op->a); - p->stream << ", "; - p->Print(op->b); - p->stream << ")"; -}) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << '('; - p->Print(op->a); - p->stream << " == "; - p->Print(op->b); - p->stream << ')'; -}) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << '('; - p->Print(op->a); - p->stream << " != "; - p->Print(op->b); - p->stream << ')'; -}) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << '('; - p->Print(op->a); - p->stream << " < "; - p->Print(op->b); - p->stream << ')'; -}) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << '('; - p->Print(op->a); - p->stream << " <= "; - p->Print(op->b); - p->stream << ')'; -}) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << '('; - p->Print(op->a); - p->stream << " > "; - p->Print(op->b); - p->stream << ')'; -}) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << '('; - p->Print(op->a); - p->stream << " >= "; - p->Print(op->b); - p->stream << ')'; -}); - -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "floordiv(" << op->a << ", " << op->b << ")"; -}); - -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "floormod(" << op->a << ", " << op->b << ")"; -}); - -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << '('; - p->Print(op->a); - p->stream << " && "; - p->Print(op->b); - p->stream << ')'; -}); - -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << '('; - p->Print(op->a); - p->stream << " || "; - p->Print(op->b); - p->stream << ')'; -}); - -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << '!'; - p->Print(op->a); -}); - -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "select("; - p->Print(op->condition); - p->stream << ", "; - p->Print(op->true_value); - p->stream << ", "; - p->Print(op->false_value); - p->stream << ")"; -}); - -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << op->buffer_var << "["; - p->Print(op->index); - p->stream << "]"; - if (!is_one(op->predicate)) { - p->stream << " if "; - p->Print(op->predicate); - } -}); - -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "ramp("; - p->Print(op->base); - p->stream << ", "; - p->Print(op->stride); - p->stream << ", " << op->lanes << ")"; -}); - -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "x" << op->lanes << "("; - p->Print(op->value); - p->stream << ")"; -}); - -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << op->name << "("; - for (size_t i = 0; i < op->args.size(); ++i) { - p->Print(op->args[i]); - if (i < op->args.size() - 1) { - p->stream << ", "; + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << op->buffer->name << "["; + for (size_t i = 0; i < op->indices.size(); ++i) { + p->Print(op->indices[i]); + if (i < op->indices.size() - 1) { + p->stream << ", "; + } } - } - p->stream << ")"; - }); + p->stream << "]"; + }); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "(let " << op->var << " = "; - p->Print(op->value); - p->stream << " in "; - p->Print(op->body); - p->stream << ")"; -}); +// ProducerLoad +ProducerLoad::ProducerLoad(DataProducer producer, Array indices) { + ObjectPtr node = make_object(); + node->dtype = producer->GetDataType(); + node->producer = std::move(producer); + node->indices = std::move(indices); + data_ = std::move(node); +} -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - p->stream << "?"; -}); +TVM_REGISTER_GLOBAL("tir.ProducerLoad") + .set_body_typed([](DataProducer producer, Array indices) { + return ProducerLoad(producer, indices); + }); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "reduce(combiner=" - << op->combiner; - p->stream << ", source=" << op->source; - p->stream << ", axis=" << op->axis; - p->stream << ", where=" << op->condition; - p->stream << ", value_index=" << op->value_index; - p->stream << ")"; - }); +TVM_REGISTER_NODE_TYPE(ProducerLoadNode); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "comm_reducer(result=" << op->result - << ", lhs=" << op->lhs - << ", rhs=" << op->rhs - << ", identity_element=" << op->identity_element - << ")"; - }); - -TVM_REGISTER_NODE_TYPE(StringImmNode); -TVM_REGISTER_NODE_TYPE(CastNode); -TVM_REGISTER_NODE_TYPE(VarNode); -TVM_REGISTER_NODE_TYPE(SizeVarNode); -TVM_REGISTER_NODE_TYPE(AddNode); -TVM_REGISTER_NODE_TYPE(SubNode); -TVM_REGISTER_NODE_TYPE(MulNode); -TVM_REGISTER_NODE_TYPE(DivNode); -TVM_REGISTER_NODE_TYPE(ModNode); -TVM_REGISTER_NODE_TYPE(FloorDivNode); -TVM_REGISTER_NODE_TYPE(FloorModNode); -TVM_REGISTER_NODE_TYPE(MinNode); -TVM_REGISTER_NODE_TYPE(MaxNode); -TVM_REGISTER_NODE_TYPE(EQNode); -TVM_REGISTER_NODE_TYPE(NENode); -TVM_REGISTER_NODE_TYPE(LTNode); -TVM_REGISTER_NODE_TYPE(LENode); -TVM_REGISTER_NODE_TYPE(GTNode); -TVM_REGISTER_NODE_TYPE(GENode); -TVM_REGISTER_NODE_TYPE(AndNode); -TVM_REGISTER_NODE_TYPE(OrNode); -TVM_REGISTER_NODE_TYPE(NotNode); -TVM_REGISTER_NODE_TYPE(SelectNode); -TVM_REGISTER_NODE_TYPE(LoadNode); -TVM_REGISTER_NODE_TYPE(RampNode); -TVM_REGISTER_NODE_TYPE(BroadcastNode); -TVM_REGISTER_NODE_TYPE(ShuffleNode); -TVM_REGISTER_NODE_TYPE(CommReducerNode); -TVM_REGISTER_NODE_TYPE(ReduceNode); -TVM_REGISTER_NODE_TYPE(AnyNode); - - -TVM_REGISTER_GLOBAL("tir.Add") -.set_body_typed(AddNode::make); - -TVM_REGISTER_GLOBAL("tir.Sub") -.set_body_typed(SubNode::make); - -TVM_REGISTER_GLOBAL("tir.Mul") -.set_body_typed(MulNode::make); - -TVM_REGISTER_GLOBAL("tir.Div") -.set_body_typed(DivNode::make); - -TVM_REGISTER_GLOBAL("tir.Mod") -.set_body_typed(ModNode::make); - -TVM_REGISTER_GLOBAL("tir.FloorDiv") -.set_body_typed(FloorDivNode::make); - -TVM_REGISTER_GLOBAL("tir.FloorMod") -.set_body_typed(FloorModNode::make); - -TVM_REGISTER_GLOBAL("tir.Min") -.set_body_typed(MinNode::make); - -TVM_REGISTER_GLOBAL("tir.Max") -.set_body_typed(MaxNode::make); - -TVM_REGISTER_GLOBAL("tir.EQ") -.set_body_typed(EQNode::make); - -TVM_REGISTER_GLOBAL("tir.NE") -.set_body_typed(NENode::make); - -TVM_REGISTER_GLOBAL("tir.LT") -.set_body_typed(LTNode::make); - -TVM_REGISTER_GLOBAL("tir.LE") -.set_body_typed(LENode::make); - -TVM_REGISTER_GLOBAL("tir.GT") -.set_body_typed(GTNode::make); - -TVM_REGISTER_GLOBAL("tir.GE") -.set_body_typed(GENode::make); - -TVM_REGISTER_GLOBAL("tir.And") -.set_body_typed(AndNode::make); - -TVM_REGISTER_GLOBAL("tir.Or") -.set_body_typed(OrNode::make); - -TVM_REGISTER_GLOBAL("tir.Not") -.set_body_typed(NotNode::make); - -TVM_REGISTER_GLOBAL("tir.Select") -.set_body_typed(SelectNode::make); - -TVM_REGISTER_GLOBAL("tir.Ramp") -.set_body_typed(RampNode::make); - -TVM_REGISTER_GLOBAL("tir.Cast") -.set_body_typed(CastNode::make); - -TVM_REGISTER_GLOBAL("tir.Broadcast") -.set_body_typed(BroadcastNode::make); - -TVM_REGISTER_GLOBAL("tir.Shuffle") -.set_body_typed(ShuffleNode::make); - -TVM_REGISTER_GLOBAL("tir.Let") -.set_body_typed(LetNode::make); - -TVM_REGISTER_GLOBAL("tir.Load") -.set_body([](TVMArgs args, TVMRetValue *ret) { - DataType t = args[0]; - if (args.size() == 3) { - *ret = LoadNode::make(t, args[1], args[2], const_true(t.lanes())); - } else { - *ret = LoadNode::make(t, args[1], args[2], args[3]); - } - }); - -TVM_REGISTER_GLOBAL("tir.Call") -.set_body_typed([]( - DataType type, std::string name, - Array args, int call_type, - FunctionRef func, int value_index -) { - Array prim_expr_args; - for (const auto& it : args) { - CHECK(it->IsInstance() || - it->IsInstance()); - if (const auto* str = it.as()) { - prim_expr_args.push_back(StringImmNode::make(str->data)); - } else { - prim_expr_args.push_back(Downcast(it)); - } - } - return CallNode::make(type, - name, - prim_expr_args, - static_cast(call_type), - func, - value_index); -}); - + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << op->producer->GetNameHint() << "["; + for (size_t i = 0; i < op->indices.size(); ++i) { + p->Print(op->indices[i]); + if (i < op->indices.size() - 1) { + p->stream << ", "; + } + } + p->stream << "]"; + }); } // namespace tir } // namespace tvm diff --git a/src/tir/ir/expr_functor.cc b/src/tir/ir/expr_functor.cc index 57ff627ceaf1..b92127b24e2b 100644 --- a/src/tir/ir/expr_functor.cc +++ b/src/tir/ir/expr_functor.cc @@ -20,6 +20,7 @@ * \file expr_functor.cc */ #include + #include "functor_common.h" namespace tvm { @@ -40,6 +41,10 @@ void ExprVisitor::VisitExpr_(const BufferLoadNode* op) { VisitArray(op->indices, [this](const PrimExpr& e) { this->VisitExpr(e); }); } +void ExprVisitor::VisitExpr_(const ProducerLoadNode* op) { + VisitArray(op->indices, [this](const PrimExpr& e) { this->VisitExpr(e); }); +} + void ExprVisitor::VisitExpr_(const LetNode* op) { this->VisitExpr(op->value); this->VisitExpr(op->body); @@ -49,10 +54,10 @@ void ExprVisitor::VisitExpr_(const CallNode* op) { VisitArray(op->args, [this](const PrimExpr& e) { this->VisitExpr(e); }); } -#define DEFINE_BINOP_VISIT_(OP) \ - void ExprVisitor::VisitExpr_(const OP* op) { \ - this->VisitExpr(op->a); \ - this->VisitExpr(op->b); \ +#define DEFINE_BINOP_VISIT_(OP) \ + void ExprVisitor::VisitExpr_(const OP* op) { \ + this->VisitExpr(op->a); \ + this->VisitExpr(op->b); \ } DEFINE_BINOP_VISIT_(AddNode); @@ -79,20 +84,16 @@ void ExprVisitor::VisitExpr_(const StringImmNode* op) {} void ExprVisitor::VisitExpr_(const ReduceNode* op) { VisitArray(op->axis, [this](const IterVar& r) { - this->VisitExpr(r->dom->min); - this->VisitExpr(r->dom->extent); - }); + this->VisitExpr(r->dom->min); + this->VisitExpr(r->dom->extent); + }); VisitArray(op->source, [this](const PrimExpr& e) { this->VisitExpr(e); }); this->VisitExpr(op->condition); } -void ExprVisitor::VisitExpr_(const CastNode* op) { - this->VisitExpr(op->value); -} +void ExprVisitor::VisitExpr_(const CastNode* op) { this->VisitExpr(op->value); } -void ExprVisitor::VisitExpr_(const NotNode* op) { - this->VisitExpr(op->a); -} +void ExprVisitor::VisitExpr_(const NotNode* op) { this->VisitExpr(op->a); } void ExprVisitor::VisitExpr_(const SelectNode* op) { this->VisitExpr(op->condition); @@ -110,13 +111,9 @@ void ExprVisitor::VisitExpr_(const ShuffleNode* op) { VisitArray(op->vectors, [this](const PrimExpr& e) { this->VisitExpr(e); }); } -void ExprVisitor::VisitExpr_(const BroadcastNode* op) { - this->VisitExpr(op->value); -} +void ExprVisitor::VisitExpr_(const BroadcastNode* op) { this->VisitExpr(op->value); } -PrimExpr ExprMutator::VisitExpr_(const VarNode* op) { - return GetRef(op); -} +PrimExpr ExprMutator::VisitExpr_(const VarNode* op) { return GetRef(op); } PrimExpr ExprMutator::VisitExpr_(const SizeVarNode* op) { return this->VisitExpr_(static_cast(op)); @@ -128,7 +125,7 @@ PrimExpr ExprMutator::VisitExpr_(const LoadNode* op) { if (index.same_as(op->index) && predicate.same_as(op->predicate)) { return GetRef(op); } else { - return LoadNode::make(op->dtype, op->buffer_var, index, predicate); + return Load(op->dtype, op->buffer_var, index, predicate); } } @@ -142,14 +139,23 @@ PrimExpr ExprMutator::VisitExpr_(const BufferLoadNode* op) { } } +PrimExpr ExprMutator::VisitExpr_(const ProducerLoadNode* op) { + auto fmutate = [this](const PrimExpr& e) { return this->VisitExpr(e); }; + Array indices = MutateArray(op->indices, fmutate); + if (indices.same_as(op->indices)) { + return GetRef(op); + } else { + return ProducerLoad(op->producer, indices); + } +} + PrimExpr ExprMutator::VisitExpr_(const LetNode* op) { PrimExpr value = this->VisitExpr(op->value); PrimExpr body = this->VisitExpr(op->body); - if (value.same_as(op->value) && - body.same_as(op->body)) { + if (value.same_as(op->value) && body.same_as(op->body)) { return GetRef(op); } else { - return LetNode::make(op->var, value, body); + return Let(op->var, value, body); } } @@ -160,66 +166,55 @@ PrimExpr ExprMutator::VisitExpr_(const CallNode* op) { if (args.same_as(op->args)) { return GetRef(op); } else { - return CallNode::make(op->dtype, - op->name, - args, - op->call_type, - op->func, - op->value_index); + return Call(op->dtype, op->name, args, op->call_type); } } -#define DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(OP) \ - PrimExpr ExprMutator::VisitExpr_(const OP *op) { \ - return GetRef(op); \ - } +#define DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(OP) \ + PrimExpr ExprMutator::VisitExpr_(const OP* op) { return GetRef(op); } DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(IntImmNode) DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(FloatImmNode) DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(StringImmNode) -#define DEFINE_BIOP_EXPR_MUTATE_(OP) \ - PrimExpr ExprMutator::VisitExpr_(const OP* op) { \ - PrimExpr a = this->VisitExpr(op->a); \ - PrimExpr b = this->VisitExpr(op->b); \ - if (a.same_as(op->a) && \ - b.same_as(op->b)) { \ - return GetRef(op); \ - } else { \ - return OP::make(a, b); \ - } \ +#define DEFINE_BIOP_EXPR_MUTATE_(OP) \ + PrimExpr ExprMutator::VisitExpr_(const OP##Node* op) { \ + PrimExpr a = this->VisitExpr(op->a); \ + PrimExpr b = this->VisitExpr(op->b); \ + if (a.same_as(op->a) && b.same_as(op->b)) { \ + return GetRef(op); \ + } else { \ + return OP(a, b); \ + } \ } -DEFINE_BIOP_EXPR_MUTATE_(AddNode); -DEFINE_BIOP_EXPR_MUTATE_(SubNode); -DEFINE_BIOP_EXPR_MUTATE_(MulNode); -DEFINE_BIOP_EXPR_MUTATE_(DivNode); -DEFINE_BIOP_EXPR_MUTATE_(ModNode); -DEFINE_BIOP_EXPR_MUTATE_(FloorDivNode); -DEFINE_BIOP_EXPR_MUTATE_(FloorModNode); -DEFINE_BIOP_EXPR_MUTATE_(MinNode); -DEFINE_BIOP_EXPR_MUTATE_(MaxNode); -DEFINE_BIOP_EXPR_MUTATE_(EQNode); -DEFINE_BIOP_EXPR_MUTATE_(NENode); -DEFINE_BIOP_EXPR_MUTATE_(LTNode); -DEFINE_BIOP_EXPR_MUTATE_(LENode); -DEFINE_BIOP_EXPR_MUTATE_(GTNode); -DEFINE_BIOP_EXPR_MUTATE_(GENode); -DEFINE_BIOP_EXPR_MUTATE_(AndNode); -DEFINE_BIOP_EXPR_MUTATE_(OrNode); +DEFINE_BIOP_EXPR_MUTATE_(Add); +DEFINE_BIOP_EXPR_MUTATE_(Sub); +DEFINE_BIOP_EXPR_MUTATE_(Mul); +DEFINE_BIOP_EXPR_MUTATE_(Div); +DEFINE_BIOP_EXPR_MUTATE_(Mod); +DEFINE_BIOP_EXPR_MUTATE_(FloorDiv); +DEFINE_BIOP_EXPR_MUTATE_(FloorMod); +DEFINE_BIOP_EXPR_MUTATE_(Min); +DEFINE_BIOP_EXPR_MUTATE_(Max); +DEFINE_BIOP_EXPR_MUTATE_(EQ); +DEFINE_BIOP_EXPR_MUTATE_(NE); +DEFINE_BIOP_EXPR_MUTATE_(LT); +DEFINE_BIOP_EXPR_MUTATE_(LE); +DEFINE_BIOP_EXPR_MUTATE_(GT); +DEFINE_BIOP_EXPR_MUTATE_(GE); +DEFINE_BIOP_EXPR_MUTATE_(And); +DEFINE_BIOP_EXPR_MUTATE_(Or); PrimExpr ExprMutator::VisitExpr_(const ReduceNode* op) { - auto fitervar = [this](const IterVar& v) { + auto fitervar = [this](const IterVar& v) { Range r = v->dom; PrimExpr min = this->VisitExpr(r->min); PrimExpr extent = this->VisitExpr(r->extent); - if (min.same_as(r->min) && - extent.same_as(r->extent)) { + if (min.same_as(r->min) && extent.same_as(r->extent)) { return v; } else { - return IterVarNode::make( - Range::make_by_min_extent(min, extent), - v->var, v->iter_type, v->thread_tag); + return IterVar(Range::make_by_min_extent(min, extent), v->var, v->iter_type, v->thread_tag); } }; Array axis = MutateArray(op->axis, fitervar); @@ -229,13 +224,10 @@ PrimExpr ExprMutator::VisitExpr_(const ReduceNode* op) { PrimExpr condition = this->VisitExpr(op->condition); - if (axis.same_as(op->axis) && - source.same_as(op->source) && - condition.same_as(op->condition)) { + if (axis.same_as(op->axis) && source.same_as(op->source) && condition.same_as(op->condition)) { return GetRef(op); } else { - return ReduceNode::make( - op->combiner, source, axis, condition, op->value_index); + return Reduce(op->combiner, source, axis, condition, op->value_index); } } @@ -244,7 +236,7 @@ PrimExpr ExprMutator::VisitExpr_(const CastNode* op) { if (value.same_as(op->value)) { return GetRef(op); } else { - return CastNode::make(op->dtype, value); + return Cast(op->dtype, value); } } @@ -253,7 +245,7 @@ PrimExpr ExprMutator::VisitExpr_(const NotNode* op) { if (a.same_as(op->a)) { return GetRef(op); } else { - return NotNode::make(a); + return Not(a); } } @@ -261,23 +253,21 @@ PrimExpr ExprMutator::VisitExpr_(const SelectNode* op) { PrimExpr condition = this->VisitExpr(op->condition); PrimExpr true_value = this->VisitExpr(op->true_value); PrimExpr false_value = this->VisitExpr(op->false_value); - if (condition.same_as(op->condition) && - true_value.same_as(op->true_value) && + if (condition.same_as(op->condition) && true_value.same_as(op->true_value) && false_value.same_as(op->false_value)) { return GetRef(op); } else { - return SelectNode::make(condition, true_value, false_value); + return Select(condition, true_value, false_value); } } PrimExpr ExprMutator::VisitExpr_(const RampNode* op) { PrimExpr base = this->VisitExpr(op->base); PrimExpr stride = this->VisitExpr(op->stride); - if (base.same_as(op->base) && - stride.same_as(op->stride)) { + if (base.same_as(op->base) && stride.same_as(op->stride)) { return GetRef(op); } else { - return RampNode::make(base, stride, op->lanes); + return Ramp(base, stride, op->lanes); } } @@ -286,7 +276,7 @@ PrimExpr ExprMutator::VisitExpr_(const BroadcastNode* op) { if (value.same_as(op->value)) { return GetRef(op); } else { - return BroadcastNode::make(value, op->lanes); + return Broadcast(value, op->lanes); } } @@ -296,7 +286,7 @@ PrimExpr ExprMutator::VisitExpr_(const ShuffleNode* op) { if (vectors.same_as(op->vectors)) { return GetRef(op); } else { - return ShuffleNode::make(vectors, op->indices); + return Shuffle(vectors, op->indices); } } diff --git a/src/tir/ir/function.cc b/src/tir/ir/function.cc index ecaad586f894..1149e039cae4 100644 --- a/src/tir/ir/function.cc +++ b/src/tir/ir/function.cc @@ -29,11 +29,8 @@ namespace tvm { namespace tir { // Get the function type of a PrimFunc -PrimFunc::PrimFunc(Array params, - Stmt body, - Type ret_type, - Map buffer_map, - DictAttrs attrs) { +PrimFunc::PrimFunc(Array params, Stmt body, Type ret_type, + Map buffer_map, DictAttrs attrs) { // Assume void-return type for now // TODO(tvm-team) consider type deduction from body. if (!ret_type.defined()) { @@ -60,29 +57,25 @@ FuncType PrimFuncNode::func_type_annotation() const { TVM_REGISTER_NODE_TYPE(PrimFuncNode); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - // TODO(tvm-team) redirect to Text printer once we have a good text format. - auto* node = static_cast(ref.get()); - p->stream << "PrimFunc(" << node->params << ") "; - if (node->attrs.defined()) { - p->stream << "attrs=" << node->attrs; - } - p->stream << " {\n"; - p->indent += 2; - p->Print(node->body); - p->indent -= 2; - p->stream << "}\n"; -}); - + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + // TODO(tvm-team) redirect to Text printer once we have a good text format. + auto* node = static_cast(ref.get()); + p->stream << "PrimFunc(" << node->params << ") "; + if (node->attrs.defined()) { + p->stream << "attrs=" << node->attrs; + } + p->stream << " {\n"; + p->indent += 2; + p->Print(node->body); + p->indent -= 2; + p->stream << "}\n"; + }); TVM_REGISTER_GLOBAL("tir.PrimFunc") -.set_body_typed([](Array params, - Stmt body, - Type ret_type, - Map buffer_map, - DictAttrs attrs) { - return PrimFunc(params, body, ret_type, buffer_map, attrs); -}); + .set_body_typed([](Array params, Stmt body, Type ret_type, + Map buffer_map, DictAttrs attrs) { + return PrimFunc(params, body, ret_type, buffer_map, attrs); + }); } // namespace tir } // namespace tvm diff --git a/src/tir/ir/functor_common.h b/src/tir/ir/functor_common.h index 76a91ea42d42..f63dcfe003c6 100644 --- a/src/tir/ir/functor_common.h +++ b/src/tir/ir/functor_common.h @@ -27,7 +27,7 @@ namespace tvm { namespace tir { // Implementation of Visitors -template +template inline void VisitArray(const Array& arr, F fvisit) { for (size_t i = 0; i < arr.size(); i++) { fvisit(arr[i]); @@ -35,10 +35,8 @@ inline void VisitArray(const Array& arr, F fvisit) { } // Implementation of mutators -template -inline Array MutateArray(const Array& arr, - F fmutate, - bool allow_copy_on_write = false) { +template +inline Array MutateArray(const Array& arr, F fmutate, bool allow_copy_on_write = false) { if (allow_copy_on_write) { // if we allow copy on write, we can directly // call the inplace mutate function. diff --git a/src/tir/ir/op.cc b/src/tir/ir/op.cc index 4ad244ff02b2..5ac9f5902c12 100644 --- a/src/tir/ir/op.cc +++ b/src/tir/ir/op.cc @@ -24,6 +24,7 @@ #include #include #include + #include // Centralized header for constant folders. #include "../../arith/const_fold.h" @@ -32,15 +33,15 @@ namespace tvm { using namespace tir; - runtime::DataType GetRuntimeDataType(const Type& type) { - if (auto * n = type.as()) { + if (auto* n = type.as()) { return n->dtype; } else if (type.as()) { return DataType::Handle(); + } else if (IsVoidType(type)) { + return DataType::Void(); } else { - LOG(FATAL) << "Type " << type - << " does not have a corresponding runtime::DataType"; + LOG(FATAL) << "Type " << type << " does not have a corresponding runtime::DataType"; return DataType::Handle(); } } @@ -57,9 +58,8 @@ Type GetType(const PrimExpr& expr) { } // Default: return the type indicated by the dtype. runtime::DataType dtype = expr.dtype(); - // These types already implies the specific type. - if (dtype.is_int() || dtype.is_uint() || dtype.is_float()) { - return PrimType(dtype); + if (dtype.is_void()) { + return VoidType(); } return PrimType(dtype); } @@ -67,15 +67,13 @@ Type GetType(const PrimExpr& expr) { // simple cast that only checks if type matches and cast inline PrimExpr SimpleCast(const DataType& t, PrimExpr value) { if (value.dtype() == t) return value; - return tir::CastNode::make(t, value); + return tir::Cast(t, value); } PrimExpr LargeUIntImm(DataType t, int64_t low, int64_t high) { - return tir::CallNode::make( - t, tir::intrinsic::tvm_large_uint_imm, - {make_const(DataType::UInt(32), low), - make_const(DataType::UInt(32), high)}, - tir::CallNode::PureIntrinsic); + return tir::Call(t, tir::intrinsic::tvm_large_uint_imm, + {make_const(DataType::UInt(32), low), make_const(DataType::UInt(32), high)}, + tir::CallNode::PureIntrinsic); } // The public function with a quick checking path. @@ -84,12 +82,11 @@ void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs) { // NOLINT(*) DataType ltype = lhs.dtype(); DataType rtype = rhs.dtype(); if (ltype.lanes() == 1 && rtype.lanes() != 1) { - lhs = tir::BroadcastNode::make(lhs, rtype.lanes()); + lhs = tir::Broadcast(lhs, rtype.lanes()); } else if (rtype.lanes() == 1 && ltype.lanes() != 1) { - rhs = tir::BroadcastNode::make(rhs, ltype.lanes()); + rhs = tir::Broadcast(rhs, ltype.lanes()); } else { - CHECK(ltype.lanes() == rtype.lanes()) - << "Cannot match type " << ltype << " vs " << rtype; + CHECK(ltype.lanes() == rtype.lanes()) << "Cannot match type " << ltype << " vs " << rtype; } if (lhs.dtype() == rhs.dtype()) return; // Only do very simple type coversion @@ -196,8 +193,8 @@ PrimExpr infinity(const DataType& dtype) { } namespace tir { -template -inline bool ConstPowerHelper(ValueType val, int *shift) { +template +inline bool ConstPowerHelper(ValueType val, int* shift) { if (val <= 0) return false; shift[0] = 0; while (val != 0) { @@ -229,7 +226,7 @@ PrimExpr cast(const DataType& t, PrimExpr value) { } else if (const FloatImmNode* op = value.as()) { return make_const(t, op->value); } - return tir::CastNode::make(t, value); + return tir::Cast(t, value); } else { if (value.dtype().lanes() == 1) { // manually unroll cast @@ -240,34 +237,33 @@ PrimExpr cast(const DataType& t, PrimExpr value) { } else if (const FloatImmNode* op = value.as()) { value = make_const(vtype, op->value); } else { - value = tir::CastNode::make(vtype, value); + value = tir::Cast(vtype, value); } } - return tir::BroadcastNode::make(value, t.lanes()); + return tir::Broadcast(value, t.lanes()); } else { CHECK(value.dtype().lanes() == t.lanes()); - return tir::CastNode::make(t, value); + return tir::Cast(t, value); } } } PrimExpr reinterpret(const DataType& t, PrimExpr value) { if (value.dtype() == t) return value; - return tir::CallNode::make( - t, tir::CallNode::reinterpret, { value }, tir::CallNode::PureIntrinsic); + return tir::Call(t, tir::CallNode::reinterpret, {value}, tir::CallNode::PureIntrinsic); } PrimExpr operator+(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); - PrimExpr ret = arith::TryConstFold(a, b); + PrimExpr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return tir::AddNode::make(a, b); + return tir::Add(a, b); } // negation PrimExpr operator-(PrimExpr a) { - using tir::IntImmNode; using tir::FloatImmNode; + using tir::IntImmNode; const IntImmNode* pa = a.as(); const FloatImmNode* fa = a.as(); if (pa) return IntImm(a.dtype(), -pa->value); @@ -277,23 +273,23 @@ PrimExpr operator-(PrimExpr a) { PrimExpr operator-(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); - PrimExpr ret = arith::TryConstFold(a, b); + PrimExpr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return tir::SubNode::make(a, b); + return tir::Sub(a, b); } PrimExpr operator*(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); - PrimExpr ret = arith::TryConstFold(a, b); + PrimExpr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return tir::MulNode::make(a, b); + return tir::Mul(a, b); } PrimExpr div(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); - PrimExpr ret = arith::TryConstFold(a, b); + PrimExpr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return tir::DivNode::make(a, b); + return tir::Div(a, b); } PrimExpr truncdiv(PrimExpr a, PrimExpr b) { @@ -304,72 +300,64 @@ PrimExpr truncdiv(PrimExpr a, PrimExpr b) { PrimExpr truncmod(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); - PrimExpr ret = arith::TryConstFold(a, b); + PrimExpr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return tir::ModNode::make(a, b); + return tir::Mod(a, b); } -PrimExpr operator/(PrimExpr a, PrimExpr b) { - return div(a, b); -} +PrimExpr operator/(PrimExpr a, PrimExpr b) { return div(a, b); } -PrimExpr operator%(PrimExpr a, PrimExpr b) { - return truncmod(a, b); -} +PrimExpr operator%(PrimExpr a, PrimExpr b) { return truncmod(a, b); } // TODO(tqchen): switch to floordiv -PrimExpr indexdiv(PrimExpr a, PrimExpr b) { - return floordiv(a, b); -} +PrimExpr indexdiv(PrimExpr a, PrimExpr b) { return floordiv(a, b); } -PrimExpr indexmod(PrimExpr a, PrimExpr b) { - return floormod(a, b); -} +PrimExpr indexmod(PrimExpr a, PrimExpr b) { return floormod(a, b); } PrimExpr floordiv(PrimExpr a, PrimExpr b) { CHECK(a.dtype().is_int() || a.dtype().is_uint()) << a; CHECK(b.dtype().is_int() || b.dtype().is_uint()) << b; BinaryOpMatchTypes(a, b); - PrimExpr ret = arith::TryConstFold(a, b); + PrimExpr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return tir::FloorDivNode::make(a, b); + return tir::FloorDiv(a, b); } PrimExpr floormod(PrimExpr a, PrimExpr b) { CHECK(a.dtype().is_int() || a.dtype().is_uint()) << a; CHECK(b.dtype().is_int() || b.dtype().is_uint()) << b; BinaryOpMatchTypes(a, b); - PrimExpr ret = arith::TryConstFold(a, b); + PrimExpr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return tir::FloorModNode::make(a, b); + return tir::FloorMod(a, b); } PrimExpr min(PrimExpr a, PrimExpr b) { // inf-aware simplificaiton - using arith::is_pos_inf; using arith::is_neg_inf; + using arith::is_pos_inf; if (is_pos_inf(a)) return b; if (is_neg_inf(a)) return a; if (is_pos_inf(b)) return a; if (is_neg_inf(b)) return b; BinaryOpMatchTypes(a, b); - PrimExpr ret = arith::TryConstFold(a, b); + PrimExpr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return tir::MinNode::make(a, b); + return tir::Min(a, b); } PrimExpr max(PrimExpr a, PrimExpr b) { // inf-aware simplificaiton - using arith::is_pos_inf; using arith::is_neg_inf; + using arith::is_pos_inf; if (is_pos_inf(a)) return a; if (is_neg_inf(a)) return b; if (is_pos_inf(b)) return b; if (is_neg_inf(b)) return a; BinaryOpMatchTypes(a, b); - PrimExpr ret = arith::TryConstFold(a, b); + PrimExpr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return tir::MaxNode::make(a, b); + return tir::Max(a, b); } PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value) { @@ -383,84 +371,78 @@ PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value) return false_value; } } - return tir::CallNode::make( - true_value.dtype(), - tir::intrinsic::tvm_if_then_else, - {cond, true_value, false_value}, - tir::CallNode::PureIntrinsic); + return tir::Call(true_value.dtype(), tir::intrinsic::tvm_if_then_else, + {cond, true_value, false_value}, tir::CallNode::PureIntrinsic); } PrimExpr likely(PrimExpr cond) { if (is_const(cond)) return cond; - return tir::CallNode::make(cond.dtype(), - tir::CallNode::likely, - { cond }, - tir::CallNode::PureIntrinsic); + return tir::Call(cond.dtype(), tir::CallNode::likely, {cond}, tir::CallNode::PureIntrinsic); } PrimExpr operator>(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); - PrimExpr ret = arith::TryConstFold(a, b); + PrimExpr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return tir::GTNode::make(a, b); + return tir::GT(a, b); } PrimExpr operator>=(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); - PrimExpr ret = arith::TryConstFold(a, b); + PrimExpr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return tir::GENode::make(a, b); + return tir::GE(a, b); } PrimExpr operator<(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); - PrimExpr ret = arith::TryConstFold(a, b); + PrimExpr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return tir::LTNode::make(a, b); + return tir::LT(a, b); } PrimExpr operator<=(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); - PrimExpr ret = arith::TryConstFold(a, b); + PrimExpr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return tir::LENode::make(a, b); + return tir::LE(a, b); } PrimExpr operator==(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); - PrimExpr ret = arith::TryConstFold(a, b); + PrimExpr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return tir::EQNode::make(a, b); + return tir::EQ(a, b); } PrimExpr operator!=(PrimExpr a, PrimExpr b) { BinaryOpMatchTypes(a, b); - PrimExpr ret = arith::TryConstFold(a, b); + PrimExpr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return tir::NENode::make(a, b); + return tir::NE(a, b); } PrimExpr operator&&(PrimExpr a, PrimExpr b) { CHECK(a.dtype().is_bool()); CHECK(b.dtype().is_bool()); - PrimExpr ret = arith::TryConstFold(a, b); + PrimExpr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return tir::AndNode::make(a, b); + return tir::And(a, b); } PrimExpr operator||(PrimExpr a, PrimExpr b) { CHECK(a.dtype().is_bool()); CHECK(b.dtype().is_bool()); - PrimExpr ret = arith::TryConstFold(a, b); + PrimExpr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; - return tir::OrNode::make(a, b); + return tir::Or(a, b); } PrimExpr operator!(PrimExpr a) { CHECK(a.dtype().is_bool()); - PrimExpr ret = arith::TryConstFold(a); + PrimExpr ret = arith::TryConstFold(a); if (ret.defined()) return ret; - return tir::NotNode::make(a); + return tir::Not(a); } PrimExpr operator>>(PrimExpr a, PrimExpr b) { @@ -468,17 +450,17 @@ PrimExpr operator>>(PrimExpr a, PrimExpr b) { CHECK(b.dtype().is_int() || b.dtype().is_uint()); BinaryOpMatchTypes(a, b); TVM_INDEX_CONST_PROPAGATION({ - const DataType& rtype = a.dtype(); - if (pb) CHECK(pb->value >= 0 && pb->value < rtype.bits()) << - "Shift amount must be non-negative and less than " << rtype.bits() - << " for type " << rtype; - if (pa && pb) return IntImm(rtype, (pa->value >> pb->value)); - if (pb) { - if (pb->value == 0) return a; - } - }); - return tir::CallNode::make( - a.dtype(), tir::CallNode::shift_right, { a, b }, tir::CallNode::PureIntrinsic); + const DataType& rtype = a.dtype(); + if (pb) + CHECK(pb->value >= 0 && pb->value < rtype.bits()) + << "Shift amount must be non-negative and less than " << rtype.bits() << " for type " + << rtype; + if (pa && pb) return IntImm(rtype, (pa->value >> pb->value)); + if (pb) { + if (pb->value == 0) return a; + } + }); + return tir::Call(a.dtype(), tir::CallNode::shift_right, {a, b}, tir::CallNode::PureIntrinsic); } PrimExpr operator<<(PrimExpr a, PrimExpr b) { @@ -486,17 +468,17 @@ PrimExpr operator<<(PrimExpr a, PrimExpr b) { CHECK(b.dtype().is_int() || b.dtype().is_uint()); BinaryOpMatchTypes(a, b); TVM_INDEX_CONST_PROPAGATION({ - const DataType& rtype = a.dtype(); - if (pb) CHECK(pb->value >= 0 && pb->value < rtype.bits()) << - "Shift amount must be non-negative and less than " << rtype.bits() - << " for type " << rtype; - if (pa && pb) return IntImm(rtype, (pa->value << pb->value)); - if (pb) { - if (pb->value == 0) return a; - } - }); - return tir::CallNode::make( - a.dtype(), tir::CallNode::shift_left, { a, b }, tir::CallNode::PureIntrinsic); + const DataType& rtype = a.dtype(); + if (pb) + CHECK(pb->value >= 0 && pb->value < rtype.bits()) + << "Shift amount must be non-negative and less than " << rtype.bits() << " for type " + << rtype; + if (pa && pb) return IntImm(rtype, (pa->value << pb->value)); + if (pb) { + if (pb->value == 0) return a; + } + }); + return tir::Call(a.dtype(), tir::CallNode::shift_left, {a, b}, tir::CallNode::PureIntrinsic); } PrimExpr operator&(PrimExpr a, PrimExpr b) { @@ -504,11 +486,10 @@ PrimExpr operator&(PrimExpr a, PrimExpr b) { CHECK(b.dtype().is_int() || b.dtype().is_uint()); BinaryOpMatchTypes(a, b); TVM_INDEX_CONST_PROPAGATION({ - const DataType& rtype = a.dtype(); - if (pa && pb) return IntImm(rtype, (pa->value & pb->value)); - }); - return tir::CallNode::make( - a.dtype(), tir::CallNode::bitwise_and, { a, b }, tir::CallNode::PureIntrinsic); + const DataType& rtype = a.dtype(); + if (pa && pb) return IntImm(rtype, (pa->value & pb->value)); + }); + return tir::Call(a.dtype(), tir::CallNode::bitwise_and, {a, b}, tir::CallNode::PureIntrinsic); } PrimExpr operator|(PrimExpr a, PrimExpr b) { @@ -516,11 +497,10 @@ PrimExpr operator|(PrimExpr a, PrimExpr b) { CHECK(b.dtype().is_int() || b.dtype().is_uint()); BinaryOpMatchTypes(a, b); TVM_INDEX_CONST_PROPAGATION({ - const DataType& rtype = a.dtype(); - if (pa && pb) return IntImm(rtype, (pa->value | pb->value)); - }); - return tir::CallNode::make( - a.dtype(), tir::CallNode::bitwise_or, { a, b }, tir::CallNode::PureIntrinsic); + const DataType& rtype = a.dtype(); + if (pa && pb) return IntImm(rtype, (pa->value | pb->value)); + }); + return tir::Call(a.dtype(), tir::CallNode::bitwise_or, {a, b}, tir::CallNode::PureIntrinsic); } PrimExpr operator^(PrimExpr a, PrimExpr b) { @@ -528,24 +508,21 @@ PrimExpr operator^(PrimExpr a, PrimExpr b) { CHECK(b.dtype().is_int() || b.dtype().is_uint()); BinaryOpMatchTypes(a, b); TVM_INDEX_CONST_PROPAGATION({ - const DataType& rtype = a.dtype(); - if (pa && pb) return IntImm(rtype, (pa->value ^ pb->value)); - }); - return tir::CallNode::make( - a.dtype(), tir::CallNode::bitwise_xor, { a, b }, tir::CallNode::PureIntrinsic); + const DataType& rtype = a.dtype(); + if (pa && pb) return IntImm(rtype, (pa->value ^ pb->value)); + }); + return tir::Call(a.dtype(), tir::CallNode::bitwise_xor, {a, b}, tir::CallNode::PureIntrinsic); } PrimExpr operator~(PrimExpr a) { CHECK(a.dtype().is_int() || a.dtype().is_uint()); - return tir::CallNode::make( - a.dtype(), tir::CallNode::bitwise_not, { a }, tir::CallNode::PureIntrinsic); + return tir::Call(a.dtype(), tir::CallNode::bitwise_not, {a}, tir::CallNode::PureIntrinsic); } PrimExpr pow(PrimExpr x, PrimExpr y) { BinaryOpMatchTypes(x, y); CHECK(x.dtype().is_float()) << "power only applies to float"; - return tir::CallNode::make( - x.dtype(), "pow", { x, y }, tir::CallNode::PureIntrinsic); + return tir::Call(x.dtype(), "pow", {x, y}, tir::CallNode::PureIntrinsic); } PrimExpr abs(PrimExpr x) { @@ -555,19 +532,19 @@ PrimExpr abs(PrimExpr x) { if (px) { return IntImm(x.dtype(), std::abs(px->value)); } - return tir::SelectNode::make(x >= make_zero(x.dtype()), x, -x); + return tir::Select(x >= make_zero(x.dtype()), x, -x); } else if (x.dtype().is_float()) { using tir::FloatImmNode; const FloatImmNode* fx = x.as(); if (fx) { return FloatImm(x.dtype(), std::fabs(fx->value)); } - return tir::CallNode::make(x.dtype(), "fabs", {x}, tir::CallNode::PureIntrinsic); + return tir::Call(x.dtype(), "fabs", {x}, tir::CallNode::PureIntrinsic); } else if (x.dtype().is_uint()) { return x; } else { LOG(FATAL) << "Data type " << x.dtype() - <<" not supported for absolute op. Skipping absolute op..."; + << " not supported for absolute op. Skipping absolute op..."; return x; } } @@ -583,15 +560,14 @@ PrimExpr isnan(PrimExpr x) { return make_const(t, std::isnan(fx->value)); } if (x.dtype().bits() == 16) { - return tir::CallNode::make(t, tir::CallNode::isnan, - {cast(DataType::Float(32, t.lanes()), std::move(x))}, - tir::CallNode::PureIntrinsic); + return tir::Call(t, tir::CallNode::isnan, + {cast(DataType::Float(32, t.lanes()), std::move(x))}, + tir::CallNode::PureIntrinsic); } else { - return tir::CallNode::make(t, tir::CallNode::isnan, {x}, tir::CallNode::PureIntrinsic); + return tir::Call(t, tir::CallNode::isnan, {x}, tir::CallNode::PureIntrinsic); } } else { - LOG(FATAL) << "Data type " << x.dtype() - <<" not supported for isnan op. Skipping isnan op..."; + LOG(FATAL) << "Data type " << x.dtype() << " not supported for isnan op. Skipping isnan op..."; return x; } } @@ -613,64 +589,58 @@ PrimExpr isfinite(PrimExpr x) { return !isinf(x) && !isnan(x); } PrimExpr sum(PrimExpr source, Array rdom) { Var x("x", source.dtype()), y("y", source.dtype()); - PrimExpr result = tir::AddNode::make(x, y); + PrimExpr result = tir::Add(x, y); PrimExpr identity_element = make_zero(source.dtype()); - tir::CommReducer combiner = - tir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); - return tir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); + tir::CommReducer combiner = tir::CommReducer({x}, {y}, {result}, {identity_element}); + return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); } PrimExpr all(PrimExpr source, Array rdom) { CHECK(source.dtype().is_bool()); Var x("x", source.dtype()), y("y", source.dtype()); - PrimExpr result = tir::AndNode::make(x, y); + PrimExpr result = tir::And(x, y); PrimExpr identity_element = make_const(source.dtype(), true); - tir::CommReducer combiner = - tir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); - return tir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); + tir::CommReducer combiner = tir::CommReducer({x}, {y}, {result}, {identity_element}); + return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); } PrimExpr any(PrimExpr source, Array rdom) { CHECK(source.dtype().is_bool()); Var x("x", source.dtype()), y("y", source.dtype()); - PrimExpr result = tir::OrNode::make(x, y); + PrimExpr result = tir::Or(x, y); PrimExpr identity_element = make_const(source.dtype(), false); - tir::CommReducer combiner = - tir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); - return tir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); + tir::CommReducer combiner = tir::CommReducer({x}, {y}, {result}, {identity_element}); + return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); } PrimExpr max(PrimExpr source, Array rdom) { Var x("x", source.dtype()), y("y", source.dtype()); - PrimExpr result = tir::MaxNode::make(x, y); + PrimExpr result = tir::Max(x, y); PrimExpr identity_element = min_value(source.dtype()); - tir::CommReducer combiner = - tir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); - return tir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); + tir::CommReducer combiner = tir::CommReducer({x}, {y}, {result}, {identity_element}); + return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); } PrimExpr min(PrimExpr source, Array rdom) { Var x("x", source.dtype()), y("y", source.dtype()); - PrimExpr result = tir::MinNode::make(x, y); + PrimExpr result = tir::Min(x, y); PrimExpr identity_element = max_value(source.dtype()); - tir::CommReducer combiner = - tir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); - return tir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); + tir::CommReducer combiner = tir::CommReducer({x}, {y}, {result}, {identity_element}); + return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); } PrimExpr prod(PrimExpr source, Array rdom) { Var x("x", source.dtype()), y("y", source.dtype()); - PrimExpr result = tir::MulNode::make(x, y); + PrimExpr result = tir::Mul(x, y); PrimExpr identity_element = make_const(source.dtype(), 1); - tir::CommReducer combiner = - tir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); - return tir::ReduceNode::make(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); + tir::CommReducer combiner = tir::CommReducer({x}, {y}, {result}, {identity_element}); + return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0); } PrimExpr fmod(PrimExpr x, PrimExpr y) { BinaryOpMatchTypes(x, y); CHECK(x.dtype().is_float()) << "fmod only applies to float"; - return tir::CallNode::make(x.dtype(), "fmod", { x, y }, tir::CallNode::PureIntrinsic); + return tir::Call(x.dtype(), "fmod", {x, y}, tir::CallNode::PureIntrinsic); } PrimExpr floor(PrimExpr x) { @@ -680,7 +650,7 @@ PrimExpr floor(PrimExpr x) { using tir::FloatImmNode; const FloatImmNode* fx = x.as(); if (fx) return FloatImm(x.dtype(), std::floor(fx->value)); - return tir::CallNode::make(x.dtype(), "floor", {x}, tir::CallNode::PureIntrinsic); + return tir::Call(x.dtype(), "floor", {x}, tir::CallNode::PureIntrinsic); } PrimExpr ceil(PrimExpr x) { @@ -690,7 +660,7 @@ PrimExpr ceil(PrimExpr x) { using tir::FloatImmNode; const FloatImmNode* fx = x.as(); if (fx) return FloatImm(x.dtype(), std::ceil(fx->value)); - return tir::CallNode::make(x.dtype(), "ceil", {x}, tir::CallNode::PureIntrinsic); + return tir::Call(x.dtype(), "ceil", {x}, tir::CallNode::PureIntrinsic); } PrimExpr round(PrimExpr x) { @@ -700,7 +670,7 @@ PrimExpr round(PrimExpr x) { using tir::FloatImmNode; const FloatImmNode* fx = x.as(); if (fx) return FloatImm(x.dtype(), std::nearbyint(fx->value)); - return tir::CallNode::make(x.dtype(), "round", {x}, tir::CallNode::PureIntrinsic); + return tir::Call(x.dtype(), "round", {x}, tir::CallNode::PureIntrinsic); } PrimExpr nearbyint(PrimExpr x) { @@ -710,7 +680,7 @@ PrimExpr nearbyint(PrimExpr x) { using tir::FloatImmNode; const FloatImmNode* fx = x.as(); if (fx) return FloatImm(x.dtype(), std::nearbyint(fx->value)); - return tir::CallNode::make(x.dtype(), "nearbyint", {x}, tir::CallNode::PureIntrinsic); + return tir::Call(x.dtype(), "nearbyint", {x}, tir::CallNode::PureIntrinsic); } PrimExpr trunc(PrimExpr x) { @@ -720,91 +690,67 @@ PrimExpr trunc(PrimExpr x) { using tir::FloatImmNode; const FloatImmNode* fx = x.as(); if (fx) { - return FloatImm(x.dtype(), (fx->value < 0 ? std::ceil(fx->value) : - std::floor(fx->value))); + return FloatImm(x.dtype(), (fx->value < 0 ? std::ceil(fx->value) : std::floor(fx->value))); } - return tir::CallNode::make(x.dtype(), "trunc", {x}, tir::CallNode::PureIntrinsic); + return tir::Call(x.dtype(), "trunc", {x}, tir::CallNode::PureIntrinsic); } - // expose basic functions to node namespace -TVM_REGISTER_GLOBAL("node._const") -.set_body([](TVMArgs args, TVMRetValue* ret) { - if (args[0].type_code() == kDLInt) { - *ret = tir::make_const(args[1], args[0].operator int64_t()); - } else if (args[0].type_code() == kDLFloat) { - *ret = tir::make_const(args[1], args[0].operator double()); - } else { - LOG(FATAL) << "only accept int or float"; - } - }); - -TVM_REGISTER_GLOBAL("node.LargeUIntImm") -.set_body_typed(LargeUIntImm); - -TVM_REGISTER_GLOBAL("node.String") -.set_body_typed(tir::StringImmNode::make); - -TVM_REGISTER_GLOBAL("tir.min_value") -.set_body_typed(min_value); +TVM_REGISTER_GLOBAL("node._const").set_body([](TVMArgs args, TVMRetValue* ret) { + if (args[0].type_code() == kDLInt) { + *ret = tir::make_const(args[1], args[0].operator int64_t()); + } else if (args[0].type_code() == kDLFloat) { + *ret = tir::make_const(args[1], args[0].operator double()); + } else { + LOG(FATAL) << "only accept int or float"; + } +}); -TVM_REGISTER_GLOBAL("tir.max_value") -.set_body_typed(max_value); +TVM_REGISTER_GLOBAL("node.LargeUIntImm").set_body_typed(LargeUIntImm); -TVM_REGISTER_GLOBAL("tir.abs") -.set_body_typed(tvm::abs); +TVM_REGISTER_GLOBAL("tir.min_value").set_body_typed(min_value); -TVM_REGISTER_GLOBAL("tir.isnan") -.set_body_typed(tvm::isnan); +TVM_REGISTER_GLOBAL("tir.max_value").set_body_typed(max_value); -TVM_REGISTER_GLOBAL("tir.isfinite") -.set_body_typed(tvm::isfinite); +TVM_REGISTER_GLOBAL("tir.abs").set_body_typed(tvm::abs); -TVM_REGISTER_GLOBAL("tir.isinf") -.set_body_typed(tvm::isinf); +TVM_REGISTER_GLOBAL("tir.isnan").set_body_typed(tvm::isnan); -TVM_REGISTER_GLOBAL("tir.floor") -.set_body_typed(tvm::floor); +TVM_REGISTER_GLOBAL("tir.isfinite").set_body_typed(tvm::isfinite); -TVM_REGISTER_GLOBAL("tir.ceil") -.set_body_typed(tvm::ceil); +TVM_REGISTER_GLOBAL("tir.isinf").set_body_typed(tvm::isinf); -TVM_REGISTER_GLOBAL("tir.round") -.set_body_typed(tvm::round); +TVM_REGISTER_GLOBAL("tir.floor").set_body_typed(tvm::floor); -TVM_REGISTER_GLOBAL("tir.nearbyint") -.set_body_typed(tvm::nearbyint); +TVM_REGISTER_GLOBAL("tir.ceil").set_body_typed(tvm::ceil); -TVM_REGISTER_GLOBAL("tir.trunc") -.set_body_typed(tvm::trunc); +TVM_REGISTER_GLOBAL("tir.round").set_body_typed(tvm::round); -TVM_REGISTER_GLOBAL("tir._cast") -.set_body_typed(tvm::cast); +TVM_REGISTER_GLOBAL("tir.nearbyint").set_body_typed(tvm::nearbyint); +TVM_REGISTER_GLOBAL("tir.trunc").set_body_typed(tvm::trunc); +TVM_REGISTER_GLOBAL("tir._cast").set_body_typed(tvm::cast); // operator overloading, smarter than make -#define REGISTER_MAKE_BINARY_OP(Node, Func) \ - TVM_REGISTER_GLOBAL("tir."#Node) \ - .set_body_typed([](PrimExpr a, PrimExpr b) { \ - return (Func(a, b)); \ +#define REGISTER_MAKE_BINARY_OP(Node, Func) \ + TVM_REGISTER_GLOBAL("tir." #Node).set_body_typed([](PrimExpr a, PrimExpr b) { \ + return (Func(a, b)); \ }) -#define REGISTER_MAKE_BIT_OP(Node, Func) \ - TVM_REGISTER_GLOBAL("tir."#Node) \ - .set_body([](TVMArgs args, TVMRetValue *ret) { \ - bool lhs_is_int = args[0].type_code() == kDLInt; \ - bool rhs_is_int = args[1].type_code() == kDLInt; \ - if (lhs_is_int) { \ - *ret = (Func(args[0].operator int(), args[1].operator PrimExpr())); \ - } else if (rhs_is_int) { \ - *ret = (Func(args[0].operator PrimExpr(), args[1].operator int())); \ - } else { \ - *ret = (Func(args[0].operator PrimExpr(), args[1].operator PrimExpr())); \ - } \ +#define REGISTER_MAKE_BIT_OP(Node, Func) \ + TVM_REGISTER_GLOBAL("tir." #Node).set_body([](TVMArgs args, TVMRetValue* ret) { \ + bool lhs_is_int = args[0].type_code() == kDLInt; \ + bool rhs_is_int = args[1].type_code() == kDLInt; \ + if (lhs_is_int) { \ + *ret = (Func(args[0].operator int(), args[1].operator PrimExpr())); \ + } else if (rhs_is_int) { \ + *ret = (Func(args[0].operator PrimExpr(), args[1].operator int())); \ + } else { \ + *ret = (Func(args[0].operator PrimExpr(), args[1].operator PrimExpr())); \ + } \ }) - REGISTER_MAKE_BINARY_OP(_OpAdd, operator+); REGISTER_MAKE_BINARY_OP(_OpSub, operator-); REGISTER_MAKE_BINARY_OP(_OpMul, operator*); @@ -821,20 +767,20 @@ REGISTER_MAKE_BINARY_OP(_OpMin, min); REGISTER_MAKE_BINARY_OP(_OpMax, max); REGISTER_MAKE_BINARY_OP(_OpEQ, operator==); REGISTER_MAKE_BINARY_OP(_OpNE, operator!=); -REGISTER_MAKE_BINARY_OP(_OpLT, operator<); // NOLINT(*) -REGISTER_MAKE_BINARY_OP(_OpLE, operator<=); // NOLINT(*) -REGISTER_MAKE_BINARY_OP(_OpGT, operator>); // NOLINT(*) +REGISTER_MAKE_BINARY_OP(_OpLT, operator<); // NOLINT(*) +REGISTER_MAKE_BINARY_OP(_OpLE, operator<=); // NOLINT(*) +REGISTER_MAKE_BINARY_OP(_OpGT, operator>); // NOLINT(*) REGISTER_MAKE_BINARY_OP(_OpGE, operator>=); REGISTER_MAKE_BINARY_OP(_OpAnd, operator&&); REGISTER_MAKE_BINARY_OP(_OpOr, operator||); REGISTER_MAKE_BIT_OP(bitwise_and, operator&); REGISTER_MAKE_BIT_OP(bitwise_or, operator|); REGISTER_MAKE_BIT_OP(bitwise_xor, operator^); -REGISTER_MAKE_BIT_OP(left_shift, operator<<); // NOLINT(*) +REGISTER_MAKE_BIT_OP(left_shift, operator<<); // NOLINT(*) REGISTER_MAKE_BIT_OP(right_shift, operator>>); TVM_REGISTER_GLOBAL("tir._OpIfThenElse") -.set_body_typed([] (PrimExpr cond, PrimExpr true_value, PrimExpr false_value) { - return if_then_else(cond, true_value, false_value); -}); + .set_body_typed([](PrimExpr cond, PrimExpr true_value, PrimExpr false_value) { + return if_then_else(cond, true_value, false_value); + }); } // namespace tvm diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 1f6a7dd027ea..66497755c88a 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -21,14 +21,14 @@ * \file tvm/tir/stmt.cc */ #include +#include #include -#include -#include "../pass/ir_util.h" namespace tvm { namespace tir { -Stmt LetStmtNode::make(Var var, PrimExpr value, Stmt body) { +// LetStmt +LetStmt::LetStmt(Var var, PrimExpr value, Stmt body) { CHECK(value.defined()); CHECK(body.defined()); CHECK_EQ(value.dtype(), var.dtype()); @@ -37,57 +37,94 @@ Stmt LetStmtNode::make(Var var, PrimExpr value, Stmt body) { node->var = std::move(var); node->value = std::move(value); node->body = std::move(body); - return Stmt(node); + data_ = std::move(node); } -TVM_REGISTER_GLOBAL("tir.LetStmt") -.set_body_typed(LetStmtNode::make); +TVM_REGISTER_GLOBAL("tir.LetStmt").set_body_typed([](Var var, PrimExpr value, Stmt body) { + return LetStmt(var, value, body); +}); + +TVM_REGISTER_NODE_TYPE(LetStmtNode); -Stmt AttrStmtNode::make(ObjectRef node, - std::string attr_key, - PrimExpr value, - Stmt body) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + p->stream << "let " << op->var << " = "; + p->Print(op->value); + p->stream << '\n'; + p->Print(op->body); + }); + +// AttrStmt +AttrStmt::AttrStmt(ObjectRef node, String attr_key, PrimExpr value, Stmt body) { auto n = make_object(); n->node = node; n->attr_key = std::move(attr_key); n->value = std::move(value); n->body = std::move(body); - return Stmt(n); + data_ = std::move(n); } TVM_REGISTER_GLOBAL("tir.AttrStmt") -.set_body_typed(AttrStmtNode::make); + .set_body_typed([](ObjectRef node, String attr_key, PrimExpr value, Stmt body) { + return AttrStmt(node, attr_key, value, body); + }); + +TVM_REGISTER_NODE_TYPE(AttrStmtNode); -Stmt AssertStmtNode::make(PrimExpr condition, PrimExpr message, Stmt body) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + p->stream << "// attr ["; + p->Print(op->node); + p->stream << "] " << op->attr_key << " = "; + p->Print(op->value); + p->stream << '\n'; + p->Print(op->body); + }); + +// AssertStmt +AssertStmt::AssertStmt(PrimExpr condition, PrimExpr message, Stmt body) { CHECK(condition.defined()); - CHECK(message.dtype() == DataType::Int(32) || - message.as()) - << "TypeError: AssertStmt message must be an int or string:" - << message << "\n"; + CHECK(message.dtype() == DataType::Int(32) || message.as()) + << "TypeError: AssertStmt message must be an int or string:" << message << "\n"; ObjectPtr node = make_object(); node->condition = std::move(condition); node->message = std::move(message); node->body = std::move(body); - return Stmt(node); + data_ = std::move(node); } +TVM_REGISTER_NODE_TYPE(AssertStmtNode); + TVM_REGISTER_GLOBAL("tir.AssertStmt") -.set_body_typed([](PrimExpr condition, ObjectRef message, Stmt body) { - if (const auto* str = message.as()) { - auto msg = StringImmNode::make(str->data); - return AssertStmtNode::make(condition, msg, body); - } else { - return AssertStmtNode::make(condition, Downcast(message), body); - } -}); + .set_body_typed([](PrimExpr condition, ObjectRef message, Stmt body) { + if (const auto* str = message.as()) { + auto msg = StringImm(str->data); + return AssertStmt(condition, msg, body); + } else { + return AssertStmt(condition, Downcast(message), body); + } + }); -Stmt ForNode::make(Var loop_var, - PrimExpr min, - PrimExpr extent, - ForType for_type, - DeviceAPI device_api, - Stmt body) { +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + p->stream << "assert("; + p->Print(op->condition); + p->stream << ", "; + p->Print(op->message); + p->stream << ")\n"; + p->Print(op->body); + }); + +// For +For::For(Var loop_var, PrimExpr min, PrimExpr extent, ForType for_type, DeviceAPI device_api, + Stmt body) { CHECK(min.defined()); CHECK(extent.defined()); CHECK(min.dtype().is_scalar()); @@ -102,23 +139,55 @@ Stmt ForNode::make(Var loop_var, node->for_type = for_type; node->device_api = device_api; node->body = std::move(body); - return Stmt(node); + data_ = std::move(node); } -TVM_REGISTER_GLOBAL("tir.For") -.set_body_typed([]( - Var loop_var, PrimExpr min, PrimExpr extent, - int for_type, int device_api, Stmt body) { - return ForNode::make(loop_var, - min, - extent, - static_cast(for_type), - static_cast(device_api), - body); +TVM_REGISTER_GLOBAL("tir.For").set_body_typed([](Var loop_var, PrimExpr min, PrimExpr extent, + int for_type, int device_api, Stmt body) { + return For(loop_var, min, extent, static_cast(for_type), + static_cast(device_api), body); }); +TVM_REGISTER_NODE_TYPE(ForNode); + +std::ostream& operator<<(std::ostream& out, ForType type) { // NOLINT(*) + switch (type) { + case ForType::Serial: + out << "for"; + break; + case ForType::Parallel: + out << "parallel"; + break; + case ForType::Unrolled: + out << "unrolled"; + break; + case ForType::Vectorized: + out << "vectorized"; + break; + } + return out; +} + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + p->stream << op->for_type << " (" << op->loop_var << ", "; + p->Print(op->min); + p->stream << ", "; + p->Print(op->extent); + p->stream << ") {\n"; + + p->indent += 2; + p->Print(op->body); + p->indent -= 2; + + p->PrintIndent(); + p->stream << "}\n"; + }); -Stmt StoreNode::make(Var buffer_var, PrimExpr value, PrimExpr index, PrimExpr predicate) { +// Store +Store::Store(Var buffer_var, PrimExpr value, PrimExpr index, PrimExpr predicate) { CHECK(value.defined()); CHECK(index.defined()); CHECK(predicate.defined()); @@ -130,77 +199,90 @@ Stmt StoreNode::make(Var buffer_var, PrimExpr value, PrimExpr index, PrimExpr pr node->value = std::move(value); node->index = std::move(index); node->predicate = std::move(predicate); - return Stmt(node); + data_ = std::move(node); } +TVM_REGISTER_GLOBAL("tir.Store").set_body([](TVMArgs args, TVMRetValue* ret) { + PrimExpr value = args[1]; + if (args.size() == 3) { + *ret = Store(args[0], value, args[2], const_true(value.dtype().lanes())); + } else { + *ret = Store(args[0], value, args[2], args[3]); + } +}); -TVM_REGISTER_GLOBAL("tir.Store") -.set_body([](TVMArgs args, TVMRetValue *ret) { - PrimExpr value = args[1]; - if (args.size() == 3) { - *ret = StoreNode::make(args[0], value, args[2], const_true(value.dtype().lanes())); - } else { - *ret = StoreNode::make(args[0], value, args[2], args[3]); - } - }); - - -Stmt ProvideNode::make(FunctionRef func, int value_index, PrimExpr value, Array args) { - CHECK(value_index >=0 && value_index < func->num_outputs()) - << "value index output function return value bound"; - CHECK(value.defined()) << "Provide of undefined value\n"; +TVM_REGISTER_NODE_TYPE(StoreNode); - for (size_t i = 0; i < args.size(); ++i) { - CHECK(args[i].defined()) << "Provide to undefined location\n"; - } +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + p->stream << op->buffer_var << "["; + p->Print(op->index); + p->stream << "] = "; + p->Print(op->value); + if (!is_one(op->predicate)) { + p->stream << " if "; + p->Print(op->predicate); + } + p->stream << '\n'; + }); - ObjectPtr node = make_object(); - node->func = std::move(func); - node->value_index = value_index; +// ProducerStore +ProducerStore::ProducerStore(DataProducer producer, PrimExpr value, Array indices) { + ObjectPtr node = make_object(); + node->producer = std::move(producer); node->value = std::move(value); - node->args = std::move(args); - return Stmt(node); + node->indices = std::move(indices); + data_ = std::move(node); } -TVM_REGISTER_GLOBAL("tir.Provide") -.set_body_typed(ProvideNode::make); +TVM_REGISTER_GLOBAL("tir.ProducerStore") + .set_body_typed([](DataProducer producer, PrimExpr value, Array indices) { + return ProducerStore(producer, value, indices); + }); +TVM_REGISTER_NODE_TYPE(ProducerStoreNode); -Stmt AllocateNode::make(Var buffer_var, - DataType dtype, - Array extents, - PrimExpr condition, - Stmt body) { - for (size_t i = 0; i < extents.size(); ++i) { - CHECK(extents[i].defined()); - CHECK(extents[i].dtype().is_scalar()); - } - CHECK(body.defined()); - CHECK(condition.defined()); - CHECK(condition.dtype().is_bool()); - - ObjectPtr node = make_object(); - node->buffer_var = std::move(buffer_var); - node->dtype = dtype; - node->extents = std::move(extents); - node->condition = std::move(condition); - node->body = std::move(body); - return Stmt(node); -} +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + p->stream << op->producer->GetNameHint() << "["; + for (size_t i = 0; i < op->indices.size(); ++i) { + p->Print(op->indices[i]); + if (i < op->indices.size() - 1) p->stream << ", "; + } + p->stream << "]"; + p->stream << " ="; + p->Print(op->value); + p->stream << '\n'; + }); -// overloaded, needs special handling -// has default args -TVM_REGISTER_GLOBAL("tir.Allocate") -.set_body_typed([]( - Var buffer_var, DataType type, Array extents, PrimExpr condition, Stmt body - ){ - return AllocateNode::make(buffer_var, type, extents, condition, body); -}); +// Allocate +Allocate::Allocate(Var buffer_var, DataType dtype, Array extents, PrimExpr condition, + Stmt body) { + for (size_t i = 0; i < extents.size(); ++i) { + CHECK(extents[i].defined()); + CHECK(extents[i].dtype().is_scalar()); + } + CHECK(body.defined()); + CHECK(condition.defined()); + CHECK(condition.dtype().is_bool()); + + ObjectPtr node = make_object(); + node->buffer_var = std::move(buffer_var); + node->dtype = dtype; + node->extents = std::move(extents); + node->condition = std::move(condition); + node->body = std::move(body); + data_ = std::move(node); +} int32_t AllocateNode::constant_allocation_size(const Array& extents) { int64_t result = 1; for (size_t i = 0; i < extents.size(); ++i) { - if (const IntImmNode *int_size = extents[i].as()) { + if (const IntImmNode* int_size = extents[i].as()) { result *= int_size->value; if (result > std::numeric_limits::max()) { return 0; @@ -212,22 +294,33 @@ int32_t AllocateNode::constant_allocation_size(const Array& extents) { return static_cast(result); } -Stmt FreeNode::make(Var buffer_var) { - ObjectPtr node = make_object(); - node->buffer_var = buffer_var; - return Stmt(node); -} +TVM_REGISTER_GLOBAL("tir.Allocate") + .set_body_typed([](Var buffer_var, DataType type, Array extents, PrimExpr condition, + Stmt body) { return Allocate(buffer_var, type, extents, condition, body); }); -TVM_REGISTER_GLOBAL("tir.Free") -.set_body_typed(FreeNode::make); +TVM_REGISTER_NODE_TYPE(AllocateNode); +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + p->stream << "allocate " << op->buffer_var << "[" << op->dtype; + for (size_t i = 0; i < op->extents.size(); ++i) { + p->stream << " * "; + p->Print(op->extents[i]); + } + p->stream << "]"; + if (!is_one(op->condition)) { + p->stream << " if "; + p->Print(op->condition); + } + p->stream << "\n"; + p->Print(op->body); + }); -Stmt RealizeNode::make(FunctionRef func, - int value_index, - DataType dtype, - Region bounds, - PrimExpr condition, - Stmt body) { +// ProducerRealize +ProducerRealize::ProducerRealize(DataProducer producer, Region bounds, PrimExpr condition, + Stmt body) { for (size_t i = 0; i < bounds.size(); ++i) { CHECK(bounds[i]->min.defined()); CHECK(bounds[i]->extent.defined()); @@ -238,79 +331,188 @@ Stmt RealizeNode::make(FunctionRef func, CHECK(condition.defined()); CHECK(condition.dtype().is_bool()); - ObjectPtr node = make_object(); - node->func = std::move(func); - node->value_index = value_index; - node->dtype = dtype; + ObjectPtr node = make_object(); + node->producer = std::move(producer); node->bounds = std::move(bounds); node->condition = std::move(condition); node->body = std::move(body); - return Stmt(node); + data_ = std::move(node); } +TVM_REGISTER_GLOBAL("tir.ProducerRealize") + .set_body_typed([](DataProducer producer, Region bounds, PrimExpr condition, Stmt body) { + return ProducerRealize(producer, bounds, condition, body); + }); -TVM_REGISTER_GLOBAL("tir.Realize") -.set_body_typed(RealizeNode::make); +TVM_REGISTER_NODE_TYPE(ProducerRealizeNode); +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + p->stream << "producer_realize " << op->producer->GetNameHint() << "("; + for (size_t i = 0; i < op->bounds.size(); ++i) { + p->stream << "["; + p->Print(op->bounds[i]->min); + p->stream << ", "; + p->Print(op->bounds[i]->extent); + p->stream << "]"; + if (i < op->bounds.size() - 1) p->stream << ", "; + } + p->stream << ")"; + if (!is_one(op->condition)) { + p->stream << " if "; + p->Print(op->condition); + } + p->stream << " {\n"; -Stmt PrefetchNode::make(FunctionRef func, int value_index, DataType dtype, Region bounds) { - for (size_t i = 0; i < bounds.size(); ++i) { - CHECK(bounds[i]->min.defined()); - CHECK(bounds[i]->extent.defined()); - CHECK(bounds[i]->min.dtype().is_scalar()); - CHECK(bounds[i]->extent.dtype().is_scalar()); - } + p->indent += 2; + p->Print(op->body); + p->indent -= 2; - ObjectPtr node = make_object(); - node->func = std::move(func); - node->value_index = value_index; - node->dtype = dtype; - node->bounds = std::move(bounds); - return Stmt(node); + p->PrintIndent(); + p->stream << "}\n"; + }); + +// Free +Free::Free(Var buffer_var) { + ObjectPtr node = make_object(); + node->buffer_var = buffer_var; + data_ = std::move(node); +} + +TVM_REGISTER_GLOBAL("tir.Free").set_body_typed([](Var buffer_var) { return Free(buffer_var); }); + +TVM_REGISTER_NODE_TYPE(FreeNode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + p->stream << "free " << op->buffer_var; + p->stream << '\n'; + }); + +// Prefetch +Prefetch::Prefetch(Buffer buffer, Array bounds) { + data_ = make_object(buffer, bounds); } -TVM_REGISTER_GLOBAL("tir.Prefetch") -.set_body_typed(PrefetchNode::make); +TVM_REGISTER_GLOBAL("tir.Prefetch").set_body_typed([](Buffer buffer, Array bounds) { + return Prefetch(buffer, bounds); +}); +TVM_REGISTER_NODE_TYPE(PrefetchNode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + p->stream << "prefetch " << op->buffer << "("; + for (size_t i = 0; i < op->bounds.size(); ++i) { + p->stream << "["; + p->Print(op->bounds[i]->min); + p->stream << ", "; + p->Print(op->bounds[i]->extent); + p->stream << "]"; + if (i < op->bounds.size() - 1) p->stream << ", "; + } + p->stream << ")"; + }); +// SeqStmt SeqStmt::SeqStmt(Array seq) { auto node = make_object(); node->seq = std::move(seq); data_ = std::move(node); } -TVM_REGISTER_GLOBAL("tir.SeqStmt") -.set_body_typed([](Array seq) { +TVM_REGISTER_GLOBAL("tir.SeqStmt").set_body_typed([](Array seq) { return SeqStmt(std::move(seq)); }); -Stmt IfThenElseNode::make(PrimExpr condition, Stmt then_case, Stmt else_case) { +TVM_REGISTER_NODE_TYPE(SeqStmtNode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + for (Stmt stmt : op->seq) { + p->Print(stmt); + } + }); + +// IfThenElse +IfThenElse::IfThenElse(PrimExpr condition, Stmt then_case, Stmt else_case) { CHECK(condition.defined()); CHECK(then_case.defined()); // else_case may be null. - ObjectPtr node = make_object(); node->condition = std::move(condition); node->then_case = std::move(then_case); node->else_case = std::move(else_case); - return Stmt(node); + data_ = std::move(node); } +TVM_REGISTER_NODE_TYPE(IfThenElseNode); + TVM_REGISTER_GLOBAL("tir.IfThenElse") -.set_body_typed(IfThenElseNode::make); + .set_body_typed([](PrimExpr condition, Stmt then_case, Stmt else_case) { + return IfThenElse(condition, then_case, else_case); + }); +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + while (true) { + p->stream << "if (" << op->condition << ") {\n"; + p->indent += 2; + p->Print(op->then_case); + p->indent -= 2; + + if (!op->else_case.defined()) { + break; + } + + if (const IfThenElseNode* nested_if = op->else_case.as()) { + p->PrintIndent(); + p->stream << "} else "; + op = nested_if; + } else { + p->PrintIndent(); + p->stream << "} else {\n"; + p->indent += 2; + p->Print(op->else_case); + p->indent -= 2; + break; + } + } + p->PrintIndent(); + p->stream << "}\n"; + }); -Stmt EvaluateNode::make(PrimExpr value) { +// Evaluate +Evaluate::Evaluate(PrimExpr value) { CHECK(value.defined()); ObjectPtr node = make_object(); node->value = std::move(value); - return Stmt(node); + data_ = std::move(node); } -TVM_REGISTER_GLOBAL("tir.Evaluate") -.set_body_typed(EvaluateNode::make); +TVM_REGISTER_GLOBAL("tir.Evaluate").set_body_typed([](PrimExpr value) { return Evaluate(value); }); + +TVM_REGISTER_NODE_TYPE(EvaluateNode); +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + p->Print(op->value); + p->stream << "\n"; + }); + +// BufferStore BufferStore::BufferStore(Buffer buffer, PrimExpr value, Array indices) { ObjectPtr node = make_object(); node->buffer = std::move(buffer); @@ -320,276 +522,65 @@ BufferStore::BufferStore(Buffer buffer, PrimExpr value, Array indices) } TVM_REGISTER_GLOBAL("tir.BufferStore") -.set_body_typed([](Buffer buffer, PrimExpr value, Array indices) { - return BufferStore(buffer, value, indices); -}); + .set_body_typed([](Buffer buffer, PrimExpr value, Array indices) { + return BufferStore(buffer, value, indices); + }); TVM_REGISTER_NODE_TYPE(BufferStoreNode); -// Printers - -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - p->stream << "let " << op->var << " = "; - p->Print(op->value); - p->stream << '\n'; - p->Print(op->body); - }); - -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - p->stream << "// attr ["; - p->Print(op->node); - p->stream << "] " - << op->attr_key << " = "; - p->Print(op->value); - p->stream << '\n'; - p->Print(op->body); - }); - -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - p->stream << "assert("; - p->Print(op->condition); - p->stream << ", "; - p->Print(op->message); - p->stream << ")\n"; - p->Print(op->body); - }); - -std::ostream &operator<<(std::ostream& out, ForType type) { // NOLINT(*) - switch (type) { - case ForType::Serial: - out << "for"; - break; - case ForType::Parallel: - out << "parallel"; - break; - case ForType::Unrolled: - out << "unrolled"; - break; - case ForType::Vectorized: - out << "vectorized"; - break; - } - return out; -} - -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - p->stream << op->for_type << " (" << op->loop_var << ", "; - p->Print(op->min); - p->stream << ", "; - p->Print(op->extent); - p->stream << ") {\n"; - - p->indent += 2; - p->Print(op->body); - p->indent -= 2; - - p->PrintIndent(); - p->stream << "}\n"; -}); - -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - p->stream << op->buffer_var << "["; - p->Print(op->index); - p->stream << "] = "; - p->Print(op->value); - if (!is_one(op->predicate)) { - p->stream << " if "; - p->Print(op->predicate); - } - p->stream << '\n'; - }); - -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - p->stream << op->func->func_name() << "("; - for (size_t i = 0; i < op->args.size(); ++i) { - p->Print(op->args[i]); - if (i < op->args.size() - 1) p->stream << ", "; - } - p->stream << ")"; - if (op->func->num_outputs() != 1) { - p->stream << ".value[" << op->value_index << "]"; - } - p->stream << " ="; - p->Print(op->value); - p->stream << '\n'; - }); - -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - p->stream << "allocate " << op->buffer_var << "[" << op->dtype; - for (size_t i = 0; i < op->extents.size(); ++i) { - p->stream << " * "; - p->Print(op->extents[i]); - } - p->stream << "]"; - if (!is_one(op->condition)) { - p->stream << " if "; - p->Print(op->condition); - } - p->stream << "\n"; - p->Print(op->body); - }); - -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - p->stream << "free " << op->buffer_var; - p->stream << '\n'; - }); - TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - p->stream << "realize " << op->func->func_name() << "("; - for (size_t i = 0; i < op->bounds.size(); ++i) { - p->stream << "["; - p->Print(op->bounds[i]->min); - p->stream << ", "; - p->Print(op->bounds[i]->extent); + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + p->stream << op->buffer->name << "["; + for (size_t i = 0; i < op->indices.size(); ++i) { + p->Print(op->indices[i]); + if (i < op->indices.size() - 1) p->stream << ", "; + } p->stream << "]"; - if (i < op->bounds.size() - 1) p->stream << ", "; - } - p->stream << ")"; - if (op->func->num_outputs() != 1) { - p->stream << ".value[" << op->value_index << "]"; - } - if (!is_one(op->condition)) { - p->stream << " if "; - p->Print(op->condition); - } - p->stream << " {\n"; - - p->indent += 2; - p->Print(op->body); - p->indent -= 2; - - p->PrintIndent(); - p->stream << "}\n"; - }); + p->stream << " = "; + p->Print(op->value); + p->stream << '\n'; + }); + +// BufferRealize +BufferRealize::BufferRealize(Buffer buffer, Array bounds, PrimExpr condition, Stmt body) { + data_ = make_object(buffer, bounds, condition, body); +} -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - p->stream << "prefetch " << op->func->func_name() << "("; - for (size_t i = 0; i < op->bounds.size(); ++i) { - p->stream << "["; - p->Print(op->bounds[i]->min); - p->stream << ", "; - p->Print(op->bounds[i]->extent); - p->stream << "]"; - if (i < op->bounds.size() - 1) p->stream << ", "; - } - p->stream << ")"; - if (op->func->num_outputs() != 1) { - p->stream << ".value[" << op->value_index << "]"; - } - }); +TVM_REGISTER_GLOBAL("tir.BufferRealize") + .set_body_typed([](Buffer buffer, Array bounds, PrimExpr condition, Stmt body) { + return BufferRealize(buffer, bounds, condition, body); + }); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - for (Stmt stmt : op->seq) { - p->Print(stmt); - } - }); +TVM_REGISTER_NODE_TYPE(BufferRealizeNode); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - while (true) { - p->stream << "if (" << op->condition << ") {\n"; - p->indent += 2; - p->Print(op->then_case); - p->indent -= 2; - - if (!op->else_case.defined()) { - break; + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + p->stream << "buffer_realize " << op->buffer->name << "("; + for (size_t i = 0; i < op->bounds.size(); ++i) { + p->stream << "["; + p->Print(op->bounds[i]->min); + p->stream << ", "; + p->Print(op->bounds[i]->extent); + p->stream << "]"; + if (i < op->bounds.size() - 1) p->stream << ", "; } - - if (const IfThenElseNode *nested_if = op->else_case.as()) { - p->PrintIndent(); - p->stream << "} else "; - op = nested_if; - } else { - p->PrintIndent(); - p->stream << "} else {\n"; - p->indent += 2; - p->Print(op->else_case); - p->indent -= 2; - break; + p->stream << ")"; + if (!is_one(op->condition)) { + p->stream << " if "; + p->Print(op->condition); } - } - p->PrintIndent(); - p->stream << "}\n"; -}); - -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - p->Print(op->value); - p->stream << "\n"; - }); - -template -void PrintList(const Array &exprs, ReprPrinter* p) { - for (size_t i = 0; i < exprs.size(); ++i) { - p->Print(exprs[i]); - if (i < exprs.size() - 1) { - p->stream << ", "; - } - } -} - -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "shuffle("; - PrintList(op->vectors, p); - p->stream << ", "; - PrintList(op->indices, p); - p->stream << ")"; - }); + p->stream << " {\n"; -TVM_REGISTER_NODE_TYPE(AttrStmtNode); -TVM_REGISTER_NODE_TYPE(PrefetchNode); -TVM_REGISTER_NODE_TYPE(CallNode); -TVM_REGISTER_NODE_TYPE(LetNode); -TVM_REGISTER_NODE_TYPE(LetStmtNode); -TVM_REGISTER_NODE_TYPE(AssertStmtNode); -TVM_REGISTER_NODE_TYPE(ForNode); -TVM_REGISTER_NODE_TYPE(StoreNode); -TVM_REGISTER_NODE_TYPE(ProvideNode); -TVM_REGISTER_NODE_TYPE(AllocateNode); -TVM_REGISTER_NODE_TYPE(FreeNode); -TVM_REGISTER_NODE_TYPE(RealizeNode); -TVM_REGISTER_NODE_TYPE(SeqStmtNode); -TVM_REGISTER_NODE_TYPE(IfThenElseNode); -TVM_REGISTER_NODE_TYPE(EvaluateNode); + p->indent += 2; + p->Print(op->body); + p->indent -= 2; + p->PrintIndent(); + p->stream << "}\n"; + }); } // namespace tir } // namespace tvm diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc index ed3c2c75ef47..67329aa6414c 100644 --- a/src/tir/ir/stmt_functor.cc +++ b/src/tir/ir/stmt_functor.cc @@ -19,116 +19,16 @@ /*! * \file stmt_functor.cc */ +#include #include + +#include + #include "functor_common.h" namespace tvm { namespace tir { -// visitor to implement apply -class IRApplyVisit : - public StmtExprVisitor { - public: - explicit IRApplyVisit(std::function f) : f_(f) {} - - void VisitExpr(const PrimExpr& node) final { - if (visited_.count(node.get()) != 0) return; - visited_.insert(node.get()); - ExprVisitor::VisitExpr(node); - f_(node); - } - - void VisitStmt(const Stmt& node) final { - if (visited_.count(node.get()) != 0) return; - visited_.insert(node.get()); - StmtVisitor::VisitStmt(node); - f_(node); - } - - private: - std::function f_; - std::unordered_set visited_; -}; - -void PostOrderVisit(const ObjectRef& node, - std::function fvisit) { - if (node.as()) { - IRApplyVisit visitor(fvisit); - visitor(Downcast(node)); - } else { - IRApplyVisit visitor(fvisit); - visitor(Downcast(node)); - } -} - -class IRTransformer final : - public StmtExprMutator { - public: - IRTransformer(const runtime::PackedFunc& f_preorder, - const runtime::PackedFunc& f_postorder, - const std::unordered_set& only_enable) - : f_preorder_(f_preorder), - f_postorder_(f_postorder), - only_enable_(only_enable) { - } - - Stmt VisitStmt(const Stmt& stmt) final { - return MutateInternal(stmt, [this](const Stmt& s) { - return this->BaseVisitStmt(s); - }); - } - PrimExpr VisitExpr(const PrimExpr& expr) final { - return MutateInternal(expr, [this](const PrimExpr& e) { - return this->BaseVisitExpr(e); - }); - } - - private: - // NOTE: redirect to parent's call - // This is used to get around limitation of gcc-4.8 - Stmt BaseVisitStmt(const Stmt& s) { - return StmtMutator::VisitStmt(s); - } - PrimExpr BaseVisitExpr(const PrimExpr& e) { - return ExprMutator::VisitExpr(e); - } - - template - T MutateInternal(const T& node, F fmutate) { - if (only_enable_.size() && - !only_enable_.count(node->type_index())) { - return fmutate(node); - } - if (f_preorder_ != nullptr) { - T pre = f_preorder_(node); - if (pre.defined()) return pre; - } - T new_node = fmutate(node); - if (f_postorder_ != nullptr) { - T post = f_postorder_(new_node); - if (post.defined()) return post; - } - return new_node; - } - // The functions - const runtime::PackedFunc& f_preorder_; - const runtime::PackedFunc& f_postorder_; - // type indices enabled. - const std::unordered_set& only_enable_; -}; - -Stmt IRTransform(Stmt ir_node, - const runtime::PackedFunc& f_preorder, - const runtime::PackedFunc& f_postorder, - const Array& only_enable) { - std::unordered_set only_type_index; - for (auto s : only_enable) { - only_type_index.insert(Object::TypeKey2Index(s.c_str())); - } - IRTransformer transform(f_preorder, f_postorder, only_type_index); - return transform(std::move(ir_node)); -} - void StmtVisitor::VisitStmt_(const LetStmtNode* op) { this->VisitExpr(op->value); this->VisitStmt(op->body); @@ -158,9 +58,19 @@ void StmtVisitor::VisitStmt_(const StoreNode* op) { } void StmtVisitor::VisitStmt_(const BufferStoreNode* op) { + this->VisitExpr(op->value); VisitArray(op->indices, [this](const PrimExpr& e) { this->VisitExpr(e); }); } +void StmtVisitor::VisitStmt_(const BufferRealizeNode* op) { + VisitArray(op->bounds, [this](const Range& r) { + this->VisitExpr(r->min); + this->VisitExpr(r->extent); + }); + this->VisitExpr(op->condition); + this->VisitStmt(op->body); +} + void StmtVisitor::VisitStmt_(const IfThenElseNode* op) { this->VisitExpr(op->condition); this->VisitStmt(op->then_case); @@ -177,37 +87,32 @@ void StmtVisitor::VisitStmt_(const AssertStmtNode* op) { this->VisitStmt(op->body); } -void StmtVisitor::VisitStmt_(const ProvideNode* op) { - VisitArray(op->args, [this](const PrimExpr& e) { this->VisitExpr(e); }); +void StmtVisitor::VisitStmt_(const ProducerStoreNode* op) { + VisitArray(op->indices, [this](const PrimExpr& e) { this->VisitExpr(e); }); this->VisitExpr(op->value); } -void StmtVisitor::VisitStmt_(const RealizeNode* op) { +void StmtVisitor::VisitStmt_(const ProducerRealizeNode* op) { VisitArray(op->bounds, [this](const Range& r) { - this->VisitExpr(r->min); - this->VisitExpr(r->extent); - }); + this->VisitExpr(r->min); + this->VisitExpr(r->extent); + }); this->VisitStmt(op->body); this->VisitExpr(op->condition); } void StmtVisitor::VisitStmt_(const PrefetchNode* op) { VisitArray(op->bounds, [this](const Range& r) { - this->VisitExpr(r->min); - this->VisitExpr(r->extent); - }); + this->VisitExpr(r->min); + this->VisitExpr(r->extent); + }); } void StmtVisitor::VisitStmt_(const SeqStmtNode* op) { - VisitArray(op->seq, [this](const Stmt& s) { - this->VisitStmt(s); - }); -} - -void StmtVisitor::VisitStmt_(const EvaluateNode* op) { - this->VisitExpr(op->value); + VisitArray(op->seq, [this](const Stmt& s) { this->VisitStmt(s); }); } +void StmtVisitor::VisitStmt_(const EvaluateNode* op) { this->VisitExpr(op->value); } class StmtMutator::Internal { public: @@ -238,8 +143,7 @@ class StmtMutator::Internal { Stmt StmtMutator::VisitStmt_(const AttrStmtNode* op) { PrimExpr value = this->VisitExpr(op->value); Stmt body = this->VisitStmt(op->body); - if (value.same_as(op->value) && - body.same_as(op->body)) { + if (value.same_as(op->value) && body.same_as(op->body)) { return GetRef(op); } else { auto n = CopyOnWrite(op); @@ -252,8 +156,7 @@ Stmt StmtMutator::VisitStmt_(const AttrStmtNode* op) { Stmt StmtMutator::VisitStmt_(const LetStmtNode* op) { PrimExpr value = this->VisitExpr(op->value); Stmt body = this->VisitStmt(op->body); - if (value.same_as(op->value) && - body.same_as(op->body)) { + if (value.same_as(op->value) && body.same_as(op->body)) { return GetRef(op); } else { auto n = CopyOnWrite(op); @@ -267,9 +170,7 @@ Stmt StmtMutator::VisitStmt_(const ForNode* op) { PrimExpr min = this->VisitExpr(op->min); PrimExpr extent = this->VisitExpr(op->extent); Stmt body = this->VisitStmt(op->body); - if (min.same_as(op->min) && - extent.same_as(op->extent) && - body.same_as(op->body)) { + if (min.same_as(op->min) && extent.same_as(op->extent) && body.same_as(op->body)) { return GetRef(op); } else { auto n = CopyOnWrite(op); @@ -285,9 +186,7 @@ Stmt StmtMutator::VisitStmt_(const AllocateNode* op) { Stmt body = this->VisitStmt(op->body); PrimExpr condition = this->VisitExpr(op->condition); - if (extents.same_as(op->extents) && - body.same_as(op->body) && - condition.same_as(op->condition)) { + if (extents.same_as(op->extents) && body.same_as(op->body) && condition.same_as(op->condition)) { return GetRef(op); } else { auto n = CopyOnWrite(op); @@ -305,8 +204,7 @@ Stmt StmtMutator::VisitStmt_(const IfThenElseNode* op) { if (op->else_case.defined()) { else_case = this->VisitStmt(op->else_case); } - if (condition.same_as(op->condition) && - then_case.same_as(op->then_case) && + if (condition.same_as(op->condition) && then_case.same_as(op->then_case) && else_case.same_as(op->else_case)) { return GetRef(op); } else { @@ -322,9 +220,7 @@ Stmt StmtMutator::VisitStmt_(const StoreNode* op) { PrimExpr value = this->VisitExpr(op->value); PrimExpr index = this->VisitExpr(op->index); PrimExpr predicate = this->VisitExpr(op->predicate); - if (value.same_as(op->value) && - index.same_as(op->index) && - predicate.same_as(op->predicate)) { + if (value.same_as(op->value) && index.same_as(op->index) && predicate.same_as(op->predicate)) { return GetRef(op); } else { auto n = CopyOnWrite(op); @@ -336,37 +232,53 @@ Stmt StmtMutator::VisitStmt_(const StoreNode* op) { } Stmt StmtMutator::VisitStmt_(const BufferStoreNode* op) { + PrimExpr value = this->VisitExpr(op->value); Array indices = Internal::Mutate(this, op->indices); - if (indices.same_as(op->indices)) { + + if (value.same_as(op->value) && indices.same_as(op->indices)) { return GetRef(op); } else { auto n = CopyOnWrite(op); + n->value = std::move(value); n->indices = std::move(indices); return Stmt(n); } } -Stmt StmtMutator::VisitStmt_(const ProvideNode* op) { - Array args = Internal::Mutate(this, op->args); +Stmt StmtMutator::VisitStmt_(const BufferRealizeNode* op) { + Region bounds = Internal::Mutate(this, op->bounds); + PrimExpr condition = this->VisitExpr(op->condition); + Stmt body = this->VisitStmt(op->body); + + if (bounds.same_as(op->bounds) && condition.same_as(op->condition) && body.same_as(op->body)) { + return GetRef(op); + } else { + auto n = CopyOnWrite(op); + n->bounds = std::move(bounds); + n->condition = std::move(condition); + n->body = std::move(body); + return Stmt(n); + } +} + +Stmt StmtMutator::VisitStmt_(const ProducerStoreNode* op) { + Array indices = Internal::Mutate(this, op->indices); PrimExpr value = this->VisitExpr(op->value); - if (args.same_as(op->args) && - value.same_as(op->value)) { + if (indices.same_as(op->indices) && value.same_as(op->value)) { return GetRef(op); } else { auto n = CopyOnWrite(op); - n->args = std::move(args); + n->indices = std::move(indices); n->value = std::move(value); return Stmt(n); } } -Stmt StmtMutator::VisitStmt_(const RealizeNode* op) { +Stmt StmtMutator::VisitStmt_(const ProducerRealizeNode* op) { Region bounds = Internal::Mutate(this, op->bounds); Stmt body = this->VisitStmt(op->body); PrimExpr condition = this->VisitExpr(op->condition); - if (bounds.same_as(op->bounds) && - body.same_as(op->body) && - condition.same_as(op->condition)) { + if (bounds.same_as(op->bounds) && body.same_as(op->body) && condition.same_as(op->condition)) { return GetRef(op); } else { auto n = CopyOnWrite(op); @@ -400,8 +312,7 @@ Stmt StmtMutator::VisitStmt_(const SeqStmtNode* op) { } // advanced visit function for seqstmt. -Stmt StmtMutator::VisitSeqStmt_(const SeqStmtNode* op, - bool flatten_before_visit, +Stmt StmtMutator::VisitSeqStmt_(const SeqStmtNode* op, bool flatten_before_visit, std::function fmutate) { if (flatten_before_visit) { // Pass 1, check if we need to flatten. @@ -414,10 +325,8 @@ Stmt StmtMutator::VisitSeqStmt_(const SeqStmtNode* op, } // function to run the visit. auto frunvisit = [&](const SeqStmtNode* op) { - Array seq = - fmutate != nullptr ? - MutateArray(op->seq, fmutate, allow_copy_on_write_) : - Internal::Mutate(this, op->seq); + Array seq = fmutate != nullptr ? MutateArray(op->seq, fmutate, allow_copy_on_write_) + : Internal::Mutate(this, op->seq); if (seq.same_as(op->seq)) { return GetRef(op); } else { @@ -450,9 +359,7 @@ Stmt StmtMutator::VisitStmt_(const AssertStmtNode* op) { PrimExpr message = this->VisitExpr(op->message); Stmt body = this->VisitStmt(op->body); - if (condition.same_as(op->condition) && - message.same_as(op->message) && - body.same_as(op->body)) { + if (condition.same_as(op->condition) && message.same_as(op->message) && body.same_as(op->body)) { return GetRef(op); } else { auto n = CopyOnWrite(op); @@ -474,11 +381,156 @@ Stmt StmtMutator::VisitStmt_(const EvaluateNode* op) { } } -Stmt StmtMutator::VisitStmt_(const FreeNode* op) { - return GetRef(op); +Stmt StmtMutator::VisitStmt_(const FreeNode* op) { return GetRef(op); } + +// Implementations of IRTransform, PostOrderVisit and Substitute +class IRApplyVisit : public StmtExprVisitor { + public: + explicit IRApplyVisit(std::function f) : f_(f) {} + + void VisitExpr(const PrimExpr& node) final { + if (visited_.count(node.get()) != 0) return; + visited_.insert(node.get()); + ExprVisitor::VisitExpr(node); + f_(node); + } + + void VisitStmt(const Stmt& node) final { + if (visited_.count(node.get()) != 0) return; + visited_.insert(node.get()); + StmtVisitor::VisitStmt(node); + f_(node); + } + + private: + std::function f_; + std::unordered_set visited_; +}; + +void PostOrderVisit(const ObjectRef& node, std::function fvisit) { + if (node.as()) { + IRApplyVisit visitor(fvisit); + visitor(Downcast(node)); + } else { + IRApplyVisit visitor(fvisit); + visitor(Downcast(node)); + } +} + +class IRTransformer final : public StmtExprMutator { + public: + IRTransformer(const runtime::PackedFunc& f_preorder, const runtime::PackedFunc& f_postorder, + const std::unordered_set& only_enable) + : f_preorder_(f_preorder), f_postorder_(f_postorder), only_enable_(only_enable) {} + + Stmt VisitStmt(const Stmt& stmt) final { + return MutateInternal(stmt, [this](const Stmt& s) { return this->BaseVisitStmt(s); }); + } + PrimExpr VisitExpr(const PrimExpr& expr) final { + return MutateInternal(expr, + [this](const PrimExpr& e) { return this->BaseVisitExpr(e); }); + } + + private: + // NOTE: redirect to parent's call + // This is used to get around limitation of gcc-4.8 + Stmt BaseVisitStmt(const Stmt& s) { return StmtMutator::VisitStmt(s); } + PrimExpr BaseVisitExpr(const PrimExpr& e) { return ExprMutator::VisitExpr(e); } + + template + T MutateInternal(const T& node, F fmutate) { + if (only_enable_.size() && !only_enable_.count(node->type_index())) { + return fmutate(node); + } + if (f_preorder_ != nullptr) { + T pre = f_preorder_(node); + if (pre.defined()) return pre; + } + T new_node = fmutate(node); + if (f_postorder_ != nullptr) { + T post = f_postorder_(new_node); + if (post.defined()) return post; + } + return new_node; + } + // The functions + const runtime::PackedFunc& f_preorder_; + const runtime::PackedFunc& f_postorder_; + // type indices enabled. + const std::unordered_set& only_enable_; +}; + +Stmt IRTransform(Stmt ir_node, const runtime::PackedFunc& f_preorder, + const runtime::PackedFunc& f_postorder, Optional> only_enable) { + std::unordered_set only_type_index; + if (only_enable.defined()) { + for (auto s : only_enable.value()) { + only_type_index.insert(Object::TypeKey2Index(s.c_str())); + } + } + IRTransformer transform(f_preorder, f_postorder, only_type_index); + return transform(std::move(ir_node)); +} + +class IRSubstitue : public StmtExprMutator { + public: + explicit IRSubstitue(std::function(const Var&)> vmap) : vmap_(vmap) {} + + PrimExpr VisitExpr_(const VarNode* op) final { + Var var = GetRef(op); + auto ret = vmap_(var); + if (ret.defined()) return ret.value(); + return std::move(var); + } + + PrimExpr VisitExpr_(const LoadNode* op) final { + // NOTE: we do not explicit recursivly mutate op->buffer_var + PrimExpr ret = StmtExprMutator::VisitExpr_(op); + op = ret.as(); + if (auto mapped_var = vmap_(op->buffer_var)) { + return Load(op->dtype, Downcast(mapped_var.value()), op->index, op->predicate); + } else { + return ret; + } + } + + Stmt VisitStmt_(const StoreNode* op) final { + // NOTE: we do not explicit recursivly mutate op->buffer_var + Stmt ret = StmtExprMutator::VisitStmt_(op); + op = ret.as(); + if (auto mapped_var = vmap_(op->buffer_var)) { + return Store(Downcast(mapped_var.value()), op->value, op->index, op->predicate); + } else { + return ret; + } + } + + private: + std::function(const Var&)> vmap_; +}; + +Stmt Substitute(Stmt stmt, std::function(const Var&)> vmap) { + return IRSubstitue(vmap)(std::move(stmt)); } +PrimExpr Substitute(PrimExpr expr, std::function(const Var&)> vmap) { + return IRSubstitue(vmap)(std::move(expr)); +} + +TVM_REGISTER_GLOBAL("tir.IRTransform").set_body_typed(IRTransform); + +TVM_REGISTER_GLOBAL("tir.PostOrderVisit").set_body_typed([](ObjectRef node, PackedFunc f) { + tir::PostOrderVisit(node, [f](const ObjectRef& n) { f(n); }); +}); +TVM_REGISTER_GLOBAL("tir.Substitute") + .set_body_typed([](ObjectRef node, Map vmap) -> ObjectRef { + if (node->IsInstance()) { + return Substitute(Downcast(node), vmap); + } else { + return Substitute(Downcast(node), vmap); + } + }); } // namespace tir } // namespace tvm diff --git a/src/tir/ir/transform.cc b/src/tir/ir/transform.cc index dda9ff460cf0..50106c90a5e5 100644 --- a/src/tir/ir/transform.cc +++ b/src/tir/ir/transform.cc @@ -21,16 +21,14 @@ * \file tir/ir/transform.cc * \brief TIR specific transformation passes. */ -#include #include +#include #include - namespace tvm { namespace tir { namespace transform { - /*! * \brief Function level pass that applies transformations to all * TIR functions within the module. @@ -43,9 +41,7 @@ class PrimFuncPassNode : public PassNode { /*! \brief The pass function called on each. */ runtime::TypedPackedFunc pass_func; - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("pass_info", &pass_info); - } + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("pass_info", &pass_info); } /*! * \brief Run a function pass on given pass context. @@ -90,8 +86,7 @@ PrimFuncPass::PrimFuncPass( } // Perform Module -> Module optimizations at the PrimFunc level. -IRModule PrimFuncPassNode::operator()(IRModule mod, - const PassContext& pass_ctx) const { +IRModule PrimFuncPassNode::operator()(IRModule mod, const PassContext& pass_ctx) const { const PassInfo& pass_info = Info(); CHECK(mod.defined()); pass_ctx.Trace(mod, pass_info, true); @@ -123,9 +118,7 @@ IRModule PrimFuncPassNode::operator()(IRModule mod, Pass CreatePrimFuncPass( const runtime::TypedPackedFunc& pass_func, - int opt_level, - const std::string& name, - const tvm::Array& required) { + int opt_level, String name, tvm::Array required) { PassInfo pass_info = PassInfo(opt_level, name, required); return PrimFuncPass(pass_func, pass_info); } @@ -133,18 +126,16 @@ Pass CreatePrimFuncPass( TVM_REGISTER_NODE_TYPE(PrimFuncPassNode); TVM_REGISTER_GLOBAL("tir.transform.CreatePrimFuncPass") -.set_body_typed([](runtime::TypedPackedFunc pass_func, - PassInfo pass_info) { - return PrimFuncPass(pass_func, pass_info); -}); + .set_body_typed( + [](runtime::TypedPackedFunc pass_func, + PassInfo pass_info) { return PrimFuncPass(pass_func, pass_info); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - const PassInfo info = node->Info(); - p->stream << "PrimFuncPass(" << info->name - << ", opt_level=" << info->opt_level << ")"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + const PassInfo info = node->Info(); + p->stream << "PrimFuncPass(" << info->name << ", opt_level=" << info->opt_level << ")"; + }); } // namespace transform } // namespace tir diff --git a/src/tir/pass/ffi_api.cc b/src/tir/pass/ffi_api.cc deleted file mode 100644 index 3083b6879635..000000000000 --- a/src/tir/pass/ffi_api.cc +++ /dev/null @@ -1,140 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * Exposure of pass functions. - * \file ffi_api.cc - */ -#include -#include -#include -#include -#include -#include -#include - -namespace tvm { -namespace tir { - -TVM_REGISTER_GLOBAL("ir_pass.Simplify") -.set_body([](TVMArgs args, TVMRetValue *ret) { - if (args[0].IsObjectRef()) { - if (args.size() > 1) { - *ret = Simplify(args[0].operator Stmt(), args[1]); - } else { - *ret = Simplify(args[0].operator Stmt()); - } - } else { - if (args.size() > 1) { - *ret = Simplify(args[0].operator PrimExpr(), args[1]); - } else { - *ret = Simplify(args[0].operator PrimExpr()); - } - } - }); - -TVM_REGISTER_GLOBAL("ir_pass.CanonicalSimplify") -.set_body([](TVMArgs args, TVMRetValue *ret) { - if (args[0].IsObjectRef()) { - if (args.size() > 1) { - *ret = CanonicalSimplify(args[0].operator Stmt(), args[1]); - } else { - *ret = CanonicalSimplify(args[0].operator Stmt()); - } - } else { - if (args.size() > 1) { - *ret = CanonicalSimplify(args[0].operator PrimExpr(), args[1]); - } else { - *ret = CanonicalSimplify(args[0].operator PrimExpr()); - } - } - }); - -TVM_REGISTER_GLOBAL("ir_pass.Substitute") -.set_body([](TVMArgs args, TVMRetValue *ret) { - if (args[0].IsObjectRef()) { - *ret = Substitute(args[0].operator Stmt(), args[1].operator Map()); - } else { - *ret = Substitute(args[0].operator PrimExpr(), args[1].operator Map()); - } - }); - -TVM_REGISTER_GLOBAL("ir_pass.StorageFlatten") -.set_body([](TVMArgs args, TVMRetValue *ret) { - if (args.size() <= 3) { - *ret = StorageFlatten(args[0], args[1], args[2]); - } else { - *ret = StorageFlatten(args[0], args[1], args[2], args[3]); - } - }); - -TVM_REGISTER_GLOBAL("ir_pass.RewriteForTensorCore") -.set_body_typed - ([](const Stmt& stmt, - const te::Schedule& schedule, - const Map& extern_buffer) { - return RewriteForTensorCore(stmt, schedule, extern_buffer); - }); - -TVM_REGISTER_GLOBAL("ir_pass.ExprUseVar") -.set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = ExprUseVar(args[0].operator PrimExpr(), args[1].operator Var()); - }); - -TVM_REGISTER_GLOBAL("ir_pass.PostOrderVisit") -.set_body([](TVMArgs args, TVMRetValue *ret) { - PackedFunc f = args[1]; - tir::PostOrderVisit(args[0], [f](const ObjectRef& n) { - f(n); - }); - }); - - -// make from two arguments -#define REGISTER_PASS(PassName) \ - TVM_REGISTER_GLOBAL("ir_pass."#PassName) \ - .set_body_typed(PassName); \ - - -REGISTER_PASS(ConvertSSA); -REGISTER_PASS(VerifySSA); -REGISTER_PASS(RewriteUnsafeSelect); -REGISTER_PASS(Inline); -REGISTER_PASS(IRTransform); -REGISTER_PASS(VectorizeLoop); -REGISTER_PASS(SkipVectorize); -REGISTER_PASS(UnrollLoop); -REGISTER_PASS(InjectCopyIntrin); -REGISTER_PASS(StorageRewrite); -REGISTER_PASS(CoProcSync); -REGISTER_PASS(LowerStorageAccessInfo); -REGISTER_PASS(InjectVirtualThread); -REGISTER_PASS(InjectPrefetch); -REGISTER_PASS(InjectDoubleBuffer); -REGISTER_PASS(LoopPartition); -REGISTER_PASS(RemoveNoOp); -REGISTER_PASS(LiftAttrScope); -REGISTER_PASS(VerifyGPUCode); -REGISTER_PASS(DecorateDeviceScope); -REGISTER_PASS(InstrumentBoundCheckers); -REGISTER_PASS(VerifyCompactBuffer); -REGISTER_PASS(HoistIfThenElse); -REGISTER_PASS(NarrowDataType); -} // namespace tir -} // namespace tvm diff --git a/src/tir/pass/hoist_if_then_else.cc b/src/tir/pass/hoist_if_then_else.cc index 8bc462079eae..d1e24b94a32f 100644 --- a/src/tir/pass/hoist_if_then_else.cc +++ b/src/tir/pass/hoist_if_then_else.cc @@ -20,14 +20,15 @@ /*! * \file hoist_if_then_else.cc */ -#include -#include #include #include +#include +#include +#include #include #include -#include + #include "../../arith/interval_set.h" #include "../../runtime/thread_storage_scope.h" @@ -151,15 +152,14 @@ Stmt update_for(const Stmt& parent_for_stmt, const Stmt& new_if_stmt) { } }); - PackedFunc replace_target_for = PackedFunc( - [&](TVMArgs args, TVMRetValue *ret){ - const ObjectRef& current_for = args[0]; - if (current_for.get() == top_for_node) { - *ret = new_if_stmt; - } - }); + PackedFunc replace_target_for = PackedFunc([&](TVMArgs args, TVMRetValue* ret) { + const ObjectRef& current_for = args[0]; + if (current_for.get() == top_for_node) { + *ret = new_if_stmt; + } + }); - return IRTransform(parent_for_stmt, nullptr, replace_target_for, {"For"}); + return IRTransform(parent_for_stmt, nullptr, replace_target_for, Array{"tir.For"}); } // Remove IfThenElse node from a For node. @@ -169,25 +169,23 @@ std::pair RemoveIf(const Stmt& for_stmt, const Stmt& if_stmt) { Stmt else_for; CHECK(if_stmt.as()); - PackedFunc replace_then_case = PackedFunc( - [&](TVMArgs args, TVMRetValue *ret){ - const ObjectRef& node = args[0]; - if (node == if_stmt) { - *ret = node.as()->then_case; - } - }); + PackedFunc replace_then_case = PackedFunc([&](TVMArgs args, TVMRetValue* ret) { + const ObjectRef& node = args[0]; + if (node == if_stmt) { + *ret = node.as()->then_case; + } + }); - PackedFunc replace_else_case = PackedFunc( - [&](TVMArgs args, TVMRetValue *ret){ - const ObjectRef& node = args[0]; - if (node == if_stmt) { - *ret = node.as()->else_case; - } - }); + PackedFunc replace_else_case = PackedFunc([&](TVMArgs args, TVMRetValue* ret) { + const ObjectRef& node = args[0]; + if (node == if_stmt) { + *ret = node.as()->else_case; + } + }); - then_for = IRTransform(for_stmt, nullptr, replace_then_case, {"IfThenElse"}); + then_for = IRTransform(for_stmt, nullptr, replace_then_case, Array{"tir.IfThenElse"}); if (if_stmt.as()->else_case.defined()) { - else_for = IRTransform(for_stmt, nullptr, replace_else_case, {"IfThenElse"}); + else_for = IRTransform(for_stmt, nullptr, replace_else_case, Array{"tir.IfThenElse"}); } return std::make_pair(then_for, else_for); @@ -195,7 +193,7 @@ std::pair RemoveIf(const Stmt& for_stmt, const Stmt& if_stmt) { // Locate all For nodes and capture child IfThenElse nodes. void IfThenElseHoist::SelectCandidates(const Stmt& stmt) { - PostOrderVisit(stmt, [&](const ObjectRef& node){ + PostOrderVisit(stmt, [&](const ObjectRef& node) { const ForNode* for_node = node.as(); if (!for_node) return; @@ -268,10 +266,8 @@ void IfThenElseHoist::LocateTopFor() { CHECK(for_node); std::vector new_for_list{for_stmt}; for_tracking_map_.insert({for_stmt.get(), new_for_list}); - if (cond_var_map_[if_stmt] - .count(for_node->loop_var.get())) { - std::vector updated_for_list(for_list.begin(), - for_list.begin() + i); + if (cond_var_map_[if_stmt].count(for_node->loop_var.get())) { + std::vector updated_for_list(for_list.begin(), for_list.begin() + i); if2for_map_[if_stmt] = updated_for_list; break; } else { @@ -314,13 +310,11 @@ void IfThenElseHoist::LocateTopFor() { // We keep all For nodes tracing in for_tracking_map_. When we get a // hoisted IfThenElse, we match it with tracing For nodes to pick // the updated one. -size_t IfThenElseHoist::GetUpdatedFor(const Stmt& for_stmt, - const Stmt& if_stmt) { +size_t IfThenElseHoist::GetUpdatedFor(const Stmt& for_stmt, const Stmt& if_stmt) { std::vector tracked_for_list = for_tracking_map_[for_stmt.get()]; size_t updated_for_idx = 0; for (size_t i = 0; i < tracked_for_list.size(); ++i) { - const Stmt& current_for = - tracked_for_list.at(tracked_for_list.size() - 1 - i); + const Stmt& current_for = tracked_for_list.at(tracked_for_list.size() - 1 - i); if (is_first_if(current_for, if_stmt)) { updated_for_idx = tracked_for_list.size() - 1 - i; break; @@ -339,11 +333,11 @@ Stmt IfThenElseHoist::HoistIf(const Stmt& if_stmt) { for (size_t i = 0; i < if2for_map_[if_stmt.get()].size(); ++i) { const Stmt& for_stmt = if2for_map_[if_stmt.get()].at(i); size_t updated_for_idx = GetUpdatedFor(for_stmt, new_if); - const Stmt& updated_for_node = - for_tracking_map_[for_stmt.get()].at(updated_for_idx); + const Stmt& updated_for_node = for_tracking_map_[for_stmt.get()].at(updated_for_idx); auto generated_for_pair = RemoveIf(updated_for_node, new_if); const Stmt& then_for = generated_for_pair.first; - const Stmt& else_for = generated_for_pair.second;; + const Stmt& else_for = generated_for_pair.second; + for_tracking_map_[for_stmt.get()].at(updated_for_idx) = then_for; if (else_for.get()) { @@ -352,15 +346,13 @@ Stmt IfThenElseHoist::HoistIf(const Stmt& if_stmt) { const IfThenElseNode* new_if_node = new_if.as(); CHECK(new_if_node); - new_if = IfThenElseNode::make(new_if_node->condition, then_for, else_for); + new_if = IfThenElse(new_if_node->condition, then_for, else_for); if (i < if2for_map_[if_stmt.get()].size() - 1) { const Stmt& original_next_for = if2for_map_[if_stmt.get()].at(i + 1); - const Stmt& actual_next_for = - for_tracking_map_[original_next_for.get()].at(updated_for_idx); + const Stmt& actual_next_for = for_tracking_map_[original_next_for.get()].at(updated_for_idx); Stmt update_for_stmt = update_for(actual_next_for, new_if); - for_tracking_map_[original_next_for.get()]. - at(updated_for_idx) = update_for_stmt; + for_tracking_map_[original_next_for.get()].at(updated_for_idx) = update_for_stmt; } } return new_if; @@ -368,52 +360,45 @@ Stmt IfThenElseHoist::HoistIf(const Stmt& if_stmt) { // Mutate For nodes in post order DFS manner. Stmt IfThenElseHoist::PostOrderMutate(const Stmt& stmt) { - PackedFunc replace_top_for = PackedFunc( - [&](TVMArgs args, TVMRetValue *ret){ - const ObjectRef& current_for = args[0]; - const ForNode* for_node = current_for.as(); - if (!for_node) return; - - if (top_for_var_map_.count(for_node->loop_var.get())) { - std::vector new_if_list; - for (const Stmt& if_stmt : - top_for_var_map_[for_node->loop_var.get()]) { - new_if_list.emplace_back(HoistIf(if_stmt)); - } + PackedFunc replace_top_for = PackedFunc([&](TVMArgs args, TVMRetValue* ret) { + const ObjectRef& current_for = args[0]; + const ForNode* for_node = current_for.as(); + if (!for_node) return; - const IfThenElseNode* next_if_node; - const IfThenElseNode* current_if_node = - new_if_list.back().as(); - Stmt new_for = Stmt(); - for (size_t i = new_if_list.size() - 1; i > 0; --i) { - CHECK(current_if_node); - const Stmt current_if_stmt = - IfThenElseNode::make(current_if_node->condition, - current_if_node->then_case, - current_if_node->else_case); - next_if_node = new_if_list[i - 1].as(); - CHECK(next_if_node); - new_for = IfThenElseNode::make(next_if_node->condition, current_if_stmt, - next_if_node->else_case); - current_if_node = new_for.as(); - } + if (top_for_var_map_.count(for_node->loop_var.get())) { + std::vector new_if_list; + for (const Stmt& if_stmt : top_for_var_map_[for_node->loop_var.get()]) { + new_if_list.emplace_back(HoistIf(if_stmt)); + } - if (!new_for.get()) { - const IfThenElseNode* first_if_node = new_if_list[0].as(); - CHECK(first_if_node); - new_for = IfThenElseNode::make(first_if_node->condition, - first_if_node->then_case, - first_if_node->else_case); - } - *ret = new_for; + const IfThenElseNode* next_if_node; + const IfThenElseNode* current_if_node = new_if_list.back().as(); + Stmt new_for = Stmt(); + for (size_t i = new_if_list.size() - 1; i > 0; --i) { + CHECK(current_if_node); + const Stmt current_if_stmt = IfThenElse( + current_if_node->condition, current_if_node->then_case, current_if_node->else_case); + next_if_node = new_if_list[i - 1].as(); + CHECK(next_if_node); + new_for = IfThenElse(next_if_node->condition, current_if_stmt, next_if_node->else_case); + current_if_node = new_for.as(); } - }); - return IRTransform(stmt, nullptr, replace_top_for, {runtime::String("For")}); -} -Stmt HoistIfThenElse(Stmt stmt) { - return IfThenElseHoist().VisitAndMutate(stmt); + if (!new_for.get()) { + const IfThenElseNode* first_if_node = new_if_list[0].as(); + CHECK(first_if_node); + new_for = IfThenElse(first_if_node->condition, first_if_node->then_case, + first_if_node->else_case); + } + *ret = new_for; + } + }); + return IRTransform(stmt, nullptr, replace_top_for, Array{"tir.For"}); } +Stmt HoistIfThenElse(Stmt stmt) { return IfThenElseHoist().VisitAndMutate(stmt); } + +TVM_REGISTER_GLOBAL("testing.HoistIfThenElse").set_body_typed(HoistIfThenElse); + } // namespace tir } // namespace tvm diff --git a/src/tir/pass/ir_util.cc b/src/tir/pass/ir_util.cc deleted file mode 100644 index 7223c5b1c9e6..000000000000 --- a/src/tir/pass/ir_util.cc +++ /dev/null @@ -1,84 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file ir_util.cc - * \brief Helper functions to construct and compose IR nodes. - */ -#include "ir_util.h" - -namespace tvm { -namespace tir { - -Stmt MergeNest(const std::vector& nest, Stmt body) { - // use reverse iteration - for (auto ri = nest.rbegin(); ri != nest.rend(); ++ri) { - Stmt s = *ri; - if (const auto* for_ = s.as()) { - auto n = make_object(*for_); - CHECK(is_no_op(n->body)); - n->body = body; - body = Stmt(n); - } else if (const auto* let = s.as()) { - auto n = make_object(*let); - CHECK(is_no_op(n->body)); - n->body = body; - body = Stmt(n); - } else if (const auto* attr = s.as()) { - auto n = make_object(*attr); - CHECK(is_no_op(n->body)); - n->body = body; - body = Stmt(n); - } else if (const auto* ite = s.as()) { - auto n = make_object(*ite); - CHECK(is_no_op(n->then_case)); - CHECK(!n->else_case.defined()); - n->then_case = body; - body = Stmt(n); - } else if (const auto* seq = s.as()) { - auto n = make_object(*seq); - CHECK(n->size() != 0 && is_no_op(n->seq[n->size() - 1])); - n->seq.Set(n->size() - 1, body); - body = Stmt(n); - } else if (const auto* assert_ = s.as()) { - auto n = make_object(*assert_); - CHECK(is_no_op(n->body)); - n->body = body; - body = Stmt(n); - } else if (const auto* alloc = s.as()) { - auto n = make_object(*alloc); - CHECK(is_no_op(n->body)); - n->body = body; - body = Stmt(n); - } else { - LOG(FATAL) << "not supported nest type"; - } - } - return body; -} - -Stmt MergeNest(const std::vector >& nest, Stmt body) { - for (auto ri = nest.rbegin(); ri != nest.rend(); ++ri) { - body = MergeNest(*ri, body); - } - return body; -} - -} // namespace tir -} // namespace tvm diff --git a/src/tir/pass/simple_passes.cc b/src/tir/pass/simple_passes.cc deleted file mode 100644 index 93d17ba347fc..000000000000 --- a/src/tir/pass/simple_passes.cc +++ /dev/null @@ -1,189 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file simple_passes.cc - * \brief Implementation of simple passes - */ -#include -#include -#include - -namespace tvm { -namespace tir { - -class IRSideEffect : public ExprVisitor { - public: - void VisitExpr(const PrimExpr& e) final { - if (has_side_effect_) return; - ExprVisitor::VisitExpr(e); - } - - void VisitExpr_(const CallNode* op) final { - if (!op->is_pure()) { - has_side_effect_ = true; return; - } else { - ExprVisitor::VisitExpr_(op); - } - } - - bool has_side_effect_{false}; -}; - -bool HasSideEffect(const PrimExpr& e) { - IRSideEffect v; - v(e); - return v.has_side_effect_; -} - -class IRSubstitue : public StmtExprMutator { - public: - explicit IRSubstitue( - const std::unordered_map& smap) - : smap_(smap) { - } - - PrimExpr VisitExpr_(const VarNode* op) final { - auto it = smap_.find(op); - if (it != smap_.end()) { - return it->second; - } else { - return GetRef(op); - } - } - - PrimExpr VisitExpr_(const LoadNode* op) final { - // NOTE: we do not explicit recursivly mutate op->buffer_var - PrimExpr ret = StmtExprMutator::VisitExpr_(op); - op = ret.as(); - auto it = smap_.find(op->buffer_var.get()); - if (it != smap_.end()) { - return LoadNode::make( - op->dtype, Downcast(it->second), op->index, op->predicate); - } else { - return ret; - } - } - - Stmt VisitStmt_(const StoreNode* op) final { - // NOTE: we do not explicit recursivly mutate op->buffer_var - Stmt ret = StmtExprMutator::VisitStmt_(op); - op = ret.as(); - auto it = smap_.find(op->buffer_var.get()); - if (it != smap_.end()) { - return StoreNode::make( - Downcast(it->second), op->value, op->index, op->predicate); - } else { - return ret; - } - } - - private: - const std::unordered_map& smap_; -}; - -Stmt Substitute(Stmt stmt, - const std::unordered_map& value_map) { - if (value_map.size() == 0) return stmt; - return IRSubstitue(value_map)(std::move(stmt)); -} - -PrimExpr Substitute(PrimExpr expr, - const std::unordered_map& value_map) { - if (value_map.size() == 0) return expr; - return IRSubstitue(value_map)(std::move(expr)); -} - -Stmt Substitute(Stmt stmt, const Map& value_map) { - std::unordered_map vmap; - for (const auto& kv : value_map) { - vmap[kv.first.get()] = kv.second; - } - return Substitute(stmt, vmap); -} - -PrimExpr Substitute(PrimExpr expr, const Map& value_map) { - std::unordered_map vmap; - for (const auto& kv : value_map) { - vmap[kv.first.get()] = kv.second; - } - return Substitute(expr, vmap); -} - -class VarTouchVisitor : public ExprVisitor { - public: - void VisitExpr(const PrimExpr& e) final { - if (use_var_) return; - ExprVisitor::VisitExpr(e); - } - - void VisitExpr_(const VarNode* op) final { - Handle(op); - } - - void VisitExpr_(const LoadNode* op) final { - Handle(op->buffer_var.get()); - ExprVisitor::VisitExpr_(op); - } - - virtual void Handle(const VarNode* var) = 0; - - bool use_var_{false}; -}; - -class ExprUseVarVisitor : public VarTouchVisitor { - public: - explicit ExprUseVarVisitor(const VarNode* var) - : var_(var) {} - - void Handle(const VarNode* var) final { - if (var == var_) use_var_ = true; - } - private: - const VarNode* var_; -}; - -class ExprUseVSetVisitor : public VarTouchVisitor { - public: - explicit ExprUseVSetVisitor( - const std::unordered_set& vset) - : vset_(vset) {} - - void Handle(const VarNode* var) final { - if (vset_.count(var)) use_var_ = true; - } - private: - const std::unordered_set& vset_; -}; - -bool ExprUseVar(const PrimExpr& e, const Var& v) { - ExprUseVarVisitor visitor(v.get()); - visitor(e); - return visitor.use_var_; -} - -bool ExprUseVar(const PrimExpr& e, - const std::unordered_set& vset) { - ExprUseVSetVisitor visitor(vset); - visitor(e); - return visitor.use_var_; -} - -} // namespace tir -} // namespace tvm diff --git a/src/tir/pass/arg_binder.cc b/src/tir/transforms/arg_binder.cc similarity index 58% rename from src/tir/pass/arg_binder.cc rename to src/tir/transforms/arg_binder.cc index c684b9e68038..ae7065d94d80 100644 --- a/src/tir/pass/arg_binder.cc +++ b/src/tir/transforms/arg_binder.cc @@ -21,35 +21,32 @@ * \file arg_binder.cc * \brief Helper utility to match and bind arguments. */ -#include -#include +#include "arg_binder.h" + #include +#include +#include + #include "ir_util.h" -#include "arg_binder.h" -#include "../../arith/compute_expr.h" namespace tvm { namespace tir { -void BinderAddAssert(PrimExpr cond, - const std::string& arg_name, +void BinderAddAssert(arith::Analyzer* ana, PrimExpr cond, const std::string& arg_name, std::vector* asserts) { - PrimExpr scond = Simplify(cond); + PrimExpr scond = ana->Simplify(cond); if (is_zero(scond)) { - LOG(FATAL) << "Bind have an unmet assertion: " - << cond << ", " << " on argument " << arg_name; + LOG(FATAL) << "Bind have an unmet assertion: " << cond << ", " + << " on argument " << arg_name; } if (!is_one(scond)) { std::ostringstream os; os << "Argument " << arg_name << " has an unsatisfied constraint"; - asserts->emplace_back(AssertStmtNode::make(scond, tvm::tir::StringImmNode::make(os.str()), - EvaluateNode::make(0))); + asserts->emplace_back(AssertStmt(scond, tvm::tir::StringImm(os.str()), Evaluate(0))); } } -bool ArgBinder::Bind_(const PrimExpr& arg, - const PrimExpr& value, - const std::string& arg_name, +bool ArgBinder::Bind_(const PrimExpr& arg, const PrimExpr& value, const std::string& arg_name, bool with_lets) { CHECK_EQ(arg.dtype(), value.dtype()); if (const VarNode* v = arg.as()) { @@ -59,32 +56,28 @@ bool ArgBinder::Bind_(const PrimExpr& arg, defs_.emplace_back(v_arg); if (with_lets) { (*def_map_)[v] = arg; - init_nest_.emplace_back(LetStmtNode::make(v_arg, value, EvaluateNode::make(0))); + init_nest_.emplace_back(LetStmt(v_arg, value, Evaluate(0))); } else { (*def_map_)[v] = value; } return true; } else { - BinderAddAssert(it->second == value, arg_name, &asserts_); + BinderAddAssert(&analyzer_, it->second == value, arg_name, &asserts_); } } else { - BinderAddAssert(arg == value, arg_name, &asserts_); + BinderAddAssert(&analyzer_, arg == value, arg_name, &asserts_); } return false; } -void ArgBinder::Bind(const PrimExpr& arg, - const PrimExpr& value, - const std::string& arg_name, +void ArgBinder::Bind(const PrimExpr& arg, const PrimExpr& value, const std::string& arg_name, bool with_let) { Bind_(arg, value, arg_name, with_let); } -void ArgBinder::BindArray(const Array& arg, - const Array& value, +void ArgBinder::BindArray(const Array& arg, const Array& value, const std::string& arg_name) { - CHECK_EQ(arg.size(), value.size()) - << "Argument " << arg_name << " array size mismatch"; + CHECK_EQ(arg.size(), value.size()) << "Argument " << arg_name << " array size mismatch"; for (size_t i = 0; i < arg.size(); ++i) { std::ostringstream os; os << arg_name << "[" << i << "]"; @@ -92,16 +85,11 @@ void ArgBinder::BindArray(const Array& arg, } } -void ArgBinder::BindBuffer(const Buffer& arg, - const Buffer& value, - const std::string& arg_name, +void ArgBinder::BindBuffer(const Buffer& arg, const Buffer& value, const std::string& arg_name, bool fuzzy_match) { - CHECK_EQ(arg->scope, value->scope) - << "Argument " << arg_name - << " Buffer bind scope mismatch"; + CHECK_EQ(arg->scope, value->scope) << "Argument " << arg_name << " Buffer bind scope mismatch"; CHECK_EQ(arg->dtype, value->dtype) - << "Argument " << arg_name - << " Buffer bind data type mismatch"; + << "Argument " << arg_name << " Buffer bind data type mismatch"; if (value->data_alignment % arg->data_alignment != 0) { LOG(WARNING) << "Trying to bind buffer to another one with lower alignment requirement " << " required_alignment=" << arg->data_alignment @@ -121,8 +109,8 @@ void ArgBinder::BindBuffer(const Buffer& arg, PrimExpr offset = value->elem_offset; PrimExpr factor = make_const(offset.dtype(), arg->offset_factor); PrimExpr zero = make_zero(offset.dtype()); - BinderAddAssert(truncmod(offset, factor) == zero, - arg_name + ".elem_offset", &asserts_); + BinderAddAssert(&analyzer_, truncmod(offset, factor) == zero, arg_name + ".elem_offset", + &asserts_); } } @@ -130,9 +118,8 @@ void ArgBinder::BindBuffer(const Buffer& arg, CHECK(fuzzy_match) << "Argument " << arg_name << " size mismatch"; size_t diff = value->shape.size() - arg->shape.size(); for (size_t i = 0; i < diff; ++i) { - CHECK(is_one(Simplify(value->shape[i]))) - << "Argument " << arg_name << " shape mismatch" - << arg->shape << " vs " << value->shape; + CHECK(is_one(analyzer_.Simplify(value->shape[i]))) + << "Argument " << arg_name << " shape mismatch" << arg->shape << " vs " << value->shape; } for (size_t i = 0; i < arg->shape.size(); ++i) { std::ostringstream os; @@ -158,40 +145,33 @@ inline PrimExpr TVMArrayGet(DataType t, Var arr, intrinsic::TVMStructFieldKind k return TVMStructGet(t, arr, 0, kind); } -void ArgBinder::BindDLTensor(const Buffer& buffer, - const PrimExpr& device_type, - const PrimExpr& device_id, - const Var& handle, +void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, + const PrimExpr& device_id, const Var& handle, const std::string& arg_name) { const DataType tvm_shape_type = DataType::ShapeIndex(); const DataType tvm_ndim_type = DataType::Int(32); - const Stmt nop = EvaluateNode::make(0); + const Stmt nop = Evaluate(0); // dimension checks PrimExpr v_ndim = TVMArrayGet(tvm_ndim_type, handle, intrinsic::kArrNDim); - PrimExpr a_ndim = make_const(tvm_ndim_type, - static_cast(buffer->shape.size())); + PrimExpr a_ndim = make_const(tvm_ndim_type, static_cast(buffer->shape.size())); std::ostringstream ndim_err_msg; - ndim_err_msg << arg_name - << ".ndim is expected to equal " - << buffer->shape.size(); - auto msg = tvm::tir::StringImmNode::make(ndim_err_msg.str()); - asserts_.emplace_back(AssertStmtNode::make(a_ndim == v_ndim, msg, nop)); + ndim_err_msg << arg_name << ".ndim is expected to equal " << buffer->shape.size(); + auto msg = tvm::tir::StringImm(ndim_err_msg.str()); + asserts_.emplace_back(AssertStmt(a_ndim == v_ndim, msg, nop)); // type checks DataType dtype = buffer->dtype; std::ostringstream type_err_msg; type_err_msg << arg_name << ".dtype is expected to be " << dtype; PrimExpr cond = (TVMArrayGet(DataType::UInt(8), handle, intrinsic::kArrTypeCode) == - IntImm(DataType::UInt(8), dtype.code()) && - TVMArrayGet(DataType::UInt(8), handle, intrinsic::kArrTypeBits) == - IntImm(DataType::UInt(8), dtype.bits()) && - TVMArrayGet(DataType::UInt(16), handle, intrinsic::kArrTypeLanes) == - IntImm(DataType::UInt(16), dtype.lanes())); - if (!(dtype == DataType::Int(4) || - dtype == DataType::UInt(4) || - dtype == DataType::Int(1))) { - auto type_msg = tvm::tir::StringImmNode::make(type_err_msg.str()); - asserts_.emplace_back(AssertStmtNode::make(a_ndim == v_ndim, msg, nop)); - asserts_.emplace_back(AssertStmtNode::make(cond, type_msg, nop)); + IntImm(DataType::UInt(8), dtype.code()) && + TVMArrayGet(DataType::UInt(8), handle, intrinsic::kArrTypeBits) == + IntImm(DataType::UInt(8), dtype.bits()) && + TVMArrayGet(DataType::UInt(16), handle, intrinsic::kArrTypeLanes) == + IntImm(DataType::UInt(16), dtype.lanes())); + if (!(dtype == DataType::Int(4) || dtype == DataType::UInt(4) || dtype == DataType::Int(1))) { + auto type_msg = tvm::tir::StringImm(type_err_msg.str()); + asserts_.emplace_back(AssertStmt(a_ndim == v_ndim, msg, nop)); + asserts_.emplace_back(AssertStmt(cond, type_msg, nop)); } // data field if (Bind_(buffer->data, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrData), @@ -199,38 +179,32 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, Var vptr(buffer->data); def_handle_dtype_.Set(vptr, tir::TypeAnnotation(buffer->dtype)); // mark alignment of external bufs - init_nest_.emplace_back(AttrStmtNode::make( - vptr, tir::attr::storage_alignment, - IntImm(DataType::Int(32), buffer->data_alignment), nop)); + init_nest_.emplace_back(AttrStmt(vptr, tir::attr::storage_alignment, + IntImm(DataType::Int(32), buffer->data_alignment), nop)); } Var v_shape(arg_name + ".shape", DataType::Handle()); def_handle_dtype_.Set(v_shape, make_const(tvm_shape_type, 0)); - init_nest_.emplace_back(LetStmtNode::make( - v_shape, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrShape), nop)); + init_nest_.emplace_back( + LetStmt(v_shape, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrShape), nop)); for (size_t k = 0; k < buffer->shape.size(); ++k) { - if (dtype == DataType::Int(4) || - dtype == DataType::UInt(4) || - dtype == DataType::Int(1)) { + if (dtype == DataType::Int(4) || dtype == DataType::UInt(4) || dtype == DataType::Int(1)) { break; } std::ostringstream field_name; field_name << v_shape->name_hint << '[' << k << ']'; Bind_(buffer->shape[k], cast(buffer->shape[k].dtype(), - LoadNode::make(tvm_shape_type, v_shape, - IntImm(DataType::Int(32), k), const_true(1))), + Load(tvm_shape_type, v_shape, IntImm(DataType::Int(32), k), const_true(1))), field_name.str(), true); } // strides field Var v_strides(arg_name + ".strides", DataType::Handle()); def_handle_dtype_.Set(v_strides, tir::TypeAnnotation(tvm_shape_type)); - init_nest_.emplace_back(LetStmtNode::make( - v_strides, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrStrides), - nop)); - PrimExpr is_null = CallNode::make( - DataType::Bool(1), intrinsic::tvm_handle_is_null, - {v_strides}, CallNode::PureIntrinsic); + init_nest_.emplace_back( + LetStmt(v_strides, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrStrides), nop)); + PrimExpr is_null = + Call(DataType::Bool(1), intrinsic::tvm_handle_is_null, {v_strides}, CallNode::PureIntrinsic); if (buffer->strides.size() == 0) { // Assert the buffer is compact DataType stype = buffer->DefaultIndexType(); @@ -238,10 +212,8 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, Array conds; for (size_t i = buffer->shape.size(); i != 0; --i) { size_t k = i - 1; - PrimExpr svalue = cast( - stype, - LoadNode::make(tvm_shape_type, v_strides, - IntImm(DataType::Int(32), k), const_true(1))); + PrimExpr svalue = + cast(stype, Load(tvm_shape_type, v_strides, IntImm(DataType::Int(32), k), const_true(1))); conds.push_back(expect_stride == svalue); expect_stride = expect_stride * buffer->shape[k]; } @@ -249,12 +221,11 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, stride_err_msg << arg_name << ".strides:" << " expected to be compact array"; if (conds.size() != 0) { - auto stride_msg = tvm::tir::StringImmNode::make(stride_err_msg.str()); - Stmt check = - AssertStmtNode::make(arith::ComputeReduce(conds, PrimExpr()), - stride_msg, EvaluateNode::make(0)); - check = IfThenElseNode::make(NotNode::make(is_null), check, Stmt()); - asserts_.emplace_back(SeqStmt({check, EvaluateNode::make(0)})); + auto stride_msg = tvm::tir::StringImm(stride_err_msg.str()); + auto fand = [](PrimExpr a, PrimExpr b) { return a && b; }; + Stmt check = AssertStmt(foldl(fand, const_true(1), conds), stride_msg, Evaluate(0)); + check = IfThenElse(Not(is_null), check, Stmt()); + asserts_.emplace_back(SeqStmt({check, Evaluate(0)})); } } else if (buffer->buffer_type == kAutoBroadcast) { DataType stype = buffer->DefaultIndexType(); @@ -263,36 +234,35 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, size_t k = i - 1; std::ostringstream field_name; field_name << v_strides->name_hint << '[' << k << ']'; - PrimExpr value = cast(buffer->shape[k].dtype(), - LoadNode::make(tvm_shape_type, v_strides, - IntImm(DataType::Int(32), k), const_true(1))); + PrimExpr value = + cast(buffer->shape[k].dtype(), + Load(tvm_shape_type, v_strides, IntImm(DataType::Int(32), k), const_true(1))); value = tvm::if_then_else(is_null, stride, value); value = tvm::if_then_else(buffer->shape[k] == 1, 0, value); Bind_(buffer->strides[k], value, field_name.str(), true); - stride = Simplify(stride * buffer->shape[k]); + stride = analyzer_.Simplify(stride * buffer->shape[k]); } } else { std::ostringstream stride_null_err_msg; stride_null_err_msg << arg_name << ".strides: expected non-null strides."; - asserts_.emplace_back(AssertStmtNode::make( - NotNode::make(is_null), tvm::tir::StringImmNode::make(stride_null_err_msg.str()), nop)); + asserts_.emplace_back( + AssertStmt(Not(is_null), tvm::tir::StringImm(stride_null_err_msg.str()), nop)); for (size_t k = 0; k < buffer->strides.size(); ++k) { std::ostringstream field_name; field_name << v_strides->name_hint << '[' << k << ']'; Bind_(buffer->strides[k], cast(buffer->shape[k].dtype(), - LoadNode::make(tvm_shape_type, v_strides, - IntImm(DataType::Int(32), k), const_true(1))), + Load(tvm_shape_type, v_strides, IntImm(DataType::Int(32), k), const_true(1))), field_name.str(), true); } } // Byte_offset field. int data_bytes = GetVectorBytes(buffer->dtype); - int64_t const_offset; - if (arith::GetConst(buffer->elem_offset, &const_offset)) { - Bind_(make_const(DataType::UInt(64), const_offset * data_bytes), - TVMArrayGet(DataType::UInt(64), handle, intrinsic::kArrByteOffset), + + if (const auto* const_offset = buffer->elem_offset.as()) { + Bind_(make_const(DataType::UInt(64), const_offset->value * data_bytes), + TVMArrayGet(DataType::UInt(64), handle, intrinsic::kArrByteOffset), arg_name + ".byte_offset", true); } else { if (Bind_(buffer->elem_offset, @@ -304,16 +274,15 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, PrimExpr offset = buffer->elem_offset; PrimExpr factor = make_const(offset.dtype(), buffer->offset_factor); PrimExpr zero = make_zero(offset.dtype()); - BinderAddAssert(truncmod(offset, factor) == zero, arg_name + ".elem_offset", &asserts_); + BinderAddAssert(&analyzer_, truncmod(offset, factor) == zero, arg_name + ".elem_offset", + &asserts_); } } } // device info. - Bind_(device_type, - TVMArrayGet(DataType::Int(32), handle, intrinsic::kArrDeviceType), + Bind_(device_type, TVMArrayGet(DataType::Int(32), handle, intrinsic::kArrDeviceType), arg_name + ".device_type", true); - Bind_(device_id, - TVMArrayGet(DataType::Int(32), handle, intrinsic::kArrDeviceId), + Bind_(device_id, TVMArrayGet(DataType::Int(32), handle, intrinsic::kArrDeviceId), arg_name + ".device_id", true); } diff --git a/src/tir/pass/arg_binder.h b/src/tir/transforms/arg_binder.h similarity index 78% rename from src/tir/pass/arg_binder.h rename to src/tir/transforms/arg_binder.h index dfeb82853529..657ebdbec134 100644 --- a/src/tir/pass/arg_binder.h +++ b/src/tir/transforms/arg_binder.h @@ -21,14 +21,16 @@ * \file arg_binder.h * \brief Helper utility to match and bind arguments. */ -#ifndef TVM_TIR_PASS_ARG_BINDER_H_ -#define TVM_TIR_PASS_ARG_BINDER_H_ +#ifndef TVM_TIR_TRANSFORMS_ARG_BINDER_H_ +#define TVM_TIR_TRANSFORMS_ARG_BINDER_H_ -#include +#include #include +#include + #include -#include #include +#include namespace tvm { namespace tir { @@ -61,10 +63,7 @@ class ArgBinder { * \param def_map A definition map that contains definition of known variables. * ArgBinder will update this def_map when adding new definitions. */ - explicit ArgBinder( - std::unordered_map* def_map) - : def_map_(def_map) { - } + explicit ArgBinder(std::unordered_map* def_map) : def_map_(def_map) {} /*! * \brief Try to bind arg to value, generate constraint if necessary. * \param arg The argument to be binded. @@ -72,9 +71,7 @@ class ArgBinder { * \param arg_name argument name. * \param with_let Whether add lets during bind */ - void Bind(const PrimExpr& arg, - const PrimExpr& value, - const std::string& arg_name, + void Bind(const PrimExpr& arg, const PrimExpr& value, const std::string& arg_name, bool with_let = false); /*! * \brief Bind array to array @@ -82,19 +79,17 @@ class ArgBinder { * \param value The target expression value * \param arg_name argument name. */ - void BindArray(const Array& arg, - const Array& value, + void BindArray(const Array& arg, const Array& value, const std::string& arg_name); /*! * \brief Bind symbolic buffer to another symbolic buffer * \param arg The argument to be binded. * \param value The target expression value * \param arg_name argument name. - * \param fuzzy_match If enabled, we allow value's dimension to be smaller than arg, as long as arg's higher dimensions are of 1. + * \param fuzzy_match If enabled, we allow value's dimension to be smaller than arg, as long as + * arg's higher dimensions are of 1. */ - void BindBuffer(const Buffer& arg, - const Buffer& value, - const std::string& arg_name, + void BindBuffer(const Buffer& arg, const Buffer& value, const std::string& arg_name, bool fuzzy_match); /*! * \brief Bind symbolic buffer to a DLTensor handle. @@ -104,20 +99,13 @@ class ArgBinder { * \param handle The DLTensor handle. * \param arg_name argument name. */ - void BindDLTensor(const Buffer& buffer, - const PrimExpr& device_type, - const PrimExpr& device_id, - const Var& handle, - const std::string& arg_name); + void BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, const PrimExpr& device_id, + const Var& handle, const std::string& arg_name); /*! \return The defs generated in binding. */ - const std::vector& defs() const { - return defs_; - } + const std::vector& defs() const { return defs_; } /*! \return The asserts generated in binding */ - const std::vector& asserts() const { - return asserts_; - } + const std::vector& asserts() const { return asserts_; } /*! * \brief Initialization nest generated * This is only non-empty when BindDLTensor is called. @@ -129,19 +117,13 @@ class ArgBinder { * Let statement is usually generated when bind to DLTensor and memory load is involved. * \return The initialization nest generated during binding. */ - const std::vector& init_nest() const { - return init_nest_; - } + const std::vector& init_nest() const { return init_nest_; } /*! \return Handle data type of the data */ - const Map& def_handle_dtype() const { - return def_handle_dtype_; - } + const Map& def_handle_dtype() const { return def_handle_dtype_; } private: // Internal bind function - bool Bind_(const PrimExpr& arg, - const PrimExpr& value, - const std::string& arg_name, + bool Bind_(const PrimExpr& arg, const PrimExpr& value, const std::string& arg_name, bool with_lets); /*! \brief The definition map, can be uses to substitute */ std::unordered_map* def_map_; @@ -153,7 +135,9 @@ class ArgBinder { Map def_handle_dtype_; /*! \brief asserts generated */ std::vector asserts_; + /*! \brief internal analyzer. */ + arith::Analyzer analyzer_; }; } // namespace tir } // namespace tvm -#endif // TVM_TIR_PASS_ARG_BINDER_H_ +#endif // TVM_TIR_TRANSFORMS_ARG_BINDER_H_ diff --git a/src/tir/pass/bound_checker.cc b/src/tir/transforms/bound_checker.cc similarity index 64% rename from src/tir/pass/bound_checker.cc rename to src/tir/transforms/bound_checker.cc index ee24d0f77673..94464a04f912 100644 --- a/src/tir/pass/bound_checker.cc +++ b/src/tir/transforms/bound_checker.cc @@ -22,12 +22,16 @@ */ // Instrument checkers for out of the bounds access. +#include +#include #include -#include +#include #include -#include +#include + #include #include +#include namespace tvm { namespace tir { @@ -38,20 +42,19 @@ class BoundCollector : public StmtVisitor { void VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == tir::attr::buffer_bound) { - if (const VarNode *key = op->node.as()) { + if (const VarNode* key = op->node.as()) { mem_to_shape[key] = op->value; } } StmtVisitor::VisitStmt_(op); } // Hashtable which maps buffer_var to shape. - std::unordered_map mem_to_shape; + std::unordered_map mem_to_shape; }; class BoundChecker : public StmtExprMutator { public: - explicit BoundChecker( - const std::unordered_map &mem_to_shape) + explicit BoundChecker(const std::unordered_map& mem_to_shape) : mem_to_shape_(mem_to_shape) {} Stmt VisitStmt_(const AllocateNode* op) final { @@ -82,12 +85,10 @@ class BoundChecker : public StmtExprMutator { if (store_scope_bound_collector_.size()) { PrimExpr condition = MakeCondition(); if (!condition.as()) { - Stmt nop = EvaluateNode::make(1); - Stmt then_case = - StoreNode::make(op->buffer_var, op->value, op->index, op->predicate); - Stmt else_case = - AssertStmtNode::make(condition, StringImmNode::make(error_message_), nop); - Stmt body = IfThenElseNode::make(condition, then_case, else_case); + Stmt nop = Evaluate(1); + Stmt then_case = Store(op->buffer_var, op->value, op->index, op->predicate); + Stmt else_case = AssertStmt(condition, StringImm(error_message_), nop); + Stmt body = IfThenElse(condition, then_case, else_case); return body; } } @@ -106,9 +107,7 @@ class BoundChecker : public StmtExprMutator { return (buffer_var.defined() && mem_to_shape_.count(buffer_var.get())); } - void Update(const Var& buffer_var, - const Array& new_shape, - const DataType& type) { + void Update(const Var& buffer_var, const Array& new_shape, const DataType& type) { // Sanity check at first. if (!new_shape.size()) { return; @@ -122,12 +121,12 @@ class BoundChecker : public StmtExprMutator { } // Scalarize the shape. - PrimExpr shape = MulNode::make(make_const(DataType::UInt(64), type.lanes()), - CastNode::make(DataType::UInt(64), new_shape[0])); + PrimExpr shape = + Mul(make_const(DataType::UInt(64), type.lanes()), Cast(DataType::UInt(64), new_shape[0])); for (size_t i = 1; i < new_shape.size(); ++i) { // Cast to unsigned to avoid integer overlow at frist. - shape = MulNode::make(shape, MulNode::make(make_const(DataType::UInt(64), type.lanes()), - CastNode::make(DataType::UInt(64), new_shape[i]))); + shape = Mul(shape, Mul(make_const(DataType::UInt(64), type.lanes()), + Cast(DataType::UInt(64), new_shape[i]))); } mem_to_shape_[buffer_var.get()] = shape; } @@ -137,23 +136,21 @@ class BoundChecker : public StmtExprMutator { return false; } - if (const RampNode *ramp_index = index.as()) { - return ramp_index->base.defined() && - ramp_index->base.dtype().is_scalar() && - ramp_index->stride.defined() && - ramp_index->stride.dtype().is_scalar() && (ramp_index->lanes > 0); + if (const RampNode* ramp_index = index.as()) { + return ramp_index->base.defined() && ramp_index->base.dtype().is_scalar() && + ramp_index->stride.defined() && ramp_index->stride.dtype().is_scalar() && + (ramp_index->lanes > 0); } return true; } bool CanInstrument(const PrimExpr& index, const Var& buffer_var) const { - return buffer_var.defined() && mem_to_shape_.count(buffer_var.get()) && - IndexIsValid(index) && !unsafe_rewritten_; + return buffer_var.defined() && mem_to_shape_.count(buffer_var.get()) && IndexIsValid(index) && + !unsafe_rewritten_; } void Collect(PrimExpr index, Var buffer_var) { - store_scope_bound_collector_.push_back( - std::make_pair(index, mem_to_shape_[buffer_var.get()])); + store_scope_bound_collector_.push_back(std::make_pair(index, mem_to_shape_[buffer_var.get()])); } PrimExpr MakeCondition() { @@ -163,30 +160,26 @@ class BoundChecker : public StmtExprMutator { PrimExpr index = buffer_to_mem.first; PrimExpr upper_bound = buffer_to_mem.second; - if (const RampNode *ramp_index = index.as()) { + if (const RampNode* ramp_index = index.as()) { // In case index is base + stride * i. // Non inclusive range. - index = AddNode::make( - ramp_index->base, - MulNode::make(ramp_index->stride, make_const(ramp_index->stride.dtype(), - ramp_index->lanes - 1))); + index = Add(ramp_index->base, Mul(ramp_index->stride, make_const(ramp_index->stride.dtype(), + ramp_index->lanes - 1))); } // Try to simplify index and bound. - index = tir::Simplify(index); - upper_bound = tir::Simplify(upper_bound); + index = analyzer_.Simplify(index); + upper_bound = analyzer_.Simplify(upper_bound); // Cast to the same type - signed, to be able to check lower bound. - index = CastNode::make(DataType::Int(64), index); - upper_bound = CastNode::make(DataType::Int(64), upper_bound); + index = Cast(DataType::Int(64), index); + upper_bound = Cast(DataType::Int(64), upper_bound); // Looks like a lower bound should always be zero after normalization. PrimExpr lower_bound = make_zero(DataType::Int(64)); - PrimExpr current_condition = - AndNode::make(GENode::make(index, lower_bound), LTNode::make(index, upper_bound)); - condition = - !i ? current_condition : AndNode::make(condition, current_condition); + PrimExpr current_condition = And(GE(index, lower_bound), LT(index, upper_bound)); + condition = !i ? current_condition : And(condition, current_condition); } return condition; } @@ -198,9 +191,11 @@ class BoundChecker : public StmtExprMutator { // Pool which collects the pair of index and shape for specific store/load. std::vector> store_scope_bound_collector_; // Error message. - const char *const error_message_ = "OUT OF THE BOUNDS"; + const char* const error_message_ = "OUT OF THE BOUNDS"; // Hashtable which maps buffer_var to shape. - std::unordered_map mem_to_shape_; + std::unordered_map mem_to_shape_; + // internal analyzer + arith::Analyzer analyzer_; }; Stmt InstrumentBoundCheckers(Stmt stmt) { @@ -209,5 +204,25 @@ Stmt InstrumentBoundCheckers(Stmt stmt) { bound_collector(stmt); return BoundChecker(bound_collector.mem_to_shape)(std::move(stmt)); } + +namespace transform { + +Pass InstrumentBoundCheckers() { + auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + BoundCollector bound_collector; + // At first walk recursively and collect bound attributes. + bound_collector(n->body); + n->body = BoundChecker(bound_collector.mem_to_shape)(std::move(n->body)); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.InstrumentBoundCheckers", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.InstrumentBoundCheckers") + .set_body_typed(InstrumentBoundCheckers); + +} // namespace transform + } // namespace tir } // namespace tvm diff --git a/src/tir/transforms/combine_context_call.cc b/src/tir/transforms/combine_context_call.cc index f8e14a2a8fb3..73bf4c6f6db2 100644 --- a/src/tir/transforms/combine_context_call.cc +++ b/src/tir/transforms/combine_context_call.cc @@ -22,15 +22,13 @@ * * \file combine_context_call.cc */ +#include +#include +#include #include #include #include #include -#include -#include -#include - -#include #include @@ -45,7 +43,7 @@ class ContextCallCombiner final : public StmtExprMutator { if (op->is_intrinsic(intrinsic::tvm_thread_context)) { CHECK_EQ(op->args.size(), 1U); PrimExpr ctx = op->args[0]; - auto it = ctx_map_.find(ctx); + auto it = ctx_map_.find(ctx); if (it != ctx_map_.end()) { return it->second; } else { @@ -66,8 +64,7 @@ class ContextCallCombiner final : public StmtExprMutator { } Stmt VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == attr::thread_extent || - op->attr_key == attr::coproc_uop_scope) { + if (op->attr_key == attr::thread_extent || op->attr_key == attr::coproc_uop_scope) { // Map of comparison expression to variable std::unordered_map temp; std::swap(temp, ctx_map_); @@ -92,16 +89,13 @@ class ContextCallCombiner final : public StmtExprMutator { } } - Stmt Combine(Stmt stmt) { - return BuildContext(ctx_map_, this->VisitStmt(stmt)); - } + Stmt Combine(Stmt stmt) { return BuildContext(ctx_map_, this->VisitStmt(stmt)); } private: static Stmt BuildContext( - const std::unordered_map& cmap, - Stmt body) { + const std::unordered_map& cmap, Stmt body) { for (const auto& kv : cmap) { - body = LetStmtNode::make(kv.second, kv.first, body); + body = LetStmt(kv.second, kv.first, body); } return body; } @@ -109,7 +103,6 @@ class ContextCallCombiner final : public StmtExprMutator { std::unordered_map ctx_map_; }; - namespace transform { Pass CombineContextCall() { @@ -121,8 +114,7 @@ Pass CombineContextCall() { return CreatePrimFuncPass(pass_func, 0, "tir.CombineContextCall", {}); } -TVM_REGISTER_GLOBAL("tir.transform.CombineContextCall") -.set_body_typed(CombineContextCall); +TVM_REGISTER_GLOBAL("tir.transform.CombineContextCall").set_body_typed(CombineContextCall); } // namespace transform } // namespace tir diff --git a/src/tir/pass/coproc_sync.cc b/src/tir/transforms/coproc_sync.cc similarity index 80% rename from src/tir/pass/coproc_sync.cc rename to src/tir/transforms/coproc_sync.cc index 38b7798eae11..384dbcb0caee 100644 --- a/src/tir/pass/coproc_sync.cc +++ b/src/tir/transforms/coproc_sync.cc @@ -20,11 +20,14 @@ /*! * \file coproc_sync.cc */ +#include #include -#include #include +#include + #include #include + #include "ir_util.h" #include "storage_access.h" @@ -88,11 +91,9 @@ class CoProcTouchedBuffer : public StmtExprVisitor { // Synchronization planning with co-processor. class CoProcSyncPlanner : public StorageAccessVisitor { public: - explicit CoProcSyncPlanner( - const std::unordered_set& touched, - const std::string& coproc_name) - : touched_(touched), coproc_name_(coproc_name) { - } + explicit CoProcSyncPlanner(const std::unordered_set& touched, + const std::string& coproc_name) + : touched_(touched), coproc_name_(coproc_name) {} void Plan(const Stmt& stmt) { this->VisitStmt(stmt); @@ -106,22 +107,19 @@ class CoProcSyncPlanner : public StorageAccessVisitor { std::unordered_map > sync_; protected: - bool Enabled(const VarNode* buf, - const StorageScope& scope) const final { + bool Enabled(const VarNode* buf, const StorageScope& scope) const final { return touched_.count(buf); } // Plan the sync - std::vector Summarize( - std::vector seq, const ForNode* loop) final { + std::vector Summarize(std::vector seq, const ForNode* loop) final { return PlanSync(seq, loop, false); } private: // Plan write synchronization if write is not coherent - std::vector PlanSync( - std::vector seq, const ForNode* loop, - bool force_sync_at_end) { + std::vector PlanSync(std::vector seq, const ForNode* loop, + bool force_sync_at_end) { // detect write barriers // access by the co-processor. std::vector co_access; @@ -130,8 +128,7 @@ class CoProcSyncPlanner : public StorageAccessVisitor { auto find_conflict = [&](const AccessEntry& acc) { for (const AccessEntry& x : co_access) { if (x.buffer.same_as(acc.buffer) && - ((acc.type == kRead && x.type == kWrite) || - acc.type == kWrite)) { + ((acc.type == kRead && x.type == kWrite) || acc.type == kWrite)) { return true; } } @@ -142,7 +139,8 @@ class CoProcSyncPlanner : public StorageAccessVisitor { bool sync_write = false; for (const AccessEntry& acc : s.access) { if (acc.threads.size() == 0 && find_conflict(acc)) { - sync_write = true; break; + sync_write = true; + break; } if (acc.type == kSync) { co_access.clear(); @@ -168,7 +166,8 @@ class CoProcSyncPlanner : public StorageAccessVisitor { const StmtEntry& s = seq[i]; for (const AccessEntry& acc : s.access) { if (acc.threads.size() == 0 && find_conflict(acc)) { - sync_at_end = true; break; + sync_at_end = true; + break; } } if (sync_.count(s.stmt) || sync_at_end) break; @@ -196,10 +195,7 @@ class CoProcSyncPlanner : public StorageAccessVisitor { } std::vector GetSync(std::string sync_name) { - return {EvaluateNode::make(CallNode::make( - DataType::Int(32), - sync_name, - {}, CallNode::Intrinsic))}; + return {Evaluate(Call(DataType::Int(32), sync_name, {}, CallNode::Intrinsic))}; } const std::unordered_set& touched_; @@ -209,9 +205,8 @@ class CoProcSyncPlanner : public StorageAccessVisitor { // Detect memory barriers when coproc read/write memory class CoProcBarrierDetector : public StorageAccessVisitor { public: - explicit CoProcBarrierDetector( - const std::unordered_set& touched, - const std::string& coproc_name) + explicit CoProcBarrierDetector(const std::unordered_set& touched, + const std::string& coproc_name) : touched_(touched) { read_barrier_name_ = coproc_name + ".coproc_read_barrier"; write_barrier_name_ = coproc_name + ".coproc_write_barrier"; @@ -232,14 +227,12 @@ class CoProcBarrierDetector : public StorageAccessVisitor { std::unordered_map > barrier_after_; protected: - bool Enabled(const VarNode* buf, - const StorageScope& scope) const final { + bool Enabled(const VarNode* buf, const StorageScope& scope) const final { return touched_.count(buf); } // Plan the sync - std::vector Summarize( - std::vector seq, const ForNode* loop) final { + std::vector Summarize(std::vector seq, const ForNode* loop) final { if (read_barrier_) { return PlanReadBarrier(seq, loop); } else { @@ -249,17 +242,15 @@ class CoProcBarrierDetector : public StorageAccessVisitor { private: // Plan write barrier at Read after write point. - std::vector PlanWriteBarrier( - std::vector seq, const ForNode* loop) { + std::vector PlanWriteBarrier(std::vector seq, const ForNode* loop) { std::vector read_seq; std::unordered_map > write_set; auto fupdate = [&](size_t i, const AccessEntry& acc) { - auto it = write_set.find(acc.buffer.get()); + auto it = write_set.find(acc.buffer.get()); if (it != write_set.end()) { CHECK_NE(i, 0U); - barrier_after_[seq[i - 1].stmt].push_back( - MakeBarrier(write_barrier_name_, it->second)); + barrier_after_[seq[i - 1].stmt].push_back(MakeBarrier(write_barrier_name_, it->second)); write_set.erase(it); } }; @@ -283,23 +274,21 @@ class CoProcBarrierDetector : public StorageAccessVisitor { fupdate(seq.size(), acc); } } - for (const auto &kv : write_set) { + for (const auto& kv : write_set) { read_seq.insert(read_seq.end(), kv.second.begin(), kv.second.end()); } return read_seq; } - std::vector PlanReadBarrier( - std::vector seq, const ForNode* loop) { + std::vector PlanReadBarrier(std::vector seq, const ForNode* loop) { std::vector write_seq; std::unordered_map > read_set; auto fupdate = [&](size_t i, const AccessEntry& acc) { - auto it = read_set.find(acc.buffer.get()); + auto it = read_set.find(acc.buffer.get()); if (it != read_set.end()) { CHECK_NE(i, seq.size()); - barrier_before_[seq[i].stmt].push_back( - MakeBarrier(read_barrier_name_, it->second)); + barrier_before_[seq[i].stmt].push_back(MakeBarrier(read_barrier_name_, it->second)); read_set.erase(it); } }; @@ -324,7 +313,7 @@ class CoProcBarrierDetector : public StorageAccessVisitor { fupdate(0, acc); } } - for (const auto &kv : read_set) { + for (const auto& kv : read_set) { write_seq.insert(write_seq.end(), kv.second.begin(), kv.second.end()); } return write_seq; @@ -339,13 +328,12 @@ class CoProcBarrierDetector : public StorageAccessVisitor { } Range none; Range r = arith::Union(wset).cover_range(none); - CHECK(r.defined()) - << "Cannot deduce write range of " << wvec[0].buffer; + CHECK(r.defined()) << "Cannot deduce write range of " << wvec[0].buffer; PrimExpr min = r->min; PrimExpr extent = r->extent; - return EvaluateNode::make(CallNode::make( - DataType::Int(32), func, - {wvec[0].buffer, wvec[0].dtype.bits(), r->min, r->extent}, CallNode::Intrinsic)); + return Evaluate(Call(DataType::Int(32), func, + {wvec[0].buffer, wvec[0].dtype.bits(), r->min, r->extent}, + CallNode::Intrinsic)); } // Write barrier name bool read_barrier_{false}; @@ -354,12 +342,9 @@ class CoProcBarrierDetector : public StorageAccessVisitor { const std::unordered_set& touched_; }; - class CoProcInstDepDetector : public StmtVisitor { public: - explicit CoProcInstDepDetector( - const IterVar& coproc_axis, - const std::string& coproc_name) + explicit CoProcInstDepDetector(const IterVar& coproc_axis, const std::string& coproc_name) : coproc_axis_(coproc_axis) { sync_push_name_ = coproc_name + ".coproc_dep_push"; sync_pop_name_ = coproc_name + ".coproc_dep_pop"; @@ -374,8 +359,7 @@ class CoProcInstDepDetector : public StmtVisitor { } void VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == attr::coproc_scope && - op->node.same_as(coproc_axis_)) { + if (op->attr_key == attr::coproc_scope && op->node.same_as(coproc_axis_)) { const IntImmNode* ctx_id = op->value.as(); CHECK(ctx_id != nullptr); curr_state_.clear(); @@ -398,9 +382,7 @@ class CoProcInstDepDetector : public StmtVisitor { curr_state_.node = op; CHECK(first_state_.node != nullptr); // loop carry dependency - InjectSync(last_state_, first_state_, - &(curr_state_.exit_push), - &(curr_state_.enter_pop)); + InjectSync(last_state_, first_state_, &(curr_state_.exit_push), &(curr_state_.enter_pop)); curr_state_.enter_ctx = first_state_.enter_ctx; curr_state_.exit_ctx = last_state_.exit_ctx; } @@ -422,12 +404,8 @@ class CoProcInstDepDetector : public StmtVisitor { curr_state.node = op; MatchFixEnterPop(first_state_); MatchFixExitPush(last_state_); - curr_state.enter_ctx.insert( - first_state_.enter_ctx.begin(), - first_state_.enter_ctx.end()); - curr_state.exit_ctx.insert( - last_state_.exit_ctx.begin(), - last_state_.exit_ctx.end()); + curr_state.enter_ctx.insert(first_state_.enter_ctx.begin(), first_state_.enter_ctx.end()); + curr_state.exit_ctx.insert(last_state_.exit_ctx.begin(), last_state_.exit_ctx.end()); } first_state_.clear(); last_state_.clear(); @@ -438,12 +416,8 @@ class CoProcInstDepDetector : public StmtVisitor { curr_state.node = op; MatchFixEnterPop(first_state_); MatchFixExitPush(last_state_); - curr_state.enter_ctx.insert( - first_state_.enter_ctx.begin(), - first_state_.enter_ctx.end()); - curr_state.exit_ctx.insert( - last_state_.exit_ctx.begin(), - last_state_.exit_ctx.end()); + curr_state.enter_ctx.insert(first_state_.enter_ctx.begin(), first_state_.enter_ctx.end()); + curr_state.exit_ctx.insert(last_state_.exit_ctx.begin(), last_state_.exit_ctx.end()); } } // update in the trace. @@ -486,15 +460,14 @@ class CoProcInstDepDetector : public StmtVisitor { // record the push/pop sequence that could be possibly un-matched. // return the push/pop message at enter/exit of the Block // after considering the existing unmatcheded events and added events - void InjectSync(const SyncState& prev, - const SyncState& next, + void InjectSync(const SyncState& prev, const SyncState& next, std::vector >* prev_exit_push, std::vector >* next_enter_pop) { prev_exit_push->clear(); next_enter_pop->clear(); // quick path - if (prev.exit_push.size() == 0 && next.enter_pop.size() == 0 && - prev.exit_ctx.size() == 1 && next.enter_ctx.size() == 1) { + if (prev.exit_push.size() == 0 && next.enter_pop.size() == 0 && prev.exit_ctx.size() == 1 && + next.enter_ctx.size() == 1) { int from = *prev.exit_ctx.begin(); int to = *next.enter_ctx.begin(); if (from != to) { @@ -519,15 +492,11 @@ class CoProcInstDepDetector : public StmtVisitor { // policy 1 std::vector prev_after, next_before; for (const std::pair& p : pending) { - if (std::find(prev.exit_push.begin(), - prev.exit_push.end(), p) == - prev.exit_push.end()) { + if (std::find(prev.exit_push.begin(), prev.exit_push.end(), p) == prev.exit_push.end()) { vpush.push_back(p); prev_after.emplace_back(MakePush(p.first, p.second)); } - if (std::find(next.enter_pop.begin(), - next.enter_pop.end(), p) == - next.enter_pop.end()) { + if (std::find(next.enter_pop.begin(), next.enter_pop.end(), p) == next.enter_pop.end()) { vpop.push_back(p); next_before.emplace_back(MakePop(p.first, p.second)); } @@ -548,18 +517,18 @@ class CoProcInstDepDetector : public StmtVisitor { } } if (prev_after.size() != 0) { - auto &v1 = insert_after_[prev.node]; + auto& v1 = insert_after_[prev.node]; v1.insert(v1.end(), prev_after.begin(), prev_after.end()); } if (next_before.size() != 0) { - auto &v2 = insert_before_[next.node]; + auto& v2 = insert_before_[next.node]; v2.insert(v2.end(), next_before.begin(), next_before.end()); } } void MatchFixEnterPop(const SyncState& state) { if (state.enter_pop.size() == 0) return; - auto &vec = insert_before_[state.node]; + auto& vec = insert_before_[state.node]; for (const std::pair& p : state.enter_pop) { vec.push_back(MakePush(p.first, p.second)); } @@ -567,7 +536,7 @@ class CoProcInstDepDetector : public StmtVisitor { void MatchFixExitPush(const SyncState& state) { if (state.exit_push.size() == 0) return; - auto &vec = insert_after_[state.node]; + auto& vec = insert_after_[state.node]; for (const std::pair& p : state.exit_push) { vec.push_back(MakePop(p.first, p.second)); } @@ -586,16 +555,14 @@ class CoProcInstDepDetector : public StmtVisitor { } Stmt MakePush(int from, int to) { - return EvaluateNode::make(CallNode::make( - DataType::Int(32), sync_push_name_, - {make_const(DataType::Int(32), from), make_const(DataType::Int(32), to)}, - CallNode::Intrinsic)); + return Evaluate(Call(DataType::Int(32), sync_push_name_, + {make_const(DataType::Int(32), from), make_const(DataType::Int(32), to)}, + CallNode::Intrinsic)); } Stmt MakePop(int from, int to) { - return EvaluateNode::make(CallNode::make( - DataType::Int(32), sync_pop_name_, - {make_const(DataType::Int(32), from), make_const(DataType::Int(32), to)}, - CallNode::Intrinsic)); + return Evaluate(Call(DataType::Int(32), sync_pop_name_, + {make_const(DataType::Int(32), from), make_const(DataType::Int(32), to)}, + CallNode::Intrinsic)); } // sync states. SyncState first_state_, last_state_, curr_state_; @@ -604,7 +571,6 @@ class CoProcInstDepDetector : public StmtVisitor { std::string sync_push_name_, sync_pop_name_; }; - class CoProcSyncInserter : public StmtMutator { public: Stmt Insert(Stmt stmt) { @@ -613,7 +579,7 @@ class CoProcSyncInserter : public StmtMutator { if (visitor.coproc_.size() == 0) return stmt; std::unordered_set touched; - for (const auto &kv : visitor.touched_) { + for (const auto& kv : visitor.touched_) { if (kv.second.normal && kv.second.coproc) { touched.insert(kv.first); } @@ -640,8 +606,7 @@ class CoProcSyncInserter : public StmtMutator { vec.insert(vec.end(), kv.second.begin(), kv.second.end()); } // Detect barrier - CoProcInstDepDetector sync_detector( - *visitor.coproc_.begin(), coproc_name); + CoProcInstDepDetector sync_detector(*visitor.coproc_.begin(), coproc_name); sync_detector.Plan(stmt); for (const auto& kv : sync_detector.insert_before_) { auto& vec = insert_before_[kv.first]; @@ -660,9 +625,8 @@ class CoProcSyncInserter : public StmtMutator { Stmt new_stmt = StmtMutator::VisitStmt(stmt); return SeqStmt::Flatten( - it_before != insert_before_.end() ? it_before->second : std::vector(), - new_stmt, - it_after != insert_after_.end() ? it_after->second : std::vector()); + it_before != insert_before_.end() ? it_before->second : std::vector(), new_stmt, + it_after != insert_after_.end() ? it_after->second : std::vector()); } private: @@ -672,10 +636,22 @@ class CoProcSyncInserter : public StmtMutator { std::unordered_map > insert_after_; }; +Stmt CoProcSync(Stmt stmt) { return CoProcSyncInserter().Insert(std::move(stmt)); } -Stmt CoProcSync(Stmt stmt) { - return CoProcSyncInserter().Insert(std::move(stmt)); +namespace transform { + +Pass CoProcSync() { + auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + n->body = CoProcSyncInserter().Insert(std::move(n->body)); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.CoProcSync", {}); } +TVM_REGISTER_GLOBAL("tir.transform.CoProcSync").set_body_typed(CoProcSync); + +} // namespace transform + } // namespace tir } // namespace tvm diff --git a/src/tir/transforms/decorate_device_scope.cc b/src/tir/transforms/decorate_device_scope.cc new file mode 100644 index 000000000000..5034a858130d --- /dev/null +++ b/src/tir/transforms/decorate_device_scope.cc @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file decorate_device_scope.cc + */ +#include +#include +#include +#include + +namespace tvm { +namespace tir { + +Stmt DecorateDeviceScope(Stmt&& stmt) { + Stmt body = AttrStmt(make_zero(DataType::Int(32)), tir::attr::device_scope, 0, stmt); + return body; +} + +namespace transform { + +Pass DecorateDeviceScope() { + auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + n->body = DecorateDeviceScope(std::move(n->body)); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.DecorateDeviceScope", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.DecorateDeviceScope").set_body_typed(DecorateDeviceScope); + +} // namespace transform +} // namespace tir +} // namespace tvm diff --git a/src/tir/pass/inject_copy_intrin.cc b/src/tir/transforms/inject_copy_intrin.cc similarity index 67% rename from src/tir/pass/inject_copy_intrin.cc rename to src/tir/transforms/inject_copy_intrin.cc index 4805caf5ac55..b27459f4bd45 100644 --- a/src/tir/pass/inject_copy_intrin.cc +++ b/src/tir/transforms/inject_copy_intrin.cc @@ -21,10 +21,13 @@ * \brief Replace certain copy with copy intrinsics. * \file copy_intrin_rewrite.cc */ +#include #include +#include #include #include -#include +#include + #include "../../arith/pattern_match.h" namespace tvm { @@ -34,11 +37,9 @@ using runtime::PackedFunc; class CopyIntrinInjector : public StmtMutator { public: - CopyIntrinInjector(const std::string& pragma_key, - const PackedFunc& flower_copy_fromto) - : pragma_key_(attr::pragma_scope_prefix+ pragma_key), - flower_copy_fromto_(flower_copy_fromto) { - } + CopyIntrinInjector(const std::string& pragma_key, const PackedFunc& flower_copy_fromto) + : pragma_key_(attr::pragma_scope_prefix + pragma_key), + flower_copy_fromto_(flower_copy_fromto) {} Stmt VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == attr::storage_scope) { @@ -46,15 +47,14 @@ class CopyIntrinInjector : public StmtMutator { storage_scope_[buf] = op->value.as()->value; } else if (op->attr_key == pragma_key_) { Stmt ret; - CHECK(MatchCopyPattern(op->body, &ret)) - << "Cannot match copy pattern of " << op->body; + CHECK(MatchCopyPattern(op->body, &ret)) << "Cannot match copy pattern of " << op->body; return ret; } return StmtMutator::VisitStmt_(op); } private: - bool MatchCopyPattern(Stmt stmt, Stmt *out) { + bool MatchCopyPattern(Stmt stmt, Stmt* out) { using namespace arith; Stmt body = stmt; @@ -70,9 +70,8 @@ class CopyIntrinInjector : public StmtMutator { // Expr sel_cond, sel_true_value, sel_false_value; // match select or if PVar sel_cond, sel_true_value, sel_false_value; - bool has_cond = - if_then_else(sel_cond, sel_true_value, sel_false_value).Match(store->value) || - select(sel_cond, sel_true_value, sel_false_value).Match(store->value); + bool has_cond = if_then_else(sel_cond, sel_true_value, sel_false_value).Match(store->value) || + select(sel_cond, sel_true_value, sel_false_value).Match(store->value); const CastNode* cast = store->value.as(); const LoadNode* load = store->value.as(); @@ -93,11 +92,9 @@ class CopyIntrinInjector : public StmtMutator { for (const ForNode* op : loops) { loop_vars.push_back(op->loop_var); } - Array store_strides = - arith::DetectLinearEquation(store->index, loop_vars); - Array load_strides = - arith::DetectLinearEquation(load->index, loop_vars); - if (load_strides.size() == 0 || store_strides.size() == 0) return false; + Array store_strides = arith::DetectLinearEquation(store->index, loop_vars); + Array load_strides = arith::DetectLinearEquation(load->index, loop_vars); + if (load_strides.size() == 0 || store_strides.size() == 0) return false; Array dst_shape; const size_t loop_var_size = loop_vars.size(); if (loop_var_size == 0) { @@ -112,8 +109,7 @@ class CopyIntrinInjector : public StmtMutator { PrimExpr pad_value; PrimExpr src_elem_offset = load_strides[loop_var_size]; if (has_cond) { - Array clip_bound = - arith::DetectClipBound(sel_cond.Eval(), loop_vars); + Array clip_bound = arith::DetectClipBound(sel_cond.Eval(), loop_vars); pad_value = sel_false_value.Eval(); if (clip_bound.size() == 0) return false; CHECK_EQ(src_shape.size(), loop_vars.size()); @@ -124,7 +120,7 @@ class CopyIntrinInjector : public StmtMutator { DataType t = loop_vars[i].dtype(); PrimExpr svalue = src_shape[i]; if (min_value.defined()) { - PrimExpr pbefore = Simplify(MaxNode::make(min_value, make_zero(t))); + PrimExpr pbefore = analyzer_.Simplify(Max(min_value, make_zero(t))); src_elem_offset = src_elem_offset + pbefore * load_strides[i]; svalue = svalue - pbefore; pad_before.push_back(pbefore); @@ -132,43 +128,31 @@ class CopyIntrinInjector : public StmtMutator { pad_before.push_back(make_zero(t)); } if (max_value.defined()) { - PrimExpr pafter = Simplify(MaxNode::make(loops[i]->extent - max_value - make_const(t, 1), - make_zero(t))); + PrimExpr pafter = analyzer_.Simplify( + max(loops[i]->extent - max_value - make_const(t, 1), make_zero(t))); svalue = svalue - pafter; pad_after.push_back(pafter); } else { pad_after.push_back(make_zero(t)); } - src_shape.Set(i, Simplify(svalue)); + src_shape.Set(i, analyzer_.Simplify(svalue)); } - src_elem_offset = Simplify(src_elem_offset); + src_elem_offset = analyzer_.Simplify(src_elem_offset); } CHECK_EQ(load_strides.size(), store_strides.size()); CHECK_EQ(load_strides.size(), loop_var_size + 1); Array src_strides(load_strides.begin(), load_strides.begin() + loop_var_size); Array dst_strides(store_strides.begin(), store_strides.begin() + loop_var_size); if (loop_var_size == 0) { - src_strides.push_back(make_const(DataType::Int(32), 1)); - dst_strides.push_back(make_const(DataType::Int(32), 1)); + src_strides.push_back(make_const(DataType::Int(32), 1)); + dst_strides.push_back(make_const(DataType::Int(32), 1)); } - Buffer dst = BufferNode::make( - store->buffer_var, - store->value.dtype(), - dst_shape, - dst_strides, - store_strides[loop_var_size], - store->buffer_var->name_hint, - GetStorageScope(store->buffer_var.get()), - 0, 0, kDefault); - Buffer src = BufferNode::make( - load->buffer_var, - load->dtype, - src_shape, - src_strides, - src_elem_offset, - load->buffer_var->name_hint, - GetStorageScope(load->buffer_var.get()), - 0, 0, kDefault); + Buffer dst = Buffer(store->buffer_var, store->value.dtype(), dst_shape, dst_strides, + store_strides[loop_var_size], store->buffer_var->name_hint, + GetStorageScope(store->buffer_var.get()), 0, 0, kDefault); + Buffer src = Buffer(load->buffer_var, load->dtype, src_shape, src_strides, src_elem_offset, + load->buffer_var->name_hint, GetStorageScope(load->buffer_var.get()), 0, 0, + kDefault); *out = flower_copy_fromto_(src, dst, pad_before, pad_after, pad_value); CHECK(out->defined()) << "flower function did not return correct stmt"; return true; @@ -188,13 +172,29 @@ class CopyIntrinInjector : public StmtMutator { const PackedFunc& flower_copy_fromto_; // Storage scope std::unordered_map storage_scope_; + // arith analyzer + arith::Analyzer analyzer_; }; -Stmt InjectCopyIntrin(Stmt stmt, - const std::string& pragma_key, +Stmt InjectCopyIntrin(Stmt stmt, const std::string& pragma_key, const PackedFunc& flower_copy_fromto) { return CopyIntrinInjector(pragma_key, flower_copy_fromto)(std::move(stmt)); } +namespace transform { + +Pass InjectCopyIntrin(String pragma_key, PackedFunc flower_copy_fromto) { + auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + n->body = CopyIntrinInjector(pragma_key, flower_copy_fromto)(std::move(n->body)); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.InjectCopyIntrin", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.InjectCopyIntrin").set_body_typed(InjectCopyIntrin); + +} // namespace transform + } // namespace tir } // namespace tvm diff --git a/src/tir/pass/inject_double_buffer.cc b/src/tir/transforms/inject_double_buffer.cc similarity index 75% rename from src/tir/pass/inject_double_buffer.cc rename to src/tir/transforms/inject_double_buffer.cc index b9aa5a9e697e..9d5ee950cdfa 100644 --- a/src/tir/pass/inject_double_buffer.cc +++ b/src/tir/transforms/inject_double_buffer.cc @@ -21,15 +21,33 @@ * \brief Inject double buffering optimization for data fetch. * \file inject_double_buffer.cc */ -#include -#include +#include #include +#include +#include + #include "ir_util.h" -#include "../../arith/compute_expr.h" namespace tvm { namespace tir { +struct InjectDoubleBufferConfigNode : public tvm::AttrsNode { + int split_loop; + + TVM_DECLARE_ATTRS(InjectDoubleBufferConfigNode, "tir.transform.InjectDoubleBufferConfig") { + TVM_ATTR_FIELD(split_loop).describe("Split loop factors").set_default(1); + } +}; + +class InjectDoubleBufferConfig : public Attrs { + public: + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(InjectDoubleBufferConfig, Attrs, + InjectDoubleBufferConfigNode); +}; + +TVM_REGISTER_NODE_TYPE(InjectDoubleBufferConfigNode); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.InjectDoubleBuffer", InjectDoubleBufferConfig); + // Detect double buffer variables. class DoubleBufferDetector : public StmtExprVisitor { public: @@ -51,7 +69,6 @@ class DoubleBufferDetector : public StmtExprVisitor { std::unordered_set touched_; }; - class StripDoubleBufferWrite : public StmtMutator { public: Stmt VisitStmt_(const AttrStmtNode* op) final { @@ -65,8 +82,7 @@ class StripDoubleBufferWrite : public StmtMutator { class DoubleBufferInjector : public StmtExprMutator { public: - explicit DoubleBufferInjector(int split_loop) - : split_loop_(split_loop) {} + explicit DoubleBufferInjector(int split_loop) : split_loop_(split_loop) {} Stmt Inject(Stmt stmt) { DoubleBufferDetector detector; @@ -98,8 +114,9 @@ class DoubleBufferInjector : public StmtExprMutator { Stmt VisitStmt_(const AllocateNode* op) final { auto it = dbuffer_info_.find(op->buffer_var.get()); if (it != dbuffer_info_.end()) { - it->second.stride = arith::ComputeReduce( - op->extents, PrimExpr()) * op->dtype.lanes(); + auto fmul = [](PrimExpr a, PrimExpr b) { return a * b; }; + it->second.stride = + foldl(fmul, make_const(DataType::Int(32), 1), op->extents) * op->dtype.lanes(); Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); Array new_extents{make_const(op->extents[0].dtype(), 2)}; @@ -108,13 +125,10 @@ class DoubleBufferInjector : public StmtExprMutator { } CHECK(it->second.loop != nullptr); auto& alloc_nest = loop_allocs_[it->second.loop]; - alloc_nest.emplace_back(AttrStmtNode::make( - op->buffer_var, attr::storage_scope, - StringImmNode::make(it->second.scope), - EvaluateNode::make(0))); - alloc_nest.emplace_back(AllocateNode::make( - op->buffer_var, op->dtype, new_extents, op->condition, - EvaluateNode::make(0))); + alloc_nest.emplace_back( + AttrStmt(op->buffer_var, attr::storage_scope, StringImm(it->second.scope), Evaluate(0))); + alloc_nest.emplace_back( + Allocate(op->buffer_var, op->dtype, new_extents, op->condition, Evaluate(0))); return op->body; } else { return StmtExprMutator::VisitStmt_(op); @@ -133,8 +147,7 @@ class DoubleBufferInjector : public StmtExprMutator { << "It is better to split with multiple of 2"; CHECK(is_zero(old_loop->min)); PrimExpr zero = old_loop->min; - PrimExpr new_ext = - old_loop->extent - make_const(old_loop->loop_var.dtype(), 1); + PrimExpr new_ext = old_loop->extent - make_const(old_loop->loop_var.dtype(), 1); PrimExpr factor = make_const(new_ext.dtype(), split_loop_); PrimExpr outer_ext = new_ext / factor; PrimExpr tail_base = outer_ext * factor; @@ -145,18 +158,15 @@ class DoubleBufferInjector : public StmtExprMutator { vmap[old_loop->loop_var.get()] = outer_var * factor + make_const(factor.dtype(), i); loop_seq.emplace_back(Substitute(old_loop->body, vmap)); } - Stmt loop = ForNode::make( - outer_var, zero, outer_ext, old_loop->for_type, old_loop->device_api, - SeqStmt::Flatten(loop_seq)); + Stmt loop = For(outer_var, zero, outer_ext, old_loop->for_type, old_loop->device_api, + SeqStmt::Flatten(loop_seq)); // tail std::vector tail_seq; Stmt tail_body = StripDoubleBufferWrite()(old_loop->body); for (int32_t i = 0; i < split_loop_; ++i) { PrimExpr idx = tail_base + make_const(tail_base.dtype(), i); vmap[old_loop->loop_var.get()] = idx; - tail_seq.emplace_back( - IfThenElseNode::make(idx < old_loop->extent, - Substitute(tail_body, vmap))); + tail_seq.emplace_back(IfThenElse(idx < old_loop->extent, Substitute(tail_body, vmap))); } stmt = SeqStmt::Flatten(loop, tail_seq); } @@ -178,10 +188,8 @@ class DoubleBufferInjector : public StmtExprMutator { const StorageEntry& e = it->second; CHECK(in_double_buffer_scope_); CHECK(e.stride.defined()); - return StoreNode::make(op->buffer_var, - op->value, - e.switch_write_var * e.stride + op->index, - op->predicate); + return Store(op->buffer_var, op->value, e.switch_write_var * e.stride + op->index, + op->predicate); } else { return stmt; } @@ -195,10 +203,8 @@ class DoubleBufferInjector : public StmtExprMutator { const StorageEntry& e = it->second; CHECK(e.stride.defined()); CHECK(e.switch_read_var.defined()); - return LoadNode::make(op->dtype, - op->buffer_var, - e.switch_read_var * e.stride + op->index, - op->predicate); + return Load(op->dtype, op->buffer_var, e.switch_read_var * e.stride + op->index, + op->predicate); } else { return expr; } @@ -212,8 +218,7 @@ class DoubleBufferInjector : public StmtExprMutator { private: Stmt MakeProducer(const AttrStmtNode* op) { const Var buffer = Downcast(op->node); - CHECK_NE(loop_nest_.size(), 0U) - << "Double buffer scope must be inside a loop"; + CHECK_NE(loop_nest_.size(), 0U) << "Double buffer scope must be inside a loop"; auto it = dbuffer_info_.find(buffer.get()); if (it == dbuffer_info_.end()) { LOG(WARNING) << "Skip double buffer scope " << op->node; @@ -225,8 +230,7 @@ class DoubleBufferInjector : public StmtExprMutator { PrimExpr one = make_const(e.loop->loop_var.dtype(), 1); PrimExpr two = make_const(e.loop->loop_var.dtype(), 2); PrimExpr loop_shift = e.loop->loop_var + one; - e.switch_write_var = Var(e.loop->loop_var->name_hint + ".db", - e.loop->loop_var.dtype()); + e.switch_write_var = Var(e.loop->loop_var->name_hint + ".db", e.loop->loop_var.dtype()); e.switch_read_var = indexmod(e.loop->loop_var, two); in_double_buffer_scope_ = true; Stmt body = this->VisitStmt(op->body); @@ -238,8 +242,8 @@ class DoubleBufferInjector : public StmtExprMutator { vmap[e.loop->loop_var.get()] = loop_shift; vmap[e.switch_write_var.get()] = indexmod(loop_shift, two); body = Substitute(body, vmap); - body = AttrStmtNode::make(buffer, attr::double_buffer_write, 1, body); - body = IfThenElseNode::make(loop_shift < e.loop->extent, body); + body = AttrStmt(buffer, attr::double_buffer_write, 1, body); + body = IfThenElse(loop_shift < e.loop->extent, body); return body; } // Storage entry for those who need double buffering. @@ -269,9 +273,24 @@ class DoubleBufferInjector : public StmtExprMutator { std::unordered_map dbuffer_info_; }; +namespace transform { -Stmt InjectDoubleBuffer(Stmt stmt, int split_loop) { - return DoubleBufferInjector(split_loop).Inject(stmt); +Pass InjectDoubleBuffer() { + auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + auto cfg = ctx->GetConfig("tir.InjectDoubleBuffer"); + if (!cfg.defined()) { + cfg = AttrsWithDefaultValues(); + } + n->body = DoubleBufferInjector(cfg.value()->split_loop).Inject(std::move(n->body)); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.InjectDoubleBuffer", {}); } + +TVM_REGISTER_GLOBAL("tir.transform.InjectDoubleBuffer").set_body_typed(InjectDoubleBuffer); + +} // namespace transform + } // namespace tir } // namespace tvm diff --git a/src/tir/pass/inject_prefetch.cc b/src/tir/transforms/inject_prefetch.cc similarity index 74% rename from src/tir/pass/inject_prefetch.cc rename to src/tir/transforms/inject_prefetch.cc index 894ff3864864..9c27a71929c5 100644 --- a/src/tir/pass/inject_prefetch.cc +++ b/src/tir/transforms/inject_prefetch.cc @@ -21,17 +21,21 @@ * \file inject_prefetch.cc */ // Inject prefetch op in HalideIR +#include +#include +#include #include +#include #include -#include -#include +#include + #include namespace tvm { namespace tir { -using arith::IntSet; using arith::DomainTouched; +using arith::IntSet; class PrefetchInjector : public StmtMutator { public: @@ -39,9 +43,9 @@ class PrefetchInjector : public StmtMutator { Stmt ret = StmtMutator::VisitStmt_(op); op = ret.as(); if (op && op->attr_key == attr::prefetch_scope) { - te::Tensor ts = Downcast(op->node); + Buffer buffer = Downcast(op->node); CHECK_NE(loop_nest_.size(), 0U); - Domain domain = DomainTouched(op->body, ts, true, false); + Region domain = DomainTouched(op->body, buffer, true, false); Region region; auto iter_var = loop_nest_.back().get(); @@ -49,7 +53,7 @@ class PrefetchInjector : public StmtMutator { for (Range r : domain) { if (!r.defined()) { - LOG(WARNING) << "Cannot decide prefetch region for " << ts; + LOG(WARNING) << "Cannot decide prefetch region for " << buffer; return op->body; } Range res(EvalSet(r, vectorized_).cover_range(none)); @@ -58,14 +62,14 @@ class PrefetchInjector : public StmtMutator { vectorized_.erase(iter_var); - Stmt prefetch = PrefetchNode::make(ts->op, ts->value_index, ts->dtype, region); + Stmt prefetch = Prefetch(buffer, region); return SeqStmt({prefetch, op->body}); } return ret; } Stmt VisitStmt_(const ForNode* op) final { - auto &var = op->loop_var; + auto& var = op->loop_var; loop_nest_.push_back(var); if (op->for_type == ForType::Vectorized) { vectorized_[var.get()] = IntSet::interval(op->min, (op->min + op->extent) - 1); @@ -80,15 +84,28 @@ class PrefetchInjector : public StmtMutator { private: std::vector loop_nest_; - std::unordered_map vectorized_; + std::unordered_map vectorized_; static const Range none; }; const Range PrefetchInjector::none; -Stmt InjectPrefetch(Stmt stmt) { - return PrefetchInjector()(std::move(stmt)); +Stmt InjectPrefetch(Stmt stmt) { return PrefetchInjector()(std::move(stmt)); } + +namespace transform { + +Pass InjectPrefetch() { + auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + n->body = PrefetchInjector()(std::move(n->body)); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.InjectPrefetch", {}); } +TVM_REGISTER_GLOBAL("tir.transform.InjectPrefetch").set_body_typed(InjectPrefetch); + +} // namespace transform + } // namespace tir } // namespace tvm diff --git a/src/tir/pass/inject_virtual_thread.cc b/src/tir/transforms/inject_virtual_thread.cc similarity index 80% rename from src/tir/pass/inject_virtual_thread.cc rename to src/tir/transforms/inject_virtual_thread.cc index e9c403ca5cb5..042ddab15a2f 100644 --- a/src/tir/pass/inject_virtual_thread.cc +++ b/src/tir/transforms/inject_virtual_thread.cc @@ -20,11 +20,14 @@ /*! * \file inject_virtual_thread.cc */ +#include #include #include -#include +#include + #include -#include "../../arith/compute_expr.h" + +#include "ir_util.h" namespace tvm { namespace tir { @@ -32,8 +35,7 @@ namespace tir { // If expression is touched by var. class ExprTouched final : public StmtExprVisitor { public: - explicit ExprTouched(const std::unordered_set &touched, - bool check_write) + explicit ExprTouched(const std::unordered_set& touched, bool check_write) : touched_var_(touched), check_write_(check_write) {} void VisitExpr(const PrimExpr& n) final { @@ -41,29 +43,27 @@ class ExprTouched final : public StmtExprVisitor { if (expr_touched_ && !check_write_) return; StmtExprVisitor::VisitExpr(n); } - void VisitStmt(const Stmt& n) final { + void VisitStmt(const Stmt& n) final { // early stopping if (expr_touched_ && !check_write_) return; StmtExprVisitor::VisitStmt(n); } - void VisitExpr_(const LoadNode *op) final { + void VisitExpr_(const LoadNode* op) final { HandleUseVar(op->buffer_var.get()); StmtExprVisitor::VisitExpr_(op); } - void VisitExpr_(const VarNode *op) final { - HandleUseVar(op); - } - void VisitExpr_(const CallNode *op) final { + void VisitExpr_(const VarNode* op) final { HandleUseVar(op); } + void VisitExpr_(const CallNode* op) final { if (op->is_intrinsic(intrinsic::tvm_access_ptr)) { - int rw_mask = 0; - CHECK(arith::GetConstInt(op->args[4], &rw_mask)); + const auto* rw_mask = op->args[4].as(); const VarNode* buffer_var = op->args[1].as(); CHECK(buffer_var); + CHECK(rw_mask); // read - if (rw_mask & 1) { + if (rw_mask->value & 1) { HandleUseVar(buffer_var); } - if (rw_mask & 2) { + if (rw_mask->value & 2) { HandleWriteVar(buffer_var); } this->VisitExpr(op->args[2]); @@ -82,9 +82,7 @@ class ExprTouched final : public StmtExprVisitor { used_vars_.push_back(var); } } - void HandleWriteVar(const VarNode* var) { - write_vars_.push_back(var); - } + void HandleWriteVar(const VarNode* var) { write_vars_.push_back(var); } // the fields. bool expr_touched_{false}; std::vector used_vars_; @@ -132,8 +130,7 @@ class VarTouchedAnalysis : public StmtVisitor { Record(op->buffer_var.get(), tc); this->VisitStmt(op->body); } - void Record(const VarNode* var, - const ExprTouched& tc) { + void Record(const VarNode* var, const ExprTouched& tc) { if (touched_var_.count(var)) return; if (tc.expr_touched_) { touched_var_.insert(var); @@ -146,14 +143,11 @@ class VarTouchedAnalysis : public StmtVisitor { } } - std::unordered_set - TouchedVar(const Stmt& stmt, - const VarNode* var) { + std::unordered_set TouchedVar(const Stmt& stmt, const VarNode* var) { touched_var_.insert(var); this->VisitStmt(stmt); // do a DFS to push affect around dependency. - std::vector pending( - touched_var_.begin(), touched_var_.end()); + std::vector pending(touched_var_.begin(), touched_var_.end()); while (!pending.empty()) { const VarNode* v = pending.back(); pending.pop_back(); @@ -171,29 +165,26 @@ class VarTouchedAnalysis : public StmtVisitor { // Whether variable is touched by the thread variable. std::unordered_set touched_var_; // x -> all the buffers x read from - std::unordered_map > affect_; + std::unordered_map > affect_; }; - // Inject virtual thread loop // rewrite the buffer access pattern when necessary. class VTInjector : public StmtExprMutator { public: // constructor - VTInjector(Var var, - int num_threads, - const std::unordered_set& touched_var, + VTInjector(Var var, int num_threads, const std::unordered_set& touched_var, bool allow_share) - : var_(var), num_threads_(num_threads), - touched_var_(touched_var), allow_share_(allow_share) { - } + : var_(var), + num_threads_(num_threads), + touched_var_(touched_var), + allow_share_(allow_share) {} // Inject VTLoop when needed. Stmt VisitStmt(const Stmt& s) final { CHECK(!visit_touched_var_); auto stmt = StmtExprMutator::VisitStmt(s); if (visit_touched_var_ || trigger_base_inject_) { - if (!vt_loop_injected_) { + if (!vt_loop_injected_) { return InjectVTLoop(stmt, false); } visit_touched_var_ = false; @@ -203,8 +194,7 @@ class VTInjector : public StmtExprMutator { } // Variable PrimExpr VisitExpr_(const VarNode* op) final { - CHECK(!alloc_remap_.count(op)) - << "Buffer address may get rewritten in virtual thread"; + CHECK(!alloc_remap_.count(op)) << "Buffer address may get rewritten in virtual thread"; if (touched_var_.count(op)) { visit_touched_var_ = true; } @@ -222,9 +212,7 @@ class VTInjector : public StmtExprMutator { } auto it = alloc_remap_.find(op->buffer_var.get()); if (it != alloc_remap_.end()) { - return LoadNode::make(op->dtype, op->buffer_var, - RewriteIndex(op->index, it->second), - op->predicate); + return Load(op->dtype, op->buffer_var, RewriteIndex(op->index, it->second), op->predicate); } else { return expr; } @@ -240,13 +228,10 @@ class VTInjector : public StmtExprMutator { visit_touched_var_ = true; PrimExpr offset = this->VisitExpr(op->args[2]); PrimExpr extent = this->VisitExpr(op->args[3]); - PrimExpr stride = - it->second / make_const(offset.dtype(), dtype.lanes()); + PrimExpr stride = it->second / make_const(offset.dtype(), dtype.lanes()); offset = stride * var_ + offset; - return CallNode::make( - op->dtype, op->name, - {op->args[0], op->args[1], offset, extent, op->args[4]}, - op->call_type); + return Call(op->dtype, op->name, {op->args[0], op->args[1], offset, extent, op->args[4]}, + op->call_type); } else if (op->is_intrinsic(intrinsic::tvm_context_id)) { return allow_share_ ? GetRef(op) : var_; } else { @@ -267,10 +252,7 @@ class VTInjector : public StmtExprMutator { trigger_base_inject_ = !allow_share_; auto it = alloc_remap_.find(op->buffer_var.get()); if (it != alloc_remap_.end()) { - return StoreNode::make(op->buffer_var, - op->value, - RewriteIndex(op->index, it->second), - op->predicate); + return Store(op->buffer_var, op->value, RewriteIndex(op->index, it->second), op->predicate); } else { return stmt; } @@ -281,16 +263,14 @@ class VTInjector : public StmtExprMutator { if (visit_touched_var_ && !vt_loop_injected_) { return InjectVTLoop(GetRef(op), true); } else if (!allow_share_ && !vt_loop_injected_ && - (op->attr_key == attr::coproc_uop_scope || - op->attr_key == attr::coproc_scope)) { + (op->attr_key == attr::coproc_uop_scope || op->attr_key == attr::coproc_scope)) { return InjectVTLoop(GetRef(op), true); } else { Stmt body = this->VisitStmt(op->body); - if (value.same_as(op->value) && - body.same_as(op->body)) { + if (value.same_as(op->value) && body.same_as(op->body)) { return GetRef(op); } else { - return AttrStmtNode::make(op->node, op->attr_key, value, body); + return AttrStmt(op->node, op->attr_key, value, body); } } } @@ -302,11 +282,10 @@ class VTInjector : public StmtExprMutator { } visit_touched_var_ = false; Stmt body = this->VisitStmt(op->body); - if (value.same_as(op->value) && - body.same_as(op->body)) { + if (value.same_as(op->value) && body.same_as(op->body)) { return GetRef(op); } else { - return LetStmtNode::make(op->var, value, body); + return LetStmt(op->var, value, body); } } // For @@ -321,12 +300,10 @@ class VTInjector : public StmtExprMutator { visit_touched_var_ = false; Stmt body = this->VisitStmt(op->body); ++max_loop_depth_; - if (extent.same_as(op->extent) && - body.same_as(op->body)) { + if (extent.same_as(op->extent) && body.same_as(op->body)) { return GetRef(op); } else { - return ForNode::make( - op->loop_var, op->min, extent, op->for_type, op->device_api, body); + return For(op->loop_var, op->min, extent, op->for_type, op->device_api, body); } } // IfThenElse @@ -345,12 +322,11 @@ class VTInjector : public StmtExprMutator { else_case = this->VisitStmt(op->else_case); max_loop_depth_ = std::max(temp, max_loop_depth_); } - if (condition.same_as(op->condition) && - then_case.same_as(op->then_case) && + if (condition.same_as(op->condition) && then_case.same_as(op->then_case) && else_case.same_as(op->else_case)) { return GetRef(op); } else { - return IfThenElseNode::make(condition, then_case, else_case); + return IfThenElse(condition, then_case, else_case); } } @@ -389,8 +365,9 @@ class VTInjector : public StmtExprMutator { // always rewrite if not allow sharing. if (touched_var_.count(op->buffer_var.get()) || !allow_share_) { // place v on highest dimension. - PrimExpr stride = arith::ComputeReduce( - op->extents, PrimExpr()) * op->dtype.lanes(); + auto fmul = [](PrimExpr a, PrimExpr b) { return a * b; }; + PrimExpr stride = + foldl(fmul, make_const(DataType::Int(32), 1), op->extents) * op->dtype.lanes(); Array other; other.push_back(make_const(op->extents[0].dtype(), num_threads_)); for (PrimExpr e : extents) { @@ -406,14 +383,10 @@ class VTInjector : public StmtExprMutator { // Mutate the body. body = this->VisitStmt(op->body); } - if (!changed && - body.same_as(op->body) && - condition.same_as(op->condition)) { + if (!changed && body.same_as(op->body) && condition.same_as(op->condition)) { return GetRef(op); } else { - return AllocateNode::make( - op->buffer_var, op->dtype, - extents, condition, body); + return Allocate(op->buffer_var, op->dtype, extents, condition, body); } } @@ -443,9 +416,8 @@ class VTInjector : public StmtExprMutator { Var idx(var_->name_hint + ".s", var_->dtype); Map values{{var_, idx}}; stmt = Substitute(stmt, values); - return ForNode::make(idx, make_zero(idx.dtype()), - make_const(idx.dtype(), num_threads_), - ForType::Serial, DeviceAPI::None, stmt); + return For(idx, make_zero(idx.dtype()), make_const(idx.dtype(), num_threads_), + ForType::Serial, DeviceAPI::None, stmt); } } @@ -470,7 +442,6 @@ class VTInjector : public StmtExprMutator { std::unordered_map alloc_remap_; }; - class VirtualThreadInjector : public StmtMutator { public: Stmt VisitStmt_(const AttrStmtNode* op) final { @@ -489,7 +460,7 @@ class VirtualThreadInjector : public StmtMutator { } } - Stmt VisitStmt_(const ProvideNode* op) final { + Stmt VisitStmt_(const ProducerStoreNode* op) final { LOG(FATAL) << "Need to call StorageFlatten first"; return GetRef(op); } @@ -500,5 +471,20 @@ Stmt InjectVirtualThread(Stmt stmt) { return ConvertSSA(std::move(stmt)); } +namespace transform { + +Pass InjectVirtualThread() { + auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + n->body = ConvertSSA(VirtualThreadInjector()(std::move(n->body))); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.InjectVirtualThread", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.InjectVirtualThread").set_body_typed(InjectVirtualThread); + +} // namespace transform + } // namespace tir } // namespace tvm diff --git a/src/tir/pass/ssa.cc b/src/tir/transforms/ir_util.cc similarity index 61% rename from src/tir/pass/ssa.cc rename to src/tir/transforms/ir_util.cc index daef32c01bdb..4f21f0bb7411 100644 --- a/src/tir/pass/ssa.cc +++ b/src/tir/transforms/ir_util.cc @@ -18,61 +18,73 @@ */ /*! - * SSA related checks and pass. - * - * SSA requires each varaible to be only defined once. - * \file ssa.cc + * \file ir_util.cc + * \brief Helper functions to construct and compose IR nodes. */ -#include +#include "ir_util.h" + #include -#include -#include + #include -#include +#include +#include namespace tvm { namespace tir { -namespace { -class IRVerifySSA final : public StmtExprVisitor { - public: - bool is_ssa{true}; - - void VisitExpr(const PrimExpr& n) final { - if (!is_ssa) return; - StmtExprVisitor::VisitExpr(n); - } - void VisitStmt(const Stmt& n) final { - if (!is_ssa) return; - StmtExprVisitor::VisitStmt(n); - } - void VisitExpr_(const LetNode* op) final { - MarkDef(op->var.get()); - StmtExprVisitor::VisitExpr_(op); - } - void VisitStmt_(const LetStmtNode* op) final { - MarkDef(op->var.get()); - StmtExprVisitor::VisitStmt_(op); - } - void VisitStmt_(const ForNode* op) final { - MarkDef(op->loop_var.get()); - StmtExprVisitor::VisitStmt_(op); - } - void VisitStmt_(const AllocateNode* op) final { - MarkDef(op->buffer_var.get()); - StmtExprVisitor::VisitStmt_(op); - } - private: - void MarkDef(const VarNode* v) { - if (defined_.count(v) != 0) { - is_ssa = false; return; +Stmt MergeNest(const std::vector& nest, Stmt body) { + // use reverse iteration + for (auto ri = nest.rbegin(); ri != nest.rend(); ++ri) { + Stmt s = *ri; + if (const auto* for_ = s.as()) { + auto n = make_object(*for_); + CHECK(is_no_op(n->body)); + n->body = body; + body = Stmt(n); + } else if (const auto* let = s.as()) { + auto n = make_object(*let); + CHECK(is_no_op(n->body)); + n->body = body; + body = Stmt(n); + } else if (const auto* attr = s.as()) { + auto n = make_object(*attr); + CHECK(is_no_op(n->body)); + n->body = body; + body = Stmt(n); + } else if (const auto* ite = s.as()) { + auto n = make_object(*ite); + CHECK(is_no_op(n->then_case)); + CHECK(!n->else_case.defined()); + n->then_case = body; + body = Stmt(n); + } else if (const auto* seq = s.as()) { + auto n = make_object(*seq); + CHECK(n->size() != 0 && is_no_op(n->seq[n->size() - 1])); + n->seq.Set(n->size() - 1, body); + body = Stmt(n); + } else if (const auto* assert_ = s.as()) { + auto n = make_object(*assert_); + CHECK(is_no_op(n->body)); + n->body = body; + body = Stmt(n); + } else if (const auto* alloc = s.as()) { + auto n = make_object(*alloc); + CHECK(is_no_op(n->body)); + n->body = body; + body = Stmt(n); } else { - defined_[v] = 1; + LOG(FATAL) << "not supported nest type"; } } - std::unordered_map defined_; -}; + return body; +} +Stmt MergeNest(const std::vector>& nest, Stmt body) { + for (auto ri = nest.rbegin(); ri != nest.rend(); ++ri) { + body = MergeNest(*ri, body); + } + return body; +} class IRConvertSSA final : public StmtExprMutator { public: @@ -91,7 +103,7 @@ class IRConvertSSA final : public StmtExprMutator { scope_[v.get()].push_back(new_var); PrimExpr body = this->VisitExpr(op->body); scope_[v.get()].pop_back(); - return LetNode::make(new_var, value, body); + return Let(new_var, value, body); } else { defined_.insert(v.get()); return StmtExprMutator::VisitExpr_(op); @@ -101,9 +113,7 @@ class IRConvertSSA final : public StmtExprMutator { PrimExpr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); if (scope_.count(op->buffer_var.get())) { - return LoadNode::make( - op->dtype, scope_[op->buffer_var.get()].back(), - op->index, op->predicate); + return Load(op->dtype, scope_[op->buffer_var.get()].back(), op->index, op->predicate); } else { return expr; } @@ -112,9 +122,7 @@ class IRConvertSSA final : public StmtExprMutator { Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); if (scope_.count(op->buffer_var.get())) { - return StoreNode::make( - scope_[op->buffer_var.get()].back(), op->value, - op->index, op->predicate); + return Store(scope_[op->buffer_var.get()].back(), op->value, op->index, op->predicate); } else { return stmt; } @@ -127,7 +135,7 @@ class IRConvertSSA final : public StmtExprMutator { scope_[v.get()].push_back(new_var); Stmt body = this->VisitStmt(op->body); scope_[v.get()].pop_back(); - return LetStmtNode::make(new_var, value, body); + return LetStmt(new_var, value, body); } else { defined_.insert(v.get()); return StmtExprMutator::VisitStmt_(op); @@ -141,8 +149,7 @@ class IRConvertSSA final : public StmtExprMutator { Stmt stmt = StmtExprMutator::VisitStmt_(op); scope_[v.get()].pop_back(); op = stmt.as(); - return ForNode::make( - new_var, op->min, op->extent, op->for_type, op->device_api, op->body); + return For(new_var, op->min, op->extent, op->for_type, op->device_api, op->body); } else { defined_.insert(v.get()); return StmtExprMutator::VisitStmt_(op); @@ -156,9 +163,7 @@ class IRConvertSSA final : public StmtExprMutator { Stmt stmt = StmtExprMutator::VisitStmt_(op); scope_[v.get()].pop_back(); op = stmt.as(); - return AllocateNode::make( - new_var, op->dtype, op->extents, op->condition, - op->body); + return Allocate(new_var, op->dtype, op->extents, op->condition, op->body); } else { defined_.insert(v.get()); return StmtExprMutator::VisitStmt_(op); @@ -173,15 +178,13 @@ class IRConvertSSA final : public StmtExprMutator { if (new_alloc.same_as(op->body)) return GetRef(op); alloc = new_alloc.as(); CHECK(alloc); - return AttrStmtNode::make( - alloc->buffer_var, op->attr_key, op->value, new_alloc); + return AttrStmt(alloc->buffer_var, op->attr_key, op->value, new_alloc); } } Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); if (scope_.count(v) && scope_[v].size() != 0) { - return AttrStmtNode::make( - scope_[v].back(), op->attr_key, op->value, op->body); + return AttrStmt(scope_[v].back(), op->attr_key, op->value, op->body); } else { return stmt; } @@ -191,21 +194,11 @@ class IRConvertSSA final : public StmtExprMutator { } private: - std::unordered_map > scope_; + std::unordered_map> scope_; std::unordered_set defined_; }; -} // namespace - -bool VerifySSA(const Stmt& ir) { - IRVerifySSA visitor; - visitor(ir); - return visitor.is_ssa; -} - -Stmt ConvertSSA(Stmt stmt) { - return IRConvertSSA()(std::move(stmt)); -} +Stmt ConvertSSA(Stmt stmt) { return IRConvertSSA()(std::move(stmt)); } } // namespace tir } // namespace tvm diff --git a/src/tir/pass/ir_util.h b/src/tir/transforms/ir_util.h similarity index 66% rename from src/tir/pass/ir_util.h rename to src/tir/transforms/ir_util.h index d8da61fdd961..6c0eeea97278 100644 --- a/src/tir/pass/ir_util.h +++ b/src/tir/transforms/ir_util.h @@ -21,12 +21,13 @@ * \file ir_util.h * \brief Helper functions to construct and compose IR nodes. */ -#ifndef TVM_TIR_PASS_IR_UTIL_H_ -#define TVM_TIR_PASS_IR_UTIL_H_ +#ifndef TVM_TIR_TRANSFORMS_IR_UTIL_H_ +#define TVM_TIR_TRANSFORMS_IR_UTIL_H_ +#include #include #include -#include + #include namespace tvm { @@ -56,7 +57,7 @@ Stmt MergeNest(const std::vector >& nest, Stmt body); * \return if update happens, return the new array, else return the * original array */ -template +template inline Array UpdateArray(Array arr, F fupdate) { std::vector new_arr(arr.size()); bool changed = false; @@ -81,14 +82,11 @@ inline Array UpdateArray(Array arr, F fupdate) { * \param kind The data kind. * \return the get expression. */ -inline PrimExpr TVMStructGet( - DataType dtype, Var handle, int index, - intrinsic::TVMStructFieldKind kind) { - Array args ={ - handle, - make_const(DataType::Int(32), index), - make_const(DataType::Int(32), static_cast(kind))}; - return CallNode::make(dtype, intrinsic::tvm_struct_get, args, CallNode::PureIntrinsic); +inline PrimExpr TVMStructGet(DataType dtype, Var handle, int index, + intrinsic::TVMStructFieldKind kind) { + Array args = {handle, make_const(DataType::Int(32), index), + make_const(DataType::Int(32), static_cast(kind))}; + return Call(dtype, intrinsic::tvm_struct_get, args, CallNode::PureIntrinsic); } /*! @@ -98,11 +96,10 @@ inline PrimExpr TVMStructGet( * \param offset the offset index. */ inline PrimExpr AddressOffset(Var handle, DataType dtype, int offset) { - return CallNode::make( - DataType::Handle(), intrinsic::tvm_address_of, - {LoadNode::make(dtype, handle, make_const(DataType::Int(32), offset * dtype.lanes()), - const_true(dtype.lanes()))}, - CallNode::PureIntrinsic); + return Call(DataType::Handle(), intrinsic::tvm_address_of, + {Load(dtype, handle, make_const(DataType::Int(32), offset * dtype.lanes()), + const_true(dtype.lanes()))}, + CallNode::PureIntrinsic); } /*! @@ -114,13 +111,10 @@ inline PrimExpr AddressOffset(Var handle, DataType dtype, int offset) { inline PrimExpr AddressOffset(Var handle, DataType dtype, PrimExpr offset) { if (dtype.lanes() != 1) { offset = offset * make_const(offset.dtype(), dtype.lanes()); - offset = RampNode::make(offset, make_const(offset.dtype(), 1), dtype.lanes()); + offset = Ramp(offset, make_const(offset.dtype(), 1), dtype.lanes()); } - return CallNode::make( - DataType::Handle(), intrinsic::tvm_address_of, - {LoadNode::make(dtype, handle, offset, - const_true(dtype.lanes()))}, - CallNode::PureIntrinsic); + return Call(DataType::Handle(), intrinsic::tvm_address_of, + {Load(dtype, handle, offset, const_true(dtype.lanes()))}, CallNode::PureIntrinsic); } /*! @@ -131,16 +125,11 @@ inline PrimExpr AddressOffset(Var handle, DataType dtype, PrimExpr offset) { * \param value The value to be set. * \return the set stmt. */ -inline Stmt TVMStructSet( - Var handle, int index, - intrinsic::TVMStructFieldKind kind, PrimExpr value) { - Array args ={ - handle, - make_const(DataType::Int(32), index), - make_const(DataType::Int(32), static_cast(kind)), - value}; - return EvaluateNode::make( - CallNode::make(DataType::Int(32), intrinsic::tvm_struct_set, args, CallNode::Intrinsic)); +inline Stmt TVMStructSet(Var handle, int index, intrinsic::TVMStructFieldKind kind, + PrimExpr value) { + Array args = {handle, make_const(DataType::Int(32), index), + make_const(DataType::Int(32), static_cast(kind)), value}; + return Evaluate(Call(DataType::Int(32), intrinsic::tvm_struct_set, args, CallNode::Intrinsic)); } /*! @@ -150,8 +139,7 @@ inline Stmt TVMStructSet( */ inline DataType APIType(DataType t) { if (t.is_handle()) return t; - CHECK_EQ(t.lanes(), 1) - << "Cannot pass vector type through packed API."; + CHECK_EQ(t.lanes(), 1) << "Cannot pass vector type through packed API."; if (t.is_uint() || t.is_int()) return DataType::Int(64); CHECK(t.is_float()); return DataType::Float(64); @@ -175,21 +163,12 @@ inline int GetTempAllocaAlignment(DataType type, int32_t const_size) { } /*! - * \brief Pattern match index to Ramp with stride=1 - * This is a common pattern in continuous memory load. - * \param index The index formula - * \param lanes number of lanes in the ramp - * \param base The result base. - * \return true if pattern match success and store the base to base. + * \brief Convert a IR node to be SSA form. + * \param stmt The source statement to be converted. + * \return The converted form. */ -inline bool GetRamp1Base(PrimExpr index, int lanes, PrimExpr *base) { - const RampNode* r = index.as(); - if (!r) return false; - if (!is_one(r->stride)) return false; - CHECK_EQ(r->lanes, lanes); - *base = r->base; - return true; -} +Stmt ConvertSSA(Stmt stmt); + } // namespace tir } // namespace tvm -#endif // TVM_TIR_PASS_IR_UTIL_H_ +#endif // TVM_TIR_TRANSFORMS_IR_UTIL_H_ diff --git a/src/tir/pass/lift_attr_scope.cc b/src/tir/transforms/lift_attr_scope.cc similarity index 75% rename from src/tir/pass/lift_attr_scope.cc rename to src/tir/transforms/lift_attr_scope.cc index 9aa037feb460..1a1279f0640a 100644 --- a/src/tir/pass/lift_attr_scope.cc +++ b/src/tir/transforms/lift_attr_scope.cc @@ -23,8 +23,10 @@ * the body contains the same scope. * \file lift_attr_scope.cc */ -#include +#include #include +#include + #include "ir_util.h" namespace tvm { @@ -34,14 +36,12 @@ namespace tir { // to a few specified attr keys class AttrScopeLifter : public StmtMutator { public: - explicit AttrScopeLifter(std::string attr_key) - : attr_key_(attr_key) {} + explicit AttrScopeLifter(std::string attr_key) : attr_key_(attr_key) {} Stmt Lift(Stmt stmt) { stmt = operator()(std::move(stmt)); if (attr_node_.defined()) { - stmt = AttrStmtNode::make( - attr_node_, attr_key_, attr_value_, stmt); + stmt = AttrStmt(attr_node_, attr_key_, attr_value_, stmt); } return stmt; } @@ -51,14 +51,11 @@ class AttrScopeLifter : public StmtMutator { Stmt stmt = StmtMutator::VisitStmt_(op); op = stmt.as(); if (attr_node_.defined()) { - Stmt body = AttrStmtNode::make( - attr_node_, attr_key_, attr_value_, op->body); + Stmt body = AttrStmt(attr_node_, attr_key_, attr_value_, op->body); // undefine them attr_node_ = ObjectRef(); attr_value_ = PrimExpr(); - return AllocateNode::make( - op->buffer_var, op->dtype, - op->extents, op->condition, body); + return Allocate(op->buffer_var, op->dtype, op->extents, op->condition, body); } else { return stmt; } @@ -96,8 +93,7 @@ class AttrScopeLifter : public StmtMutator { // check if all decorations are common. for (size_t begin = 0; begin < attr_node.size();) { size_t end = begin + 1; - while (end < attr_node.size() && - attr_node[end].same_as(attr_node[begin]) && + while (end < attr_node.size() && attr_node[end].same_as(attr_node[begin]) && ValueSame(attr_value[end], attr_value[begin])) { ++end; } @@ -115,8 +111,7 @@ class AttrScopeLifter : public StmtMutator { } Stmt stmt = SeqStmt::Flatten(seq); if (attr_node[begin].defined()) { - stmt = AttrStmtNode::make( - attr_node[begin], attr_key_, attr_value[begin], stmt); + stmt = AttrStmt(attr_node[begin], attr_key_, attr_value[begin], stmt); } reorg.push_back(stmt); begin = end; @@ -136,35 +131,28 @@ class AttrScopeLifter : public StmtMutator { std::swap(first_node, attr_node_); std::swap(first_value, attr_value_); Stmt else_case = this->VisitStmt(op->else_case); - if (attr_node_.defined() && - attr_value_.defined() && - first_node.defined() && - first_value.defined() && - attr_node_.same_as(first_node) && + if (attr_node_.defined() && attr_value_.defined() && first_node.defined() && + first_value.defined() && attr_node_.same_as(first_node) && ValueSame(attr_value_, first_value)) { - if (then_case.same_as(op->then_case) && - else_case.same_as(op->else_case)) { + if (then_case.same_as(op->then_case) && else_case.same_as(op->else_case)) { return GetRef(op); } else { - return IfThenElseNode::make(op->condition, then_case, else_case); + return IfThenElse(op->condition, then_case, else_case); } } else { if (first_node.defined()) { - then_case = AttrStmtNode::make( - first_node, attr_key_, first_value, then_case); + then_case = AttrStmt(first_node, attr_key_, first_value, then_case); } if (attr_node_.defined()) { - else_case = AttrStmtNode::make( - attr_node_, attr_key_, attr_value_, else_case); + else_case = AttrStmt(attr_node_, attr_key_, attr_value_, else_case); // undefine them attr_node_ = ObjectRef(); attr_value_ = PrimExpr(); } - if (then_case.same_as(op->then_case) && - else_case.same_as(op->else_case)) { + if (then_case.same_as(op->then_case) && else_case.same_as(op->else_case)) { return GetRef(op); } else { - return IfThenElseNode::make(op->condition, then_case, else_case); + return IfThenElse(op->condition, then_case, else_case); } } } @@ -191,5 +179,20 @@ Stmt LiftAttrScope(Stmt stmt, std::string attr_key) { return AttrScopeLifter(attr_key).Lift(std::move(stmt)); } +namespace transform { + +Pass LiftAttrScope(String attr_key) { + auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + n->body = AttrScopeLifter(attr_key).Lift(std::move(n->body)); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.LiftAttrScope", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.LiftAttrScope").set_body_typed(LiftAttrScope); + +} // namespace transform + } // namespace tir } // namespace tvm diff --git a/src/tir/pass/loop_partition.cc b/src/tir/transforms/loop_partition.cc similarity index 76% rename from src/tir/pass/loop_partition.cc rename to src/tir/transforms/loop_partition.cc index e9157e796e38..3b2580c60074 100644 --- a/src/tir/pass/loop_partition.cc +++ b/src/tir/transforms/loop_partition.cc @@ -20,21 +20,42 @@ /*! * \file loop_partition.cc */ +#include +#include +#include #include #include -#include -#include +#include + #include #include + #include "../../arith/interval_set.h" #include "../../runtime/thread_storage_scope.h" +#include "ir_util.h" namespace tvm { namespace tir { -using arith::IntSet; +struct LoopPartitionConfigNode : public tvm::AttrsNode { + bool partition_const_loop; + + TVM_DECLARE_ATTRS(LoopPartitionConfigNode, "tir.transform.LoopPartitionConfig") { + TVM_ATTR_FIELD(partition_const_loop).describe("Split constant loop").set_default(false); + } +}; + +class LoopPartitionConfig : public Attrs { + public: + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(LoopPartitionConfig, Attrs, LoopPartitionConfigNode); +}; + +TVM_REGISTER_NODE_TYPE(LoopPartitionConfigNode); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.LoopPartition", LoopPartitionConfig); + using arith::DeduceBound; using arith::Intersect; +using arith::IntSet; using PartitionKey = std::pair; struct PartitionKeyHash { @@ -69,12 +90,12 @@ bool ExprUseVars(PrimExpr expr, const std::unordered_set& vars) class CandidateSelector final : public StmtExprVisitor { public: using VarIsUsed = bool; - explicit CandidateSelector(bool split_const_loop) - : split_const_loop_(split_const_loop) {} + explicit CandidateSelector(bool partition_const_loop) + : partition_const_loop_(partition_const_loop) {} void VisitStmt_(const ForNode* op) final { - // partition const loop when sets split_const_loop_ - if (!is_const(op->min) || !is_const(op->extent) || split_const_loop_) { + // partition const loop when sets partition_const_loop_ + if (!is_const(op->min) || !is_const(op->extent) || partition_const_loop_) { const VarNode* var = op->loop_var.get(); record_.insert({var, false}); StmtExprVisitor::VisitStmt_(op); @@ -89,11 +110,11 @@ class CandidateSelector final : public StmtExprVisitor { void VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == attr::thread_extent) { - const IterVarNode *iv = op->node.as(); + const IterVarNode* iv = op->node.as(); CHECK(iv); Var var = iv->var; - runtime::ThreadScope scope = runtime::ThreadScope::make(iv->thread_tag); - if ((scope.rank == 0) && (!is_const(op->value) || split_const_loop_)) { + runtime::ThreadScope scope = runtime::ThreadScope::Create(iv->thread_tag); + if ((scope.rank == 0) && (!is_const(op->value) || partition_const_loop_)) { record_.insert({var.get(), false}); StmtExprVisitor::VisitStmt_(op); if (record_.at(var.get()) && !no_split_) { @@ -143,7 +164,7 @@ class CandidateSelector final : public StmtExprVisitor { private: bool in_likely_{false}; bool no_split_{false}; - bool split_const_loop_{false}; + bool partition_const_loop_{false}; std::unordered_map record_; }; @@ -153,16 +174,16 @@ class CandidateSelector final : public StmtExprVisitor { class PartitionFinder : public StmtExprVisitor { public: explicit PartitionFinder(Var current_var, - const std::unordered_map& hint_map, - const std::unordered_map& relax_map) - : current_var_(current_var), hint_map_(hint_map), relax_map_(relax_map) { - for (const auto& kv : hint_map) { - out_vars_.insert(kv.first); - } - for (const auto& kv : relax_map) { - out_vars_.insert(kv.first); - } - } + const std::unordered_map& hint_map, + const std::unordered_map& relax_map) + : current_var_(current_var), hint_map_(hint_map), relax_map_(relax_map) { + for (const auto& kv : hint_map) { + out_vars_.insert(kv.first); + } + for (const auto& kv : relax_map) { + out_vars_.insert(kv.first); + } + } void VisitStmt_(const ForNode* op) final { if (ExprUseVars(op->min, out_vars_) || ExprUseVars(op->extent, out_vars_)) return; @@ -195,21 +216,18 @@ class PartitionFinder : public StmtExprVisitor { void VisitExpr_(const CallNode* op) final { if (op->is_intrinsic(CallNode::likely)) { PrimExpr cond = op->args[0]; - if (ExprUseVars(cond, - std::unordered_set({current_var_.get()}))) { + if (ExprUseVars(cond, std::unordered_set({current_var_.get()}))) { // For cond, find out the interval, if exists, in which we can prove that cond is // true. Also find the interval, if exists, in which we can prove that cond is // false. - IntSet interval = - DeduceBound(current_var_, cond, hint_map_, relax_map_); + IntSet interval = DeduceBound(current_var_, cond, hint_map_, relax_map_); if (!interval.is_nothing()) { // cond is true within interval partitions[{cond.get(), true}] = interval; } PrimExpr inverse_cond = InverseCond(cond); if (inverse_cond.defined()) { - IntSet interval = - DeduceBound(current_var_, inverse_cond, hint_map_, relax_map_); + IntSet interval = DeduceBound(current_var_, inverse_cond, hint_map_, relax_map_); if (!interval.is_nothing()) { // cond is false within interval partitions[{cond.get(), false}] = interval; @@ -228,22 +246,22 @@ class PartitionFinder : public StmtExprVisitor { PrimExpr inverse_cond; if (const LTNode* op = cond.as()) { // a < b -> a >= b - inverse_cond = GENode::make(op->a, op->b); + inverse_cond = GE(op->a, op->b); } else if (const GTNode* op = cond.as()) { // a > b -> a <= b - inverse_cond = LENode::make(op->a, op->b); + inverse_cond = LE(op->a, op->b); } else if (const LENode* op = cond.as()) { // a <= b -> a > b - inverse_cond = GTNode::make(op->a, op->b); + inverse_cond = GT(op->a, op->b); } else if (const GENode* op = cond.as()) { // a >= b -> a < b - inverse_cond = LTNode::make(op->a, op->b); + inverse_cond = LT(op->a, op->b); } else if (const EQNode* op = cond.as()) { // a == b -> a != b - inverse_cond = NENode::make(op->a, op->b); + inverse_cond = NE(op->a, op->b); // a != b -> a == b } else if (const NENode* op = cond.as()) { - inverse_cond = EQNode::make(op->a, op->b); + inverse_cond = EQ(op->a, op->b); } return inverse_cond; } @@ -258,7 +276,7 @@ class PartitionFinder : public StmtExprVisitor { class ConditionEliminator : public StmtExprMutator { public: explicit ConditionEliminator(const std::unordered_set& ps, bool cond_value = true) - : ps_(ps), cond_value_(cond_value) {} + : ps_(ps), cond_value_(cond_value) {} PrimExpr VisitExpr(const PrimExpr& e) final { if (ps_.find(e.get()) != ps_.end()) { @@ -272,12 +290,11 @@ class ConditionEliminator : public StmtExprMutator { bool cond_value_; }; - // Insert the partition branch at the innermost thread scope class ThreadPartitionInserter : public StmtMutator { public: - explicit ThreadPartitionInserter(const std::unordered_set& ps, - PrimExpr cond) : ps_(ps), cond_(cond), innermost_thread_scope_(false) {} + explicit ThreadPartitionInserter(const std::unordered_set& ps, PrimExpr cond) + : ps_(ps), cond_(cond), innermost_thread_scope_(false) {} Stmt VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == attr::thread_extent) { @@ -286,9 +303,9 @@ class ThreadPartitionInserter : public StmtMutator { // add branch code inside the innermost thread scope if (innermost_thread_scope_) { Stmt simplified_body = ConditionEliminator(ps_)(op->body); - Stmt body = IfThenElseNode::make(cond_, simplified_body, op->body); + Stmt body = IfThenElse(cond_, simplified_body, op->body); PrimExpr value = this->VisitExpr(op->value); - stmt = AttrStmtNode::make(op->node, op->attr_key, value, body); + stmt = AttrStmt(op->node, op->attr_key, value, body); } innermost_thread_scope_ = false; return stmt; @@ -307,8 +324,8 @@ class ThreadPartitionInserter : public StmtMutator { // likely conditions class LoopPartitioner : public StmtMutator { public: - explicit LoopPartitioner(bool split_const_loop) - : selector(CandidateSelector(split_const_loop)) {} + explicit LoopPartitioner(bool partition_const_loop) + : selector(CandidateSelector(partition_const_loop)) {} Stmt VisitAndMutate(Stmt stmt) { selector(stmt); @@ -317,15 +334,14 @@ class LoopPartitioner : public StmtMutator { Stmt VisitStmt_(const ForNode* op) final { if (selector.candidates.count(op)) { - Stmt s = TryPartition(op, GetRef(op), op->loop_var, - op->min, op->min + op->extent - 1, op->body, false); + Stmt s = TryPartition(op, GetRef(op), op->loop_var, op->min, op->min + op->extent - 1, + op->body, false); if (s.defined()) return s; } // normal path when loop partition fails // normal loop variable can be put into hint map. - hint_map_.insert({op->loop_var.get(), - IntSet::interval(op->min, op->min + op->extent - 1)}); + hint_map_.insert({op->loop_var.get(), IntSet::interval(op->min, op->min + op->extent - 1)}); Stmt res = StmtMutator::VisitStmt_(op); hint_map_.erase(op->loop_var.get()); return res; @@ -336,7 +352,7 @@ class LoopPartitioner : public StmtMutator { return StmtMutator::VisitStmt_(op); } - const IterVarNode *iv = op->node.as(); + const IterVarNode* iv = op->node.as(); CHECK(iv); Var var = iv->var; if (selector.candidates.count(op)) { @@ -345,17 +361,15 @@ class LoopPartitioner : public StmtMutator { } // normal path when loop parittion fails. - runtime::ThreadScope scope = runtime::ThreadScope::make(iv->thread_tag); + runtime::ThreadScope scope = runtime::ThreadScope::Create(iv->thread_tag); Stmt res; if (scope.rank == 1) { // threadIdx should be put into relax map, in case of divergence. - relax_map_.insert({var.get(), - IntSet::interval(make_zero(var.dtype()), op->value - 1)}); + relax_map_.insert({var.get(), IntSet::interval(make_zero(var.dtype()), op->value - 1)}); res = StmtMutator::VisitStmt_(op); relax_map_.erase(var.get()); } else { - hint_map_.insert({var.get(), - IntSet::interval(make_zero(var.dtype()), op->value - 1)}); + hint_map_.insert({var.get(), IntSet::interval(make_zero(var.dtype()), op->value - 1)}); res = StmtMutator::VisitStmt_(op); hint_map_.erase(var.get()); } @@ -363,13 +377,11 @@ class LoopPartitioner : public StmtMutator { } private: - Stmt TryPartition(const Object* op, const Stmt& stmt, Var var, - PrimExpr min, PrimExpr max, Stmt body, bool partition_thread_scope); + Stmt TryPartition(const Object* op, const Stmt& stmt, Var var, PrimExpr min, PrimExpr max, + Stmt body, bool partition_thread_scope); - std::pair> - GetIntervalAndCondset(const Partition &partitions, - const arith::IntervalSet &for_interval, - bool cond_value); + std::pair> GetIntervalAndCondset( + const Partition& partitions, const arith::IntervalSet& for_interval, bool cond_value); inline Stmt MakeFor(const Object* op, PrimExpr extent, Stmt body); @@ -382,18 +394,15 @@ class LoopPartitioner : public StmtMutator { // Returns an interval (in the first component) in which all the conditions // given in the second component provably have value given by cond_value -std::pair> -LoopPartitioner::GetIntervalAndCondset(const Partition &partitions, - const arith::IntervalSet &for_interval, - bool cond_value) { +std::pair> LoopPartitioner::GetIntervalAndCondset( + const Partition& partitions, const arith::IntervalSet& for_interval, bool cond_value) { Array sets; std::unordered_set cond_set; - for (const auto &kv : partitions) { + for (const auto& kv : partitions) { if (kv.first.second == cond_value) { arith::IntervalSet interval = Downcast(kv.second); - arith::IntervalSet intersection = arith::Intersect( - &analyzer_, interval, for_interval); + arith::IntervalSet intersection = arith::Intersect(&analyzer_, interval, for_interval); if (!intersection->IsEmpty()) { sets.push_back(kv.second); cond_set.insert(kv.first.first); @@ -450,13 +459,8 @@ LoopPartitioner::GetIntervalAndCondset(const Partition &partitions, * which will eventually be simplified to empty code. And because only one loop was generated * from loop 2 we stop recursing. */ -Stmt LoopPartitioner::TryPartition(const Object* node, - const Stmt& stmt, - Var var, - PrimExpr min, - PrimExpr max, - Stmt body, - bool partition_thread_scope) { +Stmt LoopPartitioner::TryPartition(const Object* node, const Stmt& stmt, Var var, PrimExpr min, + PrimExpr max, Stmt body, bool partition_thread_scope) { using namespace arith; // include hint of var. hint_map_.insert({var.get(), IntSet::interval(min, max)}); @@ -473,7 +477,7 @@ Stmt LoopPartitioner::TryPartition(const Object* node, std::unordered_set cond_set; // find an interval in which all conditions on var are true std::tie(middle_interval, cond_set) = - GetIntervalAndCondset(finder.partitions, for_interval, true); + GetIntervalAndCondset(finder.partitions, for_interval, true); if (middle_interval.is_nothing()) { // if such interval doesn't exist, find an interval in which all // conditions on var are false @@ -500,13 +504,12 @@ Stmt LoopPartitioner::TryPartition(const Object* node, Stmt pre_stmt; bool pre_stmt_recurse = true; if (middle_interval_i->HasLowerBound()) { - body_begin = tir::Simplify(middle_interval.min()); + body_begin = analyzer_.Simplify(middle_interval.min()); if (!analyzer_.CanProve(body_begin == min)) { PrimExpr cond = (body_begin - min >= 0); if (!analyzer_.CanProve(cond)) { - LOG(WARNING) << "Cannot prove: " << cond - << ", when generating the pre doubt loop"; - body_begin = MaxNode::make(body_begin, min); + LOG(WARNING) << "Cannot prove: " << cond << ", when generating the pre doubt loop"; + body_begin = Max(body_begin, min); // stop recursing on this interval if we can't prove it has non-negative length pre_stmt_recurse = false; } @@ -525,20 +528,18 @@ Stmt LoopPartitioner::TryPartition(const Object* node, Stmt post_stmt; bool post_stmt_recurse = true; if (middle_interval_i->HasUpperBound()) { - post_doubt_begin = tir::Simplify(middle_interval.max() + 1); + post_doubt_begin = analyzer_.Simplify(middle_interval.max() + 1); if (!analyzer_.CanProve(middle_interval.max() == max)) { // require the extent to be non-negative PrimExpr cond = (max - post_doubt_begin + 1 >= 0); if (!analyzer_.CanProve(cond)) { - LOG(WARNING) << "Cannot prove: " << cond - << ", when generating the post doubt loop"; - post_doubt_begin = MinNode::make(post_doubt_begin, max+1); + LOG(WARNING) << "Cannot prove: " << cond << ", when generating the post doubt loop"; + post_doubt_begin = Min(post_doubt_begin, max + 1); // stop recursing on this interval if we can't prove it has non-negative length post_stmt_recurse = false; } if (!partition_thread_scope) { - Stmt post_body = - Substitute(body, {{Var{var}, var + post_doubt_begin}}); + Stmt post_body = Substitute(body, {{Var{var}, var + post_doubt_begin}}); post_stmt = MakeFor(node, max - post_doubt_begin + 1, post_body); } } @@ -580,21 +581,21 @@ Stmt LoopPartitioner::TryPartition(const Object* node, return s; } -inline Stmt LoopPartitioner::MakeFor(const Object *node, PrimExpr extent, Stmt body) { - const ForNode *for_node = static_cast(node); +inline Stmt LoopPartitioner::MakeFor(const Object* node, PrimExpr extent, Stmt body) { + const ForNode* for_node = static_cast(node); CHECK(for_node); if (analyzer_.CanProve(extent == make_const(DataType::Int(32), 1))) { // If the loop extent is 1, do not create the loop anymore return Substitute(body, {{Var{for_node->loop_var}, make_const(DataType::Int(32), 0)}}); } else { - return ForNode::make(for_node->loop_var, IntImm(for_node->min.dtype(), 0), extent, - for_node->for_type, for_node->device_api, body); + return For(for_node->loop_var, IntImm(for_node->min.dtype(), 0), extent, for_node->for_type, + for_node->device_api, body); } } class RemoveLikelyTags : public StmtExprMutator { public: - PrimExpr VisitExpr_(const CallNode *op) final { + PrimExpr VisitExpr_(const CallNode* op) final { if (op->is_intrinsic(CallNode::likely)) { CHECK_EQ(op->args.size(), 1); return StmtExprMutator::VisitExpr(op->args[0]); @@ -604,11 +605,30 @@ class RemoveLikelyTags : public StmtExprMutator { } }; -Stmt LoopPartition(Stmt stmt, bool split_const_loop) { - stmt = LoopPartitioner(split_const_loop).VisitAndMutate(std::move(stmt)); +Stmt LoopPartition(Stmt stmt, bool partition_const_loop) { + stmt = LoopPartitioner(partition_const_loop).VisitAndMutate(std::move(stmt)); stmt = RemoveLikelyTags()(std::move(stmt)); return stmt; } +namespace transform { + +Pass LoopPartition() { + auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + auto cfg = ctx->GetConfig("tir.LoopPartition"); + if (!cfg.defined()) { + cfg = AttrsWithDefaultValues(); + } + n->body = LoopPartition(std::move(n->body), cfg.value()->partition_const_loop); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.LoopPartition", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.LoopPartition").set_body_typed(LoopPartition); + +} // namespace transform + } // namespace tir } // namespace tvm diff --git a/src/tir/transforms/lower_custom_datatypes.cc b/src/tir/transforms/lower_custom_datatypes.cc index ce81528b8b35..154023c1cf4d 100644 --- a/src/tir/transforms/lower_custom_datatypes.cc +++ b/src/tir/transforms/lower_custom_datatypes.cc @@ -21,10 +21,11 @@ * \brief Pass for lowering custom datatypes */ +#include +#include #include #include -#include -#include + #include "../../target/datatype/registry.h" namespace tvm { @@ -79,9 +80,8 @@ class CustomDatatypesLowerer : public StmtExprMutator { if (toBeLowered) { auto new_allocate_type = DataType::UInt(allocate->dtype.bits(), allocate->dtype.lanes()); - return AllocateNode::make( - allocate->buffer_var, new_allocate_type, allocate->extents, - allocate->condition, allocate->body); + return Allocate(allocate->buffer_var, new_allocate_type, allocate->extents, + allocate->condition, allocate->body); } return stmt; } @@ -92,24 +92,24 @@ class CustomDatatypesLowerer : public StmtExprMutator { load = expr.as(); if (toBeLowered) { auto new_load_type = DataType::UInt(load->dtype.bits()); - return LoadNode::make(new_load_type, load->buffer_var, load->index, load->predicate); + return Load(new_load_type, load->buffer_var, load->index, load->predicate); } return expr; } -#define DEFINE_MUTATE__(OP, NodeName) \ - inline PrimExpr VisitExpr_(const NodeName* op) final { \ - auto type_code = op->dtype.code(); \ +#define DEFINE_MUTATE__(OP, NodeName) \ + inline PrimExpr VisitExpr_(const NodeName* op) final { \ + auto type_code = op->dtype.code(); \ bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(type_code); \ - PrimExpr expr = StmtExprMutator::VisitExpr_(op); \ - op = expr.as(); \ - if (toBeLowered) { \ - auto lower = datatype::Get##OP##LowerFunc(target_, type_code); \ - CHECK(lower) << #OP " lowering function for target " << target_ << " type " \ - << static_cast(type_code) << " not found"; \ - return (*lower)(expr); \ - } \ - return expr; \ + PrimExpr expr = StmtExprMutator::VisitExpr_(op); \ + op = expr.as(); \ + if (toBeLowered) { \ + auto lower = datatype::Get##OP##LowerFunc(target_, type_code); \ + CHECK(lower) << #OP " lowering function for target " << target_ << " type " \ + << static_cast(type_code) << " not found"; \ + return (*lower)(expr); \ + } \ + return expr; \ } DEFINE_MUTATE__(Add, AddNode); @@ -131,15 +131,13 @@ class CustomDatatypesLowerer : public StmtExprMutator { std::string target_; }; - namespace transform { Pass LowerCustomDatatypes() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { auto* n = f.CopyOnWrite(); auto target = f->GetAttr(tvm::attr::kTarget); - CHECK(target.defined()) - << "LowerCustomDatatypes: Require the target attribute"; + CHECK(target.defined()) << "LowerCustomDatatypes: Require the target attribute"; n->body = CustomDatatypesLowerer(target.value()->target_name)(std::move(n->body)); return f; @@ -147,8 +145,7 @@ Pass LowerCustomDatatypes() { return CreatePrimFuncPass(pass_func, 0, "tir.LowerCustomDatatypes", {}); } -TVM_REGISTER_GLOBAL("tir.transform.LowerCustomDatatypes") -.set_body_typed(LowerCustomDatatypes); +TVM_REGISTER_GLOBAL("tir.transform.LowerCustomDatatypes").set_body_typed(LowerCustomDatatypes); } // namespace transform diff --git a/src/tir/transforms/lower_device_storage_access_info.cc b/src/tir/transforms/lower_device_storage_access_info.cc index e7f81ed929b9..9d6b47a1ca37 100644 --- a/src/tir/transforms/lower_device_storage_access_info.cc +++ b/src/tir/transforms/lower_device_storage_access_info.cc @@ -21,22 +21,21 @@ * \file lower_device_storage_access.cc * \brief Lower the special device storage access. */ +#include +#include +#include +#include #include #include -#include -#include -#include -#include - -#include "../pass/ir_util.h" #include "../../runtime/thread_storage_scope.h" +#include "ir_util.h" namespace tvm { namespace tir { -using runtime::StorageScope; using runtime::StorageRank; +using runtime::StorageScope; class StorageAccessInfoLower : public StmtExprMutator { public: @@ -53,8 +52,7 @@ class StorageAccessInfoLower : public StmtExprMutator { << "Double allocation of " << it->second.scope.to_string(); if (info->head_address.defined()) { - return LetStmtNode::make( - op->buffer_var, info->head_address, op->body); + return LetStmt(op->buffer_var, info->head_address, op->body); } else { return op->body; } @@ -65,7 +63,7 @@ class StorageAccessInfoLower : public StmtExprMutator { Stmt VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == attr::storage_scope) { const VarNode* buf = op->node.as(); - StorageScope scope = StorageScope::make(op->value.as()->value); + StorageScope scope = StorageScope::Create(op->value.as()->value); StorageEntry e; e.scope = scope; if (scope.tag.length() != 0) { @@ -101,30 +99,23 @@ class StorageAccessInfoLower : public StmtExprMutator { PrimExpr offset = op->args[2]; auto it = storage_info_.find(buffer); if (it != storage_info_.end() && it->second.info.defined()) { - return MakeTaggedAccessPtr( - op->dtype, buffer_var, dtype, offset, - it->second.info); + return MakeTaggedAccessPtr(op->dtype, buffer_var, dtype, offset, it->second.info); } CHECK(op->dtype.is_handle()); // Change to address_of return AddressOffset(buffer_var, dtype, offset); } - PrimExpr MakeTaggedAccessPtr(DataType ptr_type, - Var buffer_var, - DataType dtype, - PrimExpr offset, + PrimExpr MakeTaggedAccessPtr(DataType ptr_type, Var buffer_var, DataType dtype, PrimExpr offset, const MemoryInfo& info) { if (ptr_type.is_handle()) { - CHECK(info->head_address.defined()) - << buffer_var << " is not adddressable."; + CHECK(info->head_address.defined()) << buffer_var << " is not adddressable."; return AddressOffset(buffer_var, dtype, offset); } int dtype_bits = dtype.bits() * dtype.lanes(); CHECK_EQ(info->unit_bits % dtype_bits, 0); - return cast(ptr_type, - tir::Simplify(offset / make_const( - offset.dtype(), info->unit_bits / dtype_bits))); + return cast(ptr_type, analyzer_.Simplify( + offset / make_const(offset.dtype(), info->unit_bits / dtype_bits))); } // The storage entry. struct StorageEntry { @@ -137,12 +128,11 @@ class StorageAccessInfoLower : public StmtExprMutator { }; // The storage scope of each buffer std::unordered_map storage_info_; + // analyzer + arith::Analyzer analyzer_; }; -Stmt LowerStorageAccessInfo(Stmt stmt) { - return StorageAccessInfoLower()(std::move(stmt)); -} - +Stmt LowerStorageAccessInfo(Stmt stmt) { return StorageAccessInfoLower()(std::move(stmt)); } namespace transform { @@ -152,12 +142,11 @@ Pass LowerDeviceStorageAccessInfo() { n->body = StorageAccessInfoLower()(std::move(n->body)); return f; }; - return CreatePrimFuncPass( - pass_func, 0, "tir.LowerDeviceStorageAccessInfo", {}); + return CreatePrimFuncPass(pass_func, 0, "tir.LowerDeviceStorageAccessInfo", {}); } TVM_REGISTER_GLOBAL("tir.transform.LowerDeviceStorageAccessInfo") -.set_body_typed(LowerDeviceStorageAccessInfo); + .set_body_typed(LowerDeviceStorageAccessInfo); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/lower_intrin.cc b/src/tir/transforms/lower_intrin.cc index 6ae638f33474..c7aa949924d7 100644 --- a/src/tir/transforms/lower_intrin.cc +++ b/src/tir/transforms/lower_intrin.cc @@ -21,24 +21,24 @@ * Lower intrinsic calls and ops to device specific ir when possible. * \file lower_intrin.cc */ +#include +#include #include -#include +#include #include -#include -#include -#include #include -#include "../../arith/pattern_match.h" + #include "../../arith/ir_mutator_with_analyzer.h" +#include "../../arith/pattern_match.h" namespace tvm { namespace tir { class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { public: - using IRMutatorWithAnalyzer::VisitStmt_; using IRMutatorWithAnalyzer::VisitExpr_; + using IRMutatorWithAnalyzer::VisitStmt_; IntrinInjecter(arith::Analyzer* analyzer, std::string target_name) : IRMutatorWithAnalyzer(analyzer) { @@ -51,8 +51,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { } PrimExpr VisitExpr_(const CallNode* op) final { - if (op->call_type == CallNode::Intrinsic || - op->call_type == CallNode::PureIntrinsic) { + if (op->call_type == CallNode::Intrinsic || op->call_type == CallNode::PureIntrinsic) { PrimExpr r = ApplyPattern(op->name, GetRef(op)); if (r.defined()) return r; } @@ -79,16 +78,14 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { const DataType& dtype = op->dtype; CHECK(dtype.is_int() || dtype.is_uint()); - if (support_bitwise_op_ && - is_const_power_of_two_integer(op->b, &shift)) { + if (support_bitwise_op_ && is_const_power_of_two_integer(op->b, &shift)) { // lower to right shift if possible. return op->a >> make_const(dtype, shift); } if (analyzer_->CanProveGreaterEqual(op->b, 0)) { // Common path, positive divisor - if (analyzer_->CanProveGreaterEqual(op->a, 0) || - analyzer_->CanProveGreaterEqual(e, 0)) { + if (analyzer_->CanProveGreaterEqual(op->a, 0) || analyzer_->CanProveGreaterEqual(e, 0)) { return truncdiv(op->a, op->b); } else { DLOG(INFO) << "LowerFloorDiv: Cannot decide the sign of divident"; @@ -101,7 +98,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { // equivalent to rdiv + (rmod >= 0 ? 0: -1); return rdiv + (rmod >> make_const(dtype, dtype.bits() - 1)); } else { - return tir::SelectNode::make(rmod >= 0 , rdiv, rdiv - make_const(dtype, 1)); + return tir::Select(rmod >= 0, rdiv, rdiv - make_const(dtype, 1)); } } } else { @@ -111,9 +108,8 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { // b < 0 => (rmod <= 0 ? rdiv : rdiv - 1) PrimExpr rdiv = truncdiv(op->a, op->b); PrimExpr rmod = truncmod(op->a, op->b); - return tir::SelectNode::make( - (op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0), - rdiv, rdiv - make_const(dtype, 1)); + return tir::Select((op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0), rdiv, + rdiv - make_const(dtype, 1)); } } @@ -126,11 +122,9 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { const DataType& dtype = op->dtype; CHECK(dtype.is_int() || dtype.is_uint()); - if (support_bitwise_op_ && - is_const_power_of_two_integer(op->b, &shift)) { + if (support_bitwise_op_ && is_const_power_of_two_integer(op->b, &shift)) { // lower to masking if possible. - int64_t mask = ( - static_cast(1) << static_cast(shift)) - 1; + int64_t mask = (static_cast(1) << static_cast(shift)) - 1; return op->a & make_const(dtype, mask); } @@ -150,7 +144,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { // -> rmod >= 0 ? 0 : b return rmod + (op->b & (rmod >> make_const(dtype, dtype.bits() - 1))); } else { - return tir::SelectNode::make(rmod >= 0, rmod, rmod + op->b); + return tir::Select(rmod >= 0, rmod, rmod + op->b); } } } else { @@ -161,9 +155,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { // b > 0 && rmod < 0 -> rmod + b // b < 0 && rmod < 0 -> rmod // b < 0 && rmod > 0 -> rmod + b - return tir::SelectNode::make( - (op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0), - rmod, rmod + op->b); + return tir::Select((op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0), rmod, rmod + op->b); } } @@ -172,8 +164,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { PVar x, y; PVar c; auto e = GetRef(op); - if (max(floordiv(x, y), c).Match(e) && - c.Eval()->value >= 0 && + if (max(floordiv(x, y), c).Match(e) && c.Eval()->value >= 0 && analyzer_->CanProveGreaterEqual(y.Eval(), 0)) { return max(VisitExpr(truncdiv(x, y).Eval()), c.Eval()); } @@ -225,28 +216,26 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { }; if (should_swap()) { - PrimExpr new_bcast = BroadcastNode::make(cast->value, bcast->lanes); - return CastNode::make(bcast->dtype, new_bcast); + PrimExpr new_bcast = Broadcast(cast->value, bcast->lanes); + return Cast(bcast->dtype, new_bcast); } } } return e; } - PrimExpr MakeFMA(const PrimExpr& a, const PrimExpr& b, const PrimExpr& c, - const AddNode* op) { + PrimExpr MakeFMA(const PrimExpr& a, const PrimExpr& b, const PrimExpr& c, const AddNode* op) { // emit fma instruction: a * b + c PrimExpr lhs = SwapBroadcastCast(a); PrimExpr rhs = SwapBroadcastCast(b); if (fma_ != nullptr && op->dtype.is_float()) { - PrimExpr r = (*fma_)(CallNode::make( - op->dtype, "fma", {lhs, rhs, c}, CallNode::PureIntrinsic)); + PrimExpr r = (*fma_)(Call(op->dtype, "fma", {lhs, rhs, c}, CallNode::PureIntrinsic)); if (r.defined()) return this->VisitExpr(r); } else { if (!lhs.same_as(a) || !rhs.same_as(b)) { - PrimExpr mul = this->VisitExpr(MulNode::make(lhs, rhs)); - return AddNode::make(mul, this->VisitExpr(c)); + PrimExpr mul = this->VisitExpr(Mul(lhs, rhs)); + return Add(mul, this->VisitExpr(c)); } } return IRMutatorWithAnalyzer::VisitExpr_(op); @@ -289,18 +278,15 @@ Pass LowerIntrin() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { auto* n = f.CopyOnWrite(); auto target = f->GetAttr(tvm::attr::kTarget); - CHECK(target.defined()) - << "LowerIntrin: Require the target attribute"; + CHECK(target.defined()) << "LowerIntrin: Require the target attribute"; arith::Analyzer analyzer; - n->body = - IntrinInjecter(&analyzer, target.value()->target_name)(std::move(n->body)); + n->body = IntrinInjecter(&analyzer, target.value()->target_name)(std::move(n->body)); return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.LowerIntrin", {}); } -TVM_REGISTER_GLOBAL("tir.transform.LowerIntrin") -.set_body_typed(LowerIntrin); +TVM_REGISTER_GLOBAL("tir.transform.LowerIntrin").set_body_typed(LowerIntrin); } // namespace transform diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index 655a0074c7fd..ee17f081c6d8 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -21,28 +21,27 @@ * Lower allreduce to device implementable ir. * \file lower_thread_allreduce.cc */ +#include +#include +#include #include #include #include -#include -#include -#include #include -#include "../pass/ir_util.h" -#include "../../arith/compute_expr.h" #include "../../runtime/thread_storage_scope.h" +#include "ir_util.h" namespace tvm { namespace tir { class ThreadAllreduceBuilder final : public StmtExprMutator { public: - explicit ThreadAllreduceBuilder(int warp_size) - : warp_size_(warp_size) {} + explicit ThreadAllreduceBuilder(const TargetNode* target) + : target_(target), warp_size_(target->thread_warp_size) {} - Stmt VisitStmt_(const AttrStmtNode *op) final { + Stmt VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == attr::thread_extent) { thread_extents_.push_back(op); Stmt ret = StmtExprMutator::VisitStmt_(op); @@ -58,7 +57,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { return ret; } } else if (op->attr_key == attr::reduce_scope) { - const CommReducerNode *combiner = op->node.as(); + const CommReducerNode* combiner = op->node.as(); CHECK(combiner); reduce_combiner_.push_back(combiner); Stmt ret = StmtExprMutator::VisitStmt_(op); @@ -84,15 +83,15 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { auto it = alloc_remap_.find(op->buffer_var.get()); if (it != alloc_remap_.end()) { const AllocateNode* repl = it->second.as(); - // use volatile access to shared buffer. - stmt = AttrStmtNode::make( - repl->buffer_var, attr::volatile_scope, 1, op->body); - stmt = AllocateNode::make( - repl->buffer_var, repl->dtype, - repl->extents, repl->condition, stmt); - stmt = AttrStmtNode::make( - repl->buffer_var, attr::storage_scope, - StringImmNode::make("shared"), stmt); + if (warp_allocs_.count(repl)) { + stmt = Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, op->body); + stmt = AttrStmt(repl->buffer_var, attr::storage_scope, StringImm("local"), stmt); + } else { + // use volatile access to shared buffer. + stmt = AttrStmt(repl->buffer_var, attr::volatile_scope, 1, op->body); + stmt = Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, stmt); + stmt = AttrStmt(repl->buffer_var, attr::storage_scope, StringImm("shared"), stmt); + } return stmt; } else { return stmt; @@ -119,29 +118,30 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { return scope.dim_index < other.scope.dim_index; } }; + // make allreduce. Stmt MakeAllreduce(const CallNode* call) { CHECK(!reduce_combiner_.empty()); - const CommReducerNode *combiner = reduce_combiner_.back(); + const CommReducerNode* combiner = reduce_combiner_.back(); size_t size = combiner->result.size(); - const IntImmNode *size_of_args = call->args[0].as(); + const IntImmNode* size_of_args = call->args[0].as(); CHECK(size_of_args) << call->args[0]->GetTypeKey(); CHECK_EQ(size, size_of_args->value); Array inits = combiner->identity_element; std::vector values(size); std::vector types(size); - PrimExpr cond = call->args[size+1]; + PrimExpr cond = call->args[size + 1]; for (size_t idx = 0; idx < size; ++idx) { - values[idx] = call->args[1+idx]; + values[idx] = call->args[1 + idx]; if (!is_one(cond)) { - values[idx] = SelectNode::make(cond, values[idx], inits[idx]); + values[idx] = Select(cond, values[idx], inits[idx]); } types[idx] = values[idx].dtype(); } std::vector buffers(size); for (size_t idx = 0; idx < size; ++idx) { - const VarNode* buffer = call->args[2+size+idx].as(); + const VarNode* buffer = call->args[2 + size + idx].as(); CHECK(buffer); buffers[idx] = buffer; } @@ -149,22 +149,35 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { std::unordered_set reduce_set; for (size_t i = 2 + 2 * size; i < call->args.size(); ++i) { const VarNode* v = call->args[i].as(); - CHECK(v); - reduce_set.insert(v); + // The simply optimization replace a iteration variable with a constant + // when extent of the iteration is 1. As threaded IterVar always started from 0, + // we can just ignore this variable in this case. + if (v) { + reduce_set.insert(v); + } else { + CHECK(call->args[i].as() && call->args[i].as()->value == 0) + << "arg" << i << "should be a VarNode or IntImmNode"; + } } + size_t nmatch = 0; std::vector vred, vpar; for (const AttrStmtNode* attr : thread_extents_) { ThreadEntry e; IterVar iv = Downcast(attr->node); - e.scope = runtime::ThreadScope::make(iv->thread_tag); + e.scope = runtime::ThreadScope::Create(iv->thread_tag); e.iv = iv; CHECK_LE(e.scope.rank, 1); - CHECK_GE(e.scope.dim_index, 0) - << "vthread do not work with cross thread reduction"; + CHECK_GE(e.scope.dim_index, 0) << "vthread do not work with cross thread reduction"; if (e.scope.rank == 1) { - CHECK(arith::GetConstInt(attr->value, &(e.extent))) - << "Need constant extent for reduce set " << iv; + const auto* ptr = attr->value.as(); + CHECK(ptr) << "Need constant extent for reduce set " << iv; + e.extent = static_cast(ptr->value); + // ignore variables equal to 0 + if (e.extent == 1) { + continue; + } + if (reduce_set.count(iv->var.get())) { vred.push_back(e); ++nmatch; @@ -173,66 +186,197 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { } } } - CHECK_EQ(nmatch, reduce_set.size()) - << "Not all reduce index are presented in the context"; + CHECK_EQ(nmatch, reduce_set.size()) << "Not all reduce index are presented in the context"; std::sort(vred.begin(), vred.end()); std::sort(vpar.begin(), vpar.end()); // the size of each index. int reduce_extent, group_extent; - int threadx_extent = 1; PrimExpr reduce_index = FlattenThread(vred, &reduce_extent); PrimExpr group_index = FlattenThread(vpar, &group_extent); - if (reduce_extent == 1) { - // special case, no reduction is needed. - std::vector stores(size); + std::vector seq; + std::vector shared_bufs(size); + std::vector local_vars; + // + // This is an optimization. For small reduction sizes, it may be beneficial + // for a single warp to performance the entire reduction. No trips to shared + // memory and no cross warp synchronizations are required. + // The following code emits the reduction as follows: + // + // Allocate reduction vars v[i], i = 0..size-1 + // + // for offset from WARP_SIZE to 1 by 2 + // + // a <- load(v[i]) + // b <- shuffle_down(load(v[i], offset)) + // v[i] <- reduction(a, b) + // + // broadcast results from lane 0 to all other lanes and store + // the final reduction result to the proper location. + // + if (is_warp_reduction(types)) { + // TODO(tvm-team) sub-warp reduction support. + CHECK_EQ(reduce_extent, warp_size_) << "not a warp reduction"; + // + // This is the index to the reduction variable, one reduction + // variable per warp. Local scope seems easier to reason without + // relying on a pattern match pass to fix it later. + PrimExpr index(0); + + for (size_t idx = 0; idx < size; ++idx) { + shared_bufs[idx] = Var("red_buf" + std::to_string(idx), DataType::Handle()); + PrimExpr pred = const_true(types[idx].lanes()); + seq.emplace_back(Store(shared_bufs[idx], values[idx], index, pred)); + + // Uses a local variable to store the shuffled data. + // Later on, this allocation will be properly attached to this statement. + Var var("t" + std::to_string(idx), types[idx]); + Stmt s = Allocate(var, var.dtype(), {PrimExpr(1)}, pred, Evaluate(0)); + local_vars.push_back(s); + } + + // The mask for this reducer, as this reducer may sit inside + // a divergent control flow. Here it uses a variable to cache the current + // active channels. + // + Var mask_var("mask", DataType::UInt(32)); + { + PrimExpr pred = const_true(1); + PrimExpr mask = + Call(DataType::UInt(32), intrinsic::tvm_warp_activemask, {}, CallNode::Intrinsic); + seq.emplace_back(Store(mask_var, mask, index, pred)); + // Push allocation with an empty body. Later this will be fixed + // when the entire body is ready. + auto stmt = Allocate(mask_var, mask_var->dtype, {PrimExpr(1)}, pred, Evaluate(0)); + local_vars.push_back(stmt); + } + + // Emit reductions within a warp. + for (int offset = warp_size_ / 2; offset > 0; offset /= 2) { + // Load reduction values, no synchronization needed. + Array a, b; + for (size_t i = 0; i < size; ++i) { + Var var = shared_bufs[i]; + PrimExpr pred = const_true(types[i].lanes()); + PrimExpr val = Load(types[i], var, index, pred); + a.push_back(val); + + // __shfl_*sync calls shall not appear in if_then_else expressions + // as this is causing extra divergency. E.g. + // + // v1 = (v2 < v3) ? v3 : __shfl_sync(mask, v1, 0); + // + // behaves differently from + // + // int t = __shfl_sync(mask, v1, 0); + // v1 = (v2 < v3) ? v3 : t; + // + // The former may cause dead lock as there is a divergent + // branch with a warp sync call inside. + // + const char* shfl_func = intrinsic::tvm_warp_shuffle_down; + PrimExpr other = WarpShuffle(shfl_func, mask_var, val, offset); + const AllocateNode* repl = local_vars[i].as(); + Stmt s = Store(repl->buffer_var, other, index, pred); + seq.push_back(s); + + PrimExpr load = Load(types[i], repl->buffer_var, index, pred); + b.push_back(load); + } + + // Do reductions. + Array ret = (*combiner)(a, b); + + // Store the reduction result to itself. + std::vector stores(size); + for (size_t i = 0; i < size; ++i) { + Var var = shared_bufs[i]; + PrimExpr pred = const_true(types[i].lanes()); + stores[i] = Store(var, ret[i], index, pred); + } + seq.push_back(SeqStmt::Flatten(stores)); + } + + // Broadcast the reduction result from lane 0 to all other lanes. + // This avoids to emit predicated stores, as all threads are + // uniformmly writting the same result. + // for (size_t i = 0; i < size; ++i) { + Var var = shared_bufs[i]; PrimExpr pred = const_true(types[i].lanes()); - Var buffer_var = Downcast(call->args[2+size+i]); - stores[i] = StoreNode::make(buffer_var, values[i], 0, pred); + const char* shfl_func = intrinsic::tvm_warp_shuffle; + PrimExpr val = Load(types[i], var, index, pred); + PrimExpr splat = WarpShuffle(shfl_func, mask_var, val, 0); + seq.push_back(Store(var, splat, index, pred)); + } + + // Update existing allocations. + for (size_t i = 0; i < size; ++i) { + CHECK(!load_remap_.count(buffers[i])); + PrimExpr pred = const_true(types[i].lanes()); + Var var = shared_bufs[i]; + load_remap_[buffers[i]] = Load(types[i], var, index, pred); + Array extents{PrimExpr(1)}; + auto node = Allocate(var, types[i], extents, pred, Evaluate(0)); + alloc_remap_[buffers[i]] = node; + warp_allocs_.insert(node.get()); + } + } else { + int threadx_extent = 1; + if (reduce_extent == 1) { + // special case, no reduction is needed. + std::vector stores(size); + for (size_t i = 0; i < size; ++i) { + PrimExpr pred = const_true(types[i].lanes()); + Var buffer_var = Downcast(call->args[2 + size + i]); + stores[i] = Store(buffer_var, values[i], 0, pred); + } + return SeqStmt::Flatten(stores); + } + // Whether the threadIdx.x is involved in reduction. + if (vred[0].scope.dim_index == 0) { + threadx_extent = vred[0].extent; + } + // This sync is necessary because there might be incomplete read of + // previous iteration on the same buffer. + seq.emplace_back(SyncThread("shared")); + for (size_t idx = 0; idx < size; ++idx) { + shared_bufs[idx] = Var("red_buf" + std::to_string(idx), DataType::Handle()); + PrimExpr pred = const_true(types[idx].lanes()); + seq.emplace_back(Store(shared_bufs[idx], values[idx], + BufIndex(reduce_index, group_index, reduce_extent), pred)); + } + seq.emplace_back(SyncThread("shared")); + seq.emplace_back(MakeBufAllreduce(combiner, types, shared_bufs, reduce_index, group_index, + reduce_extent, threadx_extent)); + for (size_t idx = 0; idx < size; ++idx) { + CHECK(!load_remap_.count(buffers[idx])); + PrimExpr pred = const_true(types[idx].lanes()); + load_remap_[buffers[idx]] = + Load(types[idx], shared_bufs[idx], + BufIndex(make_zero(reduce_index.dtype()), group_index, reduce_extent), pred); + alloc_remap_[buffers[idx]] = + Allocate(shared_bufs[idx], types[idx], + {PrimExpr(group_extent), PrimExpr(reduce_extent)}, pred, Evaluate(0)); } - return SeqStmt::Flatten(stores); - } - // Whether the threadIdx.x is involved in reduction. - if (vred[0].scope.dim_index == 0) { - threadx_extent = vred[0].extent; - } - std::vector seq; - std::vector shared_bufs(size); - // This sync is necessary because there might be incomplete read of - // previous iteration on the same buffer. - seq.emplace_back(SyncThread("shared")); - for (size_t idx = 0; idx < size; ++idx) { - shared_bufs[idx] = Var("red_buf"+std::to_string(idx), DataType::Handle()); - PrimExpr pred = const_true(types[idx].lanes()); - seq.emplace_back(StoreNode::make( - shared_bufs[idx], values[idx], - BufIndex(reduce_index, group_index, reduce_extent), pred)); } - seq.emplace_back(SyncThread("shared")); - seq.emplace_back(MakeBufAllreduce( - combiner, types, shared_bufs, - reduce_index, group_index, reduce_extent, threadx_extent)); - for (size_t idx = 0; idx < size; ++idx) { - CHECK(!load_remap_.count(buffers[idx])); - PrimExpr pred = const_true(types[idx].lanes()); - load_remap_[buffers[idx]] = LoadNode::make( - types[idx], shared_bufs[idx], - BufIndex(make_zero(reduce_index.dtype()), group_index, reduce_extent), pred); - alloc_remap_[buffers[idx]] = AllocateNode::make( - shared_bufs[idx], types[idx], - {PrimExpr(group_extent), PrimExpr(reduce_extent)}, - pred, EvaluateNode::make(0)); + + // Fix all local allocations as all statements are built. + Stmt body = SeqStmt::Flatten(seq); + for (auto var : local_vars) { + const AllocateNode* repl = var.as(); + if (repl) { + body = Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, body); + body = AttrStmt(repl->buffer_var, attr::storage_scope, StringImm("local"), body); + } } - return SeqStmt::Flatten(seq); + + return body; } + // make allreduce. - Stmt MakeBufAllreduce(const CommReducerNode *combiner, - const std::vector& types, - const Array& shared_bufs, - PrimExpr reduce_index, - PrimExpr group_index, - int reduce_extent, - int threadx_extent) { + Stmt MakeBufAllreduce(const CommReducerNode* combiner, const std::vector& types, + const Array& shared_bufs, PrimExpr reduce_index, PrimExpr group_index, + int reduce_extent, int threadx_extent) { // Get next power of two int reduce_align = 1; while (reduce_extent > reduce_align) { @@ -247,15 +391,15 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { auto freduce = [&](int offset) { Array a, b; for (size_t i = 0; i < size; ++i) { - b.push_back(LoadNode::make(types[i], shared_bufs[i], - BufIndex(reduce_index + offset, group_index, reduce_extent), - const_true())); - a.push_back(LoadNode::make(types[i], shared_bufs[i], buf_index, const_true())); + b.push_back(Load(types[i], shared_bufs[i], + BufIndex(reduce_index + offset, group_index, reduce_extent), + const_true())); + a.push_back(Load(types[i], shared_bufs[i], buf_index, const_true())); } Array ret = (*combiner)(a, b); std::vector stores(size); for (size_t i = 0; i < size; ++i) { - stores[i] = StoreNode::make(shared_bufs[i], ret[i], buf_index, const_true()); + stores[i] = Store(shared_bufs[i], ret[i], buf_index, const_true()); } return SeqStmt::Flatten(stores); }; @@ -264,16 +408,15 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // reduction with the boundary condition reduce_align = reduce_align >> 1; PrimExpr cond = reduce_index < (reduce_extent - reduce_align); - seq.emplace_back(IfThenElseNode::make(cond, freduce(reduce_align))); + seq.emplace_back(IfThenElse(cond, freduce(reduce_align))); seq.emplace_back(SyncThread("shared")); } CHECK(threadx_extent >= 1 && warp_size_ >= 1); // normal synchronization - while (reduce_align > threadx_extent || - reduce_align > warp_size_) { - reduce_align = reduce_align >> 1; + while (reduce_align > threadx_extent || reduce_align > warp_size_) { + reduce_align = reduce_align >> 1; PrimExpr cond = reduce_index < reduce_align; - seq.emplace_back(IfThenElseNode::make(cond, freduce(reduce_align))); + seq.emplace_back(IfThenElse(cond, freduce(reduce_align))); seq.emplace_back(SyncThread("shared")); } // in warp synchronization. @@ -286,15 +429,14 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { } if (in_warp_seq.size() != 0) { Stmt warp_body = SeqStmt::Flatten(in_warp_seq); - seq.emplace_back(IfThenElseNode::make(in_warp_cond, warp_body)); + seq.emplace_back(IfThenElse(in_warp_cond, warp_body)); seq.emplace_back(SyncThread("shared")); } return SeqStmt::Flatten(seq); } // Flatten the thread index. // Also return a warp number, - PrimExpr FlattenThread(const std::vector& tvec, - int* out_total_extent) { + PrimExpr FlattenThread(const std::vector& tvec, int* out_total_extent) { int& total_extent = *out_total_extent; total_extent = 1; if (tvec.size() == 0) { @@ -313,21 +455,79 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { } return ret; } - // sync thread op. - static Stmt SyncThread(const std::string& sync) { - return EvaluateNode::make( - CallNode::make(DataType::Int(32), intrinsic::tvm_storage_sync, - {StringImmNode::make(sync)}, - CallNode::Intrinsic)); - } // The local buffer index. - static PrimExpr BufIndex(PrimExpr reduce_index, PrimExpr group_index, int reduce_extent) { + PrimExpr BufIndex(PrimExpr reduce_index, PrimExpr group_index, int reduce_extent) { if (!is_zero(group_index)) { - return tir::Simplify(group_index * reduce_extent + reduce_index); + return analyzer_.Simplify(group_index * reduce_extent + reduce_index); } else { return reduce_index; } } + // sync thread op. + static Stmt SyncThread(const std::string& sync) { + return Evaluate(Call(DataType::Int(32), intrinsic::tvm_storage_sync, {StringImm(sync)}, + CallNode::Intrinsic)); + } + + // Emit warp shuffle intrinsic calls. + PrimExpr WarpShuffle(const char* name, Var mask_var, PrimExpr val, int delta_or_lane) { + PrimExpr pred = const_true(1); + PrimExpr index(0); + PrimExpr mask = Load(DataType::UInt(32), mask_var, index, pred); + PrimExpr width = IntImm(DataType::Int(32), warp_size_); + Array args{mask, val, IntImm(DataType::Int(32), delta_or_lane), width, width}; + return Call(val.dtype(), name, args, CallNode::Intrinsic); + } + + // Check if this is a reduction on threadIdx.x and its extent matches + // the warp size. + // + // TODO(tvm-team) reduction with a sub-warp of 8 or 16 threads. + // Note: The ROCm backend will only have warp reductions for now. + // Also, the warp/wavefront size differs (64 on rocm, 32 on cuda). + bool is_warp_reduction(const std::vector& types) const { + // Only cuda target supports warp reductions. + if ((target_->target_name != "cuda") && (target_->target_name != "rocm")) return false; + + // rocm only supports 32 bit operands for shuffling at the moment + if ((target_->target_name == "rocm") && + (std::any_of(types.begin(), types.end(), [](DataType ty) { + if (ty.is_vector()) return true; + return ty.bits() != 32; + }))) { + return false; + } + + // Supported types: + // {u}int, {u}long, {u}long long, float, double, half/half2 + if (std::any_of(types.begin(), types.end(), [](DataType ty) { + if (ty.is_float16()) return ty.lanes() > 2; + if (ty.is_vector()) return true; + return ty.bytes() < 4 || ty.bytes() > 8; + })) { + return false; + } + if (thread_extents_.empty()) { + return false; + } + + const AttrStmtNode* op = thread_extents_.back(); + DCHECK_EQ(op->attr_key, attr::thread_extent); + + IterVar iv = Downcast(op->node); + ThreadEntry e; + e.scope = runtime::ThreadScope::Create(iv->thread_tag); + e.extent = 0; + if (auto ptr = op->value.as()) { + e.extent = static_cast(ptr->value); + } + + return e.extent == warp_size_ && e.scope.dim_index == 0 && e.scope.rank == 1; + } + + // The target. + const TargetNode* target_ = nullptr; + // The warp size of the device. int warp_size_{1}; @@ -335,9 +535,13 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { std::vector thread_extents_; std::vector reduce_combiner_; // The load remap - std::unordered_map load_remap_; + std::unordered_map load_remap_; // Allocate remap - std::unordered_map alloc_remap_; + std::unordered_map alloc_remap_; + // Allocate from warp reductions + std::unordered_set warp_allocs_; + // Internal analyzer + arith::Analyzer analyzer_; }; namespace transform { @@ -346,16 +550,15 @@ Pass LowerThreadAllreduce() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { auto* n = f.CopyOnWrite(); auto target = f->GetAttr(tvm::attr::kTarget); - CHECK(target.defined()) - << "LowerThreadAllreduce: Require the target attribute"; - n->body = ThreadAllreduceBuilder(target.value()->thread_warp_size)(n->body); + CHECK(target.defined()) << "LowerThreadAllreduce: Require the target attribute"; + const TargetNode* target_node = target.as(); + n->body = ThreadAllreduceBuilder(target_node)(n->body); return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.LowerThreadAllreduce", {}); } -TVM_REGISTER_GLOBAL("tir.transform.LowerThreadAllreduce") -.set_body_typed(LowerThreadAllreduce); +TVM_REGISTER_GLOBAL("tir.transform.LowerThreadAllreduce").set_body_typed(LowerThreadAllreduce); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc index 71ba468a950f..7611e0fcc8b3 100644 --- a/src/tir/transforms/lower_tvm_builtin.cc +++ b/src/tir/transforms/lower_tvm_builtin.cc @@ -21,16 +21,14 @@ * Lower TVM related builtin intrinsics such as packed call. * \file tir/transforms/lower_tvm_buildin.cc */ +#include #include #include #include -#include -#include #include -#include "../pass/ir_util.h" -#include "../../arith/compute_expr.h" +#include "ir_util.h" namespace tvm { namespace tir { @@ -41,11 +39,8 @@ inline PrimExpr ConstInt32(size_t index) { } inline PrimExpr StackAlloca(std::string type, size_t num) { - Array args = {StringImmNode::make(type), ConstInt32(num)}; - return CallNode::make( - DataType::Handle(), - intrinsic::tvm_stack_alloca, - args, CallNode::Intrinsic); + Array args = {StringImm(type), ConstInt32(num)}; + return Call(DataType::Handle(), intrinsic::tvm_stack_alloca, args, CallNode::Intrinsic); } // Calculate the statistics of packed function. @@ -59,18 +54,14 @@ class BuiltinLower : public StmtExprMutator { stack_tcode_ = Var("stack_tcode", DataType::Handle()); stmt = this->VisitStmt(stmt); if (max_shape_stack_ != 0) { - stmt = LetStmtNode::make( - stack_shape_, StackAlloca("shape", max_shape_stack_), stmt); + stmt = LetStmt(stack_shape_, StackAlloca("shape", max_shape_stack_), stmt); } if (max_array_stack_ != 0) { - stmt = LetStmtNode::make( - stack_array_, StackAlloca("array", max_array_stack_), stmt); + stmt = LetStmt(stack_array_, StackAlloca("array", max_array_stack_), stmt); } if (max_arg_stack_ != 0) { - stmt = LetStmtNode::make( - stack_value_, StackAlloca("arg_value", max_arg_stack_), stmt); - stmt = LetStmtNode::make( - stack_tcode_, StackAlloca("arg_tcode", max_arg_stack_), stmt); + stmt = LetStmt(stack_value_, StackAlloca("arg_value", max_arg_stack_), stmt); + stmt = LetStmt(stack_tcode_, StackAlloca("arg_tcode", max_arg_stack_), stmt); } return stmt; } @@ -94,11 +85,10 @@ class BuiltinLower : public StmtExprMutator { Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); // Get constant allocation bound. - int64_t dev_type; int64_t nbytes = GetVectorBytes(op->dtype); if (device_type_.defined()) { - if (arith::GetConst(device_type_, &dev_type)) { - if (dev_type == kDLCPU) { + if (const auto* dev_type = device_type_.as()) { + if (dev_type->value == kDLCPU) { int32_t constant_size = op->constant_allocation_size(); if (constant_size > 0 && constant_size * nbytes < runtime::kMaxStackAlloca) { return stmt; @@ -112,44 +102,31 @@ class BuiltinLower : public StmtExprMutator { } CHECK(device_type_.defined()) << "Unknown device type in current IR"; CHECK(device_id_.defined()) << "Unknown device id in current IR"; - Stmt throw_last_error = EvaluateNode::make( - CallNode::make(DataType::Int(32), - intrinsic::tvm_throw_last_error, {}, - CallNode::Intrinsic)); + Stmt throw_last_error = + Evaluate(Call(DataType::Int(32), intrinsic::tvm_throw_last_error, {}, CallNode::Intrinsic)); - Stmt body = SeqStmt({ - IfThenElseNode::make( - CallNode::make(DataType::Bool(1), - intrinsic::tvm_handle_is_null, - {op->buffer_var}, CallNode::PureIntrinsic), - throw_last_error), - op->body}); + Stmt body = SeqStmt({IfThenElse(Call(DataType::Bool(1), intrinsic::tvm_handle_is_null, + {op->buffer_var}, CallNode::PureIntrinsic), + throw_last_error), + op->body}); - Stmt alloca = LetStmtNode::make( + Stmt alloca = LetStmt( op->buffer_var, - CallNode::make(op->buffer_var.dtype(), - "TVMBackendAllocWorkspace", - {cast(DataType::Int(32), device_type_), - cast(DataType::Int(32), device_id_), - cast(DataType::UInt(64), total_bytes), - IntImm(DataType::Int(32), op->dtype.code()), - IntImm(DataType::Int(32), op->dtype.bits())}, - CallNode::Extern), + Call(op->buffer_var.dtype(), "TVMBackendAllocWorkspace", + {cast(DataType::Int(32), device_type_), cast(DataType::Int(32), device_id_), + cast(DataType::UInt(64), total_bytes), IntImm(DataType::Int(32), op->dtype.code()), + IntImm(DataType::Int(32), op->dtype.bits())}, + CallNode::Extern), body); - PrimExpr free_op = CallNode::make(DataType::Int(32), - "TVMBackendFreeWorkspace", - {cast(DataType::Int(32), device_type_), - cast(DataType::Int(32), device_id_), - op->buffer_var}, - CallNode::Extern); - Stmt free_stmt = IfThenElseNode::make( - free_op != make_zero(DataType::Int(32)), throw_last_error); + PrimExpr free_op = Call(DataType::Int(32), "TVMBackendFreeWorkspace", + {cast(DataType::Int(32), device_type_), + cast(DataType::Int(32), device_id_), op->buffer_var}, + CallNode::Extern); + Stmt free_stmt = IfThenElse(free_op != make_zero(DataType::Int(32)), throw_last_error); body = SeqStmt({alloca, free_stmt}); - body = AttrStmtNode::make( - op->buffer_var, attr::storage_alignment, - make_const(DataType::Int(32), runtime::kTempAllocaAlignment), - body); + body = AttrStmt(op->buffer_var, attr::storage_alignment, + make_const(DataType::Int(32), runtime::kTempAllocaAlignment), body); return body; } @@ -188,9 +165,8 @@ class BuiltinLower : public StmtExprMutator { PrimExpr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); for (size_t i = 0; i < op->args.size(); ++i) { - prep_seq_.emplace_back( - StoreNode::make(stack_shape_, cast(DataType::Int(64), op->args[i]), - ConstInt32(stack_begin +i), const_true(1))); + prep_seq_.emplace_back(Store(stack_shape_, cast(DataType::Int(64), op->args[i]), + ConstInt32(stack_begin + i), const_true(1))); } return AddressOffset(stack_shape_, DataType::Int(64), stack_begin); } @@ -200,45 +176,36 @@ class BuiltinLower : public StmtExprMutator { run_array_stack_ += 1; PrimExpr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); - prep_seq_.emplace_back( - TVMStructSet(stack_array_, idx, intrinsic::kArrData, op->args[0])); - prep_seq_.emplace_back( - TVMStructSet(stack_array_, idx, intrinsic::kArrShape, op->args[1])); + prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, intrinsic::kArrData, op->args[0])); + prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, intrinsic::kArrShape, op->args[1])); PrimExpr strides = op->args[2]; if (!strides.defined() || is_zero(strides)) { strides = make_zero(DataType::Handle()); } - prep_seq_.emplace_back( - TVMStructSet(stack_array_, idx, intrinsic::kArrStrides, strides)); - prep_seq_.emplace_back( - TVMStructSet(stack_array_, idx, intrinsic::kArrNDim, op->args[3])); + prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, intrinsic::kArrStrides, strides)); + prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, intrinsic::kArrNDim, op->args[3])); DataType dtype = op->args[4].dtype(); prep_seq_.emplace_back( TVMStructSet(stack_array_, idx, intrinsic::kArrTypeCode, make_const(DataType::UInt(8), static_cast(dtype.code())))); - prep_seq_.emplace_back( - TVMStructSet(stack_array_, idx, intrinsic::kArrTypeBits, - make_const(DataType::UInt(8), dtype.bits()))); - prep_seq_.emplace_back( - TVMStructSet(stack_array_, idx, intrinsic::kArrTypeLanes, - make_const(DataType::UInt(16), dtype.lanes()))); + prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, intrinsic::kArrTypeBits, + make_const(DataType::UInt(8), dtype.bits()))); + prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, intrinsic::kArrTypeLanes, + make_const(DataType::UInt(16), dtype.lanes()))); // set byte offset int data_bytes = GetVectorBytes(dtype); PrimExpr byte_offset = op->args[5]; if (!is_zero(byte_offset)) { byte_offset = byte_offset * make_const(byte_offset.dtype(), data_bytes); } - prep_seq_.emplace_back( - TVMStructSet(stack_array_, idx, intrinsic::kArrByteOffset, - cast(DataType::UInt(64), byte_offset))); + prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, intrinsic::kArrByteOffset, + cast(DataType::UInt(64), byte_offset))); CHECK(device_type_.defined()) << "Unknown device type in current IR"; CHECK(device_id_.defined()) << "Unknown device id in current IR"; - prep_seq_.emplace_back( - TVMStructSet(stack_array_, idx, intrinsic::kArrDeviceId, - cast(DataType::Int(32), device_id_))); - prep_seq_.emplace_back( - TVMStructSet(stack_array_, idx, intrinsic::kArrDeviceType, - cast(DataType::Int(32), device_type_))); + prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, intrinsic::kArrDeviceId, + cast(DataType::Int(32), device_id_))); + prep_seq_.emplace_back(TVMStructSet(stack_array_, idx, intrinsic::kArrDeviceType, + cast(DataType::Int(32), device_type_))); return TVMStructGet(DataType::Handle(), stack_array_, idx, intrinsic::kArrAddr); } // call packed. @@ -256,20 +223,17 @@ class BuiltinLower : public StmtExprMutator { DataType t = arg.dtype(); DataType api_type = APIType(t); if (t != api_type) { - arg = CastNode::make(api_type, arg); + arg = Cast(api_type, arg); } - prep_seq_.emplace_back(TVMStructSet( - stack_value_, static_cast(arg_stack_begin + i - 1), - intrinsic::kTVMValueContent, arg)); + prep_seq_.emplace_back(TVMStructSet(stack_value_, static_cast(arg_stack_begin + i - 1), + intrinsic::kTVMValueContent, arg)); int arg_tcode = api_type.code(); if (api_type.is_handle() && arg.as()) { arg_tcode = kTVMStr; } if (IsArrayHandle(arg)) arg_tcode = kTVMDLTensorHandle; prep_seq_.emplace_back( - StoreNode::make(stack_tcode_, - ConstInt32(arg_tcode), - stack_index, const_true(1))); + Store(stack_tcode_, ConstInt32(arg_tcode), stack_index, const_true(1))); } // UPDATE stack value max_arg_stack_ = std::max(run_arg_stack_, max_arg_stack_); @@ -278,19 +242,14 @@ class BuiltinLower : public StmtExprMutator { run_shape_stack_ = restore_shape_stack; run_array_stack_ = restore_array_stack; run_arg_stack_ = arg_stack_begin; - Array packed_args = { - op->args[0], - stack_value_, - stack_tcode_, - ConstInt32(arg_stack_begin), - ConstInt32(arg_stack_begin + op->args.size() - 1) - }; - return CallNode::make( - DataType::Int(32), intrinsic::tvm_call_packed_lowered, - packed_args, CallNode::Intrinsic); + Array packed_args = {op->args[0], stack_value_, stack_tcode_, + ConstInt32(arg_stack_begin), + ConstInt32(arg_stack_begin + op->args.size() - 1)}; + return Call(DataType::Int(32), intrinsic::tvm_call_packed_lowered, packed_args, + CallNode::Intrinsic); } - PrimExpr MakeCallTracePacked(const CallNode *op) { + PrimExpr MakeCallTracePacked(const CallNode* op) { size_t restore_shape_stack = run_shape_stack_; size_t restore_array_stack = run_array_stack_; size_t arg_stack_begin = run_arg_stack_; @@ -305,17 +264,14 @@ class BuiltinLower : public StmtExprMutator { DataType t = arg.dtype(); DataType api_type = APIType(t); if (t != api_type) { - arg = CastNode::make(api_type, arg); + arg = Cast(api_type, arg); } - prep_seq_.emplace_back(TVMStructSet( - stack_value_, static_cast(arg_stack_begin + i - 1), - intrinsic::kTVMValueContent, arg)); + prep_seq_.emplace_back(TVMStructSet(stack_value_, static_cast(arg_stack_begin + i - 1), + intrinsic::kTVMValueContent, arg)); int arg_tcode = api_type.code(); CHECK(!IsArrayHandle(arg)) << "Trace does not support Buffers"; prep_seq_.emplace_back( - StoreNode::make(stack_tcode_, - ConstInt32(arg_tcode), - stack_index, const_true(1))); + Store(stack_tcode_, ConstInt32(arg_tcode), stack_index, const_true(1))); } // UPDATE stack value max_arg_stack_ = std::max(run_arg_stack_, max_arg_stack_); @@ -326,18 +282,13 @@ class BuiltinLower : public StmtExprMutator { // Update the top of the stack, so we can use more than one // packed function's arguments with the one stack. run_arg_stack_ = arg_stack_begin + args_size - 1; - Array packed_args = { - op->args[0], - stack_value_, - stack_tcode_, - ConstInt32(arg_stack_begin), - ConstInt32(arg_stack_begin + op->args.size() - 1), - // Pass traced value. - op->args[args_size - 1] - }; - return CallNode::make( - op->dtype, intrinsic::tvm_call_trace_packed_lowered, - packed_args, CallNode::Intrinsic); + Array packed_args = {op->args[0], stack_value_, stack_tcode_, + ConstInt32(arg_stack_begin), + ConstInt32(arg_stack_begin + op->args.size() - 1), + // Pass traced value. + op->args[args_size - 1]}; + return Call(op->dtype, intrinsic::tvm_call_trace_packed_lowered, packed_args, + CallNode::Intrinsic); } private: @@ -382,8 +333,7 @@ Pass LowerTVMBuiltin() { return CreatePrimFuncPass(pass_func, 0, "tir.LowerTVMBuiltin", {}); } -TVM_REGISTER_GLOBAL("tir.transform.LowerTVMBuiltin") -.set_body_typed(LowerTVMBuiltin); +TVM_REGISTER_GLOBAL("tir.transform.LowerTVMBuiltin").set_body_typed(LowerTVMBuiltin); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc index 612a8f4d9eef..92f9ab54adb4 100644 --- a/src/tir/transforms/lower_warp_memory.cc +++ b/src/tir/transforms/lower_warp_memory.cc @@ -25,20 +25,19 @@ */ // Thanks to Andrew Adams and Vinod Grover for // explaining the concept of warp shuffle. -#include #include - +#include +#include +#include +#include #include +#include #include #include -#include -#include -#include #include -#include "../pass/ir_util.h" -#include "../../arith/compute_expr.h" +#include "../../arith/pattern_match.h" #include "../../runtime/thread_storage_scope.h" namespace tvm { @@ -60,37 +59,49 @@ namespace tir { // // Before rewrite, // -// alloc warp warp_mem[n * warp_size * m] -// store warp_mem[m * warp_index + (warp_size * m) * y + x] -// load warp_mem[m * z + (warp_size * m) * y + x] +// alloc warp warp_mem[n * width * m] +// store warp_mem[m * warp_index + (width * m) * y + x] +// load warp_mem[m * z + (width * m) * y + x] // subject to x \in [0, m), y \in [0, n) // +// where width equals to the extent of threadIdx.x, which should +// be no larger than the warp size +// // After rewrite: // // alloc local local_mem[n * m] // store warp_mem[m * y + x] // warp_shuffle(load warp_mem[m * y + x], z) // subject to (m * y + x) is invariant to warp_index +// +// If width == warp size, we are shuffling on full warps. +// Otherwise, we are virtually shuffling on sub-warps, +// whose size equals to width. In this case, you can imagine +// a warp only consists of `width` threads. Width is passed +// as an argument to the shuffle primitive, and will be +// lowered to the device code if the target supports. +// +// A limitation of this sub-warp approach is that users +// cannot shuffle across the sub-warp boundary (i.e. shuffle +// with threadIdx.y or threadIdx.z indices). It can be solved +// via fusing threadIdx.x to the warp size, or improving the +// analyzer to detect both 3 thread axes, which is left for +// future improvements. // Algorithm // // To implement this rewrite rule, we can do the follow step: // For each warp memory alloc // - Use linear pattern detector on load index to find m -// - Deduce n given warp_size and alloc size -// - Now that we have m, n, warp_size, we can proceed with the rewrite +// - Deduce n given width and alloc size +// - Now that we have m, n, width, we can proceed with the rewrite // Visitor to find m in pattern -// store warp_mem[m * warp_index + (warp_size * m) * y + x] +// store warp_mem[m * warp_index + (width * m) * y + x] class WarpStoreCoeffFinder : private StmtVisitor { public: - WarpStoreCoeffFinder(const VarNode* buffer, - Var warp_index, - arith::Analyzer* analyzer) - : buffer_(buffer), - warp_index_(warp_index), - analyzer_(analyzer) { - } + WarpStoreCoeffFinder(const VarNode* buffer, Var warp_index, arith::Analyzer* analyzer) + : buffer_(buffer), warp_index_(warp_index), analyzer_(analyzer) {} // find the warp co-efficient in the statement given the warp size int Find(const Stmt& stmt) { this->VisitStmt(stmt); @@ -99,16 +110,16 @@ class WarpStoreCoeffFinder : private StmtVisitor { private: /// Visitor implementation - void VisitStmt_(const StoreNode *op) final { + void VisitStmt_(const StoreNode* op) final { if (op->buffer_var.get() == buffer_) { if (op->value.dtype().lanes() == 1) { UpdatePattern(op->index); } else { - PrimExpr base; - CHECK(GetRamp1Base(op->index, op->value.dtype().lanes(), &base)) + arith::PVar base; + CHECK(arith::ramp(base, 1, op->value.dtype().lanes()).Match(op->index)) << "LowerWarpMemory failed due to store index=" << op->index << ", can only handle continuous store"; - UpdatePattern(base); + UpdatePattern(base.Eval()); } } else { StmtVisitor::VisitStmt_(op); @@ -116,23 +127,20 @@ class WarpStoreCoeffFinder : private StmtVisitor { } void UpdatePattern(const PrimExpr& index) { - Array m = - arith::DetectLinearEquation(index, {warp_index_}); - CHECK_EQ(m.size(), 2U) - << "LowerWarpMemory failed due to store index=" << index; - int coeff = 0; + Array m = arith::DetectLinearEquation(index, {warp_index_}); + CHECK_EQ(m.size(), 2U) << "LowerWarpMemory failed due to store index=" << index; PrimExpr mcoeff = analyzer_->canonical_simplify(m[0]); - - CHECK(arith::GetConstInt(mcoeff, &coeff) && coeff > 0) + const auto* mcoeff_as_int = mcoeff.as(); + CHECK(mcoeff_as_int && mcoeff_as_int->value > 0) << "LowerWarpMemory failed due to store index=" << index - << ", require positive constant coefficient on warp index " << warp_index_ - << " but get " << mcoeff; + << ", require positive constant coefficient on warp index " << warp_index_ << " but get " + << mcoeff; if (warp_coeff_ != 0) { - CHECK_EQ(warp_coeff_, coeff) + CHECK_EQ(warp_coeff_, mcoeff_as_int->value) << "LowerWarpMemory failed due to two different store coefficient to warp index"; } else { - warp_coeff_ = coeff; + warp_coeff_ = mcoeff_as_int->value; } } @@ -141,45 +149,43 @@ class WarpStoreCoeffFinder : private StmtVisitor { // the warp index Var warp_index_; // the coefficient - int warp_coeff_{0}; + int64_t warp_coeff_{0}; // analyzer. arith::Analyzer* analyzer_; }; - // Visitor to find the warp index class WarpIndexFinder : private StmtVisitor { public: - explicit WarpIndexFinder(int warp_size) - : warp_size_(warp_size) { - } - // find the warp co-efficient in the statement given the warp size - IterVar Find(const Stmt& stmt) { + explicit WarpIndexFinder(int warp_size) : warp_size_(warp_size) {} + // find the warp co-efficient and the shuffle width in the statement + std::pair Find(const Stmt& stmt) { this->VisitStmt(stmt); CHECK(warp_index_.defined()) << "Cannot find warp index(threadIdx.x) within the scope of warp memory"; - return warp_index_; + return std::make_pair(warp_index_->var, width_); } private: /// Visitor implementation - void VisitStmt_(const AttrStmtNode *op) final { + void VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == attr::thread_extent) { IterVar iv = Downcast(op->node); if (iv->thread_tag == "threadIdx.x") { - int value; - CHECK(arith::GetConstInt(op->value, &value) && - value == warp_size_) - << "Expect threadIdx.x 's size to be equal to warp size(" - << warp_size_ << ")" << " to enable warp memory" + auto* value_as_int = op->value.as(); + CHECK(value_as_int && value_as_int->value <= warp_size_ && + warp_size_ % value_as_int->value == 0) + << "Expect threadIdx.x 's size to be no larger than, and a factor of" + << " warp size(" << warp_size_ << ")" + << " to enable warp memory" << " but get " << op->value << " instead"; if (warp_index_.defined()) { CHECK(warp_index_.same_as(iv)) - << "Find two instance of " << warp_index_->thread_tag - << " in the same kernel. " + << "Find two instance of " << warp_index_->thread_tag << " in the same kernel. " << "Please create it using thread_axis once and reuse the axis " << "across multiple binds in the same kernel"; } else { + width_ = value_as_int->value; warp_index_ = iv; } } @@ -188,6 +194,8 @@ class WarpIndexFinder : private StmtVisitor { } // warp size int warp_size_{0}; + // number of threads involved in one shuffle + int width_{0}; // the warp index IterVar warp_index_{nullptr}; }; @@ -201,27 +209,24 @@ class WarpAccessRewriter : protected StmtExprMutator { Stmt Rewrite(const AllocateNode* op) { buffer_ = op->buffer_var.get(); int alloc_size = op->constant_allocation_size(); - CHECK_GT(alloc_size, 0) - << "warp memory only support constant alloc size"; + CHECK_GT(alloc_size, 0) << "warp memory only support constant alloc size"; alloc_size *= op->dtype.lanes(); - warp_index_ = WarpIndexFinder(warp_size_).Find(op->body)->var; - warp_coeff_ = WarpStoreCoeffFinder( - buffer_, warp_index_, analyzer_).Find(op->body); - CHECK_EQ(alloc_size % (warp_size_ * warp_coeff_), 0) - << "Warp memory must be multiple of warp size"; - warp_group_ = alloc_size / (warp_size_ * warp_coeff_); - return AllocateNode::make( - op->buffer_var, - op->dtype, - {make_const(DataType::Int(32), alloc_size / warp_size_)}, - op->condition, - this->VisitStmt(op->body)); + std::tie(warp_index_, width_) = WarpIndexFinder(warp_size_).Find(op->body); + warp_coeff_ = WarpStoreCoeffFinder(buffer_, warp_index_, analyzer_).Find(op->body); + + // Align the local memory size. The number of elements may not + // be a multiple of width_ * warp_coeff_; round it up. + int factor = width_ * warp_coeff_; + warp_group_ = (alloc_size + (factor - 1)) / factor; + alloc_size = warp_group_ * factor; + + return Allocate(op->buffer_var, op->dtype, {make_const(DataType::Int(32), alloc_size / width_)}, + op->condition, this->VisitStmt(op->body)); } protected: PrimExpr VisitExpr_(const VarNode* op) override { - CHECK(op != buffer_) - << "Cannot access address of warp memory directly"; + CHECK(op != buffer_) << "Cannot access address of warp memory directly"; return StmtExprMutator::VisitExpr_(op); } @@ -229,7 +234,7 @@ class WarpAccessRewriter : protected StmtExprMutator { if (op->buffer_var.get() == buffer_) { PrimExpr local_index, group; std::tie(local_index, group) = SplitIndexByGroup(op->index); - return StoreNode::make(op->buffer_var, op->value, local_index, op->predicate); + return Store(op->buffer_var, op->value, local_index, op->predicate); } else { return StmtExprMutator::VisitStmt_(op); } @@ -240,15 +245,14 @@ class WarpAccessRewriter : protected StmtExprMutator { PrimExpr local_index, group; std::tie(local_index, group) = SplitIndexByGroup(op->index); // invariance: local index must do not contain warp id - CHECK(!ExprUseVar(local_index, {warp_index_.get()})) - << "LowerWarpMemory failed to rewrite load to shuffle for index " - << op->index << " local_index=" << local_index; - PrimExpr load_value = LoadNode::make( - op->dtype, op->buffer_var, local_index, op->predicate); - return CallNode::make(load_value.dtype(), - intrinsic::tvm_warp_shuffle, - {load_value, group}, - CallNode::Intrinsic); + CHECK(!ExprUseVar(local_index, warp_index_)) + << "LowerWarpMemory failed to rewrite load to shuffle for index " << op->index + << " local_index=" << local_index; + PrimExpr load_value = Load(op->dtype, op->buffer_var, local_index, op->predicate); + PrimExpr mask = + Call(DataType::UInt(32), intrinsic::tvm_warp_activemask, {}, CallNode::Intrinsic); + return Call(load_value.dtype(), intrinsic::tvm_warp_shuffle, + {mask, load_value, group, width_, warp_size_}, CallNode::Intrinsic); } else { return StmtExprMutator::VisitExpr_(op); } @@ -260,11 +264,13 @@ class WarpAccessRewriter : protected StmtExprMutator { // in this access pattern. std::pair SplitIndexByGroup(const PrimExpr& index) { if (index.dtype().lanes() != 1) { - PrimExpr base, local_index, group; - CHECK(GetRamp1Base(index, index.dtype().lanes(), &base)); - std::tie(local_index, group) = SplitIndexByGroup(base); - local_index = - RampNode::make(local_index, make_const(local_index.dtype(), 1), index.dtype().lanes()); + PrimExpr local_index, group; + + arith::PVar base; + CHECK(arith::ramp(base, 1, index.dtype().lanes()).Match(index)); + + std::tie(local_index, group) = SplitIndexByGroup(base.Eval()); + local_index = Ramp(local_index, make_const(local_index.dtype(), 1), index.dtype().lanes()); return std::make_pair(local_index, group); } PrimExpr m = make_const(index.dtype(), warp_coeff_); @@ -276,12 +282,10 @@ class WarpAccessRewriter : protected StmtExprMutator { return std::make_pair(x, z); } else { PrimExpr x = analyzer_->canonical_simplify(indexmod(index, m)); - PrimExpr y = index / make_const(index.dtype(), warp_coeff_ * warp_size_); + PrimExpr y = index / make_const(index.dtype(), warp_coeff_ * width_); y = y * m + x; - PrimExpr z = indexdiv(indexmod(index, make_const(index.dtype(), warp_coeff_ * warp_size_)), - m); - return std::make_pair(analyzer_->canonical_simplify(y), - analyzer_->canonical_simplify(z)); + PrimExpr z = indexdiv(indexmod(index, make_const(index.dtype(), warp_coeff_ * width_)), m); + return std::make_pair(analyzer_->canonical_simplify(y), analyzer_->canonical_simplify(z)); } } @@ -290,6 +294,8 @@ class WarpAccessRewriter : protected StmtExprMutator { int warp_size_{0}; // The buffer variable const VarNode* buffer_; + // number of threads involved in one shuffle + int width_{0}; // Warp index Var warp_index_; // the coefficient m @@ -300,14 +306,12 @@ class WarpAccessRewriter : protected StmtExprMutator { arith::Analyzer* analyzer_; }; - // Bind bound information of variables to make analyzer more effective // TODO(tqchen): consider a pass to inline the bound info into the expr // so analysis can be context independent. class BindVarBoundInfo : public StmtVisitor { public: - explicit BindVarBoundInfo(arith::Analyzer* analyzer) - : analyzer_(analyzer) {} + explicit BindVarBoundInfo(arith::Analyzer* analyzer) : analyzer_(analyzer) {} void VisitStmt_(const ForNode* op) final { const Var& loop_var = op->loop_var; @@ -316,8 +320,7 @@ class BindVarBoundInfo : public StmtVisitor { } void VisitStmt_(const AttrStmtNode* op) { - if (op->attr_key == attr::thread_extent || - op->attr_key == attr::virtual_thread) { + if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) { IterVar iv = Downcast(op->node); CHECK_NE(iv->thread_tag.length(), 0U); if (!var_dom_.count(iv->var.get())) { @@ -339,40 +342,37 @@ class BindVarBoundInfo : public StmtVisitor { // Mutator to change the read pattern class WarpMemoryRewriter : private StmtMutator { public: - explicit WarpMemoryRewriter(int warp_size) - : warp_size_(warp_size) { - } + explicit WarpMemoryRewriter(int warp_size) : warp_size_(warp_size) {} Stmt Rewrite(Stmt stmt) { if (warp_size_ == 1) return stmt; BindVarBoundInfo binder(&analyzer_); binder(stmt); stmt = operator()(std::move(stmt)); - stmt = CanonicalSimplify(stmt); return stmt; } private: Stmt VisitStmt_(const AllocateNode* op) { + auto ret = StmtMutator::VisitStmt_(op); + op = ret.as(); if (warp_buffer_.count(op->buffer_var.get())) { WarpAccessRewriter rewriter(warp_size_, &analyzer_); - return rewriter.Rewrite(op); - } else { - return StmtMutator::VisitStmt_(op); + ret = rewriter.Rewrite(op); } + return ret; } Stmt VisitStmt_(const AttrStmtNode* op) { using runtime::StorageScope; if (op->attr_key == attr::storage_scope) { const VarNode* buf = op->node.as(); - StorageScope scope = StorageScope::make(op->value.as()->value); + StorageScope scope = StorageScope::Create(op->value.as()->value); if (scope.rank == runtime::StorageRank::kWarp) { warp_buffer_.insert(buf); Stmt ret = StmtMutator::VisitStmt_(op); op = ret.as(); - return AttrStmtNode::make( - op->node, op->attr_key, StringImmNode::make("local"), op->body); + return AttrStmt(op->node, op->attr_key, StringImm("local"), op->body); } } return StmtMutator::VisitStmt_(op); @@ -391,16 +391,14 @@ Pass LowerWarpMemory() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { auto* n = f.CopyOnWrite(); auto target = f->GetAttr(tvm::attr::kTarget); - CHECK(target.defined()) - << "LowerWarpMemory: Require the target attribute"; + CHECK(target.defined()) << "LowerWarpMemory: Require the target attribute"; n->body = WarpMemoryRewriter(target.value()->thread_warp_size).Rewrite(std::move(n->body)); return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.LowerWarpMemory", {}); } -TVM_REGISTER_GLOBAL("tir.transform.LowerWarpMemory") -.set_body_typed(LowerWarpMemory); +TVM_REGISTER_GLOBAL("tir.transform.LowerWarpMemory").set_body_typed(LowerWarpMemory); } // namespace transform diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index 7980a9d7238f..a91e350e6b22 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -20,47 +20,42 @@ /*! * \file make_packed_api.cc Lower PrimFunc to use the packed function API. */ -#include -#include -#include -#include -#include -#include -#include +#include #include #include -#include +#include +#include +#include +#include +#include +#include -#include -#include #include +#include +#include -#include "../pass/ir_util.h" -#include "../pass/arg_binder.h" +#include "arg_binder.h" +#include "ir_util.h" namespace tvm { namespace tir { inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) { - return AssertStmtNode::make(lhs == rhs, tvm::tir::StringImmNode::make(msg), - EvaluateNode::make(0)); + return AssertStmt(lhs == rhs, tvm::tir::StringImm(msg), Evaluate(0)); } -PrimFunc MakePackedAPI(PrimFunc&& func, - int num_unpacked_args) { +PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) { auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); - CHECK(global_symbol) - << "MakePackedAPI: Expect PrimFunc to have the global_symbol attribute"; + CHECK(global_symbol) << "MakePackedAPI: Expect PrimFunc to have the global_symbol attribute"; auto target = func->GetAttr(tvm::attr::kTarget); - CHECK(target.defined()) - << "MakePackedAPI: Require the target attribute"; + CHECK(target.defined()) << "MakePackedAPI: Require the target attribute"; int target_device_type = target.value()->device_type; std::string name_hint = global_symbol.value(); auto* func_ptr = func.CopyOnWrite(); - const Stmt nop = EvaluateNode::make(0); + const Stmt nop = Evaluate(0); int num_args = static_cast(func_ptr->params.size()); CHECK_LE(num_unpacked_args, num_args); @@ -86,18 +81,14 @@ PrimFunc MakePackedAPI(PrimFunc&& func, // local function definitions // load i-th argument as type t auto f_arg_value = [&](DataType t, int i) { - Array call_args{ - v_packed_args, - IntImm(DataType::Int(32), i), - IntImm(DataType::Int(32), intrinsic::kTVMValueContent)}; + Array call_args{v_packed_args, IntImm(DataType::Int(32), i), + IntImm(DataType::Int(32), intrinsic::kTVMValueContent)}; // load 64 bit version DataType api_type = APIType(t); - PrimExpr res = CallNode::make( - api_type, intrinsic::tvm_struct_get, call_args, - CallNode::PureIntrinsic); + PrimExpr res = Call(api_type, intrinsic::tvm_struct_get, call_args, CallNode::PureIntrinsic); // cast to the target version. if (api_type != t) { - res = CastNode::make(t, res); + res = Cast(t, res); } return res; }; @@ -112,8 +103,7 @@ PrimFunc MakePackedAPI(PrimFunc&& func, std::ostringstream os; os << name_hint << ": num_args should be " << num_packed_args; - seq_init.emplace_back( - MakeAssertEQ(v_num_packed_args, num_packed_args, os.str())); + seq_init.emplace_back(MakeAssertEQ(v_num_packed_args, num_packed_args, os.str())); } // Need to re-declare vars, in case some arguments also appears in the buffer. @@ -132,36 +122,29 @@ PrimFunc MakePackedAPI(PrimFunc&& func, } if (i < num_packed_args) { // Value loads - seq_init.emplace_back(LetStmtNode::make( - v_arg, f_arg_value(v_arg.dtype(), i), nop)); + seq_init.emplace_back(LetStmt(v_arg, f_arg_value(v_arg.dtype(), i), nop)); // type code checks Var tcode(v_arg->name_hint + ".code", DataType::Int(32)); - seq_init.emplace_back(LetStmtNode::make( - tcode, LoadNode::make( - DataType::Int(32), v_packed_arg_type_ids, - IntImm(DataType::Int(32), i), const_true(1)), - nop)); + seq_init.emplace_back(LetStmt(tcode, + Load(DataType::Int(32), v_packed_arg_type_ids, + IntImm(DataType::Int(32), i), const_true(1)), + nop)); DataType t = v_arg.dtype(); if (t.is_handle()) { std::ostringstream msg; msg << name_hint << ": Expect arg[" << i << "] to be pointer"; - seq_check.emplace_back( - AssertStmtNode::make(tcode == kTVMOpaqueHandle || - tcode == kTVMNDArrayHandle || - tcode == kTVMDLTensorHandle || - tcode == kTVMNullptr, - tvm::tir::StringImmNode::make(msg.str()), nop)); + seq_check.emplace_back(AssertStmt(tcode == kTVMOpaqueHandle || tcode == kTVMNDArrayHandle || + tcode == kTVMDLTensorHandle || tcode == kTVMNullptr, + tvm::tir::StringImm(msg.str()), nop)); } else if (t.is_int() || t.is_uint()) { std::ostringstream msg; msg << name_hint << ": Expect arg[" << i << "] to be int"; - seq_check.emplace_back( - AssertStmtNode::make(tcode == kDLInt, tvm::tir::StringImmNode::make(msg.str()), nop)); + seq_check.emplace_back(AssertStmt(tcode == kDLInt, tvm::tir::StringImm(msg.str()), nop)); } else { CHECK(t.is_float()); std::ostringstream msg; msg << name_hint << ": Expect arg[" << i << "] to be float"; - seq_check.emplace_back( - AssertStmtNode::make(tcode == kDLFloat, tvm::tir::StringImmNode::make(msg.str()), nop)); + seq_check.emplace_back(AssertStmt(tcode == kDLFloat, tvm::tir::StringImm(msg.str()), nop)); } } else { args.push_back(v_arg); @@ -189,35 +172,30 @@ PrimFunc MakePackedAPI(PrimFunc&& func, } for (const auto& kv : buffer_def) { - binder.BindDLTensor(kv.second, device_type, device_id, - kv.first, kv.first->name_hint); + binder.BindDLTensor(kv.second, device_type, device_id, kv.first, kv.first->name_hint); } if (num_unpacked_args == 0) { func = WithAttr(std::move(func), tvm::attr::kCallingConv, Integer(CallingConv::kCPackedFunc)); } - auto body = AttrStmtNode::make( - make_zero(DataType::Int(32)), attr::compute_scope, - StringImmNode::make(name_hint + "_compute_"), func_ptr->body); + Stmt body = AttrStmt(make_zero(DataType::Int(32)), attr::compute_scope, + StringImm(name_hint + "_compute_"), func_ptr->body); // Set device context if (vmap.count(device_id.get())) { - PrimExpr node = StringImmNode::make("default"); - seq_check.push_back(AttrStmtNode::make( - node, attr::device_context_id, device_id, nop)); - seq_check.push_back(AttrStmtNode::make( - node, attr::device_context_type, device_type, nop)); + PrimExpr node = StringImm("default"); + seq_check.push_back(AttrStmt(node, attr::device_context_id, device_id, nop)); + seq_check.push_back(AttrStmt(node, attr::device_context_type, device_type, nop)); if (runtime::DeviceAPI::NeedSetDeviceContext(target_device_type)) { - Stmt set_device = EvaluateNode::make(CallNode::make( - DataType::Int(32), intrinsic::tvm_call_packed, - {StringImmNode::make(runtime::symbol::tvm_set_device), - device_type, device_id}, CallNode::Intrinsic)); + Stmt set_device = + Evaluate(Call(DataType::Int(32), intrinsic::tvm_call_packed, + {StringImm(runtime::symbol::tvm_set_device), device_type, device_id}, + CallNode::Intrinsic)); body = SeqStmt({set_device, body}); } } - func_ptr->body = MergeNest( - {seq_init, binder.init_nest(), seq_check, binder.asserts()}, body); + func_ptr->body = MergeNest({seq_init, binder.init_nest(), seq_check, binder.asserts()}, body); func_ptr->params = args; Array undefined = UndefinedVars(func_ptr->body, func_ptr->params); @@ -230,7 +208,6 @@ PrimFunc MakePackedAPI(PrimFunc&& func, LOG(FATAL) << "Not all Vars are passed in api_args: " << os.str(); } - func_ptr->buffer_map = Map(); func_ptr->checked_type_ = func_ptr->func_type_annotation(); func_ptr->ret_type = PrimType(DataType::Int(32)); @@ -249,9 +226,8 @@ Pass MakePackedAPI(int num_unpacked_args) { for (const auto& kv : mptr->functions) { if (auto* n = kv.second.as()) { PrimFunc func = GetRef(n); - if (func->GetAttr( - tvm::attr::kCallingConv, - Integer(CallingConv::kDefault)) == CallingConv::kDefault) { + if (func->GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) == + CallingConv::kDefault) { auto updated_func = MakePackedAPI(std::move(func), num_unpacked_args); updates.push_back({kv.first, updated_func}); } @@ -264,12 +240,10 @@ Pass MakePackedAPI(int num_unpacked_args) { return m; }; - return tvm::transform::CreateModulePass( - pass_func, 0, "tir.MakePackedAPI", {}); + return tvm::transform::CreateModulePass(pass_func, 0, "tir.MakePackedAPI", {}); } -TVM_REGISTER_GLOBAL("tir.transform.MakePackedAPI") -.set_body_typed(MakePackedAPI); +TVM_REGISTER_GLOBAL("tir.transform.MakePackedAPI").set_body_typed(MakePackedAPI); } // namespace transform } // namespace tir } // namespace tvm diff --git a/src/tir/transforms/narrow_datatype.cc b/src/tir/transforms/narrow_datatype.cc index 1f9d976c407d..07b0ea29a52a 100644 --- a/src/tir/transforms/narrow_datatype.cc +++ b/src/tir/transforms/narrow_datatype.cc @@ -22,10 +22,10 @@ * \brief narrow the datatype of indexing vars */ -#include +#include #include #include -#include + #include "../../arith/ir_mutator_with_analyzer.h" #include "../../arith/ir_visitor_with_analyzer.h" @@ -56,8 +56,8 @@ namespace tir { // - Use DataTypeRewritter to rewrite the components of an indexing expression. using arith::Analyzer; -using arith::IRMutatorWithAnalyzer; using arith::ConstIntBound; +using arith::IRMutatorWithAnalyzer; // Determine the result dtype for Var, IntImm and Cast, // which will be stored in `vmap` eventually. @@ -71,24 +71,22 @@ using arith::ConstIntBound; // Otherwise, `var` is not narrowed, that is, `vmap[var] = var.dtype.bits()` class DataTypeVisitor final : public StmtExprVisitor { public: - explicit DataTypeVisitor(int target_bits) - : bits_(target_bits), target_bits_(target_bits) {} + explicit DataTypeVisitor(int target_bits) : bits_(target_bits), target_bits_(target_bits) {} void VisitExpr(const PrimExpr& e) { if (e.dtype().is_int()) { int bits = max_bits_; - const PrimExprNode* op = e.as(); - if (bound_.find(op) == bound_.end()) { + if (bound_.find(e) == bound_.end()) { analyzer_.const_int_bound(e, &bound_); } - ConstIntBound bound = bound_[op]; + ConstIntBound bound = bound_[e]; int64_t ubound = Downcast(max_value(DataType::Int(target_bits_)))->value; int64_t lbound = Downcast(min_value(DataType::Int(target_bits_)))->value; if (e.dtype().bits() <= target_bits_ || (bound->max_value <= ubound && bound->min_value >= lbound)) { bits = target_bits_; } - int tmp = bits > bits_ ? bits : bits_; + int tmp = bits > bits_ ? bits : bits_; std::swap(bits_, tmp); StmtExprVisitor::VisitExpr(e); std::swap(bits_, tmp); @@ -98,19 +96,16 @@ class DataTypeVisitor final : public StmtExprVisitor { } void VisitStmt_(const ForNode* op) { - analyzer_.Bind(op->loop_var, - Range::make_by_min_extent(op->min, op->extent)); + analyzer_.Bind(op->loop_var, Range::make_by_min_extent(op->min, op->extent)); vextent_[op->loop_var.as()] = op->extent.dtype(); return StmtExprVisitor::VisitStmt_(op); } void VisitStmt_(const AttrStmtNode* op) { - if (op->attr_key == attr::thread_extent || - op->attr_key == attr::virtual_thread) { + if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) { IterVar iv = Downcast(op->node); CHECK_NE(iv->thread_tag.length(), 0U); - analyzer_.Bind(iv->var, - Range::make_by_min_extent(0, op->value)); + analyzer_.Bind(iv->var, Range::make_by_min_extent(0, op->value)); vextent_[iv->var.as()] = op->value.dtype(); StmtExprVisitor::VisitStmt_(op); } else { @@ -188,12 +183,12 @@ class DataTypeVisitor final : public StmtExprVisitor { // the extent of vars to be rewritten std::unordered_map vextent_; // the memorized bound generated by ConstIntBoundAnalyzer - std::unordered_map bound_; + arith::ConstIntBoundAnalyzer::BoundMapType bound_; }; class DataTypeRewriter : public StmtExprMutator { public: - explicit DataTypeRewriter(int target_bits): visitor_(target_bits) {} + explicit DataTypeRewriter(int target_bits) : visitor_(target_bits) {} Stmt operator()(Stmt s) { visitor_(s); @@ -213,47 +208,36 @@ class DataTypeRewriter : public StmtExprMutator { is_index_ = true; PrimExpr index = this->VisitExpr(op->index); is_index_ = false; - Stmt s = StoreNode::make(op->buffer_var, - op->value, - index, - op->predicate); + Stmt s = Store(op->buffer_var, op->value, index, op->predicate); return StmtExprMutator::VisitStmt_(s.as()); } Stmt VisitStmt_(const ForNode* op) final { Stmt s = StmtExprMutator::VisitStmt_(op); op = s.as(); - CHECK(op != nullptr) - << "Expected type to be ForNode" - << ", but get " << s->GetTypeKey(); + CHECK(op != nullptr) << "Expected type to be ForNode" + << ", but get " << s->GetTypeKey(); PrimExpr e = VisitExpr(op->loop_var); Var var = Downcast(e); - return ForNode::make(var, cast(var.dtype(), op->min), cast(var.dtype(), op->extent), - op->for_type, op->device_api, op->body); + return For(var, cast(var.dtype(), op->min), cast(var.dtype(), op->extent), op->for_type, + op->device_api, op->body); } Stmt VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == attr::thread_extent || - op->attr_key == attr::virtual_thread) { + if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) { Stmt s = StmtExprMutator::VisitStmt_(op); op = s.as(); - CHECK(op != nullptr) - << "Expected type to be AttrStmtNode" - << ", but get " << s->GetTypeKey(); + CHECK(op != nullptr) << "Expected type to be AttrStmtNode" + << ", but get " << s->GetTypeKey(); const IterVarNode* iv = op->node.as(); - CHECK(iv != nullptr) - << "Expected type to be IterVarNode" - << ", but get " << op->node->GetTypeKey(); + CHECK(iv != nullptr) << "Expected type to be IterVarNode" + << ", but get " << op->node->GetTypeKey(); PrimExpr e = VisitExpr(iv->var); Var var = Downcast(e); if (ivmap_.find(iv) == ivmap_.end()) { - ivmap_[iv] = IterVarNode::make(iv->dom, var, iv->iter_type, iv->thread_tag); + ivmap_[iv] = IterVar(iv->dom, var, iv->iter_type, iv->thread_tag); } - return AttrStmtNode::make( - ivmap_[iv], - op->attr_key, - cast(var.dtype(), op->value), - op->body); + return AttrStmt(ivmap_[iv], op->attr_key, cast(var.dtype(), op->value), op->body); } return StmtExprMutator::VisitStmt_(op); } @@ -282,7 +266,7 @@ class DataTypeRewriter : public StmtExprMutator { is_index_ = true; PrimExpr index = this->VisitExpr(op->index); is_index_ = false; - PrimExpr e = LoadNode::make(op->dtype, op->buffer_var, index, op->predicate); + PrimExpr e = Load(op->dtype, op->buffer_var, index, op->predicate); return StmtExprMutator::VisitExpr_(e.as()); } @@ -299,10 +283,9 @@ class DataTypeRewriter : public StmtExprMutator { if (is_index_ && visitor_.vmap.find(op) != visitor_.vmap.end()) { PrimExpr e = StmtExprMutator::VisitExpr_(op); const CastNode* new_op = e.as(); - CHECK(new_op != nullptr) - << "Expected type to be CastNode" - << ", but get " << e->GetTypeKey(); - return CastNode::make(visitor_.vmap[op], new_op->value); + CHECK(new_op != nullptr) << "Expected type to be CastNode" + << ", but get " << e->GetTypeKey(); + return Cast(visitor_.vmap[op], new_op->value); } return StmtExprMutator::VisitExpr_(op); } @@ -337,40 +320,38 @@ class DataTypeRewriter : public StmtExprMutator { bool is_index_{false}; }; -#define DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(OP, FUNC) \ - PrimExpr DataTypeRewriter::VisitExpr_(const OP* op) { \ - PrimExpr a = this->VisitExpr(op->a); \ - PrimExpr b = this->VisitExpr(op->b); \ - if (a.same_as(op->a) && \ - b.same_as(op->b)) { \ - return GetRef(op); \ - } else { \ - return FUNC(a, b); \ - } \ +#define DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(OP, FUNC) \ + PrimExpr DataTypeRewriter::VisitExpr_(const OP* op) { \ + PrimExpr a = this->VisitExpr(op->a); \ + PrimExpr b = this->VisitExpr(op->b); \ + if (a.same_as(op->a) && b.same_as(op->b)) { \ + return GetRef(op); \ + } else { \ + return FUNC(a, b); \ + } \ } -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(AddNode, operator+) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(SubNode, operator-) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MulNode, operator*) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(DivNode, div) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(ModNode, truncmod) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(FloorDivNode, floordiv) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(FloorModNode, floormod) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MinNode, min) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MaxNode, max) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(EQNode, operator==) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(NENode, operator!=) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(LTNode, operator <) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(LENode, operator<=) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GTNode, operator >) -DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GENode, operator>=) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(AddNode, operator+); +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(SubNode, operator-); +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MulNode, operator*); +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(DivNode, div); +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(ModNode, truncmod); +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(FloorDivNode, floordiv); +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(FloorModNode, floormod); +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MinNode, min); +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MaxNode, max); +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(EQNode, operator==); +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(NENode, operator!=); +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(LENode, operator<=); +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(LTNode, operator<); // NOLINT(*) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GTNode, operator>); // NOLINT(*) +DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GENode, operator>=); PrimExpr DataTypeRewriter::VisitExpr_(const CallNode* op) { PrimExpr e = StmtExprMutator::VisitExpr_(op); op = e.as(); - CHECK(op != nullptr) - << "Expected type to be CallNode" - << ", but get " << e->GetTypeKey(); + CHECK(op != nullptr) << "Expected type to be CallNode" + << ", but get " << e->GetTypeKey(); if (op->call_type == CallNode::PureIntrinsic) { if (op->name == intrinsic::tvm_if_then_else) { return if_then_else(op->args[0], op->args[1], op->args[2]); @@ -391,9 +372,7 @@ PrimExpr DataTypeRewriter::VisitExpr_(const CallNode* op) { return e; } -Stmt NarrowDataType(Stmt stmt, int target_bits) { - return DataTypeRewriter(target_bits)(stmt); -} +Stmt NarrowDataType(Stmt stmt, int target_bits) { return DataTypeRewriter(target_bits)(stmt); } namespace transform { @@ -403,12 +382,10 @@ Pass NarrowDataType(int target_bits) { n->body = DataTypeRewriter(target_bits)(std::move(n->body)); return f; }; - return CreatePrimFuncPass( - pass_func, 0, "tir.NarrowDataType", {}); + return CreatePrimFuncPass(pass_func, 0, "tir.NarrowDataType", {}); } -TVM_REGISTER_GLOBAL("tir.transform.NarrowDataType") -.set_body_typed(NarrowDataType); +TVM_REGISTER_GLOBAL("tir.transform.NarrowDataType").set_body_typed(NarrowDataType); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/remap_thread_axis.cc b/src/tir/transforms/remap_thread_axis.cc index fdcfc4d4702e..017d1b4e6c67 100644 --- a/src/tir/transforms/remap_thread_axis.cc +++ b/src/tir/transforms/remap_thread_axis.cc @@ -20,12 +20,12 @@ /*! * \file remap_thread_axis.cc */ +#include #include #include #include -#include -#include +#include namespace tvm { namespace tir { @@ -33,14 +33,9 @@ namespace tir { // Mutator to change the read pattern class ThreadAxisRewriter : private StmtExprMutator { public: - explicit ThreadAxisRewriter( - const std::unordered_map& tmap) - : tmap_(tmap) { - } + explicit ThreadAxisRewriter(const std::unordered_map& tmap) : tmap_(tmap) {} - Stmt Rewrite(Stmt stmt) { - return operator()(std::move(stmt)); - } + Stmt Rewrite(Stmt stmt) { return operator()(std::move(stmt)); } private: Stmt VisitStmt_(const AttrStmtNode* op) final { @@ -57,8 +52,7 @@ class ThreadAxisRewriter : private StmtExprMutator { CHECK(vmap_[v].same_as(new_iv->var)); } Stmt body = this->VisitStmt(op->body); - return AttrStmtNode::make( - new_iv, op->attr_key, op->value, body); + return AttrStmt(new_iv, op->attr_key, op->value, body); } } return StmtExprMutator::VisitStmt_(op); @@ -75,7 +69,6 @@ class ThreadAxisRewriter : private StmtExprMutator { std::unordered_map vmap_; }; - PrimFunc RemapThreadAxis(PrimFunc&& f, Map thread_map) { std::unordered_map tmap; for (const auto& kv : thread_map) { @@ -83,8 +76,7 @@ PrimFunc RemapThreadAxis(PrimFunc&& f, Map thread_map) } auto opt_thread_axis = f->GetAttr>(tir::attr::kDeviceThreadAxis); - CHECK(opt_thread_axis != nullptr) - << "Require attribute " << tir::attr::kDeviceThreadAxis; + CHECK(opt_thread_axis != nullptr) << "Require attribute " << tir::attr::kDeviceThreadAxis; auto thread_axis = opt_thread_axis.value(); auto* n = f.CopyOnWrite(); @@ -99,7 +91,6 @@ PrimFunc RemapThreadAxis(PrimFunc&& f, Map thread_map) return WithAttr(std::move(f), tir::attr::kDeviceThreadAxis, thread_axis); } - namespace transform { Pass RemapThreadAxis(Map thread_map) { @@ -109,8 +100,7 @@ Pass RemapThreadAxis(Map thread_map) { return CreatePrimFuncPass(pass_func, 0, "tir.RemapThreadAxis", {}); } -TVM_REGISTER_GLOBAL("tir.transform.RemapThreadAxis") -.set_body_typed(RemapThreadAxis); +TVM_REGISTER_GLOBAL("tir.transform.RemapThreadAxis").set_body_typed(RemapThreadAxis); } // namespace transform } // namespace tir diff --git a/src/tir/pass/remove_no_op.cc b/src/tir/transforms/remove_no_op.cc similarity index 79% rename from src/tir/pass/remove_no_op.cc rename to src/tir/transforms/remove_no_op.cc index 181a8c483e4e..cd3a4b7483cc 100644 --- a/src/tir/pass/remove_no_op.cc +++ b/src/tir/transforms/remove_no_op.cc @@ -21,9 +21,13 @@ * \file remove_no_op.cc * \brief Remove no op from the stmt */ +#include +#include +#include #include -#include #include +#include + #include namespace tvm { @@ -53,7 +57,7 @@ class NoOpRemover : public StmtMutator { if (is_no_op(op->then_case)) { return MakeEvaluate(op->condition); } else { - return IfThenElseNode::make(op->condition, op->then_case); + return IfThenElse(op->condition, op->then_case); } } else { return stmt; @@ -70,7 +74,7 @@ class NoOpRemover : public StmtMutator { Stmt stmt = StmtMutator::VisitStmt_(op); op = stmt.as(); if (is_zero(op->extent)) { - return EvaluateNode::make(0); + return Evaluate(0); } return is_no_op(op->body) ? MakeEvaluate({op->min, op->extent}) : stmt; } @@ -80,14 +84,14 @@ class NoOpRemover : public StmtMutator { return is_no_op(op->body) ? MakeEvaluate(op->extents) : stmt; } - Stmt VisitStmt_(const RealizeNode* op) final { + Stmt VisitStmt_(const ProducerRealizeNode* op) final { Stmt stmt = StmtMutator::VisitStmt_(op); - op = stmt.as(); + op = stmt.as(); return is_no_op(op->body) ? op->body : stmt; } Stmt VisitStmt_(const EvaluateNode* op) final { if (HasSideEffect(op->value)) return GetRef(op); - return EvaluateNode::make(0); + return Evaluate(0); } Stmt VisitStmt_(const SeqStmtNode* op) final { @@ -102,7 +106,7 @@ class NoOpRemover : public StmtMutator { auto n = CopyOnWrite(op); size_t top = 0; for (size_t i = 0; i < n->seq.size(); ++i) { - if (!is_no_op(n->seq[i])) { + if (!is_no_op(n->seq[i])) { n->seq.Set(top++, n->seq[i]); } } @@ -124,9 +128,9 @@ class NoOpRemover : public StmtMutator { private: Stmt MakeEvaluate(PrimExpr value) { if (HasSideEffect(value)) { - return EvaluateNode::make(value); + return Evaluate(value); } else { - return EvaluateNode::make(0); + return Evaluate(0); } } Stmt MakeEvaluate(const Array& values) { @@ -134,18 +138,32 @@ class NoOpRemover : public StmtMutator { for (PrimExpr e : values) { if (HasSideEffect(e)) { if (stmt.defined()) { - stmt = SeqStmt({stmt, EvaluateNode::make(e)}); + stmt = SeqStmt({stmt, Evaluate(e)}); } else { - stmt = EvaluateNode::make(e); + stmt = Evaluate(e); } } } - return stmt.defined() ? stmt : EvaluateNode::make(0); + return stmt.defined() ? stmt : Evaluate(0); } }; -Stmt RemoveNoOp(Stmt stmt) { - return NoOpRemover()(std::move(stmt)); +Stmt RemoveNoOp(Stmt stmt) { return NoOpRemover()(std::move(stmt)); } + +namespace transform { + +Pass RemoveNoOp() { + auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + n->body = NoOpRemover()(std::move(n->body)); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.RemoveNoOp", {}); } + +TVM_REGISTER_GLOBAL("tir.transform.RemoveNoOp").set_body_typed(RemoveNoOp); + +} // namespace transform + } // namespace tir } // namespace tvm diff --git a/src/tir/pass/rewrite_unsafe_select.cc b/src/tir/transforms/rewrite_unsafe_select.cc similarity index 74% rename from src/tir/pass/rewrite_unsafe_select.cc rename to src/tir/transforms/rewrite_unsafe_select.cc index 501649237090..701f0cea1bfa 100644 --- a/src/tir/pass/rewrite_unsafe_select.cc +++ b/src/tir/transforms/rewrite_unsafe_select.cc @@ -21,23 +21,21 @@ * \file unsafe_select_rewrite.cc * \brief Rewrite uinsafe select expression. */ +#include #include #include -#include +#include namespace tvm { namespace tir { - // For now, rewrite unsafe select expression to if_then_else // TODO(tqchen) pattern matching to support masked load class UnsafeExprDetector : public ExprFunctor { public: // select itself is always considered safe if condition is safe // Because we will issue guard to make sure it is. - bool VisitExpr_(const SelectNode* op) { - return VisitExpr(op->condition); - } + bool VisitExpr_(const SelectNode* op) { return VisitExpr(op->condition); } bool VisitExpr_(const CallNode* op) { if (op->is_intrinsic(intrinsic::tvm_if_then_else)) { return VisitExpr(op->args[0]); @@ -74,21 +72,11 @@ class UnsafeExprDetector : public ExprFunctor { bool VisitExpr_(const GENode* op) final { return BinaryOp(op); } bool VisitExpr_(const AndNode* op) final { return BinaryOp(op); } bool VisitExpr_(const OrNode* op) final { return BinaryOp(op); } - bool VisitExpr_(const NotNode* op) final { - return VisitExpr(op->a); - } - bool VisitExpr_(const LetNode* op) final { - return VisitExpr(op->body) || VisitExpr(op->value); - } - bool VisitExpr_(const CastNode* op) final { - return VisitExpr(op->value); - } - bool VisitExpr_(const BroadcastNode* op) final { - return VisitExpr(op->value); - } - bool VisitExpr_(const RampNode* op) final { - return VisitExpr(op->base) && VisitExpr(op->stride); - } + bool VisitExpr_(const NotNode* op) final { return VisitExpr(op->a); } + bool VisitExpr_(const LetNode* op) final { return VisitExpr(op->body) || VisitExpr(op->value); } + bool VisitExpr_(const CastNode* op) final { return VisitExpr(op->value); } + bool VisitExpr_(const BroadcastNode* op) final { return VisitExpr(op->value); } + bool VisitExpr_(const RampNode* op) final { return VisitExpr(op->base) && VisitExpr(op->stride); } bool VisitExpr_(const ShuffleNode* op) final { for (PrimExpr e : op->vectors) { if (VisitExpr(e)) return true; @@ -101,7 +89,7 @@ class UnsafeExprDetector : public ExprFunctor { bool VisitExpr_(const StringImmNode* op) final { return false; } private: - template + template bool BinaryOp(const T* op) { return VisitExpr(op->a) || VisitExpr(op->b); } @@ -114,23 +102,32 @@ class UnsafeSelectRewriter : public StmtExprMutator { op = expr.as(); UnsafeExprDetector unsafe; bool cond_is_scalar_bool = op->condition.dtype().is_bool() && op->condition.dtype().is_scalar(); - if ((unsafe.VisitExpr(op->true_value) || - unsafe.VisitExpr(op->false_value)) && + if ((unsafe.VisitExpr(op->true_value) || unsafe.VisitExpr(op->false_value)) && cond_is_scalar_bool) { - return CallNode::make( - op->dtype, - intrinsic::tvm_if_then_else, - {op->condition, op->true_value, op->false_value}, - CallNode::Intrinsic); + return Call(op->dtype, intrinsic::tvm_if_then_else, + {op->condition, op->true_value, op->false_value}, CallNode::Intrinsic); } else { return expr; } } }; -Stmt RewriteUnsafeSelect(Stmt stmt) { - return UnsafeSelectRewriter()(std::move(stmt)); +Stmt RewriteUnsafeSelect(Stmt stmt) { return UnsafeSelectRewriter()(std::move(stmt)); } + +namespace transform { + +Pass RewriteUnsafeSelect() { + auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + n->body = UnsafeSelectRewriter()(std::move(n->body)); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.RewriteUnsafeSelect", {}); } +TVM_REGISTER_GLOBAL("tir.transform.RewriteUnsafeSelect").set_body_typed(RewriteUnsafeSelect); + +} // namespace transform + } // namespace tir } // namespace tvm diff --git a/src/arith/stmt_simplify.cc b/src/tir/transforms/simplify.cc similarity index 67% rename from src/arith/stmt_simplify.cc rename to src/tir/transforms/simplify.cc index 6c3dd022565c..3be232964f36 100644 --- a/src/arith/stmt_simplify.cc +++ b/src/tir/transforms/simplify.cc @@ -18,17 +18,17 @@ */ /*! - * \file stmt_simplify.cc + * \file simplify.cc * \brief Statement simplifier based on analyzer */ -#include -#include -#include #include - +#include +#include +#include #include -#include -#include "ir_mutator_with_analyzer.h" +#include + +#include "../../arith/ir_mutator_with_analyzer.h" namespace tvm { namespace arith { @@ -37,20 +37,15 @@ using namespace tir; class StmtSimplifier : public IRMutatorWithAnalyzer { public: - explicit StmtSimplifier(Analyzer* analyzer) - : IRMutatorWithAnalyzer(analyzer) {} + explicit StmtSimplifier(Analyzer* analyzer) : IRMutatorWithAnalyzer(analyzer) {} using Parent = IRMutatorWithAnalyzer; using Parent::VisitStmt; using Parent::VisitStmt_; - PrimExpr VisitExpr(const PrimExpr& expr) final { - return analyzer_->Simplify(expr); - } + PrimExpr VisitExpr(const PrimExpr& expr) final { return analyzer_->Simplify(expr); } - Stmt Simplify(Stmt stmt) { - return operator()(std::move(stmt)); - } + Stmt Simplify(Stmt stmt) { return operator()(std::move(stmt)); } Stmt VisitStmt_(const ForNode* op) final { analyzer_->Bind(op->loop_var, Range::make_by_min_extent(op->min, op->extent)); @@ -68,8 +63,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { return this->VisitStmt(op->body); } Stmt body = this->VisitStmt(op->body); - if (value.same_as(op->value) && - body.same_as(op->body)) { + if (value.same_as(op->value) && body.same_as(op->body)) { return GetRef(op); } else { auto n = this->CopyOnWrite(op); @@ -86,7 +80,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { if (const LoadNode* load = op->value.as()) { if (load->buffer_var.same_as(op->buffer_var) && tir::ExprDeepEqual()(load->index, op->index)) { - return EvaluateNode::make(0); + return Evaluate(0); } } return GetRef(op); @@ -96,34 +90,21 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { } // namespace arith namespace tir { - -Stmt CanonicalSimplify(Stmt stmt, Map vrange) { - arith::Analyzer analyzer; - for (auto kv : vrange) { - analyzer.Bind(kv.first, kv.second); - } - return arith::StmtSimplifier(&analyzer).Simplify(std::move(stmt)); +namespace transform { + +Pass Simplify() { + auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + arith::Analyzer analyzer; + n->body = arith::StmtSimplifier(&analyzer).Simplify(std::move(n->body)); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.Simplify", {}); } -PrimExpr CanonicalSimplify(PrimExpr expr, Map vrange) { - arith::Analyzer analyzer; - for (auto kv : vrange) { - analyzer.Bind(kv.first, kv.second); - } - return analyzer.canonical_simplify(expr); -} +TVM_REGISTER_GLOBAL("tir.transform.Simplify").set_body_typed(Simplify); -PrimExpr Simplify(PrimExpr expr, Map vrange) { - arith::Analyzer analyzer; - for (auto kv : vrange) { - analyzer.Bind(kv.first, kv.second); - } - expr = analyzer.Simplify(expr); - return expr; -} +} // namespace transform -Stmt Simplify(Stmt stmt, Map vrange) { - return CanonicalSimplify(std::move(stmt), vrange); -} } // namespace tir } // namespace tvm diff --git a/src/tir/transforms/skip_assert.cc b/src/tir/transforms/skip_assert.cc index 2857639f2e78..d9cd6d35497c 100644 --- a/src/tir/transforms/skip_assert.cc +++ b/src/tir/transforms/skip_assert.cc @@ -17,11 +17,10 @@ * under the License. */ +#include #include -#include -#include #include -#include +#include namespace tvm { namespace tir { @@ -35,9 +34,7 @@ class AssertSkipper : public StmtMutator { } }; -Stmt SkipAssert(Stmt stmt) { - return AssertSkipper()(std::move(stmt)); -} +Stmt SkipAssert(Stmt stmt) { return AssertSkipper()(std::move(stmt)); } namespace transform { @@ -50,8 +47,7 @@ Pass SkipAssert() { return CreatePrimFuncPass(pass_func, 0, "tir.SkipAssert", {}); } -TVM_REGISTER_GLOBAL("tir.transform.SkipAssert") -.set_body_typed(SkipAssert); +TVM_REGISTER_GLOBAL("tir.transform.SkipAssert").set_body_typed(SkipAssert); } // namespace transform diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index 927536b5938e..67336d483ca7 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -22,13 +22,14 @@ * \brief Split device function from host. */ #include +#include +#include +#include +#include #include -#include -#include +#include #include -#include -#include -#include +#include #include @@ -58,7 +59,7 @@ class VarUseDefAnalysis : public StmtExprMutator { if (value.same_as(op->value) && body.same_as(op->body)) { return GetRef(op); } - return AttrStmtNode::make(op->node, op->attr_key, value, body); + return AttrStmt(op->node, op->attr_key, value, body); } else { return StmtExprMutator::VisitStmt_(op); } @@ -68,16 +69,14 @@ class VarUseDefAnalysis : public StmtExprMutator { this->HandleDef(op->var.get()); Stmt body = this->VisitStmt(op->body); // eliminate unreferenced let - if (use_count_.at(op->var.get()) == 0 && - !HasSideEffect(op->value)) { + if (use_count_.at(op->var.get()) == 0 && !HasSideEffect(op->value)) { return body; } else { PrimExpr value = this->VisitExpr(op->value); - if (body.same_as(op->body) && - value.same_as(op->value)) { + if (body.same_as(op->body) && value.same_as(op->value)) { return GetRef(op); } else { - return LetStmtNode::make(op->var, value, body); + return LetStmt(op->var, value, body); } } } @@ -101,16 +100,14 @@ class VarUseDefAnalysis : public StmtExprMutator { this->HandleDef(op->var.get()); PrimExpr body = this->VisitExpr(op->body); // eliminate unreferenced let - if (use_count_.at(op->var.get()) == 0 && - !HasSideEffect(op->value)) { + if (use_count_.at(op->var.get()) == 0 && !HasSideEffect(op->value)) { return body; } else { PrimExpr value = this->VisitExpr(op->value); - if (body.same_as(op->body) && - value.same_as(op->value)) { + if (body.same_as(op->body) && value.same_as(op->value)) { return GetRef(op); } else { - return LetNode::make(op->var, value, body); + return Let(op->var, value, body); } } } @@ -126,12 +123,10 @@ class VarUseDefAnalysis : public StmtExprMutator { } void HandleDef(const VarNode* v) { - CHECK(!def_count_.count(v)) - << "variable " << v->name_hint - << " has already been defined, the Stmt is not SSA"; - CHECK(!use_count_.count(v)) - << "variable " << v->name_hint - << " has been used before definition!"; + CHECK(!def_count_.count(v)) << "variable " << v->name_hint + << " has already been defined, the Stmt is not SSA"; + CHECK(!use_count_.count(v)) << "variable " << v->name_hint + << " has been used before definition!"; use_count_[v] = 0; def_count_[v] = 1; } @@ -160,7 +155,6 @@ class VarUseDefAnalysis : public StmtExprMutator { std::unordered_map def_count_; }; - Array UndefinedVars(const Stmt& stmt, const Array& args) { VarUseDefAnalysis m; for (Var arg : args) { @@ -170,16 +164,10 @@ Array UndefinedVars(const Stmt& stmt, const Array& args) { return m.undefined_; } - class HostDeviceSplitter : public StmtMutator { public: - explicit HostDeviceSplitter(IRModule* device_mod, - Target device_target, - std::string name_prefix) - : device_mod_(device_mod), - device_target_(device_target), - name_prefix_(name_prefix) { - } + explicit HostDeviceSplitter(IRModule* device_mod, Target device_target, std::string name_prefix) + : device_mod_(device_mod), device_target_(device_target), name_prefix_(name_prefix) {} Stmt VisitStmt_(const AllocateNode* op) final { handle_data_type_[op->buffer_var.get()] = make_const(op->dtype, 0); @@ -187,8 +175,7 @@ class HostDeviceSplitter : public StmtMutator { } Stmt VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == attr::thread_extent || - op->attr_key == attr::pipeline_exec_scope || + if (op->attr_key == attr::thread_extent || op->attr_key == attr::pipeline_exec_scope || op->attr_key == attr::device_scope) { return SplitDeviceFunc(GetRef(op)); } @@ -215,8 +202,7 @@ class HostDeviceSplitter : public StmtMutator { // Create a new version of v. auto it = handle_data_type_.find(var.get()); if (it != handle_data_type_.end()) { - tir::Var new_var(var->name_hint, - PointerType(PrimType((*it).second->dtype))); + tir::Var new_var(var->name_hint, PointerType(PrimType((*it).second->dtype))); params.push_back(new_var); remap_vars.Set(var, new_var); } else { @@ -236,24 +222,23 @@ class HostDeviceSplitter : public StmtMutator { device_func = WithAttr(std::move(device_func), tir::attr::kDeviceThreadAxis, m.thread_axis_); device_func = WithAttr(std::move(device_func), tvm::attr::kCallingConv, Integer(CallingConv::kDeviceKernelLaunch)); - device_func = WithAttr(std::move(device_func), tvm::attr::kGlobalSymbol, - runtime::String(kernel_symbol)); + device_func = + WithAttr(std::move(device_func), tvm::attr::kGlobalSymbol, runtime::String(kernel_symbol)); device_func = WithAttr(std::move(device_func), tir::attr::kNoAlias, Integer(1)); device_func = WithAttr(std::move(device_func), tvm::attr::kTarget, device_target_); (*device_mod_)->Add(GlobalVar(kernel_symbol), device_func); // generate calls to the device function Array call_args; - call_args.push_back(StringImmNode::make(kernel_symbol)); + call_args.push_back(StringImm(kernel_symbol)); for (PrimExpr arg : arguments) { call_args.push_back(arg); } for (PrimExpr ext : m.thread_extent_) { call_args.push_back(ext); } - return EvaluateNode::make(CallNode::make( - DataType::Int(32), intrinsic::tvm_call_packed, - call_args, CallNode::Intrinsic)); + return Evaluate( + Call(DataType::Int(32), intrinsic::tvm_call_packed, call_args, CallNode::Intrinsic)); } // target ir module @@ -267,19 +252,15 @@ class HostDeviceSplitter : public StmtMutator { std::unordered_map handle_data_type_; }; - PrimFunc SplitHostDevice(PrimFunc&& func, IRModule* device_mod) { auto target = func->GetAttr(tvm::attr::kTarget); - CHECK(target.defined()) - << "SplitHostDevice: Require the target attribute"; + CHECK(target.defined()) << "SplitHostDevice: Require the target attribute"; auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); CHECK(global_symbol.defined()) << "SplitHostDevice: Expect PrimFunc to have the global_symbol attribute"; - HostDeviceSplitter splitter( - device_mod, - target.value(), - static_cast(global_symbol.value())); + HostDeviceSplitter splitter(device_mod, target.value(), + static_cast(global_symbol.value())); auto* n = func.CopyOnWrite(); n->body = splitter(std::move(n->body)); @@ -288,14 +269,13 @@ PrimFunc SplitHostDevice(PrimFunc&& func, IRModule* device_mod) { return std::move(func); } - namespace transform { Pass SplitHostDevice() { auto pass_func = [](IRModule mod, PassContext ctx) { IRModuleNode* mod_ptr = mod.CopyOnWrite(); auto* func_dict = mod_ptr->functions.CopyOnWrite(); - IRModule device_mod = IRModule::Empty(); + IRModule device_mod = IRModule(); for (auto& kv : func_dict->data) { if (kv.second->IsInstance()) { @@ -307,12 +287,10 @@ Pass SplitHostDevice() { return mod; }; - return tvm::transform::CreateModulePass( - pass_func, 0, "tir.SplitHostDevice", {}); + return tvm::transform::CreateModulePass(pass_func, 0, "tir.SplitHostDevice", {}); } -TVM_REGISTER_GLOBAL("tir.transform.SplitHostDevice") -.set_body_typed(SplitHostDevice); +TVM_REGISTER_GLOBAL("tir.transform.SplitHostDevice").set_body_typed(SplitHostDevice); } // namespace transform } // namespace tir diff --git a/src/tir/pass/storage_access.cc b/src/tir/transforms/storage_access.cc similarity index 92% rename from src/tir/pass/storage_access.cc rename to src/tir/transforms/storage_access.cc index f6bba486c785..20cc6402135f 100644 --- a/src/tir/pass/storage_access.cc +++ b/src/tir/transforms/storage_access.cc @@ -20,13 +20,15 @@ /*! * \file storage_access.cc */ -#include +#include "storage_access.h" + #include +#include + #include #include -#include "storage_access.h" + #include "ir_util.h" -#include "../../arith/compute_expr.h" namespace tvm { namespace tir { @@ -90,8 +92,7 @@ void StorageAccessVisitor::VisitStmt_(const EvaluateNode* op) { void StorageAccessVisitor::VisitStmt_(const AttrStmtNode* op) { if (op->attr_key == attr::storage_scope) { const VarNode* buf = op->node.as(); - storage_scope_[buf] = - StorageScope::make(op->value.as()->value); + storage_scope_[buf] = StorageScope::Create(op->value.as()->value); StmtExprVisitor::VisitStmt_(op); } else if (op->attr_key == attr::double_buffer_write) { CHECK(double_buffer_write_ == nullptr); @@ -115,7 +116,7 @@ void StorageAccessVisitor::VisitStmt_(const AttrStmtNode* op) { IterVar iv = Downcast(op->node); env_threads_.push_back(iv); StmtExprVisitor::VisitStmt_(op); - env_threads_.CopyOnWrite()->data.pop_back(); + env_threads_.pop_back(); } else if (op->attr_key == attr::thread_extent) { IterVar iv = Downcast(op->node); env_threads_.push_back(iv); @@ -130,7 +131,7 @@ void StorageAccessVisitor::VisitStmt_(const AttrStmtNode* op) { } else { StmtExprVisitor::VisitStmt_(op); } - env_threads_.CopyOnWrite()->data.pop_back(); + env_threads_.pop_back(); } else { StmtExprVisitor::VisitStmt_(op); } @@ -146,8 +147,8 @@ void StorageAccessVisitor::VisitStmt_(const ForNode* op) { if (s.access.size() != 0) { // relax the touched set to contain all ranges in the loop. std::unordered_map relax_map; - relax_map[op->loop_var.get()] = arith::IntSet::range( - Range::make_by_min_extent(op->min, op->extent)); + relax_map[op->loop_var.get()] = + arith::IntSet::range(Range::make_by_min_extent(op->min, op->extent)); for (AccessEntry& e : s.access) { if (e.buffer.defined()) { CHECK(e.touched.defined()); @@ -181,7 +182,7 @@ void StorageAccessVisitor::VisitStmt_(const IfThenElseNode* op) { void StorageAccessVisitor::VisitExpr_(const CallNode* op) { if (op->is_intrinsic(intrinsic::tvm_address_of)) { - const LoadNode *l = op->args[0].as(); + const LoadNode* l = op->args[0].as(); StmtExprVisitor::VisitExpr_(l); } else if (op->is_intrinsic(intrinsic::tvm_access_ptr)) { CHECK_EQ(op->args.size(), 5U); @@ -198,8 +199,7 @@ void StorageAccessVisitor::VisitExpr_(const CallNode* op) { e.threads = env_threads(); e.dtype = dtype; e.buffer = Downcast(op->args[1]); - e.touched = arith::IntSet::range( - Range::make_by_min_extent(offset, extent)); + e.touched = arith::IntSet::range(Range::make_by_min_extent(offset, extent)); e.scope = scope; if (flag->value & 1) { e.type = kRead; @@ -215,11 +215,11 @@ void StorageAccessVisitor::VisitExpr_(const CallNode* op) { CHECK(allow_append_); const std::string& s = op->args[0].as()->value; if (s != "warp") { - StorageScope scope = StorageScope::make(s); + StorageScope scope = StorageScope::Create(s); AccessEntry e; e.threads = env_threads(); e.type = kSync; - e.scope = StorageScope::make(s); + e.scope = StorageScope::Create(s); curr_stmt_.access.emplace_back(std::move(e)); } } else { diff --git a/src/tir/pass/storage_access.h b/src/tir/transforms/storage_access.h similarity index 86% rename from src/tir/pass/storage_access.h rename to src/tir/transforms/storage_access.h index d3614b8fff4e..80bbff4c1fe4 100644 --- a/src/tir/pass/storage_access.h +++ b/src/tir/transforms/storage_access.h @@ -21,22 +21,24 @@ * \file storage_access.h * \brief Common data structure for storage access analysis. */ -#ifndef TVM_TIR_PASS_STORAGE_ACCESS_H_ -#define TVM_TIR_PASS_STORAGE_ACCESS_H_ +#ifndef TVM_TIR_TRANSFORMS_STORAGE_ACCESS_H_ +#define TVM_TIR_TRANSFORMS_STORAGE_ACCESS_H_ -#include +#include #include -#include +#include #include -#include + #include +#include + #include "../../runtime/thread_storage_scope.h" namespace tvm { namespace tir { -using runtime::StorageScope; using runtime::StorageRank; +using runtime::StorageScope; /*! * \brief Base class of storage access analysis */ @@ -85,31 +87,20 @@ class StorageAccessVisitor : public StmtExprVisitor { void VisitExpr_(const CallNode* op) final; protected: - StorageAccessVisitor() { - scope_.push_back(std::vector()); - } + StorageAccessVisitor() { scope_.push_back(std::vector()); } /*! \return number of conditions in the current scope. */ - int condition_counter() const { - return condition_counter_; - } + int condition_counter() const { return condition_counter_; } /*! \return whether we are in device environment. */ - bool in_device_env() const { - return in_device_env_; - } + bool in_device_env() const { return in_device_env_; } /*! \return environment threads */ - const Array& env_threads() const { - return env_threads_; - } + const Array& env_threads() const { return env_threads_; } /*! * \brief Whether we need analyze the buffer in current scope. * \param buffer The buffer to be checked * \param scope The scope of the buffer. * \return Whether the analysis of buffer is enabled. */ - virtual bool Enabled(const VarNode* buffer, - const StorageScope& scope) const { - return true; - } + virtual bool Enabled(const VarNode* buffer, const StorageScope& scope) const { return true; } /*! * \brief Summarize the sequence of operations into parent. * @@ -121,8 +112,7 @@ class StorageAccessVisitor : public StmtExprVisitor { * \return The summarized sequence that represent access that * the parent should taken care of to synchronize. */ - virtual std::vector Summarize( - std::vector seq, const ForNode* loop) = 0; + virtual std::vector Summarize(std::vector seq, const ForNode* loop) = 0; /*! * \brief Get the scope of the buffer array. * \return The scope of the final buffer array. @@ -150,4 +140,4 @@ class StorageAccessVisitor : public StmtExprVisitor { } // namespace tir } // namespace tvm -#endif // TVM_TIR_PASS_STORAGE_ACCESS_H_ +#endif // TVM_TIR_TRANSFORMS_STORAGE_ACCESS_H_ diff --git a/src/tir/pass/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc similarity index 59% rename from src/tir/pass/storage_flatten.cc rename to src/tir/transforms/storage_flatten.cc index f9533fa4820a..e29d978e0d42 100644 --- a/src/tir/pass/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -19,46 +19,47 @@ /*! * \file storage_flatten.cc + * \brief Flattens storage from multi-dimensional array to 1D buffer access */ -// Flattens storage from multi-dimensional array to 1D -// buffer access as in Halide pipeline. +// The pass definition originates from Halide pipeline. + #include +#include +#include +#include +#include +#include #include +#include #include -#include #include -#include -#include -#include -#include -#include +#include + #include -#include "ir_util.h" -#include "arg_binder.h" -#include "../../arith/compute_expr.h" + #include "../../arith/ir_visitor_with_analyzer.h" #include "../../runtime/thread_storage_scope.h" +#include "arg_binder.h" +#include "ir_util.h" namespace tvm { namespace tir { +using intrinsic::tvm_address_of; using runtime::StorageRank; using runtime::StorageScope; using runtime::ThreadScope; -using intrinsic::tvm_address_of; class StorageFlattener : public StmtExprMutator { public: - explicit StorageFlattener(Map extern_buffer, - int cache_line_size, bool create_bound_attributes, - IRVisitorWithAnalyzer* bounded_analyzer) - : bounded_analyzer_(bounded_analyzer), - create_bound_attributes_(create_bound_attributes) { - for (auto kv : extern_buffer) { + explicit StorageFlattener(const Map& extern_buffer_map, int cache_line_size, + bool create_bound_attributes, IRVisitorWithAnalyzer* bound_analyzer) + : bound_analyzer_(bound_analyzer), create_bound_attributes_(create_bound_attributes) { + for (auto kv : extern_buffer_map) { BufferEntry e; e.buffer = kv.second; e.external = true; - buf_map_[TensorKey{kv.first->op, kv.first->value_index}] = e; + buf_map_[kv.second] = e; } cache_line_size_ = cache_line_size; } @@ -67,11 +68,10 @@ class StorageFlattener : public StmtExprMutator { Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); auto it = var_remap_.find(op->buffer_var.get()); - if (it != var_remap_.end() && - !it->second.same_as(op->buffer_var)) { + if (it != var_remap_.end() && !it->second.same_as(op->buffer_var)) { CHECK(it->second.as()); Var buf_var = Downcast(it->second); - return StoreNode::make(buf_var, op->value, op->index, op->predicate); + return Store(buf_var, op->value, op->index, op->predicate); } else { return stmt; } @@ -82,21 +82,16 @@ class StorageFlattener : public StmtExprMutator { storage_scope_[op->node.get()] = op->value.as()->value; return this->VisitStmt(op->body); } else if (op->attr_key == attr::double_buffer_scope && - op->node->IsInstance()) { - auto func = Downcast(op->node); + op->node->IsInstance()) { + auto buffer = Downcast(op->node); Stmt body = this->VisitStmt(op->body); - for (int i = 0; i < func->num_outputs(); ++i) { - TensorKey key{func, i}; - auto it = buf_map_.find(key); - CHECK(it != buf_map_.end()) - << "Cannot find allocated buffer for " << key.f; - body = AttrStmtNode::make( - it->second.buffer->data, op->attr_key, op->value, body); - } + auto it = buf_map_.find(buffer); + CHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << buffer; + body = AttrStmt(it->second.buffer->data, op->attr_key, op->value, std::move(body)); return body; } else if (op->attr_key == attr::thread_extent) { IterVar iv = Downcast(op->node); - ThreadScope ts = ThreadScope::make(iv->thread_tag); + ThreadScope ts = ThreadScope::Create(iv->thread_tag); curr_thread_scope_.push_back(ts); Stmt stmt = StmtExprMutator::VisitStmt_(op); curr_thread_scope_.pop_back(); @@ -104,11 +99,10 @@ class StorageFlattener : public StmtExprMutator { } else if (op->attr_key == attr::buffer_bind_scope) { return HandleBufferBindScope(op); } else if (op->attr_key == attr::buffer_dim_align) { - auto tensor = Downcast(op->node); + auto buffer = Downcast(op->node); const CallNode* tuple = op->value.as(); CHECK(tuple && tuple->is_intrinsic(intrinsic::tvm_tuple)); - TensorKey key{tensor->op, tensor->value_index}; - auto& vinfo = dim_align_[key]; + auto& vinfo = dim_align_[buffer]; int dim = tuple->args[0].as()->value; if (static_cast(dim) >= vinfo.size()) { vinfo.resize(dim + 1); @@ -116,50 +110,40 @@ class StorageFlattener : public StmtExprMutator { vinfo[dim].align_factor = tuple->args[1].as()->value; vinfo[dim].align_offset = tuple->args[2].as()->value; return this->VisitStmt(op->body); - } else if (op->attr_key == attr::opengl_stage_scope) { - is_opengl_ = true; } return StmtExprMutator::VisitStmt_(op); } - Stmt VisitStmt_(const ProvideNode* op) final { - if (create_bound_attributes_) - shape_collector_.clear(); + Stmt VisitStmt_(const BufferStoreNode* op) final { + if (create_bound_attributes_) shape_collector_.clear(); Stmt stmt = StmtExprMutator::VisitStmt_(op); - op = stmt.as(); - TensorKey key{op->func, op->value_index}; + op = stmt.as(); + + const auto& key = op->buffer; + auto it = buf_map_.find(key); - CHECK(it != buf_map_.end()) - << "Cannot find allocated buffer for " << key.f; + CHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << key; + const BufferEntry& e = it->second; - CHECK(!e.released) - << "Read a buffer that is already out of scope"; - if (is_opengl_) { - return EvaluateNode::make(CallNode::make( - DataType(), - CallNode::glsl_texture_store, - {e.buffer->data, op->value}, - CallNode::Intrinsic)); - } else { - Stmt body = e.buffer.vstore(e.RelIndex(op->args), op->value); - if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) { - shape_collector_.push_back( - std::make_pair(e.buffer->data, e.buffer->shape)); - } - // To create bound attribute collector should has at least one item. - if (create_bound_attributes_ && shape_collector_.size()) { - for (size_t i = 0; i < shape_collector_.size(); ++i) { - body = AttrStmtNode::make( - shape_collector_[i].first, tir::attr::buffer_bound, - MakeBound(e.buffer->dtype, shape_collector_[i].second), body); - } + CHECK(!e.released) << "Read a buffer that is already out of scope"; + + Stmt body = e.buffer.vstore(e.RelIndex(op->indices), op->value); + if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) { + shape_collector_.push_back(std::make_pair(e.buffer->data, e.buffer->shape)); + } + // To create bound attribute collector should has at least one item. + if (create_bound_attributes_ && shape_collector_.size()) { + for (size_t i = 0; i < shape_collector_.size(); ++i) { + body = AttrStmt(shape_collector_[i].first, tir::attr::buffer_bound, + MakeBound(e.buffer->dtype, shape_collector_[i].second), body); } - return body; } + return body; } - Stmt VisitStmt_(const RealizeNode* op) final { - TensorKey key{op->func, op->value_index}; + Stmt VisitStmt_(const BufferRealizeNode* op) final { + const auto& key = op->buffer; + if (buf_map_.count(key)) { CHECK(buf_map_.at(key).external); return this->VisitStmt(op->body); @@ -172,29 +156,27 @@ class StorageFlattener : public StmtExprMutator { shape.push_back(r->extent); } // deduce current storage scope. - auto it = storage_scope_.find(op->func.get()); - CHECK(it != storage_scope_.end()) - << "Cannot find storage scope of " << op->func - << " value_index=" << op->value_index; + auto it = storage_scope_.find(op->buffer.get()); + CHECK(it != storage_scope_.end()) << "Cannot find storage scope of " << op->buffer; StorageScope skey; const std::string& strkey = it->second; if (strkey.length() == 0) { if (curr_thread_scope_.size() != 0) { - skey.rank = runtime::DefaultStorageRank( - curr_thread_scope_.back().rank); + skey.rank = runtime::DefaultStorageRank(curr_thread_scope_.back().rank); } } else { - skey = StorageScope::make(strkey); + skey = StorageScope::Create(strkey); } // use small alignment for small arrays + auto dtype = op->buffer->dtype; int32_t const_size = AllocateNode::constant_allocation_size(shape); - int align = GetTempAllocaAlignment(op->dtype, const_size); + int align = GetTempAllocaAlignment(dtype, const_size); if (skey.tag.length() != 0) { MemoryInfo info = GetMemoryInfo(skey.to_string()); if (info.defined()) { - align = (info->max_simd_bits + op->dtype.bits() - 1) / op->dtype.bits(); - CHECK_LE(const_size * op->dtype.bits(), info->max_num_bits) + align = (info->max_simd_bits + dtype.bits() - 1) / dtype.bits(); + CHECK_LE(const_size * dtype.bits(), info->max_num_bits) << "Allocation exceed bound of memory tag " << skey.to_string(); } } @@ -210,7 +192,7 @@ class StorageFlattener : public StmtExprMutator { PrimExpr factor = make_const(stride.dtype(), avec[dim].align_factor); PrimExpr offset = make_const(stride.dtype(), avec[dim].align_offset); stride = stride + indexmod(factor + offset - indexmod(stride, factor), factor); - stride = tir::Simplify(stride); + stride = bound_analyzer_->Simplify(stride); } rstrides.push_back(stride); stride = stride * shape[dim]; @@ -218,11 +200,9 @@ class StorageFlattener : public StmtExprMutator { strides = Array(rstrides.rbegin(), rstrides.rend()); } - e.buffer = BufferNode::make( - Var(key.GetName(), DataType::Handle()), - op->dtype, shape, strides, PrimExpr(), - key.GetName(), skey.to_string(), - align, 0, kDefault); + e.buffer = + Buffer(Var(op->buffer->data->name_hint, DataType::Handle()), op->buffer->dtype, shape, + strides, PrimExpr(), op->buffer->name, skey.to_string(), align, 0, kDefault); buf_map_[key] = e; Stmt body = this->VisitStmt(op->body); @@ -237,26 +217,22 @@ class StorageFlattener : public StmtExprMutator { } if (strides.size() != 0) { int first_dim = 0; - ret = AllocateNode::make( - e.buffer->data, storage_type, - {e.buffer->strides[first_dim] * e.buffer->shape[first_dim]}, - make_const(DataType::Bool(e.buffer->dtype.lanes()), true), body); + ret = Allocate(e.buffer->data, storage_type, + {e.buffer->strides[first_dim] * e.buffer->shape[first_dim]}, + make_const(DataType::Bool(e.buffer->dtype.lanes()), true), body); } else { shape = e.buffer->shape; if (shape.size() == 0) { shape.push_back(make_const(DataType::Int(32), 1)); } - ret = AllocateNode::make( - e.buffer->data, storage_type, shape, - make_const(DataType::Bool(e.buffer->dtype.lanes()), true), body); + ret = Allocate(e.buffer->data, storage_type, shape, + make_const(DataType::Bool(e.buffer->dtype.lanes()), true), body); } - ret = AttrStmtNode::make( - e.buffer->data, attr::storage_scope, - StringImmNode::make(e.buffer->scope), ret); + ret = AttrStmt(e.buffer->data, attr::storage_scope, StringImm(e.buffer->scope), ret); if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) { - ret = AttrStmtNode::make(e.buffer->data, tir::attr::buffer_bound, - MakeBound(e.buffer->dtype, e.buffer->shape), ret); + ret = AttrStmt(e.buffer->data, tir::attr::buffer_bound, + MakeBound(e.buffer->dtype, e.buffer->shape), ret); } return ret; } @@ -266,11 +242,10 @@ class StorageFlattener : public StmtExprMutator { PrimExpr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); auto it = var_remap_.find(op->buffer_var.get()); - if (it != var_remap_.end() && - !it->second.same_as(op->buffer_var)) { + if (it != var_remap_.end() && !it->second.same_as(op->buffer_var)) { CHECK(it->second.as()); Var buf_var = Downcast(it->second); - return LoadNode::make(op->dtype, buf_var, op->index, op->predicate); + return Load(op->dtype, buf_var, op->index, op->predicate); } else { return expr; } @@ -285,51 +260,45 @@ class StorageFlattener : public StmtExprMutator { } } - PrimExpr VisitExpr_(const CallNode* op) final { + PrimExpr VisitExpr_(const BufferLoadNode* op) final { PrimExpr expr = StmtExprMutator::VisitExpr_(op); - op = expr.as(); - if (op != nullptr && op->call_type == CallNode::Halide) { - TensorKey key{op->func, op->value_index}; - auto it = buf_map_.find(key); - CHECK(it != buf_map_.end()) - << "Cannot find allocated buffer for " << key.f; - const BufferEntry& e = it->second; - CHECK(!e.released) - << "Read a buffer that is already out of scope"; + op = expr.as(); - if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) { - shape_collector_.push_back( - std::make_pair(e.buffer->data, e.buffer->shape)); - } - return e.buffer.vload(e.RelIndex(op->args), e.buffer->dtype); - } else { - return expr; + const auto& key = op->buffer; + + auto it = buf_map_.find(key); + CHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << key; + const BufferEntry& e = it->second; + CHECK(!e.released) << "Read a buffer that is already out of scope"; + + if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) { + shape_collector_.push_back(std::make_pair(e.buffer->data, e.buffer->shape)); } + return e.buffer.vload(e.RelIndex(op->indices), e.buffer->dtype); } - Stmt VisitStmt_(const PrefetchNode *op) final { + Stmt VisitStmt_(const PrefetchNode* op) final { Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); CHECK(op != nullptr); - TensorKey key{op->func, op->value_index}; + + const auto& key = op->buffer; auto it = buf_map_.find(key); - CHECK(it != buf_map_.end()) - << "Cannot find allocated buffer for " << key.f; + CHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << key; const BufferEntry& e = it->second; - CHECK(!e.released) - << "Read a buffer that is already out of scope"; + CHECK(!e.released) << "Read a buffer that is already out of scope"; CHECK_EQ(e.buffer->shape.size(), op->bounds.size()) - << "Prefetch dim should be the same as buffer dim"; + << "Prefetch dim should be the same as buffer dim"; - int block_size = 1, - elem_cnt = cache_line_size_ / e.buffer->dtype.bytes(), - shape = 0; + int block_size = 1, elem_cnt = cache_line_size_ / e.buffer->dtype.bytes(); int starts = op->bounds.size() - 1; - while (starts > 0 && arith::GetConstInt(e.buffer->shape[starts], &shape) - && elem_cnt >= block_size * shape) { - block_size *= shape; + + while (starts > 0) { + auto* shape_as_int = e.buffer->shape[starts].as(); + if (shape_as_int == nullptr || block_size * shape_as_int->value > elem_cnt) break; + block_size *= static_cast(shape_as_int->value); starts--; } PrimExpr stride(elem_cnt / block_size); @@ -340,33 +309,47 @@ class StorageFlattener : public StmtExprMutator { for (int i = op->bounds.size() - 1; i > starts; --i) { args.push_back(op->bounds[i]->min); } - auto &func_name = op->func->func_name(); - vars.push_back(Var( - "prefetch." + func_name + "." + std::to_string(starts), DataType::Int(32))); + auto& func_name = op->buffer->name; + vars.push_back(Var("prefetch." + func_name + "." + std::to_string(starts), DataType::Int(32))); args.push_back(op->bounds[starts]->min + stride * vars.back()); for (int i = starts - 1; i >= 0; --i) { - vars.push_back(Var( - "prefetch." + func_name + "." + std::to_string(i), DataType::Int(32))); + vars.push_back(Var("prefetch." + func_name + "." + std::to_string(i), DataType::Int(32))); args.push_back(vars.back() + op->bounds[i]->min); } for (int i = starts; i >= 0; --i) { if (i < starts) { - stmt = ForNode::make( - vars[i], 0, op->bounds[i]->extent, ForType::Serial, DeviceAPI::None, stmt); + stmt = For(vars[i], 0, op->bounds[i]->extent, ForType::Serial, DeviceAPI::None, stmt); } else { PrimExpr load = e.buffer.vload(e.RelIndex(args), e.buffer->dtype); - PrimExpr address = CallNode::make( - DataType::Handle(), tvm_address_of, {load}, CallNode::PureIntrinsic); - PrimExpr prefetch = CallNode::make( - op->dtype, CallNode::prefetch, {address, 0, 3, 1}, CallNode::Intrinsic); - stmt = EvaluateNode::make(prefetch); + PrimExpr address = + Call(DataType::Handle(), tvm_address_of, {load}, CallNode::PureIntrinsic); + PrimExpr prefetch = + Call(op->buffer->dtype, CallNode::prefetch, {address, 0, 3, 1}, CallNode::Intrinsic); + stmt = Evaluate(prefetch); PrimExpr extent = (op->bounds[i]->extent - 1) / stride + 1; - stmt = ForNode::make(vars[i], 0, extent, ForType::Serial, DeviceAPI::None, stmt); + stmt = For(vars[i], 0, extent, ForType::Serial, DeviceAPI::None, stmt); } } return stmt; } + PrimExpr VisitExpr_(const ProducerLoadNode* op) final { + LOG(FATAL) << "ProducerLoad cannot appear in a valid TIR PrimFunc."; + return PrimExpr(); + } + + Stmt VisitStmt_(const ProducerStoreNode* op) final { + LOG(FATAL) << "Cannot handle Provide " + << " please run SchedulePostProcToPrimFunc first"; + return Stmt(); + } + + Stmt VisitStmt_(const ProducerRealizeNode* op) final { + LOG(FATAL) << "Cannot handle Realize " + << " please run SchedulePostProcToPrimFunc first"; + return Stmt(); + } + private: // The specific tensor data layout is not determined before // StorageFlatten pass. We use buffer_bind_scope @@ -403,17 +386,18 @@ class StorageFlattener : public StmtExprMutator { // We do support a few relaxed case, such as bindingx // region with shape [1, 1, n, m] to buffer with shape [n, m] Stmt HandleBufferBindScope(const AttrStmtNode* op) { - Array arr = Downcast > (op->node); + Array arr = Downcast>(op->node); CHECK_EQ(arr.size(), 2U); const BufferNode* buffer = arr[0].as(); - const te::TensorNode* tensor = arr[1].as(); + const BufferNode* target = arr[1].as(); const CallNode* tuple = op->value.as(); - CHECK(buffer && tensor); + CHECK(buffer && target); CHECK(tuple && tuple->is_intrinsic(intrinsic::tvm_tuple)); - TensorKey key{tensor->op, tensor->value_index}; - CHECK(buf_map_.count(key)) - << "Cannot find buffer of " << tensor->op << " value=" << tensor->value_index; - const BufferEntry& be = buf_map_.at(key); + auto key = GetRef(target); + + auto it = buf_map_.find(key); + CHECK(it != buf_map_.end()) << "Cannot find buffer of " << key; + const BufferEntry& be = it->second; CHECK(!be.released); CHECK_EQ(tuple->args.size(), be.buffer->shape.size() * 2); Array begins, extents; @@ -426,15 +410,14 @@ class StorageFlattener : public StmtExprMutator { } else { for (size_t i = 0; i < tuple->args.size(); i += 2) { begins.push_back(tuple->args[i]); - auto new_extent = bounded_analyzer_->Simplify(tuple->args[i+1]); + auto new_extent = bound_analyzer_->Simplify(tuple->args[i + 1]); extents.push_back(new_extent); } } Buffer slice = be.buffer.MakeSlice(begins, extents); if (buffer->strides.size() == 0) { CHECK_EQ(slice->strides.size(), 0U) - << "Trying to bind compact buffer to strided one strides=" - << slice->strides; + << "Trying to bind compact buffer to strided one strides=" << slice->strides; } else { slice = slice.MakeStrideView(); } @@ -451,6 +434,7 @@ class StorageFlattener : public StmtExprMutator { } return body; } + // The buffer entry in the flatten map struct DimAlignInfo { int align_factor{0}; @@ -481,26 +465,23 @@ class StorageFlattener : public StmtExprMutator { } }; - bool ShapeIsValid(const Array &shape) { + bool ShapeIsValid(const Array& shape) { // Zero-dimensional tensor does not need boundary check. - if (!shape.size()) - return false; + if (!shape.size()) return false; for (size_t i = 0; i < shape.size(); ++i) { - if (!shape[i].defined() || !shape[i].dtype().is_scalar() || - is_negative_const(shape[i])) { + if (!shape[i].defined() || !shape[i].dtype().is_scalar() || is_negative_const(shape[i])) { return false; } } return true; } - PrimExpr MakeBound(const DataType &type, const Array &shape) { + PrimExpr MakeBound(const DataType& type, const Array& shape) { // We have already checked the shape size to be greater then 0. - PrimExpr bound = MulNode::make(make_const(shape[0].dtype(), type.lanes()), shape[0]); + PrimExpr bound = Mul(make_const(shape[0].dtype(), type.lanes()), shape[0]); for (size_t i = 1; i < shape.size(); ++i) { - bound = MulNode::make( - bound, MulNode::make(make_const(bound.dtype(), type.lanes()), shape[i])); + bound = Mul(bound, Mul(make_const(bound.dtype(), type.lanes()), shape[i])); } return bound; } @@ -509,9 +490,9 @@ class StorageFlattener : public StmtExprMutator { // Variable remap std::unordered_map var_remap_; // Buffer map - std::unordered_map buf_map_; + std::unordered_map buf_map_; // Dimension alignment - std::unordered_map > dim_align_; + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> dim_align_; // Storage scope std::unordered_map storage_scope_; // The current thread scope. @@ -520,24 +501,36 @@ class StorageFlattener : public StmtExprMutator { std::vector>> shape_collector_; // bounds populator. We really need the analyzer from it. // However - IRVisitorWithAnalyzer* bounded_analyzer_; + IRVisitorWithAnalyzer* bound_analyzer_; // The size of cacheline int cache_line_size_; - // The current stage is an OpenGL shader. - bool is_opengl_{false}; // Whether to mark load/store with theirs bounds. bool create_bound_attributes_{false}; }; -Stmt StorageFlatten(Stmt stmt, Map extern_buffer, - int cache_line_size, bool create_bound_attributes) { - IRVisitorWithAnalyzer bounded_analyzer; - bounded_analyzer(stmt); - stmt = - StorageFlattener(extern_buffer, cache_line_size, - create_bound_attributes, &bounded_analyzer)(std::move(stmt)); - return stmt; +PrimFunc StorageFlatten(PrimFunc func, int cache_line_size, bool create_bound_attributes) { + auto fptr = func.CopyOnWrite(); + + IRVisitorWithAnalyzer bound_analyzer; + bound_analyzer(fptr->body); + fptr->body = StorageFlattener(fptr->buffer_map, cache_line_size, create_bound_attributes, + &bound_analyzer)(std::move(fptr->body)); + return func; } +namespace transform { + +// TODO(tvm-team): consolidate configs to the PassContext +Pass StorageFlatten(int cache_line_size, bool create_bound_attributes) { + auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + return StorageFlatten(std::move(f), cache_line_size, create_bound_attributes); + }; + return CreatePrimFuncPass(pass_func, 0, "tir.StorageFlatten", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.StorageFlatten").set_body_typed(StorageFlatten); + +} // namespace transform + } // namespace tir } // namespace tvm diff --git a/src/tir/pass/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc similarity index 85% rename from src/tir/pass/storage_rewrite.cc rename to src/tir/transforms/storage_rewrite.cc index f3604b640349..283ab0f6f703 100644 --- a/src/tir/pass/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -23,17 +23,19 @@ * Re-write data access to enable memory sharing when possible. */ #include -#include -#include +#include +#include #include +#include #include -#include +#include + #include -#include #include -#include "ir_util.h" -#include "../../arith/compute_expr.h" +#include + #include "../../runtime/thread_storage_scope.h" +#include "ir_util.h" namespace tvm { namespace tir { @@ -124,8 +126,7 @@ class LinearAccessPatternFinder final : public StmtExprVisitor { const VarNode* buf = op->buffer_var.get(); auto it = alloc_info_.find(buf); if (it != alloc_info_.end() && it->second.alloc) { - CHECK_LT(it->second.level, scope_.size()) - << "Load memory in places other than store."; + CHECK_LT(it->second.level, scope_.size()) << "Load memory in places other than store."; scope_[it->second.level].touched.push_back(buf); } } @@ -141,24 +142,23 @@ class LinearAccessPatternFinder final : public StmtExprVisitor { // Directly reference to the variable count as a read. auto it = alloc_info_.find(buf); if (it != alloc_info_.end() && it->second.alloc) { - CHECK_LT(it->second.level, scope_.size()) - << " buf=" << buf->name_hint; + CHECK_LT(it->second.level, scope_.size()) << " buf=" << buf->name_hint; scope_[it->second.level].touched.push_back(buf); } } - template + template void VisitNewScope(const T* op) { scope_.push_back(StmtEntry()); StmtEntry e; e.stmt = op; - int64_t begin_index = static_cast(linear_seq_.size()); + int64_t begin_index = static_cast(linear_seq_.size()); // before scope. linear_seq_.push_back(e); StmtExprVisitor::VisitStmt_(op); // after scope. e.touched = std::move(scope_.back().touched); scope_.pop_back(); - int64_t end_index = static_cast(linear_seq_.size()); + int64_t end_index = static_cast(linear_seq_.size()); CHECK_GT(end_index, begin_index); e.scope_pair_offset = begin_index - end_index; linear_seq_.push_back(e); @@ -178,24 +178,17 @@ class LinearAccessPatternFinder final : public StmtExprVisitor { VisitNewScope(op); } else if (op->attr_key == attr::storage_scope) { const VarNode* buf = op->node.as(); - alloc_info_[buf].storage_scope = - StorageScope::make(op->value.as()->value); + alloc_info_[buf].storage_scope = StorageScope::Create(op->value.as()->value); StmtExprVisitor::VisitStmt_(op); } else { StmtExprVisitor::VisitStmt_(op); } } - void VisitStmt_(const IfThenElseNode* op) final { - VisitNewScope(op); - } + void VisitStmt_(const IfThenElseNode* op) final { VisitNewScope(op); } - void VisitStmt_(const ForNode* op) final { - VisitNewScope(op); - } + void VisitStmt_(const ForNode* op) final { VisitNewScope(op); } - void VisitStmt_(const AssertStmtNode* op) final { - VisitNewScope(op); - } + void VisitStmt_(const AssertStmtNode* op) final { VisitNewScope(op); } // linearized access sequence. std::vector linear_seq_; @@ -237,9 +230,7 @@ class LinearAccessPatternFinder final : public StmtExprVisitor { // class InplaceOpVerifier : public StmtExprVisitor { public: - bool Check(const Object* stmt, - const VarNode* dst, - const VarNode* src) { + bool Check(const Object* stmt, const VarNode* dst, const VarNode* src) { dst_ = dst; src_ = src; result_ = true; @@ -271,7 +262,8 @@ class InplaceOpVerifier : public StmtExprVisitor { void VisitExpr_(const VarNode* op) final { // assume all opaque access is unsafe if (op == dst_ || op == src_) { - result_ = false; return; + result_ = false; + return; } } @@ -292,9 +284,9 @@ class InplaceOpVerifier : public StmtExprVisitor { void VisitStmt_(const AttrStmtNode* op) final { // always reject extern code - if (op->attr_key == attr::extern_scope || - op->attr_key == attr::volatile_scope) { - result_ = false; return; + if (op->attr_key == attr::extern_scope || op->attr_key == attr::volatile_scope) { + result_ = false; + return; } StmtExprVisitor::VisitStmt_(op); } @@ -303,17 +295,19 @@ class InplaceOpVerifier : public StmtExprVisitor { const VarNode* buf = op->buffer_var.get(); // cannot read from dst_ (no reduction) if (buf == dst_) { - result_ = false; return; + result_ = false; + return; } // do not allow indirect memory load if (mem_nest_ != 0) { - result_ = false; return; + result_ = false; + return; } if (src_ == buf) { - if (store_ == nullptr || - store_->value.dtype() != op->dtype || + if (store_ == nullptr || store_->value.dtype() != op->dtype || !tir::ExprDeepEqual()(store_->index, op->index)) { - result_ = false; return; + result_ = false; + return; } } ++mem_nest_; @@ -321,7 +315,6 @@ class InplaceOpVerifier : public StmtExprVisitor { --mem_nest_; } - private: // result of the check bool result_{true}; @@ -357,10 +350,8 @@ class StoragePlanRewriter : public StmtExprMutator { for (StorageEntry* e : attach_map_.at(nullptr)) { // CHECK_EQ(e->scope.rank, 0); if (e->new_alloc.defined()) { - nest.emplace_back(AttrStmtNode::make( - e->alloc_var, attr::storage_scope, - StringImmNode::make(e->scope.to_string()), - EvaluateNode::make(0))); + nest.emplace_back(AttrStmt(e->alloc_var, attr::storage_scope, + StringImm(e->scope.to_string()), Evaluate(0))); nest.push_back(e->new_alloc); } } @@ -373,20 +364,16 @@ class StoragePlanRewriter : public StmtExprMutator { op = stmt.as(); auto it = alloc_map_.find(op->buffer_var.get()); if (it == alloc_map_.end()) return stmt; - return StoreNode::make(it->second->alloc_var, - op->value, - RemapIndex(op->value.dtype(), op->index, it->second), - op->predicate); + return Store(it->second->alloc_var, op->value, + RemapIndex(op->value.dtype(), op->index, it->second), op->predicate); } PrimExpr VisitExpr_(const LoadNode* op) final { PrimExpr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); auto it = alloc_map_.find(op->buffer_var.get()); if (it == alloc_map_.end()) return expr; - return LoadNode::make(op->dtype, - it->second->alloc_var, - RemapIndex(op->dtype, op->index, it->second), - op->predicate); + return Load(op->dtype, it->second->alloc_var, RemapIndex(op->dtype, op->index, it->second), + op->predicate); } PrimExpr VisitExpr_(const VarNode* op) final { auto it = alloc_map_.find(op); @@ -416,10 +403,8 @@ class StoragePlanRewriter : public StmtExprMutator { if (se->bits_offset != 0) { offset = make_const(offset.dtype(), se->bits_offset / elem_bits) + offset; } - return CallNode::make( - op->dtype, op->name, - {op->args[0], se->alloc_var, offset, extent, op->args[4]}, - op->call_type); + return Call(op->dtype, op->name, {op->args[0], se->alloc_var, offset, extent, op->args[4]}, + op->call_type); } else { return StmtExprMutator::VisitExpr_(op); } @@ -428,17 +413,14 @@ class StoragePlanRewriter : public StmtExprMutator { Stmt VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == attr::storage_scope) { return this->VisitStmt(op->body); - } else if (op->attr_key == attr::thread_extent || - op->attr_key == attr::virtual_thread || + } else if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread || attr::IsPragmaKey(op->attr_key)) { // remake all the allocation at the attach scope. if (attach_map_.count(op)) { auto& svec = attach_map_[op]; Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); - return AttrStmtNode::make( - op->node, op->attr_key, op->value, - MakeAttach(svec, op->body)); + return AttrStmt(op->node, op->attr_key, op->value, MakeAttach(svec, op->body)); } else { return StmtExprMutator::VisitStmt_(op); } @@ -447,31 +429,26 @@ class StoragePlanRewriter : public StmtExprMutator { op = stmt.as(); auto it = alloc_map_.find(op->node.as()); if (it == alloc_map_.end()) return stmt; - return AttrStmtNode::make( - it->second->alloc_var, op->attr_key, op->value, op->body); + return AttrStmt(it->second->alloc_var, op->attr_key, op->value, op->body); } else { return StmtExprMutator::VisitStmt_(op); } } Stmt VisitStmt_(const ForNode* op) final { - CHECK(op->for_type != ForType::Vectorized) - << "VectorizeLoop before LiftStorageAlloc"; + CHECK(op->for_type != ForType::Vectorized) << "VectorizeLoop before LiftStorageAlloc"; // remake all the allocation at the attach scope. if (attach_map_.count(op)) { auto& svec = attach_map_[op]; Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); - return ForNode::make( - op->loop_var, op->min, op->extent, op->for_type, op->device_api, - MakeAttach(svec, op->body)); + return For(op->loop_var, op->min, op->extent, op->for_type, op->device_api, + MakeAttach(svec, op->body)); } else { return StmtExprMutator::VisitStmt_(op); } } - Stmt VisitStmt_(const AllocateNode* op) final { - return this->VisitStmt(op->body); - } + Stmt VisitStmt_(const AllocateNode* op) final { return this->VisitStmt(op->body); } private: struct StorageEntry { @@ -516,15 +493,12 @@ class StoragePlanRewriter : public StmtExprMutator { std::vector kill; }; - Stmt MakeAttach(const std::vector& svec, - Stmt body) { + Stmt MakeAttach(const std::vector& svec, Stmt body) { std::vector nest; for (StorageEntry* e : svec) { if (e->new_alloc.defined()) { - nest.emplace_back(AttrStmtNode::make( - e->alloc_var, attr::storage_scope, - StringImmNode::make(e->scope.to_string()), - EvaluateNode::make(0))); + nest.emplace_back(AttrStmt(e->alloc_var, attr::storage_scope, + StringImm(e->scope.to_string()), Evaluate(0))); nest.push_back(e->new_alloc); } } @@ -544,15 +518,14 @@ class StoragePlanRewriter : public StmtExprMutator { attach_map_[e->attach_scope_].push_back(e); } // find allocation via attach map. - for (auto &kv : attach_map_) { + for (auto& kv : attach_map_) { // find the element with the most amount of bytes. std::vector& vec = kv.second; // try to find merge, for tagged memory for (size_t i = 0; i < vec.size(); ++i) { StorageEntry* e = vec[i]; if (e->scope.tag.length() != 0) { - CHECK_NE(e->const_nbits, 0U) - << "Special tagged memory must be const size"; + CHECK_NE(e->const_nbits, 0U) << "Special tagged memory must be const size"; for (size_t j = 0; j < i; ++j) { if (e->scope == vec[j]->scope) { vec[j]->merged_children.push_back(e); @@ -567,7 +540,8 @@ class StoragePlanRewriter : public StmtExprMutator { // already merged if (e->bits_offset != 0) continue; if (e->merged_children.size() != 0) { - NewAllocTagMerged(e); continue; + NewAllocTagMerged(e); + continue; } // Get the allocation size; e->alloc_var = e->allocs[0]->buffer_var; @@ -577,13 +551,14 @@ class StoragePlanRewriter : public StmtExprMutator { alloc_type = op->dtype; } } + + auto fmul = [](PrimExpr a, PrimExpr b) { return a * b; }; + if (e->allocs.size() == 1) { // simply use the original allocation. - PrimExpr sz = arith::ComputeReduce(e->allocs[0]->extents, - make_const(DataType::Int(32), 1)); - e->new_alloc = AllocateNode::make( - e->alloc_var, alloc_type, {sz}, - e->allocs[0]->condition, EvaluateNode::make(0)); + PrimExpr sz = foldl(fmul, make_const(DataType::Int(32), 1), e->allocs[0]->extents); + e->new_alloc = + Allocate(e->alloc_var, alloc_type, {sz}, e->allocs[0]->condition, Evaluate(0)); if (e->scope.tag.length() != 0) { MemoryInfo info = GetMemoryInfo(e->scope.to_string()); uint64_t total_elem = e->const_nbits / e->elem_type.bits(); @@ -594,13 +569,11 @@ class StoragePlanRewriter : public StmtExprMutator { // Build a merged allocation PrimExpr combo_size; for (const AllocateNode* op : e->allocs) { - PrimExpr sz = arith::ComputeReduce( - op->extents, make_const(DataType::Int(32), 1)); + PrimExpr sz = foldl(fmul, make_const(DataType::Int(32), 1), op->extents); auto nbits = op->dtype.bits() * op->dtype.lanes(); if (const auto* imm = sz.as()) { if (imm->value > std::numeric_limits::max() / nbits) { - LOG(WARNING) << "The allocation requires : " << imm->value - << " * " << nbits + LOG(WARNING) << "The allocation requires : " << imm->value << " * " << nbits << " bits, which is greater than the maximum of" " int32. The size is cast to int64." << "\n"; @@ -623,10 +596,9 @@ class StoragePlanRewriter : public StmtExprMutator { if (!divided) { combo_size = combo_size + make_const(DataType::Int(32), 1); } - combo_size = tir::Simplify(combo_size); - e->new_alloc = AllocateNode::make( - e->alloc_var, alloc_type, {combo_size}, const_true(), - EvaluateNode::make(0)); + combo_size = analyzer_.Simplify(combo_size); + e->new_alloc = + Allocate(e->alloc_var, alloc_type, {combo_size}, const_true(), Evaluate(0)); if (e->scope.tag.length() != 0) { MemoryInfo info = GetMemoryInfo(e->scope.to_string()); uint64_t total_elem = e->const_nbits / e->elem_type.bits(); @@ -652,7 +624,7 @@ class StoragePlanRewriter : public StmtExprMutator { // Always align to max_simd_bits // so we can remap types by keeping this property if (total_bits % align != 0) { - total_bits += align - (total_bits % align); + total_bits += align - (total_bits % align); } e->alloc_var = e->allocs[0]->buffer_var; for (StorageEntry* child : e->merged_children) { @@ -662,15 +634,13 @@ class StoragePlanRewriter : public StmtExprMutator { child->alloc_var = e->alloc_var; total_bits += child->const_nbits; if (total_bits % align != 0) { - total_bits += align - (total_bits % align); + total_bits += align - (total_bits % align); } } uint64_t type_bits = e->elem_type.bits() * e->elem_type.lanes(); - PrimExpr alloc_size = make_const(e->allocs[0]->extents[0].dtype(), - (total_bits + type_bits - 1) / type_bits); - e->new_alloc = AllocateNode::make( - e->alloc_var, e->elem_type, {alloc_size}, const_true(), - EvaluateNode::make(0)); + PrimExpr alloc_size = + make_const(e->allocs[0]->extents[0].dtype(), (total_bits + type_bits - 1) / type_bits); + e->new_alloc = Allocate(e->alloc_var, e->elem_type, {alloc_size}, const_true(), Evaluate(0)); if (info.defined()) { CHECK_LE(total_bits, info->max_num_bits) << "Allocation exceed bound of memory tag " << e->scope.to_string(); @@ -763,8 +733,7 @@ class StoragePlanRewriter : public StmtExprMutator { visitor.Check(s.stmt, var, src)) { uint64_t const_nbits = static_cast(ae.alloc->constant_allocation_size()) * - ae.alloc->dtype.bits() * - ae.alloc->dtype.lanes(); + ae.alloc->dtype.bits() * ae.alloc->dtype.lanes(); if (src_entry->const_nbits == const_nbits && !inplace_found) { // successfully inplace dst_entry = src_entry; @@ -785,8 +754,7 @@ class StoragePlanRewriter : public StmtExprMutator { // enter/exit new scope if (s.stmt->IsInstance()) { const auto* op = static_cast(s.stmt); - if (op->attr_key == attr::thread_extent || - op->attr_key == attr::virtual_thread || + if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread || attr::IsPragmaKey(op->attr_key)) { PlanNewScope(op); } else { @@ -815,10 +783,8 @@ class StoragePlanRewriter : public StmtExprMutator { } } // Allocate new storage entry. - StorageEntry* NewAlloc(const AllocateNode* op, - const Object* attach_scope, - const StorageScope& scope, - size_t const_nbits) { + StorageEntry* NewAlloc(const AllocateNode* op, const Object* attach_scope, + const StorageScope& scope, size_t const_nbits) { CHECK(op != nullptr); // Re-use not successful, allocate a new buffer. std::unique_ptr entry(new StorageEntry()); @@ -831,23 +797,21 @@ class StoragePlanRewriter : public StmtExprMutator { return e; } - StorageEntry* FindAlloc(const AllocateNode* op, - const Object* attach_scope, + StorageEntry* FindAlloc(const AllocateNode* op, const Object* attach_scope, const StorageScope& scope) { CHECK(op != nullptr); // skip plan for local variable, // compiler can do a better job with register allocation. const uint64_t match_range = 16; uint64_t op_elem_bits = op->dtype.bits() * op->dtype.lanes(); - uint64_t const_nbits = static_cast( - op->constant_allocation_size() * op_elem_bits); + uint64_t const_nbits = static_cast(op->constant_allocation_size() * op_elem_bits); // disable reuse of small arrays, they will be lowered to registers in LLVM // This rules only apply if we are using non special memory if (scope.tag.length() == 0) { if (scope.rank >= StorageRank::kWarp || op->dtype.is_handle()) { return NewAlloc(op, attach_scope, scope, const_nbits); } - if (const_nbits > 0 && const_nbits <= 32) { + if (const_nbits > 0 && const_nbits <= 32) { return NewAlloc(op, attach_scope, scope, const_nbits); } } @@ -858,7 +822,7 @@ class StoragePlanRewriter : public StmtExprMutator { auto end = const_free_map_.upper_bound(const_nbits * match_range); // start looking at the buffer that is bigger than the required size first for (auto it = mid; it != end; ++it) { - StorageEntry *e = it->second; + StorageEntry* e = it->second; if (e->attach_scope_ != attach_scope) continue; if (e->scope != scope) continue; // when not divided, no reuse, eg, float4 vs float3 @@ -870,7 +834,7 @@ class StoragePlanRewriter : public StmtExprMutator { // then start looking at smaller buffers. for (auto it = mid; it != begin;) { --it; - StorageEntry *e = it->second; + StorageEntry* e = it->second; if (e->attach_scope_ != attach_scope) continue; if (e->scope != scope) continue; if (e->elem_type != op->dtype.element_of()) continue; @@ -880,8 +844,7 @@ class StoragePlanRewriter : public StmtExprMutator { } } else { // Simple strategy: round roubin. - for (auto it = sym_free_list_.begin(); - it != sym_free_list_.end(); ++it) { + for (auto it = sym_free_list_.begin(); it != sym_free_list_.end(); ++it) { StorageEntry* e = *it; if (e->attach_scope_ != attach_scope) continue; if (e->scope != scope) continue; @@ -903,8 +866,7 @@ class StoragePlanRewriter : public StmtExprMutator { // This rules only apply if we are using non special memory if (e->scope.tag.length() == 0) { // Disable sharing of local memory. - if (e->scope.rank >= StorageRank::kWarp || - e->allocs[0]->dtype.is_handle()) return; + if (e->scope.rank >= StorageRank::kWarp || e->allocs[0]->dtype.is_handle()) return; // disable reuse of small arrays if (e->const_nbits > 0 && e->const_nbits <= 32) return; } @@ -962,19 +924,15 @@ class VectorAllocRewriter : public StmtExprMutator { op = stmt.as(); const auto& tvec = acc_map_[op->buffer_var.get()]; - if (tvec.size() == 1 && - tvec[0].element_of() == op->dtype.element_of() && - tvec[0].lanes() % op->dtype.lanes() == 0 && - tvec[0].lanes() != op->dtype.lanes()) { + if (tvec.size() == 1 && tvec[0].element_of() == op->dtype.element_of() && + tvec[0].lanes() % op->dtype.lanes() == 0 && tvec[0].lanes() != op->dtype.lanes()) { int factor = tvec[0].lanes() / op->dtype.lanes(); Array extents = op->extents; arith::ModularSet me = analyzer_.modular_set(extents[extents.size() - 1]); if (me->base % factor == 0 && me->coeff % factor == 0) { extents.Set(extents.size() - 1, extents[extents.size() - 1] / make_const(extents[0].dtype(), factor)); - return AllocateNode::make( - op->buffer_var, tvec[0], extents, - op->condition, op->body); + return Allocate(op->buffer_var, tvec[0], extents, op->condition, op->body); } } return stmt; @@ -993,6 +951,10 @@ class VectorAllocRewriter : public StmtExprMutator { arith::Analyzer analyzer_; }; +Stmt StorageRewrite(Stmt stmt) { + stmt = StoragePlanRewriter().Rewrite(std::move(stmt), true); + return VectorAllocRewriter()(std::move(stmt)); +} PrimFunc PointerValueTypeRewrite(PrimFunc f) { auto* n = f.CopyOnWrite(); @@ -1007,8 +969,7 @@ PrimFunc PointerValueTypeRewrite(PrimFunc f) { const auto& tvec = rewriter.acc_map_[var.get()]; if (tvec.size() == 1) { - tir::Var new_var(var->name_hint, - PointerType(PrimType(tvec[0]))); + tir::Var new_var(var->name_hint, PointerType(PrimType(tvec[0]))); args.push_back(new_var); remap_vars.Set(var, new_var); @@ -1016,8 +977,7 @@ PrimFunc PointerValueTypeRewrite(PrimFunc f) { // always set data type to be non vectorized so // load/store can still work via scalarization if (tvec.size() != 0 && !var->type_annotation.defined()) { - tir::Var new_var(var->name_hint, - PointerType(PrimType(tvec[0].with_lanes(1)))); + tir::Var new_var(var->name_hint, PointerType(PrimType(tvec[0].with_lanes(1)))); args.push_back(new_var); remap_vars.Set(var, new_var); } else { @@ -1035,9 +995,31 @@ PrimFunc PointerValueTypeRewrite(PrimFunc f) { return f; } -Stmt StorageRewrite(Stmt stmt) { - stmt = StoragePlanRewriter().Rewrite(std::move(stmt), true); - return VectorAllocRewriter()(std::move(stmt)); +namespace transform { + +Pass StorageRewrite() { + auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + n->body = StoragePlanRewriter().Rewrite(std::move(n->body), true); + n->body = VectorAllocRewriter()(std::move(n->body)); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.StorageRewrite", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.StorageRewrite").set_body_typed(StorageRewrite); + +Pass PointerValueTypeRewrite() { + auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + return PointerValueTypeRewrite(std::move(f)); + }; + return CreatePrimFuncPass(pass_func, 0, "tir.PointerValueTypeRewrite", {}); } + +TVM_REGISTER_GLOBAL("tir.transform.PointerValueTypeRewrite") + .set_body_typed(PointerValueTypeRewrite); + +} // namespace transform + } // namespace tir } // namespace tvm diff --git a/src/tir/transforms/tensorcore_infer_fragment.cc b/src/tir/transforms/tensorcore_infer_fragment.cc index 1ece078e6c3c..493aa516fbd7 100644 --- a/src/tir/transforms/tensorcore_infer_fragment.cc +++ b/src/tir/transforms/tensorcore_infer_fragment.cc @@ -21,18 +21,17 @@ * \brief Infer TensorCore metadata from tensor intrinsic. * \file tensorcore_fragment.cc */ +#include #include -#include -#include #include -#include +#include #include #include -#include "../pass/storage_access.h" -#include "../pass/ir_util.h" #include "../../runtime/thread_storage_scope.h" +#include "ir_util.h" +#include "storage_access.h" namespace tvm { namespace tir { @@ -48,7 +47,7 @@ class FragmentGetter : public StmtExprVisitor { std::string layout; FragmentInfo() = default; FragmentInfo(int _m, int _n, int _k, const std::string& _layout) - : m(_m), n(_n), k(_k), layout(_layout) {} + : m(_m), n(_n), k(_k), layout(_layout) {} }; void VisitExpr_(const CallNode* op) final { @@ -137,13 +136,12 @@ class FragmentGetter : public StmtExprVisitor { // Check shape of fragment making sure it is a valid shape for tvm_mma_sync class FragmentChecker : public StmtExprVisitor { public: - explicit FragmentChecker(const FragmentGetter &getter) : fragment_getter(getter) {} + explicit FragmentChecker(const FragmentGetter& getter) : fragment_getter(getter) {} void VisitExpr_(const CallNode* op) final { StmtExprVisitor::VisitExpr_(op); // Check shape when calling tvm_mma_sync - if (op->is_intrinsic(intrinsic::tvm_mma_sync) || - op->is_intrinsic(intrinsic::tvm_bmma_sync)) { + if (op->is_intrinsic(intrinsic::tvm_mma_sync) || op->is_intrinsic(intrinsic::tvm_bmma_sync)) { CHECK_EQ(op->args.size(), 8U); const VarNode* buffer_var_d = op->args[0].as(); const VarNode* buffer_var_a = op->args[2].as(); @@ -171,13 +169,13 @@ class FragmentChecker : public StmtExprVisitor { return info1.m == info2.m && info1.n == info2.n && info1.k == info2.k; } // Fragment infomation - const FragmentGetter &fragment_getter; + const FragmentGetter& fragment_getter; }; // Store the metadata into attributes class InferFragmenter : public StmtMutator { public: - explicit InferFragmenter(const FragmentGetter &getter) : fragment_getter(getter) {} + explicit InferFragmenter(const FragmentGetter& getter) : fragment_getter(getter) {} Stmt VisitStmt_(const AllocateNode* op) final { Stmt stmt = StmtMutator::VisitStmt_(op); @@ -187,15 +185,14 @@ class InferFragmenter : public StmtMutator { FragmentGetter::FragmentInfo info = fragment_getter.fragments.at(buffer); // Add shape attribute to all fragments - std::string shape = std::to_string(info.m) + ", " + - std::to_string(info.n) + ", " + - std::to_string(info.k); - PrimExpr shape_expr = StringImmNode::make(shape); - Stmt shape_attr = AttrStmtNode::make(op->buffer_var, attr::fragment_shape, shape_expr, stmt); + std::string shape = + std::to_string(info.m) + ", " + std::to_string(info.n) + ", " + std::to_string(info.k); + PrimExpr shape_expr = StringImm(shape); + Stmt shape_attr = AttrStmt(op->buffer_var, attr::fragment_shape, shape_expr, stmt); if (info.layout != "") { // Add shape attribute to matrix_a and matrix_b - Stmt layout_attr = AttrStmtNode::make(op->buffer_var, attr::fragment_layout, - StringImmNode::make(info.layout), shape_attr); + Stmt layout_attr = + AttrStmt(op->buffer_var, attr::fragment_layout, StringImm(info.layout), shape_attr); return layout_attr; } else { return shape_attr; @@ -206,7 +203,7 @@ class InferFragmenter : public StmtMutator { private: // Fragment infomation - const FragmentGetter &fragment_getter; + const FragmentGetter& fragment_getter; }; Stmt InferFragment(Stmt stmt) { @@ -229,8 +226,7 @@ Pass InferFragment() { return CreatePrimFuncPass(pass_func, 0, "tir.InferFragment", {}); } -TVM_REGISTER_GLOBAL("tir.transform.InferFragment") -.set_body_typed(InferFragment); +TVM_REGISTER_GLOBAL("tir.transform.InferFragment").set_body_typed(InferFragment); } // namespace transform } // namespace tir diff --git a/src/tir/transforms/thread_storage_sync.cc b/src/tir/transforms/thread_storage_sync.cc index f464af655a15..612efb092395 100644 --- a/src/tir/transforms/thread_storage_sync.cc +++ b/src/tir/transforms/thread_storage_sync.cc @@ -20,40 +20,35 @@ /*! * \file thread_storage_sync.cc */ -#include -#include +#include #include -#include +#include #include #include -#include #include #include -#include "../pass/ir_util.h" -#include "../pass/storage_access.h" #include "../../runtime/thread_storage_scope.h" +#include "ir_util.h" +#include "storage_access.h" namespace tvm { namespace tir { class ThreadSyncPlanner : public StorageAccessVisitor { public: - explicit ThreadSyncPlanner(StorageScope sync_scope) - : sync_scope_(sync_scope) {} + explicit ThreadSyncPlanner(StorageScope sync_scope) : sync_scope_(sync_scope) {} - // The syncs inserted before each statement + // The syncs inserted before each statement std::unordered_set syncs_inserted_; protected: - bool Enabled(const VarNode* buf, - const StorageScope& scope) const final { + bool Enabled(const VarNode* buf, const StorageScope& scope) const final { return in_device_env() && scope == sync_scope_; } // Plan the sync - std::vector Summarize( - std::vector seq, const ForNode* loop) final { + std::vector Summarize(std::vector seq, const ForNode* loop) final { // Unsynced reads and writes std::vector reads; std::vector writes; @@ -71,19 +66,23 @@ class ThreadSyncPlanner : public StorageAccessVisitor { for (const AccessEntry& acc : s.access) { if (acc.type == kRead) { if (FindConflict(writes, acc, false)) { - sync_before_stmt = true; break; + sync_before_stmt = true; + break; } } else if (acc.type == kWrite) { if (FindConflict(reads, acc, false)) { - sync_before_stmt = true; break; + sync_before_stmt = true; + break; } } else if (acc.type == kSync) { - reads.clear(); writes.clear(); + reads.clear(); + writes.clear(); } } // If sync is inserted. remove the irrelevant things. if (sync_before_stmt) { - reads.clear(); writes.clear(); + reads.clear(); + writes.clear(); } // Add the read/write of current statement for (const AccessEntry& acc : s.access) { @@ -92,12 +91,12 @@ class ThreadSyncPlanner : public StorageAccessVisitor { } else if (acc.type == kWrite) { writes.push_back(acc); } else if (acc.type == kSync) { - reads.clear(); writes.clear(); + reads.clear(); + writes.clear(); } } if (sync_before_stmt) { - CHECK_EQ(condition_counter(), 0) - << "Cannot insert syncs inside condition"; + CHECK_EQ(condition_counter(), 0) << "Cannot insert syncs inside condition"; syncs_inserted_.insert(s.stmt); } } @@ -110,19 +109,21 @@ class ThreadSyncPlanner : public StorageAccessVisitor { for (const AccessEntry& acc : s.access) { if (acc.type == kRead) { if (FindConflict(writes, acc, true)) { - sync_before_stmt = true; break; + sync_before_stmt = true; + break; } } else if (acc.type == kWrite) { if (FindConflict(reads, acc, true)) { - sync_before_stmt = true; break; + sync_before_stmt = true; + break; } } else if (acc.type == kSync) { - reads.clear(); writes.clear(); + reads.clear(); + writes.clear(); } } if (sync_before_stmt) { - CHECK_EQ(condition_counter(), 0) - << "Cannot insert syncs inside condition"; + CHECK_EQ(condition_counter(), 0) << "Cannot insert syncs inside condition"; syncs_inserted_.insert(s.stmt); break; } @@ -175,22 +176,16 @@ class ThreadSyncPlanner : public StorageAccessVisitor { private: // find conflicting entry in vec. - bool FindConflict(const std::vector& vec, - const AccessEntry& e, - bool loop_carry) { + bool FindConflict(const std::vector& vec, const AccessEntry& e, bool loop_carry) { for (const AccessEntry& x : vec) { if (x.buffer.same_as(e.buffer)) { // Assumes no race between threads // Same index value means no conflicts // TODO(tqchen) more standard set based testing. - if (e.touched.is_single_point() && - x.touched.is_single_point()) { - if (ExprDeepEqual()(e.touched.point_value(), - x.touched.point_value())) continue; + if (e.touched.is_single_point() && x.touched.is_single_point()) { + if (ExprDeepEqual()(e.touched.point_value(), x.touched.point_value())) continue; } - if (x.double_buffer_write && - e.type == kRead && - !loop_carry) continue; + if (x.double_buffer_write && e.type == kRead && !loop_carry) continue; return true; } } @@ -204,8 +199,7 @@ class ThreadSyncPlanner : public StorageAccessVisitor { class ThreadSyncInserter : public StmtExprMutator { public: - ThreadSyncInserter(StorageScope sync_scope, - const std::unordered_set& syncs) + ThreadSyncInserter(StorageScope sync_scope, const std::unordered_set& syncs) : sync_scope_(sync_scope), syncs_(syncs) {} Stmt VisitStmt(const Stmt& stmt) final { @@ -215,10 +209,8 @@ class ThreadSyncInserter : public StmtExprMutator { if (sync_scope_.rank == StorageRank::kGlobal) { barrier = MakeGlobalBarrier(); } else { - barrier = EvaluateNode::make( - CallNode::make(DataType::Int(32), intrinsic::tvm_storage_sync, - {StringImmNode::make(sync_scope_.to_string())}, - CallNode::Intrinsic)); + barrier = Evaluate(Call(DataType::Int(32), intrinsic::tvm_storage_sync, + {StringImm(sync_scope_.to_string())}, CallNode::Intrinsic)); } // Mutate after query, to avoid stmt change. auto ret = StmtExprMutator::VisitStmt(stmt); @@ -259,8 +251,7 @@ class ThreadSyncInserter : public StmtExprMutator { return ret; } else if (op->attr_key == attr::storage_scope) { const VarNode* buf = op->node.as(); - storage_scope_[buf] = - StorageScope::make(op->value.as()->value); + storage_scope_[buf] = StorageScope::Create(op->value.as()->value); return StmtExprMutator::VisitStmt_(op); } else { return StmtExprMutator::VisitStmt_(op); @@ -306,24 +297,21 @@ class ThreadSyncInserter : public StmtExprMutator { // private functions. Stmt InitGlobalBarrier(const AttrStmtNode* op) { CHECK(op != nullptr); - Array pargs = {StringImmNode::make(runtime::symbol::tvm_prepare_global_barrier)}; - Stmt prep = EvaluateNode::make( - CallNode::make(DataType::Int(32), intrinsic::tvm_call_packed, pargs, CallNode::Intrinsic)); + Array pargs = {StringImm(runtime::symbol::tvm_prepare_global_barrier)}; + Stmt prep = + Evaluate(Call(DataType::Int(32), intrinsic::tvm_call_packed, pargs, CallNode::Intrinsic)); Stmt body = op->body; for (const auto& kv : rw_stats_) { const auto& e = kv.second; if (e.read_count != 0 && e.write_count != 0) { - body = AttrStmtNode::make(kv.first, attr::volatile_scope, 1, body); + body = AttrStmt(kv.first, attr::volatile_scope, 1, body); } } rw_stats_.clear(); - Stmt kinit = EvaluateNode::make( - CallNode::make( - DataType::Int(32), - intrinsic::tvm_global_barrier_kinit, {}, CallNode::Intrinsic)); + Stmt kinit = Evaluate( + Call(DataType::Int(32), intrinsic::tvm_global_barrier_kinit, {}, CallNode::Intrinsic)); body = SeqStmt({kinit, body}); - body = AttrStmtNode::make( - op->node, op->attr_key, op->value, body); + body = AttrStmt(op->node, op->attr_key, op->value, body); return SeqStmt({prep, body}); } Stmt MakeGlobalBarrier() { @@ -333,10 +321,9 @@ class ThreadSyncInserter : public StmtExprMutator { num_work_dim_ = thread_extents_.size(); for (const AttrStmtNode* attr : thread_extents_) { IterVar iv = Downcast(attr->node); - runtime::ThreadScope s = runtime::ThreadScope::make(iv->thread_tag); + runtime::ThreadScope s = runtime::ThreadScope::Create(iv->thread_tag); if (s.rank == 0) { - num_blocks_ = (num_blocks_.defined() ? - attr->value * num_blocks_ : attr->value); + num_blocks_ = (num_blocks_.defined() ? attr->value * num_blocks_ : attr->value); } else if (s.rank == 1) { PrimExpr cond = iv->var == make_zero(iv->var.dtype()); is_lead_ = is_lead_.defined() ? (is_lead_ && cond) : cond; @@ -345,11 +332,9 @@ class ThreadSyncInserter : public StmtExprMutator { } else { CHECK_EQ(num_work_dim_, thread_extents_.size()); } - return EvaluateNode::make( - CallNode::make(DataType::Int(32), intrinsic::tvm_storage_sync, - {StringImmNode::make(sync_scope_.to_string()), - is_lead_, num_blocks_}, - CallNode::Intrinsic)); + return Evaluate(Call(DataType::Int(32), intrinsic::tvm_storage_sync, + {StringImm(sync_scope_.to_string()), is_lead_, num_blocks_}, + CallNode::Intrinsic)); } // data structure. StorageScope sync_scope_; @@ -357,7 +342,7 @@ class ThreadSyncInserter : public StmtExprMutator { // The storage scope of each buffer std::unordered_map storage_scope_; // The read write statistics of storage - std::unordered_map rw_stats_; + std::unordered_map rw_stats_; // The statistics for global barrier bool in_thread_env_{false}; // memorized results @@ -368,7 +353,7 @@ class ThreadSyncInserter : public StmtExprMutator { }; Stmt ThreadSync(Stmt stmt, std::string storage_scope) { - StorageScope sync_scope = StorageScope::make(storage_scope); + StorageScope sync_scope = StorageScope::Create(storage_scope); ThreadSyncPlanner planner(sync_scope); planner(stmt); return ThreadSyncInserter(sync_scope, planner.syncs_inserted_)(std::move(stmt)); @@ -376,7 +361,7 @@ Stmt ThreadSync(Stmt stmt, std::string storage_scope) { namespace transform { -Pass ThreadSync(std::string storage_scope) { +Pass ThreadSync(String storage_scope) { auto pass_func = [storage_scope](PrimFunc f, IRModule m, PassContext ctx) { auto* n = f.CopyOnWrite(); n->body = ThreadSync(std::move(n->body), storage_scope); @@ -385,8 +370,7 @@ Pass ThreadSync(std::string storage_scope) { return CreatePrimFuncPass(pass_func, 0, "tir.ThreadSync", {}); } -TVM_REGISTER_GLOBAL("tir.transform.ThreadSync") -.set_body_typed(ThreadSync); +TVM_REGISTER_GLOBAL("tir.transform.ThreadSync").set_body_typed(ThreadSync); } // namespace transform } // namespace tir diff --git a/src/tir/pass/unroll_loop.cc b/src/tir/transforms/unroll_loop.cc similarity index 62% rename from src/tir/pass/unroll_loop.cc rename to src/tir/transforms/unroll_loop.cc index 0167dbcec5f2..a15190665949 100644 --- a/src/tir/pass/unroll_loop.cc +++ b/src/tir/transforms/unroll_loop.cc @@ -22,41 +22,70 @@ * \file unroll_loop.cc */ // Unrolls the loop as in Halide pipeline. +#include +#include #include -#include +#include #include -#include +#include + #include +#include #include -#include "../../arith/compute_expr.h" + +#include "ir_util.h" namespace tvm { namespace tir { +struct UnrollLoopConfigNode : public tvm::AttrsNode { + int auto_max_step; + int auto_max_depth; + int auto_max_extent; + int explicit_unroll; + + TVM_DECLARE_ATTRS(UnrollLoopConfigNode, "tir.transform.UnrollLoopConfig") { + TVM_ATTR_FIELD(auto_max_step) + .describe("Threshold of number of steps in the loop to be automatically unrolled") + .set_default(0); + TVM_ATTR_FIELD(auto_max_depth) + .describe("The maximum nested level of loops that can be automatically unrolled.") + .set_default(8); + TVM_ATTR_FIELD(auto_max_extent) + .describe("The maximum extent of loop that will be unrolled.") + .set_default(0); + TVM_ATTR_FIELD(explicit_unroll) + .describe("Whether to explicitly unroll the loop instead of setting a pragma") + .set_default(true); + } +}; + +class UnrollLoopConfig : public Attrs { + public: + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(UnrollLoopConfig, Attrs, UnrollLoopConfigNode); +}; + +TVM_REGISTER_NODE_TYPE(UnrollLoopConfigNode); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.UnrollLoop", UnrollLoopConfig); + class LoopUnroller : public StmtExprMutator { public: - explicit LoopUnroller(int auto_max_step, - int auto_max_depth, - int auto_max_extent, + explicit LoopUnroller(int auto_max_step, int auto_max_depth, int auto_max_extent, bool explicit_unroll) : auto_max_step_(auto_max_step), auto_max_depth_(auto_max_depth), auto_max_extent_(auto_max_extent), - explicit_unroll_(explicit_unroll) { - } + explicit_unroll_(explicit_unroll) {} Stmt VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == "pragma_auto_unroll_max_step") { - int value = 0; - CHECK(arith::GetConstInt(op->value, &value)); + int value = static_cast(Downcast(op->value)->value); std::swap(value, auto_max_step_); Stmt ret = this->VisitStmt(op->body); std::swap(value, auto_max_step_); return ret; } else if (op->attr_key == "pragma_unroll_explicit") { - int value = 0; - CHECK(arith::GetConstInt(op->value, &value)); - bool explicit_unroll = value; + bool explicit_unroll = Downcast(op->value)->value; std::swap(explicit_unroll, explicit_unroll_); Stmt ret = this->VisitStmt(op->body); std::swap(explicit_unroll, explicit_unroll_); @@ -71,24 +100,19 @@ class LoopUnroller : public StmtExprMutator { op = stmt.as(); int value = GetExtent(op); // condition for auto unroll - bool auto_unroll = ( - op->for_type == ForType::Serial && - value >= 0 && - normal_loop_depth_ == 0 && - unroll_depth_ <= auto_max_depth_); + bool auto_unroll = (op->for_type == ForType::Serial && value >= 0 && normal_loop_depth_ == 0 && + unroll_depth_ <= auto_max_depth_); - auto_unroll = auto_unroll && ( - value * step_count_ <= auto_max_step_|| - value <= auto_max_extent_); + auto_unroll = + auto_unroll && (value * step_count_ <= auto_max_step_ || value <= auto_max_extent_); if (op->for_type == ForType::Unrolled) { - CHECK_GE(value, 0) - << "Cannot unroll non-constant loop"; + CHECK_GE(value, 0) << "Cannot unroll non-constant loop"; auto_unroll = true; } if (auto_unroll) { - step_count_ *= value; + step_count_ *= value; unroll_depth_ += 1; } else { normal_loop_depth_ += 1; @@ -101,9 +125,8 @@ class LoopUnroller : public StmtExprMutator { } else { if (auto_unroll) { if (op->for_type != ForType::Unrolled) { - return ForNode::make( - op->loop_var, op->min, op->extent, - ForType::Unrolled, op->device_api, op->body); + return For(op->loop_var, op->min, op->extent, ForType::Unrolled, op->device_api, + op->body); } } return stmt; @@ -141,7 +164,7 @@ class LoopUnroller : public StmtExprMutator { int value = GetExtent(op); // For loop must have a constant integer extent CHECK_NE(value, -1) << "loop doesn't have a constant integer extent"; - if (value == 0) return EvaluateNode::make(0); + if (value == 0) return Evaluate(0); Stmt body = op->body; Map vmap; Array unrolled; @@ -157,8 +180,8 @@ class LoopUnroller : public StmtExprMutator { // returns the extent of the loop if it's a constant integer, otherwise return -1 int GetExtent(const ForNode* op) { // constant folding. - PrimExpr extent = tir::Simplify(op->extent); - const IntImmNode *v1 = extent.as(); + PrimExpr extent = analyzer_.Simplify(op->extent); + const IntImmNode* v1 = extent.as(); int value = -1; // integers that do not fit in int32_t are treated as symbolic, // as it's impossible to unroll such large loops @@ -181,19 +204,13 @@ class LoopUnroller : public StmtExprMutator { int unroll_depth_{0}; // Number of total steps unrolled int step_count_{0}; + // analyzer + arith::Analyzer analyzer_; }; - -Stmt UnrollLoop(Stmt stmt, - int auto_max_step, - int auto_max_depth, - int auto_max_extent, - bool explicit_unroll) { - Stmt ret = LoopUnroller( - auto_max_step, - auto_max_depth, - auto_max_extent, - explicit_unroll)(stmt); +Stmt UnrollLoop(Stmt stmt, UnrollLoopConfig cfg) { + Stmt ret = LoopUnroller(cfg->auto_max_step, cfg->auto_max_depth, cfg->auto_max_extent, + cfg->explicit_unroll)(stmt); if (!ret.same_as(stmt)) { return ConvertSSA(ret); } else { @@ -201,13 +218,24 @@ Stmt UnrollLoop(Stmt stmt, } } -Stmt UnrollLoopExplicitly(Stmt stmt) { - const ForNode* op = stmt.as(); - if (!op) { - LOG(FATAL) << "attempted to unroll a non-loop statement"; - } - return LoopUnroller(0, 0, 0, false).Unroll(op); +namespace transform { + +Pass UnrollLoop() { + auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + auto cfg = ctx->GetConfig("tir.UnrollLoop"); + if (!cfg.defined()) { + cfg = AttrsWithDefaultValues(); + } + n->body = UnrollLoop(std::move(f->body), cfg.value()); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.UnrollLoop", {}); } +TVM_REGISTER_GLOBAL("tir.transform.UnrollLoop").set_body_typed(UnrollLoop); + +} // namespace transform + } // namespace tir } // namespace tvm diff --git a/src/tir/pass/vectorize_loop.cc b/src/tir/transforms/vectorize_loop.cc similarity index 64% rename from src/tir/pass/vectorize_loop.cc rename to src/tir/transforms/vectorize_loop.cc index b73587db2ab6..227aea2eb575 100644 --- a/src/tir/pass/vectorize_loop.cc +++ b/src/tir/transforms/vectorize_loop.cc @@ -21,14 +21,16 @@ * \file vectorize_loop.cc */ // Loop vectorizer as in Halide pipeline. +#include +#include #include -#include +#include #include -#include -#include +#include + #include +#include #include -#include "../../arith/compute_expr.h" namespace tvm { namespace tir { @@ -37,13 +39,12 @@ inline PrimExpr BroadcastTo(PrimExpr e, int lanes) { if (e.dtype().lanes() == lanes) return e; if (const BroadcastNode* op = e.as()) { if (lanes % op->lanes == 0) { - return BroadcastNode::make(op->value, lanes); + return Broadcast(op->value, lanes); } } - CHECK_EQ(e.dtype().lanes(), 1) - << "Cannot broadcast lane=" << e.dtype().lanes() - << " to " << lanes; - return BroadcastNode::make(e, lanes); + CHECK_EQ(e.dtype().lanes(), 1) << "Cannot broadcast lane=" << e.dtype().lanes() << " to " + << lanes; + return Broadcast(e, lanes); } // Rewrite vectorized allocation access @@ -63,9 +64,7 @@ class VecAllocAccess : public StmtExprMutator { PrimExpr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); if (op->buffer_var.get() == buf_) { - return LoadNode::make(op->dtype, op->buffer_var, - op->index * var_lanes_ + var_, - op->predicate); + return Load(op->dtype, op->buffer_var, op->index * var_lanes_ + var_, op->predicate); } else { return expr; } @@ -75,10 +74,7 @@ class VecAllocAccess : public StmtExprMutator { Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); if (op->buffer_var.get() == buf_) { - return StoreNode::make(op->buffer_var, - op->value, - op->index * var_lanes_ + var_, - op->predicate); + return Store(op->buffer_var, op->value, op->index * var_lanes_ + var_, op->predicate); } else { return stmt; } @@ -95,9 +91,8 @@ class VecAllocAccess : public StmtExprMutator { class Vectorizer : public StmtExprMutator { public: - Vectorizer(Var var, int var_lanes) - : var_(var), var_lanes_(var_lanes) { - ramp_ = RampNode::make(0, 1, var_lanes); + Vectorizer(Var var, int var_lanes) : var_(var), var_lanes_(var_lanes) { + ramp_ = Ramp(0, 1, var_lanes); } Stmt VisitStmt(const Stmt& stmt) final { @@ -112,16 +107,17 @@ class Vectorizer : public StmtExprMutator { } PrimExpr VisitExpr_(const AddNode* op) final { - return AddSubVec(op); + return AddSubVec(op, [](PrimExpr a, PrimExpr b) { return a + b; }); } + PrimExpr VisitExpr_(const SubNode* op) final { - return AddSubVec(op); + return AddSubVec(op, [](PrimExpr a, PrimExpr b) { return a - b; }); } + PrimExpr VisitExpr_(const MulNode* op) final { PrimExpr a = this->VisitExpr(op->a); PrimExpr b = this->VisitExpr(op->b); - if (a.same_as(op->a) && - b.same_as(op->b)) { + if (a.same_as(op->a) && b.same_as(op->b)) { return GetRef(op); } else { int lanes = std::max(a.dtype().lanes(), b.dtype().lanes()); @@ -129,67 +125,37 @@ class Vectorizer : public StmtExprMutator { const RampNode* b_ramp = b.as(); const RampNode* a_ramp = a.as(); if (a_ramp && b.dtype().lanes() == 1 && analyzer_.CanProve(b > 0)) { - return RampNode::make( - a_ramp->base * b, a_ramp->stride * b, a_ramp->lanes); + return Ramp(a_ramp->base * b, a_ramp->stride * b, a_ramp->lanes); } if (b_ramp && a.dtype().lanes() == 1 && analyzer_.CanProve(a > 0)) { - return RampNode::make( - b_ramp->base * a, b_ramp->stride * a, b_ramp->lanes); + return Ramp(b_ramp->base * a, b_ramp->stride * a, b_ramp->lanes); } } - return MulNode::make(BroadcastTo(a, lanes), BroadcastTo(b, lanes)); - } - return BinaryVec(op); - } - PrimExpr VisitExpr_(const DivNode* op) final { - return BinaryVec(op); - } - PrimExpr VisitExpr_(const ModNode* op) final { - return BinaryVec(op); - } - PrimExpr VisitExpr_(const FloorDivNode* op) final { - return BinaryVec(op); - } - PrimExpr VisitExpr_(const FloorModNode* op) final { - return BinaryVec(op); - } - PrimExpr VisitExpr_(const MinNode* op) final { - return BinaryVec(op); - } - PrimExpr VisitExpr_(const MaxNode* op) final { - return BinaryVec(op); - } - PrimExpr VisitExpr_(const EQNode* op) final { - return BinaryVec(op); - } - PrimExpr VisitExpr_(const NENode* op) final { - return BinaryVec(op); - } - PrimExpr VisitExpr_(const LTNode* op) final { - return BinaryVec(op); - } - PrimExpr VisitExpr_(const LENode* op) final { - return BinaryVec(op); - } - PrimExpr VisitExpr_(const GTNode* op) final { - return BinaryVec(op); - } - PrimExpr VisitExpr_(const GENode* op) final { - return BinaryVec(op); - } - PrimExpr VisitExpr_(const AndNode* op) final { - return BinaryVec(op); - } - PrimExpr VisitExpr_(const OrNode* op) final { - return BinaryVec(op); - } + return Mul(BroadcastTo(a, lanes), BroadcastTo(b, lanes)); + } + return BinaryVec(op); + } + PrimExpr VisitExpr_(const DivNode* op) final { return BinaryVec
(op); } + PrimExpr VisitExpr_(const ModNode* op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const FloorDivNode* op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const FloorModNode* op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const MinNode* op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const MaxNode* op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const EQNode* op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const NENode* op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const LTNode* op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const LENode* op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const GTNode* op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const GENode* op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const AndNode* op) final { return BinaryVec(op); } + PrimExpr VisitExpr_(const OrNode* op) final { return BinaryVec(op); } PrimExpr VisitExpr_(const RampNode* op) final { PrimExpr base = this->VisitExpr(op->base); PrimExpr stride = this->VisitExpr(op->stride); if (base.dtype().lanes() > 1 && stride.dtype().lanes() == 1) { const RampNode* base_ramp = base.as(); if (analyzer_.CanProve(base_ramp->stride == stride * make_const(stride.dtype(), op->lanes))) { - return RampNode::make(base_ramp->base, stride, op->lanes * base_ramp->lanes); + return Ramp(base_ramp->base, stride, op->lanes * base_ramp->lanes); } } int lanes = std::max(base.dtype().lanes(), stride.dtype().lanes()); @@ -198,33 +164,27 @@ class Vectorizer : public StmtExprMutator { Array elems; for (int i = 0; i < lanes; ++i) { elems.push_back( - RampNode::make(ShuffleNode::make_extract_element(base, i), - ShuffleNode::make_extract_element(stride, i), - op->lanes)); + Ramp(Shuffle::ExtractElement(base, i), Shuffle::ExtractElement(stride, i), op->lanes)); } - return ShuffleNode::make_concat(elems); + return Shuffle::Concat(elems); } - PrimExpr VisitExpr_(const SelectNode *op) final { + PrimExpr VisitExpr_(const SelectNode* op) final { PrimExpr cond = this->VisitExpr(op->condition); PrimExpr t = this->VisitExpr(op->true_value); PrimExpr f = this->VisitExpr(op->false_value); - if (cond.same_as(op->condition) && - t.same_as(op->true_value) && - f.same_as(op->false_value)) { + if (cond.same_as(op->condition) && t.same_as(op->true_value) && f.same_as(op->false_value)) { return GetRef(op); } else { - int lanes = std::max(std::max( - cond.dtype().lanes(), - t.dtype().lanes()), f.dtype().lanes()); - return SelectNode::make(cond, BroadcastTo(t, lanes), BroadcastTo(f, lanes)); + int lanes = std::max(std::max(cond.dtype().lanes(), t.dtype().lanes()), f.dtype().lanes()); + return Select(cond, BroadcastTo(t, lanes), BroadcastTo(f, lanes)); } } - PrimExpr VisitExpr_(const CastNode *op) final { + PrimExpr VisitExpr_(const CastNode* op) final { PrimExpr value = this->VisitExpr(op->value); if (value.same_as(op->value)) { return GetRef(op); } else { - return CastNode::make(op->dtype.with_lanes(value.dtype().lanes()), value); + return Cast(op->dtype.with_lanes(value.dtype().lanes()), value); } } // Variable @@ -232,31 +192,27 @@ class Vectorizer : public StmtExprMutator { if (v == var_.get()) { return ramp_; } else if (lets_.count(v)) { - return lets_[v]; + return lets_[v]; } else { return GetRef(v); } } // IfThenElse expr - PrimExpr MutateIfThenElseExpr_(const CallNode *op) { + PrimExpr MutateIfThenElseExpr_(const CallNode* op) { PrimExpr cond = this->VisitExpr(op->args[0]); - if (cond.dtype().is_vector()) { + if (cond.dtype().is_vector()) { need_scalarize_ = true; return GetRef(op); } PrimExpr t = this->VisitExpr(op->args[1]); PrimExpr f = this->VisitExpr(op->args[2]); - if (cond.same_as(op->args[0]) && - t.same_as(op->args[1]) && - f.same_as(op->args[2])) { + if (cond.same_as(op->args[0]) && t.same_as(op->args[1]) && f.same_as(op->args[2])) { return GetRef(op); } else { int lanes = std::max(t.dtype().lanes(), f.dtype().lanes()); t = BroadcastTo(t, lanes); f = BroadcastTo(f, lanes); - return CallNode::make( - op->dtype.with_lanes(lanes), op->name, - {cond, t, f}, op->call_type, op->func, op->value_index); + return Call(op->dtype.with_lanes(lanes), op->name, {cond, t, f}, op->call_type); } } // Call @@ -278,8 +234,7 @@ class Vectorizer : public StmtExprMutator { if (op->args.same_as(new_args)) { return GetRef(op); } else { - return CallNode::make( - op->dtype, op->name, new_args, op->call_type, op->func, op->value_index); + return Call(op->dtype, op->name, new_args, op->call_type); } } else { int lane = 0; @@ -288,9 +243,7 @@ class Vectorizer : public StmtExprMutator { if (op->args.same_as(new_args)) { return GetRef(op); } else { - return CallNode::make( - op->dtype.with_lanes(lane), op->name, new_args, - op->call_type, op->func, op->value_index); + return Call(op->dtype.with_lanes(lane), op->name, new_args, op->call_type); } } } @@ -302,11 +255,8 @@ class Vectorizer : public StmtExprMutator { return GetRef(op); } else { int lanes = std::max(index.dtype().lanes(), pred.dtype().lanes()); - return LoadNode::make( - op->dtype.with_lanes(lanes), - op->buffer_var, - BroadcastTo(index, lanes), - BroadcastTo(pred, lanes)); + return Load(op->dtype.with_lanes(lanes), op->buffer_var, BroadcastTo(index, lanes), + BroadcastTo(pred, lanes)); } } // Let @@ -316,28 +266,19 @@ class Vectorizer : public StmtExprMutator { if (value.dtype().lanes() != op->value.dtype().lanes()) { Var v(op->var->name_hint, value.dtype()); lets_[op->var.get()] = v; - return LetNode::make(v, value, this->VisitExpr(op->body)); + return Let(v, value, this->VisitExpr(op->body)); } else { PrimExpr body = this->VisitExpr(op->body); - if (value.same_as(op->value) && - body.same_as(op->body)) { + if (value.same_as(op->value) && body.same_as(op->body)) { return GetRef(op); } else { - return LetNode::make(op->var, value, body); + return Let(op->var, value, body); } } } - // Provide - Stmt VisitStmt_(const ProvideNode* op) final { - PrimExpr new_value = this->VisitExpr(op->value); - int lane = new_value.dtype().lanes(); - Array new_args = MutateArray(op->args, &lane); - if (op->args.same_as(new_args) && op->value.same_as(new_value)) { - return GetRef(op); - } else { - new_value = BroadcastTo(new_value, lane); - return ProvideNode::make(op->func, op->value_index, new_value, new_args); - } + Stmt VisitStmt_(const ProducerStoreNode* op) final { + LOG(FATAL) << "ProducerProvide is cannot appear in a TIR PrimFunc"; + return Stmt(); } // Store Stmt VisitStmt_(const StoreNode* op) final { @@ -349,10 +290,8 @@ class Vectorizer : public StmtExprMutator { } else { int lanes = std::max(value.dtype().lanes(), index.dtype().lanes()); lanes = std::max(lanes, pred.dtype().lanes()); - return StoreNode::make(op->buffer_var, - BroadcastTo(value, lanes), - BroadcastTo(index, lanes), - BroadcastTo(pred, lanes)); + return Store(op->buffer_var, BroadcastTo(value, lanes), BroadcastTo(index, lanes), + BroadcastTo(pred, lanes)); } } // For @@ -367,13 +306,10 @@ class Vectorizer : public StmtExprMutator { return Scalarize(GetRef(op)); } Stmt body = this->VisitStmt(op->body); - if (extent.same_as(op->extent) && - body.same_as(op->body)) { + if (extent.same_as(op->extent) && body.same_as(op->body)) { return GetRef(op); } else { - return ForNode::make( - op->loop_var, op->min, extent, - op->for_type, op->device_api, body); + return For(op->loop_var, op->min, extent, op->for_type, op->device_api, body); } } // IfThenElse @@ -388,12 +324,11 @@ class Vectorizer : public StmtExprMutator { if (op->else_case.defined()) { else_case = this->VisitStmt(op->else_case); } - if (condition.same_as(op->condition) && - then_case.same_as(op->then_case) && + if (condition.same_as(op->condition) && then_case.same_as(op->then_case) && else_case.same_as(op->else_case)) { return GetRef(op); } else { - return IfThenElseNode::make(condition, then_case, else_case); + return IfThenElse(condition, then_case, else_case); } } // LetStmt @@ -420,19 +355,16 @@ class Vectorizer : public StmtExprMutator { // place the vector lanes in least significant dimension. extents.push_back(var_lanes_); // rewrite access to buffer internally. - Stmt body = VecAllocAccess( - op->buffer_var.get(), var_, var_lanes_)(op->body); + Stmt body = VecAllocAccess(op->buffer_var.get(), var_, var_lanes_)(op->body); body = this->VisitStmt(body); - return AllocateNode::make( - op->buffer_var, op->dtype, - extents, condition, body); + return Allocate(op->buffer_var, op->dtype, extents, condition, body); } // scalarize the statment Stmt Scalarize(Stmt stmt) { Var idx(var_->name_hint + ".s", var_->dtype); Map values{{var_, idx}}; stmt = Substitute(stmt, values); - return ForNode::make(idx, 0, var_lanes_, ForType::Serial, DeviceAPI::None, stmt); + return For(idx, 0, var_lanes_, ForType::Serial, DeviceAPI::None, stmt); } private: @@ -472,24 +404,23 @@ class Vectorizer : public StmtExprMutator { if (!changed) return arr; return Array(new_arr); } - template + template PrimExpr BinaryVec(const T* op) { + static_assert(std::is_same::value, "constraint"); PrimExpr a = this->VisitExpr(op->a); PrimExpr b = this->VisitExpr(op->b); - if (a.same_as(op->a) && - b.same_as(op->b)) { + if (a.same_as(op->a) && b.same_as(op->b)) { return GetRef(op); } else { int lanes = std::max(a.dtype().lanes(), b.dtype().lanes()); - return T::make(BroadcastTo(a, lanes), BroadcastTo(b, lanes)); + return TOp(BroadcastTo(a, lanes), BroadcastTo(b, lanes)); } } - template - PrimExpr AddSubVec(const T* op) { + template + PrimExpr AddSubVec(const T* op, FCompute fcompute) { PrimExpr a = this->VisitExpr(op->a); PrimExpr b = this->VisitExpr(op->b); - if (a.same_as(op->a) && - b.same_as(op->b)) { + if (a.same_as(op->a) && b.same_as(op->b)) { return GetRef(op); } else { int lanes = std::max(a.dtype().lanes(), b.dtype().lanes()); @@ -497,17 +428,14 @@ class Vectorizer : public StmtExprMutator { const RampNode* b_ramp = b.as(); const RampNode* a_ramp = a.as(); if (a.dtype().lanes() == 1 && b_ramp) { - return RampNode::make( - arith::Compute(a, b_ramp->base), - arith::Compute(make_zero(b_ramp->stride.dtype()), b_ramp->stride), - b_ramp->lanes); + return Ramp(fcompute(a, b_ramp->base), + fcompute(make_zero(b_ramp->stride.dtype()), b_ramp->stride), b_ramp->lanes); } if (b.dtype().lanes() == 1 && a_ramp) { - return RampNode::make( - arith::Compute(a_ramp->base, b), a_ramp->stride, a_ramp->lanes); + return Ramp(fcompute(a_ramp->base, b), a_ramp->stride, a_ramp->lanes); } } - return T::make(BroadcastTo(a, lanes), BroadcastTo(b, lanes)); + return fcompute(BroadcastTo(a, lanes), BroadcastTo(b, lanes)); } } }; @@ -517,21 +445,18 @@ class LoopVectorizer : public StmtMutator { Stmt VisitStmt_(const ForNode* op) final { if (op->for_type == ForType::Vectorized) { CHECK(is_zero(op->min)); - int lanes = 0; - bool succ = arith::GetConstInt(op->extent, &lanes); - if (!succ || lanes < 1) { + auto* extent_as_int = op->extent.as(); + if (!extent_as_int || extent_as_int->value < 1) { LOG(FATAL) << "Failed to vectorize loop with extent " << op->extent; } - return Vectorizer(op->loop_var, lanes)(op->body); + return Vectorizer(op->loop_var, static_cast(extent_as_int->value))(op->body); } else { return StmtMutator::VisitStmt_(op); } } }; -Stmt VectorizeLoop(Stmt stmt) { - return LoopVectorizer()(std::move(stmt)); -} +Stmt VectorizeLoop(Stmt stmt) { return LoopVectorizer()(std::move(stmt)); } class VectorizeSkipper : public StmtMutator { public: @@ -539,17 +464,34 @@ class VectorizeSkipper : public StmtMutator { Stmt stmt = StmtMutator::VisitStmt_(op); op = stmt.as(); if (op->for_type == ForType::Vectorized) { - return ForNode::make(op->loop_var, op->min, op->extent, ForType::Serial, op->device_api, - op->body); + return For(op->loop_var, op->min, op->extent, ForType::Serial, op->device_api, op->body); } else { - return stmt; + return stmt; } } }; -Stmt SkipVectorize(Stmt stmt) { - return VectorizeSkipper()(std::move(stmt)); +Stmt SkipVectorize(Stmt stmt) { return VectorizeSkipper()(std::move(stmt)); } + +namespace transform { + +// TODO(tvm-team): Make it as a target property. +Pass VectorizeLoop(bool enable_vectorize) { + auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + if (enable_vectorize) { + n->body = LoopVectorizer()(std::move(n->body)); + } else { + n->body = VectorizeSkipper()(std::move(n->body)); + } + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.VectorizeLoop", {}); } +TVM_REGISTER_GLOBAL("tir.transform.VectorizeLoop").set_body_typed(VectorizeLoop); + +} // namespace transform + } // namespace tir } // namespace tvm diff --git a/tests/cpp/ir_simplify_test.cc b/tests/cpp/arith_simplify_test.cc similarity index 74% rename from tests/cpp/ir_simplify_test.cc rename to tests/cpp/arith_simplify_test.cc index 69cf1298a320..341d9f8df062 100644 --- a/tests/cpp/ir_simplify_test.cc +++ b/tests/cpp/arith_simplify_test.cc @@ -19,38 +19,41 @@ #include #include -#include +#include #include -TEST(IRSIMPLIFY, MinMax) { +TEST(Simplify, MinMax) { + tvm::arith::Analyzer ana; auto x = tvm::te::var("x"); - auto e1 = (tvm::max(x, 1) - tvm::max(x, 1)) ; - auto e1s = tvm::tir::CanonicalSimplify(e1); + auto e1 = (tvm::max(x, 1) - tvm::max(x, 1)); + auto e1s = ana.canonical_simplify(e1); CHECK(tvm::tir::is_zero(e1s)); auto e2 = (x * tvm::min(x, 1)) - (x * tvm::min(x, 1)); - auto e2s = tvm::tir::CanonicalSimplify(e2); + auto e2s = ana.canonical_simplify(e2); CHECK(tvm::tir::is_zero(e2s)); } -TEST(IRSIMPLIFY, Mul) { +TEST(Simplify, Mul) { + tvm::arith::Analyzer ana; auto x = tvm::te::var("x"); - auto e = (x * x) - (x * x) ; - auto es = tvm::tir::CanonicalSimplify(e); + auto e = (x * x) - (x * x); + auto es = ana.canonical_simplify(e); CHECK(tvm::tir::is_zero(es)); } -TEST(IRSIMPLIFY, Mod) { +TEST(Simplify, Mod) { + tvm::arith::Analyzer ana; auto x = tvm::Integer(10); auto y = tvm::Integer(12); // Mod::make is used instead of % to avoid constant folding during // calling operator%(x,y). Mod::make doesn't try constant folding, // and therefore, the constant folding will be attempted in CanonicalSimplify - auto mod = tvm::tir::CanonicalSimplify(tvm::tir::ModNode::make(x, y)); - auto es = tvm::tir::CanonicalSimplify(mod - x); + auto mod = ana.canonical_simplify(tvm::tir::Mod(x, y)); + auto es = ana.canonical_simplify(mod - x); CHECK(tvm::tir::is_zero(es)); } -int main(int argc, char ** argv) { +int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; return RUN_ALL_TESTS(); diff --git a/tests/cpp/attrs_test.cc b/tests/cpp/attrs_test.cc index ccf1b251482f..7b301bd13f68 100644 --- a/tests/cpp/attrs_test.cc +++ b/tests/cpp/attrs_test.cc @@ -20,8 +20,8 @@ #include #include #include -#include #include +#include namespace tvm { namespace test { @@ -33,23 +33,17 @@ struct TestAttrs : public AttrsNode { double learning_rate; TVM_DECLARE_ATTRS(TestAttrs, "attrs.cpptest.TestAttrs") { - TVM_ATTR_FIELD(axis) - .set_default(10) - .set_lower_bound(1) - .set_upper_bound(10) - .describe("axis field"); - TVM_ATTR_FIELD(name) - .describe("name of the field"); + TVM_ATTR_FIELD(axis).set_default(10).set_lower_bound(1).set_upper_bound(10).describe( + "axis field"); + TVM_ATTR_FIELD(name).describe("name of the field"); TVM_ATTR_FIELD(expr) .describe("expression field") .set_default(tir::make_const(DataType::Int(32), 1)); - TVM_ATTR_FIELD(learning_rate) - .describe("learning_rate") - .set_default(0.1); + TVM_ATTR_FIELD(learning_rate).describe("learning_rate").set_default(0.1); } }; -} -} +} // namespace test +} // namespace tvm TEST(Attrs, Basic) { using namespace tvm; @@ -84,12 +78,11 @@ TEST(Attrs, Basic) { // Check docstring std::ostringstream os; n->PrintDocString(os); - LOG(INFO) << "docstring\n"<< os.str(); + LOG(INFO) << "docstring\n" << os.str(); CHECK(os.str().find("expr : PrimExpr, default=1") != std::string::npos); } - -int main(int argc, char ** argv) { +int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; return RUN_ALL_TESTS(); diff --git a/tests/cpp/build_module_test.cc b/tests/cpp/build_module_test.cc index 9333a3470715..fc9edf8ababe 100644 --- a/tests/cpp/build_module_test.cc +++ b/tests/cpp/build_module_test.cc @@ -20,12 +20,12 @@ #include #include #include -#include -#include #include +#include +#include -#include #include +#include TEST(BuildModule, Basic) { using namespace tvm; @@ -37,25 +37,23 @@ TEST(BuildModule, Basic) { auto A = placeholder(shape, DataType::Float(32), "A"); auto B = placeholder(shape, DataType::Float(32), "B"); - auto C = compute(A->shape, [&A, &B](PrimExpr i) { - return A[i] + B[i]; - }, "C"); + auto C = compute( + A->shape, [&A, &B](PrimExpr i) { return A[i] + B[i]; }, "C"); - auto s = create_schedule({ C->op }); + auto s = create_schedule({C->op}); auto cAxis = C->op.as()->axis; IterVar bx, tx; s[C].split(cAxis[0], 64, &bx, &tx); - auto args = Array({ A, B, C }); + auto args = Array({A, B, C}); std::unordered_map binds; - auto config = BuildConfig::Create(); auto target = target::llvm(); - auto lowered = lower(s, args, "func", binds, config); - auto module = build(lowered, target, Target(), config); + auto lowered = lower(s, args, "func", binds); + auto module = build(lowered, target, Target()); auto mali_target = Target::Create("opencl -model=Mali-T860MP4@800Mhz -device=mali"); CHECK_EQ(mali_target->str(), "opencl -model=Mali-T860MP4@800Mhz -device=mali"); @@ -94,32 +92,27 @@ TEST(BuildModule, Heterogeneous) { auto B = placeholder(shape, DataType::Float(32), "B"); auto C = placeholder(shape, DataType::Float(32), "C"); - auto elemwise_add = compute(A->shape, [&A, &B](PrimExpr i) { - return A[i] + B[i]; - }, "elemwise_add"); + auto elemwise_add = compute( + A->shape, [&A, &B](PrimExpr i) { return A[i] + B[i]; }, "elemwise_add"); auto copy = placeholder(shape, DataType::Float(32), "__copy"); - auto elemwise_sub = compute(C->shape, [©, &C](PrimExpr i) { - return copy[i] - C[i]; - }, "elemwise_sub"); + auto elemwise_sub = compute( + C->shape, [©, &C](PrimExpr i) { return copy[i] - C[i]; }, "elemwise_sub"); With cuda_scope(target_cuda); auto s1 = topi::cuda::schedule_injective(target_cuda, {elemwise_add}); - With llvm_scope(target_llvm); auto s2 = create_schedule({elemwise_sub->op}); - auto config = BuildConfig::Create(); auto args1 = Array({A, B, elemwise_add}); auto args2 = Array({copy, C, elemwise_sub}); std::unordered_map binds; - auto lowered_s1 = lower(s1, args1, "elemwise_add", binds, config); - auto lowered_s2 = lower(s2, args2, "elemwise_sub", binds, config); - Map inputs = {{target_cuda, lowered_s1}, - {target_llvm, lowered_s2}}; - auto module = build(inputs, Target(), config); + auto lowered_s1 = lower(s1, args1, "elemwise_add", binds); + auto lowered_s2 = lower(s2, args2, "elemwise_sub", binds); + Map inputs = {{target_cuda, lowered_s1}, {target_llvm, lowered_s2}}; + auto module = build(inputs, Target()); // Assertion for build. CHECK_EQ(module->imports().size(), 1); @@ -148,16 +141,13 @@ TEST(BuildModule, Heterogeneous) { "\"float32\"]]}}"; // Setup inputs. - auto a_val = - runtime::NDArray::Empty({n}, {kDLFloat, 32, 1}, {kDLCPU, 0}); - auto b_val = - runtime::NDArray::Empty({n}, {kDLFloat, 32, 1}, {kDLCPU, 0}); - auto c_val = - runtime::NDArray::Empty({n}, {kDLFloat, 32, 1}, {kDLCPU, 0}); + auto a_val = runtime::NDArray::Empty({n}, {kDLFloat, 32, 1}, {kDLCPU, 0}); + auto b_val = runtime::NDArray::Empty({n}, {kDLFloat, 32, 1}, {kDLCPU, 0}); + auto c_val = runtime::NDArray::Empty({n}, {kDLFloat, 32, 1}, {kDLCPU, 0}); - auto pa = (float*)a_val.ToDLPack()->dl_tensor.data; - auto pb = (float*)b_val.ToDLPack()->dl_tensor.data; - auto pc = (float*)c_val.ToDLPack()->dl_tensor.data; + auto pa = (float*)(a_val->data); + auto pb = (float*)(b_val->data); + auto pc = (float*)(c_val->data); // Assign values. for (int i = 0; i < n; i++) { @@ -174,8 +164,17 @@ TEST(BuildModule, Heterogeneous) { const runtime::PackedFunc* graph_runtime = tvm::runtime::Registry::Get("tvm.graph_runtime.create"); - runtime::Module mod = (*graph_runtime)( - json, module, cpu_dev_ty, cpu_dev_id, gpu_dev_ty, gpu_dev_id); + runtime::Module mod = + (*graph_runtime)(json, module, cpu_dev_ty, cpu_dev_id, gpu_dev_ty, gpu_dev_id); + + // test FFI for module. + auto test_ffi = PackedFunc([](TVMArgs args, TVMRetValue* rv) { + int tcode = args[1]; + CHECK_EQ(args[0].type_code(), tcode); + }); + + test_ffi(runtime::Module(mod), static_cast(kTVMModuleHandle)); + test_ffi(Optional(mod), static_cast(kTVMModuleHandle)); PackedFunc set_input = mod.GetFunction("set_input", false); PackedFunc run = mod.GetFunction("run", false); @@ -186,7 +185,7 @@ TEST(BuildModule, Heterogeneous) { run(); tvm::runtime::NDArray out = get_output(0); - float* p_out = (float*)out.ToDLPack()->dl_tensor.data; + float* p_out = (float*)out->data; // Check correctness. for (int i = 0; i < n; ++i) { @@ -194,7 +193,7 @@ TEST(BuildModule, Heterogeneous) { } } -int main(int argc, char ** argv) { +int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; return RUN_ALL_TESTS(); diff --git a/tests/cpp/container_test.cc b/tests/cpp/container_test.cc index c67df63e6e7e..efd6ac7e406f 100644 --- a/tests/cpp/container_test.cc +++ b/tests/cpp/container_test.cc @@ -20,8 +20,8 @@ #include #include #include -#include #include +#include #include #include @@ -35,8 +35,7 @@ class TestErrorSwitch { public: // Need this so that destructor of temporary objects don't interrupt our // testing. - TestErrorSwitch(const TestErrorSwitch& other) - : should_fail(other.should_fail) { + TestErrorSwitch(const TestErrorSwitch& other) : should_fail(other.should_fail) { const_cast(other).should_fail = false; } @@ -50,8 +49,7 @@ class TestErrorSwitch { } }; -class TestArrayObj : public Object, - public InplaceArrayBase { +class TestArrayObj : public Object, public InplaceArrayBase { public: static constexpr const uint32_t _type_index = TypeIndex::kDynamic; static constexpr const char* _type_key = "test.TestArrayObj"; @@ -112,8 +110,7 @@ TEST(InplaceArrayBase, BadExceptionSafety) { TestErrorSwitch f2{true}; TestErrorSwitch f3{false}; std::vector fields{f1, f2, f3}; - auto ptr = - make_inplace_array_object(fields.size()); + auto ptr = make_inplace_array_object(fields.size()); try { ptr->WrongInit(fields.begin(), fields.end()); } catch (...) { @@ -133,8 +130,7 @@ TEST(InplaceArrayBase, ExceptionSafety) { // since it's not initalized. TestErrorSwitch f2{true}; std::vector fields{f1, f2}; - auto ptr = - make_inplace_array_object(fields.size()); + auto ptr = make_inplace_array_object(fields.size()); try { ptr->Init(fields.begin(), fields.end()); } catch (...) { @@ -175,6 +171,106 @@ TEST(Array, Iterator) { CHECK(vector[1].as()->value == 2); } +TEST(Array, PushPop) { + using namespace tvm; + Array a; + std::vector b; + for (int i = 0; i < 10; ++i) { + a.push_back(i); + b.push_back(i); + ASSERT_EQ(a.front(), b.front()); + ASSERT_EQ(a.back(), b.back()); + ASSERT_EQ(a.size(), b.size()); + int n = a.size(); + for (int j = 0; j < n; ++j) { + ASSERT_EQ(a[j], b[j]); + } + } + for (int i = 9; i >= 0; --i) { + ASSERT_EQ(a.front(), b.front()); + ASSERT_EQ(a.back(), b.back()); + ASSERT_EQ(a.size(), b.size()); + a.pop_back(); + b.pop_back(); + int n = a.size(); + for (int j = 0; j < n; ++j) { + ASSERT_EQ(a[j], b[j]); + } + } + ASSERT_EQ(a.empty(), true); +} + +TEST(Array, ResizeReserveClear) { + using namespace tvm; + for (size_t n = 0; n < 10; ++n) { + Array a; + Array b; + a.resize(n); + b.reserve(n); + ASSERT_EQ(a.size(), n); + ASSERT_GE(a.capacity(), n); + a.clear(); + b.clear(); + ASSERT_EQ(a.size(), 0); + ASSERT_EQ(b.size(), 0); + } +} + +TEST(Array, InsertErase) { + using namespace tvm; + Array a; + std::vector b; + for (int n = 1; n <= 10; ++n) { + a.insert(a.end(), n); + b.insert(b.end(), n); + for (int pos = 0; pos <= n; ++pos) { + a.insert(a.begin() + pos, pos); + b.insert(b.begin() + pos, pos); + ASSERT_EQ(a.front(), b.front()); + ASSERT_EQ(a.back(), b.back()); + ASSERT_EQ(a.size(), n + 1); + ASSERT_EQ(b.size(), n + 1); + for (int k = 0; k <= n; ++k) { + ASSERT_EQ(a[k], b[k]); + } + a.erase(a.begin() + pos); + b.erase(b.begin() + pos); + } + ASSERT_EQ(a.front(), b.front()); + ASSERT_EQ(a.back(), b.back()); + ASSERT_EQ(a.size(), n); + } +} + +TEST(Array, InsertEraseRange) { + using namespace tvm; + Array range_a{-1, -2, -3, -4}; + std::vector range_b{-1, -2, -3, -4}; + Array a; + std::vector b; + for (size_t n = 1; n <= 10; ++n) { + a.insert(a.end(), n); + b.insert(b.end(), n); + for (size_t pos = 0; pos <= n; ++pos) { + a.insert(a.begin() + pos, range_a.begin(), range_a.end()); + b.insert(b.begin() + pos, range_b.begin(), range_b.end()); + ASSERT_EQ(a.front(), b.front()); + ASSERT_EQ(a.back(), b.back()); + ASSERT_EQ(a.size(), n + range_a.size()); + ASSERT_EQ(b.size(), n + range_b.size()); + size_t m = n + range_a.size(); + for (size_t k = 0; k < m; ++k) { + ASSERT_EQ(a[k], b[k]); + } + a.erase(a.begin() + pos, a.begin() + pos + range_a.size()); + b.erase(b.begin() + pos, b.begin() + pos + range_b.size()); + } + ASSERT_EQ(a.front(), b.front()); + ASSERT_EQ(a.back(), b.back()); + ASSERT_EQ(a.size(), n); + } +} + TEST(Map, Expr) { using namespace tvm; Var x("x"); @@ -187,11 +283,11 @@ TEST(Map, Expr) { CHECK(!dict.count(zz)); } -TEST(StrMap, Expr) { +TEST(Map, Str) { using namespace tvm; Var x("x"); auto z = max(x + 1 + 2, 100); - Map dict{{"x", z}, {"z", 2}}; + Map dict{{"x", z}, {"z", 2}}; CHECK(dict.size() == 2); CHECK(dict["x"].same_as(z)); } @@ -223,8 +319,8 @@ TEST(Map, Iterator) { using namespace tvm; PrimExpr a = 1, b = 2; Map map1{{a, b}}; - std::unordered_map map2( - map1.begin(), map1.end()); + std::unordered_map map2(map1.begin(), + map1.end()); CHECK(map2[a].as()->value == 2); } @@ -271,11 +367,26 @@ TEST(String, Comparisons) { string source = "a string"; string mismatch = "a string but longer"; String s{source}; + String m{mismatch}; CHECK_EQ(s == source, true); CHECK_EQ(s == mismatch, false); CHECK_EQ(s == source.data(), true); CHECK_EQ(s == mismatch.data(), false); + + CHECK_EQ(s < m, source < mismatch); + CHECK_EQ(s > m, source > mismatch); + CHECK_EQ(s <= m, source <= mismatch); + CHECK_EQ(s >= m, source >= mismatch); + CHECK_EQ(s == m, source == mismatch); + CHECK_EQ(s != m, source != mismatch); + + CHECK_EQ(m < s, mismatch < source); + CHECK_EQ(m > s, mismatch > source); + CHECK_EQ(m <= s, mismatch <= source); + CHECK_EQ(m >= s, mismatch >= source); + CHECK_EQ(m == s, mismatch == source); + CHECK_EQ(m != s, mismatch != source); } // Check '\0' handling @@ -402,7 +513,6 @@ TEST(String, Cast) { String s2 = Downcast(r); } - TEST(Optional, Composition) { Optional opt0(nullptr); Optional opt1 = String("xyz"); @@ -468,6 +578,18 @@ TEST(Optional, PackedCall) { CHECK(packedfunc("xyz", false).operator String() == "xyz"); CHECK(packedfunc("xyz", false).operator Optional() == "xyz"); CHECK(packedfunc(nullptr, true).operator Optional() == nullptr); + + // test FFI convention. + auto test_ffi = PackedFunc([](TVMArgs args, TVMRetValue* rv) { + int tcode = args[1]; + CHECK_EQ(args[0].type_code(), tcode); + }); + String s = "xyz"; + auto nd = NDArray::Empty({0, 1}, DataType::Float(32), DLContext{kDLCPU, 0}); + test_ffi(Optional(nd), static_cast(kTVMNDArrayHandle)); + test_ffi(Optional(s), static_cast(kTVMObjectRValueRefArg)); + test_ffi(s, static_cast(kTVMObjectHandle)); + test_ffi(String(s), static_cast(kTVMObjectRValueRefArg)); } int main(int argc, char** argv) { diff --git a/tests/cpp/crt_memory_test.cc b/tests/cpp/crt_memory_test.cc index 1c129166f122..c2582ba02525 100644 --- a/tests/cpp/crt_memory_test.cc +++ b/tests/cpp/crt_memory_test.cc @@ -27,7 +27,7 @@ TEST(CRTMemory, Alloc) { for (int idx = 0; idx < 65536; idx++) { - void * a = vmalloc(1); + void* a = vmalloc(1); EXPECT_EQ(vleak_size, 1); vfree(a); EXPECT_EQ(vleak_size, 0); @@ -36,9 +36,9 @@ TEST(CRTMemory, Alloc) { TEST(CRTMemory, Realloc) { for (int idx = 0; idx < 65536; idx++) { - void * a = vrealloc(0, 1); + void* a = vrealloc(0, 1); EXPECT_EQ(vleak_size, 1); - void * b = vrealloc(a, 1); + void* b = vrealloc(a, 1); EXPECT_EQ(a, b); EXPECT_EQ(vleak_size, 1); vfree(a); @@ -46,7 +46,7 @@ TEST(CRTMemory, Realloc) { } } -int main(int argc, char ** argv) { +int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; return RUN_ALL_TESTS(); diff --git a/tests/cpp/expr_test.cc b/tests/cpp/expr_test.cc index e17cc73b18a1..a5d47dd4d989 100644 --- a/tests/cpp/expr_test.cc +++ b/tests/cpp/expr_test.cc @@ -34,7 +34,6 @@ TEST(Expr, Basic) { CHECK(os.str() == "max(((x + 1) + 2), 100)"); } - TEST(ExprNodeRef, Basic) { using namespace tvm; using namespace tvm::tir; @@ -44,8 +43,7 @@ TEST(ExprNodeRef, Basic) { CHECK(GetRef(op).same_as(z)); } - -int main(int argc, char ** argv) { +int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; return RUN_ALL_TESTS(); diff --git a/tests/cpp/ir_functor_test.cc b/tests/cpp/ir_functor_test.cc index 3941de5eef17..8dae79929fe8 100644 --- a/tests/cpp/ir_functor_test.cc +++ b/tests/cpp/ir_functor_test.cc @@ -19,10 +19,10 @@ #include #include -#include -#include #include +#include #include +#include #include TEST(IRF, Basic) { @@ -32,14 +32,10 @@ TEST(IRF, Basic) { auto z = x + 1; NodeFunctor f; - f.set_dispatch([](const ObjectRef& n, int b) { - return b; - }); - f.set_dispatch([](const ObjectRef& n, int b) { - return b + 2; - }); - CHECK_EQ(f(x, 2), 2); - CHECK_EQ(f(z, 2), 4); + f.set_dispatch([](const ObjectRef& n, int b) { return b; }); + f.set_dispatch([](const ObjectRef& n, int b) { return b + 2; }); + CHECK_EQ(f(x, 2), 2); + CHECK_EQ(f(z, 2), 4); } TEST(IRF, CountVar) { @@ -51,37 +47,31 @@ TEST(IRF, CountVar) { auto z = x + 1 + y + y; tir::PostOrderVisit(z, [&n_var](const ObjectRef& n) { if (n.as()) ++n_var; - }); + }); CHECK_EQ(n_var, 2); } - TEST(IRF, ExprTransform) { using namespace tvm; using namespace tvm::tir; Var x("x"); auto z = x + 1; - class MyExprFunctor - : public tir::ExprFunctor { + class MyExprFunctor : public tir::ExprFunctor { public: - int VisitExpr_(const VarNode* op, int b) final { - return b; - } - int VisitExpr_(const IntImmNode* op, int b) final { - return op->value; - } + int VisitExpr_(const VarNode* op, int b) final { return b; } + int VisitExpr_(const IntImmNode* op, int b) final { return op->value; } int VisitExpr_(const AddNode* op, int b) final { return VisitExpr(op->a, b) + VisitExpr(op->b, b); } }; MyExprFunctor f; - CHECK_EQ(f(x, 2), 2); - CHECK_EQ(f(z, 2), 3); + CHECK_EQ(f(x, 2), 2); + CHECK_EQ(f(z, 2), 3); try { f(z - 1, 2); LOG(FATAL) << "should fail"; - } catch(dmlc::Error) { + } catch (dmlc::Error) { } } @@ -91,50 +81,40 @@ TEST(IRF, ExprVisit) { Var x("x"); auto z = x + 1; - class MyVisitor - : public tir::ExprFunctor, - public tir::StmtFunctor { + class MyVisitor : public tir::ExprFunctor, + public tir::StmtFunctor { public: int count = 0; // implementation - void VisitExpr_(const VarNode* op) final { - ++count; - } - void VisitExpr_(const IntImmNode* op) final { - } + void VisitExpr_(const VarNode* op) final { ++count; } + void VisitExpr_(const IntImmNode* op) final {} void VisitExpr_(const AddNode* op) final { VisitExpr(op->a); VisitExpr(op->b); } - void VisitStmt_(const EvaluateNode* op) final { - VisitExpr(op->value); - } + void VisitStmt_(const EvaluateNode* op) final { VisitExpr(op->value); } }; MyVisitor v; - v.VisitStmt(EvaluateNode::make(z)); + v.VisitStmt(Evaluate(z)); CHECK_EQ(v.count, 1); } - TEST(IRF, StmtVisitor) { using namespace tvm; using namespace tvm::tir; Var x("x"); - class MyVisitor - : public StmtExprVisitor { + class MyVisitor : public StmtExprVisitor { public: int count = 0; // implementation - void VisitExpr_(const VarNode* op) final { - ++count; - } + void VisitExpr_(const VarNode* op) final { ++count; } }; MyVisitor v; auto fmaketest = [&]() { auto z = x + 1; - Stmt body = EvaluateNode::make(z); + Stmt body = Evaluate(z); Var buffer("b", DataType::Handle()); - return AllocateNode::make(buffer, DataType::Float(32), {z, z}, const_true(), body); + return Allocate(buffer, DataType::Float(32), {z, z}, const_true(), body); }; v(fmaketest()); CHECK_EQ(v.count, 3); @@ -145,42 +125,34 @@ TEST(IRF, StmtMutator) { using namespace tvm::tir; Var x("x"); - class MyVisitor - : public tir::StmtMutator, - public tir::ExprMutator { + class MyVisitor : public tir::StmtMutator, public tir::ExprMutator { public: using StmtMutator::operator(); using ExprMutator::operator(); protected: // implementation - PrimExpr VisitExpr_(const AddNode* op) final { - return op->a; - } - Stmt VisitStmt_(const SeqStmtNode* op) final { - return StmtMutator::VisitSeqStmt_(op, true); - } - PrimExpr VisitExpr(const PrimExpr& expr) final { - return ExprMutator::VisitExpr(expr); - } + PrimExpr VisitExpr_(const AddNode* op) final { return op->a; } + Stmt VisitStmt_(const SeqStmtNode* op) final { return StmtMutator::VisitSeqStmt_(op, true); } + PrimExpr VisitExpr(const PrimExpr& expr) final { return ExprMutator::VisitExpr(expr); } }; auto fmakealloc = [&]() { auto z = x + 1; - Stmt body = EvaluateNode::make(z); + Stmt body = Evaluate(z); Var buffer("b", DataType::Handle()); - return AllocateNode::make(buffer, DataType::Float(32), {1, z}, const_true(), body); + return Allocate(buffer, DataType::Float(32), {1, z}, const_true(), body); }; auto fmakeif = [&]() { auto z = x + 1; - Stmt body = EvaluateNode::make(z); - return IfThenElseNode::make(x, EvaluateNode::make(0), body); + Stmt body = Evaluate(z); + return IfThenElse(x, Evaluate(0), body); }; MyVisitor v; { auto body = fmakealloc(); - Stmt body2 = EvaluateNode::make(1); + Stmt body2 = Evaluate(1); Stmt bref = body.as()->body; auto* extentptr = body.as()->extents.get(); Array arr{std::move(body), body2, body2}; @@ -220,13 +192,13 @@ TEST(IRF, StmtMutator) { } { - auto body = EvaluateNode::make(CallNode::make(DataType::Int(32), "xyz", {x + 1}, CallNode::Extern)); + auto body = Evaluate(Call(DataType::Int(32), "xyz", {x + 1}, CallNode::Extern)); auto res = v(std::move(body)); CHECK(res.as()->value.as()->args[0].same_as(x)); } { - auto body = fmakealloc(); - Stmt body2 = EvaluateNode::make(1); + Stmt body = fmakealloc(); + Stmt body2 = Evaluate(1); auto* ref2 = body2.get(); auto* extentptr = body.as()->extents.get(); // construct a recursive SeqStmt. @@ -242,8 +214,8 @@ TEST(IRF, StmtMutator) { { // Cannot cow because of bref - auto body = fmakealloc(); - Stmt body2 = EvaluateNode::make(1); + Stmt body = fmakealloc(); + Stmt body2 = Evaluate(1); auto* extentptr = body.as()->extents.get(); // construct a recursive SeqStmt. body = SeqStmt({body}); @@ -255,7 +227,7 @@ TEST(IRF, StmtMutator) { } } -int main(int argc, char ** argv) { +int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; return RUN_ALL_TESTS(); diff --git a/tests/cpp/object_protocol_test.cc b/tests/cpp/object_protocol_test.cc index 2977b6805e5c..0df802497434 100644 --- a/tests/cpp/object_protocol_test.cc +++ b/tests/cpp/object_protocol_test.cc @@ -19,8 +19,8 @@ #include #include -#include #include +#include namespace tvm { namespace test { @@ -47,6 +47,7 @@ class ObjA : public ObjBase { class ObjB : public ObjBase { public: static constexpr const uint32_t _type_index = TypeIndex::kDynamic; + static constexpr const uint32_t _type_child_slots = 0; static constexpr const char* _type_key = "test.ObjB"; TVM_DECLARE_BASE_OBJECT_INFO(ObjB, ObjBase); }; @@ -58,7 +59,6 @@ class ObjAA : public ObjA { TVM_DECLARE_FINAL_OBJECT_INFO(ObjAA, ObjA); }; - TVM_REGISTER_OBJECT_TYPE(ObjBase); TVM_REGISTER_OBJECT_TYPE(ObjA); TVM_REGISTER_OBJECT_TYPE(ObjB); @@ -96,7 +96,7 @@ TEST(ObjectHierachy, Basic) { CHECK(refB.as() != nullptr); } -int main(int argc, char ** argv) { +int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; return RUN_ALL_TESTS(); diff --git a/tests/cpp/packed_func_test.cc b/tests/cpp/packed_func_test.cc index 787e0c4b8f4d..523df9891332 100644 --- a/tests/cpp/packed_func_test.cc +++ b/tests/cpp/packed_func_test.cc @@ -19,11 +19,11 @@ #include #include -#include #include +#include #include -#include #include +#include TEST(PackedFunc, Basic) { using namespace tvm; @@ -34,15 +34,15 @@ TEST(PackedFunc, Basic) { DLTensor a; Var v = PackedFunc([&](TVMArgs args, TVMRetValue* rv) { - CHECK(args.num_args == 3); - CHECK(args.values[0].v_float64 == 1.0); - CHECK(args.type_codes[0] == kDLFloat); - CHECK(args.values[1].v_handle == &a); - CHECK(args.type_codes[1] == kTVMDLTensorHandle); - CHECK(args.values[2].v_handle == &x); - CHECK(args.type_codes[2] == kTVMOpaqueHandle); - *rv = Var("a"); - })(1.0, &a, handle); + CHECK(args.num_args == 3); + CHECK(args.values[0].v_float64 == 1.0); + CHECK(args.type_codes[0] == kDLFloat); + CHECK(args.values[1].v_handle == &a); + CHECK(args.type_codes[1] == kTVMDLTensorHandle); + CHECK(args.values[2].v_handle == &x); + CHECK(args.type_codes[2] == kTVMOpaqueHandle); + *rv = Var("a"); + })(1.0, &a, handle); CHECK(v->name_hint == "a"); } @@ -52,36 +52,32 @@ TEST(PackedFunc, Node) { using namespace tvm::runtime; Var x; Var t = PackedFunc([&](TVMArgs args, TVMRetValue* rv) { - CHECK(args.num_args == 1); - CHECK(args[0].IsObjectRef()); - Var b = args[0]; - CHECK(x.same_as(b)); - *rv = b; - })(x); + CHECK(args.num_args == 1); + CHECK(args[0].IsObjectRef()); + Var b = args[0]; + CHECK(x.same_as(b)); + *rv = b; + })(x); CHECK(t.same_as(x)); } TEST(PackedFunc, NDArray) { using namespace tvm; using namespace tvm::runtime; - auto x = NDArray::Empty( - {}, String2DLDataType("float32"), - TVMContext{kDLCPU, 0}); + auto x = NDArray::Empty({}, String2DLDataType("float32"), TVMContext{kDLCPU, 0}); reinterpret_cast(x->data)[0] = 10.0f; CHECK(x.use_count() == 1); - PackedFunc forward([&](TVMArgs args, TVMRetValue* rv) { - *rv = args[0]; - }); + PackedFunc forward([&](TVMArgs args, TVMRetValue* rv) { *rv = args[0]; }); NDArray ret = PackedFunc([&](TVMArgs args, TVMRetValue* rv) { - NDArray y = args[0]; - DLTensor* ptr = args[0]; - CHECK(ptr == x.operator->()); - CHECK(x.same_as(y)); - CHECK(x.use_count() == 2); - *rv = forward(y); - })(x); + NDArray y = args[0]; + DLTensor* ptr = args[0]; + CHECK(ptr == x.operator->()); + CHECK(x.same_as(y)); + CHECK(x.use_count() == 2); + *rv = forward(y); + })(x); CHECK(ret.use_count() == 2); CHECK(ret.same_as(x)); } @@ -90,48 +86,45 @@ TEST(PackedFunc, str) { using namespace tvm; using namespace tvm::runtime; PackedFunc([&](TVMArgs args, TVMRetValue* rv) { - CHECK(args.num_args == 1); - std::string x = args[0]; - CHECK(x == "hello"); - String y = args[0]; - CHECK(y == "hello"); - *rv = x; - })("hello"); + CHECK(args.num_args == 1); + std::string x = args[0]; + CHECK(x == "hello"); + String y = args[0]; + CHECK(y == "hello"); + *rv = x; + })("hello"); PackedFunc([&](TVMArgs args, TVMRetValue* rv) { - CHECK(args.num_args == 1); - runtime::String s = args[0]; - CHECK(s == "hello"); + CHECK(args.num_args == 1); + runtime::String s = args[0]; + CHECK(s == "hello"); })(runtime::String("hello")); } - TEST(PackedFunc, func) { using namespace tvm; using namespace tvm::runtime; - PackedFunc addone([&](TVMArgs args, TVMRetValue* rv) { - *rv = args[0].operator int() + 1; - }); + PackedFunc addone([&](TVMArgs args, TVMRetValue* rv) { *rv = args[0].operator int() + 1; }); // function as arguments int r0 = PackedFunc([](TVMArgs args, TVMRetValue* rv) { - PackedFunc f = args[0]; - // TVMArgValue -> Arguments as function - *rv = f(args[1]).operator int(); - })(addone, 1); + PackedFunc f = args[0]; + // TVMArgValue -> Arguments as function + *rv = f(args[1]).operator int(); + })(addone, 1); CHECK_EQ(r0, 2); int r1 = PackedFunc([](TVMArgs args, TVMRetValue* rv) { - // TVMArgValue -> TVMRetValue - *rv = args[1]; - })(2, 100); + // TVMArgValue -> TVMRetValue + *rv = args[1]; + })(2, 100); CHECK_EQ(r1, 100); int r2 = PackedFunc([&](TVMArgs args, TVMRetValue* rv) { - // re-assignment - *rv = args[0]; - // TVMRetValue -> Function argument - *rv = addone(args[0].operator PackedFunc()(args[1], 1)); - })(addone, 100); + // re-assignment + *rv = args[0]; + // TVMRetValue -> Function argument + *rv = addone(args[0].operator PackedFunc()(args[1], 1)); + })(addone, 100); CHECK_EQ(r2, 102); } @@ -140,14 +133,14 @@ TEST(PackedFunc, Expr) { using namespace tvm::runtime; // automatic conversion of int to expr PackedFunc addone([](TVMArgs args, TVMRetValue* rv) { - PrimExpr x = args[0]; - *rv = x.as()->value + 1; + PrimExpr x = args[0]; + *rv = x.as()->value + 1; }); int r0 = PackedFunc([](TVMArgs args, TVMRetValue* rv) { - PackedFunc f = args[0]; - // TVMArgValue -> Arguments as function - *rv = f(args[1]).operator int(); - })(addone, 1); + PackedFunc f = args[0]; + // TVMArgValue -> Arguments as function + *rv = f(args[1]).operator int(); + })(addone, 1); CHECK_EQ(r0, 2); } @@ -155,12 +148,10 @@ TEST(PackedFunc, Type) { using namespace tvm; using namespace tvm::runtime; auto get_type = PackedFunc([](TVMArgs args, TVMRetValue* rv) { - DataType x = args[0]; - *rv = x; - }); - auto get_type2 = PackedFunc([](TVMArgs args, TVMRetValue* rv) { - *rv = args[0]; - }); + DataType x = args[0]; + *rv = x; + }); + auto get_type2 = PackedFunc([](TVMArgs args, TVMRetValue* rv) { *rv = args[0]; }); CHECK(get_type("int32").operator DataType() == DataType::Int(32)); CHECK(get_type("float").operator DataType() == DataType::Float(32)); CHECK(get_type2("float32x2").operator DataType() == DataType::Float(32, 2)); @@ -174,9 +165,7 @@ TEST(TypedPackedFunc, HighOrder) { using BindFunc = TypedPackedFunc; BindFunc ftyped; ftyped = [](Int2Func f1, int value) -> Int1Func { - auto binded = [f1, value](int x) { - return f1(value, x); - }; + auto binded = [f1, value](int x) { return f1(value, x); }; Int1Func x(binded); return x; }; @@ -194,28 +183,23 @@ TEST(TypedPackedFunc, Deduce) { using tvm::runtime::detail::function_signature; TypedPackedFunc x; - auto f = [](int x) -> int { - return x + 1; - }; + auto f = [](int x) -> int { return x + 1; }; std::function y; - static_assert(std::is_same::FType, - int(float)>::value, "invariant1"); - static_assert(std::is_same::FType, - int(int)>::value, "invariant2"); - static_assert(std::is_same::FType, - void(float)>::value, "invariant3"); + static_assert(std::is_same::FType, int(float)>::value, + "invariant1"); + static_assert(std::is_same::FType, int(int)>::value, + "invariant2"); + static_assert(std::is_same::FType, void(float)>::value, + "invariant3"); } - TEST(PackedFunc, ObjectConversion) { using namespace tvm; using namespace tvm::tir; using namespace tvm::runtime; TVMRetValue rv; - auto x = NDArray::Empty( - {}, String2DLDataType("float32"), - TVMContext{kDLCPU, 0}); + auto x = NDArray::Empty({}, String2DLDataType("float32"), TVMContext{kDLCPU, 0}); // assign null rv = ObjectRef(); CHECK_EQ(rv.type_code(), kTVMNullptr); @@ -232,15 +216,15 @@ TEST(PackedFunc, ObjectConversion) { CHECK(!rv.IsObjectRef()); auto pf1 = PackedFunc([&](TVMArgs args, TVMRetValue* rv) { - CHECK_EQ(args[0].type_code(), kTVMNDArrayHandle); - CHECK(args[0].operator NDArray().same_as(x)); - CHECK(args[0].operator ObjectRef().same_as(x)); - CHECK(args[1].operator ObjectRef().get() == nullptr); - CHECK(args[1].operator NDArray().get() == nullptr); - CHECK(args[1].operator Module().get() == nullptr); - CHECK(args[1].operator Array().get() == nullptr); - CHECK(!args[0].IsObjectRef()); - }); + CHECK_EQ(args[0].type_code(), kTVMNDArrayHandle); + CHECK(args[0].operator NDArray().same_as(x)); + CHECK(args[0].operator ObjectRef().same_as(x)); + CHECK(args[1].operator ObjectRef().get() == nullptr); + CHECK(args[1].operator NDArray().get() == nullptr); + CHECK(args[1].operator Module().get() == nullptr); + CHECK(args[1].operator Array().get() == nullptr); + CHECK(!args[0].IsObjectRef()); + }); pf1(x, ObjectRef()); pf1(ObjectRef(x), NDArray()); @@ -259,14 +243,14 @@ TEST(PackedFunc, ObjectConversion) { CHECK(!rv.IsObjectRef()); auto pf2 = PackedFunc([&](TVMArgs args, TVMRetValue* rv) { - CHECK_EQ(args[0].type_code(), kTVMModuleHandle); - CHECK(args[0].operator Module().same_as(m)); - CHECK(args[0].operator ObjectRef().same_as(m)); - CHECK(args[1].operator ObjectRef().get() == nullptr); - CHECK(args[1].operator NDArray().get() == nullptr); - CHECK(args[1].operator Module().get() == nullptr); - CHECK(!args[0].IsObjectRef()); - }); + CHECK_EQ(args[0].type_code(), kTVMModuleHandle); + CHECK(args[0].operator Module().same_as(m)); + CHECK(args[0].operator ObjectRef().same_as(m)); + CHECK(args[1].operator ObjectRef().get() == nullptr); + CHECK(args[1].operator NDArray().get() == nullptr); + CHECK(args[1].operator Module().get() == nullptr); + CHECK(!args[0].IsObjectRef()); + }); pf2(m, ObjectRef()); pf2(ObjectRef(m), Module()); } @@ -275,13 +259,12 @@ TEST(TypedPackedFunc, RValue) { using namespace tvm; using namespace tvm::runtime; { - auto inspect = [](TVMArgs args, TVMRetValue* rv) { for (int i = 0; i < args.size(); ++i) { CHECK_EQ(args[0].type_code(), kTVMObjectRValueRefArg); } }; - PackedFunc finspect(inspect); + PackedFunc finspect(inspect); finspect(tir::Var("x")); } { @@ -325,7 +308,7 @@ TEST(TypedPackedFunc, RValue) { } } -int main(int argc, char ** argv) { +int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; return RUN_ALL_TESTS(); diff --git a/tests/cpp/pattern_match_test.cc b/tests/cpp/pattern_match_test.cc index 5cb79101a05e..5063509e4e35 100644 --- a/tests/cpp/pattern_match_test.cc +++ b/tests/cpp/pattern_match_test.cc @@ -17,9 +17,10 @@ * under the License. */ +#include "../src/arith/pattern_match.h" + #include #include -#include "../src/arith/pattern_match.h" TEST(Pattern, Basic) { using namespace tvm; @@ -49,6 +50,7 @@ TEST(Pattern, Basic) { CHECK(tir::ExprDeepEqual()(px.Eval(), x + 1)); } CHECK(!(px + min(py, px)).Match((x + 1) + max(y, (x + 1)))); + CHECK((px + min(py, px)).Match(z + min(y, z))); CHECK((px + truncdiv(py, px * py)).Match(x + truncdiv(2, x * 2))); CHECK((px - truncmod(py, px * pz)).Match(x - truncmod(2, x * 2))); @@ -64,8 +66,7 @@ TEST(Pattern, Basic) { CHECK((px >= py && px < pz).Match(x >= y && x < z)); CHECK((!(px > py || px != py)).Match(!(x > y || x != y))); { - CHECK(select(px >= pz, py, py + pz).Match( - tir::SelectNode::make((x + 1) >= 1, y, y + 1))); + CHECK(select(px >= pz, py, py + pz).Match(tir::Select((x + 1) >= 1, y, y + 1))); CHECK(tir::ExprDeepEqual()(px.Eval(), x + 1)); } // bit intrinsics @@ -81,52 +82,42 @@ TEST(Pattern, Basic) { CHECK((px - (~(py | (px * pz)))).Match(x - (~(2 | (x * 2))))); // select { - CHECK(select(px > pz, py, py + pz).Match( - tir::SelectNode::make(x > 1, y, y + 1))); + CHECK(select(px > pz, py, py + pz).Match(tir::Select(x > 1, y, y + 1))); CHECK(is_const_int(pz.Eval(), 1)); } - CHECK(!select(px > pz, py, py + pz).Match( - tir::SelectNode::make(x > 2, y, y + 1))); - CHECK(!select(px > pz, py, py).Match( - tir::SelectNode::make(x > 2, y, y + 1))); + CHECK(!select(px > pz, py, py + pz).Match(tir::Select(x > 2, y, y + 1))); + CHECK(!select(px > pz, py, py).Match(tir::Select(x > 2, y, y + 1))); { - CHECK(select(px, py, pz).Match( - tir::SelectNode::make(x > 2, y, y + 1))); + CHECK(select(px, py, pz).Match(tir::Select(x > 2, y, y + 1))); CHECK(tir::ExprDeepEqual()(pz.Eval(), y + 1)); } // if_then_else { - CHECK(if_then_else(px > pz, py, py + pz).Match( - if_then_else(x > 1, y, y + 1))); + CHECK(if_then_else(px > pz, py, py + pz).Match(if_then_else(x > 1, y, y + 1))); CHECK(is_const_int(pz.Eval(), 1)); } // cast pattern { - CHECK(!cast(PConst( - DataType::Int(32)), px).Match(tir::CastNode::make(DataType::Float(64), x))); - CHECK(cast(pt, px).Match(tir::CastNode::make(DataType::Float(64), x))); + CHECK(!cast(PConst(DataType::Int(32)), px).Match(tir::Cast(DataType::Float(64), x))); + CHECK(cast(pt, px).Match(tir::Cast(DataType::Float(64), x))); CHECK(pt.Eval() == DataType::Float(64)); auto zz = cast(pt, px).Eval(); - CHECK((cast(pt, px) - cast(pt, py)).Match( - tir::CastNode::make(DataType::Float(64), x) - tir::CastNode::make(DataType::Int(64), x))); - auto expr = tir::CastNode::make(DataType::Int(32), tir::CastNode::make(DataType::Float(64), x)); + CHECK((cast(pt, px) - cast(pt, py)) + .Match(tir::Cast(DataType::Float(64), x) - tir::Cast(DataType::Int(64), x))); + auto expr = tir::Cast(DataType::Int(32), tir::Cast(DataType::Float(64), x)); CHECK(!(cast(pt, cast(pt, px))).Match(expr)); } // ramp pattern { - CHECK(ramp(px, PConst(1), planes).Match( - tir::RampNode::make(x, 1, 10))); + CHECK(ramp(px, PConst(1), planes).Match(tir::Ramp(x, 1, 10))); CHECK(planes.Eval() == 10); - CHECK(!ramp(px, PConst(1), planes).Match( - tir::RampNode::make(x, 2, 10))); + CHECK(!ramp(px, PConst(1), planes).Match(tir::Ramp(x, 2, 10))); } // broadcast pattern { - CHECK(broadcast(px, planes).Match( - tir::BroadcastNode::make(x, 10))); + CHECK(broadcast(px, planes).Match(tir::Broadcast(x, 10))); CHECK(planes.Eval() == 10); - CHECK(broadcast(px * py , planes).Match( - tir::BroadcastNode::make(x * 10, 10))); + CHECK(broadcast(px * py, planes).Match(tir::Broadcast(x * 10, 10))); } } @@ -148,7 +139,7 @@ TEST(Pattern, IntImm) { CHECK(!(v * c).Match((tx + 1) * 3)); } -int main(int argc, char ** argv) { +int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; return RUN_ALL_TESTS(); diff --git a/tests/cpp/relay_build_module_test.cc b/tests/cpp/relay_build_module_test.cc index f5658fbce1e9..636593f9803e 100644 --- a/tests/cpp/relay_build_module_test.cc +++ b/tests/cpp/relay_build_module_test.cc @@ -18,61 +18,59 @@ */ #include -#include -#include -#include -#include -#include -#include -#include -#include #include #include -#include +#include #include +#include +#include +#include +#include +#include +#include #include +#include #include +#include using namespace tvm; using namespace tvm::relay; TVM_REGISTER_GLOBAL("test.strategy") -.set_body_typed([](const Attrs& attrs, const Array& inputs, - const Type& out_type, const Target& target) { - FTVMCompute fcompute = [](const Attrs& attrs, - const Array& inputs, - const Type& out_type) -> Array { + .set_body_typed([](const Attrs& attrs, const Array& inputs, const Type& out_type, + const Target& target) { + FTVMCompute fcompute = [](const Attrs& attrs, const Array& inputs, + const Type& out_type) -> Array { CHECK_EQ(inputs.size(), 2U); return {topi::add(inputs[0], inputs[1])}; - }; - FTVMSchedule fschedule = [](const Attrs& attrs, - const Array& outs, - const Target& target) { + }; + FTVMSchedule fschedule = [](const Attrs& attrs, const Array& outs, + const Target& target) { With target_scope(target); return topi::generic::schedule_injective(target, outs); - }; + }; - auto n = make_object(); - auto strategy = tvm::relay::OpStrategy(std::move(n)); - strategy.AddImplementation(fcompute, fschedule, "test.strategy", 10); - return strategy; -}); + auto n = make_object(); + auto strategy = tvm::relay::OpStrategy(std::move(n)); + strategy.AddImplementation(fcompute, fschedule, "test.strategy", 10); + return strategy; + }); TVM_REGISTER_GLOBAL("relay.backend.lower_call") -.set_body_typed([](const relay::Call& call, const Array& inputs, - const Target& target) { - static auto fstrategy = Op::GetAttr("FTVMStrategy"); - Op op = Downcast(call->op); - auto out_type = call->checked_type(); - OpStrategy strategy = fstrategy[op](call->attrs, inputs, out_type, target); - auto impl = strategy->specializations[0]->implementations[0]; - auto outs = impl.Compute(call->attrs, inputs, out_type); - auto f = tvm::runtime::Registry::Get("relay.backend._make_LoweredOutput"); - if (!f) { - LOG(FATAL) << "relay.backend._make_LoweredOutput is not registered"; - } - return (*f)(outs, impl); -}); + .set_body_typed([](const relay::Call& call, const Array& inputs, + const Target& target) { + static auto fstrategy = Op::GetAttrMap("FTVMStrategy"); + Op op = Downcast(call->op); + auto out_type = call->checked_type(); + OpStrategy strategy = fstrategy[op](call->attrs, inputs, out_type, target); + auto impl = strategy->specializations[0]->implementations[0]; + auto outs = impl.Compute(call->attrs, inputs, out_type); + auto f = tvm::runtime::Registry::Get("relay.backend._make_LoweredOutput"); + if (!f) { + LOG(FATAL) << "relay.backend._make_LoweredOutput is not registered"; + } + return (*f)(outs, impl); + }); TEST(Relay, BuildModule) { auto tensor_type = relay::TensorType({2, 3}, DataType::Float(32)); @@ -87,9 +85,9 @@ TEST(Relay, BuildModule) { auto B = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0}); auto C = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0}); - auto pA = (float*)A.ToDLPack()->dl_tensor.data; - auto pB = (float*)B.ToDLPack()->dl_tensor.data; - auto pC = (float*)C.ToDLPack()->dl_tensor.data; + auto pA = (float*)A->data; + auto pB = (float*)B->data; + auto pC = (float*)C->data; for (int i = 0; i < 6; ++i) { pA[i] = i; @@ -97,7 +95,7 @@ TEST(Relay, BuildModule) { pC[i] = i + 2; } // get schedule - auto reg = tvm::runtime::Registry::Get("relay.op._Register"); + auto reg = tvm::runtime::Registry::Get("ir.RegisterOpAttr"); if (!reg) { LOG(FATAL) << "no _Register"; } @@ -107,6 +105,7 @@ TEST(Relay, BuildModule) { } auto fgeneric = GenericFunc::Get("test.strategy_generic").set_default(*fs); (*reg)("add", "FTVMStrategy", fgeneric, 10); + (*reg)("add", "TShapeDataDependant", false, 10); // build auto pfb = tvm::runtime::Registry::Get("relay.build_module._BuildModule"); tvm::runtime::Module build_mod = (*pfb)(); @@ -132,7 +131,7 @@ TEST(Relay, BuildModule) { set_input_f("c", &C.ToDLPack()->dl_tensor); run_f(); tvm::runtime::NDArray Y = get_output_f(0); - auto pY = (float*)Y.ToDLPack()->dl_tensor.data; + auto pY = (float*)Y->data; for (int i = 0; i < 6; ++i) { CHECK_LT(fabs(pY[i] - (i + (i + 1) + (i + 2))), 1e-4); } @@ -142,20 +141,20 @@ TEST(Relay, BuildModule) { } run_f(); tvm::runtime::NDArray Y2 = get_output_f(0); - auto pY2 = (float*)Y2.ToDLPack()->dl_tensor.data; + auto pY2 = (float*)Y2->data; for (int i = 0; i < 6; ++i) { CHECK_LT(fabs(pY2[i] - (i + (i + 3) + (i + 2))), 1e-4); } // attach a different input and run it again auto C2 = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0}); - auto pC2 = (float*)C2.ToDLPack()->dl_tensor.data; + auto pC2 = (float*)C2->data; for (int i = 0; i < 6; ++i) { pC2[i] = i + 4; } set_input_f("c", &C2.ToDLPack()->dl_tensor); run_f(); tvm::runtime::NDArray Y3 = get_output_f(0); - auto pY3 = (float*)Y3.ToDLPack()->dl_tensor.data; + auto pY3 = (float*)Y3->data; for (int i = 0; i < 6; ++i) { CHECK_LT(fabs(pY3[i] - (i + (i + 3) + (i + 4))), 1e-4); } @@ -178,7 +177,7 @@ TEST(Relay, GetExprRefCount) { CHECK(ref_count[z.get()] == 1); } -int main(int argc, char ** argv) { +int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; return RUN_ALL_TESTS(); diff --git a/tests/cpp/relay_pass_type_infer_test.cc b/tests/cpp/relay_pass_type_infer_test.cc index 3c416918e441..cb7330dfab6d 100644 --- a/tests/cpp/relay_pass_type_infer_test.cc +++ b/tests/cpp/relay_pass_type_infer_test.cc @@ -19,30 +19,30 @@ #include #include -#include -#include -#include #include +#include #include +#include +#include TEST(Relay, SelfReference) { using namespace tvm; auto tensor_type = relay::TensorType({}, DataType::Bool()); auto x = relay::Var("x", relay::Type()); - auto f = relay::Function(tvm::Array{ x }, x, relay::Type(), {}); + auto f = relay::Function(tvm::Array{x}, x, relay::Type(), {}); CHECK(f->IsInstance()); auto y = relay::Var("y", tensor_type); - auto call = relay::Call(f, Array{ y }); - auto fx = relay::Function(tvm::Array{ y }, call, relay::Type(), {}); + auto call = relay::Call(f, Array{y}); + auto fx = relay::Function(tvm::Array{y}, call, relay::Type(), {}); auto mod = IRModule::FromExpr(fx); mod = relay::transform::InferType()(mod); auto type_fx = mod->Lookup("main"); - auto expected = relay::FuncType(tvm::Array{ tensor_type }, tensor_type, {}, {}); + auto expected = relay::FuncType(tvm::Array{tensor_type}, tensor_type, {}, {}); CHECK(tvm::StructuralEqual()(type_fx->checked_type(), expected)); } -int main(int argc, char ** argv) { +int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; return RUN_ALL_TESTS(); diff --git a/tests/cpp/relay_transform_sequential.cc b/tests/cpp/relay_transform_sequential_test.cc similarity index 61% rename from tests/cpp/relay_transform_sequential.cc rename to tests/cpp/relay_transform_sequential_test.cc index d974f023d74b..bb4bf928b018 100644 --- a/tests/cpp/relay_transform_sequential.cc +++ b/tests/cpp/relay_transform_sequential_test.cc @@ -18,28 +18,46 @@ */ #include +#include #include -#include #include -#include #include +#include #include +#include +#include +#include #include #include #include #include #include -TVM_REGISTER_GLOBAL("schedule") - .set_body([](tvm::TVMArgs args, tvm::TVMRetValue* rv) { - *rv = topi::generic::schedule_injective(args[0], args[1]); +using namespace tvm; + +TVM_REGISTER_GLOBAL("test.seq.strategy") + .set_body_typed([](const Attrs& attrs, const Array& inputs, const Type& out_type, + const Target& target) { + relay::FTVMCompute fcompute = [](const Attrs& attrs, const Array& inputs, + const Type& out_type) -> Array { + CHECK_EQ(inputs.size(), 2U); + return {topi::add(inputs[0], inputs[1])}; + }; + relay::FTVMSchedule fschedule = [](const Attrs& attrs, const Array& outs, + const Target& target) { + With target_scope(target); + return topi::generic::schedule_injective(target, outs); + }; + + auto n = make_object(); + auto strategy = relay::OpStrategy(std::move(n)); + strategy.AddImplementation(fcompute, fschedule, "test.strategy", 10); + return strategy; }); TEST(Relay, Sequential) { - using namespace tvm; auto tensor_type = relay::TensorType({1, 2, 3}, DataType::Float(32)); - auto c_data = - tvm::runtime::NDArray::Empty({1, 2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0}); + auto c_data = tvm::runtime::NDArray::Empty({1, 2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0}); // Create a function for optimization. auto c = relay::Constant(c_data); @@ -53,30 +71,28 @@ TEST(Relay, Sequential) { auto z2 = relay::Call(add_op, {z, z1}); // Let expression and varaible a should be dead-code eliminated. auto z3 = relay::Let(a, c, z2); - relay::Function func = - relay::Function(relay::FreeVars(z3), z3, relay::Type(), {}); + relay::Function func = relay::Function(relay::FreeVars(z3), z3, relay::Type(), {}); - // Get schedule - auto reg = tvm::runtime::Registry::Get("relay.op._Register"); - auto sch = tvm::runtime::Registry::Get("schedule"); - if (!reg || !sch) { - LOG(FATAL) << "Register/schedule is not defined."; + auto reg = tvm::runtime::Registry::Get("ir.RegisterOpAttr"); + if (!reg) { + LOG(FATAL) << "Register is not defined."; } - - (*reg)("add", "FTVMSchedule", *sch, 10); + auto fs = tvm::runtime::Registry::Get("test.seq.strategy"); + if (!fs) { + LOG(FATAL) << "Strategy is not defined."; + } + auto fgeneric = GenericFunc::Get("test.strategy_generic").set_default(*fs); + (*reg)("add", "FTVMStrategy", fgeneric, 10); // Run sequential passes. tvm::Array pass_seqs{ - relay::transform::InferType(), - relay::transform::DeadCodeElimination(), - relay::transform::EliminateCommonSubexpr(), - relay::transform::AlterOpLayout() - }; + relay::transform::InferType(), relay::transform::DeadCodeElimination(), + relay::transform::EliminateCommonSubexpr(), relay::transform::AlterOpLayout()}; relay::transform::Pass seq = relay::transform::Sequential(pass_seqs); auto mod = IRModule::FromExpr(func); auto pass_ctx = relay::transform::PassContext::Create(); pass_ctx->opt_level = 3; - pass_ctx->fallback_device = 1; + pass_ctx->config.Set("relay.fallback_device_type", Integer(1)); { tvm::With ctx_scope(pass_ctx); tvm::With tctx(tvm::Target::Create("llvm")); @@ -96,8 +112,7 @@ TEST(Relay, Sequential) { y1 = relay::Call(add_op, {x1, y1}); auto zz = relay::Call(add_op, {y1, c1}); zz = relay::Call(add_op, {zz, zz}); - relay::Function expected_func = - relay::Function(relay::FreeVars(zz), zz, relay::Type(), {}); + relay::Function expected_func = relay::Function(relay::FreeVars(zz), zz, relay::Type(), {}); // Infer type for the expected function. auto mod1 = IRModule::FromExpr(expected_func); diff --git a/tests/cpp/simple_passes_test.cc b/tests/cpp/simple_passes_test.cc index a3c6b07ddc8a..36b36452f4fc 100644 --- a/tests/cpp/simple_passes_test.cc +++ b/tests/cpp/simple_passes_test.cc @@ -19,8 +19,8 @@ #include #include -#include #include +#include TEST(SimplePasses, HasSideEffect) { using namespace tvm; @@ -33,8 +33,7 @@ TEST(SimplePasses, HasSideEffect) { CHECK(!tvm::tir::HasSideEffect(A[0])); } - -int main(int argc, char ** argv) { +int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; return RUN_ALL_TESTS(); diff --git a/tests/cpp/tensor_test.cc b/tests/cpp/tensor_test.cc index a9566cb6d005..ea02ca656dce 100644 --- a/tests/cpp/tensor_test.cc +++ b/tests/cpp/tensor_test.cc @@ -30,9 +30,8 @@ TEST(Tensor, Basic) { Tensor A = placeholder({m, l}, DataType::Float(32), "A"); Tensor B = placeholder({n, l}, DataType::Float(32), "B"); - auto C = compute({m, n}, [&](Var i, Var j) { - return A[i][j]; - }, "C"); + auto C = compute( + {m, n}, [&](Var i, Var j) { return A[i][j]; }, "C"); Tensor::Slice x = A[n]; } @@ -46,13 +45,12 @@ TEST(Tensor, Reduce) { te::Tensor B = te::placeholder({n, l}, DataType::Float(32), "B"); IterVar rv = reduce_axis(Range{0, l}, "k"); - auto C = te::compute({m, n}, [&](Var i, Var j) { - return sum(max(1 + A[i][rv] + 1, B[j][rv]), {rv}); - }, "C"); + auto C = te::compute( + {m, n}, [&](Var i, Var j) { return sum(max(1 + A[i][rv] + 1, B[j][rv]), {rv}); }, "C"); LOG(INFO) << C->op.as()->body; } -int main(int argc, char ** argv) { +int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; return RUN_ALL_TESTS(); diff --git a/tests/cpp/threading_backend_test.cc b/tests/cpp/threading_backend_test.cc index 508705c2630a..cf7434b4b036 100644 --- a/tests/cpp/threading_backend_test.cc +++ b/tests/cpp/threading_backend_test.cc @@ -17,13 +17,13 @@ * under the License. */ +#include +#include + #include #include #include -#include -#include - constexpr size_t N = 128; static FTVMParallelLambda atomic_add_task_id = [](int task_id, TVMParallelGroupEnv* penv, diff --git a/tests/cpp/topi_ewise_test.cc b/tests/cpp/topi_ewise_test.cc index a1ca6d7fd229..10c7b9d7464b 100644 --- a/tests/cpp/topi_ewise_test.cc +++ b/tests/cpp/topi_ewise_test.cc @@ -17,9 +17,9 @@ * under the License. */ -#include -#include #include +#include +#include namespace topi { TEST(Tensor, Basic) { @@ -28,9 +28,9 @@ TEST(Tensor, Basic) { Tensor A = placeholder({m, l}, DataType::Float(32), "A"); auto C = topi::exp(A); } -} +} // namespace topi -int main(int argc, char ** argv) { +int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; return RUN_ALL_TESTS(); diff --git a/tests/cpp/utvm_runtime_standalone_test.cc b/tests/cpp/utvm_runtime_standalone_test.cc index 14f7de52a22d..70709b0f96a1 100644 --- a/tests/cpp/utvm_runtime_standalone_test.cc +++ b/tests/cpp/utvm_runtime_standalone_test.cc @@ -17,11 +17,11 @@ * under the License. */ -#include - #include #include + #include +#include #include #ifdef USE_MICRO_STANDALONE_RUNTIME @@ -30,9 +30,10 @@ #if defined(__APPLE__) && defined(__MACH__) #include +#include +#include #include #include -#include #include #include #include @@ -41,9 +42,7 @@ #include #include #include - -#include -#include +#include TVM_REGISTER_GLOBAL("test.sch").set_body([](tvm::TVMArgs args, tvm::TVMRetValue* rv) { *rv = topi::generic::schedule_injective(args[0], args[1]); @@ -52,20 +51,20 @@ TVM_REGISTER_GLOBAL("test.sch").set_body([](tvm::TVMArgs args, tvm::TVMRetValue* TEST(MicroStandaloneRuntime, BuildModule) { using namespace tvm; auto tensor_type = relay::TensorType({2, 3}, ::tvm::Float(32)); - auto a = relay::VarNode::make("a", tensor_type); - auto b = relay::VarNode::make("b", tensor_type); + auto a = relay::Var("a", tensor_type); + auto b = relay::Var("b", tensor_type); auto add_op = relay::Op::Get("add"); - auto x = relay::CallNode::make(add_op, {a, b}, tvm::Attrs(), {}); - auto c = relay::VarNode::make("c", tensor_type); - auto y = relay::CallNode::make(add_op, {x, c}, tvm::Attrs(), {}); + auto x = relay::Call(add_op, {a, b}, tvm::Attrs(), {}); + auto c = relay::Var("c", tensor_type); + auto y = relay::Call(add_op, {x, c}, tvm::Attrs(), {}); auto func = relay::Function(relay::FreeVars(y), y, relay::Type(), {}); auto A = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0}); auto B = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0}); auto C = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0}); - auto pA = (float*)A.ToDLPack()->dl_tensor.data; - auto pB = (float*)B.ToDLPack()->dl_tensor.data; - auto pC = (float*)C.ToDLPack()->dl_tensor.data; + auto pA = (float*)A->data; + auto pB = (float*)B->data; + auto pC = (float*)C->data; for (int i = 0; i < 6; ++i) { pA[i] = i; @@ -118,7 +117,7 @@ TEST(MicroStandaloneRuntime, BuildModule) { UTVMRuntimeRun(handle); auto Y = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0}); UTVMRuntimeGetOutput(handle, 0, &Y.ToDLPack()->dl_tensor); - auto* pY = (float*)Y.ToDLPack()->dl_tensor.data; + auto* pY = (float*)Y->data; for (int i = 0; i < 6; ++i) { CHECK_LT(fabs(pY[i] - (i + (i + 1) + (i + 2))), 1e-4); } diff --git a/tests/lint/add_asf_header.py b/tests/lint/add_asf_header.py index a44fbd3df1b5..21d25c25e573 100644 --- a/tests/lint/add_asf_header.py +++ b/tests/lint/add_asf_header.py @@ -181,7 +181,9 @@ def add_header(fname, header): skipline = False ext = os.path.splitext(fname)[1][1:] - if lines[0][:2] == "#!": + if not lines: + skipline = False # File is enpty + elif lines[0][:2] == "#!": skipline = True elif lines[0][:2] == "" + echo "" + echo "Run clang-format on files that changed since " + echo "Examples:" + echo "- Compare last one commit: tests/lint/git-clang-format.sh HEAD~1" + echo "- Compare against upstream/master: tests/lint/git-clang-format.sh upstream/master" + echo "You can also add -i option to do inplace format" + exit 1 +fi + +if [[ "$1" == "-i" ]]; then + INPLACE_FORMAT=1 + shift 1 +else + INPLACE_FORMAT=0 +fi + +cleanup() +{ + rm -rf /tmp/$$.clang-format.txt +} +trap cleanup 0 + +CLANG_FORMAT=clang-format-10 + +if [ -x "$(command -v clang-format-10)" ]; then + CLANG_FORMAT=clang-format-10 +elif [ -x "$(command -v clang-format)" ]; then + echo "clang-format might be different from clang-format-10, expect potential difference." + CLANG_FORMAT=clang-format +else + echo "Cannot find clang-format-10" + exit 1 +fi + +# Print out specific version +${CLANG_FORMAT} --version + +if [[ ${INPLACE_FORMAT} -eq 1 ]]; then + echo "Running inplace git-clang-format against" $1 + git-${CLANG_FORMAT} --extensions h,mm,c,cc --binary=${CLANG_FORMAT} $1 + exit 0 +fi + +echo "Running git-clang-format against" $1 +git-${CLANG_FORMAT} --diff --extensions h,mm,c,cc --binary=${CLANG_FORMAT} $1 1> /tmp/$$.clang-format.txt +echo "---------clang-format log----------" +cat /tmp/$$.clang-format.txt +echo "" +if grep --quiet -E "diff" < /tmp/$$.clang-format.txt; then + echo "clang-format lint error found. Consider running clang-format-10 on these files to fix them." + exit 1 +fi diff --git a/tests/lint/rat-excludes b/tests/lint/rat-excludes index 5421d22a08aa..0c3ab601e04a 100644 --- a/tests/lint/rat-excludes +++ b/tests/lint/rat-excludes @@ -28,9 +28,14 @@ core.cpp build _static _build +node_modules +dist .*~ \#..*\# \.#.* +.npm +.node_repl_history +node_modules # Relay parser: they are generated by ANTLR. RelayLexer.py @@ -40,6 +45,7 @@ RelayVisitor.py # Specific files package-list MANIFEST +.eslintignore .gitignore .gitattributes .gitmodules diff --git a/tests/micro/test_runtime_micro_on_arm.py b/tests/micro/test_runtime_micro_on_arm.py new file mode 100644 index 000000000000..301677eaa628 --- /dev/null +++ b/tests/micro/test_runtime_micro_on_arm.py @@ -0,0 +1,387 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import os + +import numpy as np +import tvm +from tvm import te +from tvm.contrib import graph_runtime, util +from tvm import relay +import tvm.micro as micro +from tvm.micro import create_micro_mod +from tvm.relay.testing import resnet + +# Use real micro device - an STM32F746 discovery board +# SETUP: +# Be sure to have openocd installed and running +# Ex : openocd -f board/stm32f7discovery.cfg +# Be sure to have the ST CMSIS library downloaded, installed and +# Ex : export CMSIS_ST_PATH="/home/yourid/st/STM32Cube_FW_F7_V1.16.0/Drivers/CMSIS" +DEV_CONFIG_A = micro.device.arm.stm32f746xx.generate_config("127.0.0.1", 6666) +DEV_CONFIG_B = micro.device.arm.stm32f746xx.generate_config("127.0.0.1", 6666) +TARGET = 'c -device=micro_dev' + +def relay_micro_build(func, dev_config, params=None): + """Create a graph runtime module with a micro device context from a Relay function. + + Parameters + ---------- + func : relay.Function + function to compile + + dev_config : Dict[str, Any] + MicroTVM config dict for the target device + + params : dict + input parameters that do not change during inference + + Return + ------ + mod : tvm.runtime.Module + graph runtime module for the target device + """ + with tvm.transform.PassContext(disabled_pass={'FuseOps'}, config={ + "tir.disable_vectorize": True + }): + graph, c_mod, params = relay.build(func, target=TARGET, params=params) + micro_mod = micro.create_micro_mod(c_mod, dev_config) + ctx = tvm.micro_dev(0) + mod = graph_runtime.create(graph, micro_mod, ctx) + mod.set_input(**params) + return mod + + +GDB_INIT_TEMPLATE = """ +layout asm +target remote localhost:{gdb_port} +set $pc = UTVMInit +break UTVMDone +""" + + +def reset_gdbinit(): + if 'server_port' not in DEV_CONFIG_A: + return + try: + gdb_init_dir = os.environ['MICRO_GDB_INIT_DIR'] + except KeyError: + return + with open(f'{gdb_init_dir}/.gdbinit', 'w') as f: + gdb_port = DEV_CONFIG_A['server_port'] - 3333 + f.write(GDB_INIT_TEMPLATE.format(gdb_port=gdb_port)) + + +def test_alloc(): + """Test tensor allocation on the device.""" + if not tvm.runtime.enabled("micro_dev"): + return + shape = (1024,) + dtype = "float32" + with micro.Session(DEV_CONFIG_A): + ctx = tvm.micro_dev(0) + np_tensor = np.random.uniform(size=shape).astype(dtype) + micro_tensor = tvm.nd.array(np_tensor, ctx) + tvm.testing.assert_allclose(np_tensor, micro_tensor.asnumpy()) + + +def test_add(): + """Test a module which performs addition.""" + if not tvm.runtime.enabled("micro_dev"): + return + shape = (1024,) + dtype = "float32" + + reset_gdbinit() + + # Construct TVM expression. + tvm_shape = tvm.runtime.convert(shape) + A = te.placeholder(tvm_shape, name="A", dtype=dtype) + B = te.placeholder(tvm_shape, name="B", dtype=dtype) + C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name="C") + s = te.create_schedule(C.op) + + func_name = "fadd" + c_mod = tvm.build(s, [A, B, C], target="c", name=func_name) + + with micro.Session(DEV_CONFIG_A) as sess: + micro_mod = micro.create_micro_mod(c_mod, DEV_CONFIG_A) + micro_func = micro_mod[func_name] + ctx = tvm.micro_dev(0) + + a_np = np.random.uniform(size=shape).astype(dtype) + a = tvm.nd.array(a_np, ctx) + b_np = np.random.uniform(size=shape).astype(dtype) + b = tvm.nd.array(b_np, ctx) + c = tvm.nd.array(np.zeros(shape, dtype=dtype), ctx) + micro_func(a, b, c) + + # ensure inputs weren't corrupted + tvm.testing.assert_allclose( + a.asnumpy(), a_np) + tvm.testing.assert_allclose( + b.asnumpy(), b_np) + # ensure output is correct + tvm.testing.assert_allclose( + c.asnumpy(), a.asnumpy() + b.asnumpy()) + + +def test_workspace_add(): + """Test a module which uses a workspace to compute an intermediate value.""" + if not tvm.runtime.enabled("micro_dev"): + return + shape = (1024,) + dtype = "float32" + + reset_gdbinit() + + # Construct TVM expression. + tvm_shape = tvm.runtime.convert(shape) + A = te.placeholder(tvm_shape, name="A", dtype=dtype) + B = te.placeholder(tvm_shape, name="B", dtype=dtype) + B = te.compute(A.shape, lambda *i: A(*i) + 1, name="B") + C = te.compute(A.shape, lambda *i: B(*i) + 1, name="C") + s = te.create_schedule(C.op) + + func_name = "fadd_two_workspace" + c_mod = tvm.build(s, [A, C], target="c", name=func_name) + + with micro.Session(DEV_CONFIG_A) as sess: + micro_mod = micro.create_micro_mod(c_mod, DEV_CONFIG_A) + micro_func = micro_mod[func_name] + ctx = tvm.micro_dev(0) + a_np = np.random.uniform(size=shape).astype(dtype) + a = tvm.nd.array(a_np, ctx) + c = tvm.nd.array(np.zeros(shape, dtype=dtype), ctx) + micro_func(a, c) + + # ensure input wasn't corrupted + tvm.testing.assert_allclose( + a.asnumpy(), a_np) + # ensure output is correct + tvm.testing.assert_allclose( + c.asnumpy(), a.asnumpy() + 2.0) + + +def test_graph_runtime(): + """Test a program which uses the graph runtime.""" + if not tvm.runtime.enabled("micro_dev"): + return + shape = (1024,) + dtype = "float32" + + # Construct Relay program. + x = relay.var("x", relay.TensorType(shape=shape, dtype=dtype)) + xx = relay.multiply(x, x) + z = relay.add(xx, relay.const(1.0)) + func = relay.Function([x], z) + + with micro.Session(DEV_CONFIG_A): + mod = relay_micro_build(func, DEV_CONFIG_A) + + x_in = np.random.uniform(size=shape[0]).astype(dtype) + mod.run(x=x_in) + result = mod.get_output(0).asnumpy() + + tvm.testing.assert_allclose( + mod.get_input(0).asnumpy(), x_in) + tvm.testing.assert_allclose( + result, x_in * x_in + 1.0) + + +def test_conv2d(): + if not tvm.runtime.enabled("micro_dev"): + return + + from tvm.relay import create_executor + from tvm.relay import transform + + dshape = (1, 4, 16, 16) + dtype = 'int8' + func_name = 'fused_nn_conv2d' + + reset_gdbinit() + + # Construct Relay program. + x = relay.var("x", shape=dshape, dtype=dtype) + conv_expr = relay.nn.conv2d( + x, relay.var("w"), + kernel_size=(3, 3), + padding=(1, 1), + channels=4) + func = relay.Function(relay.analysis.free_vars(conv_expr), conv_expr) + mod = tvm.IRModule.from_expr(func) + mod = transform.InferType()(mod) + + x_shape = list(map(lambda x: x.value, mod['main'].params[0].checked_type.shape)) + w_shape = list(map(lambda x: x.value, mod['main'].params[1].checked_type.shape)) + out_shape = list(map(lambda x: x.value, mod['main'].ret_type.shape)) + + with tvm.transform.PassContext(config={ + "tir.disable_vectorize": True + }): + graph, c_mod, params = relay.build(mod, target="c") + + with micro.Session(DEV_CONFIG_A): + micro_mod = micro.create_micro_mod(c_mod, DEV_CONFIG_A) + candidate_func_name = func_name + for i in range(100): + try: + micro_func = micro_mod[candidate_func_name] + break + except tvm.TVMError as e: + candidate_func_name = f'{func_name}_{i}' + else: + assert False + ctx = tvm.micro_dev(0) + + x_data = tvm.nd.array(np.random.uniform(size=x_shape).astype(dtype), ctx) + w_data = tvm.nd.array(np.random.uniform(size=w_shape).astype(dtype), ctx) + result = tvm.nd.array(np.zeros(shape=out_shape, dtype=dtype), ctx) + micro_func(x_data, w_data, result) + + out_data = np.zeros(out_shape, dtype=dtype) + params = { 'x': x_data.asnumpy(), 'w': w_data.asnumpy() } + intrp = create_executor('debug') + expected_result = intrp.evaluate(mod['main'])(x_data, w_data) + + tvm.testing.assert_allclose(result.asnumpy(), expected_result.asnumpy()) + + +def test_interleave_sessions(): + """Test closing and reopening sessions.""" + if not tvm.runtime.enabled("micro_dev"): + return + shape = (1024,) + dtype = "float32" + + # Construct Relay add program. + x = relay.var("x", relay.TensorType(shape=shape, dtype=dtype)) + ret = relay.add(x, relay.const(1.0)) + add_const_func = relay.Function([x], ret) + + sess_a = micro.Session(DEV_CONFIG_A) + sess_b = micro.Session(DEV_CONFIG_B) + with sess_a: + np_tensor_a = np.random.uniform(size=shape).astype(dtype) + micro_tensor_a = tvm.nd.array(np_tensor_a, tvm.micro_dev(0)) + with sess_b: + np_tensor_b = np.random.uniform(size=shape).astype(dtype) + micro_tensor_b = tvm.nd.array(np_tensor_b, tvm.micro_dev(0)) + with sess_a: + add_const_mod = relay_micro_build(add_const_func, DEV_CONFIG_A) + add_const_mod.run(x=micro_tensor_a) + add_result = add_const_mod.get_output(0).asnumpy() + tvm.testing.assert_allclose( + add_result, np_tensor_a + 1.0) + with sess_b: + add_const_mod = relay_micro_build(add_const_func, DEV_CONFIG_B) + add_const_mod.run(x=micro_tensor_b) + add_result = add_const_mod.get_output(0).asnumpy() + tvm.testing.assert_allclose( + add_result, np_tensor_b + 1.0) + + +def test_nested_sessions(): + """Test entering and exiting nested session contexts.""" + if not tvm.runtime.enabled("micro_dev"): + return + shape = (1024,) + dtype = "float32" + + # Construct Relay add program. + x = relay.var("x", relay.TensorType(shape=shape, dtype=dtype)) + ret = relay.add(x, relay.const(1.0)) + add_const_func = relay.Function([x], ret) + + sess_a = micro.Session(DEV_CONFIG_A) + sess_b = micro.Session(DEV_CONFIG_B) + with sess_a: + np_tensor_a = np.random.uniform(size=shape).astype(dtype) + micro_tensor_a = tvm.nd.array(np_tensor_a, tvm.micro_dev(0)) + with sess_b: + np_tensor_b = np.random.uniform(size=shape).astype(dtype) + micro_tensor_b = tvm.nd.array(np_tensor_b, tvm.micro_dev(0)) + add_const_mod = relay_micro_build(add_const_func, DEV_CONFIG_A) + add_const_mod.run(x=micro_tensor_a) + add_result = add_const_mod.get_output(0).asnumpy() + tvm.testing.assert_allclose( + add_result, np_tensor_a + 1.0) + + +def test_inactive_session_use(): + """Test the use of objects allocated in a session that is no longer active.""" + if not tvm.runtime.enabled("micro_dev"): + return + shape = (1024,) + dtype = "float32" + + # Construct Relay add program. + x = relay.var("x", relay.TensorType(shape=shape, dtype=dtype)) + ret = relay.add(x, relay.const(1.0)) + add_const_func = relay.Function([x], ret) + + sess_a = micro.Session(DEV_CONFIG_A) + sess_b = micro.Session(DEV_CONFIG_B) + with sess_a: + np_tensor_a = np.random.uniform(size=shape).astype(dtype) + micro_tensor_a = tvm.nd.array(np_tensor_a, tvm.micro_dev(0)) + add_const_mod = relay_micro_build(add_const_func, DEV_CONFIG_A) + + with sess_b: + # These objects belong to `sess_a`. + add_const_mod.run(x=micro_tensor_a) + add_result = add_const_mod.get_output(0).asnumpy() + tvm.testing.assert_allclose( + add_result, np_tensor_a + 1.0) + + +# TODO add workspace alloc/free stress test + +if __name__ == "__main__": + test_alloc() + print() + print('finished alloc test') + input('[press enter to continue]') + test_add() + print() + print('finished add test') + input('[press enter to continue]') + test_workspace_add() + print() + print('finished workspace add test') + input('[press enter to continue]') + test_graph_runtime() + print() + print('finished graph runtime test') + input('[press enter to continue]') + test_conv2d() + print() + print('finished conv2d test') + input('[press enter to continue]') + # disable for now as these are currently broken + #test_interleave_sessions() + #print() + #print('finished interleaved sessions test') + #input('[press enter to continue]') + # test_nested_sessions() + #print() + #print('finished nested sessions test') + #input('[press enter to continue]') + test_inactive_session_use() + print() + print('finished use inactive session test') + input('[press enter to continue]') diff --git a/tests/python/contrib/test_binutil.py b/tests/python/contrib/test_binutil.py index 3106e73136fa..3aa0583b2816 100644 --- a/tests/python/contrib/test_binutil.py +++ b/tests/python/contrib/test_binutil.py @@ -43,7 +43,7 @@ def make_binary(): tmp_obj = tmp_dir.relpath("obj.obj") with open(tmp_source, "w") as f: f.write(prog) - cc.create_shared(tmp_obj, tmp_source, [], + cc.create_executable(tmp_obj, tmp_source, [], cc="{}gcc".format(TOOLCHAIN_PREFIX)) prog_bin = bytearray(open(tmp_obj, "rb").read()) return prog_bin diff --git a/tests/python/contrib/test_coreml_codegen.py b/tests/python/contrib/test_coreml_codegen.py new file mode 100644 index 000000000000..be47b3e4fc2b --- /dev/null +++ b/tests/python/contrib/test_coreml_codegen.py @@ -0,0 +1,139 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import numpy as np +import pytest +from unittest import mock + +import tvm +from tvm import relay +from tvm.relay import transform +from tvm.contrib.target import coreml as _coreml + +pytest.importorskip("coremltools") + + +def _has_xcode(): + try: + tvm.contrib.xcode.xcrun([]) + return True + except FileNotFoundError: + pass + + return False + + +def _create_graph(): + shape = (10, 10) + mod = tvm.IRModule() + + x = relay.var('x', shape=shape) + y = relay.var('y', shape=shape) + z = x + x + p = y * y + func = relay.Function([x, y], p - z) + mod["main"] = func + + return mod + + +def _create_graph_annotated(): + shape = (10, 10) + target = "coremlcompiler" + mod = tvm.IRModule() + + # function 0 + f0_i0 = relay.var(target + "_0_i0", shape=shape) + func0 = relay.Function([f0_i0], f0_i0 * f0_i0) + + func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) + func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1)) + func0 = func0.with_attr("Compiler", target) + func0 = func0.with_attr("global_symbol", target + "_0") + gv0 = relay.GlobalVar(target + "_0") + mod[gv0] = func0 + + # function 2 + f2_i0 = relay.var(target + "_2_i0", shape=shape) + func2 = relay.Function([f2_i0], f2_i0 + f2_i0) + + func2 = func2.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) + func2 = func2.with_attr("Inline", tvm.tir.IntImm("int32", 1)) + func2 = func2.with_attr("Compiler", target) + func2 = func2.with_attr("global_symbol", target + "_2") + gv2 = relay.GlobalVar(target + "_2") + mod[gv2] = func2 + + # body + x = relay.var('x', shape=shape) + y = relay.var('y', shape=shape) + func = relay.Function([x, y], gv0(y) - gv2(x)) + mod["main"] = func + + return mod + + +def test_annotate(): + mod = _create_graph() + mod = transform.AnnotateTarget("coremlcompiler")(mod) + mod = transform.PartitionGraph()(mod) + + expected = _create_graph_annotated() + assert tvm.ir.structural_equal(mod, expected, map_free_vars=True) + + +@mock.patch('tvm.contrib.coreml_runtime.create') +@mock.patch('tvm.contrib.xcode.compile_coreml') +def test_construct_model(m1, m2): + mod = _create_graph_annotated() + + fcompile = tvm._ffi.get_global_func("relay.ext.coremlcompiler") + + for var, func in mod.functions.items(): + if func.attrs and 'Compiler' in func.attrs and \ + func.attrs['Compiler'] == 'coremlcompiler': + fcompile(tvm.IRModule.from_expr(func.body)) + + +@pytest.mark.skipif(not _has_xcode(), reason="Xcode is not available") +def test_compile_and_run(): + ctx=tvm.cpu() + target="llvm" + tol=1e-3 + + with relay.build_config(opt_level=3): + json, lib, params = relay.build(_create_graph_annotated(), target=target) + m = tvm.contrib.graph_runtime.create(json, lib, ctx) + + shape = (10, 10) + x_data = np.random.rand(*shape).astype('float32') + y_data = np.random.rand(*shape).astype('float32') + + m.set_input("x", x_data) + m.set_input("y", y_data) + m.set_input(**params) + m.run() + out = tvm.nd.empty(shape, ctx=ctx) + out = m.get_output(0, out) + + expected = (y_data * y_data) - (x_data + x_data) + tvm.testing.assert_allclose(out.asnumpy(), expected, rtol=tol, atol=tol) + + +if __name__ == "__main__": + test_annotate() + test_construct_model() + test_compile_and_run() diff --git a/tests/python/contrib/test_coreml_runtime.py b/tests/python/contrib/test_coreml_runtime.py new file mode 100644 index 000000000000..78bacfd2f199 --- /dev/null +++ b/tests/python/contrib/test_coreml_runtime.py @@ -0,0 +1,107 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +from tvm import te +import numpy as np +from tvm import rpc +from tvm.contrib import util, xcode, coreml_runtime + +import pytest +import os + +proxy_host = os.environ.get("TVM_IOS_RPC_PROXY_HOST", "localhost") +proxy_port = os.environ.get("TVM_IOS_RPC_PROXY_PORT", 9090) +destination = os.environ.get("TVM_IOS_RPC_DESTINATION", "") +key = "iphone" + +@pytest.mark.skip('skip because coremltools is not available in CI') +def test_coreml_runtime(): + + import coremltools + from coremltools.models.neural_network import NeuralNetworkBuilder + + def create_coreml_model(): + shape = (2,) + alpha = 2 + + inputs = [ + ('input0', coremltools.models.datatypes.Array(*shape)), + ('input1', coremltools.models.datatypes.Array(*shape)) + ] + outputs = [ + ('output0', coremltools.models.datatypes.Array(*shape)), + ('output1', coremltools.models.datatypes.Array(*shape)), + ] + builder = NeuralNetworkBuilder(inputs, outputs) + builder.add_elementwise(name='Add', + input_names=['input0', 'input1'], + output_name='output0', + mode='ADD') + builder.add_elementwise(name='Mul', + alpha=alpha, + input_names=['input0'], + output_name='output1', + mode='MULTIPLY') + return coremltools.models.MLModel(builder.spec) + + def verify(coreml_model, model_dir, ctx): + coreml_model = create_coreml_model() + + out_spec = coreml_model.output_description._fd_spec + out_names = [spec.name for spec in out_spec] + + # inference via coremltools + inputs = {} + for in_spec in coreml_model.input_description._fd_spec: + name = in_spec.name + shape = in_spec.type.multiArrayType.shape + inputs[name] = np.random.random_sample(shape) + + coreml_outputs = [coreml_model.predict(inputs)[name] for name in out_names] + + # inference via tvm coreml runtime + runtime = coreml_runtime.create(model_dir, ctx) + for name in inputs: + runtime.set_input(name, tvm.nd.array(inputs[name], ctx)) + runtime.invoke() + tvm_outputs = [runtime.get_output(i).asnumpy() for i in range(runtime.get_num_outputs())] + + for c_out, t_out in zip(coreml_outputs, tvm_outputs): + np.testing.assert_almost_equal(c_out, t_out, 3) + + def check_remote(coreml_model): + temp = util.tempdir() + compiled_model = xcode.compile_coreml(coreml_model, out_dir=temp.temp_dir) + xcode.popen_test_rpc(proxy_host, proxy_port, key, destination=destination, + libs=[compiled_model]) + remote = rpc.connect(proxy_host, proxy_port, key=key) + ctx = remote.cpu(0) + verify(coreml_model, "tvm", ctx) + + def check_local(coreml_model): + temp = util.tempdir() + xcode.compile_coreml(coreml_model, out_dir=temp.temp_dir) + ctx = tvm.cpu(0) + verify(coreml_model, temp.temp_dir, ctx) + + coreml_model = create_coreml_model() + check_remote(coreml_model) + check_local(coreml_model) + + +if __name__ == "__main__": + test_coreml_runtime() diff --git a/tests/python/contrib/test_cudnn.py b/tests/python/contrib/test_cudnn.py index 5d1f100c1fc4..17cb0d1f0f1c 100644 --- a/tests/python/contrib/test_cudnn.py +++ b/tests/python/contrib/test_cudnn.py @@ -17,11 +17,11 @@ import tvm from tvm import te from tvm.contrib import cudnn +from tvm.contrib.nvcc import have_fp16 import numpy as np import topi.testing - -def verify_conv2d(data_dtype, conv_dtype, tensor_format=0): +def verify_conv2d(data_dtype, conv_dtype, tensor_format=0, groups=1): in_channel = 4 out_channel = 16 filter_h = 3 @@ -34,7 +34,7 @@ def verify_conv2d(data_dtype, conv_dtype, tensor_format=0): dilation_w = 1 batch = 3 height = 32 - weight = 32 + width = 32 if not tvm.runtime.enabled("cuda"): print("skip because cuda is not enabled...") @@ -42,12 +42,17 @@ def verify_conv2d(data_dtype, conv_dtype, tensor_format=0): if not tvm.get_global_func("tvm.contrib.cudnn.conv.output_shape", True): print("skip because cudnn is not enabled...") return + if data_dtype == "float16" and not have_fp16(tvm.gpu(0).compute_version): + print("Skip because gpu does not have fp16 support") + return + + # schedule if tensor_format == 0: - xshape = [batch, in_channel, height, weight] - wshape = [out_channel, in_channel, filter_h, filter_w] + xshape = [batch, in_channel, height, width] + wshape = [out_channel, in_channel // groups, filter_h, filter_w] else: - xshape = [batch, height, weight, in_channel] - wshape = [out_channel, filter_h, filter_w, in_channel] + xshape = [batch, height, width, in_channel] + wshape = [out_channel, filter_h, filter_w, in_channel // groups] X = te.placeholder(xshape, name='X', dtype=data_dtype) W = te.placeholder(wshape, name='W', dtype=data_dtype) @@ -59,39 +64,41 @@ def verify_conv2d(data_dtype, conv_dtype, tensor_format=0): conv_mode=1, tensor_format=tensor_format, conv_dtype=conv_dtype, - algo=-1) + algo=-1, + groups=groups) yshape = [x.value for x in Y.shape] s = te.create_schedule(Y.op) - def verify(): - ctx = tvm.gpu(0) - f = tvm.build(s, [X, W, Y], "cuda", target_host="llvm", name="conv2d") - x_np = np.random.uniform(-1, 1, xshape).astype(data_dtype) - w_np = np.random.uniform(-1, 1, wshape).astype(data_dtype) - y_np = np.zeros(yshape).astype(data_dtype) - x = tvm.nd.array(x_np, ctx) - w = tvm.nd.array(w_np, ctx) - y = tvm.nd.array(y_np, ctx) - if tensor_format == 0: - c_np = topi.testing.conv2d_nchw_python(x_np, w_np, 1, 1) - elif tensor_format == 1: - wt = w_np.transpose((1, 2, 3, 0)) #OHWI => HWIO - c_np = topi.testing.conv2d_nhwc_python(x_np, wt, 1, 1) - - f(x, w, y) - tvm.testing.assert_allclose(y.asnumpy(), c_np, atol=3e-5, rtol=1e-3) - - verify() + # validation + ctx = tvm.gpu(0) + f = tvm.build(s, [X, W, Y], "cuda", target_host="llvm", name="conv2d") + x_np = np.random.uniform(-1, 1, xshape).astype(data_dtype) + w_np = np.random.uniform(-1, 1, wshape).astype(data_dtype) + y_np = np.zeros(yshape).astype(data_dtype) + x = tvm.nd.array(x_np, ctx) + w = tvm.nd.array(w_np, ctx) + y = tvm.nd.array(y_np, ctx) + if tensor_format == 0: + c_np = topi.testing.conv2d_nchw_python(x_np, w_np, 1, 1, groups=groups) + elif tensor_format == 1: + wt = w_np.transpose((1, 2, 3, 0)) #OHWI => HWIO + c_np = topi.testing.conv2d_nhwc_python(x_np, wt, 1, 1, groups=groups) + + f(x, w, y) + tvm.testing.assert_allclose(y.asnumpy(), c_np, atol=1e-2, rtol=1e-2) def test_conv2d(): verify_conv2d("float32", "float32", tensor_format=0) verify_conv2d("float16", "float32", tensor_format=1) - #Not pass accuracy test, need check - #verify_conv2d("float16", "float16", tensor_format=0) + verify_conv2d("float16", "float16", tensor_format=0) verify_conv2d("int8", "int32", tensor_format=1) + verify_conv2d("float32", "float32", tensor_format=0, groups=2) + verify_conv2d("float16", "float32", tensor_format=1, groups=2) + verify_conv2d("float16", "float16", tensor_format=0, groups=2) + verify_conv2d("int8", "int32", tensor_format=1, groups=2) -def verify_conv3d(data_dtype, conv_dtype, tensor_format=0): +def verify_conv3d(data_dtype, conv_dtype, tensor_format=0, groups=1): in_channel = 4 out_channel = 16 filter_d = 3 @@ -109,7 +116,7 @@ def verify_conv3d(data_dtype, conv_dtype, tensor_format=0): batch = 3 depth = 32 height = 32 - weight = 32 + width = 32 if not tvm.runtime.enabled("cuda"): print("skip because cuda is not enabled...") @@ -118,8 +125,9 @@ def verify_conv3d(data_dtype, conv_dtype, tensor_format=0): print("skip because cudnn is not enabled...") return - xshape = [batch, in_channel, depth, height, weight] - wshape = [out_channel, in_channel, filter_d, filter_h, filter_w] + # schedule + xshape = [batch, in_channel, depth, height, width] + wshape = [out_channel, in_channel // groups, filter_d, filter_h, filter_w] X = te.placeholder(xshape, name='X', dtype=data_dtype) W = te.placeholder(wshape, name='W', dtype=data_dtype) @@ -131,33 +139,31 @@ def verify_conv3d(data_dtype, conv_dtype, tensor_format=0): conv_mode=1, tensor_format=tensor_format, algo=-1, - conv_dtype=conv_dtype) + conv_dtype=conv_dtype, + groups=groups) yshape = [x.value for x in Y.shape] s = te.create_schedule(Y.op) - def verify(): - ctx = tvm.gpu(0) - f = tvm.build(s, [X, W, Y], "cuda", target_host="llvm", name="conv3d") - x_np = np.random.uniform(-1, 1, xshape).astype(data_dtype) - w_np = np.random.uniform(-1, 1, wshape).astype(data_dtype) - y_np = np.zeros(yshape).astype(data_dtype) - x = tvm.nd.array(x_np, ctx) - w = tvm.nd.array(w_np, ctx) - y = tvm.nd.array(y_np, ctx) - if tensor_format == 0: - c_np = topi.testing.conv3d_ncdhw_python(x_np, w_np, 1, 1) - else: - raise AssertionError("For now, conv3d tensor format only support: 0(NCHW)") - - f(x, w, y) - tvm.testing.assert_allclose(y.asnumpy(), c_np, atol=3e-5, rtol=1e-4) - - verify() + # validation + ctx = tvm.gpu(0) + f = tvm.build(s, [X, W, Y], "cuda", target_host="llvm", name="conv3d") + x_np = np.random.uniform(-1, 1, xshape).astype(data_dtype) + w_np = np.random.uniform(-1, 1, wshape).astype(data_dtype) + y_np = np.zeros(yshape).astype(data_dtype) + x = tvm.nd.array(x_np, ctx) + w = tvm.nd.array(w_np, ctx) + y = tvm.nd.array(y_np, ctx) + if tensor_format == 0: + c_np = topi.testing.conv3d_ncdhw_python(x_np, w_np, 1, 1, groups) + else: + raise AssertionError("For now, conv3d tensor format only support: 0(NCHW)") + f(x, w, y) + tvm.testing.assert_allclose(y.asnumpy(), c_np, atol=3e-5, rtol=1e-4) def test_conv3d(): verify_conv3d("float32", "float32", tensor_format=0) - + verify_conv3d("float32", "float32", tensor_format=0, groups=2) def verify_softmax(shape, axis, dtype="float32"): A = te.placeholder(shape, dtype=dtype, name='A') diff --git a/tests/python/contrib/test_tflite_runtime.py b/tests/python/contrib/test_tflite_runtime.py index 8c883b031a89..1b911b7eb632 100644 --- a/tests/python/contrib/test_tflite_runtime.py +++ b/tests/python/contrib/test_tflite_runtime.py @@ -14,92 +14,130 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import pytest + import tvm from tvm import te import numpy as np from tvm import rpc from tvm.contrib import util, tflite_runtime -# import tensorflow as tf -# import tflite_runtime.interpreter as tflite - - -def skipped_test_tflite_runtime(): - - def create_tflite_model(): - root = tf.Module() - root.const = tf.constant([1., 2.], tf.float32) - root.f = tf.function(lambda x: root.const * x) - - input_signature = tf.TensorSpec(shape=[2, ], dtype=tf.float32) - concrete_func = root.f.get_concrete_function(input_signature) - converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) - tflite_model = converter.convert() - return tflite_model - - - def check_local(): - tflite_fname = "model.tflite" - tflite_model = create_tflite_model() - temp = util.tempdir() - tflite_model_path = temp.relpath(tflite_fname) - open(tflite_model_path, 'wb').write(tflite_model) - - # inference via tflite interpreter python apis - interpreter = tflite.Interpreter(model_path=tflite_model_path) - interpreter.allocate_tensors() - input_details = interpreter.get_input_details() - output_details = interpreter.get_output_details() - - input_shape = input_details[0]['shape'] - tflite_input = np.array(np.random.random_sample(input_shape), dtype=np.float32) - interpreter.set_tensor(input_details[0]['index'], tflite_input) - interpreter.invoke() - tflite_output = interpreter.get_tensor(output_details[0]['index']) - - # inference via tvm tflite runtime - with open(tflite_model_path, 'rb') as model_fin: - runtime = tflite_runtime.create(model_fin.read(), tvm.cpu(0)) - runtime.set_input(0, tvm.nd.array(tflite_input)) - runtime.invoke() - out = runtime.get_output(0) - np.testing.assert_equal(out.asnumpy(), tflite_output) - - - def check_remote(): - tflite_fname = "model.tflite" - tflite_model = create_tflite_model() - temp = util.tempdir() - tflite_model_path = temp.relpath(tflite_fname) - open(tflite_model_path, 'wb').write(tflite_model) - - # inference via tflite interpreter python apis - interpreter = tflite.Interpreter(model_path=tflite_model_path) - interpreter.allocate_tensors() - input_details = interpreter.get_input_details() - output_details = interpreter.get_output_details() - - input_shape = input_details[0]['shape'] - tflite_input = np.array(np.random.random_sample(input_shape), dtype=np.float32) - interpreter.set_tensor(input_details[0]['index'], tflite_input) - interpreter.invoke() - tflite_output = interpreter.get_tensor(output_details[0]['index']) - - # inference via remote tvm tflite runtime - server = rpc.Server("localhost") - remote = rpc.connect(server.host, server.port) - ctx = remote.cpu(0) - a = remote.upload(tflite_model_path) - - with open(tflite_model_path, 'rb') as model_fin: - runtime = tflite_runtime.create(model_fin.read(), remote.cpu(0)) - runtime.set_input(0, tvm.nd.array(tflite_input, remote.cpu(0))) - runtime.invoke() - out = runtime.get_output(0) - np.testing.assert_equal(out.asnumpy(), tflite_output) - - check_local() - check_remote() + + +def _create_tflite_model(): + if not tvm.runtime.enabled("tflite"): + print("skip because tflite runtime is not enabled...") + return + if not tvm.get_global_func("tvm.tflite_runtime.create", True): + print("skip because tflite runtime is not enabled...") + return + + try: + import tensorflow as tf + except ImportError: + print('skip because tensorflow not installed...') + return + + root = tf.Module() + root.const = tf.constant([1., 2.], tf.float32) + root.f = tf.function(lambda x: root.const * x) + + input_signature = tf.TensorSpec(shape=[2, ], dtype=tf.float32) + concrete_func = root.f.get_concrete_function(input_signature) + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + tflite_model = converter.convert() + return tflite_model + + +@pytest.mark.skip('skip because accessing output tensor is flakey') +def test_local(): + if not tvm.runtime.enabled("tflite"): + print("skip because tflite runtime is not enabled...") + return + if not tvm.get_global_func("tvm.tflite_runtime.create", True): + print("skip because tflite runtime is not enabled...") + return + + try: + import tensorflow as tf + except ImportError: + print('skip because tensorflow not installed...') + return + + tflite_fname = "model.tflite" + tflite_model = _create_tflite_model() + temp = util.tempdir() + tflite_model_path = temp.relpath(tflite_fname) + open(tflite_model_path, 'wb').write(tflite_model) + + # inference via tflite interpreter python apis + interpreter = tf.lite.Interpreter(model_path=tflite_model_path) + interpreter.allocate_tensors() + input_details = interpreter.get_input_details() + output_details = interpreter.get_output_details() + + input_shape = input_details[0]['shape'] + tflite_input = np.array(np.random.random_sample(input_shape), dtype=np.float32) + interpreter.set_tensor(input_details[0]['index'], tflite_input) + interpreter.invoke() + tflite_output = interpreter.get_tensor(output_details[0]['index']) + + # inference via tvm tflite runtime + with open(tflite_model_path, 'rb') as model_fin: + runtime = tflite_runtime.create(model_fin.read(), tvm.cpu(0)) + runtime.set_input(0, tvm.nd.array(tflite_input)) + runtime.invoke() + out = runtime.get_output(0) + np.testing.assert_equal(out.asnumpy(), tflite_output) + + +def test_remote(): + if not tvm.runtime.enabled("tflite"): + print("skip because tflite runtime is not enabled...") + return + if not tvm.get_global_func("tvm.tflite_runtime.create", True): + print("skip because tflite runtime is not enabled...") + return + + try: + import tensorflow as tf + except ImportError: + print('skip because tensorflow not installed...') + return + + tflite_fname = "model.tflite" + tflite_model = _create_tflite_model() + temp = util.tempdir() + tflite_model_path = temp.relpath(tflite_fname) + open(tflite_model_path, 'wb').write(tflite_model) + + # inference via tflite interpreter python apis + interpreter = tf.lite.Interpreter(model_path=tflite_model_path) + interpreter.allocate_tensors() + input_details = interpreter.get_input_details() + output_details = interpreter.get_output_details() + + input_shape = input_details[0]['shape'] + tflite_input = np.array(np.random.random_sample(input_shape), dtype=np.float32) + interpreter.set_tensor(input_details[0]['index'], tflite_input) + interpreter.invoke() + tflite_output = interpreter.get_tensor(output_details[0]['index']) + + # inference via remote tvm tflite runtime + server = rpc.Server("localhost") + remote = rpc.connect(server.host, server.port) + ctx = remote.cpu(0) + a = remote.upload(tflite_model_path) + + with open(tflite_model_path, 'rb') as model_fin: + runtime = tflite_runtime.create(model_fin.read(), remote.cpu(0)) + runtime.set_input(0, tvm.nd.array(tflite_input, remote.cpu(0))) + runtime.invoke() + out = runtime.get_output(0) + np.testing.assert_equal(out.asnumpy(), tflite_output) + + server.terminate() + if __name__ == "__main__": - # skipped_test_tflite_runtime() - pass + test_local() + test_remote() diff --git a/tests/python/contrib/test_util.py b/tests/python/contrib/test_util.py new file mode 100644 index 000000000000..55a2b7616e84 --- /dev/null +++ b/tests/python/contrib/test_util.py @@ -0,0 +1,86 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Tests for functions in tvm/python/tvm/contrib/util.py.""" + +import datetime +import os +import shutil +from tvm.contrib import util + + +def validate_debug_dir_path(temp_dir, expected_basename): + dirname, basename = os.path.split(temp_dir.temp_dir) + assert basename == expected_basename, 'unexpected basename: %s' % (basename,) + + parent_dir = os.path.basename(dirname) + create_time = datetime.datetime.strptime(parent_dir.split('___', 1)[0], '%Y-%m-%dT%H-%M-%S') + assert abs(datetime.datetime.now() - create_time) < datetime.timedelta(seconds=60) + + + +def test_tempdir(): + assert util.TempDirectory._KEEP_FOR_DEBUG == False, "don't submit with KEEP_FOR_DEBUG == True" + + temp_dir = util.tempdir() + assert os.path.exists(temp_dir.temp_dir) + + old_debug_mode = util.TempDirectory._KEEP_FOR_DEBUG + try: + for temp_dir_number in range(0, 3): + with util.TempDirectory.set_keep_for_debug(): + debug_temp_dir = util.tempdir() + try: + validate_debug_dir_path(debug_temp_dir, '0000' + str(temp_dir_number)) + finally: + shutil.rmtree(debug_temp_dir.temp_dir) + + with util.TempDirectory.set_keep_for_debug(): + # Create 2 temp_dir within the same session. + debug_temp_dir = util.tempdir() + try: + validate_debug_dir_path(debug_temp_dir, '00003') + finally: + shutil.rmtree(debug_temp_dir.temp_dir) + + debug_temp_dir = util.tempdir() + try: + validate_debug_dir_path(debug_temp_dir, '00004') + finally: + shutil.rmtree(debug_temp_dir.temp_dir) + + with util.TempDirectory.set_keep_for_debug(False): + debug_temp_dir = util.tempdir() # This one should get deleted. + + # Simulate atexit hook + util.TempDirectory.remove_tempdirs() + + # Calling twice should be a no-op. + util.TempDirectory.remove_tempdirs() + + # Creating a new TempDirectory should fail now + try: + util.tempdir() + assert False, 'creation should fail' + except util.DirectoryCreatedPastAtExit: + pass + + finally: + util.TempDirectory.DEBUG_MODE = old_debug_mode + + +if __name__ == '__main__': + test_tempdir() diff --git a/tests/python/frontend/caffe2/test_forward.py b/tests/python/frontend/caffe2/test_forward.py index f05287216ec9..50a878180ac9 100644 --- a/tests/python/frontend/caffe2/test_forward.py +++ b/tests/python/frontend/caffe2/test_forward.py @@ -43,7 +43,7 @@ def get_tvm_output(model, dtype_dict = {input_names: input_data.dtype} mod, params = relay.frontend.from_caffe2( model.init_net, model.predict_net, shape_dict, dtype_dict) - with relay.build_config(opt_level=3): + with tvm.transform.PassContext(opt_level=3): graph, lib, params = relay.build(mod, target, params=params) m = graph_runtime.create(graph, lib, ctx) diff --git a/tests/python/frontend/coreml/test_forward.py b/tests/python/frontend/coreml/test_forward.py index 3a156385d510..179f5b41c1d7 100644 --- a/tests/python/frontend/coreml/test_forward.py +++ b/tests/python/frontend/coreml/test_forward.py @@ -33,7 +33,7 @@ def get_tvm_output(func, x, params, target, ctx, out_shape=(1, 1000), input_name='image', dtype='float32'): - with relay.transform.build_config(opt_level=3): + with tvm.transform.PassContext(opt_level=3): graph, lib, params = relay.build(func, target, params=params) m = graph_runtime.create(graph, lib, ctx) # set inputs @@ -76,7 +76,7 @@ def run_tvm_graph(coreml_model, target, ctx, input_data, input_name, output_shap dtype_dict = {input_name: input_data.dtype} mod, params = relay.frontend.from_coreml(coreml_model, shape_dict) - with relay.transform.build_config(opt_level=3): + with tvm.transform.PassContext(opt_level=3): graph, lib, params = relay.build(mod, target, params=params) from tvm.contrib import graph_runtime diff --git a/tests/python/frontend/keras/test_forward.py b/tests/python/frontend/keras/test_forward.py index c2caf916558e..9b963c396319 100644 --- a/tests/python/frontend/keras/test_forward.py +++ b/tests/python/frontend/keras/test_forward.py @@ -84,7 +84,7 @@ def get_keras_output(xs, dtype='float32'): def get_tvm_output(xs, target, ctx, dtype='float32'): shape_dict = {name: x.shape for (name, x) in zip(keras_model.input_names, xs)} mod, params = relay.frontend.from_keras(keras_model, shape_dict, layout=layout) - with relay.transform.build_config(opt_level=2): + with tvm.transform.PassContext(opt_level=2): graph, lib, params = relay.build(mod, target, params=params) @@ -125,6 +125,7 @@ def test_forward_merge(self, keras): keras.layers.Subtract(), keras.layers.Multiply(), keras.layers.Maximum(), + keras.layers.Minimum(), keras.layers.Average(), keras.layers.Concatenate()] for merge_func in merge_funcs: @@ -465,6 +466,35 @@ def test_forward_zero_padding3d(self, keras): keras_model = keras.models.Model(data, x) verify_keras_frontend(keras_model, layout='NDHWC') + + def test_forward_embedding(self, keras): + data = keras.layers.Input(shape=(2, 4), dtype="int32") + x = keras.layers.Embedding(10, 3)(data) + keras_model = keras.models.Model(data, x) + verify_keras_frontend(keras_model, need_transpose=False) + + data = keras.layers.Input(shape=(2, 3, 4), dtype="int32") + x = keras.layers.Embedding(4, 5)(data) + keras_model = keras.models.Model(data, x) + verify_keras_frontend(keras_model, need_transpose=False) + + data = keras.layers.Input(shape=(6, 2, 3, 4), dtype="int32") + x = keras.layers.Embedding(4, 5)(data) + keras_model = keras.models.Model(data, x) + verify_keras_frontend(keras_model, need_transpose=False) + + def test_forward_global_pool3d(self, keras): + data = keras.layers.Input(shape=(32, 32, 32, 1)) + pool_funcs = [# global maxpool + keras.layers.GlobalMaxPooling3D(), + # global avgpool + keras.layers.GlobalAveragePooling3D() + ] + for pool_func in pool_funcs: + x = pool_func(data) + keras_model = keras.models.Model(data, x) + verify_keras_frontend(keras_model, layout='NDHWC') + if __name__ == '__main__': for k in [keras, tf_keras]: sut = TestKeras() @@ -494,6 +524,7 @@ def test_forward_zero_padding3d(self, keras): sut.test_forward_mobilenet(keras=k, layout='NHWC') sut.test_forward_conv3d(keras=k) sut.test_forward_pool3d(keras=k) + sut.test_forward_global_pool3d(keras=k) sut.test_forward_upsample3d(keras=k) sut.test_forward_zero_padding3d(keras=k) - + sut.test_forward_embedding(keras=k) diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index 4a9848e03b5e..00c077f0d2e0 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -66,7 +66,7 @@ def get_tvm_output(symbol, x, args, auxs, target, ctx, dtype='float32'): shape_dict, arg_params=args, aux_params=auxs) - with relay.build_config(opt_level=3): + with tvm.transform.PassContext(opt_level=3): graph, lib, params = relay.build(mod, target, params=params) m = graph_runtime.create(graph, lib, ctx) # set inputs @@ -179,6 +179,14 @@ def test_forward_pooling(): mx_sym = mx.sym.Pooling(data, kernel=(3, 3), pad=(1, 1), pool_type='max') verify_mxnet_frontend_impl(mx_sym, (1, 20, 8, 8), (1, 20, 8, 8)) +def test_forward_pooling3d(): + data = mx.sym.var('data') + mx_sym = mx.sym.Pooling(data, kernel=(3, 3, 3), pad=(1, 1, 1), pool_type='avg') + verify_mxnet_frontend_impl(mx_sym, (1, 20, 8, 8, 8), (1, 20, 8, 8, 8)) + + mx_sym = mx.sym.Pooling(data, kernel=(3, 3, 3), pad=(1, 1, 1), pool_type='max') + verify_mxnet_frontend_impl(mx_sym, (1, 20, 8, 8, 8), (1, 20, 8, 8, 8)) + def test_forward_adaptive_pooling(): data = mx.sym.var('data') mx_sym = mx.sym.contrib.AdaptiveAvgPooling2D(data, output_size=(1,)) @@ -301,11 +309,25 @@ def _mx_symbol(F, op_name, inputs): return op(*inputs) def test_forward_broadcast_ops(): - for op in ["broadcast_add", "broadcast_sub", "broadcast_mul", - "broadcast_div", "broadcast_mod", "broadcast_maximum", - "broadcast_minimum", "broadcast_equal", "broadcast_not_equal", - "broadcast_greater", "broadcast_greater_equal", - "broadcast_lesser", "broadcast_lesser_equal"]: + for op in ["broadcast_add", + "broadcast_plus", + "broadcast_sub", + "broadcast_minus", + "broadcast_mul", + "broadcast_div", + "broadcast_mod", + "broadcast_maximum", + "broadcast_minimum", + "broadcast_equal", + "broadcast_not_equal", + "broadcast_greater", + "broadcast_greater_equal", + "broadcast_lesser", + "broadcast_lesser_equal", + "broadcast_power", + "broadcast_logical_or", + "broadcast_logical_and", + "broadcast_logical_xor"]: a_shape = (3, 4, 5) b_shape = (4, 5) if op == "broadcast_mod": @@ -328,13 +350,19 @@ def test_forward_broadcast_ops(): def test_forward_elemwise_ops(): for op in ["elemwise_add", "elemwise_sub", "elemwise_mul", - "elemwise_div", "maximum", "minimum"]: + "elemwise_div", "maximum", "minimum", + operator.lt, operator.le, operator.eq, + operator.ne, operator.gt, operator.ge]: shape = (3, 4, 5) dtype = 'float32' a_np = np.random.uniform(size=shape).astype(dtype) b_np = np.random.uniform(size=shape).astype(dtype) - mx_sym = _mx_symbol(mx.sym, op, [mx.sym.var('a'), mx.sym.var('b')]) - ref_res = _mx_symbol(mx.nd, op, [mx.nd.array(a_np), mx.nd.array(b_np)]) + if type(op) == str: + mx_sym = _mx_symbol(mx.sym, op, [mx.sym.var('a'), mx.sym.var('b')]) + ref_res = _mx_symbol(mx.nd, op, [mx.nd.array(a_np), mx.nd.array(b_np)]) + else: + mx_sym = op(mx.sym.var('a'), mx.sym.var('b')) + ref_res = op(mx.nd.array(a_np), mx.nd.array(b_np)) shapes = {'a': shape, 'b': shape} mod, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype) for target, ctx in ctx_list(): @@ -343,6 +371,37 @@ def test_forward_elemwise_ops(): op_res = intrp.evaluate()(a_np, b_np) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy()) + +def test_forward_softmin(): + data = mx.sym.var('data') + mx_sym = mx.sym.softmin(data) + verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 3, 100, 100)) + + mx_sym = mx.sym.softmin(data, axis=2) + verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 3, 100, 100)) + + +def test_forward_unary_ops(): + for op in ["abs", "sqrt", "ceil", "floor", "round", "reciprocal", "trunc", + "softsign", "hard_sigmoid", + "cos", "sin", "tan", + "cosh", "sinh", "tanh", + "arccos", "arcsin", "arctan", + "arccosh", "arcsinh", "arctanh"]: + shape = (1, 3, 4, 5) + dtype = 'float32' + a_np = np.random.uniform(size=shape).astype(dtype) + mx_sym = _mx_symbol(mx.sym, op, [mx.sym.var('a')]) + ref_res = _mx_symbol(mx.nd, op, [mx.nd.array(a_np)]) + shapes = {'a': shape} + mod, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype) + for target, ctx in ctx_list(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) + op_res = intrp.evaluate()(a_np) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-5, atol=1e-5) + + def test_forward_scalar_ops(): for op in [operator.add, operator.sub, operator.mul, operator.truediv, operator.pow, operator.lt, operator.le, operator.eq, @@ -456,16 +515,51 @@ def verify(shape, axis): def test_forward_broadcast_axis(): def verify(shape, axis, size): x_np = np.random.uniform(size=shape).astype("float32") - ref_res = mx.nd.broadcast_axis(mx.nd.array(x_np), axis=axis, size=size) - mx_sym = mx.sym.broadcast_axis(mx.sym.var("x"), axis=axis, size=size) - mod, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape}) + for op in ["broadcast_axis", + "broadcast_axes"]: + mx_sym = _mx_symbol(mx.sym, op, [mx.sym.var('x'),axis,size]) + ref_res = _mx_symbol(mx.nd, op, [mx.nd.array(x_np),axis,size]) + mod, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape}) + for target, ctx in ctx_list(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) + op_res = intrp.evaluate()(x_np) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy()) + + verify((1, 2, 1), 2, 3) + verify((1, 2, 1), (0, 2), (2, 3)) + + +def test_forward_broadcast_to(): + def verify(input_shape, shape): + x_np = np.random.uniform(size=input_shape).astype("float32") + ref_res = mx.nd.broadcast_to(mx.nd.array(x_np), shape=shape) + mx_sym = mx.sym.broadcast_to(mx.sym.var("x"), shape=shape) + mod, _ = relay.frontend.from_mxnet(mx_sym, {"x": input_shape}) for target, ctx in ctx_list(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) op_res = intrp.evaluate()(x_np) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy()) - verify((1, 2, 1), 2, 3) - verify((1, 2, 1), (0, 2), (2, 3)) + + verify((1, 2, 3), (3, 2, 3)) + verify((4, 1, 32, 32), (4, 8, 32, 32)) + + +def test_forward_logical_not(): + a_shape = (3, 4, 5) + dtype = 'float32' + a_np = np.random.uniform(size=a_shape).astype(dtype) + mx_sym = mx.sym.logical_not(mx.sym.var('a')) + ref_res = mx.nd.logical_not(mx.nd.array(a_np)) + shapes = {'a': a_shape} + mod, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype) + for target, ctx in ctx_list(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) + op_res = intrp.evaluate()(a_np) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy()) + def test_forward_full(): def verify(val, shape, dtype): @@ -531,7 +625,7 @@ def verify(shape, indices_src, axis, mode="clip"): verify((3,4), [-1, 5], 1, mode="wrap") def test_forward_gather_nd(): - def verify(xshape, yshape, y_data): + def verify(xshape, yshape, y_data, error=False): x_data = np.random.uniform(size=xshape).astype("float32") ref_res = mx.nd.gather_nd(mx.nd.array(x_data), mx.nd.array(y_data)) mx_sym = mx.sym.gather_nd(mx.sym.var("x_data"), mx.sym.var("y_data")) @@ -541,10 +635,12 @@ def verify(xshape, yshape, y_data): intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) op_res = intrp.evaluate()(x_data, y_data) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy()) + verify((2, 2), (2, 3), [[1, 1, 0], [0, 1, 0]]) verify((2, 2, 2), (2, 2), [[0, 1], [1, 0]]) verify((3, 2, 2), (2, 2), [[0, 1], [1, 0]]) verify((3, 2), (2, 2, 3), [[[0, 1, 2], [2, 0, 1]], [[0, 0, 0], [1, 1, 1]]]) + verify((1, 4), (1, 1), [[0]]) def test_forward_bilinear_resize(): # add tests including scale_height and scale_width when mxnet is updated to version 1.5 @@ -552,6 +648,42 @@ def test_forward_bilinear_resize(): mx_sym = mx.sym.contrib.BilinearResize2D(data, height=5, width=10) verify_mxnet_frontend_impl(mx_sym, (1, 2, 3, 4), (1, 2, 5, 10)) +def test_forward_grid_generator(): + def verify(shape, transform_type, target_shape): + x = np.random.uniform(size=shape).astype("float32") + ref_res = mx.nd.GridGenerator(mx.nd.array(x), transform_type, target_shape) + mx_sym = mx.sym.GridGenerator(mx.sym.var("x"), transform_type, target_shape) + shape_dict = {"x": x.shape} + mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict) + for target, ctx in ctx_list(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor( + kind, mod=mod, ctx=ctx, target=target) + op_res = intrp.evaluate()(x) + tvm.testing.assert_allclose( + op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-5, atol=1e-5) + verify((4, 6), 'affine', (16, 32)) + verify((4, 2, 16, 16), 'warp', None) + verify((1, 2, 16, 16), 'warp', None) + +def test_forward_bilinear_sampler(): + def verify(data_shape, grid_shape): + data = np.random.uniform(size=data_shape).astype("float32") + grid = np.random.uniform(low=-1.5, high=1.5, size=grid_shape).astype("float32") + ref_res = mx.nd.BilinearSampler(mx.nd.array(data), mx.nd.array(grid)) + mx_sym = mx.sym.BilinearSampler(mx.sym.var("data"), mx.sym.var("grid")) + shape_dict = {"data": data.shape, "grid": grid.shape} + mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict) + for target, ctx in ctx_list(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor( + kind, mod=mod, ctx=ctx, target=target) + op_res = intrp.evaluate()(data, grid) + tvm.testing.assert_allclose( + op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-5, atol=1e-5) + verify((4, 4, 16, 32), (4, 2, 8, 8)) + verify((4, 4, 16, 32), (4, 2, 32, 32)) + def test_forward_rnn_layer(): def verify(mode, seq_len, input_size, hidden_size, num_layers, batch=1, init_states=True, bidirectional=False): @@ -901,6 +1033,10 @@ def verify(data_shape, kernel_size, stride, pad, num_filter, is_depthwise=False) verify(data_shape=(20, 8, 32, 32), kernel_size=(3, 3), stride=(1, 1), pad=(1, 1), num_filter=2) verify(data_shape=(1, 8, 32, 32), kernel_size=(3, 3), stride=(1, 1), pad=(1, 1), num_filter=8, is_depthwise=True) + verify(data_shape=(1, 1, 16, 16, 16), kernel_size=(3, 3, 3), stride=(1, 1, 1), pad=(1, 1, 1), num_filter=2) + verify(data_shape=(20, 1, 16, 16, 16), kernel_size=(3, 3, 3), stride=(1, 1, 1), pad=(1, 1, 1), num_filter=2) + verify(data_shape=(1, 8, 16, 16, 16), kernel_size=(3, 3, 3), stride=(2, 2, 2), pad=(1, 1, 1), num_filter=2) + verify(data_shape=(20, 8, 16, 16, 16), kernel_size=(3, 3, 3), stride=(1, 1, 1), pad=(1, 1, 1), num_filter=2) def test_forward_deconvolution(): def verify(data_shape, kernel_size, stride, pad, num_filter): @@ -995,6 +1131,70 @@ def _verify_swap_axis(in_shape, out_shape, dim1, dim2): # _verify_swap_axis((4, 5), (5, 4), 0, 0) +def test_forward_depth_to_space(): + def verify(shape, blocksize=2): + x = np.random.uniform(size=shape).astype("float32") + ref_res = mx.nd.depth_to_space(mx.nd.array(x), blocksize) + mx_sym = mx.sym.depth_to_space(mx.sym.var("x"), blocksize) + shape_dict = {"x": x.shape, } + mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict) + for target, ctx in ctx_list(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) + op_res = intrp.evaluate()(x) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3, atol=1e-5) + + verify((1, 18, 3, 3), 3) + + +def test_forward_space_to_depth(): + def verify(shape, blocksize=2): + x = np.random.uniform(size=shape).astype("float32") + ref_res = mx.nd.space_to_depth(mx.nd.array(x), blocksize) + mx_sym = mx.sym.space_to_depth(mx.sym.var("x"), blocksize) + shape_dict = {"x": x.shape, } + mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict) + for target, ctx in ctx_list(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) + op_res = intrp.evaluate()(x) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3, atol=1e-5) + + verify((1, 1, 9, 9), 3) + + +def test_forward_correlation(): + def verify(data_shape, kernel_size, max_displacement, stride1, stride2, pad_size, + is_multiply): + data1 = np.random.uniform(size=data_shape).astype("float32") + data2 = np.random.uniform(size=data_shape).astype("float32") + ref_res = mx.nd.Correlation(data1=mx.nd.array(data1), data2=mx.nd.array(data2), + kernel_size=kernel_size, max_displacement=max_displacement, + stride1=stride1, stride2=stride2, pad_size=pad_size, + is_multiply=is_multiply) + mx_sym = mx.sym.Correlation(data1=mx.sym.var('data1'), data2=mx.sym.var('data2'), + kernel_size=kernel_size, max_displacement=max_displacement, + stride1=stride1, stride2=stride2, pad_size=pad_size, + is_multiply=is_multiply) + shape_dict = {"data1": data1.shape, "data2": data2.shape} + mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict) + for target, ctx in ctx_list(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) + op_res = intrp.evaluate()(data1, data2) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3, atol=1e-5) + + verify((1, 3, 10, 10), kernel_size = 1, max_displacement = 4, stride1 = 1, stride2 = 1, pad_size = 4, is_multiply = False) + verify((5, 1, 15, 15), kernel_size = 1, max_displacement = 5, stride1 = 1, stride2 = 1, pad_size = 5, is_multiply = False) + verify((5, 1, 15, 15), kernel_size = 1, max_displacement = 5, stride1 = 1, stride2 = 1, pad_size = 5, is_multiply = True) + verify((5, 1, 15, 15), kernel_size = 1, max_displacement = 10, stride1 = 1, stride2 = 2, pad_size = 10, is_multiply = True) + verify((5, 1, 4, 4), kernel_size = 3, max_displacement = 1, stride1 = 1, stride2 = 1, pad_size = 2, is_multiply = True) + verify((5, 1, 4, 4), kernel_size = 3, max_displacement = 1, stride1 = 2, stride2 = 1, pad_size = 2, is_multiply = True) + verify((5, 1, 4, 4), kernel_size = 3, max_displacement = 1, stride1 = 2, stride2 = 1, pad_size = 2, is_multiply = False) + verify((5, 1, 6, 4), kernel_size = 3, max_displacement = 1, stride1 = 2, stride2 = 1, pad_size = 2, is_multiply = False) + verify((5, 1, 11, 11), kernel_size = 5, max_displacement = 1, stride1 = 1, stride2 = 1, pad_size = 2, is_multiply = False) + + if __name__ == '__main__': test_forward_mlp() test_forward_vgg() @@ -1004,6 +1204,7 @@ def _verify_swap_axis(in_shape, out_shape, dim1, dim2): test_forward_rrelu() test_forward_prelu() test_forward_softrelu() + test_forward_softmin() test_forward_fc_flatten() test_forward_clip() test_forward_split() @@ -1012,6 +1213,7 @@ def _verify_swap_axis(in_shape, out_shape, dim1, dim2): test_forward_pad() test_forward_slice() test_forward_pooling() + test_forward_pooling3d() test_forward_adaptive_pooling() test_forward_lrn() test_forward_ones() @@ -1023,7 +1225,10 @@ def _verify_swap_axis(in_shape, out_shape, dim1, dim2): test_forward_where() test_forward_arange() test_forward_broadcast_ops() + test_forward_broadcast_to() + test_forward_logical_not() test_forward_elemwise_ops() + test_forward_unary_ops() test_forward_scalar_ops() test_forward_slice_like() test_forward_slice_axis() @@ -1047,9 +1252,14 @@ def _verify_swap_axis(in_shape, out_shape, dim1, dim2): test_forward_instance_norm() test_forward_layer_norm() test_forward_one_hot() + test_forward_depth_to_space() + test_forward_space_to_depth() test_forward_convolution() test_forward_deconvolution() test_forward_cond() test_forward_make_loss() test_forward_unravel_index() - test_forward_swap_axis() \ No newline at end of file + test_forward_swap_axis() + test_forward_correlation() + test_forward_grid_generator() + test_forward_bilinear_sampler() diff --git a/tests/python/frontend/mxnet/test_qnn_ops_utils.py b/tests/python/frontend/mxnet/test_qnn_ops_utils.py index 32042562b209..541162d79afe 100644 --- a/tests/python/frontend/mxnet/test_qnn_ops_utils.py +++ b/tests/python/frontend/mxnet/test_qnn_ops_utils.py @@ -15,11 +15,15 @@ # specific language governing permissions and limitations # under the License. -import tvm -from tvm import te import numpy as np +import tvm from tvm import relay from tvm.contrib import graph_runtime +from tvm.relay.frontend.mxnet_qnn_op_utils import dequantize_mxnet_min_max, \ + quantize_mxnet_min_max, \ + get_mkldnn_int8_scale, \ + get_mkldnn_uint8_scale, \ + quantize_conv_bias_mkldnn_from_var def test_mkldnn_dequantize(): @@ -29,14 +33,13 @@ def dequantize_test_driver(in_dtype, quant_args, in_data, verify_output_data): input_data = relay.var("input_data", shape=shape, dtype=in_dtype) min_range = quant_args['min_range'] max_range = quant_args['max_range'] - dequantized_output = \ - relay.frontend.dequantize_mxnet_min_max(input_data, - min_range=min_range, - max_range=max_range, - in_dtype=in_dtype) + dequantized_output = dequantize_mxnet_min_max(input_data, + min_range=min_range, + max_range=max_range, + in_dtype=in_dtype) mod = relay.Function(relay.analysis.free_vars(dequantized_output), dequantized_output) mod = tvm.IRModule.from_expr(mod) - with relay.build_config(opt_level=3): + with tvm.transform.PassContext(opt_level=3): graph, lib, params = relay.build(mod, "llvm", params=None) rt_mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0)) rt_mod.set_input(input_data=in_data) @@ -79,20 +82,18 @@ def test_int8_to_float32(): def test_mkldnn_quantize(): - def quantize_test_driver(out_dtype, quant_args, in_data, verify_output_data): shape = in_data.shape input_data = relay.var("input_data", shape=shape, dtype='float32') min_range = quant_args['min_range'] max_range = quant_args['max_range'] - quantized_output, _, _ = \ - relay.frontend.quantize_mxnet_min_max(input_data, - min_range=min_range, - max_range=max_range, - out_dtype=out_dtype) + quantized_output, _, _ = quantize_mxnet_min_max(input_data, + min_range=min_range, + max_range=max_range, + out_dtype=out_dtype) mod = relay.Function(relay.analysis.free_vars(quantized_output), quantized_output) mod = tvm.IRModule.from_expr(mod) - with relay.build_config(opt_level=3): + with tvm.transform.PassContext(opt_level=3): graph, lib, params = relay.build(mod, "llvm", params=None) rt_mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0)) rt_mod.set_input(input_data=in_data) @@ -140,8 +141,8 @@ def test_get_mkldnn_int8_scale(): range_min = -3.904039 range_max = 3.904039 expected = 0.03061991354976495 - output = relay.frontend.get_mkldnn_int8_scale(range_max=range_max, - range_min=range_min) + output = get_mkldnn_int8_scale(range_max=range_max, + range_min=range_min) assert np.allclose(output, expected) @@ -149,15 +150,15 @@ def test_get_mkldnn_uint8_scale(): range_min = 0.0 range_max = 55.77269 expected = 0.21828841189047482 - output = relay.frontend.get_mkldnn_uint8_scale(range_max=range_max, - range_min=range_min) + output = get_mkldnn_uint8_scale(range_max=range_max, + range_min=range_min) assert np.allclose(output, expected) def test_quantize_conv_bias_mkldnn_from_var(): bias_var = relay.var('bias', shape=(3,), dtype='float32') bias_scale = tvm.nd.array(np.array([0.5, 0.6, 0.7])) - output = relay.frontend.quantize_conv_bias_mkldnn_from_var(bias_var, bias_scale) + output = quantize_conv_bias_mkldnn_from_var(bias_var, bias_scale) assert isinstance(output, tvm.relay.expr.Call) attrs = output.attrs assert attrs.axis == 0 @@ -171,4 +172,4 @@ def test_quantize_conv_bias_mkldnn_from_var(): test_mkldnn_quantize() test_get_mkldnn_int8_scale() test_get_mkldnn_uint8_scale() - test_quantize_conv_bias_mkldnn_from_var() \ No newline at end of file + test_quantize_conv_bias_mkldnn_from_var() diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 2c0849451a25..665cb7bffd4f 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -64,7 +64,8 @@ def get_tvm_output(graph_def, input_data, target, ctx, output_shape=None, output input_names, shape_dict = get_input_data_shape_dict(graph_def, input_data) mod, params = relay.frontend.from_onnx(graph_def, shape_dict, opset=opset) - with relay.build_config(opt_level=1): + + with tvm.transform.PassContext(opt_level=1): graph, lib, params = relay.build(mod, target, params=params) @@ -407,6 +408,41 @@ def test_gather(): verify_gather((4, 3, 5, 6), [[2, 1, 0, 0]], 0, 'float32') +def verify_scatter(in_shape, indices, axis): + x = np.random.uniform(size=in_shape).astype("float32") + indices = np.array(indices, dtype="int32") + updates = np.random.uniform(size=indices.shape).astype("float32") + + y = helper.make_node("ScatterElements", ['data', 'indices', 'updates'], ['output'], axis=axis) + + graph = helper.make_graph([y], + 'scatter_test', + inputs=[helper.make_tensor_value_info("data", + TensorProto.FLOAT, list(in_shape)), + helper.make_tensor_value_info("indices", + TensorProto.INT32, list(indices.shape)), + helper.make_tensor_value_info("updates", + TensorProto.FLOAT, list(indices.shape))], + outputs=[helper.make_tensor_value_info("output", + TensorProto.FLOAT, list(in_shape))]) + model = helper.make_model(graph, producer_name='scatter_test') + onnx_out = get_onnxruntime_output(model, [x, indices, updates]) + + for target, ctx in ctx_list(): + tvm_out = get_tvm_output( + model, [x, indices, updates], target, ctx, onnx_out[0].shape) + tvm.testing.assert_allclose(onnx_out[0], tvm_out) + + +def test_scatter(): + verify_scatter((4,), [1], 0) + verify_scatter((1, 4), [[0]], 0) + verify_scatter((4,), [2, 3], 0) + verify_scatter((2, 2), [[1, 0], [0, 1]], 1) + verify_scatter((3, 3, 3), [[[-1, -3]]], -1) + verify_scatter((4, 3, 5, 6), [[[[2, 1, 0, 0]]]], 0) + + def _test_slice_iteration_v1(indata, outdata, starts, ends, axes=None): if axes: y = helper.make_node( @@ -542,6 +578,70 @@ def test_clip(): {'min': -1.0, 'max': 1.0}) + +def test_round(): + _test_onnx_op_elementwise((2, 4, 5, 6), np.round, {}, 'float32', 'Round', {}) + + +def _test_finite_ops(inshape, outfunc, npargs, dtype, opname, kwargs): + indata = np.random.choice(a=[np.nan, np.inf, -np.inf, 0.5, 1.0, 0], size=inshape).astype(dtype) + + outdata = outfunc(indata, **npargs) + y = helper.make_node(opname, ['in'], ['out'], **kwargs) + + graph = helper.make_graph([y], + opname+'_test', + inputs=[helper.make_tensor_value_info("in", + TensorProto.FLOAT, list(indata.shape))], + outputs=[helper.make_tensor_value_info("out", + TensorProto.BOOL, list(outdata.shape))]) + + model = helper.make_model(graph, producer_name=opname+'_test') + + for target, ctx in ctx_list(): + tvm_out = get_tvm_output( + model, indata, target, ctx, outdata.shape, dtype) + + tvm.testing.assert_allclose(outdata, tvm_out) + + +def test_isinf(): + _test_finite_ops((2, 4, 5, 6), np.isinf, {}, 'float32', 'IsInf', {}) + + +def test_isnan(): + _test_finite_ops((2, 4, 5, 6), np.isnan, {}, 'float32', 'IsNaN', {}) + + +def verify_gather_nd(in_shape, indices, dtype): + x = np.random.uniform(size=in_shape).astype(dtype) + indices = np.array(indices, dtype="int32") + out_np = topi.testing.gather_nd_python(x, indices) + + y = helper.make_node("GatherND", ['in', 'indices'], ['out']) + + graph = helper.make_graph([y], + 'gather_test', + inputs=[helper.make_tensor_value_info("in", + TensorProto.FLOAT, list(in_shape)), + helper.make_tensor_value_info("indices", + TensorProto.INT32, list(indices.shape))], + outputs=[helper.make_tensor_value_info("out", + TensorProto.FLOAT, list(out_np.shape))]) + model = helper.make_model(graph, producer_name='gather_test') + + for target, ctx in ctx_list(): + tvm_out = get_tvm_output( + model, [x, indices], target, ctx, out_np.shape) + tvm.testing.assert_allclose(out_np, tvm_out) + + +def test_gather_nd(): + verify_gather_nd((2, 2), [[0,0],[1,1]], 'int32') + verify_gather_nd((3, 3, 3), [[0,1],[1,0]] , 'float32') + verify_gather_nd((4, 3, 5, 6), [[2, 1, 0, 0]], 'float32') + + def test_onehot(): indices_shape = [10] indices_array = np.random.randint( @@ -795,21 +895,22 @@ def _test_upsample_bilinear_opset9(): in_shape = (1, 1, 3, 3) out_shape = (1, 1, 3*scale, 3*scale) y = helper.make_node("Upsample", ['in', 'scales'], ['out'], mode='linear') - scales = [1.0, 1.0, 2.0, 2.0] + scales = [1, 1, 2, 2] in_array = np.random.uniform(size=in_shape).astype(np.float32) out_array = topi.testing.bilinear_resize_python( in_array, (3*scale, 3*scale), "NCHW") - ref_array = np.array(scales) ref_node = helper.make_node('Constant', inputs=[], - outputs=['scales'], + outputs=['const'], value=onnx.helper.make_tensor(name='const_tensor', data_type=TensorProto.FLOAT, - dims=ref_array.shape, - vals=ref_array.flatten().astype(float))) + dims=scales, + vals=np.random.random(scales).flatten().astype(float))) - graph = helper.make_graph([ref_node, y], + shape_node = helper.make_node("Shape", ['const'], ['scales']) + + graph = helper.make_graph([ref_node, shape_node, y], 'upsample_bilinear_opset9_test', inputs=[helper.make_tensor_value_info( "in", TensorProto.FLOAT, list(in_shape))], @@ -1214,7 +1315,63 @@ def verify_pad(indata, pads, mode='constant', value=0.0): # tvm result for target, ctx in ctx_list(): tvm_out = get_tvm_output( - model, indata, target, ctx, outdata.shape, 'float32') + model, indata, target, ctx, outdata.shape, 'float32', opset=2) + tvm.testing.assert_allclose(outdata, tvm_out, rtol=1e-5, atol=1e-5) + + +def verify_pad_v11(indata, pads, mode='constant', value=0.0): + indata = np.array(indata).astype(np.float32) + # numpy expect result + len_dim = len(pads) // 2 + np_pads = [(pads[i], pads[i+len_dim]) for i in range(len_dim)] + pads = np.array(pads) + # onnx graph + if mode in ['edge', 'reflect']: + inputs = [indata, pads] + outdata = np.pad(indata, pad_width=np_pads, mode=mode) + node = helper.make_node( + 'Pad', + inputs=['input', 'pads'], + outputs=['output'], + mode=mode + ) + graph = helper.make_graph([node], + 'pad_test', + inputs=[helper.make_tensor_value_info("input", + TensorProto.FLOAT, list(indata.shape)), + helper.make_tensor_value_info("pads", + TensorProto.INT64,(len(pads),))], + initializer=[helper.make_tensor("pads", TensorProto.INT64, (len(pads),), pads)], + outputs=[helper.make_tensor_value_info("output", + TensorProto.FLOAT, list(outdata.shape))]) + else: + inputs = [indata, pads, np.array([value])] + outdata = np.pad(indata, pad_width=np_pads, + mode='constant', constant_values=value) + node = helper.make_node( + 'Pad', + inputs=['input', 'pads', 'constant_value'], + outputs=['output'], + mode='constant' + ) + graph = helper.make_graph([node], + 'pad_test', + inputs=[helper.make_tensor_value_info("input", + TensorProto.FLOAT, list(indata.shape)), + helper.make_tensor_value_info("pads", + TensorProto.INT64,(len(pads),)), + helper.make_tensor_value_info("constant_value", + TensorProto.INT64,(1,)), + ], + initializer=[helper.make_tensor("pads", TensorProto.INT64, (len(pads),), pads), + helper.make_tensor("constant_value", TensorProto.FLOAT, (1,), [value])], + outputs=[helper.make_tensor_value_info("output", + TensorProto.FLOAT, list(outdata.shape))]) + model = helper.make_model(graph, producer_name='pad_test') + # tvm result + for target, ctx in ctx_list(): + tvm_out = get_tvm_output( + model, inputs, target, ctx, outdata.shape, 'float32', opset=11) tvm.testing.assert_allclose(outdata, tvm_out, rtol=1e-5, atol=1e-5) @@ -1230,89 +1387,83 @@ def test_pad(): verify_pad(np.random.randn(1, 3, 4, 5).astype( np.float32), [0, 0, 1, 1, 0, 0, 1, 1], 'reflect') + verify_pad_v11(np.random.randn(2, 2).astype( + np.float32), [0, 1, 0, 0], 'constant', 0.0) + verify_pad_v11(np.random.randn(2, 3).astype( + np.float32), [1, 0, 0, 1], 'constant', 0.0) + verify_pad_v11(np.random.randn(3, 2).astype( + np.float32), [0, 0, 1, 0], 'constant', 5.0) + verify_pad_v11(np.random.randn(1, 3, 4, 5).astype( + np.float32), [0, 0, 1, 1, 0, 0, 1, 1], 'edge') + verify_pad_v11(np.random.randn(1, 3, 4, 5).astype( + np.float32), [0, 0, 1, 1, 0, 0, 1, 1], 'reflect') + -def verify_reduce_x(name, indata, axis, keepdims): - indata = np.array(indata).astype(np.float32) - # numpy expect result - if name == 'ReduceMax': - outdata = np.maximum.reduce(indata, axis=axis, keepdims=keepdims == 1) - elif name == 'ReduceMin': - outdata = np.minimum.reduce(indata, axis=axis, keepdims=keepdims == 1) - elif name == 'ReduceSum': - outdata = np.sum(indata, axis=axis, keepdims=keepdims == 1) - elif name == 'ReduceMean': - outdata = np.mean(indata, axis=axis, keepdims=keepdims == 1) - else: - raise Exception('unsupport op: {}'.format(name)) - if len(np.asarray(outdata).shape) == 0: - outdata = np.asarray([outdata]) - # onnx graph - if axis is None: - node = helper.make_node(name, inputs=['input'], outputs=['output'], - keepdims=keepdims) +def verify_reduce_func(func, data, axis, keepdims): + inshape = data.shape + outshape = np.sum(data, axis=axis, keepdims=keepdims == 1).shape + + if axis: + node = onnx.helper.make_node(func, + inputs=['x'], + outputs=['y'], + axes=axis, + keepdims=keepdims) else: - node = helper.make_node(name, inputs=['input'], outputs=['output'], - axes=axis, keepdims=keepdims) + node = onnx.helper.make_node(func, + inputs=['x'], + outputs=['y'], + keepdims=keepdims) + graph = helper.make_graph([node], - '{}_test'.format(name), - inputs=[helper.make_tensor_value_info("input", - TensorProto.FLOAT, list(indata.shape))], - outputs=[helper.make_tensor_value_info("output", - TensorProto.FLOAT, list(outdata.shape))]) - model = helper.make_model(graph, producer_name='{}_test'.format(name)) - # tvm result - for target, ctx in ctx_list(): - tvm_out = get_tvm_output( - model, indata, target, ctx, outdata.shape, 'float32') - tvm.testing.assert_allclose(outdata, tvm_out, rtol=1e-5, atol=1e-5) + "reduce_test", + inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, list(inshape))], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(outshape))]) + model = helper.make_model(graph, producer_name='reduce_test') -def test_reduce_max(): - verify_reduce_x("ReduceMax", - np.random.randn(3, 2, 2).astype(np.float32), - axis=None, keepdims=1) - verify_reduce_x("ReduceMax", - np.random.randn(3, 2, 3).astype(np.float32), - axis=None, keepdims=0) - verify_reduce_x("ReduceMax", - np.random.randn(3, 3, 3).astype(np.float32), - axis=(1,), keepdims=1) - - -def test_reduce_min(): - verify_reduce_x("ReduceMin", - np.random.randn(3, 2, 2).astype(np.float32), - axis=None, keepdims=1) - verify_reduce_x("ReduceMin", - np.random.randn(3, 2, 3).astype(np.float32), - axis=None, keepdims=0) - verify_reduce_x("ReduceMin", - np.random.randn(3, 3, 3).astype(np.float32), - axis=(1,), keepdims=1) - - -def test_reduce_sum(): - verify_reduce_x("ReduceSum", - np.random.randn(3, 2, 2).astype(np.float32), - axis=None, keepdims=1) - verify_reduce_x("ReduceSum", - np.random.randn(3, 2, 3).astype(np.float32), - axis=None, keepdims=0) - verify_reduce_x("ReduceSum", - np.random.randn(3, 3, 3).astype(np.float32), - axis=(1,), keepdims=1) - - -def test_reduce_mean(): - verify_reduce_x("ReduceMean", - np.random.randn(3, 2, 2).astype(np.float32), - axis=None, keepdims=1) - verify_reduce_x("ReduceMean", - np.random.randn(3, 2, 3).astype(np.float32), - axis=None, keepdims=0) - verify_reduce_x("ReduceMean", - np.random.randn(3, 3, 3).astype(np.float32), - axis=(1,), keepdims=1) + onnx_out = get_onnxruntime_output(model, data, 'float32') + for target, ctx in ctx_list(): + tvm_out = get_tvm_output(model, data, target, ctx, outshape, 'float32') + tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-5, atol=1e-5) + +def test_all_reduce_funcs(): + funcs = ["ReduceMax", + "ReduceMean", + "ReduceMin", + "ReduceProd", + "ReduceSum", + 'ReduceSumSquare', + "ReduceLogSum", + "ReduceLogSumExp", + "ReduceL1", + "ReduceL2"] + + for func in funcs: + for keepdims in [True, False]: + verify_reduce_func(func, + np.random.randn(3, 2, 2).astype(np.float32), + axis=None, keepdims=keepdims) + + verify_reduce_func(func, + np.random.randn(3, 2, 3).astype(np.float32), + axis=None, keepdims=keepdims) + + verify_reduce_func(func, + np.random.randn(3, 3, 3).astype(np.float32), + axis=(1,), keepdims=keepdims) + + verify_reduce_func(func, + np.random.randn(3, 3, 3, 1).astype(np.float32), + axis=(1, 2), keepdims=keepdims) + + verify_reduce_func(func, + np.random.randn(3, 3, 3, 1).astype(np.float32), + axis=(1,), keepdims=keepdims) + + verify_reduce_func(func, + np.random.randn(1, 3, 4, 1).astype(np.float32), + axis=(1,), keepdims=keepdims) def verify_split(indata, outdatas, split, axis=0): @@ -1366,16 +1517,16 @@ def test_binary_ops(): dtype = "float32" out_shape = in_shape - def verify_binary_ops(op, x, y, out_np, broadcast=None): + def verify_binary_ops(op, x, y, out_np, x_name='in1', y_name='in2', broadcast=None): if broadcast is None: - z = helper.make_node(op, ['in1', 'in2'], ['out']) + z = helper.make_node(op, [x_name, y_name], ['out']) else: - z = helper.make_node(op, ['in1', 'in2'], ['out'], broadcast=1) + z = helper.make_node(op, [x_name, y_name], ['out'], broadcast=1) graph = helper.make_graph([z], '_test', - inputs=[helper.make_tensor_value_info("in1", + inputs=[helper.make_tensor_value_info(x_name, TensorProto.FLOAT, list(in_shape)), - helper.make_tensor_value_info("in2", + helper.make_tensor_value_info(y_name, TensorProto.FLOAT, list(in_shape))], outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))]) @@ -1393,6 +1544,7 @@ def verify_binary_ops(op, x, y, out_np, broadcast=None): verify_binary_ops("Sub", x, z, x - z, broadcast=True) verify_binary_ops("Mul", x, y, x * y, broadcast=None) verify_binary_ops("Mul", x, z, x * z, broadcast=True) + verify_binary_ops("Mul", x, x, x * x, x_name='in1', y_name='in1', broadcast=None) verify_binary_ops("Div", x, y, x / y, broadcast=None) verify_binary_ops("Div", x, z, x / z, broadcast=True) verify_binary_ops("Sum", x, y, x + y, broadcast=None) @@ -1428,6 +1580,17 @@ def verify_single_ops(op, x, out_np, rtol=1e-5, atol=1e-5): verify_single_ops("Exp", x, np.exp(x)) verify_single_ops("Log", x, np.log(x)) verify_single_ops("Log", x, np.log(x)) + verify_single_ops("ACos", x, np.arccos(x)) + verify_single_ops("ACosh", x, np.arccosh(x)) + verify_single_ops("ASin", x, np.arcsin(x)) + verify_single_ops("ASinh", x, np.arcsinh(x)) + verify_single_ops("ATan", x, np.arctan(x)) + verify_single_ops("ATanh", x, np.arctanh(x)) + verify_single_ops("Cos", x, np.cos(x)) + verify_single_ops("Cosh", x, np.cosh(x)) + verify_single_ops("Sin", x, np.sin(x)) + verify_single_ops("Sinh", x, np.sinh(x)) + verify_single_ops("Tan", x, np.tan(x)) verify_single_ops("Tanh", x, np.tanh(x)) verify_single_ops("Sigmoid", x, 1 / (1 + np.exp(-x))) verify_single_ops("Softsign", x, x / (1 + np.abs(x))) @@ -1467,6 +1630,34 @@ def selu_x(x, alpha, gamma): {'alpha': 0.25, 'gamma': 0.3}) +def test_prelu(): + def verify_prelu(x_shape, a_shape): + node = helper.make_node('PRelu', + inputs=['X', 'slope'], + outputs=['Y']) + + graph = helper.make_graph([node], + "prelu_test", + inputs=[helper.make_tensor_value_info("X", TensorProto.FLOAT, list(x_shape)), + helper.make_tensor_value_info("slope", TensorProto.FLOAT, list(a_shape))], + outputs=[helper.make_tensor_value_info("Y", TensorProto.FLOAT, list(x_shape))]) + + model = helper.make_model(graph, producer_name='prelu_test') + + indata = np.random.uniform(-10, 10, x_shape).astype(np.float32) + slopedata = np.random.uniform(-10, 10, a_shape).astype(np.float32) + onnx_out = get_onnxruntime_output(model, [indata, slopedata]) + + for target, ctx in [('llvm', tvm.cpu())]: + tvm_out = get_tvm_output(model, [indata, slopedata], target, ctx, list(x_shape), + output_dtype='float32') + tvm.testing.assert_allclose(onnx_out[0], tvm_out, rtol=1e-05, atol=1e-05) + + verify_prelu([3,4,5,6], [1, 4, 1, 1]) + verify_prelu([1,8,5,6], [1, 8, 1, 1]) + verify_prelu([2,12,16,16], [1, 12, 1, 1]) + + def test_ThresholdedRelu(): def ThresholdedRelu_x(x, alpha): out_np = np.clip(x, alpha, np.inf) @@ -1830,8 +2021,18 @@ def test_or(): verify_or(indata=[x, y], dtype=bool) -def verify_conv(x_shape, w_shape, y_shape, padding, kernel_shape, strides, dilations, auto_pad="NOTSET"): - if padding is None: +def verify_conv(x_shape, w_shape, y_shape, padding, kernel_shape, strides, dilations, auto_pad="NOTSET", unset_pad=False): + if unset_pad: + node = helper.make_node('Conv', + inputs=['x', 'W'], + outputs=['y'], + kernel_shape=kernel_shape, + # Default values for other attributes: + strides=strides, + dilations=dilations, + # groups=1 + ) + elif padding is None: node = helper.make_node('Conv', inputs=['x', 'W'], outputs=['y'], @@ -1897,6 +2098,15 @@ def repeat(N, D): repeat(1, D), repeat(1, D), auto_pad="SAME_UPPER") + # Convolution with unset padding + verify_conv((1, 1) + repeat(5, D), + (1, 1) + repeat(3, D), + (1, 1) + repeat(3, D), + 2 * repeat(0, D), + repeat(3, D), + repeat(1, D), + repeat(1, D), + True) # Convolution with non uniform stride verify_conv((1, 1) + repeat(5, D), (1, 1) + repeat(3, D), @@ -1977,20 +2187,18 @@ def verify_pooling(x_shape, kernel_shape, strides, pads, out_shape, mode, auto_p else: raise ValueError("Pool method {} is not supported.".format(mode)) + pool_node = helper.make_node( + node_type, inputs=["x"], outputs=["y"], kernel_shape=kernel_shape, strides=strides) + if pads is None: - pool_node = helper.make_node(node_type, - inputs=["x"], - outputs=["y"], - kernel_shape=kernel_shape, - auto_pad=auto_pad, - strides=strides) + pad_attr = helper.make_attribute('auto_pad', auto_pad) else: - pool_node = helper.make_node(node_type, - inputs=["x"], - outputs=["y"], - kernel_shape=kernel_shape, - pads=pads, - strides=strides) + pad_attr = helper.make_attribute('pads', pads) + pool_node.attribute.append(pad_attr) + + if mode == 'max': + storage_attr = helper.make_attribute('storage_order', 0) + pool_node.attribute.append(storage_attr) graph = helper.make_graph([pool_node], "pooling_test", @@ -2075,6 +2283,205 @@ def test_pooling(): auto_pad='SAME_UPPER') +def verify_mod(x_shape, y_shape, fmod, dtype='float32'): + x_np = np.random.uniform(size=x_shape).astype(dtype) + y_np = np.random.uniform(size=y_shape).astype(dtype) + y_np = np.where(y_np==0, 1, y_np) #remove 0's to avoid division by zero error + + if fmod: + np_out = np.fmod(x_np, y_np) + else: + np_out = np.mod(x_np, y_np) + + out_shape = np_out.shape + mod_node = helper.make_node("Mod", + inputs=["x", "y"], + outputs=["z"], + fmod=fmod) + + onnx_dtype = TensorProto.FLOAT if dtype == "float32" else TensorProto.INT32 + graph = helper.make_graph([mod_node], + "mod_test", + inputs=[helper.make_tensor_value_info("x", + onnx_dtype, list(x_shape)), + helper.make_tensor_value_info("y", + onnx_dtype, list(y_shape))], + outputs=[helper.make_tensor_value_info("z", + onnx_dtype, list(out_shape))]) + model = helper.make_model(graph, producer_name='mod_test') + + for target, ctx in ctx_list(): + tvm_out = get_tvm_output( + model, [x_np, y_np], target, ctx, out_shape) + tvm.testing.assert_allclose(np_out, tvm_out, rtol=1e-5, atol=1e-5) + + +def test_mod(): + # Mod + verify_mod(x_shape=[1, 32, 32], y_shape=[1, 32, 32], fmod=0) + + verify_mod(x_shape=[1, 32, 32], y_shape=[1, 1, 32], fmod=0, dtype="int32") + + # fmod + verify_mod(x_shape=[1, 1, 32], y_shape=[1, 32, 32], fmod=1) + + verify_mod(x_shape=[1, 32, 32], y_shape=[1, 32, 32], fmod=1, dtype="int32") + + +def verify_xor(x_shape, y_shape): + x_np = np.random.choice(a=[False, True], size=x_shape).astype("bool") + y_np = np.random.choice(a=[False, True], size=y_shape).astype("bool") + + np_out = np.logical_xor(x_np, y_np) + out_shape = np_out.shape + + xor_node = helper.make_node("Xor", + inputs=["x", "y"], + outputs=["z"]) + + onnx_dtype = TensorProto.BOOL + graph = helper.make_graph([xor_node], + "xor_test", + inputs=[helper.make_tensor_value_info("x", + onnx_dtype, list(x_shape)), + helper.make_tensor_value_info("y", + onnx_dtype, list(y_shape))], + outputs=[helper.make_tensor_value_info("z", + onnx_dtype, list(out_shape))]) + model = helper.make_model(graph, producer_name='xor_test') + + for target, ctx in ctx_list(): + tvm_out = get_tvm_output( + model, [x_np, y_np], target, ctx, out_shape) + tvm.testing.assert_allclose(np_out, tvm_out, rtol=1e-5, atol=1e-5) + + +def test_xor(): + # XOR + verify_xor(x_shape=[1, 32, 32], y_shape=[1, 32, 32]) + + # Xor broadcast + verify_xor(x_shape=[1, 32, 32], y_shape=[1, 1, 32]) + + +def verify_max_roi_pool(x_shape, rois_shape, pooled_shape, spatial_scale, out_shape): + x_np = np.random.uniform(size=x_shape).astype('float32') + rois_np = np.random.uniform(size=rois_shape).astype('float32') + + if spatial_scale is None: + pool_node = helper.make_node("MaxRoiPool", + inputs=["x", "rois"], + outputs=["y"], + pooled_shape=pooled_shape) + else: + pool_node = helper.make_node("MaxRoiPool", + inputs=["x", "rois"], + outputs=["y"], + pooled_shape=pooled_shape, + spatial_scale=spatial_scale) + + graph = helper.make_graph([pool_node], + "pool_test", + inputs=[helper.make_tensor_value_info("x", + TensorProto.FLOAT, list(x_shape)), + helper.make_tensor_value_info("rois", + TensorProto.FLOAT, list(rois_shape))], + outputs=[helper.make_tensor_value_info("y", + TensorProto.FLOAT, list(out_shape))]) + + model = helper.make_model(graph, producer_name='pool_test') + + onnx_out = get_onnxruntime_output(model, [x_np, rois_np], 'float32')[0] + for target, ctx in ctx_list(): + tvm_out = get_tvm_output( + model, [x_np, rois_np], target, ctx, out_shape) + tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-5, atol=1e-5) + + +def test_max_roi_pool(): + verify_max_roi_pool(x_shape=[1, 3, 6, 6], + rois_shape=[3, 5], + pooled_shape=[1, 1], + spatial_scale=None, + out_shape=[3, 3, 1, 1]) + + verify_max_roi_pool(x_shape=[1, 3, 10, 10], + rois_shape=[4, 5], + pooled_shape=[2, 2], + spatial_scale=2.0, + out_shape=[4, 3, 2, 2]) + + +def verify_lppool(x_shape, kernel_shape, p, strides, pads, out_shape, auto_pad="NOTSET"): + x_np = np.random.uniform(size=x_shape).astype('float32') + + if pads is None: + pool_node = helper.make_node("LpPool", + inputs=["x"], + outputs=["y"], + kernel_shape=kernel_shape, + p = p, + auto_pad=auto_pad, + strides=strides) + else: + pool_node = helper.make_node("LpPool", + inputs=["x"], + outputs=["y"], + kernel_shape=kernel_shape, + p = p, + pads=pads, + strides=strides) + + graph = helper.make_graph([pool_node], + "lppool_test", + inputs=[helper.make_tensor_value_info("x", + TensorProto.FLOAT, list(x_shape))], + outputs=[helper.make_tensor_value_info("y", + TensorProto.FLOAT, list(out_shape))]) + + model = helper.make_model(graph, producer_name='lppool_test') + + for target, ctx in ctx_list(): + onnx_out = get_onnxruntime_output(model, x_np, 'float32') + tvm_out = get_tvm_output( + model, [x_np], target, ctx, out_shape) + tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-5, atol=1e-5) + + +def test_lppool(): + # Pool1D + verify_lppool(x_shape=[1, 1, 32], kernel_shape=[3], p=2, strides=[1], pads=[1, 1], + out_shape=[1, 1, 32]) + + # Pool2D + verify_lppool(x_shape=[1, 1, 32, 32], kernel_shape=[3, 3], p=2, strides=[1, 1], + pads=[1, 1, 1, 1], out_shape=[1, 1, 32, 32]) + + # Pool1D with stride + verify_lppool(x_shape=[1, 1, 32], kernel_shape=[3], p=2, strides=[2], pads=[1, 1], + out_shape=[1, 1, 16]) + + # Pool2D with stride + verify_lppool(x_shape=[1, 1, 32, 32], kernel_shape=[3, 3], p=2, strides=[2, 2], + pads=[1, 1, 1, 1], out_shape=[1, 1, 16, 16]) + + # Pool1D with stride and autopadding + verify_lppool(x_shape=[1, 1, 32], kernel_shape=[3], p=2, strides=[2], pads=None, + out_shape=[1, 1, 16], auto_pad='SAME_UPPER') + + # Pool2D with stride and autopadding + verify_lppool(x_shape=[1, 1, 32, 32], kernel_shape=[3, 3], p=2, strides=[2, 2], + pads=None, out_shape=[1, 1, 16, 16], auto_pad='SAME_UPPER') + + # Pool3D with stride + verify_lppool(x_shape=[1, 1, 32, 32, 32], kernel_shape=[3, 3, 3], p=2, strides=[2, 2, 2], + pads=[1, 1, 1, 1, 1, 1], out_shape=[1, 1, 16, 16, 16]) + + # Pool3D with stride and autopadding + verify_lppool(x_shape=[1, 1, 32, 32, 32], kernel_shape=[3, 3, 3], p=2, strides=[2, 2, 2], + pads=None, out_shape=[1, 1, 16, 16, 16], auto_pad='SAME_UPPER') + + def verify_lstm(seq_length, batch_size, input_size, @@ -2329,6 +2736,105 @@ def verify_nonzero(indata, outdata, dtype): result = np.array((np.nonzero(input_data))) # expected output [[0, 1, 2, 2], [0, 1, 0, 1]] verify_nonzero(input_data, result, dtype=np.int64) +def test_topk(): + def verify_topk(input_dims, K, axis=-1): + output_dims = list(input_dims) + output_dims[axis] = K + + node = helper.make_node('TopK', + inputs=['X', 'K'], + outputs=['Values', 'Indicies'], + axis=axis) + + graph = helper.make_graph([node], + "topk_test", + inputs=[helper.make_tensor_value_info("X", TensorProto.FLOAT, list(input_dims)), + helper.make_tensor_value_info("K", TensorProto.INT64, [1,])], + initializer=[helper.make_tensor("K", TensorProto.INT64, [1], [K])], + outputs=[helper.make_tensor_value_info("Values", TensorProto.FLOAT, output_dims), + helper.make_tensor_value_info("Indicies", TensorProto.INT64, output_dims)]) + + model = helper.make_model(graph, producer_name='topk_test') + + indata = np.random.uniform(-10, 10, input_dims).astype(np.float32) + onnx_out = get_onnxruntime_output(model, [indata, k]) + + for target, ctx in [('llvm', tvm.cpu())]: + tvm_out = get_tvm_output(model, indata, target, ctx, [output_dims, output_dims], + output_dtype=['float32', 'int64']) + tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-05, atol=1e-05) + + for n in [12, 32]: + for shape in [[n], [n, n], [n, n, n]]: + for k in [1, 5, 10]: + verify_topk(shape, k) + + verify_topk([n, n, n], 5, 0) + verify_topk([n, n, n], 5, 1) + verify_topk([n, n, n], 5, 2) + + +def test_roi_align(): + def verify_roi_align(input_dims, num_roi, output_height, output_width, sampling_ratio=0, spatial_scale=1.0): + output_dims = [num_roi, input_dims[1], output_height, output_width] + + node = helper.make_node('RoiAlign', + inputs=['X', 'rois', 'batch_indicies'], + outputs=['Y'], + mode="avg", + output_height=output_height, + output_width=output_width, + sampling_ratio=sampling_ratio, + spatial_scale=spatial_scale, + ) + + graph = helper.make_graph([node], + "roialign_test", + inputs=[helper.make_tensor_value_info("X", TensorProto.FLOAT, list(input_dims)), + helper.make_tensor_value_info( + "rois", TensorProto.FLOAT, [num_roi, 4]), + helper.make_tensor_value_info( + "batch_indicies", TensorProto.INT64, [num_roi, ]), + ], + outputs=[helper.make_tensor_value_info("Y", TensorProto.FLOAT, output_dims)]) + + model = helper.make_model(graph, producer_name='roialign_test') + + np_data = np.random.uniform(size=input_dims).astype("float32") + np_rois = np.random.uniform(size=[num_roi, 4]).astype( + 'float32') * input_dims[2] + np_batch_indicies = np.random.randint( + low=0, high=input_dims[0], size=num_roi) + + onnx_out = get_onnxruntime_output( + model, [np_data, np_rois, np_batch_indicies]) + for target, ctx in [('llvm', tvm.cpu())]: + tvm_out = get_tvm_output(model, [np_data, np_rois, np_batch_indicies], target, ctx, output_dims, + output_dtype='float32') + tvm.testing.assert_allclose( + onnx_out[0], tvm_out, rtol=1e-05, atol=1e-05) + + verify_roi_align((1, 4, 16, 16), 32, 7, 7, + sampling_ratio=0, spatial_scale=1.0) + verify_roi_align((4, 4, 16, 32), 32, 7, 7, + sampling_ratio=0, spatial_scale=1.0) + verify_roi_align((1, 8, 16, 16), 32, 7, 7, + sampling_ratio=0, spatial_scale=1.0) + verify_roi_align((1, 4, 8, 8), 32, 7, 7, + sampling_ratio=0, spatial_scale=1.0) + verify_roi_align((1, 4, 16, 16), 16, 5, 7, + sampling_ratio=0, spatial_scale=1.0) + verify_roi_align((1, 4, 16, 12), 8, 7, 3, + sampling_ratio=0, spatial_scale=1.0) + verify_roi_align((1, 4, 16, 16), 32, 7, 7, + sampling_ratio=0, spatial_scale=0.5) + verify_roi_align((3, 4, 12, 16), 32, 7, 7, + sampling_ratio=0, spatial_scale=1.5) + verify_roi_align((5, 4, 16, 14), 32, 7, 7, + sampling_ratio=1, spatial_scale=1.0) + verify_roi_align((1, 4, 16, 16), 32, 7, 7, + sampling_ratio=2, spatial_scale=1.0) + if __name__ == '__main__': test_flatten() @@ -2341,11 +2847,16 @@ def verify_nonzero(indata, outdata, dtype): test_slice() test_floor() test_ceil() + test_round() + test_isinf() + test_isnan() test_clip() test_onehot() test_matmul() test_batch_matmul() test_gather() + test_gather_nd() + test_scatter() test_lrn() test_instance_norm() test_upsample() @@ -2356,10 +2867,7 @@ def verify_nonzero(indata, outdata, dtype): test_forward_arg_min_max() test_softmax() test_constantofshape() - test_reduce_max() - test_reduce_min() - test_reduce_sum() - test_reduce_mean() + test_all_reduce_funcs() test_pad() test_split() test_binary_ops() @@ -2367,6 +2875,7 @@ def verify_nonzero(indata, outdata, dtype): test_leaky_relu() test_elu() test_selu() + test_prelu() test_ThresholdedRelu() test_ScaledTanh() test_ParametricSoftplus() @@ -2388,6 +2897,12 @@ def verify_nonzero(indata, outdata, dtype): test_convtranspose() test_unsqueeze_constant() test_pooling() + test_lppool() test_lstm() test_resize() test_nonzero() + test_topk() + test_mod() + test_xor() + test_max_roi_pool() + test_roi_align() diff --git a/tests/python/frontend/pytorch/qnn_test.py b/tests/python/frontend/pytorch/qnn_test.py index bf5fa981e6f4..551cdc4cd418 100644 --- a/tests/python/frontend/pytorch/qnn_test.py +++ b/tests/python/frontend/pytorch/qnn_test.py @@ -41,7 +41,7 @@ def get_tvm_runtime(script_module, input_name, ishape): input_shapes = [(input_name, ishape)] mod, params = relay.frontend.from_pytorch(script_module, input_shapes) - with relay.build_config(opt_level=3): + with tvm.transform.PassContext(opt_level=3): # test on only cpu for now, torch cannot run quant models on cuda # also not to make CI too slow json, lib, params = relay.build(mod, target="llvm", params=params) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 4eba4d002fe9..96e9144e03cc 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -27,6 +27,7 @@ from tvm import relay from tvm.contrib import graph_runtime +from tvm.contrib.nvcc import have_fp16 from tvm.relay.testing.config import ctx_list @@ -135,7 +136,8 @@ def measure_latency(model, input_shapes, output_shapes, thresh, dryruns=40): def verify_model(model_name, input_data=[], custom_convert_map={}, - ctx_list=ctx_list()): + ctx_list=ctx_list(), + rtol=1e-5, atol=1e-5): """Assert that the output of a compiled model matches with that of its baseline.""" if isinstance(model_name, str): @@ -176,7 +178,7 @@ def verify_model(model_name, input_data=[], compiled_input = dict(zip(input_names, [inp.cpu().numpy() for inp in baseline_input])) - with relay.build_config(opt_level=3): + with tvm.transform.PassContext(opt_level=3): for target, ctx in ctx_list: relay_graph, relay_lib, relay_params = relay.build(mod, target=target, params=params) relay_model = graph_runtime.create(relay_graph, relay_lib, ctx) @@ -190,7 +192,7 @@ def verify_model(model_name, input_data=[], assert_shapes_match(baseline_output, compiled_output) tvm.testing.assert_allclose(baseline_output, compiled_output, - rtol=1e-3, atol=1e-3) + rtol=rtol, atol=atol) del model_name del baseline_model @@ -381,28 +383,61 @@ def test_forward_arange(): class Arange1(Module): def forward(self, *args): return torch.arange(5) + class Arange2(Module): def forward(self, *args): return torch.arange(2.5) + class Arange3(Module): def forward(self, *args): return torch.arange(1, 4) + class Arange4(Module): def forward(self, *args): return torch.arange(1, 2.5, 0.5) + class Arange5(Module): def forward(self, *args): return torch.arange(1, 2, 1, dtype=torch.int32) + class Arange6(Module): def forward(self, *args): return torch.arange(start=1, end=6, step=2) + class Arange7(Module): def forward(self, *args): return torch.arange(1, 4, dtype=torch.float32) + class Arange8(Module): def forward(self, *args): return torch.arange(1, 2, 1, dtype=torch.int16) + class Arange9(Module): + def forward(self, *args): + end = torch.add(torch.tensor(4), 1) + return torch.arange(end) + torch.ones((5,), dtype=torch.int64) + + class Arange10(Module): + def forward(self, *args): + end = torch.add(torch.tensor(4.0), torch.tensor(1.0)) + return torch.arange(end) + torch.ones((5,), dtype=torch.float) + + class Arange11(Module): + def forward(self, *args): + start = torch.add(torch.tensor(1), 1) + end = torch.add(torch.tensor(4), 1) + step = torch.add(torch.tensor(2), 1) + out = torch.arange(start, end, step) + return out + torch.ones((3,), dtype=torch.int64) + + class Arange12(Module): + def forward(self, *args): + start = torch.add(torch.tensor(1), 1) + end = torch.add(torch.tensor(4), 1) + step = torch.add(torch.tensor(2.5), torch.tensor(4.1)) + out = torch.arange(start, end, step) + return out + torch.ones((3,), dtype=torch.float) + verify_model(Arange1().float().eval()) verify_model(Arange2().float().eval()) verify_model(Arange3().float().eval()) @@ -411,6 +446,11 @@ def forward(self, *args): verify_model(Arange6().float().eval()) verify_model(Arange7().float().eval()) verify_model(Arange8().float().eval()) + verify_model(Arange9().float().eval()) + verify_model(Arange10().float().eval()) + verify_model(Arange11().float().eval()) + verify_model(Arange12().float().eval()) + def test_forward_abs(): torch.set_grad_enabled(False) @@ -728,6 +768,28 @@ def init_weight(m): init_weight(ln.eval()) verify_model(ln.eval(), input_data=inp) + +def test_forward_groupnorm(): + input_shape = [10, 6, 5, 5] + input_data = torch.rand(input_shape).float() + + # Separate 6 channels into 3 groups + verify_model(torch.nn.GroupNorm(3, 6).eval(), input_data=input_data) + + # Put all 6 channels into a single group (equivalent with LayerNorm) + verify_model(torch.nn.GroupNorm(1, 6).eval(), input_data=input_data) + + # Separate 6 channels into 6 groups (equivalent with InstanceNorm) + verify_model(torch.nn.GroupNorm(6, 6).eval(), input_data=input_data) + + input_shape = [1, 10, 4, 7] + input_data = torch.rand(input_shape).float() + verify_model(torch.nn.GroupNorm(1, 10).eval(), input_data=input_data) + verify_model(torch.nn.GroupNorm(2, 10).eval(), input_data=input_data) + verify_model(torch.nn.GroupNorm(5, 10).eval(), input_data=input_data) + verify_model(torch.nn.GroupNorm(10, 10).eval(), input_data=input_data) + + def test_forward_reshape(): torch.set_grad_enabled(False) input_shape = [2, 1, 10, 1, 10] @@ -756,9 +818,14 @@ class Transpose2(Module): def forward(self, *args): return args[0].transpose(-2, -1) + class Transpose3(Module): + def forward(self, *args): + return args[0].permute(0,2,3,1) + input_data = torch.rand(input_shape).float() verify_model(Transpose1().float().eval(), input_data=input_data) verify_model(Transpose2().float().eval(), input_data=input_data) + verify_model(Transpose3().float().eval(), input_data=input_data) def test_forward_size(): torch.set_grad_enabled(False) @@ -771,6 +838,44 @@ def forward(self, *args): input_data = torch.rand(input_shape).float() verify_model(Size1().float().eval(), input_data=input_data) + +def test_type_as(): + torch.set_grad_enabled(False) + input_shape = [1, 3] + + def _create_module(dtype): + class TypeAs(Module): + def forward(self, *args): + expected_type_tensor = torch.zeros(1, 3, dtype=dtype) + return args[0].type_as(expected_type_tensor) + + return TypeAs() + + input_data = torch.randn(input_shape).float() + verify_model(_create_module(torch.float64), input_data=input_data) + verify_model(_create_module(torch.float32), input_data=input_data) + verify_model(_create_module(torch.int64), input_data=input_data) + verify_model(_create_module(torch.int32), input_data=input_data) + verify_model(_create_module(torch.int16), input_data=input_data) + verify_model(_create_module(torch.int8), input_data=input_data) + + if torch.cuda.is_available(): + check_fp16 = False + try: + # Only check half precision on supported hardwares. + if have_fp16(tvm.gpu(0).compute_version): + check_fp16 = True + except Exception as e: + # If GPU is not enabled in TVM, skip the fp16 test. + pass + + # Temporary disable fp16 test + check_fp16 = False + + if check_fp16: + verify_model(_create_module(torch.float16), input_data=input_data) + + def test_forward_view(): torch.set_grad_enabled(False) input_shape = [1, 3, 10, 10] @@ -783,9 +888,15 @@ class View2(Module): def forward(self, *args): return args[0].view(args[0].shape[0], -1) + class View3(Module): + def forward(self, *args): + d1 = torch.tensor(3) * torch.tensor(10) * torch.tensor(10) + return args[0].view(args[0].shape[0], d1) + input_data = torch.rand(input_shape).float() verify_model(View1().float().eval(), input_data=input_data) verify_model(View2().float().eval(), input_data=input_data) + verify_model(View3().float().eval(), input_data=input_data) def test_forward_select(): torch.set_grad_enabled(False) @@ -820,6 +931,91 @@ def forward(self, *args): input_data = torch.rand(input_shape).float() verify_model(LogSoftmax1().float().eval(), input_data=input_data) + +def test_forward_norm(): + torch.set_grad_enabled(False) + input_shape = [1, 3, 10, 10] + + class Norm1(Module): + def forward(self, *args): + return torch.norm(args[0], p=float('inf'), dim=None, keepdim=False) + + class Norm2(Module): + def forward(self, *args): + return torch.norm(args[0], p=float('-inf'), dim=None, keepdim=False) + + class Norm3(Module): + def forward(self, *args): + return torch.norm(args[0], p=float('-inf'), dim=None, keepdim=True) + + class Norm4(Module): + def forward(self, *args): + return torch.norm(args[0], p=float('inf'), dim=(1, 2), keepdim=False) + + class Norm5(Module): + def forward(self, *args): + return torch.norm(args[0], p=float('inf'), dim=(1), keepdim=True) + + class Norm6(Module): + def forward(self, *args): + return torch.norm(args[0], p=float(0.5), dim=(1), keepdim=True) + + class Norm7(Module): + def forward(self, *args): + return torch.norm(args[0], p=float(1), dim=None, keepdim=False) + + class Norm8(Module): + def forward(self, *args): + return torch.norm(args[0], p=float(2.0), dim=(1), keepdim=True) + + class Norm9(Module): + def forward(self, *args): + return torch.norm(args[0], p=float(-0.5), dim=(1, 2), keepdim=True) + + class Norm10(Module): + def forward(self, *args): + return torch.norm(args[0], p=float(-2), dim=(1), keepdim=False) + + input_data = torch.rand(input_shape).float() + verify_model(Norm1().float().eval(), input_data=input_data) + verify_model(Norm2().float().eval(), input_data=input_data) + verify_model(Norm3().float().eval(), input_data=input_data) + verify_model(Norm4().float().eval(), input_data=input_data) + verify_model(Norm5().float().eval(), input_data=input_data) + verify_model(Norm6().float().eval(), input_data=input_data) + verify_model(Norm7().float().eval(), input_data=input_data) + verify_model(Norm8().float().eval(), input_data=input_data) + verify_model(Norm9().float().eval(), input_data=input_data) + verify_model(Norm10().float().eval(), input_data=input_data) + + +def test_forward_frobenius_norm(): + torch.set_grad_enabled(False) + input_shape = [1, 3, 10, 10] + + class FroNorm1(Module): + def forward(self, *args): + return torch.norm(args[0]) + + class FroNorm2(Module): + def forward(self, *args): + return torch.norm(args[0], p='fro', dim=None, keepdim=True) + + class FroNorm3(Module): + def forward(self, *args): + return torch.norm(args[0], p='fro', dim=(1), keepdim=True) + + class FroNorm4(Module): + def forward(self, *args): + return torch.norm(args[0], dim=None, keepdim=False) + + input_data = torch.rand(input_shape).float() + verify_model(FroNorm1().float().eval(), input_data=input_data) + verify_model(FroNorm2().float().eval(), input_data=input_data) + verify_model(FroNorm3().float().eval(), input_data=input_data) + verify_model(FroNorm4().float().eval(), input_data=input_data) + + def test_forward_sigmoid(): torch.set_grad_enabled(False) input_shape = [1, 3, 10, 10] @@ -869,9 +1065,17 @@ class Slice2(Module): def forward(self, *args): return args[0][0, :, :, :] + class Slice3(Module): + def forward(self, *args): + x0 = torch.tensor(2) - torch.tensor(1) + x1 = torch.tensor(3) + torch.tensor(1) + return args[0][:, x0:, :x1, :] + input_data = torch.rand(input_shape).float() verify_model(Slice1().float().eval(), input_data=input_data) verify_model(Slice2().float().eval(), input_data=input_data) + verify_model(Slice3().float().eval(), input_data=input_data) + def test_forward_mean(): torch.set_grad_enabled(False) @@ -886,15 +1090,24 @@ def forward(self, *args): def test_forward_expand(): torch.set_grad_enabled(False) - input_shape = [1, 3, 10, 10] class Expand1(Module): def forward(self, *args): return args[0].expand((3, -1, -1, -1)) + input_shape = [1, 3, 10, 10] input_data = torch.rand(input_shape).float() verify_model(Expand1().float().eval(), input_data=input_data) + class Expand2(Module): + def forward(self, *args): + return args[0].expand((3, 3, 3, 1)) + + input_shape = [3, 1] + input_data = torch.rand(input_shape).float() + verify_model(Expand2().float().eval(), input_data=input_data) + + def test_forward_pow(): torch.set_grad_enabled(False) input_shape = [1, 3, 10, 10] @@ -984,6 +1197,102 @@ def test_adaptive_pool3d(): verify_model(torch.nn.AdaptiveMaxPool3d((7, 8, 9)).eval(), inp) +def test_forward_functional_pad(): + torch.set_grad_enabled(False) + pad = (0, 0) + class Pad1(Module): + def forward(self, *args): + return torch.nn.functional.pad(args[0], pad, "constant", 0) + + input_data = torch.rand((3, 3, 4, 2)) + pad = (1, 1) + verify_model(Pad1().float().eval(), input_data=input_data) + + pad = (1, 1, 2, 2) + verify_model(Pad1().float().eval(), input_data=input_data) + + pad = (0, 1, 2, 1, 3, 3) + verify_model(Pad1().float().eval(), input_data=input_data) + + +def test_forward_zero_pad2d(): + inp = torch.rand((1, 1, 3, 3)) + verify_model(torch.nn.ZeroPad2d(2).eval(), inp) + verify_model(torch.nn.ZeroPad2d((1, 1, 2, 0)).eval(), inp) + + +def test_forward_constant_pad1d(): + inp = torch.rand((1, 2, 4)) + verify_model(torch.nn.ConstantPad2d(2, 3.5).eval(), inp) + + inp = torch.rand((1, 2, 3)) + verify_model(torch.nn.ConstantPad2d((3, 1), 3.5).eval(), inp) + + +def test_forward_constant_pad2d(): + inp = torch.rand((1, 2, 2, 2)) + verify_model(torch.nn.ConstantPad2d(2, 3.5).eval(), inp) + verify_model(torch.nn.ConstantPad2d((3, 0, 2, 1), 3.5).eval(), inp) + + +def test_forward_constant_pad3d(): + inp = torch.rand((1, 3, 2, 2, 2)) + verify_model(torch.nn.ConstantPad3d(3, 3.5).eval(), inp) + verify_model(torch.nn.ConstantPad3d((3, 4, 5, 6, 0, 1), 3.5).eval(), inp) + + +def test_forward_reflection_pad1d(): + inp = torch.rand((1, 2, 4)) + verify_model(torch.nn.ReflectionPad1d(2).eval(), inp) + verify_model(torch.nn.ReflectionPad1d((3, 1)).eval(), inp) + + inp = torch.rand((2, 4, 5)) + verify_model(torch.nn.ReflectionPad1d((2, 3)).eval(), inp) + + +def test_forward_reflection_pad2d(): + inp = torch.rand((1, 1, 3, 3)) + verify_model(torch.nn.ReflectionPad2d(2).eval(), inp) + verify_model(torch.nn.ReflectionPad2d((1, 1, 2, 0)).eval(), inp) + + inp = torch.rand((2, 4, 5, 6)) + verify_model(torch.nn.ReflectionPad2d((1, 3, 2, 4)).eval(), inp) + + +def test_forward_replication_pad1d(): + inp = torch.rand((1, 2, 4)) + verify_model(torch.nn.ReplicationPad1d(2).eval(), inp) + verify_model(torch.nn.ReplicationPad1d((3, 1)).eval(), inp) + + inp = torch.rand((2, 4, 5)) + verify_model(torch.nn.ReplicationPad1d((2, 3)).eval(), inp) + + +def test_forward_replication_pad2d(): + inp = torch.rand((1, 1, 3, 3)) + verify_model(torch.nn.ReplicationPad2d(2).eval(), inp) + verify_model(torch.nn.ReplicationPad2d((1, 1, 2, 0)).eval(), inp) + + inp = torch.rand((2, 4, 5, 6)) + verify_model(torch.nn.ReplicationPad2d((1, 3, 2, 4)).eval(), inp) + + +def test_forward_replication_pad3d(): + inp = torch.rand((1, 1, 3, 3, 3)) + verify_model(torch.nn.ReplicationPad3d(3).eval(), inp) + verify_model(torch.nn.ReplicationPad3d((1, 1, 2, 2, 1, 1)).eval(), inp) + + inp = torch.rand((7, 5, 4, 5, 6)) + verify_model(torch.nn.ReplicationPad3d((2, 3, 2, 5, 1, 4)).eval(), inp) + + +def test_forward_upsample3d(): + inp = torch.arange(1, 9, dtype=torch.float32).view(1, 1, 2, 2, 2) + verify_model(torch.nn.Upsample(scale_factor=2, mode='nearest').eval(), inp) + verify_model(torch.nn.Upsample(scale_factor=2, mode='trilinear').eval(), inp) + verify_model(torch.nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True).eval(), inp) + + def test_conv3d(): for ishape in [(1, 32, 16, 16, 16), (1, 32, 9, 15, 15), @@ -1002,38 +1311,65 @@ def test_conv3d(): inp) +def test_conv3d_transpose(): + for ishape in [(1, 8, 10, 5, 10), + (1, 8, 5, 8, 8), + (1, 8, 13, 7, 7)]: + inp = torch.rand(ishape) + verify_model(torch.nn.ConvTranspose3d(in_channels=8, + out_channels=33, + kernel_size=3, + stride=2).eval(), + inp), + verify_model(torch.nn.ConvTranspose3d(in_channels=8, + out_channels=20, + kernel_size=(3, 5, 2), + stride=(2, 1, 1), + padding=(0, 4, 2)).eval(), + inp), + verify_model(torch.nn.ConvTranspose3d(in_channels=8, + out_channels=20, + kernel_size=1).eval(), + inp) + verify_model(torch.nn.ConvTranspose3d(in_channels=8, + out_channels=5, + kernel_size=1, + stride=2).eval(), + inp) + + # Model tests def test_resnet18(): torch.set_grad_enabled(False) - verify_model("resnet18") + verify_model("resnet18", atol=1e-4, rtol=1e-4) def test_squeezenet1_0(): torch.set_grad_enabled(False) - verify_model("squeezenet1_0") + verify_model("squeezenet1_0", atol=1e-4, rtol=1e-4) def test_squeezenet1_1(): torch.set_grad_enabled(False) - verify_model("squeezenet1_1") + verify_model("squeezenet1_1", atol=1e-4, rtol=1e-4) def test_densenet121(): torch.set_grad_enabled(False) - verify_model("densenet121") + verify_model("densenet121", atol=1e-4, rtol=1e-4) def test_inception_v3(): torch.set_grad_enabled(False) - verify_model("inception_v3") + verify_model("inception_v3", atol=1e-4, rtol=1e-4) def test_googlenet(): torch.set_grad_enabled(False) - verify_model("googlenet") + verify_model("googlenet", atol=1e-4, rtol=1e-4) def test_mnasnet0_5(): torch.set_grad_enabled(False) - verify_model("mnasnet0_5") + verify_model("mnasnet0_5", atol=1e-4, rtol=1e-4) def test_mobilenet_v2(): torch.set_grad_enabled(False) - verify_model("mobilenet_v2") + verify_model("mobilenet_v2", atol=1e-4, rtol=1e-4) """ #TODO: Fix VGG and AlexNet issues (probably due to pooling) @@ -1094,19 +1430,19 @@ def forward(self, inp): inp = [torch.rand((1, 3, 300, 300), dtype=torch.float)] - verify_model(SegmentationModelWrapper(fcn.eval()), inp) + verify_model(SegmentationModelWrapper(fcn.eval()), inp, atol=1e-4, rtol=1e-4) # depthwise + dilated covolution not supported on x86 # see https://github.com/apache/incubator-tvm/issues/4962 cuda_ctx = ("cuda", tvm.gpu(0)) if cuda_ctx[1].exist: - verify_model(SegmentationModelWrapper(deeplab.eval()), inp, [cuda_ctx]) + verify_model(SegmentationModelWrapper(deeplab.eval()), inp, [cuda_ctx], atol=1e-4, rtol=1e-4) def test_3d_models(): input_shape = (1, 3, 4, 56, 56) resnet3d = torchvision.models.video.r3d_18(pretrained=True).eval() - verify_model(resnet3d, [torch.rand(input_shape)]) + verify_model(resnet3d, [torch.rand(input_shape)], atol=1e-4, rtol=1e-4) def verify_script_model(pt_model, ishapes): @@ -1452,6 +1788,56 @@ def forward(self, *args): verify_model(Variance5().float().eval(), input_data=input_data) +def test_forward_rsub(): + torch.set_grad_enabled(False) + + class Rsub1(Module): + def forward(self, *args): + return torch.rsub(args[0], args[1]) + + class Rsub2(Module): + def forward(self, *args): + return torch.rsub(args[0], args[1], alpha=0.5) + + d1 = torch.rand([1, 3]).float() + d2 = torch.rand([1, 3]).float() + d3 = torch.rand([1, 3]).int() + verify_model(Rsub1().float().eval(), input_data=[d1, d2]) + verify_model(Rsub1().float().eval(), input_data=[d1, d3]) + verify_model(Rsub2().float().eval(), input_data=[d1, d2]) + verify_model(Rsub2().float().eval(), input_data=[d1, d3]) + + +def test_forward_embedding(): + torch.set_grad_enabled(False) + + input_data = torch.randint(0, 10, [2, 4]).long() + verify_model(torch.nn.Embedding(10, 3).float().eval(), input_data=input_data) + + input_data = torch.randint(0, 4, [2, 3, 4]).long() + verify_model(torch.nn.Embedding(4, 5, sparse=False).float().eval(), input_data=input_data) + + input_data = torch.randint(0, 4, [2, 3, 4]).long() + verify_model(torch.nn.Embedding(4, 5, sparse=True).float().eval(), input_data=input_data) + + +def test_forward_onehot(): + torch.set_grad_enabled(False) + + class OneHot1(Module): + def forward(self, *args): + return torch.nn.functional.one_hot(args[0], num_classes=3) + + class OneHot2(Module): + def forward(self, *args): + return torch.nn.functional.one_hot(args[0], num_classes=5) + + input_data = torch.arange(0, 5) % 3 + verify_model(OneHot1().float().eval(), input_data=input_data) + + input_data = torch.arange(0, 5) % 4 + verify_model(OneHot2().float().eval(), input_data=input_data) + def test_forward_isfinite(): torch.set_grad_enabled(False) @@ -1486,74 +1872,164 @@ def forward(self, *args): verify_model(IsInf1().float().eval(), input_data=input_data) -def test_forward_rsqrt(): +def test_forward_clamp(): torch.set_grad_enabled(False) input_shape = [1, 3, 10, 10] - class Rsqrt1(Module): + class Clamp1(Module): def forward(self, *args): - return torch.rsqrt(args[0]) + return torch.clamp(args[0], min=-0.5, max=0.5) + + class Clamp2(Module): + def forward(self, *args): + return torch.clamp(args[0], min=-0.3) + + class Clamp3(Module): + def forward(self, *args): + return torch.clamp(args[0], max=1.0) input_data = torch.rand(input_shape).float() - verify_model(Rsqrt1().float().eval(), input_data=input_data) + verify_model(Clamp1().float().eval(), input_data=input_data) + verify_model(Clamp2().float().eval(), input_data=input_data) + verify_model(Clamp3().float().eval(), input_data=input_data) -def test_forward_ceil(): +def test_forward_ones(): torch.set_grad_enabled(False) - input_shape = [1, 3, 10, 10] - class Ceil1(Module): + class Ones1(Module): def forward(self, *args): - return torch.ceil(args[0]) + return torch.ones(2,3) - input_data = torch.rand(input_shape).float() - verify_model(Ceil1().float().eval(), input_data=input_data) + verify_model(Ones1().float().eval(), input_data=[]) -def test_forward_clamp(): +def test_forward_ones_like(): torch.set_grad_enabled(False) input_shape = [1, 3, 10, 10] - class Clamp1(Module): + class OnesLike1(Module): def forward(self, *args): - return torch.clamp(args[0], min=-0.5, max=0.5) + return torch.ones_like(args[0]) - class Clamp2(Module): + class OnesLike2(Module): def forward(self, *args): - return torch.clamp(args[0], min=-0.3) + return torch.ones_like(args[0], dtype=torch.int8) - class Clamp3(Module): + class OnesLike3(Module): def forward(self, *args): - return torch.clamp(args[0], max=1.0) + return torch.ones_like(args[0], dtype=torch.float) input_data = torch.rand(input_shape).float() - verify_model(Clamp1().float().eval(), input_data=input_data) - verify_model(Clamp2().float().eval(), input_data=input_data) - verify_model(Clamp3().float().eval(), input_data=input_data) + verify_model(OnesLike1().float().eval(), input_data=input_data) + verify_model(OnesLike2().float().eval(), input_data=input_data) + verify_model(OnesLike3().float().eval(), input_data=input_data) + + +def test_forward_zeros(): + torch.set_grad_enabled(False) + + class Zeros1(Module): + def forward(self, *args): + return torch.zeros(2,3) + + verify_model(Zeros1().float().eval(), input_data=[]) -def test_forward_floor(): +def test_forward_zeros_like(): torch.set_grad_enabled(False) input_shape = [1, 3, 10, 10] - class Floor1(Module): + class ZerosLike1(Module): def forward(self, *args): - return torch.floor(args[0]) + return torch.zeros_like(args[0]) + + class ZerosLike2(Module): + def forward(self, *args): + return torch.zeros_like(args[0], dtype=torch.int32) + + class ZerosLike3(Module): + def forward(self, *args): + return torch.zeros_like(args[0], dtype=torch.float) input_data = torch.rand(input_shape).float() - verify_model(Floor1().float().eval(), input_data=input_data) + verify_model(ZerosLike1().float().eval(), input_data=input_data) + verify_model(ZerosLike2().float().eval(), input_data=input_data) + verify_model(ZerosLike3().float().eval(), input_data=input_data) + + +def test_forward_full(): + torch.set_grad_enabled(False) + + class Full1(Module): + def forward(self, *args): + return torch.full((2,3), 3.14) + + class Full2(Module): + def forward(self, *args): + return torch.full((1, 2,3), 1.0, dtype=torch.int32) + + verify_model(Full1().float().eval(), input_data=[]) + verify_model(Full2().float().eval(), input_data=[]) -def test_forward_round(): +def test_forward_full_like(): torch.set_grad_enabled(False) input_shape = [1, 3, 10, 10] - class Round1(Module): + class FullLike1(Module): def forward(self, *args): - return torch.round(args[0]) + return torch.full_like(args[0], 3.14) + + class FullLike2(Module): + def forward(self, *args): + return torch.full_like(args[0], 22.22, dtype=torch.int32) + + class FullLike3(Module): + def forward(self, *args): + return torch.full_like(args[0], 1.4, dtype=torch.float) input_data = torch.rand(input_shape).float() - verify_model(Round1().float().eval(), input_data=input_data) + verify_model(FullLike1().float().eval(), input_data=input_data) + verify_model(FullLike2().float().eval(), input_data=input_data) + verify_model(FullLike3().float().eval(), input_data=input_data) + +def test_forward_linspace(): + torch.set_grad_enabled(False) + + class Linspace1(Module): + def forward(self, *args): + return torch.linspace(5, 10) + class Linspace2(Module): + def forward(self, *args): + return torch.linspace(-10, 10, steps=5) + class Linspace3(Module): + def forward(self, *args): + return torch.linspace(start=-10, end=10, steps=5) + class Linspace4(Module): + def forward(self, *args): + return torch.linspace(start=-10, end=10, steps=1) + class Linspace5(Module): + def forward(self, *args): + return torch.linspace(1, 2, 1, dtype=torch.int32) + class Linspace6(Module): + def forward(self, *args): + return torch.linspace(start=1, end=6, steps=2) + class Linspace7(Module): + def forward(self, *args): + return torch.linspace(1, 4, dtype=torch.float32) + class Linspace8(Module): + def forward(self, *args): + return torch.linspace(1, 2, 1, dtype=torch.int16) + + verify_model(Linspace1().float().eval()) + verify_model(Linspace2().float().eval()) + verify_model(Linspace3().float().eval()) + verify_model(Linspace4().float().eval()) + verify_model(Linspace5().float().eval()) + verify_model(Linspace6().float().eval()) + verify_model(Linspace7().float().eval()) + verify_model(Linspace8().float().eval()) def test_forward_take(): @@ -1700,11 +2176,364 @@ def forward(self, *args): verify_model(LogicalXor2().float().eval(), input_data=[lhs]) +def test_forward_unary(): + torch.set_grad_enabled(False) + + class Sqrt1(Module): + def forward(self, *args): + return torch.sqrt(args[0]) + + class RSqrt1(Module): + def forward(self, *args): + return torch.rsqrt(args[0]) + + class Ceil1(Module): + def forward(self, *args): + return torch.ceil(args[0]) + + class Floor1(Module): + def forward(self, *args): + return torch.floor(args[0]) + + class Round1(Module): + def forward(self, *args): + return torch.round(args[0]) + + class Cos1(Module): + def forward(self, *args): + return torch.cos(args[0]) + + class Sin1(Module): + def forward(self, *args): + return torch.sin(args[0]) + + class Tan1(Module): + def forward(self, *args): + return torch.tan(args[0]) + + class Tanh1(Module): + def forward(self, *args): + return torch.tanh(args[0]) + + class Acos1(Module): + def forward(self, *args): + return torch.acos(args[0]) + + class Asin1(Module): + def forward(self, *args): + return torch.asin(args[0]) + + class Atan1(Module): + def forward(self, *args): + return torch.atan(args[0]) + + class Log1(Module): + def forward(self, *args): + return torch.log(args[0]) + + class Exp1(Module): + def forward(self, *args): + return torch.exp(args[0]) + + class Erf1(Module): + def forward(self, *args): + return torch.erf(args[0]) + + class Trunc1(Module): + def forward(self, *args): + return torch.trunc(args[0]) + + class Sign1(Module): + def forward(self, *args): + return torch.sign(args[0]) + + class Neg1(Module): + def forward(self, *args): + return torch.neg(args[0]) + + class Sinh1(Module): + def forward(self, *args): + return torch.sinh(args[0]) + + class Cosh1(Module): + def forward(self, *args): + return torch.cosh(args[0]) + + class Log2_1(Module): + def forward(self, *args): + return torch.log2(args[0]) + + class Log10_1(Module): + def forward(self, *args): + return torch.log10(args[0]) + + class Log1p_1(Module): + def forward(self, *args): + return torch.log1p(args[0]) + + input_shape = [1, 3, 10, 10] + input_data = torch.rand(input_shape).float() + verify_model(Sqrt1().float().eval(), input_data=input_data) + verify_model(RSqrt1().float().eval(), input_data=input_data) + verify_model(Ceil1().float().eval(), input_data=input_data) + verify_model(Floor1().float().eval(), input_data=input_data) + verify_model(Round1().float().eval(), input_data=input_data) + verify_model(Cos1().float().eval(), input_data=input_data) + verify_model(Cosh1().float().eval(), input_data=input_data) + verify_model(Sin1().float().eval(), input_data=input_data) + verify_model(Sinh1().float().eval(), input_data=input_data) + verify_model(Tan1().float().eval(), input_data=input_data) + verify_model(Tanh1().float().eval(), input_data=input_data) + verify_model(Acos1().float().eval(), input_data=input_data) + verify_model(Asin1().float().eval(), input_data=input_data) + verify_model(Atan1().float().eval(), input_data=input_data) + verify_model(Log1().float().eval(), input_data=input_data) + verify_model(Log2_1().float().eval(), input_data=input_data) + verify_model(Log10_1().float().eval(), input_data=input_data) + verify_model(Log1p_1().float().eval(), input_data=input_data) + verify_model(Exp1().float().eval(), input_data=input_data) + verify_model(Erf1().float().eval(), input_data=input_data) + verify_model(Trunc1().float().eval(), input_data=input_data) + verify_model(Sign1().float().eval(), input_data=input_data) + verify_model(Neg1().float().eval(), input_data=input_data) + + +def test_forward_where(): + torch.set_grad_enabled(False) + + class Where1(Module): + def forward(self, *args): + y = torch.ones([3, 2]) + if torch.cuda.is_available(): + y = y.cuda() + return torch.where(args[0] > 0, args[0], y) + + class Where2(Module): + def forward(self, *args): + return torch.where(args[0] > 0, args[0], args[1]) + + x = torch.rand([3, 2]).float() + verify_model(Where1().float().eval(), input_data=[x]) + y = torch.rand([3, 2]) + verify_model(Where2().float().eval(), input_data=[x, y]) + + +def test_forward_addcdiv(): + torch.set_grad_enabled(False) + + class Addcdiv1(Module): + def forward(self, *args): + t1 = torch.ones([3, 1]) + t2 = torch.ones([1, 3]) + if torch.cuda.is_available(): + t1 = t1.cuda() + t2 = t2.cuda() + return torch.addcdiv(args[0], 0.1, t1, t2) + + class Addcdiv2(Module): + def forward(self, *args): + return torch.addcdiv(args[0], 0.5, args[1], args[2]) + + input_data = torch.rand([1, 3]).float() + verify_model(Addcdiv1().float().eval(), input_data=input_data) + t1 = torch.rand([3, 1]).float() + t2 = torch.rand([1, 3]).float() + verify_model(Addcdiv2().float().eval(), input_data=[input_data, t1, t2]) + + +def test_forward_addcmul(): + torch.set_grad_enabled(False) + + class Addcmul1(Module): + def forward(self, *args): + t1 = torch.ones([3, 1]) + t2 = torch.ones([1, 3]) + if torch.cuda.is_available(): + t1 = t1.cuda() + t2 = t2.cuda() + return torch.addcmul(args[0], 0.1, t1, t2) + + class Addcmul2(Module): + def forward(self, *args): + return torch.addcmul(args[0], 0.5, args[1], args[2]) + + input_data = torch.rand([1, 3]).float() + verify_model(Addcmul1().float().eval(), input_data=input_data) + t1 = torch.rand([3, 1]).float() + t2 = torch.rand([1, 3]).float() + verify_model(Addcmul2().float().eval(), input_data=[input_data, t1, t2]) + + +def test_forward_matmul(): + torch.set_grad_enabled(False) + + class MatMul1(Module): + def forward(self, *args): + return torch.matmul(args[0], args[1]) + + # matrix x vector + tensor1 = torch.randn(3, 4) + tensor2 = torch.randn(4) + verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2]) + + # matrix x matrix + tensor1 = torch.randn(10, 4) + tensor2 = torch.randn(4, 10) + verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2]) + + # batched matrix x batched matrix + tensor1 = torch.randn(10, 3, 4) + tensor2 = torch.randn(10, 4, 5) + verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2]) + + # batched matrix x broadcasted matrix + tensor1 = torch.randn(10, 3, 4) + tensor2 = torch.randn(4, 5) + verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2]) + + # batched matrix x batched matrix + tensor1 = torch.randn(1, 12, 14, 64) + tensor2 = torch.randn(1, 12, 64, 14) + verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2]) + + +def test_forward_pretrained_bert_base_uncased(): + ###################################################################### + # This is an example how to run BERT models using TVM + # --------------------------------------------------- + """ + Refer the bert example given in https://pypi.org/project/pytorch-pretrained-bert + + # To get started, pretrained bert package needs to be installed as prerequisite. + + .. code-block:: bash + + # install bert package + pip install pytorch_pretrained_bert==0.6.2 --user + """ + + try: + from pytorch_pretrained_bert import BertTokenizer, BertForMaskedLM + except: + print("Torch pretrained bert package must be installed to run this script.") + return + + ###################################################################### + # Load the tokenizer and tokenize the input + # ----------------------------------------- + + # Load pre-trained model tokenizer (vocabulary) + tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') + + # Tokenized input + text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]" + tokenized_text = tokenizer.tokenize(text) + + # Mask a token that we will try to predict back with `BertForMaskedLM` + masked_index = 8 + tokenized_text[masked_index] = '[MASK]' + assert tokenized_text == ['[CLS]', 'who', 'was', 'jim', 'henson', '?', '[SEP]', 'jim', '[MASK]', 'was', 'a', 'puppet', + '##eer', '[SEP]'] + + # Convert token to vocabulary indices + indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text) + # Define sentence A and B indices associated to 1st and 2nd sentences (see paper) + segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1] + + # Convert inputs to PyTorch tensors + tokens_tensor = torch.tensor([indexed_tokens]) + segments_tensors = torch.tensor([segments_ids]) + + ###################################################################### + # Load a pretrained PyTorch model bert-base-uncased + # ------------------------------------------------- + + # Bert Model with a language modeling + model = BertForMaskedLM.from_pretrained('bert-base-uncased') + model.eval() + + ###################################################################### + # Predict all tokens with pytorch + # ------------------------------- + + with torch.no_grad(): + torch_preds = model(tokens_tensor, segments_tensors) + + ###################################################################### + # Make TorchScripted model via jit trace + # -------------------------------------- + + scripted_model = torch.jit.trace(model, (tokens_tensor, segments_tensors)).eval() + + ###################################################################### + # Import the graph to Relay + # ------------------------- + # Convert PyTorch graph to Relay graph. The input name can be arbitrary. + + input_1 = 'input_ids' + input_2 = 'input.2' + shape_list = [(input_1, list(tokens_tensor.shape)), + (input_2, list(segments_tensors.shape))] + + mod, params = relay.frontend.from_pytorch(scripted_model, shape_list) + + ###################################################################### + # Compile the model with relay + # ---------------------------- + + target = 'llvm' + with tvm.transform.PassContext(opt_level=3): + relay_graph, relay_lib, relay_params = relay.build(mod, target=target, params=params) + + ###################################################################### + # Execute on TVM + # -------------- + + ctx = tvm.context(target, 0) + relay_model = graph_runtime.create(relay_graph, relay_lib, ctx) + relay_model.set_input(**relay_params) + relay_model.set_input(input_1, tokens_tensor) + relay_model.set_input(input_2, segments_tensors) + relay_model.run() + compiled_output = relay_model.get_output(0).asnumpy() + + ###################################################################### + # Validate the outputs + # -------------------- + # Compare the torch and tvm outputs + + tvm.testing.assert_allclose(torch_preds, compiled_output, rtol=1e-3, atol=1e-3) + + ###################################################################### + # Process the output + # ------------------ + # Process the model output to token. + + # Torch output to token + torch_pred_idx = torch.argmax(torch_preds[0, masked_index]).item() + torch_pred_token = tokenizer.convert_ids_to_tokens([torch_pred_idx])[0] + + # TVM output to token + tvm_pred_idx = compiled_output[0, masked_index].argmax() + tvm_pred_token = tokenizer.convert_ids_to_tokens([tvm_pred_idx])[0] + + assert torch_pred_idx == tvm_pred_idx + assert torch_pred_token == tvm_pred_token + + # Print the outputs + print('Torch top-1 id: {}, token: {}'.format(torch_pred_idx, torch_pred_token)) + print('TVM top-1 id: {}, token: {}'.format(tvm_pred_idx, tvm_pred_token)) + + if __name__ == "__main__": # Single operator tests test_forward_add() test_forward_subtract() test_forward_multiply() + test_forward_matmul() + test_forward_rsub() + test_forward_onehot() + test_forward_embedding() test_forward_reshape() test_forward_reciprocal() test_forward_repeat() @@ -1716,6 +2545,8 @@ def forward(self, *args): test_forward_reduce_prod() test_forward_argmin() test_forward_argmax() + test_forward_norm() + test_forward_frobenius_norm() test_forward_std() test_forward_variance() test_forward_relu() @@ -1738,12 +2569,16 @@ def forward(self, *args): test_forward_batchnorm() test_forward_instancenorm() test_forward_layernorm() + test_forward_groupnorm() test_forward_transpose() test_forward_size() test_forward_view() test_forward_select() test_forward_take() test_forward_topk() + test_forward_where() + test_forward_addcdiv() + test_forward_addcmul() test_forward_clone() test_forward_softplus() test_forward_softsign() @@ -1757,12 +2592,8 @@ def forward(self, *args): test_forward_mean() test_forward_expand() test_forward_pow() - test_forward_abs() - test_forward_rsqrt() - test_forward_ceil() + test_forward_unary() test_forward_clamp() - test_forward_floor() - test_forward_round() test_forward_logical_not() test_forward_bitwise_not() test_forward_bitwise_xor() @@ -1770,13 +2601,33 @@ def forward(self, *args): test_forward_isfinite() test_forward_isnan() test_forward_isinf() + test_forward_ones() + test_forward_ones_like() + test_forward_zeros() + test_forward_zeros_like() + test_forward_full() + test_forward_full_like() + test_forward_linspace() test_forward_arange() test_forward_chunk() test_forward_split() test_upsample() + test_forward_upsample3d() test_to() + test_type_as() + test_forward_functional_pad() + test_forward_zero_pad2d() + test_forward_constant_pad1d() + test_forward_constant_pad2d() + test_forward_constant_pad3d() + test_forward_reflection_pad1d() + test_forward_reflection_pad2d() + test_forward_replication_pad1d() + test_forward_replication_pad2d() + test_forward_replication_pad3d() test_adaptive_pool3d() test_conv3d() + test_conv3d_transpose() # Model tests test_resnet18() @@ -1809,3 +2660,6 @@ def forward(self, *args): from lstm_test import custom_lstm_test custom_lstm_test() + + # Test bert model + test_forward_pretrained_bert_base_uncased() diff --git a/tests/python/frontend/tensorflow/test_bn_dynamic.py b/tests/python/frontend/tensorflow/test_bn_dynamic.py index 4be838e331ef..e80d774408a3 100644 --- a/tests/python/frontend/tensorflow/test_bn_dynamic.py +++ b/tests/python/frontend/tensorflow/test_bn_dynamic.py @@ -22,7 +22,10 @@ """ import tvm import numpy as np -import tensorflow as tf +try: + import tensorflow.compat.v1 as tf +except ImportError: + import tensorflow as tf from tvm import relay from tensorflow.python.framework import graph_util @@ -47,7 +50,7 @@ def verify_fused_batch_norm(shape): continue mod, params = relay.frontend.from_tensorflow(constant_graph, outputs=['output']) - with relay.build_config(opt_level=3): + with tvm.transform.PassContext(opt_level=3): graph, lib, params = relay.build(mod, target=device, params=params) diff --git a/tests/python/frontend/tensorflow/test_control_flow.py b/tests/python/frontend/tensorflow/test_control_flow.py index 9777a8dc4462..3ec04bf38490 100644 --- a/tests/python/frontend/tensorflow/test_control_flow.py +++ b/tests/python/frontend/tensorflow/test_control_flow.py @@ -21,6 +21,7 @@ tf.disable_v2_behavior() except ImportError: import tensorflow as tf +from tensorflow.python.ops import control_flow_ops import numpy as np from tvm import nd from tvm import relay @@ -45,7 +46,7 @@ def check_equal(graph, tf_out, input_map=None): def test_vanilla_loop(): graph = tf.Graph() with graph.as_default(): - i = tf.constant(0) + i = tf.constant(0, name="while/constant") def c(i): return tf.less(i, 10) @@ -367,6 +368,44 @@ def condition(x, y): check_equal(graph, tf_out, {dname: np_data}) +def test_switch(): + graph = tf.Graph() + + with graph.as_default(): + data_np = np.random.uniform(0, 5, size=(2, 4, 5, 1)).astype('float32') + dname = 'data' + flag_name = 'flag' + data = tf.placeholder(shape=data_np.shape, dtype=data_np.dtype, name=dname) + split = tf.split(data, 2, axis=0) + flag = tf.placeholder(shape={}, dtype=tf.bool, name=flag_name) + output_false, output_true = control_flow_ops.switch(split[1], flag) + with tf.Session() as sess: + tf_out = sess.run(output_false, feed_dict={data.name: data_np, flag.name: False}) + + check_equal(graph, tf_out, {dname: data_np, flag_name: False}) + +def test_loop_tuple_input(): + graph = tf.Graph() + + with graph.as_default(): + data_np = np.random.uniform(0, 5, size=(2, 4, 5, 1)).astype('float32') + dname = 'data' + data = tf.placeholder(shape=data_np.shape, dtype=data_np.dtype, name=dname) + split = tf.split(data, 2, axis=0) + + def body(x, y): + return x + 2, y + 1 + + start = tf.constant(0) + def condition(x, y): + return tf.less(y, 20) + + r = tf.while_loop(condition, body, loop_vars=[split[1], start]) + with tf.Session() as sess: + tf_out = sess.run(r, feed_dict={data.name: data_np}) + + check_equal(graph, tf_out, {dname: data_np}) + if __name__ == "__main__": # tf.while_loop @@ -390,3 +429,6 @@ def condition(x, y): test_cond_in_loop() test_vanilla_loop_bound() test_nested_loop_bound() + + test_switch() + test_loop_tuple_input() diff --git a/tests/python/frontend/tensorflow/test_debugging.py b/tests/python/frontend/tensorflow/test_debugging.py index 01ad6a256f88..a6df6ffb63a1 100644 --- a/tests/python/frontend/tensorflow/test_debugging.py +++ b/tests/python/frontend/tensorflow/test_debugging.py @@ -51,7 +51,8 @@ def test_assert_true(): # do that, it's happening in Relay, and that optimization shouldn't # affect the arity of the main function. We should have to pass in # x_value here. - np.testing.assert_allclose(0, run_relay(g, {'input':shape}).asnumpy()) + np.testing.assert_allclose(0, run_relay(g, {'input': shape}).asnumpy()) + def test_assert_true_var_capture(): g = tf.Graph() @@ -71,7 +72,7 @@ def test_assert_true_var_capture(): # the graph as a boolean, which is not correct - as you can see above, # TF believes that the value of this graph is None. np.testing.assert_allclose(True, - run_relay(g, None, x_value).asnumpy()) + run_relay(g, None, x_value).asnumpy()) def test_assert_false(): g = tf.Graph() @@ -91,9 +92,7 @@ def test_assert_false(): # argument is false. np.testing.assert_allclose(0, run_relay(g).asnumpy()) - if __name__ == "__main__": test_assert_true() test_assert_true_var_capture() test_assert_false() - diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index bc884bbbfa9b..fbe0060f8ad9 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -21,6 +21,7 @@ This article is a test script to test tensorflow operator with Relay. """ from __future__ import print_function +import threading import numpy as np import pytest try: @@ -36,11 +37,16 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.ops import init_ops +from tensorflow.python.framework import function +from tensorflow.python.framework import ops +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import gen_functional_ops from distutils.version import LooseVersion import tvm from tvm import te from tvm import relay import tvm.relay.testing.tf as tf_testing +from tvm.runtime.vm import VirtualMachine from packaging import version as package_version ####################################################################### @@ -94,20 +100,20 @@ def vmobj_to_list(o): def run_tvm_graph(graph_def, input_data, input_node, num_output=1, target='llvm', out_names=None, opt_level=3, mode='graph_runtime', - cuda_layout="NCHW"): + cuda_layout="NCHW", layout=None, disabled_pass=None): """ Generic function to compile on relay and execute on tvm """ input_data = convert_to_list(input_data) input_node = convert_to_list(input_node) - layout = None if target == "cuda": layout = cuda_layout target_host = None - shape_dict = {e: i.shape for e, i in zip(input_node, input_data)} + shape_dict = {e: i.shape if hasattr(i, "shape") else () for e, i in zip(input_node, input_data)} mod, params = relay.frontend.from_tensorflow(graph_def, layout=layout, shape=shape_dict, outputs=out_names) - if mode in ['debug', 'vm']: + ctx = tvm.context(target, 0) + if mode == 'debug': ex = relay.create_executor(mode, mod=mod, ctx=tvm.cpu(), target="llvm") inputs = [] for param in mod['main'].params: @@ -122,11 +128,19 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1, inputs.append(tvm.nd.array(params[param.name_hint])) result = ex.evaluate()(*inputs) return vmobj_to_list(result) + elif mode == 'vm': + with tvm.transform.PassContext(opt_level=opt_level, disabled_pass=disabled_pass): + vm_exec = relay.vm.compile(mod, target="llvm", params=params) + vm = VirtualMachine(vm_exec) + vm.init(tvm.cpu()) + inputs = {} + for e, i in zip(input_node, input_data): + inputs[e] = tvm.nd.array(i) + result = vm.invoke("main", **inputs) + return vmobj_to_list(result) else: - with relay.build_config(opt_level=opt_level): + with tvm.transform.PassContext(opt_level=opt_level, disabled_pass=disabled_pass): graph, lib, params = relay.build(mod, target, target_host, params) - - ctx = tvm.context(target, 0) from tvm.contrib import graph_runtime m = graph_runtime.create(graph, lib, ctx) # set inputs @@ -176,6 +190,7 @@ def name_without_num(name): if init_global_variables: sess.run(variables.global_variables_initializer()) final_graph_def = tf_testing.AddShapesToGraphDef(sess, out_node) + tf_output = run_tf_graph(sess, in_data, in_name, out_name) for device in ["llvm", "cuda"]: @@ -469,7 +484,7 @@ def test_forward_convolution(): ####################################################################### # Convolution3D -# ----------- +# ------------- def _test_convolution3d(opname, tensor_in_sizes, filter_in_sizes, @@ -518,6 +533,92 @@ def test_forward_convolution3d(): _test_convolution3d('conv', [4, 17, 17, 17, 12], [3, 3, 3, 12, 32], [1, 1, 1], [2, 2, 2], 'VALID', 'NDHWC') +####################################################################### +# Convolution3D Transpose +# ----------------------- + +def _test_convolution3d_transpose(data_shape, filter_shape, strides, + padding, output_shape, data_format='NCDHW'): + """ One iteration of 3D convolution transpose with given shapes and attributes """ + + dtype = 'float32' + data_array = np.random.uniform(size=data_shape).astype(dtype) + filter_array = np.random.uniform(size=filter_shape).astype(dtype) + if data_format == 'NDHWC': + strides = [1] + strides + [1] + else: + strides = [1, 1] + strides + + with tf.Graph().as_default(): + in_data = array_ops.placeholder(shape=data_shape, dtype=dtype) + in_filter = constant_op.constant( + filter_array, shape=filter_shape, dtype=dtype) + + nn_ops.conv3d_transpose(in_data, + in_filter, + output_shape=output_shape, + strides=strides, + padding=padding, + data_format=data_format) + + compare_tf_with_tvm(data_array, 'Placeholder:0', 'conv3d_transpose:0', cuda_layout="NDHWC") + + +def test_forward_convolution3d_transpose(): + if is_gpu_available(): + _test_convolution3d_transpose(data_shape=[1, 10, 8, 8, 8], + filter_shape=[1, 1, 1, 6, 10], + strides=[1, 1, 1], + padding='VALID', + output_shape=[1, 6, 8, 8, 8]) + + _test_convolution3d_transpose(data_shape=[4, 9, 8, 8, 8], + filter_shape=[1, 1, 1, 6, 9], + strides=[1, 1, 1], + padding='VALID', + output_shape=[4, 6, 8, 8, 8]) + + _test_convolution3d_transpose(data_shape=[1, 3, 8, 8, 8], + filter_shape=[1, 1, 1, 6, 3], + strides=[2, 2, 2], + padding='SAME', + output_shape=[1, 6, 15, 15, 15]) + + _test_convolution3d_transpose(data_shape=[1, 16, 8, 8, 8], + filter_shape=[3, 3, 3, 6, 16], + strides=[3, 3, 3], + padding='VALID', + output_shape=[1, 6, 24, 24, 24]) + + _test_convolution3d_transpose(data_shape=[1, 8, 8, 8, 10], + filter_shape=[1, 1, 1, 6, 10], + strides=[1, 1, 1], + padding='VALID', + output_shape=[1, 8, 8, 8, 6], + data_format='NDHWC') + + _test_convolution3d_transpose(data_shape=[4, 8, 8, 8, 9], + filter_shape=[1, 1, 1, 6, 9], + strides=[1, 1, 1], + padding='VALID', + output_shape=[4, 8, 8, 8, 6], + data_format='NDHWC') + + _test_convolution3d_transpose(data_shape=[1, 8, 8, 8, 3], + filter_shape=[1, 1, 1, 6, 3], + strides=[2, 2, 2], + padding='SAME', + output_shape=[1, 15, 15, 15, 6], + data_format='NDHWC') + + _test_convolution3d_transpose(data_shape=[1, 8, 8, 8, 16], + filter_shape=[3, 3, 3, 6, 16], + strides=[3, 3, 3], + padding='VALID', + output_shape=[1, 24, 24, 24, 6], + data_format='NDHWC') + + ####################################################################### # BiasAdd # ----------- @@ -747,6 +848,17 @@ def _test_reshape_like(data, shape_like): compare_tf_with_tvm(data, 'Placeholder:0', 'Reshape:0') +def _test_reshape_symbolic(data, a_data, b_data): + with tf.Graph().as_default(): + in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype) + a = array_ops.placeholder(shape=a_data.shape, dtype=a_data.dtype) + b = array_ops.placeholder(shape=b_data.shape, dtype=b_data.dtype) + newshape = tf.add(a, b) + out = array_ops.reshape(in_data, newshape) + + for mode in ["debug", "vm"]: + compare_tf_with_tvm([data, a_data, b_data], [in_data.name, a.name, b.name], out.name, mode=mode) + def test_forward_reshape(): _test_reshape(np.arange(6.0), [2, 3]) _test_reshape(np.arange(6), [-1, 2]) @@ -754,6 +866,10 @@ def test_forward_reshape(): _test_reshape(np.arange(6), [-1]) _test_reshape_with_call() _test_reshape_like(np.zeros((3, 6)), np.zeros((9, 2))) + _test_reshape_symbolic(np.arange(6.0), np.array([2, 0]), np.array([0, 3])) + _test_reshape_symbolic(np.arange(6), np.array([-1, 0]), np.array([0, 2])) + _test_reshape_symbolic(np.arange(6), np.array([3, 0]), np.array([3, -1])) + _test_reshape_symbolic(np.arange(6), np.array([0]), np.array([-1])) ####################################################################### # DepthToSpace @@ -843,6 +959,9 @@ def test_forward_squeeze(): # TensorArray # ----------- def test_tensor_array_write_read(): + if package_version.parse(tf.VERSION) >= package_version.parse('1.15.0'): + pytest.skip("Needs fixing for tflite >= 1.15.0") + def run(dtype_str, infer_shape, element_shape): with tf.Graph().as_default(): dtype = tf_dtypes[dtype_str] @@ -865,13 +984,21 @@ def run(dtype_str, infer_shape, element_shape): def test_tensor_array_scatter(): + if package_version.parse(tf.VERSION) >= package_version.parse('1.15.0'): + pytest.skip("Needs fixing for tflite >= 1.15.0") + def run(dtype_str, infer_shape): with tf.Graph().as_default(): dtype = tf_dtypes[dtype_str] + if infer_shape: + element_shape = tf.TensorShape([tf.Dimension(None)]) + else: + element_shape = None t = tf.constant(np.array([[1.0], [2.0], [3.0]]).astype(dtype_str), dtype=dtype) indices = tf.constant([2, 1, 0]) ta1 = tf.TensorArray(dtype=dtype, size=3, - infer_shape=infer_shape) + infer_shape=infer_shape, + element_shape=element_shape) ta2 = ta1.scatter(indices, t) out0 = ta2.read(0) out1 = ta2.read(1) @@ -886,6 +1013,9 @@ def run(dtype_str, infer_shape): def test_tensor_array_gather(): + if package_version.parse(tf.VERSION) >= package_version.parse('1.15.0'): + pytest.skip("Needs fixing for tflite >= 1.15.0") + def run(dtype_str, infer_shape): with tf.Graph().as_default(): dtype = tf_dtypes[dtype_str] @@ -902,6 +1032,9 @@ def run(dtype_str, infer_shape): def test_tensor_array_split(): + if package_version.parse(tf.VERSION) >= package_version.parse('1.15.0'): + pytest.skip("Needs fixing for tflite >= 1.15.0") + def run(dtype_str, infer_shape): with tf.Graph().as_default(): dtype = tf_dtypes[dtype_str] @@ -924,6 +1057,9 @@ def run(dtype_str, infer_shape): def test_tensor_array_concat(): + if package_version.parse(tf.VERSION) >= package_version.parse('1.15.0'): + pytest.skip("Needs fixing for tflite >= 1.15.0") + def run(dtype_str, infer_shape): with tf.Graph().as_default(): dtype = tf_dtypes[dtype_str] @@ -941,11 +1077,20 @@ def run(dtype_str, infer_shape): def test_tensor_array_size(): + if package_version.parse(tf.VERSION) >= package_version.parse('1.15.0'): + pytest.skip("Needs fixing for tflite >= 1.15.0") + def run(dtype_str, infer_shape): with tf.Graph().as_default(): dtype = tf_dtypes[dtype_str] + np_data = np.array([[1.0, 2.0], [3.0, 4.0]]).astype(dtype_str) + in_data = [np_data, np_data] + t1 = tf.constant(np_data, dtype=dtype) + t2 = tf.constant(np_data, dtype=dtype) ta1 = tf.TensorArray(dtype=dtype, size=2, infer_shape=infer_shape) - out = ta1.size() + ta2 = ta1.write(0, t1) + ta3 = ta2.write(1, t2) + out = ta3.size() g = tf.get_default_graph() compare_tf_with_tvm([], [], 'TensorArraySizeV3:0', mode='debug') for dtype in ["float32", "int8"]: @@ -955,6 +1100,9 @@ def run(dtype_str, infer_shape): def test_tensor_array_stack(): def run(dtype_str, infer_shape): + if package_version.parse(tf.VERSION) >= package_version.parse('1.15.0'): + pytest.skip("Needs fixing for tflite >= 1.15.0") + with tf.Graph().as_default(): dtype = tf_dtypes[dtype_str] t = tf.constant(np.array([[1.0], [2.0], [3.0]]).astype(dtype_str)) @@ -972,6 +1120,9 @@ def run(dtype_str, infer_shape): def test_tensor_array_unstack(): def run(dtype_str, input_shape, infer_shape): + if package_version.parse(tf.VERSION) >= package_version.parse('1.15.0'): + pytest.skip("Needs fixing for tflite >= 1.15.0") + with tf.Graph().as_default(): dtype = tf_dtypes[dtype_str] t = tf.constant(np.random.choice([0, 1, 2, 3], @@ -1114,13 +1265,13 @@ def test_read_variable_op(): tf_output = run_tf_graph(sess, in_data, in_name, out_name) shape_dict = {e: i.shape for e, i in zip(in_name, in_data)} - with pytest.raises(Exception) as exexcinfo: + with pytest.raises(Exception) as execinfo: mod, params = relay.frontend.from_tensorflow(final_graph_def, layout=None, shape=shape_dict, outputs=None) - assert exexcinfo.value.args[0].startswith("Graph is not frozen. Provide a frozen graph.") + assert execinfo.value.args[0].startswith("Graph is not frozen. Provide a frozen graph") # Now convert the variables to constant and run inference on the converted graph final_graph_def = tf.graph_util.convert_variables_to_constants( @@ -1355,7 +1506,7 @@ def test_forward_truncatemod(): ####################################################################### -# Gather, GatherV2, GatherNd +# Gather, GatherV2 # -------------------------- def _test_gather(ip_shape, indice_shape, indice_value, axis, dtype): @@ -1394,16 +1545,32 @@ def test_forward_gather(): _test_gather((3, 3, 3), (1, 1, 2), [[[1, 0]]], 2, 'int32') _test_gather((4, 3, 5, 6), (1, 4), [[2, 1, 0, 0]], 0, 'float32') +####################################################################### +# GatherND +# -------------------------- -def test_forward_gather_nd(): +def _test_gather_nd(ip_shape, indice_value, dtype): """test operator GatherNd""" - np_data = np.random.uniform(1, 100, size=(2, 2, 2)).astype(np.float32) + np_data = np.random.uniform(1, 100, size=ip_shape).astype(dtype) tf.reset_default_graph() with tf.Graph().as_default(): - in_data = tf.placeholder(tf.float32, (2, 2, 2), name="in_data") - tf.gather_nd(in_data, indices=[[1, 0, 0], [0, 0, 0]], name="gather_nd") + in_data = tf.placeholder(dtype, ip_shape, name="in_data") + tf.gather_nd(in_data, indices=indice_value, name="gather_nd") compare_tf_with_tvm([np_data], ['in_data:0'], 'gather_nd:0') +def test_forward_gather_nd(): + """test operator GatherNd""" + _test_gather_nd((2, 2), [[0, 0], [1, 1]], 'float32') + _test_gather_nd((2, 2, 2), [[1, 0, 0], [0, 0, 0]], 'float32') + _test_gather_nd((4,), [1], 'float32') + _test_gather_nd((4,), [1], 'int32') + _test_gather_nd((1, 4), [0, 3], 'int32') + _test_gather_nd((2, 2), [[[1, 0], [0, 1]]], 'int32') + _test_gather_nd((2, 2), [[[1, 0], [0, 1]]], 'float32') + _test_gather_nd((3, 3, 3), [[[1, 0]]], 'int32') + _test_gather_nd((3, 3, 3), [[[1, 0]]], 'int32') + _test_gather_nd((4, 3, 5, 6), [[2, 1, 0, 0]], 'float32') + _test_gather_nd((3, 3, 3), [[[2, 1]]], 'int32') ####################################################################### # BiasAdd @@ -1771,12 +1938,24 @@ def _test_fill_from_tensor(in_shape): compare_tf_with_tvm(data, 'Placeholder:0', 'out1:0') +def _test_fill_symbolic_inputs(in_shape_data, in_value_data, dtype): + with tf.Graph().as_default(): + in_shape = tf.placeholder(shape=[in_shape_data.shape[0]], dtype=in_shape_data.dtype) + in_value = tf.placeholder(shape=(), dtype=dtype) + out = tf.fill(in_shape, in_value) + for mode in ['debug', 'vm']: + compare_tf_with_tvm([in_shape_data, in_value_data], [in_shape.name, in_value.name], out.name, mode=mode) + + def test_forward_fill(): """ Resize Bilinear """ _test_fill((32)) _test_fill((6, 32, 64, 64)) _test_fill_from_tensor((6, 32, 64, 64)) + _test_fill_symbolic_inputs(np.array((2,)), np.int32(9), tf.int32) + _test_fill_symbolic_inputs(np.array((2, 3)), 9, tf.int64) + _test_fill_symbolic_inputs(np.array((2, 3, 4)), np.float32(9.0), tf.float32) ####################################################################### # Crop to bounding box @@ -1846,6 +2025,30 @@ def test_forward_crop_and_resize(): extrapolation_value=0.2, method='nearest') +####################################################################### +# Non Max Suppression +# ------------------- +def _test_forward_nms_v3(bx_shape, score_shape, iou_threshold, score_threshold, out_size, dtype="float32"): + boxes = np.random.uniform(0, 10, size=bx_shape).astype(dtype) + scores = np.random.uniform(size=score_shape).astype(dtype) + tf.reset_default_graph() + in_data_1 = tf.placeholder(dtype, boxes.shape, name="in_data_1") + in_data_2 = tf.placeholder(dtype, scores.shape, name="in_data_2") + tf.image.non_max_suppression(boxes=in_data_1, scores=in_data_2, + max_output_size=out_size, iou_threshold=iou_threshold, + score_threshold=score_threshold, name="nms") + compare_tf_with_tvm([boxes, scores], ['in_data_1:0', 'in_data_2:0'], + 'nms/NonMaxSuppressionV3:0', mode='vm') + compare_tf_with_tvm([boxes, scores], ['in_data_1:0', 'in_data_2:0'], + 'nms/NonMaxSuppressionV3:0', mode='debug') + +def test_forward_nms_v3(): + """ NonMaxSuppressionV3 """ + _test_forward_nms_v3((5, 4), (5,), 0.7, 0.5, 5) + _test_forward_nms_v3((20, 4), (20,), 0.5, 0.6, 10) + _test_forward_nms_v3((1000, 4), (1000,), 0.3, 0.7, 1000) + + ####################################################################### # LSTM # ---- @@ -1901,7 +2104,9 @@ def _get_tensorflow_output(): def test_forward_lstm(): '''test LSTM block cell''' - _test_lstm_cell(1, 2, 1, 0.5, 'float32') + if package_version.parse(tf.VERSION) < package_version.parse('2.0.0'): + #in 2.0, tf.contrib.rnn.LSTMBlockCell is removed + _test_lstm_cell(1, 2, 1, 0.5, 'float32') ####################################################################### @@ -2196,6 +2401,49 @@ def test_forward_resnetv2(): tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tf_output[0]), rtol=1e-5, atol=1e-5) +####################################################################### +# SSD +# --- + + +def _test_ssd_impl(): + '''Test SSD with backbone MobileNet V1''' + with tf.Graph().as_default(): + graph_def = tf_testing.get_workload( + "object_detection/ssd_mobilenet_v1_ppn_shared_" + "box_predictor_300x300_coco14_sync_2018_07_03.pb") + # Call the utility to import the graph definition into default graph. + graph_def = tf_testing.ProcessGraphDefParam(graph_def) + + data = np.random.uniform(0.0, 255.0, size=(1, 512, 512, 3)).astype('uint8') + in_node = "image_tensor" + out_node = ['detection_boxes', "detection_scores", "detection_classes"] + + with tf.Session() as sess: + tf_output = run_tf_graph( + sess, data, '{}:0'.format(in_node), ["{}:0".format(oname) for oname in out_node]) + # TODO(kevinthesun): enable gpu test when VM heterogeneous execution is ready. + for device in ["llvm"]: + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + continue + tvm_output = run_tvm_graph(graph_def, data, in_node, len(out_node), + target=device, layout="NCHW", out_names=out_node, + mode="vm", disabled_pass=["FoldScaleAxis"]) + for i in range(len(out_node)): + tvm.testing.assert_allclose(tvm_output[i], tf_output[i], + rtol=1e-3, atol=1e-3) + +@pytest.mark.skip('neo-ai/tvm: skip because stack limit of 100mb is exceeded by WellFormedChecker') +def test_forward_ssd(): + run_thread = threading.Thread(target=_test_ssd_impl, args=()) + old_stack_size = threading.stack_size(100 * 1024 * 1024) + run_thread.start() + run_thread.join() + threading.stack_size(old_stack_size) + + ####################################################################### # Placeholder # ----------- @@ -2265,7 +2513,7 @@ def _get_tvm_graph_module(graph_def): 'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_c': 'float32', 'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_h': 'float32'} target = 'llvm' - with relay.build_config(opt_level=0): + with tvm.transform.PassContext(opt_level=0): graph, lib, params = relay.build(mod, target, params=params) @@ -2584,15 +2832,6 @@ def test_forward_zeros_like(): _test_forward_zeros_like((2, 3, 11), "float64") -def test_forward_erf(): - ishape = (1, 3, 10, 10) - inp_array = np.random.uniform(-5, 5, size=ishape).astype(np.float32) - with tf.Graph().as_default(): - in1 = tf.placeholder(shape=inp_array.shape, dtype=inp_array.dtype) - tf.math.erf(in1) - compare_tf_with_tvm(inp_array, 'Placeholder:0', 'Erf:0') - - def test_forward_squared_difference(): ishape = (1, 3, 10, 14) inp_array_a = np.random.uniform(-5, 5, size=ishape).astype(np.float32) @@ -2659,53 +2898,33 @@ def test_forward_pow_exp(): compare_tf_with_tvm([np_in1], ['in1:0'], 'exp:0') -def test_forward_log(): - """test operator Log """ - np_data = np.random.uniform(1, 100, size=(2, 3, 5)).astype(np.float32) - tf.reset_default_graph() - with tf.Graph().as_default(): - in_data = tf.placeholder(tf.float32, (2, 3, 5), name="in_data") - tf.log(in_data, name="log") - compare_tf_with_tvm([np_data], ['in_data:0'], 'log:0') - - -def test_forward_log1p(): - """test operator Log1p """ - np_data = np.random.uniform(1, 100, size=(2, 3, 5)).astype(np.float32) - tf.reset_default_graph() - with tf.Graph().as_default(): - in_data = tf.placeholder(tf.float32, (2, 3, 5), name="in_data") - tf.log1p(in_data, name="log1p") - compare_tf_with_tvm([np_data], ['in_data:0'], 'log1p:0') - - -def test_forward_cos(): - """test operator cos """ - np_data = np.random.uniform(1, 100, size=(2, 3, 5)).astype(np.float32) - tf.reset_default_graph() - with tf.Graph().as_default(): - in_data = tf.placeholder(tf.float32, (2, 3, 5), name="in_data") - tf.cos(in_data, name="cos") - compare_tf_with_tvm([np_data], ['in_data:0'], 'cos:0') +def test_forward_unary(): + def _test_forward_unary(op, a_min=1, a_max=5, dtype=np.float32): + """test unary operators""" + np_data = np.random.uniform(a_min, a_max, size=(2, 3, 5)).astype(dtype) + tf.reset_default_graph() + with tf.Graph().as_default(): + in_data = tf.placeholder(dtype, (2, 3, 5), name="in_data") + out = op(in_data) + compare_tf_with_tvm([np_data], ['in_data:0'], out.name) + + _test_forward_unary(tf.acos, -1, 1) + _test_forward_unary(tf.asin, -1, 1) + _test_forward_unary(tf.atanh, -1, 1) + _test_forward_unary(tf.sinh) + _test_forward_unary(tf.cosh) + _test_forward_unary(tf.acosh) + _test_forward_unary(tf.asinh) + _test_forward_unary(tf.atan) + _test_forward_unary(tf.sin) + _test_forward_unary(tf.cos) + _test_forward_unary(tf.tan) + _test_forward_unary(tf.tanh) + _test_forward_unary(tf.erf) + _test_forward_unary(tf.log) + _test_forward_unary(tf.log1p) -def test_forward_tan(): - """test operator tan """ - np_data = np.random.uniform(1, 100, size=(2, 3, 5)).astype(np.float32) - tf.reset_default_graph() - in_data = tf.placeholder(tf.float32, (2, 3, 5), name="in_data") - tf.tan(in_data, name="tan") - compare_tf_with_tvm([np_data], ['in_data:0'], 'tan:0') - -def test_forward_atan(): - """test operator tan """ - tf.disable_eager_execution() - np_data = np.random.uniform(1, 100, size=(2, 3, 5)).astype(np.float32) - tf.reset_default_graph() - in_data = tf.placeholder(tf.float32, (2, 3, 5), name="in_data") - tf.atan(in_data, name="atan") - compare_tf_with_tvm([np_data], ['in_data:0'], 'atan:0') - def test_forward_atan2(): """test operator tan """ tf.disable_eager_execution() @@ -2718,16 +2937,6 @@ def test_forward_atan2(): compare_tf_with_tvm([np_data_1, np_data_2], ['in_data_1:0', 'in_data_2:0'], 'atan2:0') -def test_forward_sin(): - """test operator sin """ - np_data = np.random.uniform(1, 100, size=(2, 3, 5)).astype(np.float32) - tf.reset_default_graph() - with tf.Graph().as_default(): - in_data = tf.placeholder(tf.float32, (2, 3, 5), name="in_data") - tf.sin(in_data, name="sin") - compare_tf_with_tvm([np_data], ['in_data:0'], 'sin:0') - - def test_forward_negative(): """test tf operator Neg """ np_data = np.random.uniform(-100, 255, @@ -3192,6 +3401,337 @@ def test_forward_isfinite(): _verify_infiniteness_ops(tf.is_finite, "isfinite") +def _test_spop_placeholder_without_shape_info(): + with tf.Graph().as_default(): + + @function.Defun(*[tf.int32]*2) + def Forward(x,y): + print(x.name) + print(y.name) + b = tf.add(x, y) + return b + pl1 = tf.placeholder(tf.int32,name="pl1") + pl2 = tf.placeholder(tf.int32,name="pl2") + pl3 = tf.placeholder(tf.int32, name="pl3") + data = np.array([[-1, 1], [2, -2]], dtype=np.int32) + data2 = np.array([[-2, 3], [4, -6]], dtype=np.int32) + data3 = np.array([[-2, 3], [4, -6]], dtype=np.int32) + z1 = gen_functional_ops.StatefulPartitionedCall(args=[pl1,pl2], Tout=[tf.int32],f=Forward) + z2 = z1 + pl3 + compare_tf_with_tvm([data, data2, data3], ['pl1:0', 'pl2:0', 'pl3:0'], + ['StatefulPartitionedCall:0',z2.name], mode='vm', init_global_variables=True) + + +def _test_spop_placeholder_with_shape_and_default_value(): + with tf.Graph().as_default(): + data = np.ones([1], dtype=int).astype(np.int32) + dataVar = tf.Variable(data, shape=data.shape) + pl1 = array_ops.placeholder_with_default(dataVar,shape=data.shape,name="pl1") + tpl = tf.convert_to_tensor(pl1, dtype=tf.int32) + + @function.Defun(*[tf.int32]) + def pl_with_default(pl): + return tf.expand_dims(tf.multiply(pl, pl), 0) + + z = gen_functional_ops.StatefulPartitionedCall(args=[tpl], Tout=[tf.int32], f=pl_with_default) + compare_tf_with_tvm(data, ['pl1:0'], 'StatefulPartitionedCall:0', mode='vm', init_global_variables=True) + + +def _test_spop_placeholder_numpy_arange_feed(): + with tf.Graph().as_default(): + t1 = tf.placeholder(tf.int32, (3, 3, 3), "t1") + t1_data = np.arange(27, dtype=np.int32).reshape((3, 3, 3)) + t2 = tf.placeholder(tf.int32, (3, 3, 3), "t2") + t2_data = np.arange(27, dtype=np.int32).reshape((3, 3, 3)) + + @tf.function + def add(x, y): + return tf.add(x, y, "add_t1_t2") + + t3 = add(t1, t2) + compare_tf_with_tvm([t1_data, t2_data], ['t1:0', 't2:0'], [t3.name], mode='vm', init_global_variables=True) + + +def _test_spop_placeholder_numpy_array_feed(): + with tf.Graph().as_default(): + t1_data = np.array([[-1, 1, 3], [2, -2, 4], [2, -3, 14]], dtype=np.int32) + t2_data = np.array([[-2, 1, 2], [12, -2, 14], [12, -3, 4]], dtype=np.int32) + t1 = tf.placeholder(tf.int32, name="t1") + t2 = tf.placeholder(tf.int32, name="t2") + + @tf.function + def add(x, y): + return tf.add(x, y, "add_t1_t2") + + t3 = add(t1, t2) + compare_tf_with_tvm([t1_data, t2_data], ['t1:0', 't2:0'], [t3.name], mode='vm', init_global_variables=True) + + +def _test_spop_function_invocation_basic(): + with tf.Graph().as_default(): + + def fun1(a): + return tf.multiply(a,a) + + def fun2(b): + return tf.multiply(b,10) + + @tf.function + def fun3(x,y): + x = fun2(x) + y = fun1(y) + z = tf.add(x,y) + return z + + t3 = fun3(tf.constant(10.5), tf.constant(20.4)) + + compare_tf_with_tvm([], [], [t3.name], mode='vm', init_global_variables=True) + + +def _test_spop_function_invocation_nested(): + with tf.Graph().as_default(): + t1 = tf.placeholder(tf.int32, (3, 3, 3), name="t1") + t1_data = np.arange(27, dtype=np.int32).reshape((3, 3, 3)) + t2 = tf.placeholder(tf.int32, name="t2") + t2_data = np.arange(27, dtype=np.int32).reshape((3, 3, 3)) + + @tf.function + def myfunc(x, y): + return tf.add(x, y, "myfunc") + + @tf.function + def myfunc2(x, y): + z = myfunc(x, y) + l = myfunc(z, y) + m = myfunc(l,z) + return tf.add(l, m, "myfunc2") + + res1 = myfunc(t1, t2) + res2 = myfunc2(res1, t1) + + compare_tf_with_tvm([t1_data, t2_data], ['t1:0', 't2:0'], [res2.name], mode='vm', init_global_variables=True) + + +def _test_spop_function_invocation_no_autograph(): + with tf.Graph().as_default(): + + @tf.function(autograph=False) + def fun1(a): + return tf.multiply(a,a) + + @tf.function(autograph=False) + def fun2(b): + return tf.multiply(b,10) + + @tf.function + def fun3(x,y): + x = fun2(x) + y = fun1(y) + z = tf.add(x,y) + return z + + t3 = fun3(tf.constant(10.5), tf.constant(20.4)) + + compare_tf_with_tvm([], [], [t3.name], mode='vm', init_global_variables=True) + + +def _test_spop_function_invocation_defun(): + with tf.Graph().as_default(): + + def fun1(a): + return tf.multiply(a,a) + + def fun2(b): + return tf.multiply(b,b) + + @function.Defun(dtypes.float32, dtypes.float32, func_name="Fun3") + def fun3(x,y): + x = fun2(x) + y = fun1(y) + z = tf.add(x,y) + return z + + op = gen_functional_ops.StatefulPartitionedCall(args=[tf.constant(10.5),tf.constant(20.4)], + Tout=[dtypes.float32], f=fun3, name="SpopFnInvocation") + compare_tf_with_tvm([],[], 'SpopFnInvocation:0', mode='vm', init_global_variables=True) + + +def _test_spop_arithmetic(): + with tf.Graph().as_default(): + @function.Defun(*[dtypes.int32]*3) + def arithmetic(m,x,c): + z = tf.add(tf.multiply(m, x), c) + return z + + m = tf.constant(10) + x = tf.constant(20) + c = tf.constant(2) + spopFn = gen_functional_ops.StatefulPartitionedCall(args=[m,x,c],Tout=[tf.int32], f=arithmetic) + + compare_tf_with_tvm([],[],'StatefulPartitionedCall:0', mode='vm', init_global_variables=True) + + +def _test_spop_control_flow(): + with tf.Graph().as_default(): + + @function.Defun(*[dtypes.float32] * 2) + def Body1(x, y): + with ops.device("/job:localhost/replica:0/task:0/device:CPU:0"): + z = math_ops.multiply(x, y) + i = 0 + while i<10 : + i +=1 + if i == 5: + continue + z = math_ops.multiply(x, y*i) + return z + + op = gen_functional_ops.StatefulPartitionedCall( + args=[constant_op.constant(32.), constant_op.constant(100.)], + Tout=[dtypes.float32], f=Body1) + compare_tf_with_tvm([], [], 'StatefulPartitionedCall:0', mode='vm', init_global_variables=True) + + +def _test_spop_variables(): + with tf.Graph().as_default(): + const1 = tf.constant(10) + const2 = tf.constant(20) + var1 = tf.Variable(const1, dtype=tf.int32) + var2 = tf.Variable(const2, dtype=tf.int32) + + @function.Defun(tf.int32,tf.int32) + def Forward(x,y): + return tf.multiply(x,y) + + z = gen_functional_ops.StatefulPartitionedCall(args=[var1,var2],Tout=[tf.int32], f=Forward) + compare_tf_with_tvm([], [], 'StatefulPartitionedCall:0', init_global_variables=True, mode="vm") + + +def _test_spop_constants(): + with tf.Graph().as_default(): + @function.Defun(*[dtypes.int32] * 2) + def constantsFn(x, y): + vv = tf.constant([2, 3, 4], name="vv") + z = tf.add(vv + x, y) + return z + + a = tf.constant(20000, name = "a") + b = tf.constant(40000, name = "b") + spopFn = gen_functional_ops.StatefulPartitionedCall(args=[a, b], Tout=[tf.int32], f=constantsFn) + + compare_tf_with_tvm([], [], 'StatefulPartitionedCall:0', mode='vm', init_global_variables=True) + + +def _test_spop_stateful(): + # This test case is to test that TVM rejects any TF stateful operations + # (including Resource Variables) except StatefulPartitionedCall/PartitionedCall + # (as these two operators can still be used as container graphs to execute + # "stateless" operations internally. + tf.reset_default_graph() + with tf.Graph().as_default(): + + @tf.function + def FunctionWithStatefulOp_One(i): + b = tf.random.uniform(shape=[2, 4], maxval=10, dtype=tf.float32, seed=10) + y = tf.multiply(b, i) + return y + + @tf.function + def FunctionWithStatefulOp(m, n): + a = tf.random.uniform(shape=[2, 4], maxval=10, dtype=tf.float32, seed = 10) + x = tf.multiply(a,m) + y = FunctionWithStatefulOp_One(n) + z = tf.multiply(x,y) + return z + + op = FunctionWithStatefulOp(constant_op.constant(1.), constant_op.constant(2.)) + with pytest.raises(Exception) as execinfo: + compare_tf_with_tvm([], [], [op.name], init_global_variables=True, mode="vm") + assert execinfo.value.args[0].startswith( + "The following operators are not implemented") + + +def _test_spop_device_assignment(): + # This test case is to test that TVM rejects inconsistent device assignment + # while using StatefulPartitionedCall/PartitionedCall operators which in case of TVM will + # be used as container graphs to internally execute "stateless" operations. + + tf.reset_default_graph() + with tf.Graph().as_default(): + + def fun1(a): + with ops.device("/GPU:0"): + return tf.multiply(a,a) + + def fun2(b): + with ops.device("/job:localhost/replica:0/task:0/device:CPU:1"): + return tf.multiply(b,b) + + @function.Defun(dtypes.float32, dtypes.float32, func_name="Fun3") + def fun3(x,y): + with ops.device("/CPU:0"): + x = fun2(x) + with ops.device("/job:localhost/replica:0/task:0/device:CPU:2"): + y = fun1(y) + with ops.device("/job:localhost/replica:0/task:0/device:CPU:3"): + z = tf.add(x,y) + return z + + op = gen_functional_ops.StatefulPartitionedCall(args=[tf.constant(10.5),tf.constant(20.4)], + Tout=[dtypes.float32], f=fun3) + with pytest.raises(Exception) as execinfo: + compare_tf_with_tvm([], [], 'StatefulPartitionedCall:0', + mode='vm', init_global_variables=True) + assert execinfo.value.args[0].startswith("Found inconsistent Device assignment") + + +def _test_spop_resource_variables(): + # This test case is to test that TVM rejects any graph containing + # resource variables with StatefulPartitionedOp. + + tf.reset_default_graph() + with tf.Graph().as_default(): + + const1 = tf.constant(10) + const2 = tf.constant(20) + var1 = tf.Variable(const1, dtype=tf.int32, use_resource=True) + var2 = tf.Variable(const2, dtype=tf.int32, use_resource=True) + + @tf.function + def resourceVariablesTest(x, y): + return tf.multiply(x, y) + + op = resourceVariablesTest(var1,var2) + with pytest.raises(Exception) as execinfo: + compare_tf_with_tvm([], [], 'StatefulPartitionedCall:0', + mode='vm', init_global_variables=True) + assert execinfo.value.args[0].startswith("Graph is not frozen." + " Provide a frozen graph") + +def test_forward_spop(): + _test_spop_stateful() + _test_spop_device_assignment() + _test_spop_resource_variables() + + #Placeholder test cases + _test_spop_placeholder_without_shape_info() + _test_spop_placeholder_with_shape_and_default_value() + _test_spop_placeholder_numpy_arange_feed() + _test_spop_placeholder_numpy_array_feed() + + #Function Invocation test cases + _test_spop_function_invocation_basic() + _test_spop_function_invocation_nested() + _test_spop_function_invocation_no_autograph() + _test_spop_function_invocation_defun() + + #Test cases for various other TF constructs + _test_spop_arithmetic() + _test_spop_control_flow() + _test_spop_variables() + _test_spop_constants() + + ####################################################################### # Main # ---- @@ -3227,8 +3767,8 @@ def test_forward_isfinite(): test_forward_left_shift() test_forward_truncatemod() test_forward_one_hot() - test_forward_atan() test_forward_atan2() + test_forward_nms_v3() # Activations test_forward_sigmoid() @@ -3243,11 +3783,6 @@ def test_forward_isfinite(): test_forward_reverse_v2() test_forward_pow_exp() test_forward_sign() - test_forward_log() - test_forward_log1p() - test_forward_tan() - test_forward_cos() - test_forward_sin() test_forward_negative() test_forward_divide() test_forward_abs() @@ -3260,13 +3795,13 @@ def test_forward_isfinite(): test_forward_log_softmax() test_forward_bias_add() test_forward_zeros_like() - test_forward_erf() test_forward_squared_difference() test_forward_add_n() test_forward_floormod() test_forward_isfinite() test_forward_isinf() test_forward_unravel_index() + test_forward_unary() # Reductions test_forward_argminmax() @@ -3291,6 +3826,8 @@ def test_forward_isfinite(): # NN test_forward_convolution() + test_forward_convolution3d() + test_forward_convolution3d_transpose() test_forward_pooling() test_forward_concat_v2() test_forward_lrn() @@ -3304,13 +3841,12 @@ def test_forward_isfinite(): test_forward_inception_v1() test_forward_mobilenet() test_forward_resnetv2() + test_forward_ssd() test_forward_placeholder() test_forward_ptb() # RNN - if package_version.parse(tf.VERSION) < package_version.parse('2.0.0'): - #in 2.0, tf.contrib.rnn.LSTMBlockCell is removed - test_forward_lstm() + test_forward_lstm() # Elementwise test_forward_ceil() @@ -3328,3 +3864,6 @@ def test_forward_isfinite(): # Sharing params case using Mean ops test_sharing_node() + + # StatefulPartitionedCall + test_forward_spop() diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 13d47b61478b..1ce4997aa73b 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -22,6 +22,7 @@ """ from __future__ import print_function from functools import partial +import pytest import numpy as np import tvm from tvm import te @@ -72,17 +73,55 @@ def get_real_image(im_height, im_width): data = np.reshape(x, (1, im_height, im_width, 3)) return data +def get_real_image_object_detection(im_height, im_width): + repo_base = 'https://github.com/dmlc/web-data/raw/master/gluoncv/detection/' + img_name = 'street_small.jpg' + image_url = os.path.join(repo_base, img_name) + img_path = download_testdata(image_url, img_name, module='data') + image = Image.open(img_path).resize((im_height, im_width)) + x = np.array(image).astype('uint8') + data = np.reshape(x, (1, im_height, im_width, 3)) + return data + +def vmobj_to_list(o): + if isinstance(o, tvm.nd.NDArray): + return [o.asnumpy().tolist()] + elif isinstance(o, tvm.runtime.container.ADT): + result = [] + for f in o: + result.extend(vmobj_to_list(f)) + return result + elif isinstance(o, tvm.relay.backend.interpreter.ConstructorValue): + if o.constructor.name_hint == 'Cons': + tl = vmobj_to_list(o.fields[1]) + hd = vmobj_to_list(o.fields[0]) + hd.extend(tl) + return hd + elif o.constructor.name_hint == 'Nil': + return [] + elif 'tensor_nil' in o.constructor.name_hint: + return [0] + elif 'tensor' in o.constructor.name_hint: + return [o.fields[0].asnumpy()] + else: + raise RuntimeError("Unknown object type: %s" % + o.constructor.name_hint) + else: + raise RuntimeError("Unknown object type: %s" % type(o)) + def run_tvm_graph(tflite_model_buf, input_data, input_node, num_output=1, target='llvm', - out_names=None): + out_names=None, mode='graph_runtime'): """ Generic function to compile on relay and execute on tvm """ + # TFLite.Model.Model has changed to TFLite.Model from 1.14 to 2.1 try: import tflite.Model + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0) + except AttributeError: + import tflite + tflite_model = tflite.Model.GetRootAsModel(tflite_model_buf, 0) except ImportError: raise ImportError("The tflite package must be installed") - # get TFLite model from buffer - tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0) - input_data = convert_to_list(input_data) input_node = convert_to_list(input_node) @@ -95,27 +134,44 @@ def run_tvm_graph(tflite_model_buf, input_data, input_node, num_output=1, target mod, params = relay.frontend.from_tflite(tflite_model, shape_dict=shape_dict, dtype_dict=dtype_dict) - with relay.build_config(opt_level=3): - graph, lib, params = relay.build(mod, target, params=params) - ctx = tvm.context(target, 0) - from tvm.contrib import graph_runtime - m = graph_runtime.create(graph, lib, ctx) - # set inputs - for i, e in enumerate(input_node): - m.set_input(e, tvm.nd.array(input_data[i].astype(input_data[i].dtype))) - - m.set_input(**params) - # execute - m.run() - # get outputs - assert out_names is None or num_output == len(out_names), "out_names: {} num_output: {}".format( - out_names, num_output) - tvm_output_list = [] - for i in range(0, num_output): - tvm_output = m.get_output(i) - tvm_output_list.append(tvm_output.asnumpy()) - return tvm_output_list + if mode in ['debug', 'vm']: + ex = relay.create_executor(mode, mod=mod, ctx=tvm.cpu(), target="llvm") + inputs = [] + for param in mod['main'].params: + found = False + for i, n in enumerate(input_node): + if n == param.name_hint: + found = True + inputs.append(tvm.nd.array(input_data[i])) + break + # Interpreter doesn't bind constants, so still need to find in params + if not found: + inputs.append(tvm.nd.array(params[param.name_hint])) + result = ex.evaluate()(*inputs) + return vmobj_to_list(result) + else: + with tvm.transform.PassContext(opt_level=3): + graph, lib, params = relay.build(mod, target, params=params) + + ctx = tvm.context(target, 0) + from tvm.contrib import graph_runtime + m = graph_runtime.create(graph, lib, ctx) + # set inputs + for i, e in enumerate(input_node): + m.set_input(e, tvm.nd.array(input_data[i].astype(input_data[i].dtype))) + + m.set_input(**params) + # execute + m.run() + # get outputs + assert out_names is None or num_output == len(out_names), "out_names: {} num_output: {}".format( + out_names, num_output) + tvm_output_list = [] + for i in range(0, num_output): + tvm_output = m.get_output(i) + tvm_output_list.append(tvm_output.asnumpy()) + return tvm_output_list def run_tflite_graph(tflite_model_buf, input_data): @@ -146,7 +202,7 @@ def run_tflite_graph(tflite_model_buf, input_data): def compare_tflite_with_tvm(in_data, in_name, input_tensors, output_tensors, init_global_variables=False, - out_names=None, quantized=False, input_range=None): + out_names=None, quantized=False, input_range=None, mode='graph_runtime'): """Generic function to generate and compare TFLite and TVM output""" in_data = convert_to_list(in_data) in_name = convert_to_list(in_name) @@ -188,7 +244,7 @@ def compare_tflite_with_tvm(in_data, in_name, input_tensors, continue tvm_output = run_tvm_graph(tflite_model_buffer, in_data, in_node, target=device, - num_output=len(out_names), out_names=out_names) + num_output=len(out_names), out_names=out_names, mode=mode) # WARNING: the results could well be random values clipped to 0 or 255 because of badly tuned output # range for the specific operator. While adding test ensure that we aren't getting only clipped values @@ -215,15 +271,19 @@ def with_fused_activation_function(input_tensor, fn_name): return math_ops.tanh(input_tensor) raise AssertionError("Unknown fused_activation_function {}".format(fn_name)) -def _test_split(in_shape, axis, num_Splits, dtype): - '''internal split tester taking as parameters in_shape, number of tensors to split into - and dtype (data type)''' + +def _test_split(in_shape, axis, num_splits, dtype): + """internal split tester taking as parameters in_shape, number of tensors to split into + and dtype (data type)""" + np_data = np.random.uniform(-5, 5, size=in_shape).astype(dtype) with tf.Graph().as_default(): - in_data = array_ops.placeholder(shape=in_shape, dtype=dtype) - out = array_ops.split(in_data, num_Splits, axis=axis) - out_names = ['out_' + str(n) + ':0' for n in range(num_Splits)] - compare_tflite_with_tvm([np_data], ['Placeholder:0'], [in_data], out, + in_data = array_ops.placeholder(shape=in_shape, dtype=dtype, name="in_data") + out = array_ops.split(in_data, num_splits, axis=axis) + num_splits = len(num_splits) if isinstance(num_splits, list) \ + else num_splits + out_names = ['out_' + str(n) + ':0' for n in range(num_splits)] + compare_tflite_with_tvm([np_data], ['in_data'], [in_data], out, out_names=out_names) def test_forward_split(): @@ -251,6 +311,9 @@ def test_forward_split(): _test_split((1, 6, 3, 5), -3, 3, 'float32') _test_split((1, 3, 6, 5), -2, 3, 'float32') _test_split((1, 3, 5, 6), -1, 3, 'float32') + # size_splits split + _test_split((6,), 0, [1, 2, 3], 'float32') + _test_split((3, 6, 4), -2, [1, 4, 1], 'float32') ####################################################################### # slice @@ -290,6 +353,121 @@ def test_forward_topk(): _test_topk((3, 5, 7), 3) _test_topk((3, 5, 7), 3) +####################################################################### +# Gather +# ------ + +def _test_gather(dshape, indices, axis, dtype, quantized=False, oob=False): + """ One iteration of Gather """ + indices = np.asarray(indices).astype('int32') + data = np.random.uniform(1, 10, size=dshape) + data = data.astype(np.uint8) if quantized else data.astype(dtype) + with tf.Graph().as_default(): + in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype, name="in_data") + if axis: + out = array_ops.gather(in_data, indices, axis=axis) + else: + out = array_ops.gather(in_data, indices) #tflite conversion fails for None axis + input_range = {'in_data': (-100, 100)} if quantized else None + try: + compare_tflite_with_tvm([data], ['in_data:0'], [in_data], [out], + quantized=quantized, input_range=input_range) + except ValueError as e: + if not oob: + raise e + except Exception as e: + raise e + +def test_forward_gather(): + """ GATHER """ + for quantized in [False, True]: + _test_gather((4,), [1], 0, 'float32', quantized) + _test_gather((4,), [1], None, 'int32', quantized) + _test_gather((1, 4), [0], 0, 'int32', quantized) + _test_gather((4,), [[[1, 0], [0, 1]]], 0, 'float32', quantized) + _test_gather((2, 2), [[[1, 0], [0, 1]]], 1, 'int32', quantized) + _test_gather((2, 2), [[[1, 0], [0, 1]]], None, 'float32', quantized) + _test_gather((3, 3, 3), [[[1, 0]]], 0, 'int32', quantized) + _test_gather((3, 3, 3), [[[1, 0]]], 2, 'int32', quantized) + _test_gather((4, 3, 5, 6), [[2, 1, 0, 0]], 0, 'float32', quantized) + _test_gather((3, 3, 3), [[[2, 1]]], -1, 'int32', quantized) + _test_gather((4,), [16], 0, 'float32', quantized, oob=True) + _test_gather((1, 3, 3), [12], 0, 'int32', quantized, oob=True) + _test_gather((1, 3, 3), [20], 1, 'float32', quantized, oob=True) + _test_gather((1, 3, 3), [20, 20], 2, 'float32', quantized, oob=True) + +####################################################################### +# Gather_ND +# --------- + +def _test_gather_nd(data, indices): + """ One iteration of GATHER_ND """ + with tf.Graph().as_default(): + in_data = tf.placeholder(shape=data.shape, dtype=data.dtype, name="data") + indices_data = tf.placeholder(shape=indices.shape, dtype=indices.dtype, + name="indices") + out = tf.gather_nd(in_data, indices_data) + + compare_tflite_with_tvm([data, indices], ['data:0', 'indices:0'], + [in_data, indices_data], [out]) + +def test_forward_gather_nd(): + """ GATHER_ND """ + _test_gather_nd( + np.array([[[1.2, 2.0], [3.1, 4.1]], [[5.1, 6.1], [7.1, 8.1]]]).astype('float32'), + np.asarray([[0, 1], [1, 0]]).astype('int32') + ) + _test_gather_nd( + np.reshape(np.arange(30), [5, 6]).astype('int32'), + np.asarray([[1, 2]]).astype('int32') + ) + _test_gather_nd( + np.reshape(np.arange(12), [2, 3, 2]).astype('int32'), + np.asarray([[[0, 0], [0, 1]], [[1, 0], [1, 1]]]).astype('int32') + ) + _test_gather_nd( + np.reshape(np.arange(4), [4]).astype('float32'), + np.asarray([1]).astype('int32') + ) + _test_gather_nd( + np.reshape(np.arange(4), [1, 4]).astype('float32'), + np.asarray([0]).astype('int32') + ) + _test_gather_nd( + np.reshape(np.arange(4), [1, 4]).astype('float32'), + np.asarray([0, 3]).astype('int32') + ) + +####################################################################### +# StridedSlice +# ------------ + +def _test_stridedslice(ip_shape, begin, end, stride, dtype, + begin_mask=0, end_mask=0, new_axis_mask=0, + shrink_axis_mask=0, ellipsis_mask=0, quantized=False): + """ One iteration of a Stridedslice """ + data = np.random.uniform(size=ip_shape).astype(dtype) + data = data.astype(np.uint8) if quantized else data.astype(dtype) + with tf.Graph().as_default(): + in_data = tf.placeholder(dtype, ip_shape, name="in_data") + out = array_ops.strided_slice(in_data, begin, end, stride, + begin_mask=begin_mask, + end_mask=end_mask, + new_axis_mask=new_axis_mask, + shrink_axis_mask=shrink_axis_mask, + ellipsis_mask=ellipsis_mask) + input_range = {'in_data': (-100, 100)} if quantized else None + compare_tflite_with_tvm([data], ['in_data:0'], [in_data], [out], quantized=quantized, + input_range=input_range) + +def test_forward_stridedslice(): + '''test StridedSlice''' + for quantized in [False, True]: + _test_stridedslice((2), [1], [1], [1], 'float32', shrink_axis_mask=1, quantized=quantized) + _test_stridedslice((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], 'float32', quantized=quantized) + _test_stridedslice((3, 4), [1, 0], [4, 4], [1, 1], 'float32', shrink_axis_mask=0, quantized=quantized) + _test_stridedslice((4, 4), [1, 0], [4, 4], [1, 1], 'float32', shrink_axis_mask=2, quantized=quantized) + ####################################################################### # transpose # --------- @@ -336,6 +514,29 @@ def test_forward_cast(): _test_cast(np.arange(6.0, dtype=np.float32).reshape((1, 6)), cast_dtype=tf.uint8) _test_cast(np.arange(6.0, dtype=np.int32).reshape((1, 6)), cast_dtype=tf.int64) +####################################################################### +# Batch Mat Mul +# ---- +def _test_batch_matmul(A_shape, B_shape, dtype, adjoint_a=False, adjoint_b=False): + with tf.Graph().as_default(): + A = array_ops.placeholder(shape=A_shape, dtype=dtype, name='A') + B = array_ops.placeholder(shape=B_shape, dtype=dtype, name='B') + result = math_ops.matmul(A, B, adjoint_a=adjoint_a, + adjoint_b=adjoint_b, name='batchmatmul') + + A_np = np.random.uniform(high=5.0, size=A_shape).astype(dtype) + B_np = np.random.uniform(high=5.0, size=B_shape).astype(dtype) + compare_tflite_with_tvm([A_np, B_np], [A.name, B.name], [A, B], [result]) + + +def test_forward_batch_matmul(): + """ BATCH_MAT_MUL """ + _test_batch_matmul((3, 5, 4), (3, 4, 5), 'float32') + _test_batch_matmul((3, 5, 4), (3, 4, 5), 'float32', True, True) + _test_batch_matmul((3, 5, 4), (3, 5, 4), 'float32', True, False) + _test_batch_matmul((3, 5, 4), (3, 5, 4), 'float32', False, True) + _test_batch_matmul((2, 3, 4, 5, 6), (2, 3, 4, 6, 5), 'float32') + ####################################################################### # Tile # ---- @@ -487,13 +688,38 @@ def test_forward_pooling(): strides=[2, 1]) +def _test_l2_pool2d(input_shape, ksize, strides, padding, data_format, fused_func_name=None): + x = np.arange(np.prod(input_shape), dtype=np.float32).reshape(input_shape) - 1 + + with tf.Graph().as_default(): + in_data = tf.placeholder( + dtype=tf.float32, name="input", shape=input_shape) + out = tf.sqrt(tf.nn.avg_pool( + tf.square(in_data), ksize=ksize, strides=strides, + padding=padding, data_format=data_format)) + out = with_fused_activation_function(out, fused_func_name) + + compare_tflite_with_tvm(x, 'input', [in_data], [out]) + + +def test_forward_l2_pool2d(): + _test_l2_pool2d([1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], 'SAME', "NHWC", "RELU6") + _test_l2_pool2d([2, 9, 10, 2], [1, 1, 1, 1], [1, 1, 1, 1], 'SAME', "NHWC", "RELU6") + _test_l2_pool2d([2, 9, 10, 2], [1, 2, 1, 1], [1, 1, 1, 1], 'SAME', "NHWC") + _test_l2_pool2d([2, 9, 10, 2], [1, 2, 1, 1], [1, 1, 2, 1], 'SAME', "NHWC") + _test_l2_pool2d([1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], 'VALID', "NHWC", "RELU") + _test_l2_pool2d([2, 9, 10, 2], [1, 1, 1, 1], [1, 1, 1, 1], 'VALID', "NHWC") + _test_l2_pool2d([2, 9, 10, 2], [1, 2, 1, 1], [1, 1, 1, 1], 'VALID', "NHWC") + _test_l2_pool2d([2, 9, 10, 2], [1, 2, 1, 1], [1, 1, 2, 1], 'VALID', "NHWC", "RELU6") + + ####################################################################### # Convolution # ----------- def _test_convolution(tensor_in_sizes, filter_in_sizes, dilations, strides, padding, data_format, - is_depthwise=False): + is_depthwise=False, quantized=False): """ One iteration of convolution with given shapes and attributes """ total_size_1 = 1 @@ -504,12 +730,16 @@ def _test_convolution(tensor_in_sizes, filter_in_sizes, total_size_2 *= s # Initializes the input tensor with array containing incrementing # numbers from 1. - data_array = [f * 1.0 for f in range(1, total_size_1 + 1)] - filter_array = [f * 1.0 for f in range(1, total_size_2 + 1)] + if quantized: + data_array = np.random.uniform(0, 255, tensor_in_sizes).astype('uint8') + filter_array = np.random.uniform(0, 255, filter_in_sizes).astype('uint8') + else: + data_array = [f * 1.0 for f in range(1, total_size_1 + 1)] + filter_array = [f * 1.0 for f in range(1, total_size_2 + 1)] with tf.Graph().as_default(): - in_data = array_ops.placeholder(shape=tensor_in_sizes, dtype='float32') - in_filter = constant_op.constant(filter_array, shape=filter_in_sizes, dtype='float32') + in_data = array_ops.placeholder(shape=tensor_in_sizes, dtype='float32', name='in_data') + in_filter = constant_op.constant(filter_array, shape=filter_in_sizes, dtype='float32', name='in_filter') strides = [1] + strides + [1] dilations = [1] + dilations + [1] @@ -525,15 +755,37 @@ def _test_convolution(tensor_in_sizes, filter_in_sizes, strides=strides, padding=padding, data_format=data_format) - data_array = np.reshape(data_array, tensor_in_sizes).astype('float32') - compare_tflite_with_tvm(data_array, 'Placeholder:0', [in_data], [out]) + + if quantized: + # For now only quantized conv2d is supported + assert not is_depthwise + + # Quantized the inputs and feed them to the convolution + inq_data = tf.quantization.fake_quant_with_min_max_args(in_data, min=-100, max=100, name='inq_data') + inq_filter = tf.quantization.fake_quant_with_min_max_args(in_filter, min=-100, max=100, name='inq_filter') + out = nn_ops.conv2d(inq_data, + inq_filter, + strides=strides, + padding=padding, + data_format=data_format) + out = tf.quantization.fake_quant_with_min_max_args(out, min=-200, max=200, name="out") + + # Set the input quantization range + input_range = {'in_data': (-100, 100)} if quantized else None + + # Compare + compare_tflite_with_tvm(data_array, 'in_data', [in_data], [out], quantized=quantized, input_range=input_range) + else: + data_array = np.reshape(data_array, tensor_in_sizes).astype('float32') + compare_tflite_with_tvm(data_array, 'in_data', [in_data], [out]) def test_forward_convolution(): - _test_convolution([4, 8, 8, 176], [1, 1, 176, 32], [1, 1], [1, 1], 'SAME', 'NHWC') - _test_convolution([4, 17, 17, 19], [3, 3, 19, 19], [1, 1], [2, 2], 'VALID', 'NHWC') - _test_convolution([4, 17, 17, 124], [1, 1, 124, 19], [1, 1], [1, 1], 'SAME', 'NHWC') - _test_convolution([4, 17, 17, 12], [3, 3, 12, 32], [1, 1], [2, 2], 'VALID', 'NHWC') + for quantized in [False, True]: + _test_convolution([4, 8, 8, 176], [1, 1, 176, 32], [1, 1], [1, 1], 'SAME', 'NHWC', quantized=quantized) + _test_convolution([4, 17, 17, 19], [3, 3, 19, 19], [1, 1], [2, 2], 'VALID', 'NHWC', quantized=quantized) + _test_convolution([4, 17, 17, 124], [1, 1, 124, 19], [1, 1], [1, 1], 'SAME', 'NHWC', quantized=quantized) + _test_convolution([4, 17, 17, 12], [3, 3, 12, 32], [1, 1], [2, 2], 'VALID', 'NHWC', quantized=quantized) # depthwise convolution _test_convolution([4, 8, 8, 176], [1, 1, 176, 1], [1, 1], [1, 1], 'SAME', 'NHWC', True) @@ -649,6 +901,79 @@ def test_all_resize(): if 'RESIZE_NEAREST_NEIGHBOR' in dir(BuiltinOperator()): _test_resize(tf.image.resize_nearest_neighbor, data, align_corners=False) +####################################################################### +# Range +# ----- +def _test_range(start, limit, delta): + # tflite 1.13 convert method does not accept empty shapes + if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'): + tf.reset_default_graph() + with tf.Graph().as_default(): + start_scalar, limit_scalar, delta_scalar = \ + tf.placeholder(dtype=start.dtype, shape=(), name="start"), \ + tf.placeholder(dtype=limit.dtype, shape=(), name="limit"), \ + tf.placeholder(dtype=delta.dtype, shape=(), name="delta") + + out = tf.range(start_scalar, limit_scalar, delta_scalar, name="range") + + compare_tflite_with_tvm( + [start, limit, delta], + ["start", "limit", "delta"], + [start_scalar, limit_scalar, delta_scalar], + [out], + mode="vm", + quantized=False + ) + +def _test_range_default(): + # tflite 1.13 convert method does not accept empty shapes + if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'): + tf.reset_default_graph() + with tf.Graph().as_default(): + inputs = [ + tf.placeholder(dtype=tf.int32, shape=(), name="p1"), + tf.placeholder(dtype=tf.int32, shape=(), name="p2") + ] + outputs = [ + tf.range(start = inputs[0], limit = inputs[1]), # use default delta + tf.range(start = inputs[1]) # use start as limit with 0 as the first item in the range + ] + + compare_tflite_with_tvm( + [np.int32(1), np.int32(18)], + ["p1", "p2"], + inputs, + outputs, + mode="vm" + ) + +def test_forward_range(): + _test_range(np.int32(1), np.int32(18), np.int32(3)) + _test_range(np.int32(1), np.int32(18), np.float32(3.1)) # increment is of type float + _test_range(np.float32(1.0), np.int32(18), np.int32(3.1)) # start is of type float + _test_range_default() + +####################################################################### +# Shape +# ----- +def test_forward_shape(): + # tflite 1.13 convert method does not accept empty shapes + if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'): + tf.reset_default_graph() + with tf.Graph().as_default(): + data = np.array([1, 18, 3], dtype=np.int32) + start = tf.placeholder(dtype=tf.int32, shape=[], name="start") + limit = tf.placeholder(dtype=tf.int32, shape=[], name="limit") + delta = tf.placeholder(dtype=tf.int32, shape=[], name="delta") + r = tf.range(start, limit, delta, tf.int32, name="range") + out = tf.shape(r, out_type=tf.dtypes.int32) + compare_tflite_with_tvm( + [x for x in np.nditer(data)], + ["start", "limit", "delta"], + [start, limit, delta], + [out], + mode="vm" + ) ####################################################################### # Concatenation @@ -820,7 +1145,11 @@ def test_all_unary_elemwise(): _test_forward_unary_elemwise(_test_ceil) _test_forward_unary_elemwise(_test_cos) _test_forward_unary_elemwise(_test_round) - _test_forward_unary_elemwise(_test_tan) + # This fails with TF and Tflite 1.15.2, this could not have been tested + # in CI or anywhere else. The failure mode is that we see a backtrace + # from the converter that we need to provide a custom Tan operator + # implementation. + #_test_forward_unary_elemwise(_test_tan) _test_forward_unary_elemwise(_test_elu) ####################################################################### @@ -1036,7 +1365,9 @@ def test_all_elemwise(): _test_forward_elemwise(_test_add) _test_forward_elemwise_quantized(_test_add) _test_forward_elemwise(partial(_test_add, fused_activation_function="RELU")) - _test_forward_elemwise(partial(_test_add, fused_activation_function="RELU6")) + # this is broken with tf upgrade 1.15.2 and hits a segfault that needs + # further investigation. + # _test_forward_elemwise(partial(_test_add, fused_activation_function="RELU6")) _test_forward_elemwise(_test_sub) _test_forward_elemwise_quantized(_test_sub) _test_forward_elemwise(partial(_test_sub, fused_activation_function="RELU")) @@ -1062,6 +1393,43 @@ def test_all_elemwise(): _test_forward_elemwise(_test_floor_divide) _test_forward_elemwise(_test_floor_mod) + +####################################################################### +# AddN +# ---- + + +def _test_forward_add_n(inputs): + tf.reset_default_graph() + with tf.Graph().as_default(): + temp = [] + for each in inputs: + temp.append(tf.placeholder(shape=each.shape, dtype=each.dtype)) + output = tf.add_n(temp) + compare_tflite_with_tvm([each for each in inputs], [each.name for each in temp], + [each for each in temp], [output]) + + +def test_forward_add_n(): + if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'): + x = np.random.randint(1, 100, size=(3, 3, 3), dtype=np.int32) + y = np.random.randint(1, 100, size=(3, 3, 3), dtype=np.int32) + z = np.random.randint(1, 100, size=(3, 3, 3), dtype=np.int32) + m, n, o = x.astype(np.float32), y.astype(np.float32), z.astype(np.float32) + in0 = x + in1 = [x, y] + in2 = (x, y, z) + in3 = m + in4 = [m, n] + in5 = (m, n, o) + _test_forward_add_n(in0) + _test_forward_add_n(in1) + _test_forward_add_n(in2) + _test_forward_add_n(in3) + _test_forward_add_n(in4) + _test_forward_add_n(in5) + + ####################################################################### # Logical operators # ----------------- @@ -1071,7 +1439,12 @@ def _test_logical_binary(logical_bin_op, data): with tf.Graph().as_default(): in_data = [array_ops.placeholder(shape=data[0].shape, dtype='bool', name='in_0'), array_ops.placeholder(shape=data[1].shape, dtype='bool', name='in_1')] - out = logical_bin_op(in_data[0], in_data[1], name='out') + if logical_bin_op == math_ops.logical_not: + out = math_ops.logical_or(in_data[0], in_data[1], name='out1') + out = logical_bin_op(out, name='out') + else: + out = logical_bin_op(in_data[0], in_data[1], name='out') + compare_tflite_with_tvm(data, ['in_0:0', 'in_1:0'], in_data, [out]) def _test_forward_logical_and(data): @@ -1082,6 +1455,10 @@ def _test_forward_logical_or(data): """ One iteration of logical or """ return _test_logical_binary(math_ops.logical_or, data) +def _test_forward_logical_not(data): + """ One iteration of logical not """ + return _test_logical_binary(math_ops.logical_not, data) + def test_all_logical(): data = [np.random.choice(a=[False, True], size=(2, 3, 4)).astype('bool'), np.random.choice(a=[False, True], size=(2, 3, 4)).astype('bool')] @@ -1089,6 +1466,7 @@ def test_all_logical(): if package_version.parse(tf.VERSION) >= package_version.parse('1.15.0'): _test_forward_logical_and(data) _test_forward_logical_or(data) + _test_forward_logical_not(data) ####################################################################### # Zeros like @@ -1105,6 +1483,39 @@ def test_forward_zeros_like(): """ ZEROS LIKE """ _test_zeros_like(np.arange(6.0, dtype=np.float32).reshape((1, 6))) + +####################################################################### +# Fill +# ---- + +def _test_fill(dims, value_data, value_dtype): + """ Use the fill op to create a tensor of value_data with constant dims.""" + + value_data = np.array(value_data, dtype=value_dtype) + # TF 1.13 TFLite convert method does not accept empty shapes + if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'): + with tf.Graph().as_default(): + value = array_ops.placeholder(dtype=value_dtype, name="value", shape=[]) + out = tf.fill(dims, value) + compare_tflite_with_tvm([value_data], ["value"], [value], [out]) + + with tf.Graph().as_default(): + input1 = array_ops.placeholder(dtype=value_dtype, name="input1", shape=dims) + # Fill op gets converted to static tensor during conversion + out = tf.fill(dims, value_data) + out1 = tf.add(out, input1) + input1_data = np.random.uniform(0, 5, size=dims).astype(value_dtype) + compare_tflite_with_tvm([input1_data], ["input1"], [input1], [out1]) + + +def test_forward_fill(): + """ Test FILL op """ + + _test_fill((1, 2, 2, 4), 5, "int32") + _test_fill((1, 2, 2, 4), 5, "float32") + _test_fill((5, ), 5, "int32") + + ####################################################################### # Reduce # ------ @@ -1221,6 +1632,27 @@ def test_all_reduce(): ####################################################################### +# Select, Where +# ------------- + +def test_forward_select(): + with tf.Graph().as_default(): + with tf.Session() as sess: + input1 = tf.placeholder( + tf.int32, shape=[1, 4, 4, 3], name='input1') + input2 = tf.placeholder( + tf.int32, shape=[1, 4, 4, 3], name='input2') + mask = input1 > input2 + out = tf.where(mask, input1 + 1, input2 * 2) + in_data1 = np.random.uniform( + 0, 10, size=(1, 4, 4, 3)).astype("int32") + in_data2 = np.random.uniform( + 0, 10, size=(1, 4, 4, 3)).astype("int32") + + compare_tflite_with_tvm([in_data1, in_data2], [ + 'input1:0', 'input2:0'], [input1, input2], [out]) + + # Squeeze # ------- @@ -1247,6 +1679,48 @@ def test_forward_squeeze(): _test_squeeze(np.arange(6).reshape((2, 1, 3, 1)), [1, 3]) +####################################################################### +# Quantize/DeQuantize +# ------------------- + +def _test_quantize_dequantize(data): + """ One iteration of quantize and dequantize """ + + # Define a dummy model + data_in = tf.keras.layers.Input(shape=data.shape[1:]) + act_func = tf.keras.layers.Activation('linear') + keras_model = tf.keras.models.Model(data_in, act_func(data_in)) + + # Load the model + converter = interpreter_wrapper.TFLiteConverter.from_keras_model(keras_model) + + # To create quantized values with dynamic range of activations, needs representative dataset + def representative_data_gen(): + for i in range(100): + yield [data] + + converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE] + converter.representative_dataset = representative_data_gen + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.uint8 + converter.inference_output_type = tf.uint8 + + # Convert the model to TensorFlow Lite format + tflite_model_quant = converter.convert() + + tflite_output = run_tflite_graph(tflite_model_quant, data) + tvm_output = run_tvm_graph(tflite_model_quant, data, 'input_1') + tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), + rtol=1e-5, atol=1e-5) + + +def test_forward_quantize_dequantize(): + """ Quantize Dequantize """ + data = np.random.uniform(0, 1, (1, 4, 4, 3)).astype("float32") + if package_version.parse(tf.VERSION) >= package_version.parse('2.0.0'): + _test_quantize_dequantize(data) + + ####################################################################### # Pad # --- @@ -1503,6 +1977,80 @@ def test_forward_spacetodepth(): _test_spacetodepth(np.random.normal(size=[1, 32, 32, 4]).astype("float32"), 2) _test_spacetodepth(np.random.normal(size=[1, 16, 8, 32]).astype("float32"), 4) +####################################################################### +# Sparse To Dense +# --------------- +def _test_sparse_to_dense(sparse_indices, sparse_values, default_value, output_shape): + # tflite 1.13 convert method does not accept empty shapes + if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'): + with tf.Graph().as_default(): + indices = tf.placeholder(shape=sparse_indices.shape, dtype=str(sparse_indices.dtype), name="indices") + values = tf.placeholder(shape=sparse_values.shape, dtype=str(sparse_values.dtype), name="values") + oshape = tf.constant(output_shape, shape=output_shape.shape, dtype=str(output_shape.dtype)) + + if default_value == None: + output = tf.sparse_to_dense(indices, oshape, values) + compare_tflite_with_tvm( + [sparse_indices, sparse_values], + ["indices", "values"], + [indices, values], + [output] + ) + else: + dv = tf.placeholder(shape=(), dtype=str(default_value.dtype), name="default_value") + output = tf.sparse_to_dense(indices, oshape, values, dv) + compare_tflite_with_tvm( + [sparse_indices, sparse_values, default_value], + ["indices", "values", "default_value"], + [indices, values, dv], + [output] + ) + +def test_forward_sparse_to_dense(): + ''' + Works in tvm/topi/tensorflow. But tflite converter breaks this test case + _test_sparse_to_dense( + np.int32(1), + np.int32(3), + np.int32(0), + np.array([5]).astype("int32") + ) + ''' + # vector + _test_sparse_to_dense( + np.array([0, 1, 4]).astype("int32"), + np.array([3, 3, 3]).astype("int32"), + np.int32(0), + np.array([5]).astype("int32") + ) + # vector nXd + _test_sparse_to_dense( + np.array([[0, 0], [1, 2]]).astype("int32"), + np.array([1, 2]).astype("int32"), + np.int32(0), + np.array([3, 4]).astype("int32") + ) + _test_sparse_to_dense( + np.array([[0, 0, 0], [1, 2, 3]]).astype("int32"), + np.array([1, 2]).astype("int32"), + np.int32(4), + np.array([2, 3, 4]).astype("int32") + ) + # floats + _test_sparse_to_dense( + np.array([0, 1, 4]).astype("int32"), + np.array([3.1, 3.1, 3.1]).astype("float32"), + np.float32(3.5), + np.array([5]).astype("int32") + ) + # default value not specified + _test_sparse_to_dense( + np.array([0, 1, 4]).astype("int32"), + np.array([3.1, 3.1, 3.1]).astype("float32"), + None, + np.array([5]).astype("int32") + ) + ####################################################################### # Fully Connected # --------------- @@ -1584,16 +2132,30 @@ def test_detection_postprocess(): tflite_output = run_tflite_graph(tflite_model, [box_encodings, class_predictions]) tvm_output = run_tvm_graph(tflite_model, [box_encodings, class_predictions], ["raw_outputs/box_encodings", "raw_outputs/class_predictions"], num_output=4) - # check valid count is the same + + # Check all output shapes are equal + assert all([tvm_tensor.shape == tflite_tensor.shape \ + for (tvm_tensor, tflite_tensor) in zip(tvm_output, tflite_output)]) + + # Check valid count is the same assert tvm_output[3] == tflite_output[3] valid_count = tvm_output[3][0] - tvm_boxes = tvm_output[0][0][:valid_count] - tvm_classes = tvm_output[1][0][:valid_count] - tvm_scores = tvm_output[2][0][:valid_count] - # check the output data is correct - tvm.testing.assert_allclose(np.squeeze(tvm_boxes), np.squeeze(tflite_output[0]), rtol=1e-5, atol=1e-5) - tvm.testing.assert_allclose(np.squeeze(tvm_classes), np.squeeze(tflite_output[1]), rtol=1e-5, atol=1e-5) - tvm.testing.assert_allclose(np.squeeze(tvm_scores), np.squeeze(tflite_output[2]), rtol=1e-5, atol=1e-5) + + # For boxes that do not have any detections, TFLite puts random values. Therefore, we compare + # tflite and tvm tensors for only valid boxes. + for i in range(0, valid_count): + # Check bounding box co-ords + tvm.testing.assert_allclose(np.squeeze(tvm_output[0][0][i]), np.squeeze(tflite_output[0][0][i]), + rtol=1e-5, atol=1e-5) + + # Check the class + # Stricter check to ensure class remains same + np.testing.assert_equal(np.squeeze(tvm_output[1][0][i]), + np.squeeze(tflite_output[1][0][i])) + + # Check the score + tvm.testing.assert_allclose(np.squeeze(tvm_output[2][0][i]), np.squeeze(tflite_output[2][0][i]), + rtol=1e-5, atol=1e-5) ####################################################################### @@ -1754,7 +2316,9 @@ def test_forward_qnn_mobilenet_v3_net(): """Test the Quantized TFLite Mobilenet V3 model.""" # In MobilenetV3, some ops are not supported before tf 1.15 fbs schema if package_version.parse(tf.VERSION) < package_version.parse('1.15.0'): - return + pytest.skip("Unsupported in tflite < 1.15.0") + else: + pytest.skip("This segfaults with tensorflow 1.15.2 and above") tflite_model_file = tf_testing.get_workload_official( "https://storage.googleapis.com/mobilenet_v3/checkpoints/v3-large_224_1.0_uint8.tgz", @@ -1777,14 +2341,107 @@ def test_forward_qnn_mobilenet_v3_net(): ####################################################################### -# MediaPipe +# Quantized SSD Mobilenet +# ----------------------- + +def test_forward_qnn_coco_ssd_mobilenet_v1(): + """Test the quantized Coco SSD Mobilenet V1 TF Lite model.""" + pytest.skip("LLVM bug - getExtendedVectorNumElements - " + + "https://discuss.tvm.ai/t/segfault-in-llvm/3567. The workaround is to use a " + + "specific target, for example, llvm -mpcu=core-avx2") + + tflite_model_file = tf_testing.get_workload_official( + "https://storage.googleapis.com/download.tensorflow.org/models/tflite/coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip", + "detect.tflite") + + with open(tflite_model_file, "rb") as f: + tflite_model_buf = f.read() + + data = get_real_image_object_detection(300, 300) + tflite_output = run_tflite_graph(tflite_model_buf, data) + tvm_output = run_tvm_graph(tflite_model_buf, data, 'normalized_input_image_tensor', num_output=4) + + # Check all output shapes are equal + assert all([tvm_tensor.shape == tflite_tensor.shape \ + for (tvm_tensor, tflite_tensor) in zip(tvm_output, tflite_output)]) + + # Check valid count is the same + assert tvm_output[3] == tflite_output[3] + valid_count = tvm_output[3][0] + + # For boxes that do not have any detections, TFLite puts random values. Therefore, we compare + # tflite and tvm tensors for only valid boxes. + for i in range(0, valid_count): + # We compare the bounding boxes whose prediction score is above 60%. This is typical in end + # to end application where a low prediction score is discarded. This is also needed because + # multiple low score bounding boxes can have same score and TFlite and TVM can have + # different orderings for same score bounding boxes. Another reason for minor differences in + # low score bounding boxes is the difference between TVM and TFLite for requantize operator. + if tvm_output[2][0][i] > 0.6: + # Check bounding box co-ords. The tolerances have to be adjusted, from 1e-5 to 1e-2, + # because of differences between for requantiize operator in TFLite and TVM. + tvm.testing.assert_allclose(np.squeeze(tvm_output[0][0][i]), + np.squeeze(tflite_output[0][0][i]), + rtol=1e-2, atol=1e-2) + + # Check the class + # Stricter check to ensure class remains same + np.testing.assert_equal(np.squeeze(tvm_output[1][0][i]), + np.squeeze(tflite_output[1][0][i])) + + # Check the score + tvm.testing.assert_allclose(np.squeeze(tvm_output[2][0][i]), + np.squeeze(tflite_output[2][0][i]), + rtol=1e-5, atol=1e-5) + + +####################################################################### +# SSD Mobilenet # ------------- +def test_forward_coco_ssd_mobilenet_v1(): + """Test the FP32 Coco SSD Mobilenet V1 TF Lite model.""" + tflite_model_file = tf_testing.get_workload_official( + "https://raw.githubusercontent.com/dmlc/web-data/master/tensorflow/models/object_detection/ssd_mobilenet_v1_coco_2018_01_28.tgz", + "ssd_mobilenet_v1_coco_2018_01_28.tflite") + + with open(tflite_model_file, "rb") as f: + tflite_model_buf = f.read() + + np.random.seed(0) + data = np.random.uniform(size=(1, 300, 300, 3)).astype('float32') + tflite_output = run_tflite_graph(tflite_model_buf, data) + tvm_output = run_tvm_graph(tflite_model_buf, data, 'normalized_input_image_tensor', num_output=4) + + # Check all output shapes are equal + assert all([tvm_tensor.shape == tflite_tensor.shape \ + for (tvm_tensor, tflite_tensor) in zip(tvm_output, tflite_output)]) + + # Check valid count is the same + assert tvm_output[3] == tflite_output[3] + valid_count = tvm_output[3][0] + + # For boxes that do not have any detections, TFLite puts random values. Therefore, we compare + # tflite and tvm tensors for only valid boxes. + for i in range(0, valid_count): + # Check bounding box co-ords + tvm.testing.assert_allclose(np.squeeze(tvm_output[0][0][i]), np.squeeze(tflite_output[0][0][i]), + rtol=1e-5, atol=1e-5) + # Check the class + np.testing.assert_equal(np.squeeze(tvm_output[1][0][i]), np.squeeze(tflite_output[1][0][i])) + + # Check the score + tvm.testing.assert_allclose(np.squeeze(tvm_output[2][0][i]), np.squeeze(tflite_output[2][0][i]), + rtol=1e-5, atol=1e-5) + +####################################################################### +# MediaPipe +# ------------- def test_forward_mediapipe_hand_landmark(): """Test MediaPipe 2D hand landmark TF Lite model.""" # MediaPipe 2D hand landmark TF tflite_model_file = download_testdata( - "https://github.com/google/mediapipe/raw/master/mediapipe/models/hand_landmark.tflite", + "https://github.com/google/mediapipe/raw/v0.7.4/mediapipe/models/hand_landmark.tflite", "hand_landmark.tflite") with open(tflite_model_file, "rb") as f: tflite_model_buf = f.read() @@ -1795,6 +2452,7 @@ def test_forward_mediapipe_hand_landmark(): tvm.testing.assert_allclose(np.squeeze(tvm_output[i]), np.squeeze(tflite_output[i]), rtol=1e-5, atol=1e-5) + ####################################################################### # Main # ---- @@ -1814,9 +2472,15 @@ def test_forward_mediapipe_hand_landmark(): # Cast test_forward_cast() + # BatchMatMul + test_forward_batch_matmul() + # Tile test_forward_tile() + # Query + test_forward_shape() + # Transforms test_forward_concatenation() test_forward_pad() @@ -1824,17 +2488,25 @@ def test_forward_mediapipe_hand_landmark(): test_forward_unpack() test_forward_reshape() test_all_resize() + test_forward_range() test_forward_squeeze() test_forward_slice() test_forward_topk() + test_forward_gather() + test_forward_gather_nd() + test_forward_stridedslice() test_forward_depthtospace() test_forward_spacetodepth() + test_forward_sparse_to_dense() + test_forward_select() + test_forward_quantize_dequantize() # NN test_forward_convolution() test_forward_transpose_conv() test_forward_logistic() test_forward_pooling() + test_forward_l2_pool2d() test_forward_softmax() test_forward_tanh() test_forward_relu() @@ -1845,13 +2517,16 @@ def test_forward_mediapipe_hand_landmark(): # Elemwise test_all_elemwise() + test_forward_add_n() # Unary elemwise test_all_unary_elemwise() - # Zeros Like test_forward_zeros_like() + # Fill + test_forward_fill() + # Reduce test_all_reduce() @@ -1867,10 +2542,14 @@ def test_forward_mediapipe_hand_landmark(): test_forward_mobilenet_v3() test_forward_inception_v3_net() test_forward_inception_v4_net() + test_forward_coco_ssd_mobilenet_v1() test_forward_mediapipe_hand_landmark() # End to End quantized test_forward_qnn_inception_v1_net() test_forward_qnn_mobilenet_v1_net() test_forward_qnn_mobilenet_v2_net() + #This also fails with a segmentation fault in my run + #with Tflite 1.15.2 test_forward_qnn_mobilenet_v3_net() + test_forward_qnn_coco_ssd_mobilenet_v1() diff --git a/tests/python/integration/test_ewise.py b/tests/python/integration/test_ewise.py index a8f2db19a9b0..dfa247e5a09a 100644 --- a/tests/python/integration/test_ewise.py +++ b/tests/python/integration/test_ewise.py @@ -80,8 +80,15 @@ def check_device(device): # launch the kernel. n = 1024 - a = tvm.nd.array((np.random.uniform(size=n) * 256).astype(A.dtype), ctx) - b = tvm.nd.array((np.random.uniform(size=n) * 256).astype(B.dtype), ctx) + a_np = (np.random.uniform(size=n) * 256).astype(A.dtype) + b_np = (np.random.uniform(size=n) * 256).astype(B.dtype) + + # "fix" the values in a and b to avoid the result being too small + b_np += ((b_np < 2.0) * 2) + a_np[np.abs(np.fmod(a_np, b_np)) < 1] += 1 + + a = tvm.nd.array(a_np, ctx) + b = tvm.nd.array(b_np, ctx) c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx) ftimer = fmod.time_evaluator(fmod.entry_name, ctx, number=1) tcost = ftimer(a, b, c).mean diff --git a/tests/python/integration/test_reduce.py b/tests/python/integration/test_reduce.py index 82ade4478bea..c5d9d0875c3e 100644 --- a/tests/python/integration/test_reduce.py +++ b/tests/python/integration/test_reduce.py @@ -65,6 +65,7 @@ def check_device(device, host="llvm"): check_device("vulkan") check_device("cuda") check_device("opencl") + check_device("rocm") test_prim(te.sum, np.sum) test_prim(tvm.te.min, np.amin) test_prim(tvm.te.max, np.amax) @@ -179,7 +180,7 @@ def check_target(device, host="stackvm"): check_target("cuda") check_target("metal") check_target("opencl") - + check_target("rocm") def test_rfactor_elemwise_threads(): n = 1025 @@ -230,6 +231,7 @@ def check_target(device, host="stackvm"): check_target("cuda") check_target("metal") check_target("opencl") + check_target("rocm") def test_argmax(): def fcombine(x, y): @@ -337,6 +339,110 @@ def check_target(device): check_target("cuda") check_target("vulkan") + check_target("rocm") + +def test_warp_reduction1(): + nthx = 32 + nthy = 4 + block_x = te.thread_axis("blockIdx.x") + thread_x = te.thread_axis((0, nthx), "threadIdx.x") + thread_y = te.thread_axis((0, nthy), "threadIdx.y") + + def check_target(device, m, n): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("skip because %s is not enabled.." % device) + return + + # compute + A = te.placeholder((m, n), name='A') + k = te.reduce_axis((0, n)) + B = te.compute((m,), lambda i: te.max(A[i][k], axis=k), name='B') + s = te.create_schedule(B.op) + + # schedule + k = s[B].op.reduce_axis[0] + ko, _ = s[B].split(k, nparts=nthx) + s[B].bind(ko, thread_x) + xo, xi = s[B].split(s[B].op.axis[0], factor=nthy) + s[B].bind(xi, thread_y) + s[B].bind(xo, block_x) + + tvm.lower(s, [A, B], simple_mode=True) + + # validation + func = tvm.build(s, [A, B], device, name="warp_reduction") + a_np = np.random.uniform(size=(m,n)).astype(A.dtype) + b_np = np.zeros((m,), dtype=A.dtype) + a = tvm.nd.array(a_np, ctx) + b = tvm.nd.array(b_np, ctx) + b_np = np.max(a_np, axis=1) + func(a, b) + tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-3, atol=1e-3) + + check_target("cuda", m=32, n=256) + check_target("cuda", m=10, n=20) + check_target("rocm", m=32, n=256) + check_target("rocm", m=10, n=20) + # This is a bug in normal reduction. + # check_target("cuda", m=10, n=37) + +def test_warp_reduction2(): + def fcombine(x, y): + return x[0] + y[0], x[1] * y[1] + + def fidentity(t0, t1): + return tvm.tir.const(0, t0), tvm.tir.const(1, t1) + + add_mul_reducer = te.comm_reducer(fcombine, fidentity, name='add_mul_reducer') + + # compute + m = 16 + n = 256 + A0 = te.placeholder((m, n), name='A0', dtype='float32') + A1 = te.placeholder((m, n), name='Al', dtype='float32') + k = te.reduce_axis((0, n), 'k') + T0, T1 = te.compute((m, ), lambda i: \ + add_mul_reducer((A0[i, k], A1[i, k]), axis=k), name='T') + + nthdx, nthdy = 32, 2 + block_x = te.thread_axis("blockIdx.x") + thread_x = te.thread_axis((0, nthdx), "threadIdx.x") + thread_y = te.thread_axis((0, nthdy), "threadIdx.y") + + def check_target(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("skip because %s is not enabled.." % device) + return + + # schedule + s = te.create_schedule(T0.op) + ko, _ = s[T0].split(k, nparts=nthdx) + xo, xi = s[T0].split(s[T0].op.axis[0], factor=nthdy) + s[T0].bind(ko, thread_x) + s[T0].bind(xi, thread_y) + s[T0].bind(xo, block_x) + + # validation + ctx = tvm.context(device, 0) + a0_np = np.random.uniform(size=(m,n)).astype(A0.dtype) + a1_np = np.random.uniform(size=(m,n)).astype(A1.dtype) + t0_np = np.zeros((m,), dtype=A0.dtype) + t1_np = np.zeros((m,), dtype=A1.dtype) + a0 = tvm.nd.array(a0_np, ctx) + a1 = tvm.nd.array(a1_np, ctx) + t0 = tvm.nd.array(t0_np, ctx) + t1 = tvm.nd.array(t1_np, ctx) + func = tvm.build(s, [A0, A1, T0, T1], device, name="reduction") + func(a0, a1, t0, t1) + t0_np = np.sum(a0_np, axis=1) + t1_np = np.product(a1_np, axis=1) + tvm.testing.assert_allclose(t0.asnumpy(), t0_np, rtol=1e-3, atol=1e-3) + tvm.testing.assert_allclose(t1.asnumpy(), t1_np, rtol=1e-3, atol=1e-3) + + check_target("cuda") + check_target("rocm") if __name__ == "__main__": test_rfactor_elemwise_threads() @@ -346,3 +452,5 @@ def check_target(device): test_reduce_prims() test_argmax() test_rfactor_argmax() + test_warp_reduction1() + test_warp_reduction2() diff --git a/tests/python/nightly/quantization/test_quantization_accuracy.py b/tests/python/nightly/quantization/test_quantization_accuracy.py index 4818cc651b94..d4b55f14100b 100644 --- a/tests/python/nightly/quantization/test_quantization_accuracy.py +++ b/tests/python/nightly/quantization/test_quantization_accuracy.py @@ -66,7 +66,7 @@ def get_model(model_name, batch_size, qconfig, target=None, original=False, simu mod, params = relay.frontend.from_mxnet(gluon_model, {"data": data_shape}) net = mod['main'] - with relay.build_config(opt_level=3): + with tvm.transform.PassContext(opt_level=3): qfunc = relay.quantize.prerequisite_optimize(net, params=params) logging.debug('original') logging.debug(qfunc.astext(show_meta_data=False)) @@ -83,7 +83,7 @@ def get_model(model_name, batch_size, qconfig, target=None, original=False, simu def eval_acc(model, dataset, batch_fn, target=tvm.target.cuda(), ctx=tvm.gpu(), log_interval=100): - with relay.build_config(opt_level=3): + with tvm.transform.PassContext(opt_level=3): graph, lib, params = relay.build(model, target) # create runtime module m = tvm.contrib.graph_runtime.create(graph, lib, ctx) diff --git a/tests/python/relay/benchmarking/benchmark_vm.py b/tests/python/relay/benchmarking/benchmark_vm.py index 1e9030c5d8e6..a6e05bee5ca2 100644 --- a/tests/python/relay/benchmarking/benchmark_vm.py +++ b/tests/python/relay/benchmarking/benchmark_vm.py @@ -36,7 +36,7 @@ def benchmark_execution(mod, model="unknown"): def get_graph_runtime_output(mod, data, params, target, ctx, dtype='float32', number=2, repeat=20): - with relay.build_config(opt_level=3): + with tvm.transform.PassContext(opt_level=3): graph, lib, params = relay.build(mod, target, params=params) m = graph_runtime.create(graph, lib, ctx) @@ -59,7 +59,7 @@ def get_graph_runtime_output(mod, data, params, target, ctx, def get_vm_output(mod, data, params, target, ctx, dtype='float32', number=2, repeat=20): - with relay.build_config(opt_level=3): + with tvm.transform.PassContext(opt_level=3): exe = vm.compile(mod, target, params=params) rly_vm = vm_rt.VirtualMachine(exe) rly_vm.init(ctx) diff --git a/tests/python/relay/test_adt.py b/tests/python/relay/test_adt.py index c9b13d26894f..ff76e1c64bcb 100644 --- a/tests/python/relay/test_adt.py +++ b/tests/python/relay/test_adt.py @@ -1336,7 +1336,6 @@ def run(dtype, shape): p = Prelude(mod) static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape) static_tensor_array_ops.register() - static_tensor_array_ops.define_tensor_get_data(shape) np_data_list = [] ta_length = 3 diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index aa81e3113b7f..8e535a692b88 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -22,6 +22,7 @@ from tvm import relay from tvm.relay.loops import while_loop from tvm.relay.testing import run_infer_type as infer_type +import topi.testing def int32(val): return relay.const(val, 'int32') @@ -96,31 +97,48 @@ def check_fail(x_shape, y_shape, x_np_shape, y_np_shape, op, np_op): check_fail((relay.Any(),), (3, 2), (2), (4, 2), relay.add, np.add) -def verify_any_full(x_shape, x_np_shape, relay_op, np_op, dtype='float32'): +def verify_any_full_like(x_shape, x_np_shape, relay_op, np_op, dtype='float32'): x = relay.var('x', shape=x_shape, dtype=dtype) mod = tvm.IRModule() - mod['main'] = relay.Function([x], relay.zeros_like(x)) + mod['main'] = relay.Function([x], relay_op(x)) x_np = np.random.uniform(size=x_np_shape).astype(dtype) - res_np = np.zeros_like(x_np) + res_np = np_op(x_np) + for kind in ['debug', 'vm']: + ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target='llvm') + result = ex.evaluate()(x_np).asnumpy() + tvm.testing.assert_allclose(result, res_np) + +def test_any_full_like(): + # zeros_like, ones_like + verify_any_full_like(any_dims(3), (2, 3, 5), relay.zeros_like, np.zeros_like, "float32") + verify_any_full_like(any_dims(3), (225, 115, 15), relay.zeros_like, np.zeros_like, "float32") + verify_any_full_like(any_dims(5), (10, 11, 12, 13, 14), relay.zeros_like, np.zeros_like, "int32") + verify_any_full_like(any_dims(3), (2, 3, 5), relay.ones_like, np.ones_like, "float32") + verify_any_full_like(any_dims(3), (225, 115, 15), relay.ones_like, np.ones_like, "float32") + verify_any_full_like(any_dims(5), (10, 11, 12, 13, 14), relay.ones_like, np.ones_like, "int32") + +def verify_any_full(x_np_shape, relay_op, np_op, dtype='float32', value=None): + x = relay.var('x', shape=(len(x_np_shape),), dtype="int32") + mod = tvm.IRModule() + out = relay_op(x, dtype) if value is None else relay_op(relay.expr.const(value), x, dtype) + mod['main'] = relay.Function([x], out) + res_np = np_op(x_np_shape) if value is None else np_op(x_np_shape, value) + x_np = np.array(x_np_shape).astype("int32") for kind in ['debug', 'vm']: ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target='llvm') result = ex.evaluate()(x_np).asnumpy() tvm.testing.assert_allclose(result, res_np) def test_any_full(): - # zeros, zeros_like, ones, ones_like - verify_any_full(any_dims(3), (2, 3, 5), relay.zeros, np.zeros, "float32") - verify_any_full(any_dims(3), (225, 115, 15), relay.zeros, np.zeros, "float32") - verify_any_full(any_dims(5), (10, 11, 12, 13, 14), relay.zeros, np.zeros, "int32") - verify_any_full(any_dims(3), (2, 3, 5), relay.zeros_like, np.zeros_like, "float32") - verify_any_full(any_dims(3), (225, 115, 15), relay.zeros_like, np.zeros_like, "float32") - verify_any_full(any_dims(5), (10, 11, 12, 13, 14), relay.zeros_like, np.zeros_like, "int32") - verify_any_full(any_dims(3), (2, 3, 5), relay.ones, np.ones, "float32") - verify_any_full(any_dims(3), (225, 115, 15), relay.ones, np.ones, "float32") - verify_any_full(any_dims(5), (10, 11, 12, 13, 14), relay.ones, np.ones, "int32") - verify_any_full(any_dims(3), (2, 3, 5), relay.ones_like, np.ones_like, "float32") - verify_any_full(any_dims(3), (225, 115, 15), relay.ones_like, np.ones_like, "float32") - verify_any_full(any_dims(5), (10, 11, 12, 13, 14), relay.ones_like, np.ones_like, "int32") + # zeros, ones, full + verify_any_full((2, 3, 5), relay.zeros, np.zeros, "float32") + verify_any_full((225, 115, 15), relay.zeros, np.zeros, "float32") + verify_any_full((10, 11, 12, 13, 14), relay.zeros, np.zeros, "int32") + verify_any_full((2, 3, 5), relay.ones, np.ones, "float32") + verify_any_full((225, 115, 15), relay.ones, np.ones, "float32") + verify_any_full((10, 11, 12, 13, 14), relay.ones, np.ones, "int32") + verify_any_full((10, 11, 12, 13, 14), relay.full, np.full, "float32", 2.0) + verify_any_full((1, 2, 3, 4), relay.full, np.full, "int32", -2) def test_any_concat(): x = relay.var('x', shape=(relay.Any(), 2), dtype="float32") @@ -138,23 +156,36 @@ def test_any_concat(): result = ex.evaluate()(x_np, y_np) tvm.testing.assert_allclose(result.asnumpy(), ref) -def verify_any_reshape(x_shape, newshape, x_np_shape, out_shape): +def verify_any_reshape(x_shape, newshape, x_np_shape, out_shape, variable_newshape=False): x = relay.var('x', shape=x_shape, dtype="float32") - y = relay.reshape(x, newshape=newshape) - mod = tvm.IRModule() - mod["main"] = relay.Function([x], y) + relu_x = relay.nn.relu(x) data = np.random.uniform(size=x_np_shape).astype('float32') + params = [x] + args = [data] + + if variable_newshape: + newshape_var = relay.var('newshape', shape=(len(newshape),), dtype='int64') + params.append(newshape_var) + args.append(np.array(newshape, dtype='int64')) + newshape = newshape_var + + y = relay.reshape(relu_x, newshape=newshape) + mod = tvm.IRModule() + mod["main"] = relay.Function(params, y) + for kind in ["debug", "vm"]: ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") - result = ex.evaluate()(data).asnumpy() + result = ex.evaluate()(*args).asnumpy() assert result.shape == out_shape tvm.testing.assert_allclose(result.flatten(), data.flatten()) def test_any_reshape(): - verify_any_reshape(any_dims(3), (1, -1), (2, 3, 4), (1, 24)) - verify_any_reshape(any_dims(3), (0, -1), (2, 3, 4), (2, 12)) + for variable_newshape in [False, True]: + # Variable newshape only supports that output rank is the same as newshape + verify_any_reshape(any_dims(3), (1, -1), (2, 3, 4), (1, 24), variable_newshape) + verify_any_reshape(any_dims(3), (0, -1), (2, 3, 4), (2, 12), variable_newshape) + verify_any_reshape(any_dims(3), (-4, 2, -1, -2), (6, 3, 4), (2, 3, 3, 4), variable_newshape) verify_any_reshape(any_dims(3), (0, -2), (2, 3, 4), (2, 3, 4)) - verify_any_reshape(any_dims(3), (-4, 2, -1, -2), (6, 3, 4), (2, 3, 3, 4)) verify_any_reshape(any_dims(3), (-4, -1, 2, -3), (6, 3, 4), (3, 2, 12)) def verify_any_argwhere(x_shape, x_np_shape, dtype="bool"): @@ -508,6 +539,34 @@ def test_any_pad(): verify_any_pad(any_dims(3), ((0, 0), (1, 1), (2, 2)), (1, 2, 3)) verify_any_pad(any_dims(4), ((1, 0), (1, 3), (0, 2), (9, 0)), (13, 11, 3, 1)) +def verify_any_dilate(data_shape, strides, static_data_shape): + assert len(data_shape) == len(strides) + mod = tvm.IRModule() + dtype = "float32" + data = relay.var('data', shape=data_shape, dtype=dtype) + y = relay.nn.dilate(data, strides) + mod["main"] = relay.Function([data], y) + data_np = np.random.uniform(size=static_data_shape).astype(dtype) + ref_shape = tuple((static_data_shape[i] - 1) * strides[i] + 1 + for i in range(len(static_data_shape))) + ref_out = np.zeros(shape=ref_shape, dtype=dtype) + ref_out[tuple(slice(None, None, strides[i]) for i in range(len(data_shape)))] = data_np + + for kind in ["debug", "vm"]: + ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") + result = ex.evaluate()(data_np) + tvm.testing.assert_allclose(result.asnumpy(), ref_out) + +def test_any_dilate(): + verify_any_dilate(any_dims(1), (1,), (1,)) + verify_any_dilate(any_dims(1), (1,), (5,)) + verify_any_dilate(any_dims(1), (5,), (5,)) + verify_any_dilate(any_dims(3), (1, 1, 1), (1, 2, 3)) + verify_any_dilate(any_dims(3), (1, 1, 2), (1, 2, 3)) + verify_any_dilate(any_dims(3), (1, 1, 5), (1, 2, 3)) + verify_any_dilate(any_dims(3), (3, 7, 5), (1, 2, 3)) + verify_any_dilate(any_dims(4), (3, 7, 1, 5), (1, 2, 3, 4)) + def verify_any_softmax(data_shape, axis, static_data_shape, ref_out_shape): mod = tvm.IRModule() dtype = "float32" @@ -525,6 +584,37 @@ def test_any_softmax(): verify_any_softmax(any_dims(3), -1, (1, 2, 3), (1, 2, 3)) verify_any_softmax(any_dims(4), 2, (13, 11, 3, 1), (13, 11, 3, 1)) +def verify_any_topk(data_shape, kval, np_dshape, dtype, const_k=False): + mod = tvm.IRModule() + data = relay.var('data', shape=data_shape, dtype=dtype) + np_data = np.random.uniform(size=np_dshape).astype(dtype) + if const_k: + k = relay.const(kval) + args = [data] + in_vals = [np_data] + else: + k = relay.var('k', shape=(), dtype="int32") + args = [data, k] + in_vals = [np_data, kval] + out = relay.topk(data, k, ret_type="indices") + mod["main"] = relay.Function(args, out) + + sorted = np.argsort(-np_data) + if len(np_dshape) == 2: + ref_out = sorted[:, 0:kval] + else: + ref_out = sorted[0:kval] + + for kind in ["debug", "vm"]: + ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") + result = ex.evaluate()(*in_vals) + tvm.testing.assert_allclose(result.asnumpy(), ref_out) + +def test_any_topk(): + verify_any_topk(any_dims(1), 5, (10,), "float32") + verify_any_topk(any_dims(2), 2, (6, 3), "int32") + verify_any_topk(any_dims(2), 3, (6, 3), "float32", True) + def test_fused_ops(): x = relay.var('x', shape=(relay.Any(), relay.Any()), dtype='float32') y0 = x + relay.const(1.0, 'float32') @@ -553,6 +643,52 @@ def test_arange_with_dynamic_shape(): result = ex.evaluate()(data) tvm.testing.assert_allclose(result.asnumpy(), np.array(range(10)).astype("int32")+1) +def verify_any_strided_slice(data_shape, begin_shape, end_shape, strides_shape, + data_np_shape, slice_mode="end", const_attrs=False): + # Generate random numpy input data + np_data = np.random.uniform(size=data_np_shape).astype('float32') + np_begin = np.random.randint(2, size=begin_shape, dtype="int32") + np_end = np.random.randint(5, 10, size=end_shape, dtype="int32") + np_strides = np.random.randint(1, 2 if slice_mode == "size" else 3, size=strides_shape, dtype="int32") + # target numpy result + ref_res = topi.testing.strided_slice_python(np_data, np_begin, np_end, np_strides, slice_mode) + + # Relay Module + mod = tvm.IRModule() + data = relay.var('data', shape=data_shape, dtype='float32') + if const_attrs: + data = relay.var('data', shape=data_np_shape, dtype='float32') + begin = relay.const(np_begin) + end = relay.const(np_end) + strides = relay.const(np_strides) + args = [data] + np_inputs = [np_data] + else: + begin = relay.var('begin', shape=begin_shape, dtype="int32") + end = relay.var('end', shape=end_shape, dtype="int32") + strides = relay.var('strides', shape=strides_shape, dtype="int32") + args = [data, begin, end, strides] + np_inputs = [np_data, np_begin, np_end, np_strides] + + y = relay.strided_slice(data, begin=begin, end=end, + strides=strides, slice_mode=slice_mode) + mod["main"] = relay.Function(args, y) + + for kind in ["debug", "vm"]: + ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") + result = ex.evaluate()(*np_inputs) + tvm.testing.assert_allclose(result.asnumpy(), ref_res) + + +def test_any_strided_slice(): + verify_any_strided_slice(any_dims(2), (2,), (2,), (2,), (15, 21)) + verify_any_strided_slice(any_dims(3), (3,), (3,), (3,), (15, 17, 21)) + verify_any_strided_slice(any_dims(3), (3,), (3,), (3,), (23, 29, 41)) + verify_any_strided_slice(any_dims(4), (4,), (4,), (4,), (40, 50, 60, 70)) + verify_any_strided_slice(any_dims(3), (3,), (3,), (3,), (15, 17, 21), slice_mode="size") + verify_any_strided_slice(any_dims(2), (2,), (2,), (2,), (15, 21), const_attrs=True) + + def test_recursive_concat(): """ fn @concat_loop(%i: int32, %st: (any, 1)) -> (any, 1) { @@ -639,8 +775,50 @@ def _body(i, st): except Exception as e: assert "in particular dimension 0 conflicts 2 does not match 1" in str(e) +def test_tuple_get_item(): + mod = tvm.IRModule() + dtype = "float32" + static_data_shape = (9, 4) + data_shape = (relay.Any(), 4) + indices_or_sections = 2 + axis = 1 + data = relay.var('data', shape=data_shape, dtype=dtype) + y = relay.split(data, indices_or_sections, axis) + y = relay.expr.TupleGetItem(y.astuple(), 0) + mod["main"] = relay.Function([data], y) + data_np = np.random.uniform(size=static_data_shape).astype(dtype) + ref_out_shape = (9, 2) + for kind in ["vm"]: + ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") + result = ex.evaluate()(data_np) + assert result.asnumpy().shape == ref_out_shape, \ + "Shape mismatch: expect %s but got %s." % (str(ref_out_shape), str(ret.asnumpy().shape)) + +def test_mixed_input_type(): + mod = tvm.IRModule() + dtype = "float32" + static_data_shape = (9, 4) + data_shape = (relay.Any(), 4) + tensor_type = relay.TensorType(data_shape, dtype) + tuple_type = relay.TupleType([tensor_type, tensor_type]) + data0 = relay.var("d0", type_annotation=relay.TupleType([tuple_type, tensor_type])) + data1 = relay.var("d1", shape=(relay.Any(), 4), dtype=dtype) + data_tuple = relay.expr.TupleWrapper(data0, 2) + nested_data_tuple = relay.expr.TupleWrapper(data_tuple[0], 2) + y = nested_data_tuple[1] * data_tuple[1] + data1 + mod["main"] = relay.Function([data0, data1], y) + data_np0 = np.random.uniform(size=static_data_shape).astype(dtype) + data_np1 = np.random.uniform(size=static_data_shape).astype(dtype) + ref_out_shape = (9, 4) + for kind in ["vm"]: + ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") + result = ex.evaluate()([[data_np0, data_np0], data_np0], data_np1) + assert result.asnumpy().shape == ref_out_shape, \ + "Shape mismatch: expect %s but got %s." % (str(ref_out_shape), str(result.asnumpy().shape)) + if __name__ == "__main__": test_any_full() + test_any_full_like() test_any_broadcast() test_any_elemwise() test_any_broadcast_fail() @@ -663,7 +841,12 @@ def _body(i, st): test_any_dense() test_any_pad() test_any_softmax() + test_any_topk() test_fused_ops() + test_any_argwhere() test_arange_with_dynamic_shape() + test_any_strided_slice() test_recursive_concat() test_recursive_concat_with_wrong_annotation() + test_tuple_get_item() + test_mixed_input_type() diff --git a/tests/python/relay/test_backend_compile_engine.py b/tests/python/relay/test_backend_compile_engine.py index eb018fed96e7..1b4e08f7eb7b 100644 --- a/tests/python/relay/test_backend_compile_engine.py +++ b/tests/python/relay/test_backend_compile_engine.py @@ -184,7 +184,7 @@ def test_compile_placeholder_bypass(): z = relay.var("z", shape=(2, 3)) result = relay.Tuple([x, relay.op.concatenate([y, z], axis=0)]) func = relay.Function(relay.analysis.free_vars(result), result) - with relay.build_config(opt_level=0): + with tvm.transform.PassContext(opt_level=0): graph, lib, params = relay.build(tvm.IRModule.from_expr(func), 'llvm') diff --git a/tests/python/relay/test_backend_graph_runtime.py b/tests/python/relay/test_backend_graph_runtime.py index b0399a53a732..f0785bcf1c09 100644 --- a/tests/python/relay/test_backend_graph_runtime.py +++ b/tests/python/relay/test_backend_graph_runtime.py @@ -105,7 +105,7 @@ def test_with_params(): mod.run() res = mod.get_output(0).asnumpy() ref_res = np.exp(y_data + x_data) - tvm.testing.assert_allclose(res, ref_res) + tvm.testing.assert_allclose(res, ref_res, atol=1e-5, rtol=1e-5) def test_plan_memory(): @@ -166,7 +166,7 @@ def unit_numpy(X, W): z = unit(rnn_dim) for target, ctx in ctx_list(): - with relay.build_config(opt_level=2): + with tvm.transform.PassContext(opt_level=2): graph, lib, params = relay.build(tvm.IRModule.from_expr(z), target) m = graph_runtime.create(graph, lib, ctx) m.set_input("X", tvm.nd.array(x.astype(dtype))) diff --git a/tests/python/relay/test_cpp_build_module.py b/tests/python/relay/test_cpp_build_module.py index 171b6b0b77b0..8d5438424e32 100644 --- a/tests/python/relay/test_cpp_build_module.py +++ b/tests/python/relay/test_cpp_build_module.py @@ -115,7 +115,7 @@ def check_conversion(tgt, ctx): X = tvm.nd.array(n * np.random.randn(n).astype(src) - n / 2) # build - with relay.build_config(opt_level=1): + with tvm.transform.PassContext(opt_level=1): g_json, mmod, params = relay.build(tvm.IRModule.from_expr(func), tgt) # test diff --git a/tests/python/relay/test_dataflow_pattern.py b/tests/python/relay/test_dataflow_pattern.py new file mode 100644 index 000000000000..9727e53bab0a --- /dev/null +++ b/tests/python/relay/test_dataflow_pattern.py @@ -0,0 +1,1363 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-wildcard-import +import numpy as np + +import tvm +from tvm import relay +from tvm.relay.build_module import bind_params_by_name +from tvm.relay.dataflow_pattern import * +from tvm.relay.testing import run_opt_pass + +# NB: 1 corresponds to the C++ enum that specicfies this +# we loose the type safety due to the Python/C++ calling +# convention. +K_ELEMWISE = 0 +K_BROADCAST = 1 + + +## NODE TESTS +def test_expr_pattern(): + ep = is_expr(relay.var('x', shape=(4, 1))) + assert isinstance(ep, ExprPattern) + assert isinstance(ep.expr, relay.Var) + + +def test_var_pattern(): + v = is_var("x") + assert isinstance(v, VarPattern) + assert v.name == "x" + + +def test_constant_pattern(): + c = is_constant() + assert isinstance(c, ConstantPattern) + + +def test_wildcard_pattern(): + wc = wildcard() + assert isinstance(wc, WildcardPattern) + + +def test_CallPattern(): + wc1 = wildcard() + wc2 = wildcard() + c = is_op("add")(wc1, wc2) + assert isinstance(c, CallPattern) + assert isinstance(c.args[0], WildcardPattern) + assert isinstance(c.args[1], WildcardPattern) + + +def test_TuplePattern(): + wc1 = wildcard() + wc2 = wildcard() + t = is_tuple([wc1, wc2]) + assert isinstance(t, TuplePattern) + assert isinstance(t.fields[0], WildcardPattern) + assert isinstance(t.fields[1], WildcardPattern) + + +def test_TupleGetItemPattern(): + wc1 = wildcard() + wc2 = wildcard() + t = is_tuple([wc1, wc2]) + tgi = is_tuple_get_item(t, 1) + assert isinstance(tgi, TupleGetItemPattern) + assert isinstance(tgi.tuple, TuplePattern) + assert isinstance(tgi.tuple.fields[0], WildcardPattern) + assert isinstance(tgi.tuple.fields[1], WildcardPattern) + + +def test_AltPattern(): + is_add_or_sub = is_op('add') | is_op('subtract') + assert isinstance(is_add_or_sub, AltPattern) + + +def test_TypePattern(): + ttype = relay.TensorType((10, 10), "float32") + ty_pat = has_type(ttype) + assert isinstance(ty_pat, TypePattern) + assert ty_pat.type == ttype + + +def test_DataTypePattern(): + dtype = "float16" + pattern = has_dtype(dtype) + assert isinstance(pattern, DataTypePattern) + assert pattern.dtype == dtype + + +def test_ShapePattern(): + shape = [10, 10] + pattern = has_shape(shape) + assert isinstance(pattern, ShapePattern) + assert tvm.ir.structural_equal(pattern.shape, shape) + + +def test_AttrPattern(): + op = is_op('add').has_attr({"TOpPattern": K_ELEMWISE}) + assert isinstance(op, AttrPattern) + assert op.attrs["TOpPattern"] == K_ELEMWISE + + +## MATCHER TESTS + + +def test_match_op(): + assert is_op('add').match(relay.op.op.get("add")) + + +def test_no_match_op(): + assert not is_op('add').match(relay.op.op.get("subtract")) + + +def test_match_op_or(): + is_add_or_sub = is_op('add') | is_op('subtract') + assert is_add_or_sub.match(relay.op.op.get("add")) + assert is_add_or_sub.match(relay.op.op.get("subtract")) + + +def test_match_call_commutive(): + x = relay.var('x') + y = relay.var('y') + add_pattern = is_op('add')(is_var("x"), is_var("y")) + assert add_pattern.match(x + y) + assert add_pattern.match(y + x) + mul_pattern = is_op('multiply')(is_var("x"), is_var("y")) + assert mul_pattern.match(x * y) + assert mul_pattern.match(y * x) + + +def test_no_match_call_commutive(): + x = relay.var('x') + y = relay.var('y') + add_pattern = is_op('subtract')(is_var("x"), is_var("y")) + assert add_pattern.match(x - y) + assert not add_pattern.match(y - x) + add_pattern = is_op('divide')(is_var("x"), is_var("y")) + assert add_pattern.match(x / y) + assert not add_pattern.match(y / x) + + +def test_match_call(): + x = relay.var('x') + y = relay.var('y') + add_pattern = is_op('add')(wildcard(), wildcard()) + assert add_pattern.match(x + y) + + +def test_no_match_call(): + x = relay.var('x') + y = relay.var('y') + add_pattern = is_op('add')(wildcard(), wildcard()) + assert not add_pattern.match(x - y) + + +def test_match_option(): + x = relay.var('x') + w = relay.var('w') + b = relay.var('b') + pattern = is_op("nn.relu")(is_op("nn.conv2d")( + wildcard(), wildcard()).optional(lambda x: is_op("nn.bias_add")(x, wildcard()))) + + conv2d = relay.op.nn.conv2d(x, w) + relu = relay.op.nn.relu(conv2d) + assert pattern.match(relu) + + conv2d = relay.op.nn.conv2d(x, w) + bias_add = relay.op.nn.bias_add(conv2d, b) + relu = relay.op.nn.relu(bias_add) + assert pattern.match(relu) + + pattern = is_op("nn.conv2d")(wildcard(), wildcard()) + pattern = pattern.optional(is_op('nn.relu')).optional(is_op("tanh")) + + conv2d = relay.op.nn.conv2d(x, w) + relu = relay.op.nn.relu(conv2d) + tanh = relay.op.tanh(conv2d) + tanh2 = relay.op.tanh(relu) + relu2 = relay.op.nn.relu(tanh) + assert pattern.match(conv2d) + assert pattern.match(relu) + assert pattern.match(tanh) + assert pattern.match(tanh2) + assert not pattern.match(relu2) + + +def test_no_match_option(): + x = relay.var('x') + w = relay.var('w') + b = relay.var('b') + pattern = is_op("nn.relu")(is_op("nn.conv2d")( + wildcard(), wildcard()).optional(lambda x: is_op("nn.bias_add")(x, wildcard()))) + + conv2d = relay.op.nn.conv2d(x, w) + relu = relay.op.tanh(conv2d) + assert not pattern.match(relu) + + conv2d = relay.op.nn.dense(x, w) + relu = relay.op.tanh(conv2d) + assert not pattern.match(relu) + + conv2d = relay.op.nn.dense(x, w) + bias_add = relay.op.nn.bias_add(conv2d, b) + relu = relay.op.nn.relu(bias_add) + assert not pattern.match(relu) + + conv2d = relay.op.nn.conv2d(x, w) + bias_add = conv2d + w + relu = relay.op.nn.relu(bias_add) + assert not pattern.match(relu) + + +def test_match_const(): + conv2d = is_op('nn.conv2d')(wildcard(), is_constant()) + pattern = is_op('nn.bias_add')(conv2d, wildcard()) + + x = relay.var('x', shape=(1, 3, 224, 224)) + w = relay.var('w', shape=(3, 3, 3, 3)) + b = relay.var('b', shape=(3, )) + conv2d = relay.op.nn.conv2d(x, w) + out = relay.op.nn.bias_add(conv2d, b) + func = relay.Function([x, w, b], out) + mod = tvm.IRModule.from_expr(func) + + assert not pattern.match(mod['main'].body) + mod["main"] = bind_params_by_name(mod["main"], + {'w': tvm.nd.array(np.ones(shape=(3, 3, 3, 3)))}) + assert pattern.match(mod['main'].body) + + +def test_match_tuple(): + x = relay.var('x') + y = relay.var('y') + z = relay.op.op.get("add") + tuple_pattern = is_tuple((is_var("x"), wildcard(), is_op("add"))) + assert tuple_pattern.match(relay.expr.Tuple((x, y, z))) + + tuple_pattern = is_tuple((is_var("x"), wildcard(), is_op("add"))) + tuple_get_item_pattern = is_tuple_get_item(tuple_pattern, 1) + assert tuple_get_item_pattern.match(relay.expr.TupleGetItem(relay.expr.Tuple((x, y, z)), 1)) + + +def test_no_match_tuple(): + x = relay.var('x') + y = relay.var('y') + z = relay.op.op.get("add") + tuple_pattern = is_tuple((is_var('x'), wildcard(), is_op("add"), wildcard())) + assert not tuple_pattern.match(relay.expr.Tuple((x, y, z))) + + tuple_pattern = is_tuple((is_var('x'), wildcard(), is_op("add"))) + tuple_get_item_pattern = is_tuple_get_item(tuple_pattern, 1) + assert not tuple_get_item_pattern.match(relay.expr.TupleGetItem(relay.expr.Tuple( + (x, y, z)), 2)) + + +def test_match_type(): + x = relay.var('x', shape=(10, 10), dtype="float32") + ty_pat = has_type(relay.TensorType((10, 10), "float32")) + assert ty_pat.match(x) + + +def test_no_match_type(): + x = relay.var('x', shape=(10, 10), dtype="int32") + ty_pat = has_type(relay.TensorType((10, 10), "float32")) + assert not ty_pat.match(x) + + +def test_match_dtype(): + x = relay.var('x', shape=(10, 10), dtype="float32") + ty_pat = has_dtype("float32") + assert ty_pat.match(x) + + +def test_no_match_dtype(): + x = relay.var('x', shape=(10, 10), dtype="int32") + ty_pat = has_dtype("float32") + assert not ty_pat.match(x) + + +def test_match_shape(): + x = relay.var('x', shape=(10, 10), dtype="float32") + ty_pat = has_shape((10, 10)) + assert ty_pat.match(x) + + +def test_no_match_shape(): + x = relay.var('x', shape=(10, 10), dtype="int32") + ty_pat = has_shape((10, 5)) + assert not ty_pat.match(x) + + +def test_match_op_attr(): + op = is_op('add').has_attr({"TOpPattern": K_BROADCAST}) + op_pat = op(wildcard(), wildcard()) + x = relay.var('x') + y = relay.var('y') + assert op_pat.match(x + y) + + +def test_no_match_op_attr(): + op = is_op('nn.dense').has_attr({"TOpPattern": K_ELEMWISE}) + op_pat = op(wildcard(), wildcard()) + x = relay.var('x') + y = relay.var('y') + assert not op_pat.match(relay.op.nn.dense(x, y)) + op = is_op('add').has_attr({"TOpPattern": K_BROADCAST}) + op_pat = op(wildcard(), wildcard()) + x = relay.var('x') + y = relay.var('y') + assert not op_pat.match(x - y) + + +def test_match_func_attr(): + pattern = wildcard().has_attr({"Composite": "add"}) + x = relay.var('x') + y = relay.var('y') + f = relay.Function([x, y], x + y).with_attr("Composite", "add") + assert pattern.match(f) + + +def test_no_match_func_attr(): + pattern = wildcard().has_attr({"Composite": "add"}) + x = relay.var('x') + y = relay.var('y') + + f = relay.Function([x, y], x + y).with_attr("RandomTest", "add") + assert not pattern.match(f) + f = relay.Function([x, y], x + y).with_attr("Composite", "conv_bias") + assert not pattern.match(f) + + +def test_match_call_attr(): + is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()).has_attr({"data_layout": "NCHW"}) + x = relay.var('x') + y = relay.var('y') + assert is_conv2d.match(relay.op.nn.conv2d(x, y)) + + +def test_no_match_call_attr(): + x = relay.var('x') + y = relay.var('y') + + is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()).has_attr({"data_layout": "NHWC"}) + assert not is_conv2d.match(relay.op.nn.conv2d(x, y)) + + is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()).has_attr({"RandomAttr": "NCHW"}) + assert not is_conv2d.match(relay.op.nn.conv2d(x, y)) + + +def test_match_diamond(): + # Pattern + is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()) + path1 = is_op('nn.relu')(is_conv2d) + path2 = is_op('nn.leaky_relu')(is_conv2d) + diamond = is_op('add')(path1, path2) + + # Expr + inp = relay.var('input') + weight = relay.var('weight') + conv2d = relay.op.nn.conv2d(inp, weight) + relu = relay.op.nn.relu(conv2d) + leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0) + out = relu + leaky_relu + + # Check + assert diamond.match(out) + + +def test_no_match_diamond(): + # Pattern + is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()) + path1 = is_op('nn.relu')(is_conv2d) + path2 = is_op('nn.leaky_relu')(is_conv2d) + diamond = is_op('add')(path1, path2) + + # Expr + inp = relay.var('input') + weight = relay.var('weight') + conv2d = relay.op.nn.conv2d(inp, weight) + relu = relay.op.nn.relu(conv2d) + leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0) + + # Check + assert not diamond.match(leaky_relu) + assert not diamond.match(relu) + + +def test_match_fake_diamond(): + # Pattern + is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()) + path1 = is_op('nn.relu')(is_conv2d) + path2 = is_op('nn.leaky_relu')(is_conv2d) + diamond = is_op('add')(path1, path2) + + # Expr + input1 = relay.var('input1') + weight1 = relay.var('weight1') + conv2d1 = relay.op.nn.conv2d(input1, weight1) + inp2 = relay.var('input2') + weight2 = relay.var('weight2') + conv2d2 = relay.op.nn.conv2d(inp2, weight2) + relu = relay.op.nn.relu(conv2d1) + leaky_relu = relay.op.nn.leaky_relu(conv2d2, alpha=0) + out = relu + leaky_relu + + # Check + assert not diamond.match(out) + + +def test_match_dominator(): + # Pattern + is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()) + is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard()) + reduction = is_op('add')(wildcard(), wildcard()) + diamond = dominates(is_conv2d, is_unary_elemwise, reduction) + + # Classic Diamond + inp = relay.var('input') + weight = relay.var('weight') + conv2d = relay.op.nn.conv2d(inp, weight) + relu = relay.op.nn.relu(conv2d) + relu = relay.op.nn.relu(relu) + leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0) + out = relu + leaky_relu + + # Check + assert diamond.match(out) + + # Deeper Branch + inp = relay.var('input') + weight = relay.var('weight') + conv2d = relay.op.nn.conv2d(inp, weight) + relu = relay.op.nn.relu(conv2d) + relu = relay.op.nn.relu(relu) + relu = relay.op.tanh(relu) + leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0) + out = relu + leaky_relu + + # Check + assert diamond.match(out) + + # Single Branch + inp = relay.var('input') + weight = relay.var('weight') + conv2d = relay.op.nn.conv2d(inp, weight) + relu = relay.op.nn.relu(conv2d) + relu = relay.op.nn.relu(relu) + tanh = relay.op.tanh(relu) + out = relu + tanh + + # Check + assert diamond.match(out) + + # Fuzzy path/nested Diamond + is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()) + is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))( + wildcard()) | is_op('add')(wildcard(), wildcard()) + reduction = is_op('add')(wildcard(), wildcard()) + diamond = dominates(is_conv2d, is_unary_elemwise, reduction) + + inp = relay.var('input') + weight = relay.var('weight') + conv2d = relay.op.nn.conv2d(inp, weight) + relu = relay.op.nn.relu(conv2d) + relu = relu + relu + tanh = relay.op.tanh(relu) + leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0) + out = tanh + leaky_relu + + assert diamond.match(out) + + +def test_not_match_dominator(): + is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()) + is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard()) + reduction = is_op('add')(wildcard(), wildcard()) + diamond = dominates(is_conv2d, is_unary_elemwise, reduction) + + # Fake Diamond + input1 = relay.var('input1') + weight1 = relay.var('weight1') + conv2d1 = relay.op.nn.conv2d(input1, weight1) + inp2 = relay.var('input2') + weight2 = relay.var('weight2') + conv2d2 = relay.op.nn.conv2d(inp2, weight2) + relu = relay.op.nn.relu(conv2d1) + leaky_relu = relay.op.nn.leaky_relu(conv2d2, alpha=0) + out = relu + leaky_relu + + # Check + assert not diamond.match(out) + + # Add op that doesn't match K_ELEMWISE + inp = relay.var('input') + weight = relay.var('weight') + conv2d = relay.op.nn.conv2d(inp, weight) + relu = relay.op.nn.relu(conv2d) + relu = relu + relu + leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0) + out = relu + leaky_relu + + # Check + assert not diamond.match(out) + + # Relu on the input instead of the conv + inp = relay.var('input') + weight = relay.var('weight') + conv2d = relay.op.nn.conv2d(inp, weight) + relu = relay.op.nn.relu(inp) + leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0) + out = relu + leaky_relu + + # Check + assert not diamond.match(out) + + # No conv + inp = relay.var('input') + relu = relay.op.nn.relu(inp) + relu = relay.op.nn.relu(relu) + tanh = relay.op.tanh(relu) + out = relu + tanh + + # Check + assert not diamond.match(out) + + +def test_match_typed_dominator(): + # Pattern + is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()) + is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard()).has_dtype("float32") + reduction = is_op('add')(wildcard(), wildcard()).has_shape([1, 3, 10, 10]) + diamond = dominates(is_conv2d, is_unary_elemwise, reduction) + + # Classic Diamond + inp = relay.var('input',relay.TensorType((1, 3, 12, 12), "float32")) + weight = relay.var('weight', relay.TensorType((3, 3, 3, 3), "float32")) + conv2d = relay.op.nn.conv2d(inp, weight) + relu = relay.op.nn.relu(conv2d) + relu = relay.op.nn.relu(relu) + leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0) + out = relu + leaky_relu + + # Check + assert diamond.match(out) + +def test_no_match_typed_dominator(): + # Classic Diamond + inp = relay.var('input',relay.TensorType((1, 3, 12, 12), "float32")) + weight = relay.var('weight', relay.TensorType((3, 3, 3, 3), "float32")) + conv2d = relay.op.nn.conv2d(inp, weight) + relu = relay.op.nn.relu(conv2d) + relu = relay.op.nn.relu(relu) + leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0) + out = relu + leaky_relu + + # Pattern + is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()) + is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard()).has_dtype("float32") + reduction = is_op('add')(wildcard(), wildcard()).has_shape([1, 1, 10, 10]) + diamond = dominates(is_conv2d, is_unary_elemwise, reduction) + + # Check + assert not diamond.match(out) + + # Pattern + is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()) + is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard()).has_dtype("float16") + reduction = is_op('add')(wildcard(), wildcard()).has_shape([1, 3, 10, 10]) + diamond = dominates(is_conv2d, is_unary_elemwise, reduction) + + # Check + assert not diamond.match(out) + + +def test_rewrite(): + x = relay.var('x') + y = relay.var('y') + add_pattern = is_op('add')(wildcard(), wildcard()) + sub_pattern = is_op('subtract')(wildcard(), wildcard()) + + class TestRewrite(DFPatternCallback): + def __init__(self): + self.pattern = add_pattern + + def callback(self, pre, post, node_map): + return post.args[0] - post.args[1] + + out = rewrite(TestRewrite(), x + y) + assert sub_pattern.match(out) + + +def test_rewrite_func(): + x = relay.var('x') + w = relay.var('w') + y = relay.var('y') + add_pattern = is_op('add')(wildcard(), wildcard()) + sub_pattern = is_op('subtract')(wildcard(), wildcard()) + + class TestRewrite(DFPatternCallback): + def __init__(self): + self.pattern = add_pattern + + def callback(self, pre, post, node_map): + return post.args[0] - post.args[1] + + inpf = relay.var("input") + weightf = relay.var("weight") + func = relay.Function([inpf, weightf], + relay.op.nn.relu(relay.op.nn.conv2d(inpf, weightf)), + attrs=None) + out = rewrite(TestRewrite(), func(x, w) + y) + assert sub_pattern.match(out) + + +def test_nested_rewrite(): + class PatternCallback(DFPatternCallback): + def __init__(self, pattern): + self.pattern = pattern + + def callback(self, pre, post, node_map): + return post + + def gen(): + x = relay.var('x') + y = relay.var('y') + y_add = relay.add(y, y) + n0 = relay.add(x, y_add) + n1 = relay.add(x, n0) + return relay.add(n1, n0) + + def pattern(): + a = wildcard() + b = wildcard() + n0 = is_op('add')(a, b) + n1 = is_op('add')(n0, a) + return is_op('add')(n0, n1) + + out = gen() + pat = pattern() + new_out = rewrite(PatternCallback(pat), out) + + assert tvm.ir.structural_equal(out, new_out) + + +def test_not_fuse_multi_diamond(): + # Pattern + is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()) + path1 = is_op('nn.relu')(is_conv2d) + path2 = is_op('nn.leaky_relu')(is_conv2d) + diamond = is_op('add')(path1, path2) + + # Expr + inp = relay.var('input') + weight = relay.var('weight') + conv2d = relay.op.nn.conv2d(inp, weight) + relu = relay.op.nn.relu(conv2d) + leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0) + out = relu + leaky_relu + out = out + conv2d + # Check + assert not diamond.match(out) + + +class BatchnormCallback(DFPatternCallback): + def __init__(self): + self.x = wildcard() + self.var = wildcard() + self.mean = wildcard() + self.beta = wildcard() + self.gamma = wildcard() + self.eps = is_constant() + + self.pattern = self.gamma * (self.x - self.mean) / is_op("sqrt")(self.var + self.eps) + \ + self.beta + + def callback(self, pre, post, node_map): + x = node_map[self.x][0] + var = node_map[self.var][0] + mean = node_map[self.mean][0] + beta = node_map[self.beta][0] + gamma = node_map[self.gamma][0] + eps = node_map[self.eps][0] + return relay.op.nn.batch_norm(x, gamma, beta, mean, var, + epsilon=eps.data.asnumpy().item())[0] + + +def test_fuse_batchnorm(): + x = relay.var('x') + var = relay.var('var') + mean = relay.var('mean') + beta = relay.var('beta') + gamma = relay.var('gamma') + + BN = gamma * (x - mean) / relay.op.sqrt(var + relay.const(1e-5)) + beta + + out = rewrite(BatchnormCallback(), BN) + assert tvm.ir.structural_equal( + out, + relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=1e-5)[0]) + + +def test_no_fuse_batchnorm(): + x = relay.var('x') + var = relay.var('var') + mean = relay.var('mean') + beta = relay.var('beta') + gamma = relay.var('gamma') + + fake_BN = gamma * (x - mean) / relay.op.sqrt(var + relay.const(1e-5)) - beta + + out = rewrite(BatchnormCallback(), fake_BN) + assert tvm.ir.structural_equal(out, fake_BN) + + +def test_fuse_double_batchnorm(): + x = relay.var('x') + var = relay.var('var') + mean = relay.var('mean') + beta = relay.var('beta') + gamma = relay.var('gamma') + + BN = gamma * (x - mean) / relay.op.sqrt(var + relay.const(1e-5)) + beta + BN2 = gamma * (BN - mean) / relay.op.sqrt(var + relay.const(1e-5)) + beta + + out = rewrite(BatchnormCallback(), BN2) + + bn = relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=1e-5)[0] + bn2 = relay.op.nn.batch_norm(bn, gamma, beta, mean, var, epsilon=1e-5)[0] + + assert tvm.ir.structural_equal(out, bn2) + + +def test_partial_fuse_double_batchnorm(): + x = relay.var('x') + var = relay.var('var') + mean = relay.var('mean') + beta = relay.var('beta') + gamma = relay.var('gamma') + + BN = gamma * (x - mean) / relay.op.sqrt(var + relay.const(1e-5)) - beta + BN2 = gamma * (BN - mean) / relay.op.sqrt(var + relay.const(1e-5)) + beta + + out = rewrite(BatchnormCallback(), BN2) + + bn2 = relay.op.nn.batch_norm(BN, gamma, beta, mean, var, epsilon=1e-5)[0] + + assert tvm.ir.structural_equal(out, bn2) + + +def test_fuse_batchnorm_commutation(): + x = relay.var('x') + var = relay.var('var') + mean = relay.var('mean') + beta = relay.var('beta') + gamma = relay.var('gamma') + + #commute add + BN = beta + gamma * (x - mean) / relay.op.sqrt(var + relay.const(1e-5)) + out = rewrite(BatchnormCallback(), BN) + assert tvm.ir.structural_equal( + out, + relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=1e-5)[0]) + + # associate divide/multiply + BN = (gamma * (x - mean)) / relay.op.sqrt(var + relay.const(1e-5)) + beta + out = rewrite(BatchnormCallback(), BN) + assert tvm.ir.structural_equal( + out, + relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=1e-5)[0]) + + # associate multiply/divide + BN = gamma * ((x - mean) / relay.op.sqrt(var + relay.const(1e-5))) + beta + out = rewrite(BatchnormCallback(), BN) + assert tvm.ir.structural_equal( + out, + relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=1e-5)[0]) + + +def test_quadruple_rewrite_dominator(): + class DominatorRemovalCallback(DFPatternCallback): + def __init__(self): + self.inp = wildcard() + self.weight = wildcard() + is_conv2d = is_op('nn.conv2d')(self.inp, self.weight) + is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))( + wildcard()) | is_op('add')(wildcard(), wildcard()) + reduction = is_op('add')(wildcard(), wildcard()) + self.pattern = dominates(is_conv2d, is_unary_elemwise, reduction) + + def callback(self, pre, post, node_map): + inp = node_map[self.inp][0] + weight = node_map[self.weight][0] + return relay.op.nn.conv2d(inp, weight) + + inp = relay.var('input') + weight = relay.var('weight') + # Classic Diamond + conv2d = relay.op.nn.conv2d(inp, weight) + relu = relay.op.nn.relu(conv2d) + relu = relay.op.nn.relu(relu) + leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0) + out = relu + leaky_relu + + # Deeper Branch + conv2d = relay.op.nn.conv2d(out, weight) + relu = relay.op.nn.relu(conv2d) + relu = relay.op.nn.relu(relu) + relu = relay.op.tanh(relu) + leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0) + out = relu + leaky_relu + + # Single Branch + conv2d = relay.op.nn.conv2d(out, weight) + relu = relay.op.nn.relu(conv2d) + relu = relay.op.nn.relu(relu) + tanh = relay.op.tanh(relu) + out = relu + tanh + + # Fuzzy path/nested Diamond + conv2d = relay.op.nn.conv2d(out, weight) + relu = relay.op.nn.relu(conv2d) + relu = relu + relu + tanh = relay.op.tanh(relu) + leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0) + out = tanh + leaky_relu + one = relay.op.nn.conv2d(inp, weight) + two = relay.op.nn.conv2d(one, weight) + three = relay.op.nn.conv2d(two, weight) + four = relay.op.nn.conv2d(three, weight) + + assert tvm.ir.structural_equal(DominatorRemovalCallback().rewrite(out), four) + + +def algebraic_simplify(expr): + zero = (is_expr(relay.const(0)) | is_expr(relay.const(0.0))) + one = (is_expr(relay.const(1)) | is_expr(relay.const(1.0))) + + class ElwiseNullCallback(DFPatternCallback): + def callback(self, pre, post, node_map): + return node_map[self.x][0] # pylint: disable=no-member + + class AddCallback(ElwiseNullCallback): + def __init__(self): + self.x = wildcard() + self.pattern = self.x + zero + + class SubCallback(ElwiseNullCallback): + def __init__(self): + self.x = wildcard() + self.pattern = self.x - zero + + class MulCallback(ElwiseNullCallback): + def __init__(self): + self.x = wildcard() + self.pattern = self.x * one + + class DivCallback(ElwiseNullCallback): + def __init__(self): + self.x = wildcard() + self.pattern = self.x / one + + class MulZeroCallback(ElwiseNullCallback): + def __init__(self): + self.x = zero + self.pattern = self.x * wildcard() + + class ZeroDivCallback(ElwiseNullCallback): + def __init__(self): + self.x = zero + self.pattern = self.x / wildcard() + + return rewrite([ + AddCallback(), + SubCallback(), + MulCallback(), + DivCallback(), + MulZeroCallback(), + ZeroDivCallback() + ], expr) + + +def test_algebraic_simplify(): + x = relay.Var('x') + y = relay.Var('y') + + one = relay.const(1) + zero = relay.const(0) + onef = relay.const(1.0) + zerof = relay.const(0.0) + + assert algebraic_simplify(x + zero) == x + assert algebraic_simplify(x + zerof) == x + assert algebraic_simplify(zero + x) == x + assert algebraic_simplify(zerof + x) == x + + assert algebraic_simplify(x - zero) == x + assert algebraic_simplify(x - zerof) == x + + assert algebraic_simplify(x * one) == x + assert algebraic_simplify(x * onef) == x + assert algebraic_simplify(one * x) == x + assert algebraic_simplify(onef * x) == x + assert algebraic_simplify(x * zero) == zero + assert algebraic_simplify(x * zerof) == zerof + + assert algebraic_simplify(x / one) == x + assert algebraic_simplify(x / onef) == x + assert algebraic_simplify(zero / x) == zero + assert algebraic_simplify(zerof / x) == zerof + + assert tvm.ir.structural_equal(algebraic_simplify((x + zero * y) / one + (y * one) - zero / x), + x + y) + + +def test_double_partition(): + # Pattern 1 + conv2d_p = is_op('nn.conv2d')(wildcard(), wildcard()) + bias_add_p = is_op("nn.bias_add")(conv2d_p, wildcard()) + relu_p = is_op('nn.relu')(bias_add_p) + + # Graph + x = relay.var('input') + w = relay.var('weight') + b = relay.var('bias') + w2 = relay.var('weight') + b2 = relay.var('bias') + conv2d = relay.op.nn.conv2d(x, w) + bias_add = relay.op.nn.bias_add(conv2d, b) + relu = relay.op.nn.relu(bias_add) + conv2d2 = relay.op.nn.conv2d(relu, w2) + bias_add2 = relay.op.nn.bias_add(conv2d2, b2) + + partitioned = bias_add2 + for pat, label in [(relu_p, "conv_bias_relu"), (bias_add_p, "conv_bias")]: + partitioned = pat.partition(partitioned, {"Composite": label}) + + inpf = relay.var("input") + weightf = relay.var("weight") + biasf = relay.var("bias") + func0 = relay.Function( + [inpf, weightf, biasf], + relay.op.nn.relu(relay.op.nn.bias_add( + relay.op.nn.conv2d(inpf, weightf), + biasf))).with_attr("Composite", + "conv_bias_relu").with_attr("PartitionedFromPattern", + "nn.conv2d_nn.bias_add_nn.relu_") + inpf = relay.var("input") + weightf = relay.var("weight") + biasf = relay.var("bias") + func1 = relay.Function([inpf, weightf, biasf], + relay.op.nn.bias_add(relay.op.nn.conv2d(inpf, weightf), + biasf)).with_attr("Composite", + "conv_bias").with_attr( + "PartitionedFromPattern", + "nn.conv2d_nn.bias_add_") + + expected = func1(func0(x, w, b), w2, b2) + assert tvm.ir.structural_equal(partitioned, expected) + + +def test_partition_dominator(): + # Pattern + is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()) + is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard()) + reduction = is_op('add')(wildcard(), wildcard()) + diamond = dominates(is_conv2d, is_unary_elemwise, reduction) + + # Classic Diamond + inp = relay.var('input') + weight = relay.var('weight') + + def generate_diamond(inp, weight): + conv2d = relay.op.nn.conv2d(inp, weight) + relu = relay.op.nn.relu(conv2d) + relu = relay.op.nn.relu(relu) + leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0) + return relu + leaky_relu + + out = generate_diamond(inp * inp, weight * weight) + # Check + partitioned = diamond.partition(out) + + i = relay.Var("input") + w = relay.Var("weight") + f = relay.Function([i, w], generate_diamond(i, w)).with_attr( + "PartitionedFromPattern", "nn.conv2d_nn.relu_nn.relu_nn.leaky_relu_add_") + assert tvm.ir.structural_equal(partitioned, f(inp * inp, weight * weight)) + + +def test_quadruple_partition_dominator(): + # Pattern + is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()) + is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))( + wildcard()) | is_op('add')(wildcard(), wildcard()) + reduction = is_op('add')(wildcard(), wildcard()) + diamond = dominates(is_conv2d, is_unary_elemwise, reduction) + + inp = relay.var('input') + weight = relay.var('weight') + + # Classic Diamond + def classic_diamond(inp, weight): + conv2d = relay.op.nn.conv2d(inp, weight) + relu = relay.op.nn.relu(conv2d) + relu = relay.op.nn.relu(relu) + leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0) + return relu + leaky_relu + + # Deeper Branch + def deeper_diamond(inp, weight): + conv2d = relay.op.nn.conv2d(inp, weight) + relu = relay.op.nn.relu(conv2d) + relu = relay.op.nn.relu(relu) + relu = relay.op.tanh(relu) + leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0) + return relu + leaky_relu + + # Single Branch + def single_branch(inp, weight): + conv2d = relay.op.nn.conv2d(inp, weight) + relu = relay.op.nn.relu(conv2d) + relu = relay.op.nn.relu(relu) + tanh = relay.op.tanh(relu) + return relu + tanh + + # Fuzzy path/nested Diamond + def nested_diamond(inp, weight): + conv2d = relay.op.nn.conv2d(inp, weight) + relu = relay.op.nn.relu(conv2d) + relu = relu + relu + tanh = relay.op.tanh(relu) + leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0) + return tanh + leaky_relu + + partitioned = diamond.partition( + nested_diamond(single_branch(deeper_diamond(classic_diamond(inp, weight), weight), weight), + weight)) + + functions = [] + partition_names = [ + "nn.conv2d_nn.relu_nn.relu_nn.leaky_relu_add_", + "nn.conv2d_nn.relu_nn.relu_tanh_nn.leaky_relu_add_", "nn.conv2d_nn.relu_nn.relu_tanh_add_", + "nn.conv2d_nn.relu_add_tanh_nn.leaky_relu_add_" + ] + for i, f in enumerate([classic_diamond, deeper_diamond, single_branch, nested_diamond]): + inpf = relay.var("input") + weightf = relay.var("weight") + functions.append( + relay.Function([inpf, weightf], f(inpf, + weightf)).with_attr("PartitionedFromPattern", + partition_names[i])) + + reference = functions[3](functions[2](functions[1](functions[0](inp, weight), weight), weight), + weight) + assert tvm.ir.structural_equal(partitioned, reference) + + +def get_BN(x, var, mean, beta, gamma, eps): + return gamma * (x - mean) / relay.op.sqrt(var + eps) + beta + + +def test_partition_batchnorm(): + x = relay.var('x') + var = relay.var('var') + mean = relay.var('mean') + beta = relay.var('beta') + gamma = relay.var('gamma') + eps = relay.const(1e-5) + BN = get_BN(x, var, mean, beta, gamma, eps) + + xf = relay.var('xf') + varf = relay.var('varf') + meanf = relay.var('meanf') + betaf = relay.var('betaf') + gammaf = relay.var('gammaf') + # Put the arguments in toplogological order for the reference + f = relay.Function([gammaf, xf, meanf, varf, betaf], + get_BN(xf, varf, meanf, betaf, gammaf, + eps)).with_attr("PartitionedFromPattern", + "subtract_multiply_add_sqrt_divide_add_") + + partitioned = BatchnormCallback().pattern.partition(BN) + reference = f(gamma, x, mean, var, beta) + assert tvm.ir.structural_equal(partitioned, reference) + + +def test_partition_double_batchnorm(): + x = relay.var('x') + var = relay.var('var') + mean = relay.var('mean') + beta = relay.var('beta') + gamma = relay.var('gamma') + eps = relay.const(1e-5) + + BN = gamma * (x - mean) / relay.op.sqrt(var + eps) + beta + BN2 = gamma * (BN - mean) / relay.op.sqrt(var + eps) + beta + + xf = relay.var('xf') + varf = relay.var('varf') + meanf = relay.var('meanf') + betaf = relay.var('betaf') + gammaf = relay.var('gammaf') + f1 = relay.Function([gammaf, xf, meanf, varf, betaf], + get_BN(xf, varf, meanf, betaf, gammaf, + eps)).with_attr("PartitionedFromPattern", + "subtract_multiply_add_sqrt_divide_add_") + # The partitioner doesn't replace duplicates, so we use two copies of the function + xf2 = relay.var('xf2') + varf2 = relay.var('varf2') + meanf2 = relay.var('meanf2') + betaf2 = relay.var('betaf2') + gammaf2 = relay.var('gammaf2') + f2 = relay.Function([gammaf2, xf2, meanf2, varf2, betaf2], + get_BN(xf2, varf2, meanf2, betaf2, gammaf2, + eps)).with_attr("PartitionedFromPattern", + "subtract_multiply_add_sqrt_divide_add_") + + partitioned = BatchnormCallback().pattern.partition(BN2) + reference = f2(gamma, f1(gamma, x, mean, var, beta), mean, var, beta) + assert tvm.ir.structural_equal(partitioned, reference) + + +def test_partition_check(): + pattern = is_op("nn.relu")(is_op("nn.conv2d")(wildcard(), wildcard())) + + def check(pre): + return pre.args[0].attrs.data_layout == "NCHW" + + x = relay.var('input') + w = relay.var('weight') + conv2d = relay.op.nn.conv2d(x, w) + relu = relay.op.nn.relu(conv2d) + + xf = relay.var('input') + wf = relay.var('weight') + conv2df = relay.op.nn.conv2d(xf, wf) + reluf = relay.op.nn.relu(conv2df) + func = relay.Function([xf, wf], reluf).with_attr("PartitionedFromPattern", + "nn.conv2d_nn.relu_") + + reference = func(x, w) + partitioned = pattern.partition(relu, check=check) + assert tvm.ir.structural_equal(partitioned, reference) + + conv2d = relay.op.nn.conv2d(x, w, data_layout="NHWC") + relu = relay.op.nn.relu(conv2d) + assert relu == pattern.partition(relu, check=check) + + +def test_partition_check_types(): + pattern = is_op("nn.relu")(is_op("nn.conv2d")(wildcard(), wildcard())) + + def check(pre): + conv = pre.args[0] + return (conv.attrs.data_layout == "NCHW") and bool(conv.checked_type.shape[0] == 1) + + x = relay.var('input', shape=(1, 10, 10, 10)) + w = relay.var('weight', shape=(10, 10, 3, 3)) + conv2d = relay.op.nn.conv2d(x, w) + relu = relay.op.nn.relu(conv2d) + relu = run_opt_pass(relu, relay.transform.InferType()) + + partitioned = pattern.partition(relu, check=check) + assert partitioned.op.attrs["PartitionedFromPattern"] == "nn.conv2d_nn.relu_" + + conv2d = relay.op.nn.conv2d(x, w, data_layout="NHWC") + relu = relay.op.nn.relu(conv2d) + relu = run_opt_pass(relu, relay.transform.InferType()) + assert relu == pattern.partition(relu, check=check) + + x = relay.var('input', shape=(2, 10, 10, 10)) + w = relay.var('weight', shape=(10, 10, 3, 3)) + conv2d = relay.op.nn.conv2d(x, w) + relu = relay.op.nn.relu(conv2d) + relu = run_opt_pass(relu, relay.transform.InferType()) + assert relu == pattern.partition(relu, check=check) + + +def conv_bias_relu(x, w, b): + conv2d = relay.op.nn.conv2d(x, w) + bias_add = relay.op.nn.bias_add(conv2d, b) + relu = relay.op.nn.relu(bias_add) + return relu + + +def test_partition_option(): + x = relay.var('x') + w = relay.var('w') + b = relay.var('b') + + conv2d = is_op('nn.conv2d')(wildcard(), wildcard()) + bias = conv2d.optional(lambda x: is_op('nn.bias_add')(x, wildcard())) + pattern1 = is_op('nn.relu')(bias) + + conv2d = is_op('nn.conv2d')(wildcard(), wildcard()) + bias = is_op('nn.bias_add')(conv2d, wildcard()) + pattern2 = bias.optional(lambda x: is_op('nn.relu')(x)) + + relu = conv_bias_relu(x, w, b) + + xf = relay.var('x') + wf = relay.var('w') + bf = relay.var('b') + func = relay.Function([xf, wf, bf], + conv_bias_relu(xf, wf, bf)).with_attr("PartitionedFromPattern", + "nn.conv2d_nn.bias_add_nn.relu_") + + assert pattern1.match(relu) + assert tvm.ir.structural_equal(func(x, w, b), pattern1.partition(relu)) + + assert pattern2.match(relu) + assert tvm.ir.structural_equal(func(x, w, b), pattern2.partition(relu)) + +def test_match_match(): + add_pattern = is_op('add')(wildcard(), wildcard()) + class TestRewrite(DFPatternCallback): + def __init__(self): + self.pattern = add_pattern + def callback(self, pre, post, node_map): + return post.args[0] - post.args[1] + mod = tvm.IRModule({}) + tvm.relay.prelude.Prelude(mod) + # Apply rewrite on IR including relay.Match + out = rewrite(TestRewrite(), mod['tensor_concatenate_int64']) + assert tvm.ir.structural_equal(mod['tensor_concatenate_int64'], out) + +def test_partition_constant_embedding(): + x = relay.var('x') + w = relay.var('w') + wc = relay.const(1) + b = relay.var('b') + + xf = relay.var('x') + wf = relay.var('w') + bf = relay.var('b') + embeded_func = relay.Function([xf, bf], + conv_bias_relu(xf, wc, + bf)).with_attr("PartitionedFromPattern", + "nn.conv2d_nn.bias_add_nn.relu_") + xf = relay.var('x') + wf = relay.var('w') + bf = relay.var('b') + lifted_func = relay.Function([xf, wf, bf], + conv_bias_relu(xf, wf, + bf)).with_attr("PartitionedFromPattern", + "nn.conv2d_nn.bias_add_nn.relu_") + relu = conv_bias_relu(x, w, b) + reluc = conv_bias_relu(x, wc, b) + + # Check lifting of wildcard matches + pattern = is_op('nn.relu')(is_op('nn.bias_add')(is_op('nn.conv2d')(wildcard(), wildcard()), + wildcard())) + assert tvm.ir.structural_equal(lifted_func(x, w, b), pattern.partition(relu)) + assert tvm.ir.structural_equal(lifted_func(x, wc, b), pattern.partition(reluc)) + + # Check lifting of input matches + pattern = is_op('nn.relu')(is_op('nn.bias_add')(is_op('nn.conv2d')(wildcard(), is_var()), + wildcard())) + assert tvm.ir.structural_equal(lifted_func(x, w, b), pattern.partition(relu)) + assert tvm.ir.structural_equal(reluc, pattern.partition(reluc)) #Constants are not Inputs + + # Check embedding of constant matches + pattern = is_op('nn.relu')(is_op('nn.bias_add')(is_op('nn.conv2d')(wildcard(), is_constant()), + wildcard())) + assert tvm.ir.structural_equal(relu, pattern.partition(relu)) + assert tvm.ir.structural_equal(embeded_func(x, b), pattern.partition(reluc)) + + # Check embedding of constant ExprPatterns + pattern = is_op('nn.relu')(is_op('nn.bias_add')(is_op('nn.conv2d')(wildcard(), is_expr(wc)), + wildcard())) + assert tvm.ir.structural_equal(relu, pattern.partition(relu)) + assert tvm.ir.structural_equal(embeded_func(x, b), pattern.partition(reluc)) + + # Check lifting/embedding of Alt matches + pattern = is_op('nn.relu')(is_op('nn.bias_add')(is_op('nn.conv2d')( + wildcard(), is_var() | is_constant()), wildcard())) + assert tvm.ir.structural_equal(lifted_func(x, w, b), pattern.partition(relu)) + assert tvm.ir.structural_equal(embeded_func(x, b), pattern.partition(reluc)) + + # Check lifting/embedding of Alt matches with the other ordering + pattern = is_op('nn.relu')(is_op('nn.bias_add')(is_op('nn.conv2d')( + wildcard(), is_constant() | is_var()), wildcard())) + assert tvm.ir.structural_equal(lifted_func(x, w, b), pattern.partition(relu)) + assert tvm.ir.structural_equal(embeded_func(x, b), pattern.partition(reluc)) + + +if __name__ == "__main__": + test_expr_pattern() + test_var_pattern() + test_constant_pattern() + test_wildcard_pattern() + test_CallPattern() + test_TuplePattern() + test_TupleGetItemPattern() + test_AltPattern() + test_TypePattern() + test_DataTypePattern() + test_ShapePattern() + test_AttrPattern() + test_match_op() + test_no_match_op() + test_match_op_or() + test_match_call_commutive() + test_no_match_call_commutive() + test_match_call() + test_no_match_call() + test_match_option() + test_no_match_option() + test_match_const() + test_match_tuple() + test_no_match_tuple() + test_match_type() + test_no_match_type() + test_match_dtype() + test_no_match_dtype() + test_match_shape() + test_no_match_shape() + test_match_op_attr() + test_no_match_op_attr() + test_match_func_attr() + test_no_match_func_attr() + test_match_call_attr() + test_no_match_call_attr() + test_match_diamond() + test_no_match_diamond() + test_match_fake_diamond() + test_match_dominator() + test_not_match_dominator() + test_rewrite() + test_rewrite_func() + test_nested_rewrite() + test_not_fuse_multi_diamond() + test_fuse_batchnorm() + test_no_fuse_batchnorm() + test_fuse_double_batchnorm() + test_partial_fuse_double_batchnorm() + test_fuse_batchnorm_commutation() + test_quadruple_rewrite_dominator() + test_algebraic_simplify() + test_double_partition() + test_partition_dominator() + test_quadruple_partition_dominator() + test_partition_batchnorm() + test_partition_double_batchnorm() + test_partition_check() + test_partition_check_types() + test_partition_option() + test_match_match() + test_partition_constant_embedding() diff --git a/tests/python/relay/test_external_codegen.py b/tests/python/relay/test_external_codegen.py index 3797910080a1..c449ce39ff01 100644 --- a/tests/python/relay/test_external_codegen.py +++ b/tests/python/relay/test_external_codegen.py @@ -49,7 +49,8 @@ def update_lib(lib): return lib def check_vm_result(): - with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]): + with tvm.transform.PassContext(opt_level=3, + disabled_pass=["AlterOpLayout"]): exe = relay.vm.compile(mod, target=target) code, lib = exe.save() lib = update_lib(lib) @@ -60,7 +61,8 @@ def check_vm_result(): tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol) def check_graph_runtime_result(): - with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]): + with tvm.transform.PassContext(opt_level=3, + disabled_pass=["AlterOpLayout"]): json, lib, _ = relay.build(mod, target=target) lib = update_lib(lib) rt_mod = tvm.contrib.graph_runtime.create(json, lib, ctx) diff --git a/tests/python/relay/test_ir_op.py b/tests/python/relay/test_ir_op.py index 1fd68b391d14..46e4b025fab7 100644 --- a/tests/python/relay/test_ir_op.py +++ b/tests/python/relay/test_ir_op.py @@ -14,13 +14,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import tvm from tvm import relay from tvm.relay.testing.temp_op_attr import TempOpAttr def test_op_attr(): log_op = relay.op.get("log") - @relay.op.register("exp", "ftest") + @tvm.ir.register_op_attr("exp", "ftest") def test(x): return x + 1 @@ -37,9 +38,9 @@ def add2(x): return x + 2 # Register fadd1 and fadd2 attributes. - relay.op.register("exp", "fadd1", add1) - relay.op.register("log", "fadd1", add1) - relay.op.register("log", "fadd2", add2) + tvm.ir.register_op_attr("exp", "fadd1", add1) + tvm.ir.register_op_attr("log", "fadd1", add1) + tvm.ir.register_op_attr("log", "fadd2", add2) # Reset log fadd1 attr. log_op = relay.op.get("log") @@ -63,7 +64,7 @@ def add2(x): return x + 2 # Set original attr value is add1. - relay.op.register("sqrt", "ftest", add1) + tvm.ir.register_op_attr("sqrt", "ftest", add1) with TempOpAttr("sqrt", "ftest", add2): # Check that the attr value is updated to add2. diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py index 9e624917ab1a..c4ac042bdb22 100644 --- a/tests/python/relay/test_ir_parser.py +++ b/tests/python/relay/test_ir_parser.py @@ -219,7 +219,7 @@ def test_vars(): # operator id op = parse_text("foo") - assert isinstance(op, relay.Op) + assert isinstance(op, tvm.ir.Op) assert op.name == "foo" diff --git a/tests/python/relay/test_ir_text_printer.py b/tests/python/relay/test_ir_text_printer.py index 61dbca33ca7a..2a88c0c99ae7 100644 --- a/tests/python/relay/test_ir_text_printer.py +++ b/tests/python/relay/test_ir_text_printer.py @@ -240,6 +240,15 @@ def @main[A]() -> fn (A, List[A]) -> List[A] { assert main_def_str.strip() in mod_str +def test_null_attribute(): + x = relay.var("x") + y = relay.var("y") + z = relay.Function([x], y) + z = z.with_attr("TestAttribute", None) + txt = astext(z) + assert "TestAttribute=(nullptr)" in txt + + if __name__ == "__main__": do_print[0] = True test_lstm() @@ -262,3 +271,4 @@ def @main[A]() -> fn (A, List[A]) -> List[A] { test_variable_name() test_call_node_order() test_unapplied_constructor() + test_null_attribute() diff --git a/tests/python/relay/test_json_compact.py b/tests/python/relay/test_json_compact.py index 16d02d2cc224..be3e2a0be9e1 100644 --- a/tests/python/relay/test_json_compact.py +++ b/tests/python/relay/test_json_compact.py @@ -43,6 +43,44 @@ def test_type_var(): assert isinstance(tvar, tvm.ir.GlobalTypeVar) assert tvar.name_hint == "in0" +def test_var(): + # type var in 0.6 + nodes = [ + {"type_key": ""}, + {"type_key": "relay.Var", + "attrs": { + "_checked_type_": "0", + "span": "0", + "type_annotation": "0", + "vid": "2" + } + }, + {"type_key": "relay.Id", + "attrs": {"name_hint": "a3"}}, + {"type_key": "relay.TensorType", + "attrs": { + "dtype": "float32", + "shape": "4", + "span": "0" + } + }, + {"type_key": "Array", + "data": [5, 6] + }, + {"type_key": "IntImm", + "attrs": {"dtype": "int32", "value": "16"}}, + {"type_key": "IntImm", + "attrs": {"dtype": "int32", "value": "8"}} + ] + data = { + "root" : 1, + "nodes": nodes, + "attrs": {"tvm_version": "0.6.0"}, + "b64ndarrays": [], + } + tvar = tvm.ir.load_json(json.dumps(data)) + assert isinstance(tvar, relay.Var) + assert tvar.name_hint == "a3" def test_incomplete_type(): nodes = [ @@ -107,6 +145,24 @@ def test_global_var(): } tvar = tvm.ir.load_json(json.dumps(data)) assert isinstance(tvar, tvm.ir.GlobalVar) + nodes = [ + {"type_key": ""}, + {"type_key": "GlobalVar", + "attrs": { + "_checked_type_": "0", + "name_hint": "x", + "span": "0" + } + } + ] + data = { + "root" : 1, + "nodes": nodes, + "attrs": {"tvm_version": "0.6.0"}, + "b64ndarrays": [], + } + tvar = tvm.ir.load_json(json.dumps(data)) + assert isinstance(tvar, tvm.ir.GlobalVar) def test_op(): @@ -148,10 +204,40 @@ def test_tir_var(): assert y.name == "y" +def test_str_map(): + nodes = [ + {'type_key': ''}, + {'type_key': 'StrMap', 'keys': ['z', 'x'], 'data': [2, 3]}, + {'type_key': 'IntImm', 'attrs': {'dtype': 'int32', 'value': '2'}}, + {'type_key': 'Max', 'attrs': {'a': '4', 'b': '10', 'dtype': 'int32'}}, + {'type_key': 'Add', 'attrs': {'a': '5', 'b': '9', 'dtype': 'int32'}}, + {'type_key': 'Add', 'attrs': {'a': '6', 'b': '8', 'dtype': 'int32'}}, + {'type_key': 'tir.Var', 'attrs': {'dtype': 'int32', 'name': '7', 'type_annotation': '0'}}, + {'type_key': 'runtime.String', 'repr_str': 'x'}, + {'type_key': 'IntImm', 'attrs': {'dtype': 'int32', 'value': '1'}}, + {'type_key': 'IntImm', 'attrs': {'dtype': 'int32', 'value': '2'}}, + {'type_key': 'IntImm', 'attrs': {'dtype': 'int32', 'value': '100'}} + ] + data = { + "root" : 1, + "nodes": nodes, + "attrs": {"tvm_version": "0.6.0"}, + "b64ndarrays": [], + } + x = tvm.ir.load_json(json.dumps(data)) + assert(isinstance(x, tvm.ir.container.Map)) + assert(len(x) == 2) + assert('x' in x) + assert('z' in x) + assert(bool(x['z'] == 2)) + + if __name__ == "__main__": test_op() test_type_var() + test_var() test_incomplete_type() test_func_tuple_type() test_global_var() test_tir_var() + test_str_map() diff --git a/tests/python/relay/test_pass_memory_alloc.py b/tests/python/relay/test_memory_passes.py similarity index 62% rename from tests/python/relay/test_pass_memory_alloc.py rename to tests/python/relay/test_memory_passes.py index c3c53121d934..dc16865aa620 100644 --- a/tests/python/relay/test_pass_memory_alloc.py +++ b/tests/python/relay/test_memory_passes.py @@ -18,21 +18,38 @@ from tvm import te import numpy as np from tvm import relay -from tvm.relay.transform import memory_alloc +from tvm.relay import memory_alloc -def check_vm_alloc(func, check_fn): - mod = tvm.IRModule() - mod['main'] = func - ex = relay.create_executor('vm', mod) +def check_memory_plan(func, check_fn): + # Build Module + mod = tvm.IRModule().from_expr(func) + + # Convert arguments. args = [] for param in func.params: param = param.type_annotation sh = [int(sh) for sh in param.shape] data = np.random.rand(*sh).astype(param.dtype) args.append(tvm.nd.array(data)) - result = ex.evaluate(mod['main'])(*args) + + # Compute without memory planning. + ex = relay.create_executor('vm', mod) + no_plan_result = ex.evaluate(mod['main'])(*args) + + # Compute with memory planning. + with tvm.transform.PassContext(opt_level=1, disabled_pass=["MemoryPlan"]): + plan_result = ex.evaluate(mod['main'])(*args) + + # Compute Python result. py_res = check_fn(*[arg.asnumpy() for arg in args]) - np.testing.assert_allclose(result.asnumpy(), py_res) + + # First check that the two VM results agree. + np.testing.assert_allclose( + no_plan_result.asnumpy(), + plan_result.asnumpy()) + + # Finally check that the results match the Python result. + np.testing.assert_allclose(plan_result.asnumpy(), py_res) def storage_type(mod): return relay.TypeCall(mod.get_global_type_var("Storage"), []) @@ -46,7 +63,7 @@ def test_tyck_alloc_tensor(): mod.import_from_std("core.rly") sto = relay.Var("x", storage_type(mod)) sh = relay.const(np.array([1, 2]), dtype="int64") - at = relay.op.memory.alloc_tensor(sto, sh) + at = relay.op.memory.alloc_tensor(sto, relay.const(0, dtype="int64"), sh) mod['main'] = relay.Function([sto], at) relay.transform.InferType()(mod) @@ -58,20 +75,34 @@ def test_add(): x = relay.var('x', shape=(2,)) z = x + x func = relay.Function([x,], z) - check_vm_alloc(func, check_add) + check_memory_plan(func, check_add) def check_add_sub(x, y): z = x + x return z - y + def test_add_sub(): x = relay.var('x', shape=(10,)) y = relay.var('y', shape=(10,)) z = x + x z = z - y func = relay.Function([x, y], z) - check_vm_alloc(func, check_add_sub) + check_memory_plan(func, check_add_sub) + +def check_no_fuse(x, y, w): + z = x + y + return np.matmul(z, np.transpose(w)) + +def test_no_fuse(): + x = relay.var('x', shape=(5, 1)) + y = relay.var('y', shape=(5, 1)) + w = relay.var('w', shape=(5, 1)) + z = x + y + out = relay.op.nn.dense(z, w) + func = relay.Function([x, y, w], out) + check_memory_plan(func, check_no_fuse) if __name__ == "__main__": test_tyck_alloc_tensor() diff --git a/tests/python/relay/test_op_fast_math.py b/tests/python/relay/test_op_fast_math.py index 215b83e8e80d..a771d29a431d 100644 --- a/tests/python/relay/test_op_fast_math.py +++ b/tests/python/relay/test_op_fast_math.py @@ -34,7 +34,7 @@ def test_apply(relay_op, name, f_numpy, low, high, step, dtype="float32"): func = relay.Function([x], y) mod = tvm.IRModule.from_expr(func) - with relay.build_config(opt_level=3, required_pass=['FastMath']): + with tvm.transform.PassContext(opt_level=3, required_pass=['FastMath']): graph, lib, params = relay.build(mod, target="llvm", params=None) # Check that the op related to fast math have been convered to function in lib diff --git a/tests/python/relay/test_op_grad_level1.py b/tests/python/relay/test_op_grad_level1.py index 0579441166ae..9faf6d903a9c 100644 --- a/tests/python/relay/test_op_grad_level1.py +++ b/tests/python/relay/test_op_grad_level1.py @@ -65,7 +65,16 @@ def check_single_op(opfunc, ref): (tvm.relay.cos, lambda x: -1.0 * np.sin(x)), (tvm.relay.sin, lambda x: np.cos(x)), (tvm.relay.tan, lambda x: 1.0 / (np.cos(x) ** 2)), - (tvm.relay.atan, lambda x: 1 / (1 + np.power(x, 2.0)))]: + (tvm.relay.atan, lambda x: 1 / (1 + np.power(x, 2.0))), + (tvm.relay.log2, lambda x: 1 / (np.log(2) * x)), + (tvm.relay.log10, lambda x: 1 / (np.log(10) * x)), + (tvm.relay.cosh, lambda x: np.sinh(x)), + (tvm.relay.sinh, lambda x: np.cosh(x)), + (tvm.relay.asin, lambda x: 1. / (1. - x**2) ** (1./2.)), + (tvm.relay.acos, lambda x: -1. / (1. - x**2.) ** (1./2.)), + (tvm.relay.acosh, lambda x: 1./ (x**2 - 1.)**(1./2.)), + (tvm.relay.asinh, lambda x: 1./ (x**2 + 1.)**(1./2.)), + (tvm.relay.atanh, lambda x: -1./ (x**2 - 1.))]: check_single_op(opfunc, ref) diff --git a/tests/python/relay/test_op_grad_level2.py b/tests/python/relay/test_op_grad_level2.py index 2b5a1c29e0de..d898451ff6ac 100644 --- a/tests/python/relay/test_op_grad_level2.py +++ b/tests/python/relay/test_op_grad_level2.py @@ -162,6 +162,7 @@ def verify_dense_grad(d_shape, w_shape): def test_dense_grad(): verify_dense_grad((1, 8), (16, 8)) verify_dense_grad((1, 4), (3, 4)) + verify_dense_grad((5, 4), (3, 4)) def verify_batch_flatten_grad(d_shape): diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index 771a63deec69..d45372e1ce0d 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -262,7 +262,7 @@ def compile_test_conv2d_arm_cpu(dtype, out_dtype, scale, dshape, kshape, with open(temp.relpath("temp.log"), "w") as log_file: log_file.write(test_schedule) with autotvm.apply_history_best(temp.relpath("temp.log")): - with relay.build_config(opt_level=3): + with tvm.transform.PassContext(opt_level=3): print('Compiling...') graph_json, mod, params = tvm.relay.build(mod, target="llvm -device=arm_cpu") @@ -356,7 +356,7 @@ def run_test_conv2d_cuda(dtype, out_dtype, scale, dshape, kshape, data.astype(out_dtype), kernel.astype(out_dtype), 1, padding, groups=groups) - with WinogradFallback(), relay.build_config(opt_level=3): + with WinogradFallback(), tvm.transform.PassContext(opt_level=3): for target, ctx in ctx_list(): if target != 'cuda': continue @@ -578,7 +578,7 @@ def run_test_conv3d_cuda(dtype, out_dtype, scale, dshape, kshape, data.astype(out_dtype), kernel.astype(out_dtype), 1, padding, groups=groups) - with WinogradFallback(), relay.build_config(opt_level=3): + with WinogradFallback(), tvm.transform.PassContext(opt_level=3): for target, ctx in ctx_list(): if target != 'cuda': continue @@ -612,6 +612,66 @@ def run_test_conv3d_cuda(dtype, out_dtype, scale, dshape, kshape, padding=(0, 2, 2), channels=120, kernel_size=(1, 5, 5)) +def test_conv3d_transpose_infer_type(): + # symbolic in batch dimension + n, c, d, h, w = te.size_var("n"), 10, 224, 224, 224 + x = relay.var("x", relay.ty.TensorType((n, c, d, h, w), "float32")) + w = relay.var("w") + y = relay.nn.conv3d_transpose(x, w, + kernel_size=(3, 3, 3), + padding=(1, 1, 1), + channels=2) + yy = run_infer_type(y) + assert yy.checked_type == relay.TensorType( + (n, 2, 224, 224, 224), "float32") + + assert yy.args[1].checked_type == relay.TensorType( + (10, 2, 3, 3, 3), "float32") + + # infer by shape of w, mixed precision + n, c, d, h, w = te.size_var("n"), 10, 224, 224, 224 + x = relay.var("x", relay.TensorType((n, c, d, h, w), "int8")) + w = relay.var("w", relay.TensorType((10, 12, 3, 3, 3), "int8")) + y = relay.nn.conv3d_transpose(x, w, out_dtype="int32") + assert "out_dtype=\"int32\"" in y.astext() + yy = run_infer_type(y) + assert yy.checked_type == relay.TensorType( + (n, 12, 226, 226, 226), "int32") + + # infer shape in case of different dtypes for input and weight. + n, c, d, h, w = te.size_var("n"), 10, 224, 224, 224 + x = relay.var("x", relay.TensorType((n, c, d, h, w), "uint8")) + w = relay.var("w", relay.TensorType((10, 12, 3, 3, 3), "int8")) + y = relay.nn.conv3d_transpose(x, w, out_dtype="int32") + assert "out_dtype=\"int32\"" in y.astext() + yy = run_infer_type(y) + assert yy.checked_type == relay.TensorType( + (n, 12, 226, 226, 226), "int32") + + +def test_conv3d_transpose_ncdhw_run(): + dshape = (1, 3, 24, 24, 24) + kshape = (3, 4, 2, 2, 2) + + x = relay.var("x", shape=dshape) + w = relay.var("w") + y = relay.nn.conv3d_transpose(x, w, + channels=4, kernel_size=(2, 2, 2), strides=(1, 1, 1), + padding=(1, 1, 1)) + func = relay.Function([x, w], y) + dtype = "float32" + + data = np.random.uniform(size=dshape).astype(dtype) + kernel = np.random.uniform(size=kshape).astype(dtype) + + ref_res = topi.testing.conv3d_transpose_ncdhw_python(data, kernel, 1, 1) + + for target, ctx in ctx_list(): + intrp1 = relay.create_executor("graph", ctx=ctx, target=target) + op_res1 = intrp1.evaluate(func)(data, kernel) + tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5) + + def test_conv2d_transpose_infer_type(): # symbolic in batch dimension n, c, h, w = te.size_var("n"), 10, 10, 12 @@ -747,7 +807,7 @@ def test_upsampling3d_infer_type(): yy = run_infer_type(y) assert yy.checked_type == relay.TensorType((n, c, 200, 200, 400), "float32") -def _test_pool2d(opfunc, reffunc): +def _test_pool2d(opfunc, reffunc, pool_size=(2, 2), strides=(2, 2), padding=(0, 0)): n, c, h, w = te.size_var("n"), 10, 224, 224 x = relay.var("x", relay.TensorType((n, c, h, w), "float32")) y = opfunc(x, pool_size=(1, 1)) @@ -758,7 +818,7 @@ def _test_pool2d(opfunc, reffunc): dtype = "float32" dshape = (1, 3, 28, 28) x = relay.var("x", shape=dshape) - y = opfunc(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0)) + y = opfunc(x, pool_size=pool_size, strides=strides, padding=padding) func = relay.Function([x], y) data = np.random.uniform(size=dshape).astype(dtype) ref_res = reffunc(data.reshape(1, 3, 14, 2, 14, 2), axis=(3, 5)) @@ -780,7 +840,7 @@ def _test_pool2d_int(opfunc, reffunc, dtype): x = relay.var("x", shape=dshape, dtype=dtype) y = opfunc(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0)) func = relay.Function([x], y) - data = np.random.random_integers(low=-128, high=128, size=dshape) + data = np.random.randint(low=-128, high=128, size=dshape) ref_res = reffunc(data.reshape(1,3,14,2,14,2), axis=(3,5)).astype(dtype) for target, ctx in ctx_list(): intrp1 = relay.create_executor("graph", ctx=ctx, target=target) @@ -815,7 +875,9 @@ def _test_global_pool2d(opfunc, reffunc): def test_pool2d(): _test_pool2d(relay.nn.max_pool2d, np.max) + _test_pool2d(relay.nn.max_pool2d, np.max, pool_size=2, strides=2, padding=0) _test_pool2d(relay.nn.avg_pool2d, np.mean) + _test_pool2d(relay.nn.avg_pool2d, np.mean, pool_size=2, strides=2, padding=0) _test_pool2d_int(relay.nn.avg_pool2d, np.mean, 'int32') _test_pool2d_int(relay.nn.avg_pool2d, np.mean, 'uint16') _test_global_pool2d(relay.nn.global_max_pool2d, np.max) @@ -824,7 +886,7 @@ def test_pool2d(): def test_pool1d(): - def _test_pool1d(opfunc): + def _test_pool1d(opfunc, pool_size=(2,), strides=(2,), padding=(0, 0)): n, c, w = te.var("n"), 10, 224 x = relay.var("x", relay.TensorType((n, c, w), "float32")) y = opfunc(x, pool_size=(1,)) @@ -836,7 +898,7 @@ def _test_pool1d(opfunc): dshape = (1, 3, 32) x = relay.var("x", shape=dshape) pool_type = 'max' if 'max' in str(opfunc) else 'avg' - y = opfunc(x, pool_size=(2,), strides=(2,), padding=(0, 0)) + y = opfunc(x, pool_size=pool_size, strides=strides, padding=padding) func = relay.Function([x], y) data = np.random.uniform(size=dshape).astype(dtype) ref_res = topi.testing.pool1d_ncw_python(data, (2,), (2,), @@ -847,12 +909,18 @@ def _test_pool1d(opfunc): tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5) _test_pool1d(relay.nn.max_pool1d) + _test_pool1d(relay.nn.max_pool1d, pool_size=2, strides=2, padding=0) _test_pool1d(relay.nn.avg_pool1d) + _test_pool1d(relay.nn.avg_pool1d, pool_size=2, strides=2, padding=0) def test_pool3d(): - def _test_pool3d(opfunc, padding=(0, 0, 0, 0, 0, 0), out_shape=(1, 3, 16, 16, 16)): + def _test_pool3d(opfunc, + pool_size=(2, 2, 2), + strides=(2, 2, 2), + padding=(0, 0, 0, 0, 0, 0), + out_shape=(1, 3, 16, 16, 16)): n, c, d, h, w = te.size_var("n"), 10, 5, 224, 224 x = relay.var("x", relay.TensorType((n, c, d, h, w), "float32")) y = opfunc(x, pool_size=(1, 1, 1)) @@ -864,14 +932,14 @@ def _test_pool3d(opfunc, padding=(0, 0, 0, 0, 0, 0), out_shape=(1, 3, 16, 16, 16 dshape = (1, 3, 32, 32, 32) x = relay.var("x", shape=dshape) pool_type = 'max' if 'max' in str(opfunc) else 'avg' - y = opfunc(x, pool_size=(2, 2, 2), strides=(2, 2, 2), padding=padding) + y = opfunc(x, pool_size=pool_size, strides=strides, padding=padding) func = relay.Function([x], y) # check output shape f_out_shape = tuple(map(lambda x: int(x), run_infer_type(func).ret_type.shape)) assert out_shape == f_out_shape, \ "Output shape mismatch. expected {}, actual {}".format(out_shape, f_out_shape) data = np.random.uniform(size=dshape).astype(dtype) - ref_res = topi.testing.pool3d_ncdhw_python(data, (2, 2, 2), (2, 2, 2), + ref_res = topi.testing.pool3d_ncdhw_python(data, pool_size, strides, padding, out_shape, pool_type, False) for target, ctx in ctx_list(): intrp1 = relay.create_executor("graph", ctx=ctx, target=target) @@ -882,10 +950,12 @@ def _test_pool3d(opfunc, padding=(0, 0, 0, 0, 0, 0), out_shape=(1, 3, 16, 16, 16 _test_pool3d(relay.nn.max_pool3d, padding=(2, 0, 0, 2, 0, 0), out_shape=(1, 3, 18, 16, 16)) _test_pool3d(relay.nn.max_pool3d, padding=(0, 3, 0, 0, 3, 0), out_shape=(1, 3, 16, 19, 16)) _test_pool3d(relay.nn.max_pool3d, padding=(0, 0, 4, 0, 0, 4), out_shape=(1, 3, 16, 16, 20)) + _test_pool3d(relay.nn.max_pool3d, pool_size=2, padding=0, strides=2) _test_pool3d(relay.nn.avg_pool3d) _test_pool3d(relay.nn.avg_pool3d, padding=(2, 0, 0, 2, 0, 0), out_shape=(1, 3, 18, 16, 16)) _test_pool3d(relay.nn.avg_pool3d, padding=(0, 3, 0, 0, 3, 0), out_shape=(1, 3, 16, 19, 16)) _test_pool3d(relay.nn.avg_pool3d, padding=(0, 0, 4, 0, 0, 4), out_shape=(1, 3, 16, 16, 20)) + _test_pool3d(relay.nn.avg_pool3d, pool_size=2, padding=0, strides=2) def test_avg_pool2d_no_count_pad(): @@ -1189,7 +1259,7 @@ def _compile(ic, oc, target, data_layout, kernel_layout, dtypes): wdata = np.random.rand(*kernel_shape) * 10 parameters = {"weight": tvm.nd.array(wdata.astype(weight_dtype))} - with relay.build_config(opt_level=3): + with tvm.transform.PassContext(opt_level=3): graph, lib, params = relay.build(func, target, params=parameters) assembly = lib.get_source("asm") @@ -1304,7 +1374,7 @@ def test_depthwise_conv2d_int8(): llvm_version = tvm.target.codegen.llvm_version_major() for target in targets: if llvm_version >= 8: - with relay.build_config(opt_level=3): + with tvm.transform.PassContext(opt_level=3): graph, lib, params = relay.build(func, target, params=parameters) @@ -1332,6 +1402,45 @@ def test_bitpack_infer_type(): # TODO(@jwfromm): Need to add bitserial_conv2d & bitpack run test cases +def test_correlation(): + def _test_correlation(data_shape, kernel_size, max_displacement, stride1, stride2, padding, is_multiply, dtype='float32'): + data1 = relay.var("data1", relay.ty.TensorType(data_shape, dtype)) + data2 = relay.var("data2", relay.ty.TensorType(data_shape, dtype)) + y = relay.nn.correlation(data1, data2, kernel_size, max_displacement, stride1, stride2, + padding, is_multiply, "NCHW") + yy = run_infer_type(y) + padded_height = data_shape[2] + 2 * padding + padded_width = data_shape[3] + 2 * padding + border_size = (kernel_size - 1) // 2 + max_displacement + displacement_radius = max_displacement // stride2 + out_channel = ((2 * displacement_radius) + 1) ** 2 + out_height = (padded_height - 2 * border_size + stride1 - 1) // stride1 + out_width = (padded_width - 2 * border_size + stride1 - 1) // stride1 + assert yy.checked_type == relay.TensorType( + (data_shape[0], out_channel, out_height, out_width), dtype + ) + func = relay.Function([data1, data2], y) + data1_np = np.random.uniform(size=data_shape).astype(dtype) + data2_np = np.random.uniform(size=data_shape).astype(dtype) + ref_res = topi.testing.correlation_nchw_python(data1_np, data2_np, kernel_size, max_displacement, stride1, stride2, padding, is_multiply) + + for target, ctx in ctx_list(): + intrp1 = relay.create_executor("graph", ctx=ctx, target=target) + op_res1 = intrp1.evaluate(func)(data1_np, data2_np) + tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5) + + _test_correlation((1, 3, 10, 10), kernel_size=1, max_displacement=4, + stride1=1, stride2=1, padding=4, is_multiply=True) + _test_correlation((1, 3, 10, 10), kernel_size=1, max_displacement=5, + stride1=1, stride2=1, padding=5, is_multiply=True) + _test_correlation((5, 1, 4, 4), kernel_size=3, max_displacement=1, + stride1=2, stride2=1, padding=2, is_multiply=True) + _test_correlation((5, 1, 6, 4), kernel_size=3, max_displacement=1, + stride1=2, stride2=2, padding=2, is_multiply=False) + _test_correlation((5, 1, 11, 11), kernel_size=5, max_displacement=1, + stride1=1, stride2=1, padding=2, is_multiply=False) + + if __name__ == "__main__": test_pool1d() test_pool2d() @@ -1348,6 +1457,8 @@ def test_bitpack_infer_type(): test_flatten_infer_type() test_pad_infer_type() test_pad_run() + test_conv3d_transpose_infer_type() + test_conv3d_transpose_ncdhw_run() test_conv2d_transpose_infer_type() test_conv2d_transpose_nchw_run() test_conv2d_transpose_nhwc_run() @@ -1364,3 +1475,4 @@ def test_bitpack_infer_type(): test_upsampling3d() test_conv2d_int8_intrinsics() test_depthwise_conv2d_int8() + test_correlation() diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 4deed4232d3d..f50a69278402 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -663,6 +663,106 @@ def verify_reverse(dshape, axis): verify_reverse((2, 3, 4), -1) +def test_scatter(): + + def ref_scatter(data, indices, updates, axis=0): + idx = np.indices(indices.shape).reshape(indices.ndim, -1) + + updated_idx = np.copy(idx) + indices = indices.reshape(-1) + for i in range(len(indices)): + updated_idx[axis, i] = indices[i] + scattered = np.copy(data) + scattered[tuple(updated_idx)] = updates[tuple(idx)] + return scattered + + def verify_scatter(dshape, ishape, axis=0): + d = relay.var("d", relay.TensorType(dshape, "float32")) + i = relay.var("i", relay.TensorType(ishape, "int64")) + u = relay.var("u", relay.TensorType(ishape, "float32")) + z = relay.op.scatter(d, i, u, axis) + + func = relay.Function([d, i, u], z) + + data_np = np.random.uniform(size=dshape).astype("float32") + updates_np = np.random.uniform(size=ishape).astype("float32") + indices_np = np.random.randint(-dshape[axis], dshape[axis] - 1, ishape).astype("int64") + + ref_res = ref_scatter(data_np, indices_np, updates_np, axis) + # TODO(mbrookhart): expand testing when adding more backend schedules + for target, ctx in [("llvm", tvm.cpu())]: + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, ctx=ctx, target=target) + op_res = intrp.evaluate(func)(data_np, indices_np, updates_np) + tvm.testing.assert_allclose( + op_res.asnumpy(), ref_res, rtol=1e-5) + + verify_scatter((10, ), (10, ), 0) + verify_scatter((10, 5), (10, 5), -2) + verify_scatter((10, 5), (10, 5), -1) + verify_scatter((10, 5), (3, 5), 0) + verify_scatter((12, 4), (7, 2), 1) + verify_scatter((2, 3, 4), (1, 3, 4), 0) + verify_scatter((2, 3, 4), (2, 1, 4), 1) + verify_scatter((2, 3, 4), (2, 3, 1), 2) + verify_scatter((2, 3, 4, 5), (1, 3, 4, 5), 0) + verify_scatter((6, 3, 4, 5), (2, 3, 4, 5), 1) + verify_scatter((2, 3, 8, 5), (2, 3, 1, 1), 2) + verify_scatter((16, 16, 4, 5), (16, 16, 4, 5), 3) + + +def test_gather(): + def verify_gather(data, axis, indices, ref_res): + data = np.asarray(data, dtype='float32') + indices = np.asarray(indices, dtype='int32') + ref_res = np.asarray(ref_res) + + d = relay.var("x", relay.TensorType(data.shape, "float32")) + i = relay.var("y", relay.TensorType(indices.shape, "int32")) + z = relay.gather(d, axis, i) + + func = relay.Function([d, i], z) + + for target, ctx in ctx_list(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, ctx=ctx, target=target) + op_res = intrp.evaluate(func)(data, indices) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, + rtol=1e-5) + + verify_gather([[1, 2], [3, 4]], + 1, + [[0, 0], [1, 0]], + [[1, 1], [4, 3]]) + verify_gather([[[0, 1, 2], [3, 4, 5]], [[6, 7, 8], [9, 10, 11]]], + 0, + [[[1, 0, 1], [1, 1, 0]]], + [[[6, 1, 8], [9, 10, 5]]]) + verify_gather([[[-0.2321, -0.2024, -1.7624], [-0.3829, -0.4246, 0.2448], + [0.1822, 0.2360, -0.8965], [0.4497, -0.2224, 0.6103]], + [[0.0408, -0.7667, -0.4303], [-0.3216, 0.7489, -0.1502], + [0.0144, -0.4699, -0.0064], [-0.0768, -1.6064, 1.3390]]], + 1, + [[[2, 2, 0], [1, 0, 3]], [[3, 2, 0], [1, 0, 0]]], + [[[0.1822, 0.2360, -1.7624], [-0.3829, -0.2024, 0.6103]], + [[-0.0768, -0.4699, -0.4303], [-0.3216, -0.7667, -0.4303]]]) + verify_gather([[[0.3050, 1.6986, 1.1034], [0.7020, -0.6960, -2.1818], + [0.3116, -0.5773, -0.9912], [0.0835, -1.3915, -1.0720]], + [[0.1694, -0.6091, -0.6539], [-0.5234, -0.1218, 0.5084], + [0.2374, -1.9537, -2.0078], [-0.5700, -1.0302, 0.1558]]], + 2, + [[[1, 1, 0, 1], [0, 0, 2, 2], [1, 2, 1, 2], [2, 2, 1, 0]], + [[0, 0, 1, 2], [2, 2, 1, 0], [1, 2, 0, 0], [0, 2, 0, 2]]], + [[[1.6986, 1.6986, 0.3050, 1.6986], + [0.7020, 0.7020, -2.1818, -2.1818], + [-0.5773, -0.9912, -0.5773, -0.9912], + [-1.0720, -1.0720, -1.3915, 0.0835]], + [[0.1694, 0.1694, -0.6091, -0.6539], + [0.5084, 0.5084, -0.1218, -0.5234], + [-1.9537, -2.0078, 0.2374, 0.2374], + [-0.5700, 0.1558, -0.5700, 0.1558]]]) + + def test_gather_nd(): def verify_gather_nd(xshape, yshape, y_data): x = relay.var("x", relay.TensorType(xshape, "float32")) @@ -747,6 +847,58 @@ def verify_unravel_index(indices, shape, dtype): # output which is inline with Tensorflow # verify_unravel_index([0, 1, 2, 5], [2, 2], dtype) +def test_sparse_to_dense(): + def verify_sparse_to_dense(sparse_indices, sparse_values, default_value, output_shape, xpected): + sparse_indices_data = np.array(sparse_indices) + sparse_values_data = np.array(sparse_values) + default_value_data = np.array(default_value) + + a = relay.var("a", relay.TensorType(sparse_indices_data.shape, str(sparse_indices_data.dtype))) + b = relay.var("b", relay.TensorType(sparse_values_data.shape, str(sparse_values_data.dtype))) + if default_value is None: + args = [a, b] + d = relay.sparse_to_dense(a, output_shape, b) + else: + c = relay.var("c", relay.TensorType(default_value_data.shape, str(default_value_data.dtype))) + args = [a, b, c] + d = relay.sparse_to_dense(a, output_shape, b, c) + + zz = run_infer_type(d) + assert zz.checked_type == relay.ty.TensorType(output_shape, str(sparse_values_data.dtype)) + + func = relay.Function(args, d) + for target, ctx in ctx_list(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, ctx=ctx, target=target) + if default_value is None: + op_res = intrp.evaluate(func)(sparse_indices_data, sparse_values_data) + else: + op_res = intrp.evaluate(func)( + sparse_indices_data, sparse_values_data, default_value_data + ) + tvm.testing.assert_allclose(op_res.asnumpy(), xpected, rtol=1e-5) + + + verify_sparse_to_dense(1, 3, 0, [5], [0, 3, 0, 0, 0]) # scalar + verify_sparse_to_dense([0, 1, 4], [3, 3, 3], 0, [5], [3, 3, 0, 0, 3]) # vector + verify_sparse_to_dense([[0, 0], [1, 2]], [1, 2], 0, [3, 4], [[1, 0, 0, 0], [0, 0, 2, 0], [0, 0, 0, 0]]) # nXd + verify_sparse_to_dense( + [[0, 0, 0], [1, 2, 3]], + [1, 2], + 4, + [2, 3, 4], + [[[1, 4, 4, 4], [4, 4, 4, 4], [4, 4, 4, 4]], [[4, 4, 4, 4], [4, 4, 4, 4], [4, 4, 4, 2]]] + ) # nXd + verify_sparse_to_dense([0, 1, 4], [3.1, 3.1, 3.1], 3.5, [5], [3.1, 3.1, 3.5, 3.5, 3.1]) # floats + verify_sparse_to_dense(1, 3, None, [5], [0, 3, 0, 0, 0]) # default value not specified + + #negative test cases + #sparse indices should be ints + #verify_sparse_to_dense([[0.1, 1.1, 4.1], [0,2,4]], [3.1, 3.1, 3.1], 3.5, [5], [3.1, 3.1, 3.5, 3.5, 3.1]) + #sparse_values should be 0d or 1d only + #verify_sparse_to_dense([[0, 1, 4], [0, 2, 4]], [[[3.1, 3.1, 3.1]]], 3.5, [5], [3.1, 3.1, 3.5, 3.5, 3.1]) + #sparse_indices should not be > 2d tensor + #verify_sparse_to_dense([[[[0, 1, 4], [0, 2, 4]]]], [[[[3.1, 3.1, 3.1]]]], 3.5, [5], [3.1, 3.1, 3.5, 3.5, 3.1]) if __name__ == "__main__": test_arange() @@ -780,4 +932,5 @@ def verify_unravel_index(indices, shape, dtype): test_gather_nd() test_isfinite() test_isinf() - test_unravel_index() \ No newline at end of file + test_unravel_index() + test_sparse_to_dense() diff --git a/tests/python/relay/test_op_level4.py b/tests/python/relay/test_op_level4.py index bbe2c69d6294..74231cb0d5a1 100644 --- a/tests/python/relay/test_op_level4.py +++ b/tests/python/relay/test_op_level4.py @@ -58,11 +58,11 @@ def check_binary_op(opfunc, ref): def test_cmp_type(): for op, ref in ((relay.greater, np.greater), - (relay.greater_equal, np.greater_equal), - (relay.less, np.less), - (relay.less_equal, np.less_equal), - (relay.equal, np.equal), - (relay.not_equal, np.not_equal)): + (relay.greater_equal, np.greater_equal), + (relay.less, np.less), + (relay.less_equal, np.less_equal), + (relay.equal, np.equal), + (relay.not_equal, np.not_equal)): x = relay.var("x", relay.TensorType((10, 4), "float32")) y = relay.var("y", relay.TensorType((5, 10, 1), "float32")) z = op(x, y) @@ -165,7 +165,10 @@ def verify_reduce(funcs, data, axis, keepdims, exclude, output, dtype="float32") dtype = "bool" if ref_func in [np.all, np.any] else dtype x = relay.var("x", relay.TensorType(data, dtype)) - z = test_func(x, axis, keepdims, exclude) + if test_func == relay.logsumexp: + z = test_func(x, axis, keepdims) + else: + z = test_func(x, axis, keepdims, exclude) zz = run_infer_type(z) if axis: assert "axis=" in z.astext() @@ -215,6 +218,14 @@ def _wrapper(data, axis=None, keepdims=False): return func(data, axis=axis).reshape(out_shape) return _wrapper + def _np_log_sum_exp(x, axis, keepdims=False): + max_x = np.max(x, axis=axis, keepdims=True) + x = np.log(np.sum(np.exp(x - max_x), axis=axis, keepdims=True)) + x = x + max_x + if not keepdims: + x = np.squeeze(x, axis=axis) + return x + d1, d2, d3, d4 = te.var("d1"), te.var("d2"), te.var("d3"), te.var("d4") for func in [[relay.sum, np.sum], [relay.max, np.max], @@ -225,6 +236,7 @@ def _wrapper(data, axis=None, keepdims=False): [relay.prod, np.prod], [relay.all, np.all], [relay.any, np.any], + [relay.logsumexp, _np_log_sum_exp], [relay.argmin, _with_keepdims(np.argmin)], [relay.argmax, _with_keepdims(np.argmax)]]: verify_reduce(func, (d1, d2, d3, d4), None, False, False, ()) @@ -284,38 +296,68 @@ def test_mean_var_std(): def test_strided_slice(): - def verify(dshape, begin, end, strides, output, test_ref=True): + def verify(dshape, begin, end, strides, output, slice_mode="end", + attr_const=True, test_ref=True, dtype="int32"): x = relay.var("x", relay.TensorType(dshape, "float32")) - z = relay.strided_slice(x, begin=begin, end=end, strides=strides) + ndim = len(dshape) + begin = begin if begin else [0] * ndim + end = end if end else list(dshape) + + # target numpy result + x_data = np.random.uniform(size=dshape).astype("float32") + ref_res = topi.testing.strided_slice_python( + x_data, begin, end, strides, slice_mode) + + if attr_const: + begin = relay.const(begin, dtype=dtype) + end = relay.const(end, dtype=dtype) + + if strides: + if attr_const: + strides = relay.const(strides, dtype=dtype) + z = relay.strided_slice(x, + begin=begin, + end=end, + strides=strides, + slice_mode=slice_mode) + else: + z = relay.strided_slice(x, + begin=begin, + end=end, + slice_mode=slice_mode) func = relay.Function([x], z) + func = run_infer_type(func) text = func.astext() assert "begin=" in text assert "end=" in text + if output: assert func.body.checked_type == relay.ty.TensorType(output, "float32") + if not test_ref: return - x_data = np.random.uniform(size=dshape).astype("float32") - ref_res = topi.testing.strided_slice_python( - x_data, begin, end, strides) for target, ctx in ctx_list(): intrp = relay.create_executor("graph", ctx=ctx, target=target) op_res = intrp.evaluate(func)(x_data) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res) - d1, d2, d3, d4 = te.var("d1"), te.var("d2"), te.var("d3"), te.var("d4") - verify((d1, d2, 3), [None, None, 1], [None, None, 2], None, (d1, d2, 1), False) + verify((1, 3, 10, 10), [0, 0, 0, 0], [-1, 3, 10, 10], [1], (0, 3, 10, 10), dtype="int64") + verify((1, 224, 224, 3), [0, 20, 20, 0], [1, 140, 140, 3], + [1, 1, 1, 1], (1, 120, 120, 3), dtype="int64") + verify((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1], (1, 3, 3), dtype="int16") verify((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, -1, 2], (3, 1, 2)) - verify((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1], (1, 3, 3)) - verify((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], (1, 4, 3)) - verify((3, 4, 3), [1, 0, 0], [2, 2, 3], [1, 1, 2], (1, 2, 2)) - verify((3, 4, 3), [1, -1, 0], [2, -3, 3], [1, -1, 1], (1, 2, 3)) + verify((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, -1, 2], (3, 1, 2), attr_const=False) verify((3, 4, 3), [1, 1, 0], [4, 4, 3], None, (2, 3, 3)) verify((3, 4, 3), [1, 1, 0], [4, 1000, 3], None, (2, 3, 3)) verify((3, 4, 3), [1, 1, 0], [4, 4], None, (2, 3, 3)) verify((3, 4, 3), [1, 1], [4, 4, 3], None, (2, 3, 3)) - + verify((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], (1, 4, 3)) + verify((3, 4, 3), [1, -1, 0], [2, -3, 3], [1, -1, 1], (1, 2, 3)) + verify((3, 4, 3), [1, 0, 0], [3, -1, 3], [1, 1, 1], + (2, 4, 3), slice_mode="size", test_ref=False) + verify((3, 4, 3), [1, 0, 0], [-1, 2, 3], [1, 1, 1], + (2, 2, 3), slice_mode="size", test_ref=True) def test_strided_set(): def verify(dshape, begin, end, strides, vshape, test_ref=True): diff --git a/tests/python/relay/test_op_level5.py b/tests/python/relay/test_op_level5.py index b29b69667653..14d43c0a5fca 100644 --- a/tests/python/relay/test_op_level5.py +++ b/tests/python/relay/test_op_level5.py @@ -63,11 +63,53 @@ def verify_resize(dshape, scale, method, layout): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, ctx=ctx, target=target) op_res = intrp.evaluate(func)(x_data) - tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-4) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-4, atol=1e-6) for method in ["bilinear", "nearest_neighbor"]: for layout in ["NHWC", "NCHW"]: verify_resize((1, 4, 4, 4), 2, method, layout) +def test_resize3d_infer_type(): + n, c, d, h, w = te.size_var("n"), te.size_var("c"), te.size_var("d"), te.size_var("h"), te.size_var("w") + x = relay.var("x", relay.TensorType((n, c, d, h, w), "int8")) + td, th, tw = te.var("td"), te.var("th"), te.var("tw") + z = relay.image.resize3d(x, (td, th, tw)) + zz = run_infer_type(z) + assert zz.checked_type == relay.TensorType((n, c, td, th, tw), "int8") + + x = relay.var("x", relay.TensorType((n, c, d, h, w), "int8")) + z= relay.image.resize3d(x, (10, 10, 20), "NCDHW", "trilinear", "align_corners") + assert "size=" in z.astext() + zz = run_infer_type(z) + assert zz.checked_type == relay.TensorType((n, c, 10, 10, 20), "int8") + +def test_resize3d(): + def verify_resize(dshape, scale, method, layout): + if layout == "NDHWC": + size = (dshape[1] * scale, dshape[2] * scale, dshape[3] * scale) + else: + size = (dshape[2] * scale, dshape[3] * scale, dshape[4] * scale) + + x_data = np.random.uniform(size=dshape).astype("float32") + if method == "trilinear": + ref_res = topi.testing.trilinear_resize3d_python(x_data, size, layout) + else: + ref_res = topi.testing.upsampling3d_python(x_data, (scale, scale, scale), layout) + x = relay.var("x", relay.TensorType(dshape, "float32")) + z = relay.image.resize3d(x, size, layout, method, "align_corners") + assert "size=" in z.astext() + zz = run_infer_type(z) + assert zz.checked_type == relay.TensorType(ref_res.shape, "float32") + func = relay.Function([x], z) + + for target, ctx in ctx_list(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, ctx=ctx, target=target) + op_res = intrp.evaluate(func)(x_data) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-4) + for method in ["trilinear", "nearest_neighbor"]: + for layout in ["NDHWC", "NCDHW"]: + verify_resize((1, 4, 4, 4, 4), 2, method, layout) + def test_crop_and_resize(): def verify_crop_and_resize(img_shape, boxes, box_indices, crop_size, layout, method, extrapolation_value=0.0): @@ -202,6 +244,7 @@ def verify_get_valid_counts(dshape, score_threshold, id_index, score_index): np_data = np.random.uniform(low=-2, high=2, size=dshape).astype(dtype) np_out1 = np.zeros(shape=(batch_size,)) np_out2 = np.zeros(shape=dshape).astype(dtype) + np_out3 = np.zeros(shape=(batch_size, num_anchor)) for i in range(batch_size): np_out1[i] = 0 inter_idx = 0 @@ -211,10 +254,12 @@ def verify_get_valid_counts(dshape, score_threshold, id_index, score_index): for k in range(elem_length): np_out2[i, inter_idx, k] = np_data[i, j, k] np_out1[i] += 1 + np_out3[i, inter_idx] = j inter_idx += 1 if j >= np_out1[i]: for k in range(elem_length): np_out2[i, j, k] = -1.0 + np_out3[i, j] = -1 x = relay.var("x", relay.ty.TensorType(dshape, dtype)) z = relay.vision.get_valid_counts(x, score_threshold, id_index, score_index) @@ -229,6 +274,7 @@ def verify_get_valid_counts(dshape, score_threshold, id_index, score_index): if target == 'cuda': return tvm.testing.assert_allclose(out[1].asnumpy(), np_out2, rtol=1e-3, atol=1e-04) + tvm.testing.assert_allclose(out[2].asnumpy(), np_out3, rtol=1e-3, atol=1e-04) verify_get_valid_counts((1, 2500, 6), 0, 0, 1) verify_get_valid_counts((1, 2500, 5), -1, -1, 0) @@ -237,69 +283,79 @@ def verify_get_valid_counts(dshape, score_threshold, id_index, score_index): def test_non_max_suppression(): - def verify_nms(x0_data, x1_data, dshape, ref_res, ref_indices_res, + def verify_nms(x0_data, x1_data, x2_data, dshape, ref_res, ref_indices_res, iou_threshold=0.5, force_suppress=False, top_k=-1, check_type_only=False): x0 = relay.var("x0", relay.ty.TensorType(dshape, "float32")) x1 = relay.var("x1", relay.ty.TensorType((dshape[0],), "int32")) - z = relay.vision.non_max_suppression(x0, x1, max_output_size = -1, \ - iou_threshold = iou_threshold, force_suppress = force_suppress, \ - top_k = top_k, return_indices=False) - z_indices = relay.vision.non_max_suppression(x0, x1, max_output_size = -1, \ - iou_threshold = iou_threshold, force_suppress = force_suppress, \ - top_k = top_k) + x2 = relay.var("x2", relay.ty.TensorType((dshape[0], dshape[1]), "int32")) + z = relay.vision.non_max_suppression(x0, x1, x2, max_output_size=-1, \ + iou_threshold=iou_threshold, force_suppress=force_suppress, \ + top_k=top_k, return_indices=False) + z_indices = relay.vision.non_max_suppression(x0, x1, x2, max_output_size=-1, \ + iou_threshold=iou_threshold, force_suppress=force_suppress, \ + top_k=top_k, return_indices=True) + if isinstance(z_indices, relay.expr.TupleWrapper): + z_indices = z_indices.astuple() assert "iou_threshold" in z.astext() assert "iou_threshold" in z_indices.astext() zz = run_infer_type(z) zz_indices = run_infer_type(z_indices) assert zz.checked_type == relay.ty.TensorType(dshape, "float32") - assert zz_indices.checked_type == relay.ty.TensorType((dshape[0], dshape[1]), "int32") + assert zz_indices.checked_type == relay.ty.TupleType( + [relay.ty.TensorType((dshape[0], dshape[1]), "int32"), + relay.ty.TensorType((dshape[0], 1), "int32")]) if check_type_only: return - func = relay.Function([x0, x1], z) + func = relay.Function([x0, x1, x2], z) func = run_infer_type(func) - func_indices = relay.Function([x0, x1], z_indices) + func_indices = relay.Function([x0, x1, x2], z_indices) func_indices = run_infer_type(func_indices) for target, ctx in ctx_list(): intrp1 = relay.create_executor("graph", ctx=ctx, target=target) - op_res1 = intrp1.evaluate(func)(x0_data, x1_data) - op_indices_res1 = intrp1.evaluate(func_indices)(x0_data, x1_data) + op_res1 = intrp1.evaluate(func)(x0_data, x1_data, x2_data) tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5) - tvm.testing.assert_allclose(op_indices_res1.asnumpy(), ref_indices_res, rtol=1e-5) intrp2 = relay.create_executor("debug", ctx=ctx, target=target) - op_res2 = intrp2.evaluate(func)(x0_data, x1_data) - op_indices_res2 = intrp2.evaluate(func_indices)(x0_data, x1_data) + op_res2 = intrp2.evaluate(func)(x0_data, x1_data, x2_data) tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=1e-5) - tvm.testing.assert_allclose(op_indices_res2.asnumpy(), ref_indices_res, rtol=1e-5) + if target == 'cuda': + return + op_indices_res1 = intrp1.evaluate(func_indices)(x0_data, x1_data, x2_data) + tvm.testing.assert_allclose(op_indices_res1[0].asnumpy(), ref_indices_res, rtol=1e-5) + op_indices_res2 = intrp2.evaluate(func_indices)(x0_data, x1_data, x2_data) + tvm.testing.assert_allclose(op_indices_res2[0].asnumpy(), ref_indices_res, rtol=1e-5) np_data = np.array([[[0, 0.8, 1, 20, 25, 45], [1, 0.7, 30, 60, 50, 80], [0, 0.4, 4, 21, 19, 40], [2, 0.9, 35, 61, 52, 79], [1, 0.5, 100, 60, 70, 110]]]).astype("float32") np_valid_count = np.array([4]).astype("int32") + + np_indices = np.array([[0, 1, 3, 4, -1]]).astype("int32") + np_result = np.array([[[2, 0.9, 35, 61, 52, 79], [0, 0.8, 1, 20, 25, 45], [-1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1]]]) - np_indices_result = np.array([[3, 0, -1, -1, -1]]) + np_indices_result = np.array([[4, 0, -1, -1, -1]]) num_anchors = 5 dshape = (te.size_var("n"), num_anchors, 6) - verify_nms(np_data, np_valid_count, dshape, np_result, np_indices_result, + verify_nms(np_data, np_valid_count, np_indices, dshape, np_result, np_indices_result, force_suppress=True, top_k=2, check_type_only=True) dshape = (1, num_anchors, 6) - verify_nms(np_data, np_valid_count, dshape, np_result, np_indices_result, + verify_nms(np_data, np_valid_count, np_indices, dshape, np_result, np_indices_result, force_suppress=True, top_k=2, check_type_only=False) np_result = np.array([[[2, 0.9, 35, 61, 52, 79], [0, 0.8, 1, 20, 25, 45], [1, 0.7, 30, 60, 50, 80], [-1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1]]]) - np_indices_result = np.array([[3, 0, 1, -1, -1]]) + np_indices_result = np.array([[4, 0, 1, -1, -1]]) dshape = (te.size_var("n"), num_anchors, 6) - verify_nms(np_data, np_valid_count, dshape, np_result, + verify_nms(np_data, np_valid_count, np_indices, dshape, np_result, np_indices_result, check_type_only=True) dshape = (1, num_anchors, 6) - verify_nms(np_data, np_valid_count, dshape, np_result, + verify_nms(np_data, np_valid_count, np_indices, dshape, np_result, np_indices_result, top_k=3) @@ -342,7 +398,7 @@ def test_default_value(): assert ret.checked_type == ref_type - nms = relay.vision.non_max_suppression(mtl[0], mtl[1], return_indices=False) + nms = relay.vision.non_max_suppression(mtl[0], mtl[1], mtl[0], return_indices=False) func = relay.Function([cls_prob, loc_pred, anchors], nms) func = run_infer_type(func) for target, ctx in ctx_list(): @@ -781,9 +837,61 @@ def _convert_data(indata, kernel, out, layout=None): data_layout='NHWC', kernel_layout='HWI') +def test_affine_grid(): + def verify_affine_grid(num_batch, target_shape): + dtype = 'float32' + data_shape = (num_batch, 2, 3) + data = relay.var("data", relay.ty.TensorType(data_shape, dtype)) + y = relay.image.affine_grid(data, target_shape) + yy = run_infer_type(y) + assert yy.checked_type == relay.ty.TensorType((num_batch, len(target_shape), *target_shape), dtype) + + func = relay.Function([data], y) + data_np = np.random.uniform(size=data_shape).astype(dtype) + ref_res = topi.testing.affine_grid_python(data_np, target_shape) + + for target, ctx in ctx_list(): + for kind in ["graph", "debug"]: + intrp1 = relay.create_executor(kind, ctx=ctx, target=target) + op_res1 = intrp1.evaluate(func)(data_np) + tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5) + + verify_affine_grid(1, (16, 32)) + verify_affine_grid(4, (16, 32)) + + +def test_grid_sample(): + def verify_grid_sample(data_shape, grid_shape): + dtype = 'float32' + batch, channel, _, _ = data_shape + _, _, out_height, out_width = grid_shape + data = relay.var("data", relay.ty.TensorType(data_shape, dtype)) + grid = relay.var("grid", relay.ty.TensorType(grid_shape, dtype)) + y = relay.image.grid_sample(data, grid, method='bilinear', layout='NCHW') + yy = run_infer_type(y) + assert yy.checked_type == relay.TensorType((batch, channel, out_height, out_width), dtype) + func = relay.Function([data, grid], y) + + data_np = np.random.uniform(size=data_shape).astype(dtype) + grid_np = np.random.uniform(size=grid_shape, low=-1.5, high=1.5).astype(dtype) + ref_res = topi.testing.grid_sample_nchw_python(data_np, grid_np, method='bilinear') + + for target, ctx in ctx_list(): + for kind in ["graph", "debug"]: + intrp1 = relay.create_executor(kind, ctx=ctx, target=target) + op_res1 = intrp1.evaluate(func)(data_np, grid_np) + tvm.testing.assert_allclose( + op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5) + + verify_grid_sample((4, 4, 16, 32), (4, 2, 8, 8)) + verify_grid_sample((4, 4, 16, 32), (4, 2, 32, 32)) + + if __name__ == "__main__": test_resize_infer_type() test_resize() + test_resize3d_infer_type() + test_resize3d() test_crop_and_resize() test_multibox_prior() test_multibox_transform_loc() @@ -799,3 +907,5 @@ def _convert_data(indata, kernel, out, layout=None): test_space_to_depth() test_dilation2d_infer_type() test_dilation2d_run() + test_affine_grid() + test_grid_sample() diff --git a/tests/python/relay/test_op_qnn_concatenate.py b/tests/python/relay/test_op_qnn_concatenate.py index 03ab9eeb1321..fb60e9805206 100644 --- a/tests/python/relay/test_op_qnn_concatenate.py +++ b/tests/python/relay/test_op_qnn_concatenate.py @@ -144,7 +144,32 @@ def test_same_i_qnn_params(): op_res = intrp.evaluate(func)(x_data, y_data) np.testing.assert_equal(op_res.asnumpy(), golden_output) +def test_call_input(): + # This tests the case where the input to concatenate is not explicitly a + # tuple node but is instead a call node. + x_data = np.ones(shape=(64,)).astype('uint8') + + x = relay.var("x", shape=(64,), dtype='uint8') + x_scale = relay.const(1, 'float32') + y_scale = relay.const(1, 'float32') + x_zero_point = relay.const(0, 'int32') + y_zero_point = relay.const(0, 'int32') + + tup = relay.split(x, 2, axis=0) + z = relay.qnn.op.concatenate(tup, + input_scales=(x_scale, y_scale), + input_zero_points=(x_zero_point, y_zero_point), + output_scale=y_scale, + output_zero_point=relay.const(0, 'int32'), + axis=0) + func = relay.Function([x], z) + + intrp = relay.create_executor("graph", ctx=tvm.cpu(0), target="llvm") + op_res = intrp.evaluate(func)(x_data) + np.testing.assert_equal(op_res.asnumpy(), x_data) + if __name__ == '__main__': + test_call_input() test_same_io_qnn_params() test_different_io_qnn_params() test_few_same_io_qnn_params() diff --git a/tests/python/relay/test_op_qnn_conv2d.py b/tests/python/relay/test_op_qnn_conv2d.py index 6911c52958f0..fcb335fd0d63 100644 --- a/tests/python/relay/test_op_qnn_conv2d.py +++ b/tests/python/relay/test_op_qnn_conv2d.py @@ -182,7 +182,7 @@ def get_inputs(data_shape, data_dtype, kernel_shape, kernel_dtype): def get_output(func, golden_inputs): - with relay.build_config(opt_level=2): + with tvm.transform.PassContext(opt_level=2): golden_data, golden_weight = golden_inputs params = {'kernel': golden_weight} graph, lib, params = relay.build(func, "llvm", params=params) @@ -655,7 +655,7 @@ def test_tflite_large_irregular(): golden_data = np.full(data_shape, 127).astype('uint8') golden_weight = np.full(kernel_shape, 127).astype('uint8') - with relay.build_config(opt_level=2): + with tvm.transform.PassContext(opt_level=2): params = {'kernel': golden_weight} graph, lib, params = relay.build(qnn_func, "llvm", params=params) mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0)) @@ -698,7 +698,7 @@ def test_tflite_output_multiplier_greater_than_one(): -1, -1, 1, 1)).reshape(kernel_shape) golden_weight = golden_weight.astype('uint8') - with relay.build_config(opt_level=2): + with tvm.transform.PassContext(opt_level=2): params = {'kernel': golden_weight} graph, lib, params = relay.build(qnn_func, "llvm", params=params) mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0)) @@ -744,7 +744,7 @@ def test_tflite_anistropic_strides(): golden_weight = np.array((129, 131, 133, 135)).reshape(kernel_shape) golden_weight = golden_weight.astype('uint8') - with relay.build_config(opt_level=2): + with tvm.transform.PassContext(opt_level=2): params = {'kernel': golden_weight} graph, lib, params = relay.build(qnn_func, "llvm", params=params) mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0)) @@ -789,7 +789,7 @@ def test_broadcast_layout(): func = relay.add(func, bias) func = relay.Function(relay.analysis.free_vars(func), func) mod = tvm.IRModule.from_expr(func) - with relay.build_config(opt_level=3): + with tvm.transform.PassContext(opt_level=3): graph, lib, params = relay.build(mod, "llvm -mcpu=skylake-avx512") def test_depthwise_depth_multiplier(): diff --git a/tests/python/relay/test_op_qnn_dense.py b/tests/python/relay/test_op_qnn_dense.py index 3cfcfd165b46..0ba3210e8d8b 100644 --- a/tests/python/relay/test_op_qnn_dense.py +++ b/tests/python/relay/test_op_qnn_dense.py @@ -167,7 +167,7 @@ def qnn_dense_driver(test_configuration): mod = relay.Function(relay.analysis.free_vars(mod), mod) mod = tvm.IRModule.from_expr(mod) mod = relay.qnn.transform.CanonicalizeOps()(mod) - with relay.build_config(opt_level=2): + with tvm.transform.PassContext(opt_level=2): graph, lib, params = relay.build(mod, "llvm", params=None) mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0)) mod.set_input(quantized_data_name, test_configuration[quantized_data_name]) diff --git a/tests/python/relay/test_op_qnn_dequantize.py b/tests/python/relay/test_op_qnn_dequantize.py index febf5c5e6ecc..3c82b7fa0afa 100644 --- a/tests/python/relay/test_op_qnn_dequantize.py +++ b/tests/python/relay/test_op_qnn_dequantize.py @@ -30,7 +30,7 @@ def quantize_test_driver(in_dtype, quant_args, in_data, verify_output_data): input_zero_point=input_zero_point) mod = relay.Function(relay.analysis.free_vars(quantized_output), quantized_output) mod = tvm.IRModule.from_expr(mod) - with relay.build_config(opt_level=3): + with tvm.transform.PassContext(opt_level=3): graph, lib, params = relay.build(mod, "llvm", params=None) rt_mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0)) rt_mod.set_input(input_data=in_data) diff --git a/tests/python/relay/test_op_qnn_quantize.py b/tests/python/relay/test_op_qnn_quantize.py index 09b04d8925c6..a284e8bdbc82 100644 --- a/tests/python/relay/test_op_qnn_quantize.py +++ b/tests/python/relay/test_op_qnn_quantize.py @@ -32,7 +32,7 @@ def quantize_test_driver(in_dtype, quant_args, axis, out_dtype, in_data, verify_ out_dtype=out_dtype) mod = relay.Function(relay.analysis.free_vars(quantized_output), quantized_output) mod = tvm.IRModule.from_expr(mod) - with relay.build_config(opt_level=3): + with tvm.transform.PassContext(opt_level=3): graph, lib, params = relay.build(mod, "llvm", params=None) rt_mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0)) rt_mod.set_input(input_data=in_data) diff --git a/tests/python/relay/test_op_qnn_requantize.py b/tests/python/relay/test_op_qnn_requantize.py index 81233972cb28..fb52b3030582 100644 --- a/tests/python/relay/test_op_qnn_requantize.py +++ b/tests/python/relay/test_op_qnn_requantize.py @@ -24,7 +24,7 @@ roundings = ["UPWARD", "TONEAREST"] def verify(mod, goldens): - with relay.build_config(opt_level=3): + with tvm.transform.PassContext(opt_level=3): graph, lib, params = relay.build(mod, "llvm", params=None) golden_data, golden_output = goldens rt_mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0)) diff --git a/tests/python/relay/test_pass_alter_op_layout.py b/tests/python/relay/test_pass_alter_op_layout.py index 2a2e265dbe5b..bbe10c773ff9 100644 --- a/tests/python/relay/test_pass_alter_op_layout.py +++ b/tests/python/relay/test_pass_alter_op_layout.py @@ -18,10 +18,11 @@ import pytest import tvm -from tvm import te from tvm import relay from tvm.relay import transform, analysis from tvm.relay.testing.temp_op_attr import TempOpAttr +from tvm.relay.testing import ctx_list, run_infer_type +import numpy as np def run_opt_pass(expr, passes): passes = passes if isinstance(passes, list) else [passes] @@ -153,6 +154,56 @@ def expected(): assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) +def test_alter_layout_lrn(): + """Test alternating the layout of a conv2d. + The layout of broadcast operators and the weight should be changed accordingly. + """ + def before(): + x = relay.var("x", shape=(1, 64, 56, 56)) + bias = relay.var("bias") + weight = relay.var("weight") + y = relay.nn.conv2d(x, weight, channels=64, kernel_size=(3, 3), padding=(1, 1)) + y = relay.nn.max_pool2d(y, pool_size=(2, 2)) + y = relay.nn.lrn(y) + y = relay.Function(analysis.free_vars(y), y) + return y + + def alter_conv2d(attrs, inputs, tinfos, out_type): + data, weight = inputs + new_attrs = dict(attrs) + new_attrs['data_layout'] = 'NCHW16c' + new_attrs['kernel_layout'] = 'OIHW16i' + return relay.nn.conv2d(data, weight, **new_attrs) + + + def expected(): + x = relay.var("x", shape=(1, 64, 56, 56)) + bias = relay.var("bias", shape=(64,)) + weight = relay.var("weight", shape=(64, 64, 3, 3)) + + y = relay.layout_transform(x, "NCHW", "NCHW16c") + w = relay.layout_transform(weight, "OIHW", "OIHW16i") + y = relay.nn.conv2d(y, w, + channels=64, + kernel_size=(3, 3), + padding=(1, 1), + kernel_layout="OIHW16i", + data_layout="NCHW16c") + y = relay.nn.max_pool2d(y, pool_size=(2, 2), layout="NCHW16c") + y = relay.layout_transform(y, "NCHW16c", "NCHW") + y = relay.nn.lrn(y) + y = relay.Function(analysis.free_vars(y), y) + return y + + with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d): + a = before() + a = run_opt_pass(a, [transform.CanonicalizeOps(), + transform.AlterOpLayout()]) + b = run_opt_pass(expected(), transform.InferType()) + + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + + def test_alter_layout_dual_path(): """ @@ -570,7 +621,10 @@ def before(): x = relay.var("x", shape=(1, 32, 28, 28)) weight = relay.var('weight', shape=(32, 32, 3, 3)) y = relay.nn.conv2d(x, weight, channels=32, kernel_size=(3, 3), padding=(1, 1)) - y = relay.strided_slice(y, begin=[0, 16], end=[None, None]) + y = relay.strided_slice(y, + begin=relay.const([0, 16], "int32"), + end=relay.const([1, 33], "int32"), + strides=relay.const([1, 1], "int32")) y = relay.Function(analysis.free_vars(y), y) return y @@ -582,22 +636,41 @@ def alter_conv2d(attrs, inputs, tinfos, out_type): def expected(): x = relay.var("x", shape=(1, 32, 28, 28)) - weight = relay.var("weight") + weight = relay.var("weight", shape=(32, 32, 3, 3)) + weight = relay.layout_transform(weight, "OIHW", "OIHW4i4o") x = relay.layout_transform(x, "NCHW", "NCHW4c") - y = relay.nn.conv2d(x, weight, channels=32, kernel_size=(3, 3), padding=(1, 1), - data_layout="NCHW4c") - y = relay.strided_slice(y, begin=[0, 4], end=[None, 8]) + y = relay.op.nn.contrib_conv2d_nchwc(x, weight, channels=32, kernel_size=(3, 3), padding=(1, 1), + data_layout="NCHW4c") + + y = relay.strided_slice(y, + begin=relay.const([0, 4], "int32"), + end=relay.const([1, 21], "int32"), + strides=relay.const([1, 1], "int32")) + y = relay.layout_transform(y, "NCHW4c", "NCHW") y = relay.Function(analysis.free_vars(y), y) return y with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d): a = before() - a = run_opt_pass(a, [transform.CanonicalizeOps(), - transform.AlterOpLayout()]) b = run_opt_pass(expected(), transform.InferType()) - assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + # Verify inference result + mod_before = tvm.IRModule() + mod_new = tvm.IRModule() + mod_before['main'] = a + mod_new['main'] = b + with relay.build_config(opt_level=3): + for target, ctx in ctx_list(): + for kind in ["graph", "debug", "vm"]: + ex_before = relay.create_executor(kind, mod=mod_before, ctx=ctx, target=target) + ex_new = relay.create_executor(kind, mod=mod_new, ctx=ctx, target=target) + np_data = np.random.uniform(size=(1, 32, 28, 28)).astype("float32") + np_weight = np.random.uniform(size=(32, 32, 3, 3)).astype("float32") + result_before = ex_before.evaluate()(np_data, np_weight) + result_new = ex_new.evaluate()(np_data, np_weight) + tvm.testing.assert_allclose(result_before.asnumpy(), result_new.asnumpy(), rtol=1e-5, atol=1e-5) + def test_alter_layout_depthwise_conv2d(): """Test depthwise_conv2d operator""" @@ -940,11 +1013,8 @@ def expected_nhwc(): assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) -# TODO(@anijain2305, @icemelon9): We should fix this. This doesn't seem to be the -# right behavior of alter_layout -@pytest.mark.skip -def test_alter_layout_nhwc_nchw_arm(): - """ Check NHWC to NHCW conversion for a small sequence of ops.""" +def test_alter_layout_nhwc_arm(): + """ Check that AlterOplayout does not alter NHWC data layout. """ def alter_conv2d(attrs, inputs, tinfos, out_type): import topi with tvm.target.create("llvm -device=arm_cpu"): @@ -974,25 +1044,7 @@ def before_nhwc(): return y def expected_nhwc(): - x = relay.var("x", shape=(1, 56, 56, 64)) - weight1 = relay.var('weight1', shape=(3, 3, 64, 64)) - weight2 = relay.var('weight2', shape=(3, 3, 64, 64)) - y = relay.layout_transform(x, "NHWC", "NCHW") - weight1 = relay.layout_transform(weight1, "HWIO", "OIHW") - weight2 = relay.layout_transform(weight2, "HWIO", "OIHW") - y = relay.nn.conv2d(y, weight1, - channels=64, - kernel_size=(3, 3)) - y = relay.nn.relu(y) - y = relay.nn.avg_pool2d(y, - pool_size=(1,1)) - y = relay.nn.conv2d(y, weight2, - channels=64, - kernel_size=(3, 3)) - y = relay.nn.relu(y) - y = relay.layout_transform(y, "NCHW", "NHWC") - y = relay.Function(analysis.free_vars(y), y) - return y + return before_nhwc() with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d): a = before_nhwc() @@ -1048,6 +1100,7 @@ def expected(): test_alter_return_none() test_alter_layout() test_alter_layout_dual_path() + test_alter_layout_lrn() test_alter_layout_resnet() test_alter_layout_broadcast_op() test_alter_layout_broadcast_scalar_op() @@ -1060,5 +1113,5 @@ def expected(): test_alter_layout_pad() test_alter_layout_pool() test_alter_layout_sum() - # test_alter_layout_nhwc_nchw_arm() + test_alter_layout_nhwc_arm() test_alter_op_with_global_var() diff --git a/tests/python/relay/test_pass_annotate_target.py b/tests/python/relay/test_pass_annotate_target.py index 01ba9b619205..273c27b0d05f 100644 --- a/tests/python/relay/test_pass_annotate_target.py +++ b/tests/python/relay/test_pass_annotate_target.py @@ -22,7 +22,6 @@ import tvm import tvm.relay.testing -import tvm.relay.op as reg import tvm.relay.transform as transform from tvm import relay from tvm import runtime @@ -52,7 +51,7 @@ def update_lib(lib): return lib def check_vm_result(): - with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]): + with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]): exe = relay.vm.compile(mod, target=target, params=params) code, lib = exe.save() lib = update_lib(lib) @@ -63,7 +62,7 @@ def check_vm_result(): tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol) def check_graph_runtime_result(): - with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]): + with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]): json, lib, param = relay.build(mod, target=target, params=params) lib = update_lib(lib) rt_mod = tvm.contrib.graph_runtime.create(json, lib, ctx) @@ -187,7 +186,7 @@ def test_extern_dnnl_mobilenet(): def test_multiple_ends(): - @reg.register("nn.relu", "target.test") + @tvm.ir.register_op_attr("nn.relu", "target.test") def relu(attrs, args): # pylint: disable=unused-variable return True @@ -229,7 +228,7 @@ def after(): def test_type_propagation(): target = "test_type_propagation" - @reg.register("nn.relu", "target." + target) + @tvm.ir.register_op_attr("nn.relu", "target." + target) def relu(attrs, args): # pylint: disable=unused-variable return args[0].checked_type.dtype == "float32" @@ -248,11 +247,11 @@ def before(): def test_tuple(): target = "test_tuple" - @reg.register("nn.relu", "target." + target) + @tvm.ir.register_op_attr("nn.relu", "target." + target) def relu(attrs, args): # pylint: disable=unused-variable return True - @reg.register("concatenate", "target." + target) + @tvm.ir.register_op_attr("concatenate", "target." + target) def concatenate(attrs, args): # pylint: disable=unused-variable return True @@ -338,11 +337,11 @@ def after(): def test_multiple_runs(): - @reg.register("nn.relu", "target.A") + @tvm.ir.register_op_attr("nn.relu", "target.A") def relu(attrs, args): # pylint: disable=unused-variable return True - @reg.register("add", "target.B") + @tvm.ir.register_op_attr("add", "target.B") def add(attrs, args): # pylint: disable=unused-variable return True diff --git a/tests/python/relay/test_pass_annotation.py b/tests/python/relay/test_pass_annotation.py index 582d46aa40cf..0ecb2b522103 100644 --- a/tests/python/relay/test_pass_annotation.py +++ b/tests/python/relay/test_pass_annotation.py @@ -344,10 +344,10 @@ def get_func(): def test_runtime(target, device, func, fallback_device=None, expected_index=None): params = {"x": x_data, "y": y_data} - config = {"opt_level": 1} + config = {} if fallback_device: - config["fallback_device"] = fallback_device - with relay.build_config(**config): + config["relay.fallback_device_type"] = fallback_device.device_type + with tvm.transform.PassContext(opt_level=1, config=config): graph, lib, params = relay.build( func, target, @@ -538,9 +538,9 @@ def expected(): expected_index = [2, 2, 2, 1, 1, 1, 2, 2] check_annotated_graph(annotated_func, expected_func) params = {"a": a_data, "b": b_data, "c": c_data, "d": d_data} - config = {"opt_level": 0} - config["fallback_device"] = fallback_device - with relay.build_config(**config): + with tvm.transform.PassContext(opt_level=0, + config={"relay.fallback_device_type": + fallback_device.device_type}): graph, lib, params = relay.build(annotated_func, target, params=params) contexts = [tvm.cpu(0), tvm.context(dev)] graph_json = json.loads(graph) diff --git a/tests/python/relay/test_pass_combine_parallel_conv2d.py b/tests/python/relay/test_pass_combine_parallel_conv2d.py index 7f7f18598589..68e7fece7e98 100644 --- a/tests/python/relay/test_pass_combine_parallel_conv2d.py +++ b/tests/python/relay/test_pass_combine_parallel_conv2d.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. import tvm -from tvm import te from tvm import relay from tvm.relay import transform @@ -50,17 +49,28 @@ def expected(x, w1, w2, w3, w4, channels1, channels2, channels3, channels4): args = [x, w1, w2, w3, w4] w = relay.concatenate((w1, w2, w4), axis=0) y = relay.nn.conv2d(x, w, channels=channels1 + channels2 + channels4) - y1 = relay.strided_slice(y, [0, 0], [None, channels1]) - y2 = relay.strided_slice(y, [0, channels1], [None, channels1 + channels2]) + y1 = relay.strided_slice(y, + begin=relay.const([0, 0], "int64"), + end=relay.const([-1, channels1], "int64"), + strides=relay.const([1, 1], 'int64'), + slice_mode="size") + y2 = relay.strided_slice(y, + begin=relay.const([0, channels1], "int64"), + end=relay.const([-1, channels2], "int64"), + strides=relay.const([1, 1], 'int64'), + slice_mode="size") y3 = relay.nn.conv2d(x, w3) - y4 = relay.strided_slice(y, [0, channels1 + channels2], - [None, channels1 + channels2 + channels4]) + y4 = relay.strided_slice(y, + begin=relay.const([0, channels1 + channels2], "int64"), + end=relay.const([-1, channels4], "int64"), + strides=relay.const([1, 1], 'int64'), + slice_mode="size") y5 = relay.nn.max_pool2d(x) y = relay.Tuple((y1, y2, y3, y4, y5)) return relay.Function(args, y) def check(x_shape, channels1, channels2, channels3, channels4): - x = relay.var("x", shape=x_shape) + x = relay.var("x", shape=x_shape) in_c = x_shape[1] w1 = relay.var("w1", shape=(channels1, in_c, 1, 1)) w2 = relay.var("w2", shape=(channels2, in_c, 1, 1)) @@ -99,8 +109,16 @@ def expected(x, w1, w2, scale1, scale2, bias, channels1, channels2): y = relay.nn.conv2d(x, w, channels=channels1 + channels2) y = relay.multiply(y, scale) y = relay.nn.relu(y) - y1 = relay.strided_slice(y, [0, 0], [None, channels1]) - y2 = relay.strided_slice(y, [0, channels1], [None, channels1 + channels2]) + y1 = relay.strided_slice(y, + begin=relay.const([0, 0], "int64"), + end=relay.const([-1, channels1], "int64"), + strides=relay.const([1, 1], "int64"), + slice_mode="size") + y2 = relay.strided_slice(y, + begin=relay.const([0, channels1], "int64"), + end=relay.const([-1, channels2], "int64"), + strides=relay.const([1, 1], "int64"), + slice_mode="size") y2 = relay.add(y2, bias) y = relay.Tuple((y1, y2)) return relay.Function(args, y) @@ -138,8 +156,16 @@ def expected(x, w1, w2, scale1, scale2, channels1, channels2): args = [x, w1, w2, scale1, scale2] w = relay.concatenate((w1, w2), axis=0) y = relay.nn.conv2d(x, w, channels=channels1 + channels2) - y1 = relay.strided_slice(y, [0, 0], [None, channels1]) - y2 = relay.strided_slice(y, [0, channels1], [None, channels1 + channels2]) + y1 = relay.strided_slice(y, + begin=relay.const([0, 0], "int64"), + end=relay.const([-1, channels1], "int64"), + strides=relay.const([1, 1], "int64"), + slice_mode="size") + y2 = relay.strided_slice(y, + begin=relay.const([0, channels1], "int64"), + end=relay.const([-1, channels2], "int64"), + strides=relay.const([1, 1], "int64"), + slice_mode="size") y1 = relay.multiply(y1, scale1) y2 = relay.multiply(y2, scale2) y = relay.Tuple((y1, y2)) @@ -178,8 +204,16 @@ def expected(x, w, channels, repeat): for i in range(repeat): w_concat = relay.concatenate((w, w), axis=0) y = relay.nn.conv2d(y, w_concat, channels=channels*2) - y1 = relay.strided_slice(y, [0, 0], [None, channels]) - y2 = relay.strided_slice(y, [0, channels], [None, channels * 2]) + y1 = relay.strided_slice(y, + begin=relay.const([0, 0], "int64"), + end=relay.const([-1, channels], "int64"), + strides=relay.const([1, 1], "int64"), + slice_mode="size") + y2 = relay.strided_slice(y, + begin=relay.const([0, channels], "int64"), + end=relay.const([-1, channels], "int64"), + strides=relay.const([1, 1], "int64"), + slice_mode="size") y = relay.concatenate((y1, y2), axis=1) return relay.Function(args, y) diff --git a/tests/python/relay/test_pass_convert_op_layout.py b/tests/python/relay/test_pass_convert_op_layout.py index c5a7b0e0c14d..f3cdbfc86e51 100644 --- a/tests/python/relay/test_pass_convert_op_layout.py +++ b/tests/python/relay/test_pass_convert_op_layout.py @@ -49,7 +49,7 @@ def expected(): return before() a = before() - a = run_opt_pass(a, transform.ConvertLayout('NCHW')) + a = run_opt_pass(a, transform.ConvertLayout({'nn.conv2d': ['NCHW', 'default']})) b = run_opt_pass(expected(), transform.InferType()) assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) @@ -84,7 +84,7 @@ def expected(): return y a = before() - a = run_opt_pass(a, transform.ConvertLayout('NCHW')) + a = run_opt_pass(a, transform.ConvertLayout({'nn.conv2d': ['NCHW', 'default']})) b = run_opt_pass(expected(), transform.InferType()) assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) @@ -129,7 +129,7 @@ def expected(): return y a = before() - a = run_opt_pass(a, transform.ConvertLayout('NCHW')) + a = run_opt_pass(a, transform.ConvertLayout({'nn.conv2d': ['NCHW', 'default']})) b = run_opt_pass(expected(), transform.InferType()) assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) @@ -177,7 +177,7 @@ def expected(): return y a = before() - a = run_opt_pass(a, transform.ConvertLayout('NCHW')) + a = run_opt_pass(a, transform.ConvertLayout({'nn.conv2d': ['NCHW', 'default']})) b = run_opt_pass(expected(), transform.InferType()) assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) @@ -232,7 +232,7 @@ def expected(): return y a = before() - a = run_opt_pass(a, transform.ConvertLayout('NCHW')) + a = run_opt_pass(a, transform.ConvertLayout({'nn.conv2d': ['NCHW', 'default']})) b = run_opt_pass(expected(), transform.InferType()) assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) @@ -256,7 +256,7 @@ def before(): return relay.Function(analysis.free_vars(y), y) a = before() - a = run_opt_pass(a, transform.ConvertLayout('NCHW')) + a = run_opt_pass(a, transform.ConvertLayout({'nn.conv2d': ['NCHW', 'default']})) # Check that there is only 1 NHWC to NCHW transform. has_lt = list() @@ -312,7 +312,7 @@ def expected(): return relay.Function(analysis.free_vars(y), y) a = before() - a = run_opt_pass(a, transform.ConvertLayout('NCHW')) + a = run_opt_pass(a, transform.ConvertLayout({'nn.conv2d': ['NCHW', 'default']})) b = run_opt_pass(expected(), transform.InferType()) assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) @@ -344,7 +344,7 @@ def expected(): return y a = before() - a = run_opt_pass(a, transform.ConvertLayout('NCHW')) + a = run_opt_pass(a, transform.ConvertLayout({'nn.conv2d': ['NCHW', 'default']})) b = run_opt_pass(expected(), transform.InferType()) assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) @@ -392,7 +392,7 @@ def expected(): return y a = before() - a = run_opt_pass(a, transform.ConvertLayout('NCHW')) + a = run_opt_pass(a, transform.ConvertLayout({'nn.conv2d': ['NCHW', 'default']})) b = run_opt_pass(expected(), transform.InferType()) assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) @@ -448,7 +448,7 @@ def expected(): return y a = before() - a = run_opt_pass(a, transform.ConvertLayout('NCHW')) + a = run_opt_pass(a, transform.ConvertLayout({'qnn.conv2d': ['NCHW', 'default']})) b = run_opt_pass(expected(), transform.InferType()) assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) @@ -526,7 +526,7 @@ def expected(): return y a = before() - a = run_opt_pass(a, transform.ConvertLayout('NCHW')) + a = run_opt_pass(a, transform.ConvertLayout({'qnn.conv2d': ['NCHW', 'default']})) b = run_opt_pass(expected(), transform.InferType()) assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) @@ -606,7 +606,132 @@ def expected(): return y a = before() - a = run_opt_pass(a, transform.ConvertLayout('NCHW')) + a = run_opt_pass(a, transform.ConvertLayout({'qnn.conv2d': ['NCHW', 'default']})) + b = run_opt_pass(expected(), transform.InferType()) + + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + + +def test_conv_convert_kernel_layout(): + """ Check that convolution kernel layout is correctly transformed. """ + def before(): + x = relay.var("x", shape=(1, 56, 56, 64)) + weight = relay.var("weight", shape=(3, 3, 64, 64)) + y = relay.nn.conv2d(x, weight, channels=64, kernel_size=(3, 3), padding=(1, 1), + data_layout='NHWC', kernel_layout='HWIO') + y = relay.Function(analysis.free_vars(y), y) + return y + + def expected(): + x = relay.var("x", shape=(1, 56, 56, 64)) + w = relay.var("weight", shape=(3, 3, 64, 64)) + w = relay.layout_transform(w, 'HWIO', 'OHWI') + y = relay.nn.conv2d(x, w, + channels=64, + kernel_size=(3, 3), + padding=(1, 1), + data_layout='NHWC', + kernel_layout='OHWI') + y = relay.Function(analysis.free_vars(y), y) + return y + + a = before() + a = run_opt_pass(a, transform.ConvertLayout({'nn.conv2d': ['NHWC', 'OHWI']})) + b = run_opt_pass(expected(), transform.InferType()) + + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + + +def test_default_keyword(): + """ Check that the default keyword selects correct TVM default layout. """ + def before(): + x = relay.var("x", shape=(1, 64, 56, 56)) + weight = relay.var("weight", shape=(64, 3, 3, 64)) + y = relay.nn.conv2d(x, weight, channels=64, kernel_size=(3, 3), padding=(1, 1), + data_layout='NCHW', kernel_layout='OHWI') + y = relay.Function(analysis.free_vars(y), y) + return y + + def expected(): + x = relay.var("x", shape=(1, 64, 56, 56)) + w = relay.var("weight", shape=(64, 3, 3, 64)) + w = relay.layout_transform(w, 'OHWI', 'OIHW') + y = relay.nn.conv2d(x, w, + channels=64, + kernel_size=(3, 3), + padding=(1, 1), + data_layout='NCHW', + kernel_layout='OIHW') + y = relay.Function(analysis.free_vars(y), y) + return y + + a = before() + a = run_opt_pass(a, transform.ConvertLayout({'nn.conv2d': ['NCHW', 'default']})) + b = run_opt_pass(expected(), transform.InferType()) + + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + + +def test_different_ops_convert_layout(): + """ Check convert layout correctly supports converting the layout of + different ops in the same graph. + """ + def before(): + x = relay.var("x", shape=(1, 64, 56, 56)) + weight1 = relay.var("weight1", shape=(64, 3, 3, 64)) + weight2 = relay.var("weight2", shape=(64, 3, 3, 64), dtype='int8') + out = relay.nn.conv2d(x, weight1, + channels=64, + kernel_size=(3, 3), + padding=(1, 1), + data_layout='NCHW', + kernel_layout='OHWI') + out = relay.cast(out, 'int8') + out = relay.qnn.op.conv2d(out, weight2, + relay.const(1, 'int32'), + relay.const(1, 'int32'), + relay.const(1, 'float32'), + relay.const(1, 'float32'), + channels=64, + kernel_size=(3, 3), + padding=(1, 1), + data_layout='NCHW', + kernel_layout='OHWI') + out = relay.Function(analysis.free_vars(out), out) + return out + + def expected(): + x = relay.var("x", shape=(1, 64, 56, 56)) + weight1 = relay.var("weight1", shape=(64, 3, 3, 64)) + weight2 = relay.var("weight2", shape=(64, 3, 3, 64), dtype='int8') + x = relay.layout_transform(x, 'NCHW', 'NHWC') + weight1 = relay.layout_transform(weight1, 'OHWI', 'HWIO') + out = relay.nn.conv2d(x, weight1, + channels=64, + kernel_size=(3, 3), + padding=(1, 1), + data_layout='NHWC', + kernel_layout='HWIO') + out = relay.cast(out, 'int8') + out = relay.layout_transform(out, 'NHWC', 'NCHW') + weight2 = relay.layout_transform(weight2, 'OHWI', 'OIHW') + out = relay.qnn.op.conv2d(out, weight2, + relay.const(1, 'int32'), + relay.const(1, 'int32'), + relay.const(1, 'float32'), + relay.const(1, 'float32'), + channels=64, + kernel_size=(3, 3), + padding=(1, 1), + data_layout='NCHW', + kernel_layout='OIHW') + out = relay.Function(analysis.free_vars(out), out) + return out + + a = before() + desired_layouts = {'nn.conv2d': ['NHWC', 'HWIO'], + 'qnn.conv2d': ['NCHW', 'OIHW']} + a = run_opt_pass(a, transform.ConvertLayout(desired_layouts)) b = run_opt_pass(expected(), transform.InferType()) assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) @@ -625,3 +750,6 @@ def expected(): test_qnn_conv_requantize_convert_layout() test_qnn_conv_concat_convert_layout() test_qnn_conv_add_convert_layout() + test_conv_convert_kernel_layout() + test_default_keyword() + test_different_ops_convert_layout() diff --git a/tests/python/relay/test_pass_fast_math.py b/tests/python/relay/test_pass_fast_math.py index e75316f1e04b..93ad034be2ef 100644 --- a/tests/python/relay/test_pass_fast_math.py +++ b/tests/python/relay/test_pass_fast_math.py @@ -29,7 +29,7 @@ def test_exp(): assert "fast_exp" in fast_mod.astext() # Check that FastMath option works for relay.build. - with relay.build_config(opt_level=3, required_pass=['FastMath']): + with tvm.transform.PassContext(opt_level=3, required_pass=['FastMath']): fast_mod = relay.optimize(mod, target='llvm', params=None) assert "fast_exp" in fast_mod[0].astext() @@ -43,7 +43,7 @@ def test_tanh(): assert "fast_tanh" in fast_mod.astext() # Check that FastMath option works for relay.build. - with relay.build_config(opt_level=3, required_pass=['FastMath']): + with tvm.transform.PassContext(opt_level=3, required_pass=['FastMath']): fast_mod = relay.optimize(mod, target='llvm', params=None) assert "fast_tanh" in fast_mod[0].astext() diff --git a/tests/python/relay/test_pass_fold_constant.py b/tests/python/relay/test_pass_fold_constant.py index 4f44d2b3043f..fcccab5c6b97 100644 --- a/tests/python/relay/test_pass_fold_constant.py +++ b/tests/python/relay/test_pass_fold_constant.py @@ -32,6 +32,25 @@ def run_opt_pass(expr, opt_pass): return entry if isinstance(expr, relay.Function) else entry.body +def test_concatenate_const(): + def before(): + data = tvm.nd.array(np.array([1.0, 2.0, 3.0])) + const = relay.const(data) + concat = relay.op.concatenate([const, const], axis=0) + func = relay.Function([], concat) + return func + + def expected(): + data = tvm.nd.array(np.array([1.0, 2.0, 3.0, 1.0, 2.0, 3.0])) + const = relay.const(data) + func = relay.Function([], const) + return func + + zz = run_opt_pass(before(), transform.FoldConstant()) + zexpected = run_opt_pass(expected(), transform.InferType()) + assert tvm.ir.structural_equal(zz, zexpected) + + def test_fold_const(): c_data = np.array([1, 2, 3]).astype("float32") t = relay.TensorType([1, 2, 3], "float32") @@ -51,13 +70,9 @@ def expected(): z = relay.add(y, relay.const(c_data)) return relay.Function([x], z) - def fail(x): - raise RuntimeError() - # the fold constant should work on any context. - with tvm.target.build_config(add_lower_pass=[(0, fail)]): - with tvm.target.create("cuda"): - zz = run_opt_pass(before(), transform.FoldConstant()) + with tvm.target.create("cuda"): + zz = run_opt_pass(before(), transform.FoldConstant()) zexpected = run_opt_pass(expected(), transform.InferType()) assert tvm.ir.structural_equal(zz, zexpected) @@ -198,7 +213,7 @@ def initializer(_, param): mod, params = create_workload(bn_output[0], initializer) mod["main"] = bind_params_by_name(mod["main"], params) - with relay.build_config(opt_level=3): + with tvm.transform.PassContext(opt_level=3): mod = remove_bn_pass(mod) expect = run_infer_type(expected()) diff --git a/tests/python/relay/test_pass_fold_scale_axis.py b/tests/python/relay/test_pass_fold_scale_axis.py index d7c437adcc99..8aecf3f891f3 100644 --- a/tests/python/relay/test_pass_fold_scale_axis.py +++ b/tests/python/relay/test_pass_fold_scale_axis.py @@ -35,58 +35,75 @@ def run_opt_pass(expr, opt_pass): def test_fold_fwd_simple(): """Simple testcase.""" - def before(x, conv_weight, in_bias, in_scale, channels): + def before(x, conv_weight, in_bias, in_scale, channels, blocking): args = [x, conv_weight, in_bias] - in_bias = relay.expand_dims(in_bias, axis=1, num_newaxis=2) x = relay.multiply(x, in_scale) x = relay.nn.relu(x) x = relay.add(x, in_bias) y = relay.nn.conv2d(x, conv_weight, channels=channels, kernel_size=(3, 3), - padding=(1, 1)) + padding=(1, 1), + data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW", + kernel_layout="OIHW2i{}o".format(blocking[1]) if blocking else "OIHW") return relay.Function(args, y) - def expected(x, conv_weight, in_bias, in_scale, channels): + def expected(x, conv_weight, in_bias, in_scale, in_channels, channels, blocking): # use a fixed order of args so alpha equal check can pass args = [x, conv_weight, in_bias] - in_bias = relay.expand_dims(in_bias, axis=1, num_newaxis=2) - squeezed_scale = relay.squeeze(in_scale, axis=[1,2]) - x = relay.nn.relu(x) - in_bias = relay.divide(in_bias, relay.expand_dims(squeezed_scale, axis=1, num_newaxis=2)) - x = relay.add(x, in_bias) - conv_weight = relay.multiply( - conv_weight , relay.expand_dims(squeezed_scale, axis=1, num_newaxis=2)) + if blocking: + squeezed_scale = relay.squeeze(in_scale, axis=[0,2,3]) + x = relay.nn.relu(x) + in_bias = relay.divide(in_bias, + relay.reshape(squeezed_scale, (1, in_channels // blocking[0], 1, 1, blocking[0]))) #NCHWc + x = relay.add(x, in_bias) + conv_weight = relay.multiply(conv_weight, + relay.reshape(squeezed_scale, (1, in_channels//2, 1, 1, 2, 1))) #OIHWio + else: + squeezed_scale = relay.squeeze(in_scale, axis=[1,2]) + x = relay.nn.relu(x) + in_bias = relay.divide(in_bias, relay.expand_dims(squeezed_scale, axis=1, num_newaxis=2)) + x = relay.add(x, in_bias) + conv_weight = relay.multiply( + conv_weight , relay.expand_dims(squeezed_scale, axis=1, num_newaxis=2)) + y = relay.nn.conv2d(x, conv_weight, channels=channels, kernel_size=(3, 3), - padding=(1, 1)) + padding=(1, 1), + data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW", + kernel_layout="OIHW2i{}o".format(blocking[1]) if blocking else "OIHW") return relay.Function(args, y) - def check(shape, channels): + def check(shape, channels, blocking): x = relay.var("x", shape=shape) - in_channels = shape[1] weight = relay.var("weight") - in_bias = relay.var("in_bias", shape=(in_channels,)) - in_scale = relay.const(_get_positive_scale((in_channels, 1, 1))) - y1 = before(x, weight, in_bias, in_scale, channels) + if blocking: + in_channels = shape[1] * shape[4] + in_bias = relay.var("in_bias", shape=(1, in_channels // blocking[0], 1, 1, blocking[0])) + in_scale = relay.const(_get_positive_scale((1, in_channels // blocking[0], 1, 1, blocking[0]))) + else: + in_channels = shape[1] + in_bias = relay.var("in_bias", shape=(in_channels, 1, 1)) + in_scale = relay.const(_get_positive_scale((in_channels, 1, 1))) + y1 = before(x, weight, in_bias, in_scale, channels, blocking) y1 = run_opt_pass(y1, transform.InferType()) type_dict = {x.name_hint:x.checked_type for x in y1.params} weight = relay.var("weight", type_dict["weight"]) y1_folded = run_opt_pass(y1, transform.ForwardFoldScaleAxis()) - y1_expected = expected(x, weight, in_bias, in_scale, channels) + y1_expected = expected(x, weight, in_bias, in_scale, in_channels, channels, blocking) y1_folded = run_opt_pass(y1_folded, transform.InferType()) y1_expected = run_opt_pass(y1_expected, transform.InferType()) assert tvm.ir.structural_equal(y1_folded, y1_expected) - check((2, 4, 10, 10), 2) - + check((2, 4, 10, 10), 2, None) + check((2, 2, 10, 10, 2), 8, (2, 4)) def test_fold_fwd_dual_path(): """scale axis being consumed by two consumers""" - def before(x, conv_weight, in_bias, in_scale, channels): + def before(x, conv_weight, in_bias, in_scale, channels, blocking): args = [x, conv_weight, in_bias] x = relay.multiply(in_scale, x) x = relay.nn.relu(x) @@ -94,363 +111,474 @@ def before(x, conv_weight, in_bias, in_scale, channels): y1 = relay.nn.conv2d(x, conv_weight, channels=channels, kernel_size=(3, 3), - data_layout="NHWC", - kernel_layout="HWIO", + data_layout="NHWC{}c".format(blocking[0]) if blocking else "NHWC", + kernel_layout="HWIO1i{}o".format(blocking[1]) if blocking else "HWIO", groups=channels, padding=(1, 1)) y2 = relay.nn.conv2d(x, conv_weight, channels=channels, kernel_size=(3, 3), - data_layout="NHWC", - kernel_layout="HWIO", + data_layout="NHWC{}c".format(blocking[0]) if blocking else "NHWC", + kernel_layout="HWIO1i{}o".format(blocking[1]) if blocking else "HWIO", groups=channels, padding=(1, 1)) z = relay.add(y1, y2) return relay.Function(args, z) - def expected(x, conv_weight, in_bias, in_scale, channels): + def expected(x, conv_weight, in_bias, in_scale, channels, blocking): args = [x, conv_weight, in_bias] x = relay.nn.relu(x) - in_bias = relay.divide(in_bias, in_scale) + if blocking: + _in_scale = relay.reshape(in_scale, (1, 1, 1, channels//blocking[0], blocking[0])) #NHWCc + else: + _in_scale = in_scale + in_bias = relay.divide(in_bias, _in_scale) x = relay.subtract(x, in_bias) + if blocking: + _in_scale = relay.reshape(in_scale, (1, 1, 1, channels//blocking[0], 1, blocking[0])) #HWIOio y1 = relay.nn.conv2d(x, - relay.multiply(conv_weight, in_scale), + relay.multiply(conv_weight, _in_scale), channels=channels, kernel_size=(3, 3), - data_layout="NHWC", - kernel_layout="HWIO", + data_layout="NHWC{}c".format(blocking[0]) if blocking else "NHWC", + kernel_layout="HWIO1i{}o".format(blocking[1]) if blocking else "HWIO", groups=channels, padding=(1, 1)) + if blocking: + _in_scale = relay.reshape(in_scale, (1, 1, 1, channels//blocking[0], 1, blocking[0])) #HWIOio y2 = relay.nn.conv2d(x, - relay.multiply(conv_weight, in_scale), + relay.multiply(conv_weight, _in_scale), channels=channels, kernel_size=(3, 3), - data_layout="NHWC", - kernel_layout="HWIO", + data_layout="NHWC{}c".format(blocking[0]) if blocking else "NHWC", + kernel_layout="HWIO1i{}o".format(blocking[1]) if blocking else "HWIO", groups=channels, padding=(1, 1)) z = relay.add(y1, y2) return relay.Function(args, z) - def check(dshape, channels): + def check(dshape, channels, blocking): x = relay.var("x", shape=dshape) - in_channels = dshape[-1] + if blocking: + in_channels = dshape[3] * dshape[4] + wshape = (3, 3, 1, channels//blocking[1], 1, blocking[1]) # HWIOio + weight = relay.var("weight", shape=wshape) + in_bias = relay.var("in_bias", shape=(in_channels//blocking[0],blocking[0])) + in_scale = relay.const(_get_positive_scale((in_channels//blocking[0],blocking[0]))) + else: + in_channels = dshape[-1] + wshape = (3, 3, 1, channels) # HWIO + weight = relay.var("weight", shape=wshape) + in_bias = relay.var("in_bias", shape=(in_channels,)) + in_scale = relay.const(_get_positive_scale(in_channels,)) + # test depthwise assert in_channels == channels - wshape = (3, 3, 1, channels) # HWIO - weight = relay.var("weight", shape=wshape) - in_bias = relay.var("in_bias", shape=(in_channels,)) - in_scale = relay.const(_get_positive_scale(in_channels,)) - y1 = before(x, weight, in_bias, in_scale, channels) + + y1 = before(x, weight, in_bias, in_scale, channels, blocking) y1 = run_opt_pass(y1, transform.InferType()) y1_folded = run_opt_pass(y1, transform.ForwardFoldScaleAxis()) type_dict = {x.name_hint:x.checked_type for x in y1.params} weight = relay.var("weight", type_dict["weight"]) - y1_expected = expected(x, weight, in_bias, in_scale, channels) + y1_expected = expected(x, weight, in_bias, in_scale, channels, blocking) y1_expected = run_opt_pass(y1_expected, transform.InferType()) assert tvm.ir.structural_equal(y1_folded, y1_expected) - check((2, 4, 10, 3), 3) - + check((2, 4, 10, 3), 3, None) + check((2, 4, 10, 2, 2), 4, (2, 2)) def test_fold_fwd_fail(): """testcase where we canont fold""" - def before(x, conv_weight, in_bias, in_scale, channels): + def before(x, conv_weight, in_bias, in_scale, channels, blocking): x = relay.multiply(x, in_scale) xx = relay.nn.leaky_relu(x, alpha=0.1) y1 = relay.nn.conv2d(xx, conv_weight, channels=channels, kernel_size=(3, 3), - data_layout="NHWC", + data_layout="NHWC{}c".format(blocking[0]) if blocking else "NHWC", + kernel_layout="HWIO1i{}o".format(blocking[1]) if blocking else "HWIO", padding=(1, 1)) z = relay.add(y1, x) return relay.Function(relay.analysis.free_vars(z), z) - def check(shape, channels): + def check(shape, channels, blocking): x = relay.var("x", shape=shape) - in_channels = shape[-1] + if blocking: + in_channels = shape[3] * shape[4] + in_bias = relay.var("in_bias", shape=(in_channels//blocking[0],blocking[0])) + in_scale = relay.const(_get_positive_scale((in_channels//blocking[0],blocking[0]))) + else: + in_channels = shape[-1] + in_bias = relay.var("in_bias", shape=(in_channels,)) + in_scale = relay.const(_get_positive_scale(size=(in_channels,))) # test depthwise assert in_channels == channels weight = relay.var("weight") - in_bias = relay.var("in_bias", shape=(in_channels,)) - in_scale = relay.const(_get_positive_scale(size=(in_channels,))) - y1 = before(x, weight, in_bias, in_scale, channels) + y1 = before(x, weight, in_bias, in_scale, channels, blocking) y1 = run_opt_pass(y1, transform.InferType()) y1_folded = run_opt_pass(y1, transform.ForwardFoldScaleAxis()) assert tvm.ir.structural_equal(y1, y1_folded) - check((2, 11, 10, 4), 4) - + check((2, 11, 10, 4), 4, None) + check((2, 11, 10, 2, 2), 4, (2,2)) def test_fold_fwd_relu_fail(): """testcase where we canont fold because scale can not pass relu""" - def before(x, conv_weight, in_bias, in_scale, channels): + def before(x, conv_weight, in_bias, in_scale, channels, blocking): x = relay.multiply(x, in_scale) xx = relay.nn.relu(x) y1 = relay.nn.conv2d(xx, conv_weight, channels=channels, kernel_size=(3, 3), - data_layout="NHWC", + data_layout="NHWC{}c".format(blocking[0]) if blocking else "NHWC", + kernel_layout="HWIO1i{}o".format(blocking[1]) if blocking else "HWIO", padding=(1, 1)) z = relay.add(y1, x) return relay.Function(relay.analysis.free_vars(z), z) - def check(shape, channels, in_scale): + def check(shape, channels, blocking, in_scale): x = relay.var("x", shape=shape) - in_channels = shape[-1] - # test depthwise - assert in_channels == channels weight = relay.var("weight") - in_bias = relay.var("in_bias", shape=(in_channels,)) - y1 = before(x, weight, in_bias, in_scale, channels) + if blocking: + in_channels = shape[3] * shape[4] + in_bias = relay.var("in_bias", shape=(1, in_channels // blocking[0], 1, 1, blocking[0])) + else: + in_channels = shape[-1] + in_bias = relay.var("in_bias", shape=(in_channels,)) + + assert in_channels == channels + y1 = before(x, weight, in_bias, in_scale, channels, blocking) y1 = run_opt_pass(y1, transform.InferType()) y1_folded = run_opt_pass(y1, transform.ForwardFoldScaleAxis()) assert tvm.ir.structural_equal(y1, y1_folded) in_scale = relay.var("in_scale", shape=(4,)) - check((2, 11, 10, 4), 4, in_scale) + check((2, 11, 10, 4), 4, None, in_scale) in_scale = relay.const(-_get_positive_scale((4,))) - check((2, 11, 10, 4), 4, in_scale) + check((2, 11, 10, 4), 4, None, in_scale) + + in_scale = relay.var("in_scale", shape=(1,1,1,2,2)) + check((2, 11, 10, 2, 2), 4, (2, 2), in_scale) + in_scale = relay.const(-_get_positive_scale((1,1,1,2,2))) + check((2, 11, 10, 2, 2), 4, (2, 2), in_scale) + + def test_fold_fwd_negative_scale(): """Testcase of folding negative scale""" - def before(x, conv_weight, in_scale, channels): + def before(x, conv_weight, in_scale, channels, blocking): args = [x, conv_weight] x = relay.multiply(x, in_scale) y = relay.nn.conv2d(x, conv_weight, channels=channels, kernel_size=(3, 3), - padding=(1, 1)) + padding=(1, 1), + data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW", + kernel_layout="OIHW4i{}o".format(blocking[1]) if blocking else "OIHW") return relay.Function(args, y) - def expected(x, conv_weight, in_scale, channels): + def expected(x, conv_weight, in_scale, in_channels, channels, blocking): # use a fixed order of args so alpha equal check can pass args = [x, conv_weight] - squeezed_scale = relay.squeeze(in_scale, axis=[1,2]) - conv_weight = relay.multiply( - conv_weight , relay.expand_dims(squeezed_scale, axis=1, num_newaxis=2)) + if blocking: + squeezed_scale = relay.squeeze(in_scale, axis=[0,2,3]) + conv_weight = relay.multiply( + conv_weight , relay.reshape(squeezed_scale, (1, in_channels//4, 1, 1, 4, 1))) + #blocking by "i" in OIHWio + else: + squeezed_scale = relay.squeeze(in_scale, axis=[1,2]) + conv_weight = relay.multiply( + conv_weight , relay.expand_dims(squeezed_scale, axis=1, num_newaxis=2)) y = relay.nn.conv2d(x, conv_weight, channels=channels, kernel_size=(3, 3), - padding=(1, 1)) + padding=(1, 1), + data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW", + kernel_layout="OIHW4i{}o".format(blocking[1]) if blocking else "OIHW") return relay.Function(args, y) - def check(shape, channels): + def check(shape, channels, blocking): x = relay.var("x", shape=shape) - in_channels = shape[1] - in_scale = relay.const(-_get_positive_scale((in_channels, 1, 1))) + if blocking: + in_channels = shape[1] * shape[4] + in_scale = relay.const(-_get_positive_scale((1, shape[1], 1, 1, shape[4]))) + else: + in_channels = shape[1] + in_scale = relay.const(-_get_positive_scale((in_channels, 1, 1))) weight = relay.var("weight") - y1 = before(x, weight, in_scale, channels) + y1 = before(x, weight, in_scale, channels, blocking) y1 = run_opt_pass(y1, transform.InferType()) type_dict = {x.name_hint:x.checked_type for x in y1.params} weight = relay.var("weight", type_dict["weight"]) y1_folded = run_opt_pass(y1, transform.ForwardFoldScaleAxis()) - y1_expected = expected(x, weight, in_scale, channels) + y1_expected = expected(x, weight, in_scale, in_channels, channels, blocking) y1_expected = run_opt_pass(y1_expected, transform.InferType()) assert tvm.ir.structural_equal(y1_folded, y1_expected) - check((2, 4, 10, 10), 4) - + check((2, 4, 10, 10), 4, None) + check((2, 2, 10, 10, 2), 8, (2, 2)) def test_fold_bwd_simple(): """Simple testcase.""" - def before(x, conv_weight, out_bias, out_scale, channels): + def before(x, conv_weight, out_bias, out_scale, in_channels, channels, blocking): args = [x, conv_weight, out_bias] - out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2) + if blocking: + out_bias = relay.reshape(out_bias, (1, channels//blocking[1], 1, 1, blocking[1])) + else: + out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2) y = relay.nn.conv2d(x, conv_weight, channels=channels, kernel_size=(3, 3), - padding=(1, 1)) + padding=(1, 1), + data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW", + kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW") y = relay.add(y, out_bias) y = relay.nn.relu(y) + if blocking: + out_scale = relay.reshape(out_scale, (1, channels//blocking[1], 1, 1, blocking[1])) y = relay.multiply(y, out_scale) return relay.Function(args, y) - def expected(x, conv_weight, out_bias, out_scale, channels): + def expected(x, conv_weight, out_bias, out_scale, in_channels, channels, blocking): # use a fixed order of args so alpha equal check can pass args = [x, conv_weight, out_bias] - out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2) - squeezed_scale = relay.squeeze(out_scale, axis=[1,2]) - conv_weight = relay.multiply( - conv_weight , relay.expand_dims(squeezed_scale, axis=1, num_newaxis=3)) + if blocking: + out_bias = relay.reshape(out_bias, (1, channels//blocking[1], 1, 1, blocking[1])) + out_scale = relay.reshape(out_scale, (1, channels//blocking[1], 1, 1, blocking[1])) + squeezed_scale = relay.squeeze(out_scale, axis=[0, 2, 3]) + conv_weight = relay.multiply( + conv_weight , relay.reshape(squeezed_scale, (channels//blocking[1], 1, 1, 1, 1, blocking[1]))) + else: + out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2) + squeezed_scale = relay.squeeze(out_scale, axis=[1,2]) + conv_weight = relay.multiply( + conv_weight , relay.expand_dims(squeezed_scale, axis=1, num_newaxis=3)) y = relay.nn.conv2d(x, conv_weight, channels=channels, kernel_size=(3, 3), - padding=(1, 1)) - out_bias = relay.multiply(out_bias, + padding=(1, 1), + data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW", + kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW") + if blocking: + out_bias = relay.multiply(out_bias, + relay.reshape(squeezed_scale, (1, channels//blocking[1], 1, 1, blocking[1]))) + else: + out_bias = relay.multiply(out_bias, relay.expand_dims(squeezed_scale, axis=1, num_newaxis=2)) y = relay.add(y, out_bias) y = relay.nn.relu(y) return relay.Function(args, y) - def check(shape, channels): + def check(shape, in_channels, channels, blocking): x = relay.var("x", shape=shape) - in_channels = shape[1] weight = relay.var("weight") out_bias = relay.var("out_bias", shape=(channels,)) - out_scale = relay.const(_get_positive_scale((channels, 1, 1))) - - y1 = before(x, weight, out_bias, out_scale, channels) + if blocking: + out_scale = relay.const(_get_positive_scale((channels,))) + else: + out_scale = relay.const(_get_positive_scale((channels,1, 1))) + y1 = before(x, weight, out_bias, out_scale, in_channels, channels, blocking) y1 = run_opt_pass(y1, transform.InferType()) type_dict = {x.name_hint:x.checked_type for x in y1.params} weight = relay.var("weight", type_dict["weight"]) y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis()) - y1_expected = expected(x, weight, out_bias, out_scale, channels) + y1_expected = expected(x, weight, out_bias, out_scale, in_channels, channels, blocking) y1_expected = run_opt_pass(y1_expected, transform.InferType()) assert tvm.ir.structural_equal(y1_folded, y1_expected) - check((2, 4, 10, 10), 8) + check((2, 4, 10, 10), 4, 8, None) + check((2, 2, 10, 10, 16), 32, 64, (16, 16)) def test_fold_bwd_dual_path(): """Dual path testcase.""" - def before(x, conv_weight, out_bias, out_scale, channels): + def before(x, conv_weight, out_bias, out_scale, in_channels, channels, blocking): args = [x, conv_weight, out_bias] y1 = relay.nn.conv2d(x, conv_weight, channels=channels, kernel_size=(3, 3), - padding=(1, 1)) + padding=(1, 1), + data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW", + kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW") y1 = relay.nn.relu(y1) y2 = relay.nn.conv2d(x, conv_weight, channels=channels, kernel_size=(3, 3), - padding=(1, 1)) + padding=(1, 1), + data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW", + kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW") y2 = relay.nn.relu(y2) y = relay.add(y1, y2) y = relay.multiply(y, out_scale) return relay.Function(args, y) - def expected(x, conv_weight, out_bias, out_scale, channels): + def expected(x, conv_weight, out_bias, out_scale, in_channels, channels, blocking): # use a fixed order of args so alpha equal check can pass args = [x, conv_weight, out_bias] - out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2) + if not blocking: + out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2) squeezed_scale = relay.squeeze(out_scale, axis=[1,2]) def fold_conv_weight(): - return relay.multiply( - conv_weight , - relay.expand_dims(squeezed_scale, axis=1, num_newaxis=3)) + if blocking: + return relay.multiply( + conv_weight , + relay.reshape(squeezed_scale, (channels//blocking[1], 1, 1, 1, 1, blocking[1]))) + else: + return relay.multiply( + conv_weight , + relay.expand_dims(squeezed_scale, axis=1, num_newaxis=3)) y1 = relay.nn.conv2d(x, fold_conv_weight(), channels=channels, kernel_size=(3, 3), - padding=(1, 1)) + padding=(1, 1), + data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW", + kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW") y1 = relay.nn.relu(y1) y2 = relay.nn.conv2d(x, fold_conv_weight(), channels=channels, kernel_size=(3, 3), - padding=(1, 1)) + padding=(1, 1), + data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW", + kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW") y2 = relay.nn.relu(y2) y = relay.add(y1, y2) return relay.Function(args, y) - def check(shape, channels): + def check(shape, in_channels, channels, blocking): x = relay.var("x", shape=shape) - in_channels = shape[1] weight = relay.var("weight") - out_bias = relay.var("out_bias", shape=(channels,)) - out_scale = relay.const(_get_positive_scale((channels, 1, 1))) - - y1 = before(x, weight, out_bias, out_scale, channels) + if blocking: + out_bias = relay.var("out_bias", shape=(channels // blocking[1], 1, 1, blocking[1])) + out_scale = relay.const(_get_positive_scale((channels // blocking[1], 1, 1, blocking[1]))) + else: + out_bias = relay.var("out_bias", shape=(channels,)) + out_scale = relay.const(_get_positive_scale((channels, 1, 1))) + + y1 = before(x, weight, out_bias, out_scale, in_channels, channels, blocking) y1 = run_opt_pass(y1, transform.InferType()) type_dict = {x.name_hint:x.checked_type for x in y1.params} weight = relay.var("weight", type_dict["weight"]) y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis()) - y1_expected = expected(x, weight, out_bias, out_scale, channels) + y1_expected = expected(x, weight, out_bias, out_scale, in_channels, channels, blocking) y1_expected = run_opt_pass(y1_expected, transform.InferType()) assert tvm.ir.structural_equal(y1_folded, y1_expected) - check((2, 4, 10, 10), 8) - + check((2, 4, 10, 10), 4, 8, None) + check((2, 2, 10, 10, 2), 4, 8, (2, 2)) def test_fold_bwd_dual_consumer(): - def before(x, conv_weight, out_bias, out_scale, channels): + def before(x, conv_weight, out_bias, out_scale, in_channels, channels, blocking): args = [x, conv_weight, out_bias] y0 = relay.nn.conv2d(x, conv_weight, channels=channels, kernel_size=(3, 3), - padding=(1, 1)) + padding=(1, 1), + data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW", + kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW") y0 = relay.multiply(y0, out_scale) y0 = relay.nn.relu(y0) y1 = relay.nn.conv2d(y0, conv_weight, channels=channels, kernel_size=(3, 3), - padding=(1, 1)) + padding=(1, 1), + data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW", + kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW") y1 = relay.multiply(y1, out_scale) y1 = relay.nn.relu(y1) y2 = relay.nn.conv2d(y0, conv_weight, channels=channels, kernel_size=(3, 3), - padding=(1, 1)) + padding=(1, 1), + data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW", + kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW") y2 = relay.multiply(y2, out_scale) y2 = relay.nn.relu(y2) y = relay.add(y1, y2) return relay.Function(args, y) - def expected(x, conv_weight, out_bias, out_scale, channels): + def expected(x, conv_weight, out_bias, out_scale, in_channels, channels, blocking): # use a fixed order of args so alpha equal check can pass args = [x, conv_weight, out_bias] def fold_conv_weight(): squeezed_scale = relay.squeeze(out_scale, axis=[1,2]) - return relay.multiply( - conv_weight , - relay.expand_dims(squeezed_scale, axis=1, num_newaxis=3)) + if blocking: + return relay.multiply( + conv_weight , + relay.reshape(squeezed_scale, (channels//blocking[1], 1, 1, 1, 1, blocking[1]))) + else: + return relay.multiply( + conv_weight , + relay.expand_dims(squeezed_scale, axis=1, num_newaxis=3)) y0 = relay.nn.conv2d(x, fold_conv_weight(), channels=channels, kernel_size=(3, 3), - padding=(1, 1)) + padding=(1, 1), + data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW", + kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW") y0 = relay.nn.relu(y0) y1 = relay.nn.conv2d(y0, fold_conv_weight(), channels=channels, kernel_size=(3, 3), - padding=(1, 1)) + padding=(1, 1), + data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW", + kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW") y1 = relay.nn.relu(y1) y2 = relay.nn.conv2d(y0, fold_conv_weight(), channels=channels, kernel_size=(3, 3), - padding=(1, 1)) + padding=(1, 1), + data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW", + kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW") y2 = relay.nn.relu(y2) y = relay.add(y1, y2) return relay.Function(args, y) - def check(shape, channels): + def check(shape, in_channels, channels, blocking): x = relay.var("x", shape=shape) - in_channels = shape[1] weight = relay.var("weight") - out_bias = relay.var("out_bias", shape=(channels,)) - out_scale = relay.const(_get_positive_scale((channels,1, 1))) - - y1 = before(x, weight, out_bias, out_scale, channels) + if blocking: + out_bias = relay.var("out_bias", shape=(channels // blocking[1], 1, 1, blocking[1])) + out_scale = relay.const(_get_positive_scale((channels // blocking[1], 1, 1, blocking[1]))) + else: + out_bias = relay.var("out_bias", shape=(channels,)) + out_scale = relay.const(_get_positive_scale((channels, 1, 1))) + + y1 = before(x, weight, out_bias, out_scale, in_channels, channels, blocking) y1 = run_opt_pass(y1, transform.InferType()) type_dict = {x.name_hint:x.checked_type for x in y1.params} weight = relay.var("weight", type_dict["weight"]) y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis()) - y1_expected = expected(x, weight, out_bias, out_scale, channels) + y1_expected = expected(x, weight, out_bias, out_scale, in_channels, channels, blocking) y1_expected = run_opt_pass(y1_expected, transform.InferType()) assert tvm.ir.structural_equal(y1_folded, y1_expected) - check((2, 4, 10, 10), 4) - + check((2, 4, 10, 10), 4, 4, None) + check((2, 2, 10, 10, 2), 4, 4, (2, 2)) def test_fold_bwd_fail(): """Dual path testcase.""" - def fail1(x, conv_weight, out_bias, out_scale, channels): + def fail1(x, conv_weight, out_bias, out_scale, in_channels, channels, blocking): args = [x, conv_weight, out_bias] - out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2) y1 = relay.nn.conv2d(x, conv_weight, channels=channels, kernel_size=(3, 3), - padding=(1, 1)) + padding=(1, 1), + data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW", + kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW") y1 = relay.nn.relu(y1) y2 = relay.nn.conv2d(x, conv_weight, channels=channels, kernel_size=(3, 3), padding=(1, 1), - out_layout="CNHW") + data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW", + kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW", + out_layout="CNHW{}c".format(blocking[1]) if blocking else "CNHW") # fold will fail because the axis from two path # differs from each other. y2 = relay.nn.relu(y2) @@ -458,99 +586,123 @@ def fail1(x, conv_weight, out_bias, out_scale, channels): y = relay.multiply(y, out_scale) return relay.Function(args, y) - def fail2(x, conv_weight, out_bias, out_scale, channels): + def fail2(x, conv_weight, out_bias, out_scale, in_channels, channels, blocking): args = [x, conv_weight, out_bias] - out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2) y1 = relay.nn.conv2d(x, conv_weight, channels=channels, kernel_size=(3, 3), - padding=(1, 1)) + padding=(1, 1), + data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW", + kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW") y2 = relay.nn.relu(y1) # fold will fail because y1 is referred also by y2 y1 = relay.multiply(y1, out_scale) y = relay.add(y1, y2) return relay.Function(args, y) - def check(shape, channels, fbefore): + def check(shape, in_channels, channels, blocking, fbefore): x = relay.var("x", shape=shape) - in_channels = shape[1] weight = relay.var("weight") - out_bias = relay.var("out_bias", shape=(channels,)) - out_scale = relay.const(_get_positive_scale((channels, 1, 1))) - y1 = fbefore(x, weight, out_bias, out_scale, channels) + if blocking: + out_bias = relay.var("out_bias", shape=(channels // blocking[1], 1, 1, blocking[1])) + out_scale = relay.const(_get_positive_scale((channels // blocking[1], 1, 1, blocking[1]))) + else: + out_bias = relay.var("out_bias", shape=(channels, 1, 1)) + out_scale = relay.const(_get_positive_scale((channels, 1, 1))) + y1 = fbefore(x, weight, out_bias, out_scale, in_channels, channels, blocking) y1 = run_opt_pass(y1, transform.InferType()) y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis()) assert tvm.ir.structural_equal(y1_folded, y1) - check((4, 4, 10, 10), 4, fail1) - check((4, 4, 10, 10), 4, fail2) + check((4, 4, 10, 10), 4, 4, None, fail1) + check((2, 2, 10, 10, 2), 4, 4, (2, 2), fail1) + check((4, 4, 10, 10), 4, 4, None, fail2) + check((4, 2, 10, 10, 2), 4, 4, (2, 2), fail2) def test_fold_bwd_relu_fail(): """testcase where we canont fold because scale can not pass relu""" - def before(x, conv_weight, out_scale, channels): + def before(x, conv_weight, out_scale, channels, blocking): y = relay.nn.conv2d(x, conv_weight, channels=channels, kernel_size=(3, 3), - data_layout="NCHW", - padding=(1, 1)) + padding=(1, 1), + data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW", + kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW") y = relay.nn.relu(y) y = relay.multiply(x, out_scale) return relay.Function(relay.analysis.free_vars(y), y) - def check(shape, channels, out_scale): + def check(shape, channels, blocking, out_scale): x = relay.var("x", shape=shape) in_channels = shape[1] weight = relay.var("weight") - y1 = before(x, weight, out_scale, channels) + y1 = before(x, weight, out_scale, channels, blocking) y1 = run_opt_pass(y1, transform.InferType()) y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis()) assert tvm.ir.structural_equal(y1, y1_folded) out_scale = relay.var("in_scale", shape=(4, 1, 1)) - check((4, 4, 10, 10), 4, out_scale) + check((4, 4, 10, 10), 4, None, out_scale) out_scale = relay.const(np.random.uniform(size=(4, 1, 1), low=-1.0, high=0.0)).astype("float32") - check((4, 4, 10, 10), 4, out_scale) + check((4, 4, 10, 10), 4, None, out_scale) + + out_scale = relay.var("in_scale", shape=(1, 2, 1, 1, 2)) + check((4, 2, 10, 10, 2), 4, (2, 2), out_scale) + out_scale = relay.const(np.random.uniform(size=(1, 2, 1, 1, 2), low=-1.0, high=0.0)).astype("float32") + check((4, 2, 10, 10, 2), 4, (2, 2), out_scale) def test_fold_bwd_negative_scale(): """Testcase of folding negative scale""" - def before(x, conv_weight, out_scale, channels): + def before(x, conv_weight, out_scale, channels, blocking): args = [x, conv_weight] y = relay.nn.conv2d(x, conv_weight, channels=channels, kernel_size=(3, 3), - padding=(1, 1)) + padding=(1, 1), + data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW", + kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW") y = relay.multiply(y, out_scale) return relay.Function(args, y) - def expected(x, conv_weight, out_scale, channels): + def expected(x, conv_weight, out_scale, channels, blocking): # use a fixed order of args so alpha equal check can pass args = [x, conv_weight] - squeezed_scale = relay.squeeze(out_scale, axis=[1,2]) - conv_weight = relay.multiply( - conv_weight , relay.expand_dims(squeezed_scale, axis=1, num_newaxis=3)) + if blocking: + squeezed_scale = relay.squeeze(out_scale, axis=[0,2,3]) + conv_weight = relay.multiply( + conv_weight , relay.reshape(squeezed_scale, (channels//blocking[1], 1, 1, 1, 1, blocking[1]))) + else: + squeezed_scale = relay.squeeze(out_scale, axis=[1,2]) + conv_weight = relay.multiply( + conv_weight , relay.expand_dims(squeezed_scale, axis=1, num_newaxis=3)) y = relay.nn.conv2d(x, conv_weight, channels=channels, kernel_size=(3, 3), - padding=(1, 1)) + padding=(1, 1), + data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW", + kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW") return relay.Function(args, y) - def check(shape, channels): + def check(shape, channels, blocking): x = relay.var("x", shape=shape) weight = relay.var("weight") - out_scale = relay.const(-_get_positive_scale((channels, 1, 1))) - y1 = before(x, weight, out_scale, channels) + if blocking: + out_scale = relay.const(-_get_positive_scale((1,channels//blocking[1], 1, 1, blocking[1]))) + else: + out_scale = relay.const(-_get_positive_scale((channels, 1, 1))) + y1 = before(x, weight, out_scale, channels, blocking) y1 = run_opt_pass(y1, transform.InferType()) type_dict = {x.name_hint:x.checked_type for x in y1.params} weight = relay.var("weight", type_dict["weight"]) y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis()) - y1_expected = expected(x, weight, out_scale, channels) + y1_expected = expected(x, weight, out_scale, channels, blocking) y1_expected = run_opt_pass(y1_expected, transform.InferType()) assert tvm.ir.structural_equal(y1_folded, y1_expected) - check((2, 4, 10, 10), 8) - + check((2, 4, 10, 10), 8, None) + check((2, 2, 10, 10, 2), 8, (2, 2)) if __name__ == "__main__": test_fold_fwd_simple() diff --git a/tests/python/relay/test_pass_manager.py b/tests/python/relay/test_pass_manager.py index 28ccf6f5f941..25299caae30b 100644 --- a/tests/python/relay/test_pass_manager.py +++ b/tests/python/relay/test_pass_manager.py @@ -382,7 +382,7 @@ def test_no_pass(): def test_only_module_pass(): passes = [module_pass] sequential = tvm.transform.Sequential(opt_level=1, passes=passes) - with relay.build_config(required_pass=["mod_transform"]): + with tvm.transform.PassContext(required_pass=["mod_transform"]): ret_mod = sequential(mod) # Check the subtract function. sub_var, new_sub = extract_var_func(ret_mod, v_sub.name_hint) @@ -397,7 +397,7 @@ def test_only_function_pass(): # Check the subtract function. passes = [function_pass] sequential = tvm.transform.Sequential(opt_level=1, passes=passes) - with relay.build_config(required_pass=["func_transform"]): + with tvm.transform.PassContext(required_pass=["func_transform"]): ret_mod = sequential(mod) _, new_sub = extract_var_func(ret_mod, v_sub.name_hint) check_func(new_sub, get_ref_sub()) @@ -413,7 +413,7 @@ def test_multiple_passes(): passes = [module_pass, function_pass] sequential = tvm.transform.Sequential(opt_level=1, passes=passes) required = ["mod_transform", "func_transform"] - with relay.build_config(required_pass=required): + with tvm.transform.PassContext(required_pass=required): ret_mod = sequential(mod) # Check the abs function is added. @@ -490,7 +490,7 @@ def expected(): ]) mod = tvm.IRModule({"main": before()}) - with relay.build_config(opt_level=3): + with tvm.transform.PassContext(opt_level=3): with tvm.target.create("llvm"): mod = seq(mod) @@ -515,7 +515,7 @@ def test_print_ir(capfd): ]) mod = tvm.IRModule({"main": func}) - with relay.build_config(opt_level=3): + with tvm.transform.PassContext(opt_level=3): mod = seq(mod) out = capfd.readouterr().err @@ -530,6 +530,7 @@ def _tracer(module, info, is_before): if bool(is_before): __TRACE_COUNTER__ += 1 + def test_print_debug_callback(): global __TRACE_COUNTER__ shape = (1, 2, 3) @@ -548,10 +549,10 @@ def test_print_debug_callback(): assert __TRACE_COUNTER__ == 0 mod = tvm.IRModule({"main": func}) - with relay.build_config(opt_level=3, trace=_tracer): + with tvm.transform.PassContext(opt_level=3, trace=_tracer): mod = seq(mod) - assert __TRACE_COUNTER__ == 4 + assert __TRACE_COUNTER__ == 3 if __name__ == "__main__": diff --git a/tests/python/relay/test_pass_merge_composite.py b/tests/python/relay/test_pass_merge_composite.py index e3c8991c8ebc..f2d615e9046a 100644 --- a/tests/python/relay/test_pass_merge_composite.py +++ b/tests/python/relay/test_pass_merge_composite.py @@ -16,10 +16,11 @@ # under the License. """Unit tests for merge composite.""" import tvm -from tvm import relay -from tvm import tir +from tvm import relay, tir +from tvm.relay.dataflow_pattern import TupleGetItemPattern, is_op, wildcard from tvm.relay.testing import run_opt_pass + """ The merge composite pass is designed to merge multiple relay operators, that match a given pattern, and combine them into a single relay function. @@ -64,37 +65,32 @@ def make_add_sub_mul_pattern(): - """Create a pattern to match the following graph. + r"""Create a pattern to match the following graph. add sub \ / \ / mul """ - x = relay.var('x') - y = relay.var('y') - add_node = relay.add(x, y) - sub_node = relay.subtract(x, y) - mul_node = relay.multiply(add_node, sub_node) - return mul_node + x = wildcard() + y = wildcard() + return (x + y) * (x - y) def make_add_relu_pattern(): - """Create a pattern to match the following graph. + r"""Create a pattern to match the following graph. add | relu """ - x = relay.var('x') - y = relay.var('y') - add_node = relay.add(x, y) - r = relay.nn.relu(add_node) + add_node = wildcard() + wildcard() + r = is_op('nn.relu')(add_node) return r def make_conv_bias_relu_pattern(): - """Create a pattern to match the following graph. + r"""Create a pattern to match the following graph. conv2d | @@ -102,17 +98,35 @@ def make_conv_bias_relu_pattern(): | relu """ - x = relay.var('x') - y = relay.var('y') - z = relay.var('z') - conv_node = relay.nn.conv2d(x, y) - bias_node = relay.nn.bias_add(conv_node, z) - r = relay.nn.relu(bias_node) + x = wildcard() + y = wildcard() + z = wildcard() + conv_node = is_op('nn.conv2d')(x, y) + bias_node = is_op('nn.bias_add')(conv_node, z) + r = is_op('nn.relu')(bias_node) + return r + + +def make_pattern_with_optional(): + r"""Create a pattern to match the following graph. Note that relu is optinal. + + conv2d + | + bias_add + | + (relu) + """ + x = wildcard() + y = wildcard() + z = wildcard() + conv_node = is_op('nn.conv2d')(x, y) + bias_node = is_op('nn.bias_add')(conv_node, z) + r = bias_node.optional(lambda x: is_op('nn.relu')(x)) return r def make_add_add_add_pattern(): - """Create a pattern to match the following graph. + r"""Create a pattern to match the following graph. Useful for testing re-using a call node. x y @@ -123,15 +137,15 @@ def make_add_add_add_pattern(): | / add """ - x = relay.var('x') - y = relay.var('y') - add_node = relay.add(x, y) - add_node_1 = relay.add(x, add_node) - r = relay.add(add_node_1, add_node) + x = wildcard() + y = wildcard() + add_node = is_op('add')(x, y) + add_node_1 = is_op('add')(x, add_node) + r = is_op('add')(add_node_1, add_node) return r def make_bn_relu_pattern(): - """Create a pattern to match the following graph. + r"""Create a pattern to match the following graph. batch_norm | @@ -139,19 +153,27 @@ def make_bn_relu_pattern(): | relu """ - x = relay.var('x') - gamma = relay.var("gamma") - beta = relay.var("beta") - moving_mean = relay.var("moving_mean") - moving_var = relay.var("moving_var") - bn_node = relay.nn.batch_norm(x, gamma, beta, moving_mean, moving_var) - tuple_get_item_node = bn_node[0] - r = relay.nn.relu(tuple_get_item_node) + x = wildcard() + gamma = wildcard() + beta = wildcard() + moving_mean = wildcard() + moving_var = wildcard() + bn_node = is_op('nn.batch_norm')(x, gamma, beta, moving_mean, moving_var) + tuple_get_item_node = TupleGetItemPattern(bn_node, 0) + r = is_op('nn.relu')(tuple_get_item_node) return r +def check_result(pattern_table, graph, expected_graph): + """Utility function to check merge composite results.""" + result = run_opt_pass(graph, relay.transform.MergeComposite(pattern_table)) + assert not relay.analysis.free_vars(result), \ + "Found free vars in the result graph: {0}".format(str(result)) + expected = run_opt_pass(expected_graph, relay.transform.InferType()) + assert tvm.ir.structural_equal(result, expected, map_free_vars=True), \ + "Graph mismatch: output vs. expected\n{0}\n=====\n{1}".format(str(result), str(expected)) def test_simple_merge(): - """Test composite function is correctly produced from simple graph. + r"""Test composite function is correctly produced from simple graph. We could expect the pattern `make_add_relu_pattern` to be merged into a single op `add_relu`. @@ -185,19 +207,17 @@ def expected(): relu_node = relay.nn.relu(add_node) add_relu = relay.Function([in_1, in_2], relu_node) add_relu = add_relu.with_attr("Composite", "add_relu") + add_relu = add_relu.with_attr("PartitionedFromPattern", "add_nn.relu_") # merged function r = relay.Call(add_relu, [a, b]) return relay.Function([a, b], r) - result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table)) - assert not relay.analysis.free_vars(result) - expected = run_opt_pass(expected(), relay.transform.InferType()) - assert tvm.ir.structural_equal(result, expected, map_free_vars=True) + check_result(pattern_table, before(), expected()) def test_branch_merge(): - """Test composite function is correctly produced from branching graph. + r"""Test composite function is correctly produced from branching graph. We would expect the pattern `make_add_sub_mul_pattern` to be merged into a single op `add_sub_mul`. @@ -250,6 +270,7 @@ def expected(): mul_node = relay.multiply(add_node, sub_node) add_sub_mul = relay.Function([in_1, in_2], mul_node) add_sub_mul = add_sub_mul.with_attr("Composite", "add_sub_mul") + add_sub_mul = add_sub_mul.with_attr("PartitionedFromPattern", "add_subtract_multiply_") # add_sub_mul1 function in_3 = relay.var('in_3', shape=(10, 10)) @@ -259,6 +280,7 @@ def expected(): mul_node_1 = relay.multiply(add_node_1, sub_node_1) add_sub_mul_1 = relay.Function([in_3, in_4], mul_node_1) add_sub_mul_1 = add_sub_mul_1.with_attr("Composite", "add_sub_mul") + add_sub_mul_1 = add_sub_mul_1.with_attr("PartitionedFromPattern", "add_subtract_multiply_") # merged function m_add_sub_mul_1 = relay.Call(add_sub_mul, [a, b]) @@ -266,14 +288,11 @@ def expected(): r = relay.nn.relu(m_add_sub_mul_2) return relay.Function([a, b, c], r) - result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table)) - assert not relay.analysis.free_vars(result) - expected = run_opt_pass(expected(), relay.transform.InferType()) - assert tvm.ir.structural_equal(result, expected, map_free_vars=True) + check_result(pattern_table, before(), expected()) def test_reuse_call_merge(): - """Test composite function is correctly produced from simple graph + r"""Test composite function is correctly produced from simple graph which re-uses call nodes. We could expect the pattern `make_add_add_add` to be merged @@ -318,20 +337,18 @@ def expected(): add_node_2 = relay.add(add_node_1, add_node) add_add_add = relay.Function([in_1, in_2], add_node_2) add_add_add = add_add_add.with_attr("Composite", "add_add_add") + add_add_add = add_add_add.with_attr("PartitionedFromPattern", "add_add_add_") # merged function sub_node = relay.subtract(a, b) call = relay.Call(add_add_add, [sub_node, b]) return relay.Function([a, b], call) - result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table)) - assert not relay.analysis.free_vars(result) - expected = run_opt_pass(expected(), relay.transform.InferType()) - assert tvm.ir.structural_equal(result, expected, map_free_vars=True) + check_result(pattern_table, before(), expected()) def test_multiple_patterns(): - """Test different patterns are merged correctly in the graph. + r"""Test different patterns are merged correctly in the graph. We would expect the pattern `make_conv_bias_relu_pattern` to be merged into a single op `conv_bias_relu`. We would also expect `make_add_relu_pattern` @@ -402,6 +419,8 @@ def expected(): conv_bias_add_relu = relay.Function([in_1, in_2, in_3], r) conv_bias_add_relu = conv_bias_add_relu.with_attr("Composite", "conv2d_bias_relu") + conv_bias_add_relu = conv_bias_add_relu.with_attr("PartitionedFromPattern", + "nn.conv2d_nn.bias_add_nn.relu_") # add_relu function in_4 = relay.var('in_4', shape=(1, 256, 28, 28)) @@ -410,6 +429,7 @@ def expected(): r = relay.nn.relu(add_node) add_relu = relay.Function([in_4, in_5], r) add_relu = add_relu.with_attr("Composite", "add_relu") + add_relu = add_relu.with_attr("PartitionedFromPattern", "add_nn.relu_") # merged function conv_bias_add_relu_1 = relay.Call(conv_bias_add_relu, [data, kernel, bias]) @@ -417,14 +437,79 @@ def expected(): r = relay.multiply(add_relu_1, b) return relay.Function([data, kernel, bias, a, b], r) - result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table)) - assert not relay.analysis.free_vars(result) - expected = run_opt_pass(expected(), relay.transform.InferType()) - assert tvm.ir.structural_equal(result, expected, map_free_vars=True) + check_result(pattern_table, before(), expected()) + + +def test_optional_pattern(): + r"""Test the pattern with optional operators. We can define a pattern with some operators + optional. The merge composite pass will create composite functions for all matched patterns, + but with different "PartitionedFromPattern" attribute. We expect the backend codegen to + analyze that attribute and determine the corresponding action. + + Pattern: Matched Case A: Matched Case B: + + conv2d conv2d conv2d + | | | + bias_add bias_add bias_add + | | + (relu) relu + + In the above example, the composite function for matched case A would have + PartitionedFromPattern="nn.conv2d_nn.bias_add_nn.relu_" while the one for matched case B + woud be "nn.conv2d_nn.bias_add_". + """ + pattern_table = [("layer", make_pattern_with_optional())] + + def before(): + x = relay.var('x', shape=(1, 3, 7, 7)) + w1 = relay.var('w', shape=(3, 3, 1, 1)) + b1 = relay.var('b', shape=(3, )) + w2 = relay.var('w', shape=(3, 3, 1, 1)) + b2 = relay.var('b', shape=(3, )) + conv = relay.nn.conv2d(x, w1, kernel_size=(1, 1)) + bias = relay.nn.bias_add(conv, b1) + relu = relay.nn.relu(bias) + conv = relay.nn.conv2d(relu, w2, kernel_size=(1, 1)) + bias = relay.nn.bias_add(conv, b2) + return relay.Function([x, w1, w2, b1, b2], bias) + + def expected(): + # Matched composite function A + x = relay.var('x') + w = relay.var('w') + b = relay.var('b') + conv = relay.nn.conv2d(x, w, kernel_size=(1, 1)) + bias = relay.nn.bias_add(conv, b) + relu = relay.nn.relu(bias) + func1 = relay.Function([x, w, b], relu) + func1 = func1.with_attr("Composite", "layer") + func1 = func1.with_attr("PartitionedFromPattern", "nn.conv2d_nn.bias_add_nn.relu_") + + # Matched composite function B + x = relay.var('x') + w = relay.var('w') + b = relay.var('b') + conv = relay.nn.conv2d(x, w, kernel_size=(1, 1)) + bias = relay.nn.bias_add(conv, b) + func2 = relay.Function([x, w, b], bias) + func2 = func2.with_attr("Composite", "layer") + func2 = func2.with_attr("PartitionedFromPattern", "nn.conv2d_nn.bias_add_") + + # Main function + x = relay.var('x', shape=(1, 3, 7, 7)) + w1 = relay.var('w', shape=(3, 3, 1, 1)) + b1 = relay.var('b', shape=(3, )) + w2 = relay.var('w', shape=(3, 3, 1, 1)) + b2 = relay.var('b', shape=(3, )) + out1 = func1(x, w1, b1) + out2 = func2(out1, w2, b2) + return relay.Function([x, w1, w2, b1, b2], out2) + + check_result(pattern_table, before(), expected()) def test_merge_order(): - """Test that patterns are merged in the order they exist in the pattern table. + r"""Test that patterns are merged in the order they exist in the pattern table. There can be cases where one pattern is a subgraph of another, in which case it is not clear which match should take priority. The priority should come @@ -441,24 +526,24 @@ def test_merge_order(): """ def pattern_A(): - x = relay.var('x') - y = relay.var('y') - out = relay.add(x, y) - out = relay.abs(out) - out = relay.nn.relu(out) + x = wildcard() + y = wildcard() + out = is_op('add')(x, y) + out = is_op('abs')(out) + out = is_op('nn.relu')(out) return out def pattern_B(): - x = relay.var('x') - y = relay.var('y') - out = relay.add(x, y) - out = relay.abs(out) + x = wildcard() + y = wildcard() + out = is_op('add')(x, y) + out = is_op('abs')(out) return out def pattern_C(): - x = relay.var('x') - out = relay.abs(x) - out = relay.nn.relu(x) + x = wildcard() + out = is_op('abs')(x) + out = is_op('nn.relu')(out) return out def before(): @@ -469,7 +554,7 @@ def before(): out = relay.nn.relu(out) return relay.Function([input_1, input_2], out) - def after_A_priority(composite_name): + def after_A_priority(): input_1 = relay.var('input_1', shape=(10, 10)) input_2 = relay.var('input_2', shape=(10, 10)) x = relay.var('x') @@ -478,46 +563,65 @@ def after_A_priority(composite_name): out = relay.abs(out) out = relay.nn.relu(out) merged_func = relay.Function([x, y], out) - merged_func = merged_func.with_attr('Composite', composite_name) + merged_func = merged_func.with_attr('Composite', 'A') + merged_func = merged_func.with_attr('PartitionedFromPattern', 'add_abs_nn.relu_') ret = relay.Call(merged_func, [input_1, input_2]) return relay.Function([input_1, input_2], ret) + def after_B_priority(): + input_1 = relay.var('input_1', shape=(10, 10)) + input_2 = relay.var('input_2', shape=(10, 10)) + x = relay.var('x') + y = relay.var('y') + out = relay.add(x, y) + out = relay.abs(out) + merged_func = relay.Function([x, y], out) + merged_func = merged_func.with_attr('Composite', 'B') + merged_func = merged_func.with_attr('PartitionedFromPattern', 'add_abs_') + out = relay.Call(merged_func, [input_1, input_2]) + ret = relay.nn.relu(out) + return relay.Function([input_1, input_2], ret) + + def after_C_priority(): + input_1 = relay.var('input_1', shape=(10, 10)) + input_2 = relay.var('input_2', shape=(10, 10)) + x = relay.var('x') + out = relay.abs(x) + out = relay.nn.relu(out) + merged_func = relay.Function([x], out) + merged_func = merged_func.with_attr('Composite', 'C') + merged_func = merged_func.with_attr('PartitionedFromPattern', 'abs_nn.relu_') + out = relay.add(input_1, input_2) + ret = relay.Call(merged_func, [out]) + return relay.Function([input_1, input_2], ret) + # check A highest priority pattern_table = [ ("A", pattern_A()), ("B", pattern_B()), ("C", pattern_C()), ] - result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table)) - assert not relay.analysis.free_vars(result) - expected = run_opt_pass(after_A_priority("A"), relay.transform.InferType()) - assert tvm.ir.structural_equal(result, expected, map_free_vars=True) + check_result(pattern_table, before(), after_A_priority()) # check B highest priority pattern_table = [ - ("B", pattern_A()), - ("C", pattern_B()), - ("A", pattern_C()), + ("B", pattern_B()), + ("C", pattern_C()), + ("A", pattern_A()), ] - result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table)) - assert not relay.analysis.free_vars(result) - expected = run_opt_pass(after_A_priority("B"), relay.transform.InferType()) - assert tvm.ir.structural_equal(result, expected, map_free_vars=True) + check_result(pattern_table, before(), after_B_priority()) # check C highest priority pattern_table = [ - ("C", pattern_A()), - ("A", pattern_B()), - ("B", pattern_C()), + ("C", pattern_C()), + ("A", pattern_A()), + ("B", pattern_B()), ] - result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table)) - assert not relay.analysis.free_vars(result) - expected = run_opt_pass(after_A_priority("C"), relay.transform.InferType()) - assert tvm.ir.structural_equal(result, expected, map_free_vars=True) + check_result(pattern_table, before(), after_C_priority()) def test_parallel_merge(): - """Tests that parallel patterns relying on the same inputs are correctly merged. + r"""Tests that parallel patterns relying on the same inputs are correctly merged. The test graph is difficult to draw out as ascii art. It is essentially two parallel add-sub-mul units which both consume input_1 and input_2 with their results being multiplied @@ -536,7 +640,7 @@ def before(): out = relay.multiply(branch_1, branch_2) return relay.Function([input_1, input_2], out) - def after(): + def expected(): input_1 = relay.var('input_1', shape=(10, 10)) input_2 = relay.var('input_2', shape=(10, 10)) x = relay.var('x') @@ -544,12 +648,14 @@ def after(): branch_1 = relay.multiply(relay.add(x, y), relay.subtract(x, y)) func_1 = relay.Function([x, y], branch_1) func_1 = func_1.with_attr('Composite', "add_sub_mul") + func_1 = func_1.with_attr('PartitionedFromPattern', "add_subtract_multiply_") call_1 = relay.Call(func_1, [input_1, input_2]) x1 = relay.var('x1') y1 = relay.var('y1') branch_2 = relay.multiply(relay.add(x1, y1), relay.subtract(x1, y1)) func_2 = relay.Function([x1, y1], branch_2) func_2 = func_2.with_attr('Composite', "add_sub_mul") + func_2 = func_2.with_attr('PartitionedFromPattern', "add_subtract_multiply_") call_2 = relay.Call(func_2, [input_1, input_2]) out = relay.multiply(call_1, call_2) return relay.Function([input_1, input_2], out) @@ -557,14 +663,11 @@ def after(): pattern_table = [ ("add_sub_mul", make_add_sub_mul_pattern()) ] - result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table)) - assert not relay.analysis.free_vars(result) - expected = run_opt_pass(after(), relay.transform.InferType()) - assert tvm.ir.structural_equal(result, expected, map_free_vars=True) + check_result(pattern_table, before(), expected()) def test_multiple_input_subgraphs(): - """Test the case when multiple input subgraphs feed into another subgraph. + r"""Test the case when multiple input subgraphs feed into another subgraph. (1) (2) (3) (4) add add add add @@ -629,6 +732,7 @@ def after_A(): add_relu_1 = relay.nn.relu(add_relu_1) add_relu_1 = relay.Function([x, y], add_relu_1) add_relu_1 = add_relu_1.with_attr('Composite', 'add_relu') + add_relu_1 = add_relu_1.with_attr('PartitionedFromPattern', 'add_nn.relu_') add_relu_call_1 = relay.Call(add_relu_1, [inputs[0], inputs[1]]) x1 = relay.var('x1') y1 = relay.var('y1') @@ -636,6 +740,7 @@ def after_A(): add_relu_2 = relay.nn.relu(add_relu_2) add_relu_2 = relay.Function([x1, y1], add_relu_2) add_relu_2 = add_relu_2.with_attr('Composite', 'add_relu') + add_relu_2 = add_relu_2.with_attr('PartitionedFromPattern', 'add_nn.relu_') add_relu_call_2 = relay.Call(add_relu_2, [inputs[2], inputs[3]]) x2 = relay.var('x2') y2 = relay.var('y2') @@ -644,6 +749,7 @@ def after_A(): add_sub_mul = relay.multiply(add, sub) add_sub_mul = relay.Function([x2, y2], add_sub_mul) add_sub_mul = add_sub_mul.with_attr('Composite', 'add_sub_mul') + add_sub_mul = add_sub_mul.with_attr('PartitionedFromPattern', 'add_subtract_multiply_') add_sub_mul_call = relay.Call(add_sub_mul, [add_relu_call_1, add_relu_call_2]) return relay.Function(inputs, add_sub_mul_call) @@ -657,6 +763,7 @@ def after_B(): add_relu = relay.nn.relu(add_relu) add_relu = relay.Function([x, y], add_relu) add_relu = add_relu.with_attr('Composite', 'add_relu') + add_relu = add_relu.with_attr('PartitionedFromPattern', 'add_nn.relu_') add_relu_call = relay.Call(add_relu, [inputs[i*2], inputs[i*2+1]]) add_relu_calls.append(add_relu_call) @@ -669,17 +776,8 @@ def after_B(): ("add_sub_mul", make_add_sub_mul_pattern()), ("add_relu", make_add_relu_pattern()) ] - # check case 'A' - result = run_opt_pass(before()['A'], relay.transform.MergeComposite(pattern_table)) - assert not relay.analysis.free_vars(result) - expected = run_opt_pass(after_A(), relay.transform.InferType()) - assert tvm.ir.structural_equal(result, expected, map_free_vars=True) - - # check case 'B' - result = run_opt_pass(before()['B'], relay.transform.MergeComposite(pattern_table)) - assert not relay.analysis.free_vars(result) - expected = run_opt_pass(after_B(), relay.transform.InferType()) - assert tvm.ir.structural_equal(result, expected, map_free_vars=True) + check_result(pattern_table, before()['A'], after_A()) + check_result(pattern_table, before()['B'], after_B()) def test_tuple_get_item_merge(): @@ -717,15 +815,14 @@ def expected(): relu_node = relay.nn.relu(tuple_get_item_node) bn_relu = relay.Function([in_1, in_2, in_3, in_4, in_5], relu_node) bn_relu = bn_relu.with_attr("Composite", "bn_relu") + bn_relu = bn_relu.with_attr("PartitionedFromPattern", + "nn.batch_norm_TupleGetItem0_nn.relu_") # merged function r = relay.Call(bn_relu, [x, gamma, beta, moving_mean, moving_var]) return relay.Function([x, gamma, beta, moving_mean, moving_var], r) - result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table)) - assert not relay.analysis.free_vars(result) - expected = run_opt_pass(expected(), relay.transform.InferType()) - assert tvm.ir.structural_equal(result, expected, map_free_vars=True) + check_result(pattern_table, before(), expected()) def test_pattern_with_check(): @@ -750,28 +847,166 @@ def _check_false(extract): conv = extract.args[0].args[0] return conv.attrs.data_layout == "NCHW" + def expected(): + x = relay.var('x') + w = relay.var('w') + b = relay.var('b') + conv = relay.nn.conv2d(x, w, kernel_size=(3, 3), kernel_layout="OIHW", data_layout="NHWC") + bias = relay.nn.bias_add(conv, b) + relu = relay.nn.relu(bias) + func = relay.Function([x, w, b], relu) + func = func.with_attr("Composite", "conv_bias_relu") + func = func.with_attr("PartitionedFromPattern", "nn.conv2d_nn.bias_add_nn.relu_") + + x = relay.var('x', shape=(1, 10, 10, 10)) + w = relay.var('w', shape=(10, 10, 3, 3)) + b = relay.var('b', shape=(8,)) + return relay.Function([x, w, b], func(x, w, b)) + + pattern_table_false = [ + ("conv_bias_relu", make_conv_bias_relu_pattern(), _check_false) + ] + check_result(pattern_table_false, before(), before()) + pattern_table_true = [ ("conv_bias_relu", make_conv_bias_relu_pattern(), _check_true) ] + check_result(pattern_table_true, before(), expected()) + + +def test_diamond_not_merge(): + r""" + The pattern on the left shouldn't match the structure on the right + + relu relu + | \ | \ + | clip | add + | / | | + mul | clip + | / + mul + """ + def get_pattern(): + conv = make_conv_bias_relu_pattern() + clip = is_op('clip')(conv, wildcard(), wildcard()) + return is_op('multiply')(conv, clip) + + def get_net(): + data = relay.var('data', shape=(1, 512, 28, 28)) + kernel = relay.var('kernel', shape=(256, 512, 1, 1)) + conv = relay.nn.conv2d(data, kernel, + kernel_size=(1, 1), + padding=(0, 0), + strides=(1, 1)) + bias = relay.nn.bias_add(conv, relay.var('bias', shape=(256,))) + relu = relay.nn.relu(bias) + add = relay.op.add(relu, relay.const(1.0)) + clip2 = relay.op.clip(add, 0, 255) + mul = relay.op.multiply(relu, clip2) + return relay.Function(relay.analysis.free_vars(mul), mul) + + pattern_table = [("pat", get_pattern())] + net = get_net() + check_result(pattern_table, net, net) + + +def test_type_check(): + """Test that we can query tensor types in the 'check' function.""" + def before(): + x = relay.var('x', shape=(1, 10, 10, 10)) + w = relay.var('w', shape=(10, 10, 3, 3)) + b = relay.var('b', shape=(8,)) + add = relay.op.add(x, x) + relu = relay.nn.relu(add) + conv = relay.nn.conv2d(relu, + w, + kernel_size=(3, 3), + kernel_layout="OIHW", + data_layout="NHWC") + bias = relay.nn.bias_add(conv, b) + relu2 = relay.nn.relu(bias) + return run_opt_pass(relay.Function([x, w, b], relu2), relay.transform.InferType()) + + def expected_false(): + x = relay.var('x', shape=(1, 10, 10, 10)) + w = relay.var('w', shape=(10, 10, 3, 3)) + b = relay.var('b', shape=(8, )) + + x0 = relay.var('x') + y0 = relay.var('y') + + add = relay.op.add(y0, y0) + relu = relay.nn.relu(add) + func = relay.Function([x0, y0], relu) + func = func.with_attr("PartitionedFromPattern", "add_nn.relu_") + func = func.with_attr("Composite", "add_relu") + call = relay.Call(func, [x, x]) + + conv = relay.nn.conv2d(call, w, kernel_size=(3, 3), kernel_layout="OIHW", data_layout="NHWC") + bias = relay.nn.bias_add(conv, b) + relu2 = relay.nn.relu(bias) + return relay.Function([x, w, b], relu2) + + def expected_true(): + x = relay.var('x', shape=(1, 10, 10, 10)) + w = relay.var('w', shape=(10, 10, 3, 3)) + b = relay.var('b', shape=(8, )) + + x0 = relay.var('x') + y0 = relay.var('y') + + add = relay.op.add(y0, y0) + relu = relay.nn.relu(add) + func = relay.Function([x0, y0], relu) + func = func.with_attr("PartitionedFromPattern", "add_nn.relu_") + func = func.with_attr("Composite", "add_relu") + call = relay.Call(func, [x, x]) + + x2 = relay.var('x') + w1 = relay.var('w') + b1 = relay.var('b') + conv = relay.nn.conv2d(x2, w1, kernel_size=(3, 3), kernel_layout="OIHW", data_layout="NHWC") + bias = relay.nn.bias_add(conv, b1) + relu2 = relay.nn.relu(bias) + func = relay.Function([x2, w1, b1], relu2) + func = func.with_attr("Composite", "conv_bias_relu") + func = func.with_attr("PartitionedFromPattern", "nn.conv2d_nn.bias_add_nn.relu_") + call = relay.Call(func, [call, w, b]) + return relay.Function([x, w, b], call) + + def _check_type_true(extract): + conv = extract.args[0].args[0] + typ = conv.checked_type + return bool(typ.shape[0] == 1) + + def _check_type_false(extract): + conv = extract.args[0].args[0] + typ = conv.checked_type + return bool(typ.shape[0] != 1) + pattern_table_false = [ - ("conv_bias_relu", make_conv_bias_relu_pattern(), _check_false) + ("add_relu", make_add_relu_pattern()), + ("conv_bias_relu", make_conv_bias_relu_pattern(), _check_type_false) ] + check_result(pattern_table_false, before(), expected_false()) - result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table_false)) - expected = run_opt_pass(before(), relay.transform.InferType()) - assert tvm.ir.structural_equal(result, expected, map_free_vars=True) - - result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table_true)) - assert result.body.op.attrs["Composite"] == "conv_bias_relu" + pattern_table_true = [ + ("add_relu", make_add_relu_pattern()), + ("conv_bias_relu", make_conv_bias_relu_pattern(), _check_type_true) + ] + check_result(pattern_table_true, before(), expected_true()) if __name__ == "__main__": test_simple_merge() test_branch_merge() test_multiple_patterns() + test_optional_pattern() test_merge_order() test_parallel_merge() test_multiple_input_subgraphs() test_reuse_call_merge() test_tuple_get_item_merge() test_pattern_with_check() + test_diamond_not_merge() + test_type_check() diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py index 3261ccd0d7c9..473ca9d66106 100644 --- a/tests/python/relay/test_pass_partition_graph.py +++ b/tests/python/relay/test_pass_partition_graph.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. """Unit tests for graph partitioning.""" +# pylint: disable=not-callable import os import sys @@ -194,19 +195,22 @@ def update_lib(lib): def check_vm_result(): compile_engine.get().clear() - with relay.build_config(opt_level=3): + with tvm.transform.PassContext(opt_level=3): exe = relay.vm.compile(mod, target=target, params=params) code, lib = exe.save() lib = update_lib(lib) exe = runtime.vm.Executable.load_exec(code, lib) vm = runtime.vm.VirtualMachine(exe) vm.init(ctx) - out = vm.run(**map_inputs) - tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol) + outs = vm.run(**map_inputs) + outs = outs if isinstance(outs, runtime.container.ADT) else [outs] + results = result if isinstance(result, list) else [result] + for out, ref in zip(outs, results): + tvm.testing.assert_allclose(out.asnumpy(), ref, rtol=tol, atol=tol) def check_graph_runtime_result(): compile_engine.get().clear() - with relay.build_config(opt_level=3): + with tvm.transform.PassContext(opt_level=3): json, lib, param = relay.build(mod, target=target, params=params) lib = update_lib(lib) rt_mod = tvm.contrib.graph_runtime.create(json, lib, ctx) @@ -215,10 +219,14 @@ def check_graph_runtime_result(): rt_mod.set_input(name, data) rt_mod.set_input(**param) rt_mod.run() - out = tvm.nd.empty(out_shape, ctx=ctx) - out = rt_mod.get_output(0, out) - tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol) + out_shapes = out_shape if isinstance(out_shape, list) else [out_shape] + results = result if isinstance(result, list) else [result] + + for idx, shape in enumerate(out_shapes): + out = tvm.nd.empty(shape, ctx=ctx) + out = rt_mod.get_output(idx, out) + tvm.testing.assert_allclose(out.asnumpy(), results[idx], rtol=tol, atol=tol) check_vm_result() check_graph_runtime_result() @@ -457,7 +465,6 @@ def test_extern_dnnl_mobilenet(): mod, params = relay.testing.mobilenet.get_workload( batch_size=1, dtype='float32') - mod["main"] = bind_params_by_name(mod["main"], params) mod = transform.AnnotateTarget(["dnnl"])(mod) mod = transform.MergeCompilerRegions()(mod) mod = transform.PartitionGraph()(mod) @@ -505,7 +512,7 @@ def partition(): transform.AlterOpLayout(), ]) - with relay.build_config(opt_level=3): + with tvm.transform.PassContext(opt_level=3): mod = opt_pass(mod) return mod @@ -523,8 +530,8 @@ def expected(): bn = relay.nn.batch_norm(data0, bn_gamma, bn_beta, bn_mmean, bn_mvar) func0 = relay.Function([data0, bn_gamma, bn_beta, bn_mmean, bn_mvar], bn.astuple()) - func0 = set_func_attr(func0, "test_compiler", "test_compiler_0") - gv0 = relay.GlobalVar("test_compiler_0") + func0 = set_func_attr(func0, "test_compiler", "test_compiler_2") + gv0 = relay.GlobalVar("test_compiler_2") mod[gv0] = func0 # function for conv2d @@ -537,8 +544,8 @@ def expected(): channels=16, padding=(1, 1)) func1 = relay.Function([data1, weight1], conv) - func1 = set_func_attr(func1, "test_compiler", "test_compiler_1") - gv1 = relay.GlobalVar("test_compiler_1") + func1 = set_func_attr(func1, "test_compiler", "test_compiler_0") + gv1 = relay.GlobalVar("test_compiler_0") mod[gv1] = func1 # main function @@ -588,7 +595,7 @@ def partition(): transform.Inline(), ]) - with relay.build_config(opt_level=3): + with tvm.transform.PassContext(opt_level=3): mod = opt_pass(mod) return mod @@ -631,7 +638,6 @@ def test_constant_propagation(): def expected(): mod = tvm.IRModule() - x = relay.const(ones) y = relay.var("y", shape=(8, 8)) x0 = relay.const(ones) y0 = relay.var("y0", shape=(8, 8)) @@ -713,12 +719,12 @@ def expected(): mod = tvm.IRModule() # function 0 - data = relay.var("test_target_2_i0", relay.TensorType((1, 3, 224, 224), "float32")) - weight = relay.var("test_target_2_i1", relay.TensorType((16, 3, 3, 3), "float32")) - bn_gamma = relay.var("test_target_2_i2", relay.TensorType((16, ), "float32")) - bn_beta = relay.var("test_target_2_i3", relay.TensorType((16, ), "float32")) - bn_mean = relay.var("test_target_2_i4", relay.TensorType((16, ), "float32")) - bn_var = relay.var("test_target_2_i5", relay.TensorType((16, ), "float32")) + data = relay.var("test_target_0_i0", relay.TensorType((1, 3, 224, 224), "float32")) + weight = relay.var("test_target_0_i1", relay.TensorType((16, 3, 3, 3), "float32")) + bn_gamma = relay.var("test_target_0_i2", relay.TensorType((16, ), "float32")) + bn_beta = relay.var("test_target_0_i3", relay.TensorType((16, ), "float32")) + bn_mean = relay.var("test_target_0_i4", relay.TensorType((16, ), "float32")) + bn_var = relay.var("test_target_0_i5", relay.TensorType((16, ), "float32")) conv_o = relay.nn.conv2d( data=data, @@ -731,12 +737,12 @@ def expected(): bn_var) relu_o = relay.nn.relu(bn_o[0]) - tuple_o = relay.Tuple((bn_o[2], bn_o[1], relu_o)) + tuple_o = relay.Tuple((relu_o, bn_o[1], bn_o[2])) func0 = relay.Function([data, weight, bn_gamma, bn_beta, bn_mean, bn_var], tuple_o) - func0 = set_func_attr(func0, "test_target", "test_target_2") - gv0 = relay.GlobalVar("test_target_2") + func0 = set_func_attr(func0, "test_target", "test_target_0") + gv0 = relay.GlobalVar("test_target_0") mod[gv0] = func0 # body @@ -748,9 +754,9 @@ def expected(): bn_var = relay.var("bn_var", relay.TensorType((16, ), "float32")) f0_o = gv0(data, weight, bn_gamma, bn_beta, bn_mean, bn_var) - f0_relu_o = relay.TupleGetItem(f0_o, 2) + f0_relu_o = relay.TupleGetItem(f0_o, 0) f0_mean_o = relay.TupleGetItem(f0_o, 1) - f0_var_o = relay.TupleGetItem(f0_o, 0) + f0_var_o = relay.TupleGetItem(f0_o, 2) f0_mean_abs = relay.abs(f0_mean_o) f0_var_abs = relay.abs(f0_var_o) @@ -792,22 +798,22 @@ def expected(): mod = tvm.IRModule() # function 1 - f1_cb1 = relay.var('test_target_1_i0', shape=(10, 10)) + f1_cb1 = relay.var('test_target_0_i0', shape=(10, 10)) f1_O_1 = relay.abs(f1_cb1) f1_O_2 = relay.nn.relu(f1_O_1) f1_out = relay.Tuple((f1_O_2, f1_O_1)) func1 = relay.Function([f1_cb1], f1_out) - func1 = set_func_attr(func1, "test_target", "test_target_1") - gv1 = relay.GlobalVar("test_target_1") + func1 = set_func_attr(func1, "test_target", "test_target_0") + gv1 = relay.GlobalVar("test_target_0") mod[gv1] = func1 # function 0 - f2_cb3 = relay.var('test_target_0_i0', shape=(10, 10)) - f2_cb4 = relay.var('test_target_0_i1', shape=(10, 10)) + f2_cb3 = relay.var('test_target_1_i0', shape=(10, 10)) + f2_cb4 = relay.var('test_target_1_i1', shape=(10, 10)) f2_O_3 = relay.add(f2_cb3, f2_cb4) func0 = relay.Function([f2_cb3, f2_cb4], f2_O_3) - func0 = set_func_attr(func0, "test_target", "test_target_0") - gv0 = relay.GlobalVar("test_target_0") + func0 = set_func_attr(func0, "test_target", "test_target_1") + gv0 = relay.GlobalVar("test_target_1") mod[gv0] = func0 # body @@ -879,7 +885,8 @@ def get_partitoned_mod(mod, params, pattern_table): transform.PartitionGraph() ]) - with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]): + with tvm.transform.PassContext(opt_level=3, + disabled_pass=["AlterOpLayout"]): return composite_partition(mod) def test_detect_pattern(pattern_table, include_bn, include_sigmoid, @@ -1027,7 +1034,7 @@ def test_different_output_region(): def test_duplicate_outputs(): target = "test_duplicate_outputs" - @reg.register("abs", "target." + target) + @tvm.ir.register_op_attr("abs", "target." + target) def abs(attrs, args): # pylint: disable=unused-variable return True @@ -1083,12 +1090,12 @@ def expected(): def test_duplicate_merge_and_tuplegetitem(): target = "test_duplicate_merge_and_tuplegetitem" - @reg.register("nn.batch_norm", "target." + target) - def abs(attrs, args): # pylint: disable=unused-variable + @tvm.ir.register_op_attr("nn.batch_norm", "target." + target) + def batch_norm(attrs, args): # pylint: disable=unused-variable return True - @reg.register("nn.relu", "target." + target) - def abs(attrs, args): # pylint: disable=unused-variable + @tvm.ir.register_op_attr("nn.relu", "target." + target) + def relu(attrs, args): # pylint: disable=unused-variable return True def create_graph(): @@ -1110,22 +1117,22 @@ def expected(): mod = tvm.IRModule() # function 0 - f0_i0 = relay.var(target+"_1_i0", shape=(10, 10)) - f0_i1 = relay.var(target+"_1_i1") - f0_i2 = relay.var(target+"_1_i2") - f0_i3 = relay.var(target+"_1_i3") - f0_i4 = relay.var(target+"_1_i4") + f0_i0 = relay.var(target + "_0_i0", shape=(10, 10)) + f0_i1 = relay.var(target + "_0_i1") + f0_i2 = relay.var(target + "_0_i2") + f0_i3 = relay.var(target + "_0_i3") + f0_i4 = relay.var(target + "_0_i4") f0_n0 = relay.nn.batch_norm(f0_i0, f0_i1, f0_i2, f0_i3, f0_i4) f0_n1 = f0_n0[1] f0_n2 = relay.nn.relu(f0_n0[0]) - f0_o0 = relay.Tuple([f0_n1, f0_n2]) + f0_o0 = relay.Tuple([f0_n2, f0_n1]) func0 = relay.Function([f0_i0, f0_i1, f0_i2, f0_i3, f0_i4], f0_o0) func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1)) func0 = func0.with_attr("Compiler", target) - func0 = func0.with_attr("global_symbol", target+"_1") - gv0 = relay.GlobalVar(target+"_1") + func0 = func0.with_attr("global_symbol", target + "_0") + gv0 = relay.GlobalVar(target + "_0") mod[gv0] = func0 # body @@ -1137,9 +1144,9 @@ def expected(): function_out = gv0(data, bn_gamma, bn_beta, bn_mmean, bn_mvar) get_out0 = relay.TupleGetItem(function_out, 0) get_out1 = relay.TupleGetItem(function_out, 1) - out_2 = relay.tanh(get_out0) - out_3 = relay.log(get_out0) - out = relay.Tuple([get_out1, out_2, out_3]) + out_2 = relay.tanh(get_out1) + out_3 = relay.log(get_out1) + out = relay.Tuple([get_out0, out_2, out_3]) func = relay.Function([data, bn_gamma, bn_beta, bn_mmean, bn_mvar], out) mod["main"] = func return mod @@ -1157,6 +1164,132 @@ def expected(): partitioned = seq(mod) assert tvm.ir.structural_equal(partitioned, ref_mod, map_free_vars=True) +def test_constant_tuples(): + @tvm.ir.register_op_attr("qnn.concatenate", "target.const_tuples") + def add(attrs, args): # pylint: disable=unused-variable + return True + + def create_graph(): + a = relay.var('a', shape=(10, 10), dtype="uint8") + b = relay.var('b', shape=(10, 10), dtype="uint8") + a1 = relay.abs(a) + + zeroi = relay.const(1, "int32") + zerof = relay.const(0, "float32") + con = relay.qnn.op.concatenate((a1, b), + input_scales=(zerof, zerof), + input_zero_points=(zeroi, zeroi), + output_scale=zerof, + output_zero_point=zeroi, + axis=1) + + f = relay.Function([a, b], con) + mod = tvm.IRModule.from_expr(f) + return mod + + seq = tvm.transform.Sequential([ + transform.AnnotateTarget("const_tuples"), + transform.MergeCompilerRegions(), + transform.PartitionGraph(), + ]) + + partitioned = seq(create_graph()) + concat = partitioned["const_tuples_0"].body + assert type(concat.args[1]) == relay.Tuple + assert type(concat.args[2]) == relay.Tuple + assert type(concat.args[3]) == relay.Constant + assert type(concat.args[4]) == relay.Constant + +def test_flatten_tuple_output(): + target = "test_flatten_tuple_output" + + @tvm.ir.register_op_attr("split", "target." + target) + def split(attrs, args): # pylint: disable=unused-variable + return True + + @tvm.ir.register_op_attr("abs", "target." + target) + def abs(attrs, args): # pylint: disable=unused-variable + return True + + def create_graph(): + a = relay.var('a', shape=(10, 10), dtype="uint8") + + a_split = relay.split(a, 2) + a_split_0 = relay.TupleGetItem(a_split.astuple(),0) + a_split_0_abs = relay.abs(a_split_0) + + a_con = relay.concatenate(a_split, 0) + a_split_0_relu = relay.nn.relu(a_split_0_abs) + + out = relay.Tuple((a_con, a_split_0_relu)) + f = relay.Function([a], out) + mod = tvm.IRModule.from_expr(f) + return mod + + def expected(): + mod = tvm.IRModule() + + # function 0 + f0_i0 = relay.var(target + "_0_i0", shape=(10, 10), dtype="uint8") + a_split = relay.split(f0_i0, 2) + a_split_0 = relay.TupleGetItem(a_split.astuple(), 0) + a_split_1 = relay.TupleGetItem(a_split.astuple(), 1) + a_split_abs_in = relay.TupleGetItem(a_split.astuple(), 0) + abs = relay.abs(a_split_abs_in) + tuple_out = relay.Tuple((a_split_0, a_split_1, abs)) + func0 = relay.Function([f0_i0], tuple_out) + + func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) + func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1)) + func0 = func0.with_attr("Compiler", target) + func0 = func0.with_attr("global_symbol", target + "_0") + gv0 = relay.GlobalVar(target + "_0") + mod[gv0] = func0 + + #body + data = relay.var('a', shape=(10, 10), dtype="uint8") + f_out = gv0(data) + f_out_0 = relay.TupleGetItem(f_out, 0) + f_out_1 = relay.TupleGetItem(f_out, 1) + tuple = relay.Tuple((f_out_0, f_out_1)) + concat = relay.concatenate(tuple,0) + f_out_2 = relay.TupleGetItem(f_out, 2) + relu = relay.nn.relu(f_out_2) + ret_tuple = relay.Tuple((concat, relu)) + mod["main"] = relay.Function([data], ret_tuple) + return mod + + seq = tvm.transform.Sequential([ + transform.AnnotateTarget(target), + transform.MergeCompilerRegions(), + transform.PartitionGraph(), + ]) + + partitioned = seq(create_graph()) + assert tvm.ir.structural_equal(partitioned, expected(), map_free_vars=True) + +def test_tuple_output_exec(): + """Test C codegen and runtime for a subgraph with a tuple output""" + a = relay.var('a', shape=(10, 10), dtype='float32') + b = relay.var('b', shape=(10, 10), dtype='float32') + ba = relay.annotation.compiler_begin(a, 'ccompiler') + bb = relay.annotation.compiler_begin(b, 'ccompiler') + add = relay.add(ba, bb) + sub = relay.subtract(ba, bb) + out = relay.Tuple((add, sub)) + eout = relay.annotation.compiler_end(out, 'ccompiler') + func=relay.Function([a, b], eout) + mod = tvm.IRModule() + mod["main"] = func + mod = transform.PartitionGraph()(mod) + + a_data = np.random.rand(10, 10).astype('float32') + b_data = np.random.rand(10, 10).astype('float32') + + check_result(mod, {'a': a_data, 'b': b_data}, + [(10, 10), (10, 10)], + [(a_data + b_data), (a_data - b_data)]) + if __name__ == "__main__": test_multi_node_compiler() test_extern_ccompiler_single_op() @@ -1173,3 +1306,6 @@ def expected(): test_multiple_use_of_an_output() test_duplicate_outputs() test_duplicate_merge_and_tuplegetitem() + test_constant_tuples() + test_flatten_tuple_output() + test_tuple_output_exec() diff --git a/tests/python/relay/test_simplify_fc_transpose.py b/tests/python/relay/test_simplify_fc_transpose.py new file mode 100644 index 000000000000..e29038c49b8e --- /dev/null +++ b/tests/python/relay/test_simplify_fc_transpose.py @@ -0,0 +1,67 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import itertools + +import numpy as np +import scipy.sparse as sp + + +import tvm +from tvm.ir import IRModule +from tvm import relay +from tvm.relay.data_dep_optimization import simplify_fc_transpose + +def run_func(func, params, x): + with tvm.transform.PassContext(opt_level=3): + graph, lib, new_params = relay.build(func, "llvm", params=params) + + from tvm.contrib import graph_runtime + ctx = tvm.cpu(0) + dtype = 'float32' + m = graph_runtime.create(graph, lib, ctx) + # set inputs + m.set_input('data', tvm.nd.array(x.astype(dtype))) + m.set_input(**new_params) + # execute + m.run() + # get outputs + tvm_output = m.get_output(0) + return tvm_output.asnumpy() + +def test_simplify_fc_transpose(): + data = relay.var("data", shape=(1, 32), dtype="float32") + x = relay.nn.relu(data) + w1 = relay.var("w1", shape=(32, 64), dtype="float32") + y = relay.nn.dense(x, relay.transpose(w1, axes=[1, 0])) + z = relay.nn.relu(y) + w2 = relay.var("w2", shape=(64, 16), dtype="float32") + zz = relay.nn.dense(z, relay.transpose(w2, axes=[1, 0])) + func = relay.Function(relay.analysis.free_vars(zz), zz) + params = { + "w1": tvm.nd.array(np.random.uniform(-1, 1, (32, 64)).astype("float32")), + "w2": tvm.nd.array(np.random.uniform(-1, 1, (64, 16)).astype("float32")) + } + x_np = np.random.randn(1, 32).astype("float32") + old_result = run_func(func, params, x_np) + + new_func, new_params = simplify_fc_transpose.convert(func, params) + new_result = run_func(new_func, new_params, x_np) + np.testing.assert_allclose(old_result, new_result, atol=1e-5, rtol=1e-5) + +if __name__ == "__main__": + test_simplify_fc_transpose() diff --git a/tests/python/relay/test_sparse_dense_convert.py b/tests/python/relay/test_sparse_dense_convert.py new file mode 100644 index 000000000000..e0204aeaf9d0 --- /dev/null +++ b/tests/python/relay/test_sparse_dense_convert.py @@ -0,0 +1,86 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import itertools + +import numpy as np +import scipy.sparse as sp + + +import tvm +from tvm.ir import IRModule +from tvm import relay + + +def random_bsr_matrix(M, N, BS_R, BS_C, density, dtype="float32"): + Y = np.zeros((M, N), dtype=dtype) + assert M % BS_R == 0 + assert N % BS_C == 0 + nnz = int(density * M * N) + num_blocks = int(nnz / (BS_R * BS_C)) + 1 + candidate_blocks = np.asarray(list(itertools.product(range(0, M, BS_R), range(0, N, BS_C)))) + assert candidate_blocks.shape[0] == M // BS_R * N // BS_C + chosen_blocks = candidate_blocks[np.random.choice(candidate_blocks.shape[0], size=num_blocks, replace=False)] + for i in range(len(chosen_blocks)): + r, c = chosen_blocks[i] + Y[r:r+BS_R,c:c+BS_C] = np.random.randn(BS_R, BS_C) + s = sp.bsr_matrix(Y, blocksize=(BS_R, BS_C)) + assert s.data.shape == (num_blocks, BS_R, BS_C) + assert s.data.size >= nnz + assert s.indices.shape == (num_blocks, ) + assert s.indptr.shape == (M // BS_R + 1, ) + return s + +def run_func(func, params, x): + with tvm.transform.PassContext(opt_level=3): + graph, lib, new_params = relay.build(func, "llvm", params=params) + + from tvm.contrib import graph_runtime + ctx = tvm.cpu(0) + dtype = 'float32' + m = graph_runtime.create(graph, lib, ctx) + # set inputs + m.set_input('data', tvm.nd.array(x.astype(dtype))) + m.set_input(**new_params) + # execute + m.run() + # get outputs + tvm_output = m.get_output(0) + return tvm_output.asnumpy() + +def test_bsr_sparse_dense(): + data = relay.var("data", shape=(1, 128), dtype="float32") + x = relay.nn.relu(data) + w = relay.var("weight", shape=(768, 128), dtype="float32") + y = relay.nn.dense(x, w) + z = relay.nn.relu(y) + func = relay.Function(relay.analysis.free_vars(z), z) + + params = { + "weight": tvm.nd.array(random_bsr_matrix(768, 128, 32, 1, 0.1).todense()) + } + + x_np = np.random.randn(1, 128).astype("float32") + # dense output + dense_output = run_func(func, params, x_np) + # sparse + sparse_func, params = relay.data_dep_optimization.bsr_dense.convert(func, params, (32, 1), 0.2) + sparse_output = run_func(sparse_func, params, x_np) + np.testing.assert_allclose(sparse_output, dense_output, atol=1e-5, rtol=1e-5) + +if __name__ == "__main__": + test_bsr_sparse_dense() diff --git a/tests/python/unittest/test_arith_canonical_simplify.py b/tests/python/unittest/test_arith_canonical_simplify.py index 0dcf1fb5344c..525cd6c30736 100644 --- a/tests/python/unittest/test_arith_canonical_simplify.py +++ b/tests/python/unittest/test_arith_canonical_simplify.py @@ -126,7 +126,7 @@ def test_floormod_simplify(): x, y = te.var("x"), te.var("y") ck.verify(flm(flm((x*4) + y - 466036, 24528) - 24512, 16), flm((x*4) + y + 12, 16)) - + ck.verify(flm(flm((x*4), 16), 8), flm(x, 2) * 4) def test_canonical_mixed(): @@ -202,7 +202,7 @@ def test_reduce_combiner_simplify(): assert tvm.ir.structural_equal(lhs, rhs) # Test that components with side effects are not removed - side_effect = lambda *xs: tvm.tir.Call("int32", "dummy", xs, tvm.tir.Call.Intrinsic, None, 0) + side_effect = lambda *xs: tvm.tir.Call("int32", "dummy", xs, tvm.tir.Call.Intrinsic) ck.verify(sum_and_prod((A[k], side_effect(A[10-k])), k)[0], sum_and_prod((A[k], side_effect(A[10-k])), k)[0]) ck.verify(sum_and_prod((side_effect(A[k]), A[10-k]), k)[0], diff --git a/tests/python/unittest/test_arith_deduce_bound.py b/tests/python/unittest/test_arith_deduce_bound.py index 5baabd16c615..372f0e9ce727 100644 --- a/tests/python/unittest/test_arith_deduce_bound.py +++ b/tests/python/unittest/test_arith_deduce_bound.py @@ -18,13 +18,6 @@ from tvm import te -def assert_expr_equal(a, b): - res = tvm.tir.ir_pass.Simplify(a - b) - equal = isinstance(res, tvm.tir.IntImm) and res.value == 0 - if not equal: - raise ValueError("{} and {} are not equal".format(a, b)) - - def test_deduce(): a = te.var('a') b = te.var('b') @@ -41,90 +34,90 @@ def test_deduce(): e0 = (-b)*a+c-d res0 = tvm.arith.deduce_bound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {}) ans0 = fdiv(d - c, b*-1) - assert_expr_equal(res0.max_value, ans0) + tvm.testing.assert_prim_expr_equal(res0.max_value, ans0) # expression containing variable a is on rhs res0 = tvm.arith.deduce_bound(a, zero <= e0, {b: b_s, c: c_s, d: d_s}, {}) - assert_expr_equal(res0.max_value, ans0) + tvm.testing.assert_prim_expr_equal(res0.max_value, ans0) e0 = d*a+c-d res0 = tvm.arith.deduce_bound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {}) ans0 = fdiv(d-c, d) - assert_expr_equal(res0.max_value, ans0) + tvm.testing.assert_prim_expr_equal(res0.max_value, ans0) # expression containing variable a is on rhs res0 = tvm.arith.deduce_bound(a, zero <= e0, {b: b_s, c: c_s, d: d_s}, {}) - assert_expr_equal(res0.max_value, ans0) + tvm.testing.assert_prim_expr_equal(res0.max_value, ans0) e1 = (a*4+b < c) res1 = tvm.arith.deduce_bound(a, e1, {b: b_s, c: c_s, d: d_s}, {}) ans1 = fdiv(c-1-b, 4) - assert_expr_equal(res1.max_value, ans1) + tvm.testing.assert_prim_expr_equal(res1.max_value, ans1) # expression containing variable a is on rhs e1 = (c > a*4+b) res1 = tvm.arith.deduce_bound(a, e1, {b: b_s, c: c_s, d: d_s}, {}) - assert_expr_equal(res1.max_value, ans1) + tvm.testing.assert_prim_expr_equal(res1.max_value, ans1) e2 = (tvm.te.max(5, a * 4) < 0) res2 = tvm.arith.deduce_bound(a, e2, {b: b_s, c: c_s, d: d_s}, {}) - assert str(res2.max_value) == "neg_inf" - assert str(res2.min_value) == "pos_inf" + assert str(res2.max_value) == "neg_inf: handle" + assert str(res2.min_value) == "pos_inf: handle" # expression containing variable a is on rhs e2 = (zero < tvm.te.max(5, a * 4)) res2 = tvm.arith.deduce_bound(a, e2, {b: b_s, c: c_s, d: d_s}, {}) - assert str(res2.max_value) == "neg_inf" - assert str(res2.min_value) == "pos_inf" + assert str(res2.max_value) == "neg_inf: handle" + assert str(res2.min_value) == "pos_inf: handle" e3 = (-b)+a*c-d res3 = tvm.arith.deduce_bound(a, e3>=0, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s}) ans3 = fdiv(2,c)+1 - assert str(tvm.tir.ir_pass.Simplify(res3.min_value)) == str(ans3) + tvm.testing.assert_prim_expr_equal(res3.min_value, ans3) res3 = tvm.arith.deduce_bound(a, zero <= e3, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s}) - assert str(tvm.tir.ir_pass.Simplify(res3.min_value)) == str(ans3) + tvm.testing.assert_prim_expr_equal(res3.min_value, ans3) # tests for `EQ` op res4 = tvm.arith.deduce_bound(a, a == b, {}, {}) - assert_expr_equal(res4.max_value, b) - assert_expr_equal(res4.min_value, b) + tvm.testing.assert_prim_expr_equal(res4.max_value, b) + tvm.testing.assert_prim_expr_equal(res4.min_value, b) # Unsatisfiable `EQ`, variable as one of the Operand res5 = tvm.arith.deduce_bound(a, (a == b), {b: b_s}, {b: b_s}) - assert str(res5.max_value) == "neg_inf" - assert str(res5.min_value) == "pos_inf" + assert str(res5.max_value) == "neg_inf: handle" + assert str(res5.min_value) == "pos_inf: handle" # variable `a` on the RHS side res6 = tvm.arith.deduce_bound(a, 10 == a, {}, {}) - assert_expr_equal(res6.max_value, 10) - assert_expr_equal(res6.min_value, 10) + tvm.testing.assert_prim_expr_equal(res6.max_value, 10) + tvm.testing.assert_prim_expr_equal(res6.min_value, 10) # Add, Sub in `EQ` e4 = ((a - c) == (b + d)) ans4 = (b + d + c) res7 = tvm.arith.deduce_bound(a, e4, {b: b_s, c: c_s, d: d_s}, {}) - assert_expr_equal(res7.max_value, ans4) - assert_expr_equal(res7.min_value, ans4) + tvm.testing.assert_prim_expr_equal(res7.max_value, ans4) + tvm.testing.assert_prim_expr_equal(res7.min_value, ans4) # Satisfiable Mul in `EQ` with negative sign res8 = tvm.arith.deduce_bound(a, (5 * a == -10), {}, {}) - assert_expr_equal(res8.max_value, -2) - assert_expr_equal(res8.min_value, -2) + tvm.testing.assert_prim_expr_equal(res8.max_value, -2) + tvm.testing.assert_prim_expr_equal(res8.min_value, -2) # Unsatisfiable Mul in `EQ` e5 = (4 * a == b) res9 = tvm.arith.deduce_bound(a, e5, {b: b_s}, {}) - assert str(res9.max_value) == "neg_inf" - assert str(res9.min_value) == "pos_inf" + assert str(res9.max_value) == "neg_inf: handle" + assert str(res9.min_value) == "pos_inf: handle" # Unsatisfiable Mul in `EQ` res10 = tvm.arith.deduce_bound(a, (b * a == b), {b: b_s}, {}) # simplifier is not able to prove that (b % b == 0) - assert str(res10.max_value) == "neg_inf" - assert str(res10.min_value) == "pos_inf" + assert str(res10.max_value) == "neg_inf: handle" + assert str(res10.min_value) == "pos_inf: handle" def test_check(): @@ -158,21 +151,22 @@ def test_basic(a1, a2, coff): res1 = tvm.arith.deduce_bound(a, e0<17, {b: b_s}, {b: b_s}) [x, y] = [res1.max_value, b_s.max_value] if coff > 0 else [res1.min_value, b_s.min_value] - assert (tvm.tir.ir_pass.Simplify((x * coff + 3 + y) < 17)).value == 1 + tvm.testing.assert_prim_expr_equal((x * coff + 3 + y) < 17, True) # expression containing variable a is on rhs res1 = tvm.arith.deduce_bound(a, tvm.tir.const(17, "int32") < e0, {b: b_s}, {b: b_s}) [x, y] = [res1.max_value, b_s.max_value] if coff < 0 else [res1.min_value, b_s.min_value] - assert (tvm.tir.ir_pass.Simplify((x * coff + 3 + y) > 17)).value == 1 + tvm.testing.assert_prim_expr_equal((x * coff + 3 + y) > 17, True) # expression containing variable a is on rhs res1 = tvm.arith.deduce_bound(a, tvm.tir.const(17, "int32")>= e0, {b: b_s}, {b: b_s}) [x, y] = [res1.max_value, b_s.max_value] if coff > 0 else [res1.min_value, b_s.min_value] - assert (tvm.tir.ir_pass.Simplify((x * coff + 3 + y) <= 17)).value == 1 + + tvm.testing.assert_prim_expr_equal((x * coff + 3 + y) <= 17, True) res1 = tvm.arith.deduce_bound(a, e0>=17, {b: b_s}, {b: b_s}) [x, y] = [res1.max_value, b_s.max_value] if coff < 0 else [res1.min_value, b_s.min_value] - assert (tvm.tir.ir_pass.Simplify((x * coff + 3 + y) >= 17)).value == 1 + tvm.testing.assert_prim_expr_equal((x * coff + 3 + y) >= 17, True) test_basic(0, 4, 4) test_basic(1, 5, 4) @@ -190,21 +184,21 @@ def test_complex(a1, a2, coff): res1 = tvm.arith.deduce_bound(a, e0<63, {b: b_s}, {b: b_s}) [t, x] = [res1.max_value, b_s.max_value] if coff > 0 else [res1.min_value, b_s.min_value] - assert (tvm.tir.ir_pass.Simplify(((x*3 + t* coff) * 4) < 63)).value == 1 + tvm.testing.assert_prim_expr_equal(((x*3 + t* coff) * 4) < 63, True) # expression containing variable a is on rhs res1 = tvm.arith.deduce_bound(a, tvm.tir.const(63, "int32")>= e0, {b: b_s}, {b: b_s}) [t, x] = [res1.max_value, b_s.max_value] if coff > 0 else [res1.min_value, b_s.min_value] - assert (tvm.tir.ir_pass.Simplify(((x*3 + t* coff) * 4) <= 63)).value == 1 + tvm.testing.assert_prim_expr_equal(((x*3 + t* coff) * 4) <= 63, True) res1 = tvm.arith.deduce_bound(a, e0>63, {b: b_s}, {b: b_s}) [t, x] = [res1.max_value, b_s.max_value] if coff < 0 else [res1.min_value, b_s.min_value] - assert (tvm.tir.ir_pass.Simplify(((x*3 + t* coff) * 4) > 63)).value == 1 + tvm.testing.assert_prim_expr_equal(((x*3 + t* coff) * 4) > 63, True) # expression containing variable a is on rhs res1 = tvm.arith.deduce_bound(a, tvm.tir.const(63, "int32") <= e0, {b: b_s}, {b: b_s}) [t, x] = [res1.max_value, b_s.max_value] if coff < 0 else [res1.min_value, b_s.min_value] - assert (tvm.tir.ir_pass.Simplify(((x*3 + t* coff) * 4) >= 63)).value == 1 + tvm.testing.assert_prim_expr_equal(((x*3 + t* coff) * 4) >= 63, True) test_complex(0, 4, 4) test_complex(0, 4, -4) diff --git a/tests/python/unittest/test_arith_detect_clip_bound.py b/tests/python/unittest/test_arith_detect_clip_bound.py index d6953713f14b..129237a8c58b 100644 --- a/tests/python/unittest/test_arith_detect_clip_bound.py +++ b/tests/python/unittest/test_arith_detect_clip_bound.py @@ -23,15 +23,15 @@ def test_basic(): c = te.var("c") m = tvm.arith.detect_clip_bound(tvm.tir.all(a * 1 < b * 6, a - 1 > 0), [a]) - assert tvm.tir.ir_pass.Simplify(m[1] - (b * 6 - 1)).value == 0 + tvm.testing.assert_prim_expr_equal(m[1], b * 6 - 1) assert m[0].value == 2 m = tvm.arith.detect_clip_bound(tvm.tir.all(a * 1 < b * 6, a - 1 > 0), [a, b]) assert len(m) == 0 m = tvm.arith.detect_clip_bound(tvm.tir.all(a + 10 * c <= 20, b - 1 > 0), [a, b]) - assert tvm.tir.ir_pass.Simplify(m[1] - (20 - 10 * c)).value == 0 - assert tvm.tir.ir_pass.Simplify(m[2] - 2).value == 0 + tvm.testing.assert_prim_expr_equal(m[1], 20 - 10 * c) + tvm.testing.assert_prim_expr_equal(m[2], 2) if __name__ == "__main__": diff --git a/tests/python/unittest/test_arith_detect_linear_equation.py b/tests/python/unittest/test_arith_detect_linear_equation.py index 278581d0cacd..82153ab5207e 100644 --- a/tests/python/unittest/test_arith_detect_linear_equation.py +++ b/tests/python/unittest/test_arith_detect_linear_equation.py @@ -22,14 +22,14 @@ def test_basic(): b = te.var("b") m = tvm.arith.detect_linear_equation(a * 4 + b * 6 + 7, [a]) assert m[0].value == 4 - assert tvm.tir.ir_pass.Simplify(m[1] - (b * 6 + 7)).value == 0 + tvm.testing.assert_prim_expr_equal(m[1], b * 6 + 7) m = tvm.arith.detect_linear_equation(a * 4 * (a+1) + b * 6 + 7, [a]) assert len(m) == 0 m = tvm.arith.detect_linear_equation(a * 4 + (a+1) + b * 6 + 7, [a]) assert m[0].value == 5 - assert tvm.tir.ir_pass.Simplify(m[1] - (b * 6 + 7 + 1)).value == 0 + tvm.testing.assert_prim_expr_equal(m[1], b * 6 + 7 + 1) m = tvm.arith.detect_linear_equation(a * b + 7, [a]) assert m[0] == b @@ -39,13 +39,15 @@ def test_basic(): m = tvm.arith.detect_linear_equation(b * 7, []) assert len(m) == 1 - assert tvm.tir.ir_pass.Simplify(m[0] - b * 7).value == 0 + tvm.testing.assert_prim_expr_equal(m[0], b * 7) def test_multivariate(): v = [te.var("v%d" % i) for i in range(4)] b = te.var("b") m = tvm.arith.detect_linear_equation(v[0] * (b + 4) + v[0] + v[1] * 8, v) - assert(tvm.tir.analysis.expr_deep_equal(tvm.tir.ir_pass.Simplify(m[0]), b + 5)) + + tvm.testing.assert_prim_expr_equal(m[0], b + 5) + assert(m[1].value == 8) m = tvm.arith.detect_linear_equation(v[0] * (b + 4) + v[0] + v[1] * 8 * v[2], v) @@ -61,11 +63,12 @@ def test_multivariate(): m = tvm.arith.detect_linear_equation((v[0] - v[1]), [v[2]]) assert(m[0].value == 0) - assert(tvm.tir.ir_pass.Simplify(m[1] - (v[0] - v[1])).value == 0) + + tvm.testing.assert_prim_expr_equal(m[1], v[0] - v[1]) m = tvm.arith.detect_linear_equation((v[0] - v[1]), []) assert(len(m) == 1) - assert(tvm.tir.ir_pass.Simplify(m[0] - (v[0] - v[1])).value == 0) + tvm.testing.assert_prim_expr_equal(m[0], v[0] - v[1]) if __name__ == "__main__": test_basic() diff --git a/tests/python/unittest/test_arith_domain_touched.py b/tests/python/unittest/test_arith_domain_touched.py index 0d769aabf247..10337218dc87 100644 --- a/tests/python/unittest/test_arith_domain_touched.py +++ b/tests/python/unittest/test_arith_domain_touched.py @@ -22,21 +22,25 @@ def test_domain_touched(): j = te.var('j') n = tvm.runtime.convert(100) m = te.var('m') - a = te.placeholder((n, m), name = 'a') - b = te.placeholder((n, m), name = 'b') + + a = tvm.tir.decl_buffer((n, m), name='a') + b = tvm.tir.decl_buffer((n, m), name='b') + + ir = tvm.tir.For( i, 0, n, 0, 0, tvm.tir.For(j, 0, m, 0, 0, - tvm.tir.Provide( - a.op, - 0, - tvm.tir.Call(b.dtype, 'b', [i - 1, j + 1], 3, b.op, 0) + - tvm.tir.Call(a.dtype, 'a', [i - 1, j - 1], 3, a.op, 0), + tvm.tir.BufferStore( + a, + tvm.tir.BufferLoad(b, [i - 1, j + 1]) + + tvm.tir.BufferLoad(a, [i - 1, j - 1]), [i, j] ) ) ) + a_domain_r = tvm.arith._ffi_api.DomainTouched(ir, a, True, False) + assert a_domain_r[0].min.value == -1 assert a_domain_r[0].extent.value == 100 assert a_domain_r[1].min.value == -1 diff --git a/tests/python/unittest/test_arith_intset.py b/tests/python/unittest/test_arith_intset.py index e57dcef75994..9919c7b96cf1 100644 --- a/tests/python/unittest/test_arith_intset.py +++ b/tests/python/unittest/test_arith_intset.py @@ -90,6 +90,20 @@ def test_mod(): flm = tvm.te.floormod ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(-10, 10)}, (0, 9)) + ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(3, 5)}, (3, 5)) + ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(13, 15)}, (3, 5)) + ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(3, 15)}, (0, 9)) + ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(3, 11)}, (0, 9)) + ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(1, 21)}, (0, 9)) + + floordiv = tvm.te.floordiv + z = te.var("z") + ck.analyzer.bind(x, tvm.ir.Range.make_by_min_extent(0, 3)) + ck.verify(flm(y, 8), {y : tvm.arith.IntervalSet(z*8+x*4, z*8+x*4+3)}, + (0, 7)) + ck1 = IntSetChecker() + ck1.analyzer.bind(x, tvm.ir.Range.make_by_min_extent(0, 2)) + ck1.verify(flm(y, 8), {y : tvm.arith.IntervalSet(z*8+x*4, z*8+x*4+3)}, (x*4, x*4+3)) def test_max_min(): diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py b/tests/python/unittest/test_arith_rewrite_simplify.py index dbfdde3ac883..813e10a58707 100644 --- a/tests/python/unittest/test_arith_rewrite_simplify.py +++ b/tests/python/unittest/test_arith_rewrite_simplify.py @@ -529,7 +529,13 @@ def test_min_index_simplify(): ck.verify(tvm.te.min(tvm.te.min(x, 11), 10), tvm.te.min(x, 10)) ck.verify(tvm.te.min(x * 3, 9), tvm.te.min(x, 3) * 3) + ck.verify(tvm.te.min(x * 2, 0), tvm.te.min(x, 0) * 2) + ck.verify(tvm.te.min(0 - x * 2, 0), tvm.te.max(x, 0) * -2) ck.verify(tvm.te.min(3 - x, 2), 3 - tvm.te.max(x, 1)) + ck.verify(tvm.te.min(x * (-2), -4), tvm.te.max(x, 2) * -2) + ck.verify(tvm.te.min(x * (-2), 4), tvm.te.max(x, -2) * -2) + ck.verify(tvm.te.min(x * (0), 4), 0) + ck.verify(tvm.te.min(x * (0), -4), -4) # DivMod rules # truc div @@ -610,6 +616,12 @@ def test_max_index_simplify(): ck.verify(tvm.te.max(x * 3, 9), tvm.te.max(x, 3) * 3) ck.verify(tvm.te.max(3 - x, 1), 3 - tvm.te.min(x, 2)) + ck.verify(tvm.te.max(x * 2, 0), tvm.te.max(x, 0) * 2) + ck.verify(tvm.te.max(0 - x * 2, 0), tvm.te.min(x, 0) * -2) + ck.verify(tvm.te.max(x * (-2), -4), tvm.te.min(x, 2) * -2) + ck.verify(tvm.te.max(x * (-2), 4), tvm.te.min(x, -2) * -2) + ck.verify(tvm.te.max(x * (0), 4), 4) + ck.verify(tvm.te.max(x * (0), -4), 0) # DivMod rules # truc div diff --git a/tests/python/unittest/test_arith_solve_linear_system.py b/tests/python/unittest/test_arith_solve_linear_system.py index 45f8fc10aaf0..4f4c5ee97944 100644 --- a/tests/python/unittest/test_arith_solve_linear_system.py +++ b/tests/python/unittest/test_arith_solve_linear_system.py @@ -29,7 +29,7 @@ def run_expr(expr, vranges): """ def _compute_body(*us): vmap = {v: u + r.min for (v, r), u in zip(vranges.items(), us)} - return tir.ir_pass.Substitute(expr, vmap) + return tir.stmt_functor.substitute(expr, vmap) A = te.compute([r.extent.value for v, r in vranges.items()], _compute_body) args = [tvm.nd.empty(A.shape, A.dtype)] @@ -55,12 +55,13 @@ def check_bruteforce(bool_expr, vranges, cond=None): counterex = ", ".join([v + " = " + str(i) for v, i in counterex]) raise AssertionError("Expression {}\nis not true on {}\n" "Counterexample: {}" - .format(tir.ir_pass.CanonicalSimplify(bool_expr), vranges, counterex)) + .format(tir.arith.Analyzer().simplify(bool_expr), vranges, counterex)) def check_solution(solution, vranges={}): """Check that solution is a bijective transformation""" def _check_forward(constraints1, constraints2, varmap, backvarmap): + ana = tvm.arith.Analyzer() all_vranges = vranges.copy() all_vranges.update({v: r for v, r in constraints1.ranges.items()}) @@ -68,19 +69,19 @@ def _check_forward(constraints1, constraints2, varmap, backvarmap): cond_on_vars = tir.const(1, 'bool') for v in constraints1.variables: # variable mapping is consistent - v_back = tir.ir_pass.Simplify(tir.ir_pass.Substitute(varmap[v], backvarmap)) + v_back = ana.simplify(tir.stmt_functor.substitute(varmap[v], backvarmap)) cond_on_vars = te.all(cond_on_vars, v == v_back) # Also we have to check that the new relations are true when old relations are true - cond_subst = tir.ir_pass.Substitute( + cond_subst = tir.stmt_functor.substitute( te.all(tir.const(1, 'bool'), *constraints2.relations), backvarmap) # We have to include relations from vranges too for v in constraints2.variables: if v in constraints2.ranges: r = constraints2.ranges[v] range_cond = te.all(v >= r.min, v < r.min + r.extent) - range_cond = tir.ir_pass.Substitute(range_cond, backvarmap) + range_cond = tir.stmt_functor.substitute(range_cond, backvarmap) cond_subst = te.all(cond_subst, range_cond) - cond_subst = tir.ir_pass.Simplify(cond_subst) + cond_subst = ana.simplify(cond_subst) check_bruteforce(te.all(cond_subst, cond_on_vars), all_vranges, cond=te.all(tir.const(1, 'bool'), *constraints1.relations)) diff --git a/tests/python/unittest/test_format_si_prefix.py b/tests/python/unittest/test_format_si_prefix.py new file mode 100644 index 000000000000..69be62a063b9 --- /dev/null +++ b/tests/python/unittest/test_format_si_prefix.py @@ -0,0 +1,41 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from numpy import isclose +import random +from tvm.autotvm import util + + +SI_PREFIXES = 'yzafpn\xb5m kMGTPEZY' + + +def test_format_si_prefix(): + # test float conversion + assert util.format_si_prefix(1024, 'k') == 1.024 + + for i, prefix in enumerate(SI_PREFIXES): + integer, decimal = random.randint(0, 1000), random.randint(0, 1000) + exp = -24 + 3 * i # 0th prefix (yocto) is 10^-24 + number = integer * (10 ** exp) + decimal * (10 ** (exp - 3)) + expected = (integer + decimal / 1000) + assert isclose(util.format_si_prefix(number, prefix), expected) + + assert util.format_si_prefix(0, 'y') == 0 + + +if __name__ == '__main__': + test_format_si_prefix() diff --git a/tests/python/unittest/test_ir_attrs.py b/tests/python/unittest/test_ir_attrs.py index 48495f48dc5a..233e59b8d01e 100644 --- a/tests/python/unittest/test_ir_attrs.py +++ b/tests/python/unittest/test_ir_attrs.py @@ -15,20 +15,15 @@ # specific language governing permissions and limitations # under the License. import tvm +import pytest import tvm.ir._ffi_api def test_make_attrs(): - try: + with pytest.raises(AttributeError): x = tvm.ir.make_node("attrs.TestAttrs", unknown_key=1, name="xx") - assert False - except tvm.error.TVMError as e: - assert str(e).find("unknown_key") != -1 - try: + with pytest.raises(AttributeError): x = tvm.ir.make_node("attrs.TestAttrs", axis=100, name="xx") - assert False - except tvm.error.TVMError as e: - assert str(e).find("upper bound") != -1 x = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3,4)) assert x.name == "xx" diff --git a/tests/python/unittest/test_ir_type.py b/tests/python/unittest/test_ir_type.py index a0e7d2b46ad6..1072efb11600 100644 --- a/tests/python/unittest/test_ir_type.py +++ b/tests/python/unittest/test_ir_type.py @@ -72,7 +72,7 @@ def test_func_type(): def test_tuple_type(): tp = tvm.ir.TypeVar('tp', tvm.ir.TypeKind.Type) - tf = tvm.ir.FuncType([], None, [], []) + tf = tvm.ir.FuncType([], tvm.ir.TupleType([]), [], []) tt = tvm.ir.TensorType(tvm.runtime.convert([1, 2, 3]), 'float32') fields = tvm.runtime.convert([tp, tf, tt]) diff --git a/tests/python/unittest/test_node_reflection.py b/tests/python/unittest/test_node_reflection.py index 975192293d87..b10951691715 100644 --- a/tests/python/unittest/test_node_reflection.py +++ b/tests/python/unittest/test_node_reflection.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. import tvm +import pytest from tvm import te def test_const_saveload_json(): @@ -101,6 +102,34 @@ def test_string(): tvm.ir.assert_structural_equal(s1, s2) +def test_pass_config(): + cfg = tvm.transform.PassContext(opt_level=1, config={ + "tir.UnrollLoop": { + "auto_max_step": 10, + } + }) + cfg.opt_level == 1 + + assert cfg.config["tir.UnrollLoop"].auto_max_step == 10 + # default option + assert cfg.config["tir.UnrollLoop"].explicit_unroll == True + + # schema checking for specific config key + with pytest.raises(AttributeError): + cfg = tvm.transform.PassContext(config={ + "tir.UnrollLoop": { "invalid": 1 } + }) + + # schema check for un-registered config + with pytest.raises(AttributeError): + cfg = tvm.transform.PassContext(config={ "inavlid-opt": True }) + + # schema check for wrong type + with pytest.raises(AttributeError): + cfg = tvm.transform.PassContext(config={ + "tir.UnrollLoop": 1 + }) + if __name__ == "__main__": test_string() test_env_func() @@ -108,3 +137,4 @@ def test_string(): test_make_smap() test_const_saveload_json() test_make_sum() + test_pass_config() diff --git a/tests/python/unittest/test_runtime_container.py b/tests/python/unittest/test_runtime_container.py index 84b26be6cbc1..5ecc21e520af 100644 --- a/tests/python/unittest/test_runtime_container.py +++ b/tests/python/unittest/test_runtime_container.py @@ -17,6 +17,7 @@ import numpy as np import tvm +import pickle from tvm import te from tvm import nd, relay from tvm.runtime import container as _container @@ -56,6 +57,29 @@ def test_tuple_object(): tvm.testing.assert_allclose(out.asnumpy(), np.array(11)) +def test_string(): + s = tvm.runtime.String("xyz") + + assert isinstance(s, tvm.runtime.String) + assert isinstance(s, str) + assert s.startswith("xy") + assert s + "1" == "xyz1" + y = tvm.testing.echo(s) + assert isinstance(y, tvm.runtime.String) + assert s.__tvm_object__.same_as(y.__tvm_object__) + assert s == y + + x = tvm.ir.load_json(tvm.ir.save_json(y)) + assert isinstance(x, tvm.runtime.String) + assert x == y + + # test pickle + z = pickle.loads(pickle.dumps(s)) + assert isinstance(z, tvm.runtime.String) + assert s == z + + if __name__ == "__main__": + test_string() test_adt_constructor() test_tuple_object() diff --git a/tests/python/unittest/test_runtime_extension.py b/tests/python/unittest/test_runtime_extension.py index 48eaf7dd306b..2207eb3a73fa 100644 --- a/tests/python/unittest/test_runtime_extension.py +++ b/tests/python/unittest/test_runtime_extension.py @@ -18,9 +18,10 @@ from tvm import te import numpy as np + @tvm.register_extension class MyTensorView(object): - _tvm_tcode = tvm.TypeCode.DLTENSOR_HANDLE + _tvm_tcode = tvm._ffi.runtime_ctypes.ArgTypeCode.DLTENSOR_HANDLE def __init__(self, arr): self.arr = arr diff --git a/tests/python/unittest/test_runtime_graph_debug.py b/tests/python/unittest/test_runtime_graph_debug.py index 658d9eb95ef9..ce47b16fc4d5 100644 --- a/tests/python/unittest/test_runtime_graph_debug.py +++ b/tests/python/unittest/test_runtime_graph_debug.py @@ -14,11 +14,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import json import os import tvm from tvm import te import numpy as np -import json from tvm import rpc from tvm.contrib import util from tvm.contrib.debugger import debug_runtime as graph_runtime @@ -75,7 +75,16 @@ def check_verify(): assert(len(os.listdir(directory)) == 1) #verify the file name is proper - assert(os.path.exists(os.path.join(directory, GRAPH_DUMP_FILE_NAME))) + graph_dump_path = os.path.join(directory, GRAPH_DUMP_FILE_NAME) + assert(os.path.exists(graph_dump_path)) + + # verify the graph contains some expected keys + with open(graph_dump_path) as graph_f: + dumped_graph = json.load(graph_f) + + assert isinstance(dumped_graph, dict) + for k in ("nodes", "arg_nodes", "node_row_ptr", "heads", "attrs"): + assert k in dumped_graph, f"key {k} not in dumped graph {graph!r}" mod.run() #Verify the tensors are dumped diff --git a/tests/python/unittest/test_runtime_micro.py b/tests/python/unittest/test_runtime_micro.py index 28fdb11c3de4..2eea3df5732a 100644 --- a/tests/python/unittest/test_runtime_micro.py +++ b/tests/python/unittest/test_runtime_micro.py @@ -25,8 +25,10 @@ from tvm.micro import create_micro_mod from tvm.relay.testing import resnet -# Use the host emulated micro device. -DEV_CONFIG = micro.device.host.default_config() +# # Use the host emulated micro device. +DEV_CONFIG_A = micro.device.host.generate_config() +DEV_CONFIG_B = micro.device.host.generate_config() +TARGET = 'c -device=micro_dev' def relay_micro_build(func, dev_config, params=None): """Create a graph runtime module with a micro device context from a Relay function. @@ -47,22 +49,41 @@ def relay_micro_build(func, dev_config, params=None): mod : tvm.runtime.Module graph runtime module for the target device """ - with tvm.target.build_config(disable_vectorize=True): - graph, c_mod, params = relay.build(func, target="c", params=params) - micro_mod = create_micro_mod(c_mod, dev_config) + with tvm.transform.PassContext(disabled_pass={'FuseOps'}, config={ + "tir.disable_vectorize": True + }): + graph, c_mod, params = relay.build(func, target=TARGET, params=params) + micro_mod = micro.create_micro_mod(c_mod, dev_config) ctx = tvm.micro_dev(0) mod = graph_runtime.create(graph, micro_mod, ctx) mod.set_input(**params) return mod +GDB_INIT_TEMPLATE = """ +layout asm +target remote localhost:{gdb_port} +set $pc = UTVMInit +break UTVMDone +""" + + +def reset_gdbinit(): + if 'server_port' not in DEV_CONFIG_A: + return + gdb_init_dir = os.environ['MICRO_GDB_INIT_DIR'] + with open(f'{gdb_init_dir}/.gdbinit', 'w') as f: + gdb_port = DEV_CONFIG_A['server_port'] - 3333 + f.write(GDB_INIT_TEMPLATE.format(gdb_port=gdb_port)) + + def test_alloc(): """Test tensor allocation on the device.""" if not tvm.runtime.enabled("micro_dev"): return shape = (1024,) dtype = "float32" - with micro.Session(DEV_CONFIG): + with micro.Session(DEV_CONFIG_A): ctx = tvm.micro_dev(0) np_tensor = np.random.uniform(size=shape).astype(dtype) micro_tensor = tvm.nd.array(np_tensor, ctx) @@ -76,6 +97,8 @@ def test_add(): shape = (1024,) dtype = "float32" + reset_gdbinit() + # Construct TVM expression. tvm_shape = tvm.runtime.convert(shape) A = te.placeholder(tvm_shape, name="A", dtype=dtype) @@ -86,14 +109,24 @@ def test_add(): func_name = "fadd" c_mod = tvm.build(s, [A, B, C], target="c", name=func_name) - with micro.Session(DEV_CONFIG): - micro_mod = create_micro_mod(c_mod, DEV_CONFIG) + with micro.Session(DEV_CONFIG_A) as sess: + micro_mod = micro.create_micro_mod(c_mod, DEV_CONFIG_A) micro_func = micro_mod[func_name] ctx = tvm.micro_dev(0) - a = tvm.nd.array(np.random.uniform(size=shape).astype(dtype), ctx) - b = tvm.nd.array(np.random.uniform(size=shape).astype(dtype), ctx) + + a_np = np.random.uniform(size=shape).astype(dtype) + a = tvm.nd.array(a_np, ctx) + b_np = np.random.uniform(size=shape).astype(dtype) + b = tvm.nd.array(b_np, ctx) c = tvm.nd.array(np.zeros(shape, dtype=dtype), ctx) micro_func(a, b, c) + + # ensure inputs weren't corrupted + tvm.testing.assert_allclose( + a.asnumpy(), a_np) + tvm.testing.assert_allclose( + b.asnumpy(), b_np) + # ensure output is correct tvm.testing.assert_allclose( c.asnumpy(), a.asnumpy() + b.asnumpy()) @@ -105,6 +138,8 @@ def test_workspace_add(): shape = (1024,) dtype = "float32" + reset_gdbinit() + # Construct TVM expression. tvm_shape = tvm.runtime.convert(shape) A = te.placeholder(tvm_shape, name="A", dtype=dtype) @@ -116,14 +151,19 @@ def test_workspace_add(): func_name = "fadd_two_workspace" c_mod = tvm.build(s, [A, C], target="c", name=func_name) - with micro.Session(DEV_CONFIG): - micro_mod = create_micro_mod(c_mod, DEV_CONFIG) + with micro.Session(DEV_CONFIG_A) as sess: + micro_mod = micro.create_micro_mod(c_mod, DEV_CONFIG_A) micro_func = micro_mod[func_name] ctx = tvm.micro_dev(0) - a = tvm.nd.array(np.random.uniform(size=shape).astype(dtype), ctx) + a_np = np.random.uniform(size=shape).astype(dtype) + a = tvm.nd.array(a_np, ctx) c = tvm.nd.array(np.zeros(shape, dtype=dtype), ctx) micro_func(a, c) + # ensure input wasn't corrupted + tvm.testing.assert_allclose( + a.asnumpy(), a_np) + # ensure output is correct tvm.testing.assert_allclose( c.asnumpy(), a.asnumpy() + 2.0) @@ -141,47 +181,76 @@ def test_graph_runtime(): z = relay.add(xx, relay.const(1.0)) func = relay.Function([x], z) - with micro.Session(DEV_CONFIG): - mod = relay_micro_build(func, DEV_CONFIG) + with micro.Session(DEV_CONFIG_A): + mod = relay_micro_build(func, DEV_CONFIG_A) x_in = np.random.uniform(size=shape[0]).astype(dtype) mod.run(x=x_in) result = mod.get_output(0).asnumpy() + tvm.testing.assert_allclose( + mod.get_input(0).asnumpy(), x_in) tvm.testing.assert_allclose( result, x_in * x_in + 1.0) -def test_multiple_modules(): - """Test loading multiple modules on the device simultaneously.""" +def test_conv2d(): if not tvm.runtime.enabled("micro_dev"): return - shape = (1024,) - dtype = "float32" - # Construct Relay add program. - x = relay.var("x", relay.TensorType(shape=shape, dtype=dtype)) - ret = relay.add(x, relay.const(1.0)) - add_const_func = relay.Function([x], ret) - # Construct Relay subtract program. - x = relay.var("x", relay.TensorType(shape=shape, dtype=dtype)) - ret = relay.subtract(x, relay.const(1.0)) - sub_const_func = relay.Function([x], ret) + from tvm.relay import create_executor + from tvm.relay import transform - with micro.Session(DEV_CONFIG): - add_const_mod = relay_micro_build(add_const_func, DEV_CONFIG) - sub_const_mod = relay_micro_build(sub_const_func, DEV_CONFIG) + dshape = (1, 4, 16, 16) + dtype = 'int8' + func_name = 'fused_nn_conv2d' - x_in = np.random.uniform(size=shape[0]).astype(dtype) - add_const_mod.run(x=x_in) - add_result = add_const_mod.get_output(0).asnumpy() - sub_const_mod.run(x=x_in) - sub_result = sub_const_mod.get_output(0).asnumpy() + reset_gdbinit() - tvm.testing.assert_allclose( - add_result, x_in + 1.0) - tvm.testing.assert_allclose( - sub_result, x_in - 1.0) + # Construct Relay program. + x = relay.var("x", shape=dshape, dtype=dtype) + conv_expr = relay.nn.conv2d( + x, relay.var("w"), + kernel_size=(3, 3), + padding=(1, 1), + channels=4) + func = relay.Function(relay.analysis.free_vars(conv_expr), conv_expr) + mod = tvm.IRModule.from_expr(func) + mod = transform.InferType()(mod) + + x_shape = list(map(lambda x: x.value, mod['main'].params[0].checked_type.shape)) + w_shape = list(map(lambda x: x.value, mod['main'].params[1].checked_type.shape)) + out_shape = list(map(lambda x: x.value, mod['main'].ret_type.shape)) + + with tvm.transform.PassContext(config={ + "tir.disable_vectorize": True + }): + graph, c_mod, params = relay.build(mod, target="c") + + with micro.Session(DEV_CONFIG_A): + micro_mod = micro.create_micro_mod(c_mod, DEV_CONFIG_A) + candidate_func_name = func_name + for i in range(100): + try: + micro_func = micro_mod[candidate_func_name] + break + except tvm.TVMError as e: + candidate_func_name = f'{func_name}_{i}' + else: + assert False + ctx = tvm.micro_dev(0) + + x_data = tvm.nd.array(np.random.uniform(size=x_shape).astype(dtype), ctx) + w_data = tvm.nd.array(np.random.uniform(size=w_shape).astype(dtype), ctx) + result = tvm.nd.array(np.zeros(shape=out_shape, dtype=dtype), ctx) + micro_func(x_data, w_data, result) + + out_data = np.zeros(out_shape, dtype=dtype) + params = { 'x': x_data.asnumpy(), 'w': w_data.asnumpy() } + intrp = create_executor('debug') + expected_result = intrp.evaluate(mod['main'])(x_data, w_data) + + tvm.testing.assert_allclose(result.asnumpy(), expected_result.asnumpy()) def test_interleave_sessions(): @@ -196,8 +265,8 @@ def test_interleave_sessions(): ret = relay.add(x, relay.const(1.0)) add_const_func = relay.Function([x], ret) - sess_a = micro.Session(DEV_CONFIG) - sess_b = micro.Session(DEV_CONFIG) + sess_a = micro.Session(DEV_CONFIG_A) + sess_b = micro.Session(DEV_CONFIG_B) with sess_a: np_tensor_a = np.random.uniform(size=shape).astype(dtype) micro_tensor_a = tvm.nd.array(np_tensor_a, tvm.micro_dev(0)) @@ -205,13 +274,13 @@ def test_interleave_sessions(): np_tensor_b = np.random.uniform(size=shape).astype(dtype) micro_tensor_b = tvm.nd.array(np_tensor_b, tvm.micro_dev(0)) with sess_a: - add_const_mod = relay_micro_build(add_const_func, DEV_CONFIG) + add_const_mod = relay_micro_build(add_const_func, DEV_CONFIG_A) add_const_mod.run(x=micro_tensor_a) add_result = add_const_mod.get_output(0).asnumpy() tvm.testing.assert_allclose( add_result, np_tensor_a + 1.0) with sess_b: - add_const_mod = relay_micro_build(add_const_func, DEV_CONFIG) + add_const_mod = relay_micro_build(add_const_func, DEV_CONFIG_B) add_const_mod.run(x=micro_tensor_b) add_result = add_const_mod.get_output(0).asnumpy() tvm.testing.assert_allclose( @@ -230,15 +299,15 @@ def test_nested_sessions(): ret = relay.add(x, relay.const(1.0)) add_const_func = relay.Function([x], ret) - sess_a = micro.Session(DEV_CONFIG) - sess_b = micro.Session(DEV_CONFIG) + sess_a = micro.Session(DEV_CONFIG_A) + sess_b = micro.Session(DEV_CONFIG_B) with sess_a: np_tensor_a = np.random.uniform(size=shape).astype(dtype) micro_tensor_a = tvm.nd.array(np_tensor_a, tvm.micro_dev(0)) with sess_b: np_tensor_b = np.random.uniform(size=shape).astype(dtype) micro_tensor_b = tvm.nd.array(np_tensor_b, tvm.micro_dev(0)) - add_const_mod = relay_micro_build(add_const_func, DEV_CONFIG) + add_const_mod = relay_micro_build(add_const_func, DEV_CONFIG_A) add_const_mod.run(x=micro_tensor_a) add_result = add_const_mod.get_output(0).asnumpy() tvm.testing.assert_allclose( @@ -257,12 +326,12 @@ def test_inactive_session_use(): ret = relay.add(x, relay.const(1.0)) add_const_func = relay.Function([x], ret) - sess_a = micro.Session(DEV_CONFIG) - sess_b = micro.Session(DEV_CONFIG) + sess_a = micro.Session(DEV_CONFIG_A) + sess_b = micro.Session(DEV_CONFIG_B) with sess_a: np_tensor_a = np.random.uniform(size=shape).astype(dtype) micro_tensor_a = tvm.nd.array(np_tensor_a, tvm.micro_dev(0)) - add_const_mod = relay_micro_build(add_const_func, DEV_CONFIG) + add_const_mod = relay_micro_build(add_const_func, DEV_CONFIG_A) with sess_b: # These objects belong to `sess_a`. @@ -272,12 +341,38 @@ def test_inactive_session_use(): add_result, np_tensor_a + 1.0) +# TODO add workspace alloc/free stress test + if __name__ == "__main__": test_alloc() + print() + print('finished alloc test') + input('[press enter to continue]') test_add() + print() + print('finished add test') + input('[press enter to continue]') test_workspace_add() + print() + print('finished workspace add test') + input('[press enter to continue]') test_graph_runtime() - test_multiple_modules() + print() + print('finished graph runtime test') + input('[press enter to continue]') + test_conv2d() + print() + print('finished conv2d test') + input('[press enter to continue]') test_interleave_sessions() + print() + print('finished interleaved sessions test') + input('[press enter to continue]') test_nested_sessions() + print() + print('finished nested sessions test') + input('[press enter to continue]') test_inactive_session_use() + print() + print('finished use inactive session test') + input('[press enter to continue]') diff --git a/tests/python/unittest/test_runtime_module_export.py b/tests/python/unittest/test_runtime_module_export.py index fce7d2f350dc..8473a67e6e41 100644 --- a/tests/python/unittest/test_runtime_module_export.py +++ b/tests/python/unittest/test_runtime_module_export.py @@ -67,7 +67,7 @@ def verify_gpu_mod_export(obj_format): resnet18_mod, resnet18_params = relay.testing.resnet.get_workload(num_layers=18) resnet50_mod, resnet50_params = relay.testing.resnet.get_workload(num_layers=50) - with relay.build_config(opt_level=3): + with tvm.transform.PassContext(opt_level=3): _, resnet18_gpu_lib, _ = relay.build_module.build(resnet18_mod, "cuda", params=resnet18_params) _, resnet50_cpu_lib, _ = relay.build_module.build(resnet50_mod, "llvm", params=resnet50_params) @@ -93,7 +93,7 @@ def verify_multi_dso_mod_export(obj_format): return resnet18_mod, resnet18_params = relay.testing.resnet.get_workload(num_layers=18) - with relay.build_config(opt_level=3): + with tvm.transform.PassContext(opt_level=3): _, resnet18_cpu_lib, _ = relay.build_module.build(resnet18_mod, "llvm", params=resnet18_params) A = te.placeholder((1024,), name='A') @@ -177,7 +177,7 @@ def verify_multi_c_mod_export(): return resnet18_mod, resnet18_params = relay.testing.resnet.get_workload(num_layers=18) - with relay.build_config(opt_level=3): + with tvm.transform.PassContext(opt_level=3): _, resnet18_cpu_lib, _ = relay.build_module.build(resnet18_mod, "llvm", params=resnet18_params) A = te.placeholder((1024,), name='A') diff --git a/tests/python/unittest/test_runtime_ndarray.py b/tests/python/unittest/test_runtime_ndarray.py index e3143794cc34..36312959da3d 100644 --- a/tests/python/unittest/test_runtime_ndarray.py +++ b/tests/python/unittest/test_runtime_ndarray.py @@ -72,6 +72,13 @@ def test_fp16_conversion(): tvm.testing.assert_allclose(expected, real) + +def test_dtype(): + dtype = tvm.DataType("handle") + assert dtype.type_code == tvm.DataTypeCode.HANDLE + + if __name__ == "__main__": test_nd_create() test_fp16_conversion() + test_dtype() diff --git a/tests/python/unittest/test_runtime_rpc.py b/tests/python/unittest/test_runtime_rpc.py index b61e6bb9fa01..7f01f880cd3d 100644 --- a/tests/python/unittest/test_runtime_rpc.py +++ b/tests/python/unittest/test_runtime_rpc.py @@ -18,13 +18,15 @@ from tvm import te import tvm.testing import os +import stat import logging import time import multiprocessing +import pytest import numpy as np from tvm import rpc -from tvm.contrib import util +from tvm.contrib import util, cc from tvm.rpc.tracker import Tracker @@ -77,15 +79,29 @@ def remotethrow(name): f1 = client.get_function("rpc.test.addone") assert f1(10) == 11 f3 = client.get_function("rpc.test.except") - try: + + with pytest.raises(tvm.error.RPCError): f3("abc") - assert False - except tvm.error.TVMError as e: - assert "abc" in str(e) f2 = client.get_function("rpc.test.strcat") assert f2("abc", 11) == "abc:11" + +def test_rpc_runtime_string(): + if not tvm.runtime.enabled("rpc"): + return + @tvm.register_func("rpc.test.runtime_str_concat") + def strcat(x, y): + return x + y + + server = rpc.Server("localhost", key="x1") + client = rpc.connect(server.host, server.port, key="x1") + func = client.get_function("rpc.test.runtime_str_concat") + x = tvm.runtime.container.String("abc") + y = tvm.runtime.container.String("def") + assert str(func(x, y)) == "abcdef" + + def test_rpc_array(): if not tvm.runtime.enabled("rpc"): return @@ -101,6 +117,58 @@ def remote_array_func(y): fremote = remote.get_function("rpc.test.remote_array_func") fremote(r_cpu) + +def test_rpc_large_array(): + # testcase of large array creation + server = rpc.Server("localhost") + remote = rpc.connect(server.host, server.port) + ctx = remote.cpu(0) + a_np = np.ones((5041, 720)).astype('float32') + b_np = np.ones((720, 192)).astype('float32') + a = tvm.nd.array(a_np, ctx) + b = tvm.nd.array(b_np, ctx) + np.testing.assert_equal(a.asnumpy(), a_np) + np.testing.assert_equal(b.asnumpy(), b_np) + + +def test_rpc_echo(): + def check(remote): + fecho = remote.get_function("testing.echo") + assert(fecho(1, 2, 3) == 1) + assert(fecho(100, 2, 3) == 100) + assert(fecho("xyz") == "xyz") + assert(bytes(fecho(bytearray(b"123"))) == b"123") + + with pytest.raises(RuntimeError): + raise_err = remote.get_function( + "testing.test_raise_error_callback")("RuntimeError") + raise_err() + + remote.cpu().sync() + with pytest.raises(AttributeError): + f3 = remote.system_lib()["notexist"] + + + temp = rpc.server._server_env([]) + server = rpc.Server("localhost") + client = rpc.connect(server.host, server.port) + check(rpc.LocalSession()) + + check(client) + # Test minrpc server. + temp = util.tempdir() + minrpc_exec = temp.relpath("minrpc") + tvm.rpc.with_minrpc(cc.create_executable)(minrpc_exec, []) + check(rpc.PopenSession(minrpc_exec)) + # minrpc on the remote + server = rpc.Server("localhost") + client = rpc.connect( + server.host, server.port, + session_constructor_args=["rpc.PopenSession", + open(minrpc_exec, "rb").read()]) + check(client) + + def test_rpc_file_exchange(): if not tvm.runtime.enabled("rpc"): return @@ -114,14 +182,20 @@ def test_rpc_file_exchange(): def test_rpc_remote_module(): if not tvm.runtime.enabled("rpc"): return - server = rpc.Server("localhost") - client = rpc.connect(server.host, server.port) # graph - n = tvm.runtime.convert(1024) + n = tvm.runtime.convert(102) A = te.placeholder((n,), name='A') B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name='B') s = te.create_schedule(B.op) + server0 = rpc.Server("localhost", key="x0") + server1 = rpc.Server("localhost", key="x1") + + client = rpc.connect( + server0.host, server0.port, key="x0", + session_constructor_args=[ + "rpc.Connect", server1.host, server1.port, "x1"]) + def check_remote(remote): if not tvm.runtime.enabled("llvm"): print("Skip because llvm is not enabled") @@ -133,13 +207,45 @@ def check_remote(remote): f.export_library(path_dso) remote.upload(path_dso) f1 = remote.load_module("dev_lib.so") - a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx) - b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx) + a = tvm.nd.array(np.random.uniform(size=102).astype(A.dtype), ctx) + b = tvm.nd.array(np.zeros(102, dtype=A.dtype), ctx) time_f = f1.time_evaluator(f1.entry_name, remote.cpu(0), number=10) cost = time_f(a, b).mean print('%g secs/op' % cost) np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1) + def check_minrpc(): + if not tvm.runtime.enabled("llvm"): + print("Skip because llvm is not enabled") + return + if tvm.get_global_func("rpc.PopenSession", allow_missing=True) is None: + return + # export to minrpc + temp = util.tempdir() + f = tvm.build(s, [A, B], "llvm --system-lib", name="myadd") + path_minrpc = temp.relpath("dev_lib.minrpc") + f.export_library(path_minrpc, rpc.with_minrpc(cc.create_executable)) + + with pytest.raises(RuntimeError): + rpc.PopenSession("filenotexist") + + # statrt the minrpc session. + remote = tvm.rpc.PopenSession(path_minrpc) + ctx = remote.cpu(0) + f1 = remote.system_lib() + + a = tvm.nd.array(np.random.uniform(size=102).astype(A.dtype), ctx) + b = tvm.nd.array(np.zeros(102, dtype=A.dtype), ctx) + time_f = f1.time_evaluator("myadd", remote.cpu(0), number=1) + cost = time_f(a, b).mean + np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1) + + # change to not executable + os.chmod(path_minrpc, stat.S_IRUSR) + with pytest.raises(RuntimeError): + rpc.PopenSession(path_minrpc) + + def check_remote_link_cl(remote): """Test function to run remote code such as cl @@ -174,8 +280,8 @@ def check_remote_link_cl(remote): fhost = remote.load_module("myadd.o") fdev = remote.load_module("myadd.cl") fhost.import_module(fdev) - a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx) - b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx) + a = tvm.nd.array(np.random.uniform(size=102).astype(A.dtype), ctx) + b = tvm.nd.array(np.zeros(102, dtype=A.dtype), ctx) fhost(a, b) np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1) # Option 2: export library as a tar ball then handled by remote compiler @@ -183,13 +289,15 @@ def check_remote_link_cl(remote): f.export_library(path_tar) remote.upload(path_tar) fhost = remote.load_module("myadd.tar") - a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx) - b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx) + a = tvm.nd.array(np.random.uniform(size=102).astype(A.dtype), ctx) + b = tvm.nd.array(np.zeros(102, dtype=A.dtype), ctx) fhost(a, b) np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1) - check_remote(client) check_remote(rpc.LocalSession()) + check_remote(client) + check_minrpc() + def test_rpc_return_func(): @@ -204,6 +312,37 @@ def addone(x): assert fadd(12) == 22 +def test_rpc_session_constructor_args(): + # start server + server0 = rpc.Server("localhost", key="x0") + server1 = rpc.Server("localhost", key="x1") + + def check_multi_hop(): + # use server0 as proxy to connect to server1 + client = rpc.connect( + server0.host, server0.port, key="x0", + session_constructor_args=[ + "rpc.Connect", server1.host, server1.port, "x1"]) + + fecho = client.get_function("testing.echo") + assert(fecho(1, 2, 3) == 1) + assert(fecho(100, 2, 3) == 100) + assert(fecho("xyz") == "xyz") + assert(bytes(fecho(bytearray(b"123"))) == b"123") + + nd = tvm.nd.array([1,2,3], ctx=client.cpu(0)) + assert(nd.asnumpy()[1] == 2) + + def check_error_handling(): + with pytest.raises(tvm.error.RPCError): + client = rpc.connect( + server0.host, server0.port, key="x0", + session_constructor_args=["rpc.NonExistingConstructor"]) + + check_multi_hop() + check_error_handling() + + def test_rpc_return_ndarray(): # Use closure to check the ref counter correctness nd = tvm.nd.array(np.zeros(10).astype("float32")) @@ -221,6 +360,7 @@ def my_module(name): # start server server = rpc.Server("localhost", key="x1") client = rpc.connect(server.host, server.port, key="x1") + m = client.get_function("rpc.test.remote_return_nd") get_arr = m("get_arr") ref_count = m("ref_count") @@ -315,6 +455,7 @@ def target(host, port, device_key, timeout): time.sleep(0.5) summary = client.summary() + assert summary['queue_info'][device_key]['free'] == 0 assert summary['queue_info'][device_key]['pending'] == 1 @@ -334,6 +475,8 @@ def target(host, port, device_key, timeout): if __name__ == "__main__": logging.basicConfig(level=logging.INFO) + test_rpc_echo() + test_rpc_session_constructor_args() test_rpc_return_ndarray() test_rpc_return_func() test_bigendian_rpc() @@ -344,3 +487,4 @@ def target(host, port, device_key, timeout): test_local_func() test_rpc_tracker_register() test_rpc_tracker_request() + test_rpc_large_array() diff --git a/tests/python/unittest/test_target_codegen_blob.py b/tests/python/unittest/test_target_codegen_blob.py index 719ddfe2a820..7cd579397ec8 100644 --- a/tests/python/unittest/test_target_codegen_blob.py +++ b/tests/python/unittest/test_target_codegen_blob.py @@ -31,7 +31,7 @@ def test_resnet18(): def verify(data): mod, params = relay.testing.resnet.get_workload(num_layers=18) - with relay.build_config(opt_level=3): + with tvm.transform.PassContext(opt_level=3): graph, lib, graph_params = relay.build_module.build(mod, "llvm", params=params) ctx = tvm.cpu() module = graph_runtime.create(graph, lib, ctx) @@ -42,7 +42,7 @@ def verify(data): return out resnet18_mod, resnet18_params = relay.testing.resnet.get_workload(num_layers=18) - with relay.build_config(opt_level=3): + with tvm.transform.PassContext(opt_level=3): graph, resnet18_gpu_lib, graph_params = relay.build_module.build(resnet18_mod, "cuda", params=resnet18_params) from tvm.contrib import util diff --git a/tests/python/unittest/test_target_codegen_c_host.py b/tests/python/unittest/test_target_codegen_c_host.py index c96531e4710e..0f00e08f9192 100644 --- a/tests/python/unittest/test_target_codegen_c_host.py +++ b/tests/python/unittest/test_target_codegen_c_host.py @@ -91,8 +91,7 @@ def check_c(): tvm.testing.assert_allclose( c.asnumpy(), a.asnumpy() + b.asnumpy()) - with tvm.target.build_config(offset_factor=4): - check_c() + check_c() def test_reinterpret(): diff --git a/tests/python/unittest/test_target_codegen_cuda.py b/tests/python/unittest/test_target_codegen_cuda.py index 739fc6fda76d..1a7163ff129d 100644 --- a/tests/python/unittest/test_target_codegen_cuda.py +++ b/tests/python/unittest/test_target_codegen_cuda.py @@ -55,7 +55,12 @@ def check_cuda(dtype, n, lanes): check_cuda("float32", 64, 2) check_cuda("float32", 64, 3) check_cuda("float32", 64, 4) + check_cuda("int8", 64, 2) + check_cuda("int8", 64, 3) check_cuda("int8", 64, 4) + check_cuda("uint8", 64, 2) + check_cuda("uint8", 64, 3) + check_cuda("uint8", 64, 4) check_cuda("float16", 64, 2) check_cuda("float16", 64, 4) check_cuda("float16", 64, 6) @@ -112,15 +117,17 @@ def check_cuda(dtype, n, lanes): b = tvm.nd.empty((n,), B.dtype, ctx) fun(a,b) tvm.testing.assert_allclose(a.asnumpy(), b.asnumpy()) + check_cuda("int8", 64, 2) + check_cuda("int8", 64, 3) + check_cuda("int8", 64, 4) check_cuda("int8", 64, 8) check_cuda("int8", 64, 16) -def test_cuda_make_int8x4(): - def check_cuda(n, value): +def test_cuda_make_int8(): + def check_cuda(n, value, lanes): if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"): print("skip because cuda is not enabled..") return - lanes = 4 dtype = 'int8' ctx = tvm.gpu(0) A = te.compute((n, lanes), lambda i,j: tvm.tir.const(value, dtype=dtype)) @@ -133,9 +140,15 @@ def check_cuda(n, value): a = tvm.nd.empty(np_a.shape, dtype, ctx) fun(a) np.testing.assert_equal(a.asnumpy(), np_a) - check_cuda(64, 0xAB) - check_cuda(64, 0) - check_cuda(64, -3) + check_cuda(64, 0xAB, 4) + check_cuda(64, 0, 4) + check_cuda(64, -3, 4) + check_cuda(64, 0xAB, 3) + check_cuda(64, 0, 3) + check_cuda(64, -3, 3) + check_cuda(64, 0xAB, 2) + check_cuda(64, 0, 2) + check_cuda(64, -3, 2) def test_cuda_inf_nan(): @@ -182,7 +195,7 @@ def test_cuda_shuffle(): sch[c].bind(xo, thrx) sch[c].vectorize(xi) - def my_vectorize(stmt): + def MyVectorize(): def vectorizer(op): if op.for_type == tvm.tir.For.Vectorized: four = tvm.tir.const(4, 'int32') @@ -198,9 +211,13 @@ def vectorizer(op): new_b = tvm.tir.Shuffle(bs, ids) return tvm.tir.Store(store.buffer_var, new_a + new_b, idx, all_ones) return None - return tvm.tir.ir_pass.IRTransform(stmt, None, vectorizer, ['For']) - with tvm.target.build_config(add_lower_pass=[(1, my_vectorize)]): + def _transform(f, *_): + return f.with_body( + tvm.tir.stmt_functor.ir_transform(f.body, None, vectorizer, ['tir.For'])) + return tvm.tir.transform.prim_func_pass(_transform, opt_level=0, name="MyVectorize") + + with tvm.transform.PassContext(config={"tir.add_lower_pass": [(1, MyVectorize())]}): module = tvm.build(sch, [a, b, c], target='cuda') a_ = np.array(list(range(64)), dtype='int32') b_ = np.array((list(range(4))[::-1]) * 16, dtype='int32') @@ -210,8 +227,95 @@ def vectorizer(op): module(nda, ndb, ndc) tvm.testing.assert_allclose(ndc.asnumpy(), ref) +def test_crossthread_reduction1(): + def check(device): + ctx = tvm.context(device, 0) + if not ctx.exist or not tvm.runtime.enabled(device): + print("skip because", device, "is not enabled..") + return + n = te.var("n") + m = te.var("m") + A = te.placeholder((n, m), name='A') + k = te.reduce_axis((0, m), "m") + B = te.compute((n,), lambda i: te.sum(A[i, k], axis=k), name="B") + + def sched(nthd): + s = te.create_schedule(B.op) + ko, _ = s[B].split(B.op.reduce_axis[0], nparts=nthd) + s[B].bind(ko, te.thread_axis("threadIdx.x")) + s[B].bind(B.op.axis[0], te.thread_axis("blockIdx.x")) + func = tvm.build(s, [A, B], device) + return func + + def verify(nthd): + func = sched(nthd) + nn = 3 + # checks three typical cases + vals = [nthd-1, nthd, nthd+1] + for kk in [x for x in vals]: + size = (nn, kk) + a = tvm.nd.array(np.random.uniform(size=size).astype(A.dtype), ctx) + b = tvm.nd.array(np.zeros(nn, dtype=B.dtype), ctx) + func(a, b) + tvm.testing.assert_allclose(b.asnumpy(), \ + np.sum(a.asnumpy(), axis=1), rtol=1e-3) + + verify(16) + verify(32) + verify(64) + + check("cuda") + check("rocm") + + +def test_crossthread_reduction2(): + def check(device): + ctx = tvm.context(device, 0) + if not ctx.exist or not tvm.runtime.enabled(device): + print("skip because", device, "is not enabled..") + return -def test_cuda_reducition_binding(): + n = te.var("n") + k0 = te.var("k0") + k1 = te.var("k1") + A = te.placeholder((n, k0, k1), name='A') + k0 = te.reduce_axis((0, k0), "k0") + k1 = te.reduce_axis((0, k1), "k1") + B = te.compute((n,), lambda i: te.sum(A[i, k0, k1], axis=(k0, k1)), name="B") + + def sched(nthdx, nthdy): + s = te.create_schedule(B.op) + k0o, _ = s[B].split(B.op.reduce_axis[0], nparts=nthdx) + k1o, _ = s[B].split(B.op.reduce_axis[1], nparts=nthdy) + s[B].bind(k0o, te.thread_axis("threadIdx.x")) + s[B].bind(k1o, te.thread_axis("threadIdx.y")) + s[B].bind(B.op.axis[0], te.thread_axis("blockIdx.x")) + func = tvm.build(s, [A, B], device) + return func + + def verify(nthdx, nthdy): + func = sched(nthdx, nthdy) + nn = 3 + # checks three typical cases + vx = [nthdx-1, nthdx, nthdx+1] + vy = [nthdy-1, nthdy, nthdy+1] + for kk0, kk1 in [(x, y) for x in vx for y in vy]: + size = (nn, kk0, kk1) + a = tvm.nd.array(np.random.uniform(size=size).astype(A.dtype), ctx) + b = tvm.nd.array(np.zeros(nn, dtype=B.dtype), ctx) + func(a, b) + tvm.testing.assert_allclose(b.asnumpy(), \ + np.sum(a.asnumpy(), axis=(1, 2)), rtol=1e-3) + + verify(16, 16) + verify(32, 32) + verify(16, 32) + verify(32, 16) + + check("cuda") + check("rocm") + +def test_cuda_reduction_binding(): if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"): print("skip because cuda is not enabled..") return @@ -231,39 +335,43 @@ def test_cuda_reducition_binding(): fcuda = tvm.build(s, [A, B], "cuda") def test_rfactor_predicates(): - if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"): - print("skip because cuda is not enabled..") - return + def check(device): + ctx = tvm.context(device, 0) + if not ctx.exist or not tvm.runtime.enabled(device): + print("skip because", device, "is not enabled..") + return - n = te.reduce_axis((0, 129), 'n') - A = te.placeholder((129,), name='A') - B = te.compute( (1, ), lambda b: - te.sum(A[n], - axis=n), - name='B' - ) + n = te.reduce_axis((0, 129), 'n') + A = te.placeholder((129,), name='A') + B = te.compute( (1, ), lambda b: + te.sum(A[n], + axis=n), + name='B' + ) - s = te.create_schedule(B.op) + s = te.create_schedule(B.op) - _, ni = s[B].split(s[B].op.reduce_axis[0], factor=8) + _, ni = s[B].split(s[B].op.reduce_axis[0], factor=8) - BF = s.rfactor(B, ni, 0) - s[B].set_store_predicate(tx.var.equal(0)) + BF = s.rfactor(B, ni, 0) + s[B].set_store_predicate(tx.var.equal(0)) - s[B].bind(s[B].op.reduce_axis[0], tx) - s[B].bind(s[B].op.axis[0], bx) + s[B].bind(s[B].op.reduce_axis[0], tx) + s[B].bind(s[B].op.axis[0], bx) - s[BF].compute_at(s[B], s[B].op.axis[0]) + s[BF].compute_at(s[B], s[B].op.axis[0]) - _, noi = s[BF].split(s[BF].op.reduce_axis[0], factor=2) + _, noi = s[BF].split(s[BF].op.reduce_axis[0], factor=2) - BF2 = s.rfactor(BF, noi, 0) + BF2 = s.rfactor(BF, noi, 0) - s[BF].bind(s[BF].op.axis[0], tx) - s[BF2].compute_at(s[BF], s[BF].op.axis[1]) + s[BF].bind(s[BF].op.axis[0], tx) + s[BF2].compute_at(s[BF], s[BF].op.axis[1]) - fcuda = tvm.build(s, [A, B], "cuda") + fcuda = tvm.build(s, [A, B], device) + check("cuda") + check("rocm") @unittest.skipIf(not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"), "skip because cuda is not enabled..") def test_cuda_const_float_to_half(): @@ -291,11 +399,12 @@ def test_cuda_const_float_to_half(): np.testing.assert_equal(c.asnumpy(), a_np > b.value) def test_cuda_reduction(): - def check_cuda(dtype, m=32, n=32): - if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"): - print("skip because cuda is not enabled..") + def check(device, dtype, m=32, n=32): + ctx = tvm.context(device, 0) + if not ctx.exist or not tvm.runtime.enabled(device): + print("skip because", device, "is not enabled..") return - if dtype == "float16" and not have_fp16(tvm.gpu(0).compute_version): + if dtype == "float16" and not have_fp16(ctx.compute_version): print("Skip because gpu does not have fp16 support") return @@ -305,10 +414,9 @@ def check_cuda(dtype, m=32, n=32): d = a * b e = topi.elemwise_sum([c, d]) g = topi.sum(e) - with tvm.target.cuda(): + with tvm.target.create(device): sg = topi.cuda.schedule_reduce(g) - ctx = tvm.gpu(0) - func = tvm.build(sg, [a, b, g], 'cuda') + func = tvm.build(sg, [a, b, g], device) a_np = np.random.uniform(size=(m, n)).astype(a.dtype) b_np = np.random.uniform(size=(m, n)).astype(b.dtype) g_np = np.sum(np.add(a_np * b_np, a_np + b_np)) @@ -318,26 +426,27 @@ def check_cuda(dtype, m=32, n=32): func(a_nd, b_nd, g_nd) tvm.testing.assert_allclose(g_nd.asnumpy(), g_np, rtol=1e-3) - check_cuda("float32") - check_cuda("float16") + check("cuda", "float32") + check("rocm", "float32") + check("cuda", "float16") def test_cuda_mix_threaded_and_normal_reduction(): - def check_cuda(dtype, m=32, n=32): - if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"): - print("skip because cuda is not enabled..") + def check(device, dtype, m=32, n=32): + ctx = tvm.context(device, 0) + if not ctx.exist or not tvm.runtime.enabled(device): + print("skip because", device, "is not enabled..") return - if dtype == "float16" and not have_fp16(tvm.gpu(0).compute_version): + if dtype == "float16" and not have_fp16(ctx.compute_version): print("Skip because gpu does not have fp16 support") return a = tvm.te.placeholder((m, n), name="a", dtype=dtype) b = topi.sum(a) - with tvm.target.cuda(): + with tvm.target.create(device): sb = tvm.te.create_schedule(b.op) i, _ = b.op.reduce_axis sb[b].bind(i, tvm.te.thread_axis("threadIdx.x")) - ctx = tvm.gpu(0) - func = tvm.build(sb, [a, b], 'cuda') + func = tvm.build(sb, [a, b], device) a_np = np.random.uniform(size=(m, n)).astype(a.dtype) b_np = np.sum(a_np) a_nd = tvm.nd.array(a_np, ctx) @@ -345,8 +454,9 @@ def check_cuda(dtype, m=32, n=32): func(a_nd, b_nd) tvm.testing.assert_allclose(b_nd.asnumpy(), b_np, rtol=1e-3) - check_cuda("float32") - check_cuda("float16") + check("cuda", "float32") + check("rocm", "float32") + check("cuda", "float16") def test_cuda_floordiv_with_vectorization(): if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"): @@ -575,6 +685,8 @@ def check_cuda(dtype, n, l, padding, lanes): (0, 0)), mode='constant', constant_values=0) tvm.testing.assert_allclose(b.asnumpy(), ref) + check_cuda("int8", 64, 16, 3, 2) + check_cuda("uint8", 64, 16, 3, 2) check_cuda("int8", 64, 16, 3, 4) check_cuda("uint8", 64, 16, 3, 4) check_cuda("int32", 64, 16, 3, 4) @@ -585,11 +697,13 @@ def check_cuda(dtype, n, l, padding, lanes): test_cuda_vectorize_add() test_cuda_multiply_add() test_cuda_vectorize_load() - test_cuda_make_int8x4() + test_cuda_make_int8() test_cuda_inf_nan() test_cuda_shuffle() test_vectorized_casts() - test_cuda_reducition_binding() + test_cuda_reduction_binding() + test_crossthread_reduction1() + test_crossthread_reduction2() test_rfactor_predicates() test_cuda_const_float_to_half() test_cuda_reduction() diff --git a/tests/python/unittest/test_target_codegen_llvm.py b/tests/python/unittest/test_target_codegen_llvm.py index 44b05c90ff17..1173b71ade6f 100644 --- a/tests/python/unittest/test_target_codegen_llvm.py +++ b/tests/python/unittest/test_target_codegen_llvm.py @@ -21,6 +21,7 @@ import numpy as np import ctypes import math +import re def test_llvm_intrin(): @@ -33,7 +34,7 @@ def test_llvm_intrin(): ] ib.emit(tvm.tir.Evaluate( tvm.tir.Call( - "int32", "prefetch", args, tvm.tir.Call.Intrinsic, None, 0))) + "int32", "prefetch", args, tvm.tir.Call.Intrinsic))) body = ib.get() mod = tvm.IRModule.from_expr( @@ -43,6 +44,18 @@ def test_llvm_intrin(): fcode = tvm.build(mod, None, "llvm") +def test_llvm_void_intrin(): + ib = tvm.tir.ir_builder.create() + A = ib.pointer("uint8", name="A") + # Create an intrinsic that returns void. + x = tvm.tir.call_llvm_intrin('', 'llvm.va_start', tvm.tir.const(1, 'uint32'), A) + ib.emit(x) + body = ib.get() + mod = tvm.IRModule.from_expr( + tvm.tir.PrimFunc([A], body).with_attr("global_symbol", "main")) + fcode = tvm.build(mod, None, "llvm") + + def test_llvm_overloaded_intrin(): # Name lookup for overloaded intrinsics in LLVM 4- requires a name # that includes the overloaded types. @@ -178,8 +191,7 @@ def check_llvm(): tvm.testing.assert_allclose( c.asnumpy(), a.asnumpy() + b.asnumpy()) - with tvm.target.build_config(offset_factor=4): - check_llvm() + check_llvm() def test_llvm_persist_parallel(): @@ -293,7 +305,8 @@ def check_llvm(nn, base, stride): c.asnumpy(), a.asnumpy()[base:] + 1) check_llvm(64, 0, 2) check_llvm(4, 0, 1) - with tvm.target.build_config(restricted_func=False): + + with tvm.transform.PassContext(config={"tir.noalias": False}): check_llvm(4, 0, 3) @@ -422,7 +435,7 @@ def test_rank_zero_bound_checkers(): def check_llvm(n): if not tvm.runtime.enabled("llvm"): return - with tvm.target.build_config(instrument_bound_checkers=True): + with tvm.transform.PassContext(config={"tir.instrument_bound_checkers": True}): A = te.placeholder((n, ), name='A') scale = te.placeholder((), name='scale') k = te.reduce_axis((0, n), name="k") @@ -450,12 +463,39 @@ def test_alignment(): s = te.create_schedule(B.op) bx, tx = s[B].split(B.op.axis[0], factor=8) s[B].vectorize(tx) - f = tvm.build(s, [A, B], "llvm") + f = tvm.build(s, [A, B], "llvm", name="test_alignment") - for l in f.get_source().split("\n"): + lines = f.get_source().split("\n") + + # Check alignment on load/store. + for l in lines: if "align" in l and "4 x float" in l: assert "align 32" in l + # Check parameter alignment. This looks for the definition of the + # outlined "compute_" function to see if there is an "align" attribute + # listed there. + def has_param_alignment(): + for l in lines: + if re.search(r'test_alignment_compute_\([^(]*align [0-9]', l): + return True + return False + + if tvm.target.codegen.llvm_version_major() >= 5: + assert has_param_alignment() + + # Check for assume intrinsics. This isn't 100% accurate, since it just + # checks if the llvm.assume is there, but detailed check would require + # a much more detailed analysis of the LLVM IR. + def has_call_to_assume(): + for l in lines: + if re.search(r'call.*llvm.assume', l): + return True + return False + + assert has_call_to_assume() + + def test_llvm_div(): """Check that the semantics of div and mod is correct""" def check(start, end, dstart, dend, dtype, floor_div=False): @@ -613,7 +653,6 @@ def check_llvm_object(): temp = util.tempdir() o_path = temp.relpath("temp.o") m.save(o_path) - import re import shutil import subprocess import sys @@ -671,8 +710,7 @@ def test_llvm_shuffle(): c = te.compute((8, ), lambda x: a[x] + b[7-x]) sch = te.create_schedule(c.op) - def my_vectorize(stmt): - + def my_vectorize(): def vectorizer(op): store = op.body idx = tvm.tir.Ramp(tvm.tir.const(0, 'int32'), tvm.tir.const(1, 'int32'), 8) @@ -684,9 +722,13 @@ def vectorizer(op): value = new_a + new_b return tvm.tir.Store(store.buffer_var, new_a + new_b, idx, all_ones) - return tvm.tir.ir_pass.IRTransform(stmt, None, vectorizer, ['For']) + def _transform(f, *_): + return f.with_body( + tvm.tir.stmt_functor.ir_transform(f.body, None, vectorizer, ['tir.For'])) + + return tvm.tir.transform.prim_func_pass(_transform, opt_level=0, name="my_vectorize") - with tvm.target.build_config(add_lower_pass=[(1, my_vectorize)]): + with tvm.transform.PassContext(config={"tir.add_lower_pass": [(1, my_vectorize())]}): ir = tvm.lower(sch, [a, b, c], simple_mode=True) module = tvm.build(sch, [a, b, c]) a_ = tvm.nd.array(np.arange(1, 9, dtype='int32')) diff --git a/tests/python/unittest/test_target_codegen_rocm.py b/tests/python/unittest/test_target_codegen_rocm.py index f107e592d2d3..4c6304a7a31f 100644 --- a/tests/python/unittest/test_target_codegen_rocm.py +++ b/tests/python/unittest/test_target_codegen_rocm.py @@ -76,7 +76,7 @@ def check_inf_nan(ctx, n, value, dtype): check_inf_nan(ctx, 1, float('nan'), 'float64') @unittest.skipIf(not tvm.rocm(0).exist or not tvm.runtime.enabled("rocm"), "skip because rocm is not enabled..") -def test_rocm_reducition_binding(): +def test_rocm_reduction_binding(): k = te.reduce_axis((0, 32), 'k') A = te.placeholder((96, 32), name='A') B = te.compute( (96,), lambda m: @@ -132,6 +132,6 @@ def check_rocm(dtype, n, lanes): if __name__ == "__main__": test_rocm_cross_thread_reduction() test_rocm_inf_nan() - test_rocm_reducition_binding() + test_rocm_reduction_binding() test_rocm_copy() test_rocm_vectorize_add() diff --git a/tests/python/unittest/test_te_build_lower.py b/tests/python/unittest/test_te_build_lower.py index 442c4fed7b2f..b1d754605a46 100644 --- a/tests/python/unittest/test_te_build_lower.py +++ b/tests/python/unittest/test_te_build_lower.py @@ -48,9 +48,9 @@ def test_split_uneven_unique_likely(): x, y = c.op.axis sch = te.create_schedule(c.op) xo, xi = sch[c].split(x, 5) - stmt = tvm.lower(sch, [a, b, c], simple_mode=True) + stmt = tvm.lower(sch, [a, b, c])["main"].body assert isinstance(stmt.body.body.body, tvm.tir.stmt.IfThenElse) - assert str(stmt.body.body.body).count("likely") == 1 + if __name__ == "__main__": test_lower_rfactor() diff --git a/tests/python/unittest/test_te_hybrid_script.py b/tests/python/unittest/test_te_hybrid_script.py index b525d018340d..8ab65f129cc5 100644 --- a/tests/python/unittest/test_te_hybrid_script.py +++ b/tests/python/unittest/test_te_hybrid_script.py @@ -24,8 +24,8 @@ @pytest.mark.skip def run_and_check(func, args, var_dict={}, target='llvm', sch=None, outs=None): def tvm_val_2_py_val(val): - val = tvm.tir.ir_pass.Substitute(val, var_dict) - val = tvm.tir.ir_pass.Simplify(val) + val = tvm.tir.stmt_functor.substitute(val, var_dict) + val = tvm.arith.Analyzer().simplify(val) assert isinstance(val, (tvm.tir.IntImm,)) return val.value @@ -131,16 +131,16 @@ def test_outer_product(): assert isinstance(jbody.message, tvm.tir.StringImm) assert jbody.message.value == "index out of range!" jbody = jblock[1] - assert isinstance(jbody, tvm.tir.Provide) - assert jbody.func.name == 'c' - assert len(jbody.args) == 2 - assert jbody.args[0].name == 'i' - assert jbody.args[1].name == 'j' + assert isinstance(jbody, tvm.tir.ProducerStore) + assert jbody.producer.op.name == 'c' + assert len(jbody.indices) == 2 + assert jbody.indices[0].name == 'i' + assert jbody.indices[1].name == 'j' assert isinstance(jbody.value, tvm.tir.Mul) mul = jbody.value - assert isinstance(mul.a, tvm.tir.Call) - assert mul.a.name == 'a' - assert mul.b.name == 'b' + assert isinstance(mul.a, tvm.tir.ProducerLoad) + assert mul.a.producer.name == 'a' + assert mul.b.producer.name == 'b' func, ins, outs = run_and_check(outer_product, [n, m, a, b], {n: 99, m: 101}) temp = util.tempdir() @@ -187,51 +187,51 @@ def fanout(n, a): ibody = ir.body assert isinstance(ibody, tvm.tir.AttrStmt) abody = ibody.body - assert isinstance(abody, tvm.tir.Realize) + assert isinstance(abody, tvm.tir.ProducerRealize) assert abody.bounds[0].min.value == 0 assert abody.bounds[0].extent.value == 1 - assert abody.func.name == 'sigma' + assert abody.producer.op.name == 'sigma' #Check i loop body rbody = abody.body - assert isinstance(rbody[0], tvm.tir.Provide) - assert rbody[0].func.name == 'sigma' - assert len(rbody[0].args) == 1 - assert rbody[0].args[0].value == 0 + assert isinstance(rbody[0], tvm.tir.ProducerStore) + assert rbody[0].producer.op.name == 'sigma' + assert len(rbody[0].indices) == 1 + assert rbody[0].indices[0].value == 0 #Check fanout loop jloop = rbody[1] assert jloop.loop_var.name == 'j' assert jloop.min.value == 0 assert jloop.extent.value == 3 jbody = jloop.body - assert isinstance(jbody, tvm.tir.Provide) - assert len(jbody.args) == 1 - assert jbody.args[0].value == 0 - assert jbody.func.name == 'sigma' + assert isinstance(jbody, tvm.tir.ProducerStore) + assert len(jbody.indices) == 1 + assert jbody.indices[0].value == 0 + assert jbody.producer.op.name == 'sigma' assert isinstance(jbody.value, tvm.tir.Add) value = jbody.value - assert isinstance(value.a, tvm.tir.Call) - assert value.a.name == 'sigma' - assert len(value.a.args) == 1 - assert value.a.args[0].value == 0 - assert value.b.name == 'a' - assert len(value.b.args) == 1 - assert tvm.ir.structural_equal(value.b.args[0], ir.loop_var + jloop.loop_var) + assert isinstance(value.a, tvm.tir.ProducerLoad) + assert value.a.producer.name == 'sigma' + assert len(value.a.indices) == 1 + assert value.a.indices[0].value == 0 + assert value.b.producer.name == 'a' + assert len(value.b.indices) == 1 + assert tvm.ir.structural_equal(value.b.indices[0], ir.loop_var + jloop.loop_var) divide= rbody[2] - assert isinstance(divide, tvm.tir.Provide) - assert len(divide.args) == 1 - assert divide.args[0].value == 0 + assert isinstance(divide, tvm.tir.ProducerStore) + assert len(divide.indices) == 1 + assert divide.indices[0].value == 0 value = divide.value assert isinstance(value, tvm.tir.Mul) - assert value.a.name == 'sigma' - assert len(value.a.args) == 1 - assert value.a.args[0].value == 0 + assert value.a.producer.name == 'sigma' + assert len(value.a.indices) == 1 + assert value.a.indices[0].value == 0 assert abs(value.b.value - (1 / 3.0)) < 1e-5 write = rbody[3] - assert isinstance(write, tvm.tir.Provide) - assert write.func.name == 'b' - assert write.value.name == 'sigma' - assert len(write.value.args) == 1 - assert write.value.args[0].value == 0 + assert isinstance(write, tvm.tir.ProducerStore) + assert write.producer.op.name == 'b' + assert write.value.producer.name == 'sigma' + assert len(write.value.indices) == 1 + assert write.value.indices[0].value == 0 func, ins, outs = run_and_check(fanout, [n, a], {n: 10}) run_and_check(func, ins, {n: 10}, outs=outs) @@ -365,7 +365,7 @@ def foo(a): a = te.placeholder((8, 4), 'float32') c = foo(a) s = te.create_schedule(c.op) - ir = tvm.lower(s, [a, c], simple_mode=True) + ir = tvm.lower(s, [a, c]) func, ins, outs = run_and_check(foo, [a], target='cuda') run_and_check(func, ins, outs=outs, target='cuda') @@ -517,7 +517,7 @@ def upstream(a): c = te.compute((20, ), lambda x: a[x] + b[x]) d = upstream(c) sch = te.create_schedule([c.op, d.op]) - ir = tvm.lower(sch, [a, b, d], simple_mode=True) + ir = tvm.lower(sch, [a, b, d]) func = tvm.build(sch, [a, b, d]) assert(func) @@ -730,7 +730,7 @@ def outer_product(a, b): joo, joi = sch[c].split(jo, 4) sch[c].vectorize(ji) sch[c].reorder(ii, io, joo, joi, ji) - ir = tvm.lower(sch, [a, b, c], simple_mode=True) + ir = tvm.lower(sch, [a, b, c])["main"].body assert isinstance(ir, tvm.tir.AttrStmt) ir = ir.body assert isinstance(ir, tvm.tir.For) @@ -751,7 +751,7 @@ def outer_product(a, b): # Test fuse sch = te.create_schedule(c.op) sch[c].fuse(c.op.axis[0], c.op.axis[1]) - ir = tvm.lower(sch, [a, b, c], simple_mode=True) + ir = tvm.lower(sch, [a, b, c])["main"].body assert isinstance(ir, tvm.tir.AttrStmt) ir = ir.body assert isinstance(ir, tvm.tir.For) diff --git a/tests/python/unittest/test_te_schedule.py b/tests/python/unittest/test_te_schedule.py index c9b422f7f0a4..2c851cc39789 100644 --- a/tests/python/unittest/test_te_schedule.py +++ b/tests/python/unittest/test_te_schedule.py @@ -115,7 +115,6 @@ def test_fuse_with_split(): assert any(isinstance(x, tvm.te.schedule.Fuse) for x in s[T].relations) assert tuple(s[T].leaf_iter_vars) == (xo, fused) -@pytest.mark.xfail def test_fuse_with_out_of_order_axis(): m = te.size_var('m') n = te.size_var('n') @@ -125,9 +124,10 @@ def test_fuse_with_out_of_order_axis(): s = te.create_schedule(T.op) y = T.op.axis[1] xo, xi = s[T].split(T.op.axis[0], factor=10) - fused = s[T].fuse(xo, y) # should throw here -@pytest.mark.xfail + with pytest.raises(RuntimeError): + fused = s[T].fuse(xo, y) # should throw here + def test_fuse_with_out_of_order_axis_with_reorder(): m = te.size_var('m') n = te.size_var('n') @@ -144,23 +144,21 @@ def test_fuse_with_out_of_order_axis_with_reorder(): y = T.op.axis[1] xo, xi = s[T].split(T.op.axis[0], factor=10) s[T].reorder(y, xo, xi) - fused = s[T].fuse(y, xi) # should throw here + + with pytest.raises(RuntimeError): + fused = s[T].fuse(y, xi) # should throw here def test_singleton(): - print("test singleton") A = te.placeholder((), name='A') T = te.compute((), lambda : A() + 1) s = te.create_schedule(T.op) - print("test singleton fin1") fused = s[T].fuse() assert any(isinstance(x, tvm.te.schedule.Singleton) for x in s[T].relations) assert tuple(s[T].leaf_iter_vars) == (fused,) dump = pkl.dumps(s) - print("test singleton fin3") s_loaded = pkl.loads(dump) - print("test singleton fin2") assert isinstance(s_loaded, tvm.te.schedule.Schedule) - print("test singleton fin") + def test_vectorize(): m = te.size_var('m') @@ -177,13 +175,14 @@ def test_vectorize(): assert s[T].iter_var_attrs[xi].iter_type == UNROLL assert s[T].iter_var_attrs[yi].iter_type == VECTORIZE -@pytest.mark.xfail + def test_vectorize_commreduce(): V = te.placeholder((128,), name='V') ax = te.reduce_axis((0, 128), name='ax') O = te.compute((1,), lambda _: te.sum(V[ax], axis=[ax])) s = te.create_schedule(O.op) - s[O].vectorize(ax) # should throw here + with pytest.raises(RuntimeError): + s[O].vectorize(ax) # should throw here def test_pragma(): m = 100 @@ -271,8 +270,9 @@ def intrin_func(ins, outs, sp): assert(sp[1] == w) return tvm.tir.call_packed("hw_func", ins[0].data, outs[0].data, sp[0], sp[1]) - with tvm.target.build_config(offset_factor=1): - intrin = te.decl_tensor_intrin(z.op, intrin_func, scalar_params=[v, w]) + intrin = te.decl_tensor_intrin(z.op, intrin_func, scalar_params=[v, w], default_buffer_params={ + "offset_factor": 1 + }) assert intrin.op == z.op assert intrin.reduce_init is None assert tuple(intrin.inputs) == tuple(z.op.input_tensors) @@ -283,11 +283,11 @@ def intrin_func(ins, outs, sp): # Pass scalar inputs to the TensorIntrin, interleaved with tensor inputs C = te.compute((10,10), lambda i, j: intrin(i*i, A[i, j], i+j), name="C") s = te.create_schedule(C.op) - stmt = tvm.lower(s, [A, C], simple_mode=True) + stmt = tvm.lower(s, [A, C])["main"].body assert isinstance(stmt.body.body, tvm.tir.Evaluate) assert len(stmt.body.body.value.args) == 5 - assert str(stmt.body.body.value.args[3]) == "(i*i)" - assert str(stmt.body.body.value.args[4]) == "(i + j)" + assert str(stmt.body.body.value.args[3]) == "(i: int32*i)" + assert str(stmt.body.body.value.args[4]) == "(i: int32 + j: int32)" if __name__ == "__main__": test_singleton() diff --git a/tests/python/unittest/test_te_schedule_bound_inference.py b/tests/python/unittest/test_te_schedule_bound_inference.py index edae527c0183..e226b7ad7703 100644 --- a/tests/python/unittest/test_te_schedule_bound_inference.py +++ b/tests/python/unittest/test_te_schedule_bound_inference.py @@ -139,19 +139,20 @@ def test_bound_fusesplit1(): bounds = tvm.te.schedule.InferBound(s) assert isinstance(bounds, tvm.container.Map) idxdiv = tvm.tir.indexdiv - assert(tvm.tir.ir_pass.Simplify( - bounds[A1.op.axis[0]].min - idxdiv(xo * split1, l)).value == 0) + tvm.testing.assert_prim_expr_equal( + bounds[A1.op.axis[0]].min, idxdiv(xo * split1, l)) expected_extent = (idxdiv((xo + 1) * split1 - 1, l) - idxdiv(xo * split1, l) + 1) for i in range(1, 6): for j in range(1, 6): for k in range(1, 6): vars = tvm.runtime.convert({split1: tvm.tir.const(i, "int32"), l: tvm.tir.const(j, "int32"), xo.var: tvm.tir.const(k, "int32")}) - comp_ext = tvm.tir.ir_pass.Simplify(tvm.tir.ir_pass.Substitute(bounds[A1.op.axis[0]].extent, vars)).value - exp_ext = tvm.tir.ir_pass.Simplify(tvm.tir.ir_pass.Substitute(expected_extent, vars)).value - assert(comp_ext == exp_ext) + tvm.testing.assert_prim_expr_equal( + tvm.tir.stmt_functor.substitute(bounds[A1.op.axis[0]].extent, vars), + tvm.tir.stmt_functor.substitute(expected_extent, vars) + ) - assert(tvm.tir.ir_pass.Simplify(bounds[A1.op.axis[1]].extent - l).value == 0) + tvm.testing.assert_prim_expr_equal(bounds[A1.op.axis[1]].extent, l) def test_bound_fusesplit2(): m = te.var("m") @@ -169,10 +170,10 @@ def test_bound_fusesplit2(): bounds = tvm.te.schedule.InferBound(s) assert isinstance(bounds, tvm.container.Map) vars = tvm.runtime.convert({xo.var: tvm.tir.const(5, "int32")}) - assert(tvm.tir.ir_pass.Simplify(tvm.tir.ir_pass.Substitute(bounds[A1.op.axis[0]].min, vars)).value == 2) - assert(tvm.tir.ir_pass.Simplify(tvm.tir.ir_pass.Substitute(bounds[A1.op.axis[1]].min, vars)).value == 3) - assert(tvm.tir.ir_pass.Simplify(tvm.tir.ir_pass.Substitute(bounds[A1.op.axis[0]].extent, vars)).value == 1) - assert(tvm.tir.ir_pass.Simplify(tvm.tir.ir_pass.Substitute(bounds[A1.op.axis[1]].extent, vars)).value == 3) + tvm.testing.assert_prim_expr_equal(tvm.tir.stmt_functor.substitute(bounds[A1.op.axis[0]].min, vars), 2) + tvm.testing.assert_prim_expr_equal(tvm.tir.stmt_functor.substitute(bounds[A1.op.axis[1]].min, vars), 3) + tvm.testing.assert_prim_expr_equal(tvm.tir.stmt_functor.substitute(bounds[A1.op.axis[0]].extent, vars), 1) + tvm.testing.assert_prim_expr_equal(tvm.tir.stmt_functor.substitute(bounds[A1.op.axis[1]].extent, vars), 3) def test_bound_warp(): @@ -383,23 +384,22 @@ def test_gemm_bound(): def test_bound_tensor_compute_op(): def intrin_test(): - m1 = te.var("m1") - n1 = te.var("n1") - a = te.placeholder((m1, n1), name='a') - c = te.compute((1, n1), lambda i, j : a[0, j] + a[1, j] + a[2, j], name='c') - - Ab = tvm.tir.decl_buffer(a.shape, name="Abuf", offset_factor=1) - Cb = tvm.tir.decl_buffer(c.shape, name="Cbuf", offset_factor=1) - - def intrin_func(ins, outs): - aa = ins[0] - cc = outs[0] - def _body(): - ib = tvm.tir.ir_builder.create() - ib.emit(tvm.tir.call_extern("int32", "test", cc.access_ptr("w"), aa.access_ptr("r"))) - return ib.get() - return _body() - with tvm.target.build_config(offset_factor=1): + m1 = te.var("m1") + n1 = te.var("n1") + a = te.placeholder((m1, n1), name='a') + c = te.compute((1, n1), lambda i, j : a[0, j] + a[1, j] + a[2, j], name='c') + + Ab = tvm.tir.decl_buffer(a.shape, name="Abuf", offset_factor=1) + Cb = tvm.tir.decl_buffer(c.shape, name="Cbuf", offset_factor=1) + + def intrin_func(ins, outs): + aa = ins[0] + cc = outs[0] + def _body(): + ib = tvm.tir.ir_builder.create() + ib.emit(tvm.tir.call_extern("int32", "test", cc.access_ptr("w"), aa.access_ptr("r"))) + return ib.get() + return _body() return te.decl_tensor_intrin(c.op, intrin_func, binds={a : Ab, c : Cb}) test_func = intrin_test() diff --git a/tests/python/unittest/test_te_schedule_bound_inference_tiling.py b/tests/python/unittest/test_te_schedule_bound_inference_tiling.py new file mode 100644 index 000000000000..3893bb6befda --- /dev/null +++ b/tests/python/unittest/test_te_schedule_bound_inference_tiling.py @@ -0,0 +1,60 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +from tvm import te + +def test_bound_tile_mod(): + def compute(M_tiles, N_tiles, factor, dtype): + # Algo + M = M_tiles * factor + N = N_tiles * factor + + A = tvm.te.placeholder((N, M), name='A', dtype=dtype) + C = tvm.te.compute((N, M), lambda n, m: A[n, m], name='C') + s = tvm.te.create_schedule(C.op) + + return s, A, C + + def schedule(s, factor, padding, A, C): + C_local = s.cache_write(C, "local") + + n, m = C.op.axis + bn, bm, ni, mi = s[C].tile(n, m, factor, factor) + nio, nii = s[C].split(ni, 2) + n = s[C].fuse(nii, mi) + C_shared = s.cache_write(C, "shared") + bn, bm, ni, mi = C_shared.op.axis + s[C_shared].storage_align(ni, factor * 2, padding) + + n, m = s[C].op.axis + bn, bm, ni, mi = s[C].tile(n, m, factor, factor) + s[C].set_scope("global") + niio, niii = s[C].split(ni, 32) + s[C_shared].compute_at(s[C], niio) + + return s + + s, A, C = compute(2, 2, 128, "float16") + s = schedule(s, 128, 8, A, C) + bounds = tvm.te.schedule.InferBound(s) + check = (bounds[s.stages[2].op.axis[2]].extent == 16) + if(not check): + print(tvm.lower(s, [A, C], simple_mode=True)) + assert(check) + +if __name__ == "__main__": + test_bound_tile_mod() diff --git a/tests/python/unittest/test_te_schedule_graph.py b/tests/python/unittest/test_te_schedule_graph.py index d6d38e5f05c9..7d11020a95fd 100644 --- a/tests/python/unittest/test_te_schedule_graph.py +++ b/tests/python/unittest/test_te_schedule_graph.py @@ -41,7 +41,7 @@ def test_attach_path(): def test_fix_pt(): body = tvm.te.schedule.ScanGetBody(s_scan.op) - fxpt = tvm.te.schedule.ScanFixPointAnalysis(s_scan.op, body) + fxpt = tvm.te.schedule.ScanFixPointAnalysis(s_scan.op) assert(fxpt[s_scan.spatial_axis_[0]].value != 0) def test_scan_fix_point(): @@ -57,7 +57,7 @@ def test_scan0(): lambda t, i, j: x[t, j, i] + s_state[t-1, i, j], name="update") s_scan = tvm.te.scan(s_init, s_update, s_state) body = tvm.te.schedule.ScanGetBody(s_scan.op) - fxpt = tvm.te.schedule.ScanFixPointAnalysis(s_scan.op, body) + fxpt = tvm.te.schedule.ScanFixPointAnalysis(s_scan.op) assert(fxpt[s_scan.op.spatial_axis_[0]].value == 1) assert(fxpt[s_scan.op.spatial_axis_[1]].value == 1) @@ -66,7 +66,7 @@ def test_scan1(): lambda t, i, j: x[t, j, i] + s_state[t-1, j, i], name="update") s_scan = tvm.te.scan(s_init, s_update, s_state) body = tvm.te.schedule.ScanGetBody(s_scan.op) - fxpt = tvm.te.schedule.ScanFixPointAnalysis(s_scan.op, body) + fxpt = tvm.te.schedule.ScanFixPointAnalysis(s_scan.op) assert(fxpt[s_scan.op.spatial_axis_[0]].value == 0) assert(fxpt[s_scan.op.spatial_axis_[1]].value == 0) diff --git a/tests/python/unittest/test_te_schedule_ops.py b/tests/python/unittest/test_te_schedule_ops.py index 3e521ab07023..3f93c772a037 100644 --- a/tests/python/unittest/test_te_schedule_ops.py +++ b/tests/python/unittest/test_te_schedule_ops.py @@ -28,6 +28,9 @@ def test_schedule0(): bounds = tvm.te.schedule.InferBound(s) assert isinstance(bounds, tvm.container.Map) stmt = tvm.te.schedule.ScheduleOps(s, bounds) + func = tvm.te.schedule.SchedulePostProcToPrimFunc( + [A, A1], stmt, None) + assert isinstance(func, tvm.tir.PrimFunc) def test_schedule1(): @@ -43,6 +46,10 @@ def test_schedule1(): assert isinstance(bounds, tvm.container.Map) stmt = tvm.te.schedule.ScheduleOps(s, bounds) + func = tvm.te.schedule.SchedulePostProcToPrimFunc( + [A, A1], stmt, None) + assert isinstance(func, tvm.tir.PrimFunc) + def test_schedule2(): m = te.var('m') @@ -57,6 +64,9 @@ def test_schedule2(): bounds = tvm.te.schedule.InferBound(s) assert isinstance(bounds, tvm.container.Map) stmt = tvm.te.schedule.ScheduleOps(s, bounds) + func = tvm.te.schedule.SchedulePostProcToPrimFunc( + [A, A2], stmt, None) + assert isinstance(func, tvm.tir.PrimFunc) def test_schedule_scan(): @@ -77,6 +87,7 @@ def test_schedule_scan(): stmt = tvm.te.schedule.ScheduleOps(s, bounds) + def test_inline_multi_reduce(): def argmax_comp(x, y): idx = tvm.tir.Select((x[1] >= y[1]), x[0], y[0]) @@ -144,7 +155,7 @@ def test_inline_mixed(): def check(x): if isinstance(x, tvm.tir.Call): assert x.func != A2 - tvm.tir.ir_pass.PostOrderVisit(s[C].op.body[0], check) + tvm.tir.stmt_functor.post_order_visit(s[C].op.body[0], check) def test_scan_inline1(): @@ -310,10 +321,9 @@ def intrin_func(ins, outs): "gemv_add", ww_ptr, xx_ptr, zz_ptr, n, ww.strides[0]) return body, reset, update - with tvm.target.build_config(data_alignment=16, - offset_factor=16): - return te.decl_tensor_intrin(z.op, intrin_func, - binds={w: Wb}) + buffer_params = {"data_alignment": 16, "offset_factor": 16} + return te.decl_tensor_intrin( + z.op, intrin_func, binds={w: Wb}, default_buffer_params=buffer_params) def test_schedule_tensor_compute1(): @@ -366,8 +376,9 @@ def intrin_func(ins, outs): ib.emit(tvm.tir.call_extern(outs[0].dtype, 'vadd', ins[0].access_ptr("r"), ins[1].access_ptr('r'), outs[0].access_ptr('wr'))) return ib.get() - with tvm.target.build_config(offset_factor=16): - return te.decl_tensor_intrin(z.op, intrin_func, binds=binds) + return te.decl_tensor_intrin(z.op, intrin_func, binds=binds, default_buffer_params={ + "offset_factor": 16 + }) def test_schedule_tensor_compute2(): @@ -506,23 +517,23 @@ def schedule(thread_tag, mem_scope) : def collect_visit(stmt, f): ret = [] - tvm.tir.ir_pass.PostOrderVisit(stmt, lambda x: ret.append(f(x))) + tvm.tir.stmt_functor.post_order_visit(stmt, lambda x: ret.append(f(x))) return ret # local vs. threadIdx s = schedule(tx, "local") - lowered_body = tvm.lower(s, [A, C], simple_mode=True).body + lowered_body = tvm.lower(s, [A, C])["main"].body assert (not any( collect_visit(lowered_body, lambda x: isinstance(x, tvm.tir.IfThenElse)))) # local vs. vthread s = schedule(vx, "local") - lowered_body = tvm.lower(s, [A, C], simple_mode=True).body + lowered_body = tvm.lower(s, [A, C])["main"].body assert (not any( collect_visit(lowered_body, lambda x: isinstance(x, tvm.tir.IfThenElse)))) # shared vs. blockIdx s = schedule(by, "shared") - lowered_body = tvm.lower(s, [A, C], simple_mode=True).body + lowered_body = tvm.lower(s, [A, C])["main"].body assert (not any( collect_visit(lowered_body, lambda x: isinstance(x, tvm.tir.IfThenElse)))) @@ -548,11 +559,11 @@ def test_local_stage_predicate2(): s[AA].compute_at(s[C], ooc) oaa, iaa = s[AA].split(s[AA].op.axis[0], factor=32) s[AA].bind(iaa, thread_x) - lowered_body = tvm.lower(s, [A, C], simple_mode=True).body + lowered_body = tvm.lower(s, [A, C])["main"].body def collect_visit(stmt, f): ret = [] - tvm.tir.ir_pass.PostOrderVisit(stmt, lambda x: ret.append(f(x))) + tvm.tir.stmt_functor.post_order_visit(stmt, lambda x: ret.append(f(x))) return ret def visit_stmt(op): diff --git a/tests/python/unittest/test_tir_pass_rewrite_for_tensor_core.py b/tests/python/unittest/test_te_schedule_postproc_rewrite_for_tensor_core.py similarity index 100% rename from tests/python/unittest/test_tir_pass_rewrite_for_tensor_core.py rename to tests/python/unittest/test_te_schedule_postproc_rewrite_for_tensor_core.py diff --git a/tests/python/unittest/test_te_schedule_tensorize.py b/tests/python/unittest/test_te_schedule_tensorize.py index dafffed9bd44..5152235ef379 100644 --- a/tests/python/unittest/test_te_schedule_tensorize.py +++ b/tests/python/unittest/test_te_schedule_tensorize.py @@ -25,8 +25,8 @@ def intrin_func(ins, outs): xx, yy = ins zz = outs[0] return tvm.tir.call_packed("vadd", xx, yy, zz) - with tvm.target.build_config(offset_factor=16): - return te.decl_tensor_intrin(z.op, intrin_func) + buffer_params = {"offset_factor": 16} + return te.decl_tensor_intrin(z.op, intrin_func, default_buffer_params=buffer_params) def intrin_gemv(m, n): w = te.placeholder((m, n), name='w') @@ -52,10 +52,9 @@ def intrin_func(ins, outs): "gemv_add", ww_ptr, xx_ptr, zz_ptr, n, ww.strides[0]) return body, reset, update - with tvm.target.build_config(data_alignment=16, - offset_factor=16): - return te.decl_tensor_intrin(z.op, intrin_func, - binds={w: Wb}) + buffer_params = {"offset_factor": 16, "data_alignment": 16} + return te.decl_tensor_intrin( + z.op, intrin_func, binds={w: Wb}, default_buffer_params=buffer_params) def intrin_gemv_no_reset(m, n): w = te.placeholder((m, n), name='w') @@ -79,10 +78,10 @@ def intrin_func(ins, outs): "gemv_add", ww_ptr, xx_ptr, zz_ptr, n, ww.strides[0]) return body, None, update - with tvm.target.build_config(data_alignment=16, - offset_factor=16): - return te.decl_tensor_intrin(z.op, intrin_func, - binds={w: Wb}) + + buffer_params = {"offset_factor": 16, "data_alignment": 16} + return te.decl_tensor_intrin( + z.op, intrin_func, binds={w: Wb}, default_buffer_params=buffer_params) def test_tensorize_vadd(): @@ -105,9 +104,10 @@ def check(factor): assert tvm.ir.structural_equal(in_dom.items()[0][1][0].extent, factor) fmatch = tvm.get_global_func("test.op.MatchTensorizeBody") body = fmatch(s[z], out_dom, in_dom, vadd) + ana = tvm.arith.Analyzer() assert tvm.ir.structural_equal( - tvm.tir.ir_pass.CanonicalSimplify(body[0]), - tvm.tir.ir_pass.CanonicalSimplify(vadd.op.body[0])) + ana.simplify(body[0]), + ana.simplify(vadd.op.body[0])) stmt = tvm.te.schedule.ScheduleOps(s, dom_map) tvm.lower(s, [x, y, z]) @@ -139,9 +139,11 @@ def check(factor): assert tvm.ir.structural_equal(out_dom[y].min, yo * factor) fmatch = tvm.get_global_func("test.op.MatchTensorizeBody") body = fmatch(s[C], out_dom, in_dom, gemv) + ana = tvm.arith.Analyzer() + assert tvm.ir.structural_equal( - tvm.tir.ir_pass.CanonicalSimplify(body[0]), - tvm.tir.ir_pass.CanonicalSimplify(gemv.op.body[0])) + ana.simplify(body[0]), + ana.simplify(gemv.op.body[0])) stmt = tvm.te.schedule.ScheduleOps(s, dom_map) tvm.lower(s, [A, B, C]) @@ -164,9 +166,10 @@ def check_rfactor(factor, rfactor): assert tvm.ir.structural_equal(out_dom[y].min, yo * factor) fmatch = tvm.get_global_func("test.op.MatchTensorizeBody") body = fmatch(s[C], out_dom, in_dom, gemv) + ana = tvm.arith.Analyzer() assert tvm.ir.structural_equal( - tvm.tir.ir_pass.CanonicalSimplify(body[0]), - tvm.tir.ir_pass.CanonicalSimplify(gemv.op.body[0])) + ana.simplify(body[0]), + ana.simplify(gemv.op.body[0])) stmt = tvm.te.schedule.ScheduleOps(s, dom_map) tvm.lower(s, [A, B, C]) @@ -188,9 +191,10 @@ def check_rfactor_no_reset(factor, rfactor): assert tvm.ir.structural_equal(out_dom[y].min, yo * factor) fmatch = tvm.get_global_func("test.op.MatchTensorizeBody") body = fmatch(s[C], out_dom, in_dom, gemv) + ana = tvm.arith.Analyzer() assert tvm.ir.structural_equal( - tvm.tir.ir_pass.CanonicalSimplify(body[0]), - tvm.tir.ir_pass.CanonicalSimplify(gemv.op.body[0])) + ana.simplify(body[0]), + ana.simplify(gemv.op.body[0])) stmt = tvm.te.schedule.ScheduleOps(s, dom_map) tvm.lower(s, [A, B, C]) @@ -213,9 +217,10 @@ def check_rfactor_no_reset_multi_reduction(factor, rfactor): assert tvm.ir.structural_equal(out_dom[y].min, yo * factor) fmatch = tvm.get_global_func("test.op.MatchTensorizeBody") body = fmatch(s[C], out_dom, in_dom, gemv) + ana = tvm.arith.Analyzer() assert tvm.ir.structural_equal( - tvm.tir.ir_pass.CanonicalSimplify(body[0]), - tvm.tir.ir_pass.CanonicalSimplify(gemv.op.body[0])) + ana.simplify(body[0]), + ana.simplify(gemv.op.body[0])) stmt = tvm.te.schedule.ScheduleOps(s, dom_map) tvm.lower(s, [A, B, C]) @@ -242,8 +247,9 @@ def intrin_func(ins, outs): zz = outs[0] return tvm.tir.call_packed("op", xx, zz) - with tvm.target.build_config(offset_factor=2): - return te.decl_tensor_intrin(y.op, intrin_func) + return te.decl_tensor_intrin(y.op, intrin_func, default_buffer_params={ + "offset_factor": 2 + }) A = te.placeholder((5, 5), name='A') B = te.compute((9,9), lambda i, j: A[idxd(j,3) + idxm(i,3), idxm(j,3) + idxd(i,3)]) @@ -280,8 +286,7 @@ def intrin_multivadd(n): def intrin_func(ins, outs): return tvm.tir.call_packed("multivadd") - with tvm.target.build_config(): - return te.decl_tensor_intrin(z.op, intrin_func, name="multivadd") + return te.decl_tensor_intrin(z.op, intrin_func, name="multivadd") def intrin_vadd(n): dtype = 'float32' @@ -291,9 +296,7 @@ def intrin_vadd(n): s = te.create_schedule(z.op) def create_buffer(t): - return tvm.tir.decl_buffer(t.shape, t.dtype, - name='W'+t.name, - offset_factor=16) + return tvm.tir.decl_buffer(t.shape, t.dtype, name='W'+t.name, offset_factor=16) def intrin_func(ins, outs): ib = tvm.tir.ir_builder.create() @@ -301,11 +304,9 @@ def intrin_func(ins, outs): ins[0].access_ptr("r"), ins[1].access_ptr('r'), outs[0].access_ptr('wr'))) return ib.get() - - with tvm.target.build_config(offset_factor=16): - return te.decl_tensor_intrin(z.op, intrin_func, binds={x: create_buffer(x), - y: create_buffer(y), - z: create_buffer(z)}) + return te.decl_tensor_intrin(z.op, intrin_func, binds={x: create_buffer(x), + y: create_buffer(y), + z: create_buffer(z)}) # cache_read, cache_write M = 1024 diff --git a/tests/python/unittest/test_te_tensor.py b/tests/python/unittest/test_te_tensor.py index 55edd1c9958b..8d737c9f629b 100644 --- a/tests/python/unittest/test_te_tensor.py +++ b/tests/python/unittest/test_te_tensor.py @@ -117,8 +117,9 @@ def intrin_func(ins, outs): ib.emit(tvm.tir.call_extern(outs[0].dtype, 'vadd', ins[0].access_ptr("r"), ins[1].access_ptr('r'), outs[0].access_ptr('wr'))) return ib.get() - with tvm.target.build_config(offset_factor=n): - return te.decl_tensor_intrin(z.op, intrin_func) + return te.decl_tensor_intrin(z.op, intrin_func, default_buffer_params={ + "offset_factor": n + }) vadd = intrin_vadd(factor) @@ -128,7 +129,7 @@ def intrin_func(ins, outs): lambda i: vadd(A[i, 0:factor], B[i, 0:factor])) s = te.create_schedule(C.op) - stmt = tvm.lower(s, [A, B, C], simple_mode=True) + stmt = tvm.lower(s, [A, B, C])["main"].body assert isinstance(stmt.body, tvm.tir.Evaluate) def test_tensor_compute2(): @@ -159,8 +160,8 @@ def intrin_func(ins, outs): "gemv_add", x_ptr, y_ptr, z_ptr, m, n, l) return body, reset, update - with tvm.target.build_config(offset_factor=n): - return te.decl_tensor_intrin(z.op, intrin_func) + return te.decl_tensor_intrin(z.op, intrin_func, + default_buffer_params={"offset_factor": n}) vgemm = intrin_gemm(factor1, factor2, factor) @@ -171,7 +172,7 @@ def intrin_func(ins, outs): lambda i, j: vgemm(A[i, k, 0:factor1, 0:factor], B[j, k, 0:factor2, 0:factor], reduce_axis=k)) s = te.create_schedule(C.op) - stmt = tvm.lower(s, [A, B, C], simple_mode=True) + stmt = tvm.lower(s, [A, B, C])["main"].body assert isinstance(stmt.body.body[0], tvm.tir.Evaluate) assert isinstance(stmt.body.body[1].body, tvm.tir.Evaluate) @@ -260,11 +261,11 @@ def test_tuple_with_different_deps(): stmt = tvm.te.schedule.ScheduleOps(sch, bounds) def get_B1_realize(x): - if isinstance(x, tvm.tir.Realize) and \ - x.func == B1.op and x.value_index == 1: + if isinstance(x, tvm.tir.ProducerRealize) and \ + x.producer.op == B1.op and x.producer.value_index == 1: ret.append(x) ret = [] - tvm.tir.ir_pass.PostOrderVisit(stmt, get_B1_realize) + tvm.tir.stmt_functor.post_order_visit(stmt, get_B1_realize) assert stmt.node == C.op and len(ret) == 1 @@ -290,8 +291,8 @@ def intrin_func(ins, outs): dout = outs[0] return tvm.tir.call_packed("op", dinp, dout) - with tvm.target.build_config(offset_factor=1): - return te.decl_tensor_intrin(P.op, intrin_func) + return te.decl_tensor_intrin(P.op, intrin_func, + default_buffer_params={"offset_factor": 1}) A = te.placeholder((1, 64, 16, 16), name='A') P = pool(data=A, kernel=(3, 3), stride=(1, 1), padding=(0, 0, 0, 0), diff --git a/tests/python/unittest/test_tir_pass_verify_gpu_code.py b/tests/python/unittest/test_tir_analysis_verify_gpu_code.py similarity index 87% rename from tests/python/unittest/test_tir_pass_verify_gpu_code.py rename to tests/python/unittest/test_tir_analysis_verify_gpu_code.py index 6e138a29b3e9..11960cad04d4 100644 --- a/tests/python/unittest/test_tir_pass_verify_gpu_code.py +++ b/tests/python/unittest/test_tir_analysis_verify_gpu_code.py @@ -19,10 +19,11 @@ from tvm import te def get_verify_pass(valid, **kwargs): - def verify_pass(stmt): - valid[0] = tvm.tir.ir_pass.VerifyGPUCode(stmt, kwargs) - return stmt - return verify_pass + def _fverify(f, *_): + valid[0] = tvm.tir.analysis.verify_gpu_code(f, kwargs) + return f + return tvm.tir.transform.prim_func_pass(_fverify, opt_level=0) + def test_shared_memory(): def check_shared_memory(dtype): @@ -49,14 +50,14 @@ def check_shared_memory(dtype): if not tvm.context(target).exist: continue valid = [None] - with tvm.target.build_config(**{"add_lower_pass": [ + with tvm.transform.PassContext(config={"tir.add_lower_pass": [ (2, get_verify_pass(valid, max_shared_memory_per_block=type_size * M - 1, max_threads_per_block=M))]}): tvm.build(s, [A, B], target) assert not valid[0] - with tvm.target.build_config(**{"add_lower_pass": [ + with tvm.transform.PassContext(config={"tir.add_lower_pass": [ (2, get_verify_pass(valid, max_shared_memory_per_block=type_size * M, max_threads_per_block=M))]}): @@ -86,14 +87,14 @@ def test_local_memory(): continue valid = [None] - with tvm.target.build_config(**{"add_lower_pass": [ + with tvm.transform.PassContext(config={"tir.add_lower_pass": [ (2, get_verify_pass(valid, max_local_memory_per_block=4 * M - 1, max_threads_per_block=1))]}): tvm.build(s, [A, B], target) assert not valid[0] - with tvm.target.build_config(**{"add_lower_pass": [ + with tvm.transform.PassContext(config={"tir.add_lower_pass": [ (2, get_verify_pass(valid, max_local_memory_per_block=4 * M, max_threads_per_block=1))]}): @@ -121,21 +122,21 @@ def test_num_thread(): continue valid = [None] - with tvm.target.build_config(**{"add_lower_pass": [ + with tvm.transform.PassContext(config={"tir.add_lower_pass": [ (2, get_verify_pass(valid, max_shared_memory_per_block=0, max_threads_per_block=N - 1))]}): tvm.build(s, [A, B], target) assert not valid[0] - with tvm.target.build_config(**{"add_lower_pass": [ + with tvm.transform.PassContext(config={"tir.add_lower_pass": [ (2, get_verify_pass(valid, max_shared_memory_per_block=0, max_threads_per_block=N))]}): tvm.build(s, [A, B], target) assert valid[0] - with tvm.target.build_config(**{"add_lower_pass": [ + with tvm.transform.PassContext(config={"tir.add_lower_pass": [ (2, get_verify_pass(valid, max_shared_memory_per_block=0, max_threads_per_block=N, @@ -143,7 +144,7 @@ def test_num_thread(): tvm.build(s, [A, B], target) assert not valid[0] - with tvm.target.build_config(**{"add_lower_pass": [ + with tvm.transform.PassContext(config={"tir.add_lower_pass": [ (2, get_verify_pass(valid, max_shared_memory_per_block=0, max_threads_per_block=N, @@ -171,14 +172,14 @@ def test_multiple_kernels(): continue valid = [None] - with tvm.target.build_config(**{"add_lower_pass": [ + with tvm.transform.PassContext(config={"tir.add_lower_pass": [ (2, get_verify_pass(valid, max_shared_memory_per_block=0, max_threads_per_block=N - 1))]}): tvm.build(s, [A, C], target) assert not valid[0] - with tvm.target.build_config(**{"add_lower_pass": [ + with tvm.transform.PassContext(config={"tir.add_lower_pass": [ (2, get_verify_pass(valid, max_shared_memory_per_block=0, max_threads_per_block=N))]}): @@ -202,7 +203,7 @@ def test_wrong_bind(): continue valid = [None] - with tvm.target.build_config(**{"add_lower_pass": [ + with tvm.transform.PassContext(config={"tir.add_lower_pass": [ (2, get_verify_pass(valid, max_threads_per_block=N*N))]}): tvm.build(s, [A, B], target) assert not valid[0] diff --git a/tests/python/unittest/test_tir_analysis_verify_memory.py b/tests/python/unittest/test_tir_analysis_verify_memory.py index b3625082f6ed..386fceb150e3 100644 --- a/tests/python/unittest/test_tir_analysis_verify_memory.py +++ b/tests/python/unittest/test_tir_analysis_verify_memory.py @@ -24,29 +24,6 @@ other_devices = ["llvm", "ext_dev"] -def lower(sch, args): - binds = {} - arg_list = [] - for x in args: - if isinstance(x, te.tensor.Tensor): - buf = tvm.tir.decl_buffer(x.shape, dtype=x.dtype, name=x.name) - assert x not in binds - binds[x] = buf - arg_list.append(buf) - else: - raise ValueError("args must be Tensor, Buffer or Var") - sch = sch.normalize() - bounds = tvm.te.schedule.InferBound(sch) - stmt = tvm.te.schedule.ScheduleOps(sch, bounds) - stmt = tvm.tir.ir_pass.LoopPartition(stmt, False) - stmt = tvm.tir.ir_pass.StorageFlatten(stmt, binds, 64) - - f = tvm.tir.PrimFunc(arg_list, stmt).with_attr( - "global_symbol", tvm.runtime.String("test")) - mod = tvm.IRModule({"test": f}) - return mod - - # All computations are bound. # So VerifyMemory pass is expected to succeed. # @@ -61,12 +38,12 @@ def test_verify_memory_all_bind(): s[B].bind(bx, te.thread_axis("blockIdx.x")) s[B].bind(tx, te.thread_axis("threadIdx.x")) - mod = lower(s, [A, B]) + mod = tvm.lower(s, [A, B]) for dev_type in gpu_devices + other_devices: binded_mod = tvm.tir.transform.Apply( lambda f: f.with_attr("target", tvm.target.create(dev_type)))(mod) - tvm.tir.analysis.verify_memory(binded_mod) + tvm.tir.transform.VerifyMemory()(binded_mod) @@ -81,18 +58,18 @@ def test_verify_memory_not_bind(): # B is not bound to threads. s = te.create_schedule(B.op) - mod = lower(s, [A, B]) + mod = tvm.lower(s, [A, B]) for dev_type in gpu_devices: binded_mod = tvm.tir.transform.Apply( lambda f: f.with_attr("target", tvm.target.create(dev_type)))(mod) - with pytest.raises(ValueError): - tvm.tir.analysis.verify_memory(binded_mod) + with pytest.raises(RuntimeError): + tvm.tir.transform.VerifyMemory()(binded_mod) for dev_type in other_devices: binded_mod = tvm.tir.transform.Apply( lambda f: f.with_attr("target", tvm.target.create(dev_type)))(mod) - tvm.tir.analysis.verify_memory(binded_mod) + tvm.tir.transform.VerifyMemory()(binded_mod) # Computations are partially bound. @@ -111,18 +88,18 @@ def test_verify_memory_partially_bind(): s[C].bind(bx, te.thread_axis("blockIdx.x")) s[C].bind(tx, te.thread_axis("threadIdx.x")) - mod = lower(s, [A, B, C, D]) + mod = tvm. lower(s, [A, B, C, D]) for dev_type in gpu_devices: binded_mod = tvm.tir.transform.Apply( lambda f: f.with_attr("target", tvm.target.create(dev_type)))(mod) - with pytest.raises(ValueError): - tvm.tir.analysis.verify_memory(binded_mod) + with pytest.raises(RuntimeError): + tvm.tir.transform.VerifyMemory()(binded_mod) for dev_type in other_devices: binded_mod = tvm.tir.transform.Apply( lambda f: f.with_attr("target", tvm.target.create(dev_type)))(mod) - tvm.tir.analysis.verify_memory(binded_mod) + tvm.tir.transform.VerifyMemory()(binded_mod) diff --git a/tests/python/unittest/test_tir_analysis_verify_ssa.py b/tests/python/unittest/test_tir_analysis_verify_ssa.py new file mode 100644 index 000000000000..8a15c3628074 --- /dev/null +++ b/tests/python/unittest/test_tir_analysis_verify_ssa.py @@ -0,0 +1,32 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +from tvm import te + +def test_verify_ssa(): + x = te.var('x') + y = te.var() + z = tvm.tir.Evaluate(x + y) + assert(tvm.tir.analysis.verify_ssa( + tvm.tir.PrimFunc([x, y],z))) + + assert(not tvm.tir.analysis.verify_ssa( + tvm.tir.PrimFunc([x, y], tvm.tir.LetStmt(x, 1, z)))) + + +if __name__ == "__main__": + test_verify_ssa() diff --git a/tests/python/unittest/test_tir_buffer.py b/tests/python/unittest/test_tir_buffer.py index fe23955017a0..7ee1e539204b 100644 --- a/tests/python/unittest/test_tir_buffer.py +++ b/tests/python/unittest/test_tir_buffer.py @@ -48,17 +48,14 @@ def test_buffer_access_ptr_offset(): n = te.size_var('n') Ab = tvm.tir.decl_buffer((m, n), "float32") aptr = Ab.access_ptr("rw", offset=100) - offset = tvm.tir.ir_pass.Simplify(aptr.args[2]) - assert tvm.ir.structural_equal(offset, 100) + tvm.testing.assert_prim_expr_equal(aptr.args[2], 100) assert aptr.args[4].value == Buffer.READ | Buffer.WRITE v = te.size_var('int32') aptr = Ab.access_ptr("rw", offset=100 + 100 + v) - offset = tvm.tir.ir_pass.Simplify(aptr.args[2]) - assert tvm.ir.structural_equal(offset, 200 + v) + tvm.testing.assert_prim_expr_equal(aptr.args[2], 200 + v) assert aptr.args[4].value == Buffer.READ | Buffer.WRITE aptr = Ab.access_ptr("rw", offset=tvm.tir.call_extern('int32', "test_call", 100 + 100 + v)) - offset = tvm.tir.ir_pass.Simplify(aptr.args[2]) - assert tvm.ir.structural_equal(offset, tvm.tir.call_extern('int32', "test_call", 200 + v)) + tvm.testing.assert_prim_expr_equal(aptr.args[2], tvm.tir.call_extern('int32', "test_call", 200 + v)) assert aptr.args[4].value == Buffer.READ | Buffer.WRITE @@ -80,8 +77,7 @@ def test_buffer_vload(): n = te.size_var('n') Ab = tvm.tir.decl_buffer((m, n), "float32", elem_offset=100) load = Ab.vload([2, 3]) - offset = tvm.tir.ir_pass.Simplify(load.index) - assert tvm.ir.structural_equal(offset, n * 2 + 103) + tvm.testing.assert_prim_expr_equal(load.index, n * 2 + 103) def test_buffer_index_merge_mult_mod(): diff --git a/tests/python/unittest/test_tir_constructor.py b/tests/python/unittest/test_tir_constructor.py index 7a03e48e2270..8f03d1028bc6 100644 --- a/tests/python/unittest/test_tir_constructor.py +++ b/tests/python/unittest/test_tir_constructor.py @@ -112,14 +112,12 @@ def test_expr_constructor(): assert x.vectors[0] == a assert x.indices[0].value == 0 - x = tvm.tir.Call("float32", "xyz", [a], tvm.tir.Call.Extern, None, 0) + x = tvm.tir.Call("float32", "xyz", [a], tvm.tir.Call.Extern) assert isinstance(x, tvm.tir.Call) assert x.dtype == "float32" assert x.name == "xyz" assert x.args[0] == a assert x.call_type == tvm.tir.Call.Extern - assert x.func == None - assert x.value_index == 0 v = te.var("aa") x = tvm.tir.Let(v, 1, v) @@ -160,12 +158,6 @@ def test_stmt_constructor(): assert x.index.value == 10 assert x.value.value == 1 - tensor = te.placeholder((), dtype="float32") - x = tvm.tir.Provide(tensor.op, 0, 10, []) - assert isinstance(x, tvm.tir.Provide) - assert x.value_index == 0 - assert x.value.value == 10 - x = tvm.tir.Allocate(buffer_var, "float32", [10], tvm.tir.const(1, "uint1"), nop) assert isinstance(x, tvm.tir.Allocate) @@ -183,10 +175,6 @@ def test_stmt_constructor(): assert isinstance(x, tvm.tir.Free) assert x.buffer_var == buffer_var - x = tvm.tir.Realize(None, 0, "float", [], tvm.tir.const(1, "uint1"), nop) - assert isinstance(x, tvm.tir.Realize) - assert x.body == nop - x = tvm.tir.IfThenElse(tvm.tir.const(1, "uint1"), tvm.tir.Evaluate(11), nop) @@ -194,9 +182,9 @@ def test_stmt_constructor(): assert x.then_case.value.value == 11 assert x.else_case == nop - x = tvm.tir.Prefetch(None, 1, "float32", []) + b = tvm.tir.decl_buffer((1, 2)) + x = tvm.tir.Prefetch(b, []) assert isinstance(x, tvm.tir.Prefetch) - assert x.value_index == 1 if __name__ == "__main__": diff --git a/tests/python/unittest/test_tir_intrin.py b/tests/python/unittest/test_tir_intrin.py index 61a522ccafff..26bf80f5e1a5 100644 --- a/tests/python/unittest/test_tir_intrin.py +++ b/tests/python/unittest/test_tir_intrin.py @@ -63,6 +63,12 @@ def test_unary_intrin(): (tvm.tir.sinh, lambda x : np.sinh(x)), (tvm.tir.cosh, lambda x : np.cosh(x)), (tvm.tir.log1p, lambda x : np.log1p(x)), + (tvm.tir.asin, lambda x : np.arcsin(x)), + (tvm.tir.acos, lambda x : np.arccos(x)), + (tvm.tir.atan, lambda x : np.arctan(x)), + (tvm.tir.asinh, lambda x : np.arcsinh(x)), + (tvm.tir.acosh, lambda x : np.arccosh(x)), + (tvm.tir.atanh, lambda x : np.arctanh(x)), ] def run_test(tvm_intrin, np_func): m = te.var("m",) @@ -72,7 +78,7 @@ def run_test(tvm_intrin, np_func): f = tvm.build(s, [A, B], "llvm") ctx = tvm.cpu(0) n = 10 - a = tvm.nd.array(np.random.uniform(0, 1, size=n).astype(A.dtype), ctx) + a = tvm.nd.array(np.random.uniform(0.1, 0.5, size=n).astype(A.dtype), ctx) b = tvm.nd.array( \ np.random.uniform(size=n).astype(A.dtype), ctx) f(a, b) diff --git a/tests/python/unittest/test_tir_ir_builder.py b/tests/python/unittest/test_tir_ir_builder.py index 9106be843b48..090acda00365 100644 --- a/tests/python/unittest/test_tir_ir_builder.py +++ b/tests/python/unittest/test_tir_ir_builder.py @@ -28,7 +28,6 @@ def test_for(): A[j] = A[j] + 2 body = ib.get() - print(body) assert isinstance(body, tvm.tir.AttrStmt) body = body.body assert isinstance(body, tvm.tir.Allocate) @@ -59,14 +58,13 @@ def test_if(): assert body.else_case.index.value == 0 def test_prefetch(): - A = te.placeholder((10, 20), name="A") + A = tvm.tir.decl_buffer((10, 20), name="A") ib = tvm.tir.ir_builder.create() n = te.size_var("n") with ib.for_range(0, n, name="i") as i: ib.emit( - tvm.tir.Prefetch( - A.op, A.value_index, A.dtype, + tvm.tir.Prefetch(A, [tvm.ir.Range.make_by_min_extent(i+1, 2), tvm.ir.Range.make_by_min_extent(0, 20)])) body = ib.get() diff --git a/tests/python/unittest/test_tir_nodes.py b/tests/python/unittest/test_tir_nodes.py index 9f4ccadde94d..e6322592edaf 100644 --- a/tests/python/unittest/test_tir_nodes.py +++ b/tests/python/unittest/test_tir_nodes.py @@ -103,7 +103,7 @@ def test_basic(): a = te.var('a') b = te.var('b') c = a + b - assert str(c) == '(%s + %s)' % (a.name, b.name) + assert str(c) == '(%s: int32 + %s: int32)' % (a.name, b.name) def test_stmt(): @@ -138,11 +138,11 @@ def test_any(): assert False except ValueError: pass - assert str(tvm.tir.any(x < y)) == '(%s < %s)' % (x.name, y.name) - assert str(tvm.tir.any(x < y, x > z)) == '((%s < %s) || (%s > %s))' % ( + assert str(tvm.tir.any(x < y)) == '(%s: int32 < %s: int32)' % (x.name, y.name) + assert str(tvm.tir.any(x < y, x > z)) == '((%s: int32 < %s: int32) || (%s > %s: int32))' % ( x.name, y.name, x.name, z.name) assert str(tvm.tir.any(x < y, y > z + 1, x < z * 2)) == \ - '(((%s < %s) || (%s > (%s + 1))) || (%s < (%s*2)))' % ( + '(((%s: int32 < %s: int32) || (%s > (%s: int32 + 1))) || (%s < (%s*2)))' % ( x.name, y.name, y.name, z.name, x.name, z.name) @@ -160,29 +160,29 @@ def test_all(): assert False except ValueError: pass - assert str(tvm.tir.all(x < y)) == '(%s < %s)' % (x.name, y.name) - assert str(tvm.tir.all(x < y, x > z)) == '((%s < %s) && (%s > %s))' % ( + assert str(tvm.tir.all(x < y)) == '(%s: int32 < %s: int32)' % (x.name, y.name) + assert str(tvm.tir.all(x < y, x > z)) == '((%s: int32 < %s: int32) && (%s > %s: int32))' % ( x.name, y.name, x.name, z.name) assert str(tvm.tir.all(x < y, y > z + 1, x < z * 2)) == \ - '(((%s < %s) && (%s > (%s + 1))) && (%s < (%s*2)))' % ( + '(((%s: int32 < %s: int32) && (%s > (%s: int32 + 1))) && (%s < (%s*2)))' % ( x.name, y.name, y.name, z.name, x.name, z.name) def test_bitwise(): x = te.var('x') y = te.var('y') - assert str(x << y) == 'shift_left(x, y)' - assert str(x >> y) == 'shift_right(x, y)' - assert str(x & y) == 'bitwise_and(x, y)' - assert str(x | y) == 'bitwise_or(x, y)' - assert str(x ^ y) == 'bitwise_xor(x, y)' - assert str(10 & x) == 'bitwise_and(10, x)' - assert str(10 | x) == 'bitwise_or(10, x)' - assert str(10 ^ x) == 'bitwise_xor(10, x)' - assert str(10 >> x) == 'shift_right(10, x)' - assert str(10 << x) == 'shift_left(10, x)' - assert str(10 % x) == 'floormod(10, x)' - assert str(~x) == 'bitwise_not(x)' + assert str(x << y) == '@shift_left(x: int32, y: int32, dtype=int32, type="pure_intrin")' + assert str(x >> y) == '@shift_right(x: int32, y: int32, dtype=int32, type="pure_intrin")' + assert str(x & y) == '@bitwise_and(x: int32, y: int32, dtype=int32, type="pure_intrin")' + assert str(x | y) == '@bitwise_or(x: int32, y: int32, dtype=int32, type="pure_intrin")' + assert str(x ^ y) == '@bitwise_xor(x: int32, y: int32, dtype=int32, type="pure_intrin")' + assert str(10 & x) == '@bitwise_and(10, x: int32, dtype=int32, type="pure_intrin")' + assert str(10 | x) == '@bitwise_or(10, x: int32, dtype=int32, type="pure_intrin")' + assert str(10 ^ x) == '@bitwise_xor(10, x: int32, dtype=int32, type="pure_intrin")' + assert str(10 >> x) == '@shift_right(10, x: int32, dtype=int32, type="pure_intrin")' + assert str(10 << x) == '@shift_left(10, x: int32, dtype=int32, type="pure_intrin")' + assert str(10 % x) == 'floormod(10, x: int32)' + assert str(~x) == '@bitwise_not(x: int32, dtype=int32, type="pure_intrin")' assert(tvm.tir.const(1, "int8x2") >> 1).dtype == "int8x2" assert(x >> tvm.tir.const(1, "int32x2")).dtype == "int32x2" assert(te.var("z", "int8x2") << tvm.tir.const(1, "int8x2")).dtype == "int8x2" @@ -239,12 +239,12 @@ def test_divide_by_zero(): def test_isnan(): x = te.var('x', 'float32') - assert str(tvm.tir.isnan(x)) == 'isnan(x)' + assert str(tvm.tir.isnan(x)) == '@isnan(x: float32, dtype=bool, type="pure_intrin")' assert str(tvm.tir.isnan(x).dtype) == 'bool' y = te.var('y', 'float16') - assert str(tvm.tir.isnan(y)) == 'isnan(float32(y))' + assert str(tvm.tir.isnan(y)) == '@isnan(cast(float32, y: float16), dtype=bool, type="pure_intrin")' z = te.var('z', 'int32') - assert str(tvm.tir.isnan(z)) == '(bool)0' + assert str(tvm.tir.isnan(z)) == 'False' k = te.var('k', 'int8x2') assert str(tvm.tir.isnan(k).dtype) == 'uint1x2' @@ -301,6 +301,10 @@ def test_buffer_load_store(): s = tvm.tir.BufferStore(b, 0.1, [0]) assert isinstance(s, tvm.tir.BufferStore) + s = tvm.tir.BufferRealize(b, [tvm.ir.Range(0, 1)], + True, tvm.tir.Evaluate(0)) + assert isinstance(s, tvm.tir.BufferRealize) + def test_intimm_cond(): x = tvm.runtime.convert(1) diff --git a/tests/python/unittest/test_tir_pass_attrs_hash_equal.py b/tests/python/unittest/test_tir_pass_attrs_hash_equal.py deleted file mode 100644 index 9a115be74559..000000000000 --- a/tests/python/unittest/test_tir_pass_attrs_hash_equal.py +++ /dev/null @@ -1,56 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import tvm -from tvm import te - -def test_attrs_equal(): - x = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3, 4)) - y = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3, 4)) - z = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3,4,1)) - assert tvm.ir.structural_equal(x, y) - assert not tvm.ir.structural_equal(x, z) - - dattr = tvm.ir.make_node("DictAttrs", x=1, y=10, name="xyz", padding=(0,0)) - assert not tvm.ir.structural_equal(dattr, x) - dattr2 = tvm.ir.make_node("DictAttrs", x=1, y=10, name="xyz", padding=(0,0)) - assert tvm.ir.structural_equal(dattr, dattr2) - - assert tvm.ir.structural_equal({"x": x}, {"x": y}) - # array related checks - assert tvm.ir.structural_equal({"x": [x, x]}, {"x": [y, x]}) - assert not tvm.ir.structural_equal({"x": [x, 1]}, {"x": [y, 2]}) - - n = te.var("n") - assert tvm.ir.structural_equal({"x": n+1}, {"x": n+1}) - - - - - -def test_attrs_hash(): - fhash = tvm.ir.structural_hash - x = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3, 4)) - y = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3, 4)) - assert fhash({"x": x}) == fhash({"x": y}) - assert fhash({"x": x}) != fhash({"x": [y, 1]}) - assert fhash({"x": [x, 1]}) == fhash({"x": [y, 1]}) - assert fhash({"x": [x, 2]}) == fhash({"x": [y, 2]}) - - -if __name__ == "__main__": - test_attrs_equal() - test_attrs_hash() diff --git a/tests/python/unittest/test_tir_pass_basic.py b/tests/python/unittest/test_tir_pass_basic.py deleted file mode 100644 index 228e0c52c435..000000000000 --- a/tests/python/unittest/test_tir_pass_basic.py +++ /dev/null @@ -1,57 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import tvm -from tvm import te - -def test_simplify(): - tdiv = tvm.tir.truncdiv - tmod = tvm.tir.truncmod - x = te.var('x') - e1 = tvm.tir.ir_pass.Simplify(x + 2 + 1) - assert(tvm.ir.structural_equal(e1, x + 3)) - e2 = tvm.tir.ir_pass.Simplify(x * 3 + 5 * x) - assert(tvm.ir.structural_equal(e2, x * 8)) - e3 = tvm.tir.ir_pass.Simplify(x - tdiv(x, 3) * 3) - assert(tvm.ir.structural_equal(e3, tmod(x, 3))) - - -def test_verify_ssa(): - x = te.var('x') - y = te.var() - z = tvm.tir.Evaluate(x + y) - assert(tvm.tir.ir_pass.VerifySSA(z)) - - -def test_convert_ssa(): - x = te.var('x') - y = te.var() - let1 = tvm.tir.Let(x, 1, x + 1) - let2 = tvm.tir.Let(x, 1, x + y) - z = tvm.tir.Evaluate(let1 + let2) - assert(not tvm.tir.ir_pass.VerifySSA(z)) - z_ssa = tvm.tir.ir_pass.ConvertSSA(z) - assert(tvm.tir.ir_pass.VerifySSA(z_ssa)) - - -def test_expr_use_var(): - x = te.var('x') - assert(tvm.tir.ir_pass.ExprUseVar(x+1, x)) - assert(not tvm.tir.ir_pass.ExprUseVar(1+10, x)) - - -if __name__ == "__main__": - test_expr_use_var() diff --git a/tests/python/unittest/test_tir_pass_decorate_device_scope.py b/tests/python/unittest/test_tir_pass_decorate_device_scope.py deleted file mode 100644 index 327cfd9ed548..000000000000 --- a/tests/python/unittest/test_tir_pass_decorate_device_scope.py +++ /dev/null @@ -1,43 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import tvm -from tvm import te - -def test_decorate_device(): - m = te.size_var('m') - l = te.size_var('l') - A = te.placeholder((m, l), name='A') - - A1 = te.compute((m, l), lambda i, j: A[i, j], name='A1') - A2 = te.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2') - - s = te.create_schedule(A2.op) - xo, xi = s[A2].split(A2.op.axis[0], factor=8) - s[A1].compute_at(s[A2], xo) - s[A1].set_scope("shared") - - bounds = tvm.te.schedule.InferBound(s) - stmt = tvm.te.schedule.ScheduleOps(s, bounds) - stmt1 = tvm.tir.ir_pass.Simplify(stmt) - stmt2 = tvm.tir.ir_pass.DecorateDeviceScope(stmt1) - assert isinstance(stmt2, tvm.tir.AttrStmt) - assert stmt2.attr_key == "device_scope" - assert stmt1 == stmt2.body - -if __name__ == "__main__": - test_decorate_device() - diff --git a/tests/python/unittest/test_tir_pass_hoist_if.py b/tests/python/unittest/test_tir_pass_hoist_if.py index f6bdbd6130f4..80e93a706ee7 100644 --- a/tests/python/unittest/test_tir_pass_hoist_if.py +++ b/tests/python/unittest/test_tir_pass_hoist_if.py @@ -32,18 +32,18 @@ def _visit(op): key = op if isinstance(op, tvm.tir.IfThenElse): global var_list - tvm.tir.ir_pass.PostOrderVisit(op.condition, _extract_vars) - val = [(op.then_case, op.else_case), ("IfThenElse", tuple(var_list))] + tvm.tir.stmt_functor.post_order_visit(op.condition, _extract_vars) + val = [(op.then_case, op.else_case), ("tir.IfThenElse", tuple(var_list))] var_list.clear() elif isinstance(op, tvm.tir.For): - val = [(op.body,), ("For", op.loop_var.name)] + val = [(op.body,), ("tir.For", op.loop_var.name)] elif isinstance(op, tvm.tir.AttrStmt): - val = [(op.body,), ("AttrStmt", op.attr_key, int(op.value))] + val = [(op.body,), ("tir.AttrStmt", op.attr_key, int(op.value))] else: return node_dict[key] = val - tvm.tir.ir_pass.PostOrderVisit(stmt, _visit) + tvm.tir.stmt_functor.post_order_visit(stmt, _visit) for key, val in node_dict.items(): struct[val[1]] = tuple(node_dict[child][1] if child in node_dict else None for child in val[0]) @@ -67,10 +67,10 @@ def test_basic(): ib.emit(tvm.tir.Evaluate(n)) stmt = ib.get() - new_stmt = tvm.tir.ir_pass.HoistIfThenElse(stmt) - expected_struct = {('For', 'k'): (None,), ('For', 'j'): (('For', 'k'),), - ('IfThenElse', ('i',)): (('For', 'j'), ('For', 'j')), - ('For', 'i'): (('IfThenElse', ('i',)),)} + new_stmt = tvm.testing.HoistIfThenElse(stmt) + expected_struct = {('tir.For', 'k'): (None,), ('tir.For', 'j'): (('tir.For', 'k'),), + ('tir.IfThenElse', ('i',)): (('tir.For', 'j'), ('tir.For', 'j')), + ('tir.For', 'i'): (('tir.IfThenElse', ('i',)),)} verify_structure(new_stmt, expected_struct) def test_no_else(): @@ -86,10 +86,10 @@ def test_no_else(): ib.emit(tvm.tir.Evaluate(m)) stmt = ib.get() - new_stmt = tvm.tir.ir_pass.HoistIfThenElse(stmt) - expected_struct = {('For', 'k'): (None,), ('For', 'j'): (('For', 'k'),), - ('IfThenElse', ('i',)): (('For', 'j'), None), - ('For', 'i'): (('IfThenElse', ('i',)),)} + new_stmt = tvm.testing.HoistIfThenElse(stmt) + expected_struct = {('tir.For', 'k'): (None,), ('tir.For', 'j'): (('tir.For', 'k'),), + ('tir.IfThenElse', ('i',)): (('tir.For', 'j'), None), + ('tir.For', 'i'): (('tir.IfThenElse', ('i',)),)} verify_structure(new_stmt, expected_struct) def test_attr_stmt(): @@ -113,11 +113,11 @@ def test_attr_stmt(): data[bx * j + tx * j * k] = data[bx * j + tx * j * k] + 1.0 stmt = ib.get() - new_stmt = tvm.tir.ir_pass.HoistIfThenElse(stmt) - expected_struct = {('For', 'k'): (None,), ('IfThenElse', ('i', 'j')): (('For', 'k'), ('For', 'k')), - ('For', 'j'): (('IfThenElse', ('i', 'j')),), ('For', 'i'): (('For', 'j'),), - ('AttrStmt', 'thread_extent', 64): (('For', 'i'),), - ('AttrStmt', 'thread_extent', 32): (('AttrStmt', 'thread_extent', 64),)} + new_stmt = tvm.testing.HoistIfThenElse(stmt) + expected_struct = {('tir.For', 'k'): (None,), ('tir.IfThenElse', ('i', 'j')): (('tir.For', 'k'), ('tir.For', 'k')), + ('tir.For', 'j'): (('tir.IfThenElse', ('i', 'j')),), ('tir.For', 'i'): (('tir.For', 'j'),), + ('tir.AttrStmt', 'thread_extent', 64): (('tir.For', 'i'),), + ('tir.AttrStmt', 'thread_extent', 32): (('tir.AttrStmt', 'thread_extent', 64),)} verify_structure(new_stmt, expected_struct) def test_nested_for(): @@ -137,10 +137,10 @@ def test_nested_for(): data[i * 3 + j + k + l] = data[i * 3 + j + k + l] * 1.5 stmt = ib.get() - new_stmt = tvm.tir.ir_pass.HoistIfThenElse(stmt) - expected_struct = {('IfThenElse', ('i', 'j')): (None, None), ('For', 'l'): (('IfThenElse', ('i', 'j')),), - ('For', 'k'): (('For', 'l'),), ('For', 'j'): (None,), ('IfThenElse', ('i',)): (('For', 'j'), None), - ('For', 'i'): (('IfThenElse', ('i',)),)} + new_stmt = tvm.testing.HoistIfThenElse(stmt) + expected_struct = {('tir.IfThenElse', ('i', 'j')): (None, None), ('tir.For', 'l'): (('tir.IfThenElse', ('i', 'j')),), + ('tir.For', 'k'): (('tir.For', 'l'),), ('tir.For', 'j'): (None,), ('tir.IfThenElse', ('i',)): (('tir.For', 'j'), None), + ('tir.For', 'i'): (('tir.IfThenElse', ('i',)),)} verify_structure(new_stmt, expected_struct) def test_if_block(): @@ -170,11 +170,11 @@ def test_if_block(): data[i * 3 + j + k] = data[i * 3 + j + k] + 0.6 stmt = ib.get() - new_stmt = tvm.tir.ir_pass.HoistIfThenElse(stmt) - expected_struct = {('IfThenElse', ('i', 'j')): (None, None), ('IfThenElse', ('j',)): (None, None), - ('For', 'l'): (None,), ('For', 'k'): (None,), ('For', 'j'): (('For', 'j'),), - ('IfThenElse', ('i',)): (('For', 'j'), None), ('For', 'i'): (('IfThenElse', ('i',)),), - ('IfThenElse', ('n',)): (('For', 'j'), None)} + new_stmt = tvm.testing.HoistIfThenElse(stmt) + expected_struct = {('tir.IfThenElse', ('i', 'j')): (None, None), ('tir.IfThenElse', ('j',)): (None, None), + ('tir.For', 'l'): (None,), ('tir.For', 'k'): (None,), ('tir.For', 'j'): (('tir.For', 'j'),), + ('tir.IfThenElse', ('i',)): (('tir.For', 'j'), None), ('tir.For', 'i'): (('tir.IfThenElse', ('i',)),), + ('tir.IfThenElse', ('n',)): (('tir.For', 'j'), None)} verify_structure(new_stmt, expected_struct) diff --git a/tests/python/unittest/test_tir_pass_inline.py b/tests/python/unittest/test_tir_pass_inline.py deleted file mode 100644 index ad0591d3a7c1..000000000000 --- a/tests/python/unittest/test_tir_pass_inline.py +++ /dev/null @@ -1,54 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import tvm -from tvm import te - -def test_inline(): - m = te.size_var('m') - A = te.placeholder((m,), name='A') - T = te.compute((m,), lambda i,: A[i] + 10, name='T') - stmt = tvm.tir.Evaluate(T[10] + 11 * T[100]) - stmt = tvm.tir.ir_pass.Inline( - stmt, T.op, [x.var for x in T.op.axis], T.op.body[0]) - print(stmt) - assert(tvm.tir.ir_pass.VerifySSA(stmt)) - - try: - # pass in int array(wrong argument type) - # must raise an error - stmt = tvm.tir.ir_pass.Inline( - T.op, [1,2,3], T.op.body, stmt) - assert False - except tvm.error.TVMError: - pass - -def test_inline2(): - m = te.size_var('m') - A = te.placeholder((m,), name='A') - T = te.compute((m,), lambda i,: A[i] + 10, name='T') - stmt = tvm.tir.Evaluate(te.exp(T[10]) + 11 * T[100]) - stmt = tvm.tir.ir_pass.Inline( - stmt, T.op, [x.var for x in T.op.axis], T.op.body[0]) - def check(op): - if isinstance(op, tvm.tir.Call): - assert op.func != T.op - tvm.tir.ir_pass.PostOrderVisit(stmt, check) - - -if __name__ == "__main__": - test_inline2() - test_inline() diff --git a/tests/python/unittest/test_tir_pass_virtual_thread.py b/tests/python/unittest/test_tir_pass_virtual_thread.py deleted file mode 100644 index 2d96696eed88..000000000000 --- a/tests/python/unittest/test_tir_pass_virtual_thread.py +++ /dev/null @@ -1,45 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import tvm -from tvm import te - -def test_virtual_thread(): - m = te.var('m') - A = te.placeholder((m, ), name='A') - A1 = te.compute((m,), lambda i: A[i], name='A1') - A2 = te.compute((m,), lambda i: A1[i] + 3, name='A2') - - s = te.create_schedule(A2.op) - vx = te.thread_axis("vthread", name="vx") - xo, xi = s[A2].split(A2.op.axis[0], nparts=2) - s[A2].bind(xo, vx) - xo, xi = s[A2].split(xi, 8) - s[A1].compute_at(s[A2], xo) - - bounds = tvm.te.schedule.InferBound(s) - assert isinstance(bounds, tvm.container.Map) - stmt = tvm.te.schedule.ScheduleOps(s, bounds) - - Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A') - A2b = tvm.tir.decl_buffer(A2.shape, A2.dtype, name='A2') - stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b}, 64) - stmt = tvm.tir.ir_pass.Simplify(stmt) - stmt = tvm.tir.ir_pass.InjectVirtualThread(stmt) - print(stmt) - -if __name__ == "__main__": - test_virtual_thread() diff --git a/tests/python/unittest/test_tir_pass_ir_transform.py b/tests/python/unittest/test_tir_stmt_functor_ir_transform.py similarity index 95% rename from tests/python/unittest/test_tir_pass_ir_transform.py rename to tests/python/unittest/test_tir_stmt_functor_ir_transform.py index cb7417a7a54f..38529e927d52 100644 --- a/tests/python/unittest/test_tir_pass_ir_transform.py +++ b/tests/python/unittest/test_tir_stmt_functor_ir_transform.py @@ -37,7 +37,7 @@ def postorder(op): if op.name == "TestA": return tvm.tir.call_extern("int32", "TestB", op.args[0] + 1) return op - body = tvm.tir.ir_pass.IRTransform(body, preorder, postorder, ["Call"]) + body = tvm.tir.stmt_functor.ir_transform(body, preorder, postorder, ["tir.Call"]) stmt_list = tvm.tir.stmt_list(body.body.body) assert stmt_list[0].value.args[0].name == "TestB" assert stmt_list[1].value.value == 0 diff --git a/tests/python/unittest/test_tir_transform_combine_context_call.py b/tests/python/unittest/test_tir_transform_combine_context_call.py index 7fd2593bd365..29a330319622 100644 --- a/tests/python/unittest/test_tir_transform_combine_context_call.py +++ b/tests/python/unittest/test_tir_transform_combine_context_call.py @@ -22,7 +22,7 @@ def test_for(): def device_context(dev_id): ctx = tvm.tir.call_extern("handle", "device_context", dev_type, dev_id) return tvm.tir.Call( - "handle", "tvm_thread_context", [ctx], tvm.tir.Call.Intrinsic, None, 0) + "handle", "tvm_thread_context", [ctx], tvm.tir.Call.Intrinsic) ib = tvm.tir.ir_builder.create() n = te.var("n") diff --git a/tests/python/unittest/test_tir_pass_coproc_sync.py b/tests/python/unittest/test_tir_transform_coproc_sync.py similarity index 91% rename from tests/python/unittest/test_tir_pass_coproc_sync.py rename to tests/python/unittest/test_tir_transform_coproc_sync.py index b0e2050e2ee9..f6583493d646 100644 --- a/tests/python/unittest/test_tir_pass_coproc_sync.py +++ b/tests/python/unittest/test_tir_transform_coproc_sync.py @@ -37,7 +37,10 @@ def meminfo_cache(): ib.scope_attr(cp, "coproc_scope", 1) A[j] = A[j + k * 10] + 2 stmt = ib.get() - stmt = tvm.tir.ir_pass.CoProcSync(stmt) + + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], stmt)) + stmt = tvm.tir.transform.CoProcSync()(mod)["main"].body + body = stmt.body.body.body blist = tvm.tir.stmt_list(body) assert(blist[1].value.name == "cop.coproc_read_barrier") @@ -65,7 +68,10 @@ def test_coproc_sync2(): ib.scope_attr(cp, "coproc_scope", 2) A[ty] = 1.0 stmt = ib.get() - stmt = tvm.tir.ir_pass.CoProcSync(stmt) + + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], stmt)) + stmt = tvm.tir.transform.CoProcSync()(mod)["main"].body + def test_coproc_sync3(): def __check_list(tvm_array, py_list): @@ -91,7 +97,10 @@ def __check_list(tvm_array, py_list): A[0] = 0.0 stmt = ib.get() - stmt = tvm.tir.ir_pass.CoProcSync(stmt) + + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], stmt)) + stmt = tvm.tir.transform.CoProcSync()(mod)["main"].body + slist = tvm.tir.stmt_list(stmt[0].body.body) push_st = slist[2] slist = tvm.tir.stmt_list(slist[-1]) diff --git a/tests/python/unittest/test_tir_transform_decorate_device_scope.py b/tests/python/unittest/test_tir_transform_decorate_device_scope.py new file mode 100644 index 000000000000..cf9ea9e00fe1 --- /dev/null +++ b/tests/python/unittest/test_tir_transform_decorate_device_scope.py @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +from tvm import te + +def test_decorate_device(): + x = te.var("x") + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([x], tvm.tir.Evaluate(x))) + + stmt = tvm.tir.transform.DecorateDeviceScope()(mod)["main"].body + assert stmt.attr_key == "device_scope" + +if __name__ == "__main__": + test_decorate_device() diff --git a/tests/python/unittest/test_tir_pass_inject_copy_intrin.py b/tests/python/unittest/test_tir_transform_inject_copy_intrin.py similarity index 66% rename from tests/python/unittest/test_tir_pass_inject_copy_intrin.py rename to tests/python/unittest/test_tir_transform_inject_copy_intrin.py index 8c34e344d73e..887b8b0c2b75 100644 --- a/tests/python/unittest/test_tir_pass_inject_copy_intrin.py +++ b/tests/python/unittest/test_tir_transform_inject_copy_intrin.py @@ -26,16 +26,19 @@ def test_copy2d(): s[B].pragma(B.op.axis[0], "memcpy") bounds = tvm.te.schedule.InferBound(s) stmt = tvm.te.schedule.ScheduleOps(s, bounds) - Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A') - Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B') - stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64) + func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None) + mod = tvm.IRModule.from_expr(func) + mod = tvm.tir.transform.StorageFlatten(64)(mod) + def cb(src, dst, pad_before, pad_after, pad_value): assert dst.strides[0] == l assert dst.strides[1].value == 1 assert src.strides[0] == l assert tuple(src.shape) == (m, l) return tvm.tir.Evaluate(0) - stmt = tvm.tir.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb) + + stmt = tvm.tir.transform.InjectCopyIntrin("memcpy", cb)(mod)["main"].body + def test_copy_pad(): m = te.var('m') @@ -48,18 +51,22 @@ def test_copy_pad(): s[B].pragma(B.op.axis[0], "memcpy") bounds = tvm.te.schedule.InferBound(s) stmt = tvm.te.schedule.ScheduleOps(s, bounds) - Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A') - Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B') - stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64) + + func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None) + mod = tvm.IRModule.from_expr(func) + mod = tvm.tir.transform.StorageFlatten(64)(mod) + def cb(src, dst, pad_before, pad_after, pad_value): - assert tvm.tir.ir_pass.Simplify(src.elem_offset).value == 0 + tvm.testing.assert_prim_expr_equal(src.elem_offset, 0) assert pad_before[0].value == 1 assert pad_before[1].value == 0 assert pad_after[0].value == 1 assert pad_after[1].value == 0 assert pad_value.value == 1.0 return tvm.tir.Evaluate(0) - stmt = tvm.tir.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb) + + stmt = tvm.tir.transform.InjectCopyIntrin("memcpy", cb)(mod)["main"].body + def test_single_point_test(): A = te.placeholder((1,), name='A') @@ -69,19 +76,20 @@ def test_single_point_test(): s[B].pragma(B.op.axis[0], "memcpy") bounds = tvm.te.schedule.InferBound(s) stmt = tvm.te.schedule.ScheduleOps(s, bounds) - Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A') - Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B') - stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64) + + func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None) + mod = tvm.IRModule.from_expr(func) + mod = tvm.tir.transform.StorageFlatten(64)(mod) + def cb(src, dst, pad_before, pad_after, pad_value): - assert tvm.tir.ir_pass.Simplify(src.elem_offset).value == 0 - assert tvm.tir.ir_pass.Simplify(dst.elem_offset).value == 0 - assert tvm.tir.ir_pass.Simplify(src.strides[0]).value == 1 - assert tvm.tir.ir_pass.Simplify(dst.strides[0]).value == 1 + tvm.testing.assert_prim_expr_equal(src.elem_offset, 0) + tvm.testing.assert_prim_expr_equal(dst.elem_offset, 0) + tvm.testing.assert_prim_expr_equal(src.strides[0], 1) + tvm.testing.assert_prim_expr_equal(dst.strides[0], 1) return tvm.tir.Evaluate(0) - stmt = tvm.tir.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb) -def assert_expr_equal(a, b): - assert tvm.tir.ir_pass.Simplify(a - b).value == 0 + stmt = tvm.tir.transform.InjectCopyIntrin("memcpy", cb)(mod)["main"].body + def test_copy_pad_split(): m = 4 * 3 @@ -96,22 +104,25 @@ def test_copy_pad_split(): s[Apad].pragma(s[Apad].op.axis[0], "memcpy") bounds = tvm.te.schedule.InferBound(s) stmt = tvm.te.schedule.ScheduleOps(s, bounds) - Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A') - Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B') - stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64) - stmt = tvm.tir.ir_pass.Simplify(stmt) - stmt = tvm.tir.ir_pass.CanonicalSimplify(stmt) + + func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None) + mod = tvm.IRModule.from_expr(func) + mod = tvm.tir.transform.StorageFlatten(64)(mod._move()) + mod = tvm.tir.transform.Simplify()(mod._move()) + def cb(src, dst, pad_before, pad_after, pad_value): assert(dst.elem_offset.value == 0) - assert_expr_equal(src.elem_offset, tvm.te.max(xo * 4, 1) - 1) + tvm.testing.assert_prim_expr_equal(src.elem_offset, tvm.te.max(xo * 4, 1) - 1) rpad_before = tvm.te.max(1 - xo * 4, 0) rpad_after = tvm.te.max(xo * 4 - 7, 0) - assert_expr_equal(pad_before[0], rpad_before) - assert_expr_equal(pad_after[0], rpad_after) - assert_expr_equal(src.shape[0], 6 - rpad_before - rpad_after) + tvm.testing.assert_prim_expr_equal(pad_before[0], rpad_before) + tvm.testing.assert_prim_expr_equal(pad_after[0], rpad_after) + tvm.testing.assert_prim_expr_equal(src.shape[0], 6 - rpad_before - rpad_after) return tvm.tir.Evaluate(0) - stmt = tvm.tir.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb) + + stmt = tvm.tir.transform.InjectCopyIntrin("memcpy", cb)(mod)["main"].body + if __name__ == "__main__": diff --git a/tests/python/unittest/test_tir_pass_inject_double_buffer.py b/tests/python/unittest/test_tir_transform_inject_double_buffer.py similarity index 84% rename from tests/python/unittest/test_tir_pass_inject_double_buffer.py rename to tests/python/unittest/test_tir_transform_inject_double_buffer.py index 6b04db30f6d5..0b6b167c8660 100644 --- a/tests/python/unittest/test_tir_pass_inject_double_buffer.py +++ b/tests/python/unittest/test_tir_transform_inject_double_buffer.py @@ -36,19 +36,29 @@ def test_double_buffer(): C[j] = B[j] + 1 stmt = ib.get() - stmt = tvm.tir.ir_pass.InjectDoubleBuffer(stmt, 2) - stmt = tvm.tir.ir_pass.Simplify(stmt) - assert isinstance(stmt.body.body, tvm.tir.Allocate) - assert stmt.body.body.extents[0].value == 2 mod = tvm.IRModule({ "db" : tvm.tir.PrimFunc([A.asobject(), C.asobject()], stmt) }) + + opt = tvm.transform.Sequential( + [tvm.tir.transform.InjectDoubleBuffer(), + tvm.tir.transform.Simplify()]) + + with tvm.transform.PassContext(config={ + "tir.InjectDoubleBuffer" : {"split_loop" : 2} + }): + mod = opt(mod) + stmt = mod["db"].body + + assert isinstance(stmt.body.body, tvm.tir.Allocate) + assert stmt.body.body.extents[0].value == 2 + f = tvm.tir.transform.ThreadSync("shared")(mod)["db"] count = [0] def count_sync(op): if isinstance(op, tvm.tir.Call) and op.name == "tvm_storage_sync": count[0] += 1 - tvm.tir.ir_pass.PostOrderVisit(f.body, count_sync) + tvm.tir.stmt_functor.post_order_visit(f.body, count_sync) assert count[0] == 4 diff --git a/tests/python/unittest/test_tir_pass_inject_vthread.py b/tests/python/unittest/test_tir_transform_inject_virtual_thread.py similarity index 83% rename from tests/python/unittest/test_tir_pass_inject_vthread.py rename to tests/python/unittest/test_tir_transform_inject_virtual_thread.py index 8fbd8295d238..c0789c654fbf 100644 --- a/tests/python/unittest/test_tir_pass_inject_vthread.py +++ b/tests/python/unittest/test_tir_transform_inject_virtual_thread.py @@ -40,9 +40,14 @@ def get_vthread(name): C[i * nthread + tx] = B[i] + 1 return ib.get() - stmt = tvm.tir.ir_pass.InjectVirtualThread(get_vthread("vthread")) + stmt = tvm.tir.transform.InjectVirtualThread()(tvm.IRModule.from_expr( + tvm.tir.PrimFunc([], get_vthread("vthread"))))["main"].body + assert stmt.body.body.extents[0].value == 2 - stmt = tvm.tir.ir_pass.InjectVirtualThread(get_vthread("cthread")) + + stmt = tvm.tir.transform.InjectVirtualThread()(tvm.IRModule.from_expr( + tvm.tir.PrimFunc([], get_vthread("cthread"))))["main"].body + assert len(stmt.body.body.extents) == 3 @@ -67,16 +72,20 @@ def get_vthread(name): A[tx] = tx + 1.0 B[ty] = ty + 1.0 ib.emit(tvm.tir.call_extern("int32", "Run", - abuffer.access_ptr("r"), - bbuffer.access_ptr("r"), - cbuffer.access_ptr("rw"))) + abuffer.access_ptr("r"), + bbuffer.access_ptr("r"), + cbuffer.access_ptr("rw"))) return ib.get() - stmt = tvm.tir.ir_pass.InjectVirtualThread(get_vthread("vthread")) + + stmt = tvm.tir.transform.InjectVirtualThread()(tvm.IRModule.from_expr( + tvm.tir.PrimFunc([], get_vthread("cthread"))))["main"].body + assert stmt.body.body.extents[0].value == 2 assert stmt.body.body.body.body.body.body.extents[0].value == 2 assert len(stmt.body.body.body.body.body.body.extents) == 3 + def test_vthread_if_then_else(): nthread = 2 tx = te.thread_axis("vthread") @@ -92,7 +101,10 @@ def test_vthread_if_then_else(): with ib.if_scope(i == 0): B[i] = A[i * nthread + tx] + 2 stmt = ib.get() - stmt = tvm.tir.ir_pass.InjectVirtualThread(stmt) + + stmt = tvm.tir.transform.InjectVirtualThread()(tvm.IRModule.from_expr( + tvm.tir.PrimFunc([], stmt)))["main"].body + assert stmt.body.body.body[0].else_case != None assert stmt.body.body.body[1].else_case == None diff --git a/tests/python/unittest/test_tir_pass_bound_checkers.py b/tests/python/unittest/test_tir_transform_instrument_bound_checkers.py similarity index 93% rename from tests/python/unittest/test_tir_pass_bound_checkers.py rename to tests/python/unittest/test_tir_transform_instrument_bound_checkers.py index d6c89b2ab878..fa27fddf4c98 100644 --- a/tests/python/unittest/test_tir_pass_bound_checkers.py +++ b/tests/python/unittest/test_tir_transform_instrument_bound_checkers.py @@ -18,32 +18,12 @@ import tvm from tvm import te import numpy as np + def collect_visit(stmt, f): ret = [] - tvm.tir.ir_pass.PostOrderVisit(stmt, lambda x: ret.append(f(x))) + tvm.tir.stmt_functor.post_order_visit(stmt, lambda x: ret.append(f(x))) return ret -def lower(sch, args): - binds = {} - arg_list = [] - for x in args: - if isinstance(x, te.tensor.Tensor): - buf = tvm.tir.decl_buffer(x.shape, dtype=x.dtype, name=x.name) - assert x not in binds - binds[x] = buf - arg_list.append(buf) - else: - raise ValueError("args must be Tensor, Buffer or Var") - sch = sch.normalize() - bounds = tvm.te.schedule.InferBound(sch) - stmt = tvm.te.schedule.ScheduleOps(sch, bounds) - stmt = tvm.tir.ir_pass.LoopPartition(stmt, True) - stmt = tvm.tir.ir_pass.RemoveNoOp(stmt) - stmt = tvm.tir.ir_pass.StorageFlatten(stmt, binds, 64, True) - stmt = tvm.tir.ir_pass.CanonicalSimplify(stmt) - stmt = tvm.tir.ir_pass.VectorizeLoop(stmt) - stmt = tvm.tir.ir_pass.Simplify(stmt) - return stmt @pytest.mark.xfail def test_out_of_bounds_llvm(index_a, index_b): @@ -72,7 +52,6 @@ def test_in_bounds_llvm(): tgt = "llvm" tgt_host = "llvm" stmt = tvm.lower (s, [A, B, C], simple_mode=True) - print (stmt) fadd = tvm.build (s, [A, B, C], tgt, target_host=tgt_host, name="myadd") ctx = tvm.context(tgt, 0) a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx) @@ -93,7 +72,6 @@ def test_out_of_bounds_vectorize_llvm(nn, index_a, index_b): tgt = "llvm" tgt_host = "llvm" stmt = tvm.lower (s, [a, b, c], simple_mode=True) - print (stmt) f = tvm.build(s, [a, b, c], tgt, target_host=tgt_host, name="myaddvec") ctx = tvm.cpu(0) n = nn @@ -192,13 +170,13 @@ def collect_branch_stmt (x): s = te.create_schedule(T.op) xo, xi = s[T].split(T.op.axis[0], factor=4) - bounds = tvm.te.schedule.InferBound(s) - stmt = lower (s, [A, B, T]) - # num_attributes = num_buffers * num_splits = 2 * 3 - # before instrumentation - assert_bound_instrumentation(stmt, check_attr_stmt, 2 * 3) - assert_bound_instrumentation(stmt, check_branch_stmt, 0) - stmt = tvm.tir.ir_pass.InstrumentBoundCheckers(stmt) + with tvm.transform.PassContext(config={ + "tir.instrument_bound_checkers": True, + "tir.LoopPartition": {"partition_const_loop": True} + }): + mod = tvm.driver.lower(s, [A, B, T], name="main") + + stmt = mod["main"].body # after instrumentation assert_bound_instrumentation(stmt, check_attr_stmt, 2 * 3) assert_bound_instrumentation(stmt, check_branch_stmt, 2) @@ -209,7 +187,10 @@ def collect_branch_stmt (x): def test_in_bounds_const_loop_partition_llvm(): - with tvm.target.build_config(instrument_bound_checkers=True, partition_const_loop=True): + with tvm.transform.PassContext(config={ + "tir.instrument_bound_checkers": True, + "tir.LoopPartition": {"partition_const_loop": True} + }): n = 21 A = te.placeholder((n, ), name='A') B = te.placeholder((n, ), name='B') @@ -228,7 +209,10 @@ def test_in_bounds_const_loop_partition_llvm(): @pytest.mark.xfail def test_out_of_bounds_const_loop_partition_llvm(index_a, index_b): - with tvm.target.build_config(instrument_bound_checkers=True, partition_const_loop=True): + with tvm.transform.PassContext(config={ + "tir.instrument_bound_checkers": True, + "tir.LoopPartition": {"partition_const_loop": True} + }): n = 21 A = te.placeholder((n, ), name='A') B = te.placeholder((n, ), name='B') @@ -462,7 +446,9 @@ def test_out_of_bounds_tensors_with_zero_shape_op_with_not_zero_shape_llvm(): tvm.testing.assert_allclose(d.asnumpy(), d_np) if __name__ == "__main__": - with tvm.target.build_config(instrument_bound_checkers=True): + with tvm.transform.PassContext(config={ + "tir.instrument_bound_checkers": True, + }): # zero scale test_out_of_bounds_tensors_with_zero_shape_op_with_not_zero_shape_llvm() # in bound diff --git a/tests/python/unittest/test_tir_pass_lift_attr_scope.py b/tests/python/unittest/test_tir_transform_lift_attr_scope.py similarity index 88% rename from tests/python/unittest/test_tir_pass_lift_attr_scope.py rename to tests/python/unittest/test_tir_transform_lift_attr_scope.py index 0831565dc155..f5f4030d1b23 100644 --- a/tests/python/unittest/test_tir_pass_lift_attr_scope.py +++ b/tests/python/unittest/test_tir_transform_lift_attr_scope.py @@ -35,7 +35,10 @@ def test_coproc_lift(): A[j] = A[j] + 3 A[j] = A[j] + 3 body = ib.get() - body = tvm.tir.ir_pass.LiftAttrScope(body, "coproc_uop_scope") + + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body)) + body = tvm.tir.transform.LiftAttrScope("coproc_uop_scope")(mod)["main"].body + assert body.body.body.node == cp # only able to lift to the common pattern of the last two fors. @@ -52,7 +55,10 @@ def test_coproc_lift(): A[i] = A[i] + 2 body = ib.get() - body = tvm.tir.ir_pass.LiftAttrScope(body, "coproc_uop_scope") + + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body)) + body = tvm.tir.transform.LiftAttrScope("coproc_uop_scope")(mod)["main"].body + assert body.body.body.body[1].node == cp assert len(body.body.body.body) == 2 diff --git a/tests/python/unittest/test_tir_pass_loop_partition.py b/tests/python/unittest/test_tir_transform_loop_partition.py similarity index 73% rename from tests/python/unittest/test_tir_pass_loop_partition.py rename to tests/python/unittest/test_tir_transform_loop_partition.py index 1256d8bbd4fc..ce8c16e87413 100644 --- a/tests/python/unittest/test_tir_pass_loop_partition.py +++ b/tests/python/unittest/test_tir_transform_loop_partition.py @@ -20,29 +20,9 @@ def collect_visit(stmt, f): ret = [] - tvm.tir.ir_pass.PostOrderVisit(stmt, lambda x : ret.append(f(x))) + tvm.tir.stmt_functor.post_order_visit(stmt, lambda x : ret.append(f(x))) return ret -def lower(sch, args): - binds = {} - arg_list = [] - for x in args: - if isinstance(x, te.tensor.Tensor): - buf = tvm.tir.decl_buffer(x.shape, dtype=x.dtype, name=x.name) - assert x not in binds - binds[x] = buf - arg_list.append(buf) - else: - raise ValueError("args must be Tensor, Buffer or Var") - sch = sch.normalize() - bounds = tvm.te.schedule.InferBound(sch) - stmt = tvm.te.schedule.ScheduleOps(sch, bounds) - stmt = tvm.tir.ir_pass.LoopPartition(stmt, False) - stmt = tvm.tir.ir_pass.StorageFlatten(stmt, binds, 64) - stmt = tvm.tir.ir_pass.CanonicalSimplify(stmt) - stmt = tvm.tir.ir_pass.VectorizeLoop(stmt) - stmt = tvm.tir.ir_pass.Simplify(stmt) - return stmt def test_basic(): n = te.size_var('n') @@ -55,10 +35,16 @@ def test_basic(): bounds = tvm.te.schedule.InferBound(s) stmt = tvm.te.schedule.ScheduleOps(s, bounds) - stmt = tvm.tir.ir_pass.LoopPartition(stmt, False) - stmt = tvm.tir.ir_pass.Simplify(stmt) - assert('if' not in str(stmt.body.body[0])) - assert('if' in str(stmt.body.body[1])) + + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], stmt)) + mod = tvm.tir.transform.LoopPartition()(mod) + stmt = tvm.tir.transform.Simplify()(mod)["main"].body + + assert(not any( + collect_visit(stmt.body.body[0], lambda x: isinstance(x, tvm.tir.IfThenElse)))) + assert(any( + collect_visit(stmt.body.body[1], lambda x: isinstance(x, tvm.tir.IfThenElse)))) + def test_const_loop(): n = 21 @@ -71,9 +57,15 @@ def test_const_loop(): bounds = tvm.te.schedule.InferBound(s) stmt = tvm.te.schedule.ScheduleOps(s, bounds) - stmt = tvm.tir.ir_pass.LoopPartition(stmt, True) - stmt = tvm.tir.ir_pass.Simplify(stmt) - assert('if' not in str(stmt.body.body[0])) + + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) + with tvm.transform.PassContext(config={ + "tir.LoopPartition": {"partition_const_loop": True} + }): + mod = tvm.tir.transform.LoopPartition()(mod) + stmt = tvm.tir.transform.Simplify()(mod)["main"].body + + assert(not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse)))) def test_multi_loop(): ib = tvm.tir.ir_builder.create() @@ -87,8 +79,11 @@ def test_multi_loop(): with ib.else_scope(): ib.emit(tvm.tir.Evaluate(n)) stmt = ib.get() - stmt = tvm.tir.ir_pass.LoopPartition(stmt, False) - stmt = tvm.tir.ir_pass.Simplify(stmt) + + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n, m], stmt)) + mod = tvm.tir.transform.LoopPartition()(mod) + stmt = tvm.tir.transform.Simplify()(mod)["main"].body + assert(not any(collect_visit(stmt.body[0], lambda x: isinstance(x, tvm.tir.IfThenElse)))) def test_multi_if(): @@ -107,9 +102,14 @@ def test_multi_if(): with ib.else_scope(): ib.emit(tvm.tir.Evaluate(n)) stmt = ib.get() - stmt = tvm.tir.ir_pass.LoopPartition(stmt, False) - stmt = tvm.tir.ir_pass.Simplify(stmt) - assert('if' not in str(stmt.body[0])) + + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) + mod = tvm.tir.transform.LoopPartition()(mod) + stmt = tvm.tir.transform.Simplify()(mod)["main"].body + + assert(not any( + collect_visit(stmt.body[0], lambda x: isinstance(x, tvm.tir.IfThenElse)))) + def test_thread_axis(): m = te.size_var('m') @@ -126,9 +126,14 @@ def test_thread_axis(): bounds = tvm.te.schedule.InferBound(s) stmt = tvm.te.schedule.ScheduleOps(s, bounds) - stmt = tvm.tir.ir_pass.LoopPartition(stmt, False) - stmt = tvm.tir.ir_pass.Simplify(stmt) - assert('if' not in str(stmt.body.body[0])) + + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) + mod = tvm.tir.transform.LoopPartition()(mod) + stmt = tvm.tir.transform.Simplify()(mod)["main"].body + + assert(not any( + collect_visit(stmt.body. body[0], lambda x: isinstance(x, tvm.tir.IfThenElse)))) + def test_vectorize(): n = te.size_var('n') @@ -147,11 +152,12 @@ def test_vectorize(): s[C].bind(bx, te.thread_axis("blockIdx.x")) s[C].bind(tx, te.thread_axis("threadIdx.x")) s[C].vectorize(x) - stmt = lower(s, [A, B]) + stmt = tvm.lower(s, [A, B], name="main")["main"].body body = stmt.body.body.body.body assert(x.var.name not in str(body.condition)) assert(any(collect_visit(body.then_case, lambda x: isinstance(x, tvm.tir.Ramp)))) + def test_condition(): ib = tvm.tir.ir_builder.create() m = te.size_var('m') @@ -161,10 +167,14 @@ def test_condition(): ib.emit(tvm.tir.Evaluate( tvm.tir.Select(ib.likely(i*4+j 1, A[i-1], 1.0) - yy = tvm.tir.ir_pass.RewriteUnsafeSelect(tvm.tir.Evaluate(y)).value + + mod = tvm.IRModule.from_expr( + tvm.tir.PrimFunc([i], tvm.tir.Evaluate(y))) + yy = tvm.tir.transform.RewriteUnsafeSelect()(mod)["main"].body.value z = tvm.tir.Select( tvm.tir.Select(i > 1, A[i-1], 1.0) > 0.0, A[i], 0.1) - zz = tvm.tir.ir_pass.RewriteUnsafeSelect(tvm.tir.Evaluate(z)).value + mod = tvm.IRModule.from_expr( + tvm.tir.PrimFunc([i], tvm.tir.Evaluate(z))) + zz = tvm.tir.transform.RewriteUnsafeSelect()(mod)["main"].body.value + + a = tvm.tir.Select(tvm.tir.floordiv(i, 4) > 10, y, z) - a = tvm.tir.Select(tvm.te.floordiv(i, 4) > 10, y, z) - aa = tvm.tir.ir_pass.RewriteUnsafeSelect(tvm.tir.Evaluate(a)).value + mod = tvm.IRModule.from_expr( + tvm.tir.PrimFunc([i], tvm.tir.Evaluate(a))) + aa = tvm.tir.transform.RewriteUnsafeSelect()(mod)["main"].body.value assert yy.name == "tvm_if_then_else" assert zz.name == "tvm_if_then_else" assert isinstance(aa, tvm.tir.Select) diff --git a/tests/python/unittest/test_arith_stmt_simplify.py b/tests/python/unittest/test_tir_transform_simplify.py similarity index 79% rename from tests/python/unittest/test_arith_stmt_simplify.py rename to tests/python/unittest/test_tir_transform_simplify.py index 45f083342410..48d0849bd1ee 100644 --- a/tests/python/unittest/test_arith_stmt_simplify.py +++ b/tests/python/unittest/test_tir_transform_simplify.py @@ -27,7 +27,9 @@ def test_stmt_simplify(): A[i] = C[i] body = tvm.tir.LetStmt(n, 10, ib.get()) - body = tvm.tir.ir_pass.CanonicalSimplify(body) + mod = tvm.IRModule.from_expr( + tvm.tir.PrimFunc([A, C, n], body)) + body = tvm.tir.transform.Simplify()(mod)["main"].body assert isinstance(body.body, tvm.tir.Store) @@ -44,10 +46,32 @@ def test_thread_extent_simplify(): with ib.if_scope(tx + ty < 12): A[tx] = C[tx + ty] body = tvm.tir.LetStmt(n, 10, ib.get()) - body = tvm.tir.ir_pass.CanonicalSimplify(body) + mod = tvm.IRModule.from_expr( + tvm.tir.PrimFunc([A, C, n], body)) + body = tvm.tir.transform.Simplify()(mod)["main"].body assert isinstance(body.body.body.body, tvm.tir.Store) +def test_if_likely(): + ib = tvm.tir.ir_builder.create() + A = ib.pointer("float32", name="A") + C = ib.pointer("float32", name="C") + n = te.size_var("n") + tx = te.thread_axis("threadIdx.x") + ty = te.thread_axis("threadIdx.y") + ib.scope_attr(tx, "thread_extent", 32) + ib.scope_attr(ty, "thread_extent", 32) + with ib.if_scope(ib.likely(tx * 32 + ty < n)): + with ib.if_scope(ib.likely(tx * 32 + ty < n)): + A[tx] = C[tx * 32 + ty] + body = ib.get() + mod = tvm.IRModule.from_expr( + tvm.tir.PrimFunc([A, C, n], body)) + body = tvm.tir.transform.Simplify()(mod)["main"].body + assert isinstance(body.body.body, tvm.tir.IfThenElse) + assert not isinstance(body.body.body.then_case, tvm.tir.IfThenElse) + + def test_basic_likely_elimination(): n = te.size_var('n') X = te.placeholder(shape=(n,), name="x") @@ -106,5 +130,6 @@ def sls(n, d): if __name__ == "__main__": test_stmt_simplify() test_thread_extent_simplify() + test_if_likely() test_basic_likely_elimination() test_complex_likely_elimination() diff --git a/tests/python/unittest/test_tir_pass_storage_flatten.py b/tests/python/unittest/test_tir_transform_storage_flatten.py similarity index 73% rename from tests/python/unittest/test_tir_pass_storage_flatten.py rename to tests/python/unittest/test_tir_transform_storage_flatten.py index 1eaadb35009d..5fea580fbf5c 100644 --- a/tests/python/unittest/test_tir_pass_storage_flatten.py +++ b/tests/python/unittest/test_tir_transform_storage_flatten.py @@ -30,11 +30,14 @@ def test_flatten2(): bounds = tvm.te.schedule.InferBound(s) assert isinstance(bounds, tvm.container.Map) stmt = tvm.te.schedule.ScheduleOps(s, bounds) - Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A') A2b = tvm.tir.decl_buffer(A2.shape, A2.dtype, name='A2') - stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b}, 64) - stmt = tvm.tir.ir_pass.Simplify(stmt) + + func = tvm.te.schedule.SchedulePostProcToPrimFunc( + [Ab, A2b], stmt, {A: Ab, A2: A2b}) + mod = tvm.IRModule.from_expr(func) + mod = tvm.tir.transform.StorageFlatten(64)(mod) + def test_flatten_prefetch(): A = te.placeholder((25, 100, 4), name = 'A') @@ -42,9 +45,16 @@ def test_flatten_prefetch(): i = te.size_var('i') j = te.size_var('j') region = [tvm.ir.Range.make_by_min_extent(i[0], i[1]) for i in [(i, 2), (j, 8), (0, 4)]] - stmt = tvm.tir.Prefetch(A.op, 0, A.dtype, region) - stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: _A}, 64) - stmt = tvm.tir.ir_pass.Simplify(stmt) + stmt = tvm.tir.Prefetch(_A, region) + + func = tvm.te.schedule.SchedulePostProcToPrimFunc( + [_A], stmt, {A: _A}) + + mod = tvm.IRModule.from_expr(func) + mod = tvm.transform.Sequential([ + tvm.tir.transform.StorageFlatten(64), + tvm.tir.transform.Simplify()])(mod) + stmt = mod["main"].body assert stmt.extent.value == 2 assert isinstance(stmt.body, tvm.tir.For) assert stmt.body.extent.value == 2 @@ -62,12 +72,17 @@ def test_flatten_storage_align(): bounds = tvm.te.schedule.InferBound(s) assert isinstance(bounds, tvm.container.Map) stmt = tvm.te.schedule.ScheduleOps(s, bounds) - Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A') - A2b = tvm.tir.decl_buffer(A2.shape, A2.dtype, name='A2') - stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b}, 64) - stmt = tvm.tir.ir_pass.Simplify(stmt) + + func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, A2], stmt, None) + mod = tvm.IRModule.from_expr(func) + mod = tvm.transform.Sequential([ + tvm.tir.transform.StorageFlatten(64), + tvm.tir.transform.Simplify()])(mod) + + stmt = mod["main"].body assert(stmt.body.extents[0].value == 17 * 8) + def test_flatten_double_buffer(): dtype = 'int64' n = 100 @@ -87,9 +102,20 @@ def test_flatten_double_buffer(): C[j] = B[j] + 1 stmt = ib.get() - stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {}, 64) - stmt = tvm.tir.ir_pass.InjectDoubleBuffer(stmt, 2) - stmt = tvm.tir.ir_pass.Simplify(stmt) + + mod = tvm.IRModule.from_expr( + tvm.tir.PrimFunc([A, C], stmt)) + + + with tvm.transform.PassContext(config={ + "tir.InjectDoubleBuffer" : {"split_loop" : 2} + }): + mod = tvm.transform.Sequential([ + tvm.tir.transform.StorageFlatten(64), + tvm.tir.transform.InjectDoubleBuffer(), + tvm.tir.transform.Simplify()])(mod) + + stmt = mod["main"].body assert isinstance(stmt.body.body, tvm.tir.Allocate) assert stmt.body.body.extents[0].value == 2 @@ -101,11 +127,11 @@ def test_flatten_double_buffer(): def count_sync(op): if isinstance(op, tvm.tir.Call) and op.name == "tvm_storage_sync": count[0] += 1 - tvm.tir.ir_pass.PostOrderVisit(f.body, count_sync) + tvm.tir.stmt_functor.post_order_visit(f.body, count_sync) assert count[0] == 4 if __name__ == "__main__": - test_flatten_storage_align() test_flatten2() - test_flatten_prefetch() + test_flatten_storage_align() test_flatten_double_buffer() + test_flatten_prefetch() diff --git a/tests/python/unittest/test_tir_pass_storage_rewrite.py b/tests/python/unittest/test_tir_transform_storage_rewrite.py similarity index 82% rename from tests/python/unittest/test_tir_pass_storage_rewrite.py rename to tests/python/unittest/test_tir_transform_storage_rewrite.py index b36d86b47af8..46ba687aebda 100644 --- a/tests/python/unittest/test_tir_pass_storage_rewrite.py +++ b/tests/python/unittest/test_tir_transform_storage_rewrite.py @@ -30,19 +30,22 @@ def test_storage_share(): bounds = tvm.te.schedule.InferBound(s) assert isinstance(bounds, tvm.container.Map) stmt = tvm.te.schedule.ScheduleOps(s, bounds) - Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A') - Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B') - stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64) - stmt = tvm.tir.ir_pass.CanonicalSimplify(stmt) - stmt = tvm.tir.ir_pass.Simplify(stmt) - stmt = tvm.tir.ir_pass.StorageRewrite(stmt) + + func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None) + mod = tvm.IRModule.from_expr(func) + mod = tvm.tir.transform.StorageFlatten(64)(mod) + + mod = tvm.tir.transform.Simplify()(mod) + mod = tvm.tir.transform.StorageRewrite()(mod) + stmt = mod["main"].body + # verify only have one allocations. # verify inplace folding works num_alloc = [0] def verify(n): if isinstance(n, tvm.tir.Allocate): num_alloc[0] += 1 - tvm.tir.ir_pass.PostOrderVisit(stmt, verify) + tvm.tir.stmt_functor.post_order_visit(stmt, verify) assert num_alloc[0] == 1 def register_mem(scope_tb, max_bits): @@ -72,13 +75,16 @@ def test_alloc_seq(): A[j] = 1.3 body = ib.get() - body = tvm.tir.ir_pass.StorageRewrite(body) + + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body)) + body = tvm.tir.transform.StorageRewrite()(mod)["main"].body + num_alloc = [0] def verify(n): if isinstance(n, tvm.tir.Allocate): num_alloc[0] += 1 assert n.extents[0].value == 200 - tvm.tir.ir_pass.PostOrderVisit(body, verify) + tvm.tir.stmt_functor.post_order_visit(body, verify) assert num_alloc[0] == 1 def test_alloc_different_dtypes(): @@ -129,8 +135,11 @@ def verify(n): body = stmt_generater(dtype_list, length) offset = offset_generater(dtype_list, length) - body = tvm.tir.ir_pass.StorageRewrite(body) - tvm.tir.ir_pass.PostOrderVisit(body, verify) + + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], body)) + body = tvm.tir.transform.StorageRewrite()(mod)["main"].body + + tvm.tir.stmt_functor.post_order_visit(body, verify) length = 1024 dtype_list = ["float16", "int32", "uint16", "int8"] @@ -157,19 +166,22 @@ def test_inplace_rule(): bounds = tvm.te.schedule.InferBound(s) assert isinstance(bounds, tvm.container.Map) stmt = tvm.te.schedule.ScheduleOps(s, bounds) - Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A') - Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B') - stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64) - stmt = tvm.tir.ir_pass.CanonicalSimplify(stmt) - stmt = tvm.tir.ir_pass.Simplify(stmt) - stmt = tvm.tir.ir_pass.StorageRewrite(stmt) + + func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None) + mod = tvm.IRModule.from_expr(func) + mod = tvm.tir.transform.StorageFlatten(64)(mod) + + mod = tvm.tir.transform.Simplify()(mod) + mod = tvm.tir.transform.StorageRewrite()(mod) + stmt = mod["main"].body + # verify only have one allocations. # verify inplace folding works num_alloc = [0] def verify(n): if isinstance(n, tvm.tir.Allocate): num_alloc[0] += 1 - tvm.tir.ir_pass.PostOrderVisit(stmt, verify) + tvm.tir.stmt_functor.post_order_visit(stmt, verify) assert num_alloc[0] == 2 @@ -189,18 +201,20 @@ def test_storage_combine(): bounds = tvm.te.schedule.InferBound(s) assert isinstance(bounds, tvm.container.Map) stmt = tvm.te.schedule.ScheduleOps(s, bounds) - Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A') - Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B') - stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64) - stmt = tvm.tir.ir_pass.CanonicalSimplify(stmt) - stmt = tvm.tir.ir_pass.Simplify(stmt) - stmt = tvm.tir.ir_pass.StorageRewrite(stmt) + func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None) + mod = tvm.IRModule.from_expr(func) + mod = tvm.tir.transform.StorageFlatten(64)(mod) + + mod = tvm.tir.transform.Simplify()(mod) + mod = tvm.tir.transform.StorageRewrite()(mod) + stmt = mod["main"].body + num_alloc = [0] def verify(n): if isinstance(n, tvm.tir.Allocate): num_alloc[0] += 1 assert (n.extents[0].value == 16) - tvm.tir.ir_pass.PostOrderVisit(stmt, verify) + tvm.tir.stmt_functor.post_order_visit(stmt, verify) assert num_alloc[0] == 1 @@ -223,19 +237,20 @@ def test_storage_share_gpu(): bounds = tvm.te.schedule.InferBound(s) assert isinstance(bounds, tvm.container.Map) stmt = tvm.te.schedule.ScheduleOps(s, bounds) - Ab = tvm.tir.decl_buffer(A[0].shape, A[0].dtype, name='A') - Bb = tvm.tir.decl_buffer(A[0].shape, A[0].dtype, name='B') - stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A[0]: Ab, A[-1]: Bb}, 64) - stmt = tvm.tir.ir_pass.CanonicalSimplify(stmt) - stmt = tvm.tir.ir_pass.Simplify(stmt) - stmt = tvm.tir.ir_pass.StorageRewrite(stmt) + func = tvm.te.schedule.SchedulePostProcToPrimFunc([A[0], A[-1]], stmt, None) + mod = tvm.IRModule.from_expr(func) + mod = tvm.tir.transform.StorageFlatten(64)(mod) + mod = tvm.tir.transform.Simplify()(mod) + mod = tvm.tir.transform.StorageRewrite()(mod) + stmt = mod["main"].body + alloc_stats = {"global": 0, "shared": 0} def verify(n): if isinstance(n, tvm.tir.AttrStmt): if n.attr_key == "storage_scope": alloc_stats[n.value.value] += 1 - tvm.tir.ir_pass.PostOrderVisit(stmt, verify) + tvm.tir.stmt_functor.post_order_visit(stmt, verify) assert alloc_stats["global"] == 2 assert alloc_stats["shared"] == num_stage @@ -248,7 +263,9 @@ def test_parallel_alloc(): A[j] = A[j] + 2 body = ib.get() - body = tvm.tir.ir_pass.StorageRewrite(body) + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body)) + body = tvm.tir.transform.StorageRewrite()(mod)["main"].body + assert (isinstance(body.body.body, tvm.tir.Allocate)) ib = tvm.tir.ir_builder.create() @@ -262,7 +279,9 @@ def test_parallel_alloc(): A = ib.allocate("float32", n, name="A", scope="global") A[j] = A[j] + 2 body = ib.get() - body = tvm.tir.ir_pass.StorageRewrite(body) + + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body)) + body = tvm.tir.transform.StorageRewrite()(mod)["main"].body assert(isinstance(body.body.body.body.body, tvm.tir.Allocate)) @@ -284,21 +303,22 @@ def test_inplace_rule2(scope_tb = "local_TB2", max_bits = 1024 * 1024 * 1024): bounds = tvm.te.schedule.InferBound(s) assert isinstance(bounds, tvm.container.Map) stmt = tvm.te.schedule.ScheduleOps(s, bounds) - Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A') - Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B') - Cc = tvm.tir.decl_buffer(C.shape, B.dtype, name='C') - Dd = tvm.tir.decl_buffer(D.shape, B.dtype, name='D') - stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb, C: Cc, D:Dd}, 64) - stmt = tvm.tir.ir_pass.CanonicalSimplify(stmt) - stmt = tvm.tir.ir_pass.Simplify(stmt) - stmt = tvm.tir.ir_pass.StorageRewrite(stmt) + + func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B, C, D], stmt, None) + mod = tvm.IRModule.from_expr(func) + mod = tvm.tir.transform.StorageFlatten(64)(mod) + + mod = tvm.tir.transform.Simplify()(mod) + mod = tvm.tir.transform.StorageRewrite()(mod) + stmt = mod["main"].body + # verify only have one allocations. # verify inplace folding works num_alloc = [0] def verify(n): if isinstance(n, tvm.tir.Allocate): num_alloc[0] += 1 - tvm.tir.ir_pass.PostOrderVisit(stmt, verify) + tvm.tir.stmt_functor.post_order_visit(stmt, verify) assert num_alloc[0] == 2 def test_exceed_mem(): @@ -373,24 +393,21 @@ def test_inplace_rule3(): assert isinstance(bounds, tvm.container.Map) stmt = tvm.te.schedule.ScheduleOps(s, bounds) - B0a = tvm.tir.decl_buffer(B0.shape, B0.dtype, name='B0') - B1a = tvm.tir.decl_buffer(B1.shape, B1.dtype, name='B1') - B2a = tvm.tir.decl_buffer(B2.shape, B2.dtype, name='B2') - B3a = tvm.tir.decl_buffer(B3.shape, B3.dtype, name='B3') - B4a = tvm.tir.decl_buffer(B4.shape, B4.dtype, name='B4') - B5a = tvm.tir.decl_buffer(B5.shape, B5.dtype, name='B5') - - Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B') - stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {B0: B0a, B1: B1a, B2: B2a, B3: B2a, B4: B4a, B5: B5a, B: Bb}, 64) - stmt = tvm.tir.ir_pass.CanonicalSimplify(stmt) - stmt = tvm.tir.ir_pass.Simplify(stmt) - stmt = tvm.tir.ir_pass.StorageRewrite(stmt) + func = tvm.te.schedule.SchedulePostProcToPrimFunc( + [B0, B1, B2, B3, B4, B5, B], stmt, None) + mod = tvm.IRModule.from_expr(func) + mod = tvm.tir.transform.StorageFlatten(64)(mod) + + mod = tvm.tir.transform.Simplify()(mod) + mod = tvm.tir.transform.StorageRewrite()(mod) + stmt = mod["main"].body + # verify only have one allocations. # verify inplace folding works def verify(n): if isinstance(n, tvm.tir.Allocate): assert n.extents[0].value == 70 - tvm.tir.ir_pass.PostOrderVisit(stmt, verify) + tvm.tir.stmt_functor.post_order_visit(stmt, verify) def test_alloc_seq_type(): ib = tvm.tir.ir_builder.create() @@ -411,13 +428,16 @@ def test_alloc_seq_type(): A2[j] = A[j] body = ib.get() - body = tvm.tir.ir_pass.StorageRewrite(body) + + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body)) + body = tvm.tir.transform.StorageRewrite()(mod)["main"].body + num_alloc = [0] def verify(n): if isinstance(n, tvm.tir.Allocate): num_alloc[0] += 1 assert n.extents[0].value == 500 - tvm.tir.ir_pass.PostOrderVisit(body, verify) + tvm.tir.stmt_functor.post_order_visit(body, verify) assert num_alloc[0] == 1 def test_alloc_seq_type2(): @@ -440,13 +460,16 @@ def test_alloc_seq_type2(): C[j] = 1.2 body = ib.get() - body = tvm.tir.ir_pass.StorageRewrite(body) + + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body)) + body = tvm.tir.transform.StorageRewrite()(mod)["main"].body + num_alloc = [0] def verify(n): if isinstance(n, tvm.tir.Allocate): num_alloc[0] += 1 assert n.extents[0].value == 200 - tvm.tir.ir_pass.PostOrderVisit(body, verify) + tvm.tir.stmt_functor.post_order_visit(body, verify) assert num_alloc[0] == 1 @@ -469,7 +492,9 @@ def test_reuse_small_buffer(): E[j] = C[j] body = ib.get() - body = tvm.tir.ir_pass.StorageRewrite(body) + + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body)) + body = tvm.tir.transform.StorageRewrite()(mod)["main"].body num_alloc = [0] @@ -477,7 +502,7 @@ def verify(n): if isinstance(n, tvm.tir.Allocate): num_alloc[0] += 1 assert n.extents[0].value == 800 - tvm.tir.ir_pass.PostOrderVisit(body, verify) + tvm.tir.stmt_functor.post_order_visit(body, verify) assert num_alloc[0] == 1 def test_replace_dataflow(): @@ -511,22 +536,23 @@ def compute(a, b): c = te.compute(shape, lambda i, j: compute(a, b)[i, j]) c = te.compute(shape, lambda i, j: 1 + c[i, j]) s = te.create_schedule(c.op) - stmt = tvm.lower(s, [a, b, c], simple_mode=True) + stmt = tvm.lower(s, [a, b, c])["main"].body def verify(n): if isinstance(n, tvm.tir.Allocate): assert n.extents[0].value == 268435456 - tvm.tir.ir_pass.PostOrderVisit(stmt, verify) + tvm.tir.stmt_functor.post_order_visit(stmt, verify) if __name__ == "__main__": + test_storage_share() test_alloc_seq() test_alloc_different_dtypes() test_inplace_rule() - test_storage_share() test_parallel_alloc() test_storage_combine() test_storage_share_gpu() test_inplace_rule2() + test_exceed_mem() test_inplace_rule3() test_alloc_seq_type() diff --git a/tests/python/unittest/test_tir_transform_thread_sync.py b/tests/python/unittest/test_tir_transform_thread_sync.py index 9257f6cd3320..783b66983c48 100644 --- a/tests/python/unittest/test_tir_transform_thread_sync.py +++ b/tests/python/unittest/test_tir_transform_thread_sync.py @@ -34,15 +34,15 @@ def test_thread_storage_sync(): bounds = tvm.te.schedule.InferBound(s) assert isinstance(bounds, tvm.container.Map) stmt = tvm.te.schedule.ScheduleOps(s, bounds) - Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A') - A2b = tvm.tir.decl_buffer(A2.shape, A2.dtype, name='A2') - stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b}, 64) + + func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, A2], stmt, None) + mod = tvm.IRModule.from_expr(func) + mod = tvm.tir.transform.StorageFlatten(64)(mod._move()) cuda_target = tvm.target.create("cuda") - mod = tvm.IRModule.from_expr( - tvm.tir.PrimFunc([Ab, A2b], stmt).with_attr({ - "global_symbol": "test", "target": cuda_target})) + mod = tvm.tir.transform.Apply(lambda f: f.with_attr({ + "global_symbol": "test", "target": cuda_target}))(mod._move()) fdevice = tvm.tir.transform.SplitHostDevice()(mod)["test_kernel0"] mod = tvm.IRModule.from_expr(fdevice) diff --git a/tests/python/unittest/test_tir_pass_unroll.py b/tests/python/unittest/test_tir_transform_unroll_loop.py similarity index 56% rename from tests/python/unittest/test_tir_pass_unroll.py rename to tests/python/unittest/test_tir_transform_unroll_loop.py index 165edab55f4e..68639940bb05 100644 --- a/tests/python/unittest/test_tir_pass_unroll.py +++ b/tests/python/unittest/test_tir_transform_unroll_loop.py @@ -31,14 +31,24 @@ def test_unroll_loop(): Aptr[j + 1] = Aptr[i] + 1 stmt = ib.get() + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab], stmt)) + assert isinstance(stmt, tvm.tir.For) - ret = tvm.tir.ir_pass.UnrollLoop(stmt, 16, 8, 0, True) - assert not isinstance(ret, tvm.tir.For) - ret = tvm.tir.ir_pass.UnrollLoop(stmt, 15, 8, 0, True) - assert isinstance(ret, tvm.tir.For) - ret = tvm.tir.ir_pass.UnrollLoop(stmt, 16, 8, 0, False) - assert isinstance(ret, tvm.tir.For) - assert ret.for_type == tvm.tir.For.Unrolled + + with tvm.transform.PassContext(config={"tir.UnrollLoop": {"auto_max_step": 16}}): + ret = tvm.tir.transform.UnrollLoop()(mod)["main"].body + assert not isinstance(ret, tvm.tir.For) + + with tvm.transform.PassContext(config={"tir.UnrollLoop": {"auto_max_step": 15}}): + ret = tvm.tir.transform.UnrollLoop()(mod)["main"].body + assert isinstance(ret, tvm.tir.For) + + with tvm.transform.PassContext(config={ + "tir.UnrollLoop": {"auto_max_step": 16, "explicit_unroll": False} + }): + ret = tvm.tir.transform.UnrollLoop()(mod)["main"].body + assert isinstance(ret, tvm.tir.For) + assert ret.for_type == tvm.tir.For.Unrolled ib = tvm.tir.ir_builder.create() ib.scope_attr(tvm.tir.const(0, "int32"), "pragma_auto_unroll_max_step", 16) @@ -46,11 +56,16 @@ def test_unroll_loop(): wrapped = ib.get() wrapped = tvm.tir.SeqStmt([wrapped, stmt]) assert isinstance(ret, tvm.tir.For) - ret = tvm.tir.ir_pass.UnrollLoop(wrapped, 0, 8, 0, False) - assert isinstance(ret[0], tvm.tir.For) - assert ret[0].for_type == tvm.tir.For.Unrolled - assert isinstance(ret[1], tvm.tir.For) - assert ret[1].for_type != tvm.tir.For.Unrolled + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab], wrapped)) + + with tvm.transform.PassContext(config={ + "tir.UnrollLoop": {"auto_max_depth": 8, "explicit_unroll": False} + }): + ret = tvm.tir.transform.UnrollLoop()(mod)["main"].body + assert isinstance(ret[0], tvm.tir.For) + assert ret[0].for_type == tvm.tir.For.Unrolled + assert isinstance(ret[1], tvm.tir.For) + assert ret[1].for_type != tvm.tir.For.Unrolled def test_unroll_fake_loop(): ib = tvm.tir.ir_builder.create() @@ -65,8 +80,17 @@ def test_unroll_fake_loop(): Aptr[j + 1] = Aptr[i] + 1 stmt = ib.get() - ret = tvm.tir.ir_pass.UnrollLoop(stmt, 8, 0, 1, True) - assert isinstance(ret[0], tvm.tir.Store) + + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab], stmt)) + + with tvm.transform.PassContext(config={ + "tir.UnrollLoop": { + "auto_max_depth": 8, + "auto_max_extent": 1, + "explicit_unroll": False + }}): + ret = tvm.tir.transform.UnrollLoop()(mod)["main"].body + assert isinstance(ret[0], tvm.tir.Store) def test_unroll_single_count_loops(): n = te.size_var('n') @@ -78,8 +102,13 @@ def test_unroll_single_count_loops(): stmt = tvm.te.schedule.ScheduleOps(s, dom_map) # all parameters to UnrolLoops are default values except for # auto_unroll_max_extent which has been set to 1 (default:0) - after_unroll_stmt = tvm.tir.ir_pass.UnrollLoop(stmt, 0, 8, 1, True) - assert after_unroll_stmt == stmt + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) + + with tvm.transform.PassContext(config={ + "tir.UnrollLoop": {"auto_max_step": 1} + }): + ret = tvm.tir.transform.UnrollLoop()(mod)["main"].body + assert ret == stmt if __name__ == "__main__": test_unroll_loop() diff --git a/tests/python/unittest/test_tir_pass_vectorize.py b/tests/python/unittest/test_tir_transform_vectorize.py similarity index 82% rename from tests/python/unittest/test_tir_pass_vectorize.py rename to tests/python/unittest/test_tir_transform_vectorize.py index 2ade843361c0..d7124b6b7e89 100644 --- a/tests/python/unittest/test_tir_pass_vectorize.py +++ b/tests/python/unittest/test_tir_transform_vectorize.py @@ -28,12 +28,16 @@ def test_vectorize_loop(): stmt = ib.get() assert isinstance(stmt.body, tvm.tir.For) - stmt = tvm.tir.ir_pass.VectorizeLoop(stmt) + + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n], stmt)) + stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body + assert isinstance(stmt, tvm.tir.For) assert not isinstance(stmt.body, tvm.tir.For) assert isinstance(stmt.body.index, tvm.tir.Ramp) assert isinstance(stmt.body.value, tvm.tir.Broadcast) + def test_vectorize_vector(): dtype = 'int64' n = te.var('n') @@ -44,7 +48,10 @@ def test_vectorize_vector(): A[j] = tvm.tir.const(1, A.dtype) stmt = ib.get() assert isinstance(stmt.body, tvm.tir.For) - stmt = tvm.tir.ir_pass.VectorizeLoop(stmt) + + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n], stmt)) + stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body + assert isinstance(stmt, tvm.tir.For) assert not isinstance(stmt.body, tvm.tir.For) assert isinstance(stmt.body.index, tvm.tir.Ramp) @@ -63,13 +70,17 @@ def test_vectorize_with_if(): with ib.if_scope(i < n): A[i] = 2.0 stmt = ib.get() - stmt = tvm.tir.ir_pass.VectorizeLoop(stmt) + + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n, x], stmt)) + stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body + assert isinstance(stmt, tvm.tir.IfThenElse) assert isinstance(stmt.then_case.index, tvm.tir.Ramp) assert isinstance(stmt.then_case.value, tvm.tir.Add) assert stmt.then_case.value.dtype == "float32x4" assert isinstance(stmt.else_case, tvm.tir.For) + def test_vectorize_with_le_cond(): n = te.var('n') ib = tvm.tir.ir_builder.create() @@ -78,9 +89,13 @@ def test_vectorize_with_le_cond(): with ib.if_scope(i <= n): A[i] = A[i] + 1 stmt = ib.get() - stmt = tvm.tir.ir_pass.VectorizeLoop(stmt) + + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n], stmt)) + stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body + assert isinstance(stmt, tvm.tir.For) + def test_vectorize_with_ge_cond(): n = te.var('n') ib = tvm.tir.ir_builder.create() @@ -89,9 +104,13 @@ def test_vectorize_with_ge_cond(): with ib.if_scope(i >= n): A[i] = A[i] + 1 stmt = ib.get() - stmt = tvm.tir.ir_pass.VectorizeLoop(stmt) + + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n], stmt)) + stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body + assert isinstance(stmt, tvm.tir.For) + def test_vectorize_if_then_else(): n = te.var('n') x = te.var('x') @@ -102,7 +121,10 @@ def test_vectorize_if_then_else(): i > 0, A[i] + 1, A[i]) stmt = ib.get() - stmt = tvm.tir.ir_pass.VectorizeLoop(stmt) + + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n, x], stmt)) + stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body + assert isinstance(stmt, tvm.tir.For) @@ -114,8 +136,12 @@ def test_vectorize_if_then_else(): k > 0, A[k * 4 + i], 0) stmt = ib.get() + assert isinstance(stmt.body, tvm.tir.For) - stmt = tvm.tir.ir_pass.VectorizeLoop(stmt) + + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n], stmt)) + stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body + assert not isinstance(stmt.body, tvm.tir.For) assert isinstance(stmt.body.value.args[2], tvm.tir.Broadcast) diff --git a/tests/scripts/setup-pytest-env.sh b/tests/scripts/setup-pytest-env.sh new file mode 100755 index 000000000000..414186c97850 --- /dev/null +++ b/tests/scripts/setup-pytest-env.sh @@ -0,0 +1,30 @@ +#!/bin/bash +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# NOTE: allow unbound variable here +set +u + +if [[ ! -z $CI_PYTEST_ADD_OPTIONS ]]; then + export PYTEST_ADDOPTS="-v $CI_PYTEST_ADD_OPTIONS" +else + export PYTEST_ADDOPTS="-v " +fi +set -u + +export TVM_PATH=`pwd` +export PYTHONPATH=${TVM_PATH}/python:${TVM_PATH}/topi/python diff --git a/tests/scripts/task_config_build_cpu.sh b/tests/scripts/task_config_build_cpu.sh index 912e59eb0330..ce545bde6609 100755 --- a/tests/scripts/task_config_build_cpu.sh +++ b/tests/scripts/task_config_build_cpu.sh @@ -29,7 +29,7 @@ echo set\(USE_MICRO_STANDALONE_RUNTIME ON\) >> config.cmake echo set\(USE_GRAPH_RUNTIME_DEBUG ON\) >> config.cmake echo set\(USE_VM_PROFILER ON\) >> config.cmake echo set\(USE_EXAMPLE_EXT_RUNTIME ON\) >> config.cmake -echo set\(USE_LLVM llvm-config-8\) >> config.cmake +echo set\(USE_LLVM llvm-config-10\) >> config.cmake echo set\(USE_NNPACK ON\) >> config.cmake echo set\(NNPACK_PATH /NNPACK/build/\) >> config.cmake echo set\(USE_ANTLR ON\) >> config.cmake @@ -38,3 +38,6 @@ echo set\(CMAKE_CXX_FLAGS -Werror\) >> config.cmake echo set\(HIDE_PRIVATE_SYMBOLS ON\) >> config.cmake echo set\(USE_VTA_TSIM ON\) >> config.cmake echo set\(USE_VTA_FSIM ON\) >> config.cmake +echo set\(USE_TFLITE ON\) >> config.cmake +echo set\(USE_TENSORFLOW_PATH \"/tensorflow\"\) >> config.cmake +echo set\(USE_FLATBUFFERS_PATH \"/flatbuffers\"\) >> config.cmake diff --git a/tests/scripts/task_config_build_gpu.sh b/tests/scripts/task_config_build_gpu.sh index 73a960981b7e..c7b073728b70 100755 --- a/tests/scripts/task_config_build_gpu.sh +++ b/tests/scripts/task_config_build_gpu.sh @@ -40,6 +40,8 @@ echo set\(USE_GRAPH_RUNTIME_DEBUG ON\) >> config.cmake echo set\(USE_VM_PROFILER ON\) >> config.cmake echo set\(USE_EXAMPLE_EXT_RUNTIME ON\) >> config.cmake echo set\(USE_ANTLR ON\) >> config.cmake +echo set\(USE_VTA_TSIM ON\) >> config.cmake +echo set\(USE_VTA_FSIM ON\) >> config.cmake echo set\(USE_BLAS openblas\) >> config.cmake echo set\(CMAKE_CXX_COMPILER g++\) >> config.cmake echo set\(CMAKE_CXX_FLAGS -Werror\) >> config.cmake diff --git a/tests/scripts/task_config_build_wasm.sh b/tests/scripts/task_config_build_wasm.sh new file mode 100755 index 000000000000..cf388eb2fbdc --- /dev/null +++ b/tests/scripts/task_config_build_wasm.sh @@ -0,0 +1,38 @@ +#!/bin/bash +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +set -e +set -u + +mkdir -p build +cd build +cp ../cmake/config.cmake . + +echo set\(USE_SORT ON\) >> config.cmake +echo set\(USE_MICRO ON\) >> config.cmake +echo set\(USE_MICRO_STANDALONE_RUNTIME ON\) >> config.cmake +echo set\(USE_GRAPH_RUNTIME_DEBUG ON\) >> config.cmake +echo set\(USE_VM_PROFILER ON\) >> config.cmake +echo set\(USE_EXAMPLE_EXT_RUNTIME ON\) >> config.cmake +echo set\(USE_LLVM llvm-config-11\) >> config.cmake +echo set\(USE_ANTLR ON\) >> config.cmake +echo set\(CMAKE_CXX_COMPILER g++\) >> config.cmake +echo set\(CMAKE_CXX_FLAGS -Werror\) >> config.cmake +echo set\(HIDE_PRIVATE_SYMBOLS ON\) >> config.cmake +echo set\(USE_VTA_TSIM ON\) >> config.cmake +echo set\(USE_VTA_FSIM ON\) >> config.cmake diff --git a/tests/scripts/task_cpp_unittest.sh b/tests/scripts/task_cpp_unittest.sh index 751e98e9abdc..5ac1843253d4 100755 --- a/tests/scripts/task_cpp_unittest.sh +++ b/tests/scripts/task_cpp_unittest.sh @@ -23,6 +23,10 @@ export LD_LIBRARY_PATH="lib:${LD_LIBRARY_PATH:-}" # NOTE: important to use abspath, when VTA is enabled. export VTA_HW_PATH=`pwd`/3rdparty/vta-hw +# to avoid CI thread throttling. +export TVM_BIND_THREADS=0 +export OMP_NUM_THREADS=1 + # Remove existing testcases rm -f build/*_test diff --git a/tests/scripts/task_golang.sh b/tests/scripts/task_golang.sh index 49965793f6b3..0ff6c39d602c 100755 --- a/tests/scripts/task_golang.sh +++ b/tests/scripts/task_golang.sh @@ -24,5 +24,9 @@ export LD_LIBRARY_PATH="lib:${LD_LIBRARY_PATH:-}" tvm_root="$(git rev-parse --show-toplevel)" export PYTHONPATH="$tvm_root/python":"$tvm_root/topi/python" +# to avoid CI CPU thread throttling. +export TVM_BIND_THREADS=0 +export OMP_NUM_THREADS=1 + # Golang tests make -C golang tests diff --git a/tests/scripts/task_java_unittest.sh b/tests/scripts/task_java_unittest.sh index 63f16fd755f6..7ab4afae3c2e 100755 --- a/tests/scripts/task_java_unittest.sh +++ b/tests/scripts/task_java_unittest.sh @@ -22,6 +22,10 @@ set -u export PYTHONPATH=python export LD_LIBRARY_PATH="lib:${LD_LIBRARY_PATH:-}" +# to avoid CI CPU thread throttling. +export TVM_BIND_THREADS=0 +export OMP_NUM_THREADS=1 + CURR_DIR=$(cd `dirname $0`; pwd) SCRIPT_DIR=$CURR_DIR/../../jvm/core/src/test/scripts TEMP_DIR=$(mktemp -d) diff --git a/tests/scripts/task_lint.sh b/tests/scripts/task_lint.sh index 544ef7224770..37c21b2b2d7b 100755 --- a/tests/scripts/task_lint.sh +++ b/tests/scripts/task_lint.sh @@ -46,6 +46,13 @@ fi echo "Check codestyle of c++ code..." make cpplint + +echo "clang-format check..." +# check lastest change, for squash merge into master +./tests/lint/git-clang-format.sh HEAD~1 +# chekc against origin/master for PRs. +./tests/lint/git-clang-format.sh origin/dev + echo "Check codestyle of python code..." make pylint echo "Check codestyle of jni code..." diff --git a/tests/scripts/task_python_docs.sh b/tests/scripts/task_python_docs.sh index d24ed1af63ea..c239abb7bc9d 100755 --- a/tests/scripts/task_python_docs.sh +++ b/tests/scripts/task_python_docs.sh @@ -19,6 +19,18 @@ set -e set -u +source tests/scripts/setup-pytest-env.sh + +# to avoid CI CPU thread throttling. +export TVM_BIND_THREADS=0 +export OMP_NUM_THREADS=4 + +cleanup() +{ + rm -rf /tmp/$$.log.txt +} +trap cleanup 0 + # cleanup old states rm -rf docs/_build mkdir -p docs/_build/html @@ -35,29 +47,39 @@ find . -type f -path "*.pyc" | xargs rm -f make cython3 cd docs -PYTHONPATH=`pwd`/../python make html +PYTHONPATH=`pwd`/../python make html |& tee /tmp/$$.log.txt +if grep -E "failed to execute" < /tmp/$$.log.txt; then + echo "Some of sphinx-gallery item example failed to execute." + exit 1 +fi cd .. # C++ doc make doc rm -f docs/doxygen/html/*.map docs/doxygen/html/*.md5 -# JS doc -jsdoc -c web/.jsdoc_conf.json web/tvm_runtime.js web/README.md - # Java doc make javadoc +# type doc +cd web +npm install +npm run typedoc +cd .. + # Prepare the doc dir rm -rf _docs mv docs/_build/html _docs rm -f _docs/.buildinfo -mv docs/doxygen/html _docs/doxygen -mv out _docs/jsdoc -mv jvm/core/target/site/apidocs _docs/javadoc +mkdir -p _docs/api +mv docs/doxygen/html _docs/api/doxygen +mv jvm/core/target/site/apidocs _docs/api/javadoc +mv web/dist/docs _docs/api/typedoc echo "Start creating the docs tarball.." # make the tarball tar -C _docs -czf docs.tgz . echo "Finish creating the docs tarball" du -h docs.tgz + +echo "Finish everything" diff --git a/tests/scripts/task_python_frontend.sh b/tests/scripts/task_python_frontend.sh index 862de5a81c73..e5f9b20e3325 100755 --- a/tests/scripts/task_python_frontend.sh +++ b/tests/scripts/task_python_frontend.sh @@ -19,7 +19,7 @@ set -e set -u -export PYTHONPATH=python:topi/python +source tests/scripts/setup-pytest-env.sh # to avoid openblas threading error export TVM_BIND_THREADS=0 export OMP_NUM_THREADS=1 @@ -29,29 +29,23 @@ find . -type f -path "*.pyc" | xargs rm -f # Rebuild cython make cython3 -echo "Running relay TFLite frontend test..." -python3 -m pytest -v tests/python/frontend/tflite - echo "Running relay MXNet frontend test..." -python3 -m pytest -v tests/python/frontend/mxnet - -echo "Running relay Keras frontend test..." -python3 -m pytest -v tests/python/frontend/keras +python3 -m pytest tests/python/frontend/mxnet echo "Running relay ONNX frontend test..." -python3 -m pytest -v tests/python/frontend/onnx +python3 -m pytest tests/python/frontend/onnx echo "Running relay CoreML frontend test..." -python3 -m pytest -v tests/python/frontend/coreml +python3 -m pytest tests/python/frontend/coreml echo "Running relay Tensorflow frontend test..." -python3 -m pytest -v tests/python/frontend/tensorflow +python3 -m pytest tests/python/frontend/tensorflow echo "Running relay caffe2 frontend test..." -python3 -m pytest -v tests/python/frontend/caffe2 +python3 -m pytest tests/python/frontend/caffe2 echo "Running relay DarkNet frontend test..." -python3 -m pytest -v tests/python/frontend/darknet +python3 -m pytest tests/python/frontend/darknet echo "Running relay PyTorch frontend test..." -python3 -m pytest -v tests/python/frontend/pytorch +python3 -m pytest tests/python/frontend/pytorch diff --git a/tests/scripts/task_web_test.sh b/tests/scripts/task_python_frontend_cpu.sh similarity index 65% rename from tests/scripts/task_web_test.sh rename to tests/scripts/task_python_frontend_cpu.sh index 947a133c1a7b..96c5ce631a17 100755 --- a/tests/scripts/task_web_test.sh +++ b/tests/scripts/task_python_frontend_cpu.sh @@ -16,28 +16,22 @@ # specific language governing permissions and limitations # under the License. +# Test frontends that only need CPU resources set -e set -u -export PYTHONPATH=python +source tests/scripts/setup-pytest-env.sh +# to avoid openblas threading error +export TVM_BIND_THREADS=0 +export OMP_NUM_THREADS=1 -cp /emsdk-portable/.emscripten ~/.emscripten -source /emsdk-portable/emsdk_env.sh +find . -type f -path "*.pyc" | xargs rm -f -export EM_CONFIG=${HOME}/.emscripten -export EM_CACHE=${HOME}/.emscripten_cache +# Rebuild cython +make cython3 -echo "Build TVM Web runtime..." -make web +echo "Running relay TFLite frontend test..." +python3 -m pytest tests/python/frontend/tflite -echo "Prepare test libraries..." -python tests/web/prepare_test_libs.py - -echo "Start testing..." - -for test in tests/web/test_*.js; do - echo node $test - node $test -done - -echo "All tests finishes..." +echo "Running relay Keras frontend test..." +python3 -m pytest tests/python/frontend/keras diff --git a/tests/scripts/task_python_integration.sh b/tests/scripts/task_python_integration.sh index cfd6bd042e82..f7539d6a55fd 100755 --- a/tests/scripts/task_python_integration.sh +++ b/tests/scripts/task_python_integration.sh @@ -19,8 +19,11 @@ set -e set -u -export PYTHONPATH=`pwd`/python:`pwd`/topi/python:`pwd`/apps/extension/python +source tests/scripts/setup-pytest-env.sh +export PYTHONPATH=${PYTHONPATH}:${TVM_PATH}/apps/extension/python export LD_LIBRARY_PATH="build:${LD_LIBRARY_PATH:-}" + +# to avoid CI CPU thread throttling. export TVM_BIND_THREADS=0 export TVM_NUM_THREADS=2 @@ -42,26 +45,26 @@ rm -rf lib make cd ../.. -TVM_FFI=cython python3 -m pytest -v apps/extension/tests -TVM_FFI=ctypes python3 -m pytest -v apps/extension/tests +TVM_FFI=cython python3 -m pytest apps/extension/tests +TVM_FFI=ctypes python3 -m pytest apps/extension/tests # Test dso plugin cd apps/dso_plugin_module rm -rf lib make cd ../.. -TVM_FFI=cython python3 -m pytest -v apps/dso_plugin_module -TVM_FFI=ctypes python3 -m pytest -v apps/dso_plugin_module +TVM_FFI=cython python3 -m pytest apps/dso_plugin_module +TVM_FFI=ctypes python3 -m pytest apps/dso_plugin_module # Do not enable TensorFlow op # TVM_FFI=cython sh prepare_and_test_tfop_module.sh # TVM_FFI=ctypes sh prepare_and_test_tfop_module.sh -TVM_FFI=ctypes python3 -m pytest -v tests/python/integration -TVM_FFI=ctypes python3 -m pytest -v tests/python/contrib +TVM_FFI=ctypes python3 -m pytest tests/python/integration +TVM_FFI=ctypes python3 -m pytest tests/python/contrib -TVM_FFI=ctypes python3 -m pytest -v tests/python/relay +TVM_FFI=ctypes python3 -m pytest tests/python/relay # Do not enable OpenGL -# TVM_FFI=cython python -m pytest -v tests/webgl -# TVM_FFI=ctypes python3 -m pytest -v tests/webgl +# TVM_FFI=cython python -m pytest tests/webgl +# TVM_FFI=ctypes python3 -m pytest tests/webgl diff --git a/tests/scripts/task_python_nightly.sh b/tests/scripts/task_python_nightly.sh index 6ada8af31c5c..c2c0eab7d26a 100755 --- a/tests/scripts/task_python_nightly.sh +++ b/tests/scripts/task_python_nightly.sh @@ -19,7 +19,7 @@ set -e set -u -export PYTHONPATH=python:topi/python +source tests/scripts/setup-pytest-env.sh # Rebuild cython make cython3 @@ -27,4 +27,4 @@ make cython3 # cleanup pycache find . -type f -path "*.pyc" | xargs rm -f -python3 -m pytest -v topi/tests/python/nightly +python3 -m pytest topi/tests/python/nightly diff --git a/tests/scripts/task_python_topi.sh b/tests/scripts/task_python_topi.sh index 5e5fcb87c51b..e483d5f7f4a6 100755 --- a/tests/scripts/task_python_topi.sh +++ b/tests/scripts/task_python_topi.sh @@ -19,7 +19,11 @@ set -e set -u -export PYTHONPATH=python:topi/python +source tests/scripts/setup-pytest-env.sh + +# to avoid CI thread throttling. +export TVM_BIND_THREADS=0 +export OMP_NUM_THREADS=1 # Rebuild cython make cython3 @@ -27,4 +31,4 @@ make cython3 # cleanup pycache find . -type f -path "*.pyc" | xargs rm -f -python3 -m pytest -v topi/tests/python +python3 -m pytest topi/tests/python diff --git a/tests/scripts/task_python_unittest.sh b/tests/scripts/task_python_unittest.sh index a5ad5cae3287..622646b76189 100755 --- a/tests/scripts/task_python_unittest.sh +++ b/tests/scripts/task_python_unittest.sh @@ -19,11 +19,11 @@ set -e set -u -export PYTHONPATH=python:topi/python +source tests/scripts/setup-pytest-env.sh # cleanup pycache find . -type f -path "*.pyc" | xargs rm -f -TVM_FFI=ctypes python3 -m pytest -v tests/python/unittest +TVM_FFI=ctypes python3 -m pytest tests/python/unittest make cython3 -TVM_FFI=cython python3 -m pytest -v tests/python/unittest +TVM_FFI=cython python3 -m pytest tests/python/unittest diff --git a/tests/scripts/task_python_vta_fsim.sh b/tests/scripts/task_python_vta_fsim.sh index f269866c39e7..8080bbe756c7 100755 --- a/tests/scripts/task_python_vta_fsim.sh +++ b/tests/scripts/task_python_vta_fsim.sh @@ -19,8 +19,12 @@ set -e set -u -export TVM_PATH=`pwd` -export PYTHONPATH=${TVM_PATH}/python:${TVM_PATH}/vta/python:${TVM_PATH}/topi/python +source tests/scripts/setup-pytest-env.sh +# to avoid CI thread throttling. +export TVM_BIND_THREADS=0 +export OMP_NUM_THREADS=1 + +export PYTHONPATH=${PYTHONPATH}:${TVM_PATH}/vta/python export VTA_HW_PATH=`pwd`/3rdparty/vta-hw # cleanup pycache @@ -36,8 +40,8 @@ cp ${VTA_HW_PATH}/config/fsim_sample.json ${VTA_HW_PATH}/config/vta_config.json # Run unit tests in functional/fast simulator echo "Running unittest in fsim..." -python3 -m pytest -v ${TVM_PATH}/vta/tests/python/unittest +python3 -m pytest ${TVM_PATH}/vta/tests/python/unittest # Run unit tests in functional/fast simulator echo "Running integration test in fsim..." -python3 -m pytest -v ${TVM_PATH}/vta/tests/python/integration +python3 -m pytest ${TVM_PATH}/vta/tests/python/integration diff --git a/tests/scripts/task_python_vta_tsim.sh b/tests/scripts/task_python_vta_tsim.sh index 49366748b895..c87d5483b8a5 100755 --- a/tests/scripts/task_python_vta_tsim.sh +++ b/tests/scripts/task_python_vta_tsim.sh @@ -19,10 +19,14 @@ set -e set -u -export TVM_PATH=`pwd` -export PYTHONPATH=${TVM_PATH}/python:${TVM_PATH}/vta/python:${TVM_PATH}/topi/python +source tests/scripts/setup-pytest-env.sh +export PYTHONPATH=${PYTHONPATH}:${TVM_PATH}/vta/python export VTA_HW_PATH=`pwd`/3rdparty/vta-hw +# to avoid CI CPU thread throttling. +export TVM_BIND_THREADS=0 +export OMP_NUM_THREADS=1 + # cleanup pycache find . -type f -path "*.pyc" | xargs rm -f @@ -51,11 +55,11 @@ make -C ${VTA_HW_PATH}/hardware/chisel USE_THREADS=0 lib # Run unit tests in cycle accurate simulator echo "Running unittest in tsim..." -python3 -m pytest -v ${TVM_PATH}/vta/tests/python/unittest +python3 -m pytest ${TVM_PATH}/vta/tests/python/unittest # Run unit tests in cycle accurate simulator echo "Running integration test in tsim..." -python3 -m pytest -v ${TVM_PATH}/vta/tests/python/integration +python3 -m pytest ${TVM_PATH}/vta/tests/python/integration # Reset default fsim simulation cp ${VTA_HW_PATH}/config/fsim_sample.json ${VTA_HW_PATH}/config/vta_config.json diff --git a/tests/scripts/task_rust.sh b/tests/scripts/task_rust.sh index fae07d34e992..17bad38fa71b 100755 --- a/tests/scripts/task_rust.sh +++ b/tests/scripts/task_rust.sh @@ -24,9 +24,13 @@ export TVM_HOME="$(git rev-parse --show-toplevel)" export LD_LIBRARY_PATH="$TVM_HOME/lib:$TVM_HOME/build:${LD_LIBRARY_PATH:-}" export PYTHONPATH="$TVM_HOME/python":"$TVM_HOME/topi/python" export RUST_DIR="$TVM_HOME/rust" -export LLVM_CONFIG_PATH=`which llvm-config-8` +export LLVM_CONFIG_PATH=`which llvm-config-10` echo "Using $LLVM_CONFIG_PATH" +# to avoid CI CPU thread throttling. +export TVM_BIND_THREADS=0 +export OMP_NUM_THREADS=1 + cd $RUST_DIR cargo fmt -- --check @@ -54,6 +58,12 @@ cd tests/test_tvm_dso cargo run cd - +# # run wasm32 test +# cd tests/test_wasm32 +# cargo build +# wasmtime $RUST_DIR/target/wasm32-wasi/debug/test-wasm32.wasm +# cd - + # run nn graph test cd tests/test_nn cargo run diff --git a/tests/scripts/task_sphinx_precheck.sh b/tests/scripts/task_sphinx_precheck.sh index 6709b281a88d..fd67b0ab539b 100755 --- a/tests/scripts/task_sphinx_precheck.sh +++ b/tests/scripts/task_sphinx_precheck.sh @@ -23,10 +23,6 @@ set -o pipefail cleanup() { - # cat error log if non zero exit - if [ $? ]; then - cat /tmp/$$.log.txt - fi rm -rf /tmp/$$.* } trap cleanup 0 @@ -40,15 +36,15 @@ make cython3 echo "PreCheck sphinx doc generation WARNINGS.." cd docs make clean -TVM_TUTORIAL_EXEC_PATTERN=none make html 2>/tmp/$$.log.txt +TVM_TUTORIAL_EXEC_PATTERN=none make html |& tee /tmp/$$.log.txt -grep -v -E "__mro__|RemovedInSphinx|UserWarning|FutureWarning|Keras" < /tmp/$$.log.txt > /tmp/$$.logclean.txt || true +grep -v -E "__mro__|UserWarning|FutureWarning|tensorflow|Keras|pytorch|TensorFlow|403" < /tmp/$$.log.txt > /tmp/$$.logclean.txt || true echo "---------Sphinx Log----------" cat /tmp/$$.logclean.txt echo "-----------------------------" if grep --quiet -E "WARN" < /tmp/$$.logclean.txt; then echo "WARNINIG found in the log, please fix them." - echo "You can reproduce locally by running ./tests/script/task_sphinx_precheck.sh" + echo "You can reproduce locally by running ./tests/scripts/task_sphinx_precheck.sh" exit 1 fi echo "No WARNINGS to be fixed." diff --git a/docker/install/ubuntu_install_opengl.sh b/tests/scripts/task_web_wasm.sh similarity index 84% rename from docker/install/ubuntu_install_opengl.sh rename to tests/scripts/task_web_wasm.sh index 9b8b6057905a..717d3284fce1 100755 --- a/docker/install/ubuntu_install_opengl.sh +++ b/tests/scripts/task_web_wasm.sh @@ -6,9 +6,9 @@ # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -18,9 +18,15 @@ set -e set -u -set -o pipefail -apt-get update --fix-missing +export PYTHONPATH=`pwd`/python -apt-get install -y --no-install-recommends \ - libgl1-mesa-dev libglfw3-dev +cd web +make clean +npm install +npm run lint +npm run prepwasm +npm run bundle +npm run test +npm run typedoc +cd .. diff --git a/tests/web/test_basic.js b/tests/web/test_basic.js deleted file mode 100644 index 6852319dbc12..000000000000 --- a/tests/web/test_basic.js +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -// Load Emscripten Module, need to change path to root/build -const path = require("path"); -process.chdir(path.join(__dirname, "../../build")); -var Module = require("../../build/libtvm_web_runtime.js"); -// Bootstrap TVMruntime with emscripten module. -const tvm_runtime = require("../../web/tvm_runtime.js"); -const tvm = tvm_runtime.create(Module); - -// Basic fields. -tvm.assert(tvm.float32 == "float32"); -tvm.assert(tvm.listGlobalFuncNames() !== "undefined"); -var sysLib = tvm.systemLib(); -tvm.assert(typeof sysLib.getFunction !== "undefined"); -sysLib.release(); - -// Test ndarray -function testArrayCopy(dtype, arr) { - var data = [1, 2, 3, 4, 5, 6]; - var a = tvm.empty([2, 3], dtype); - a.copyFrom(data); - var ret = a.asArray(); - tvm.assert(ret instanceof arr); - tvm.assert(ret.toString() == arr.from(data)); - a.release(); -} - -testArrayCopy("float32", Float32Array); -testArrayCopy("int", Int32Array); -testArrayCopy("int8", Int8Array); -testArrayCopy("uint8", Uint8Array); -testArrayCopy("float64", Float64Array); - -// Function registration -tvm.registerFunc("xyz", function(x, y) { - return x + y; -}); diff --git a/tests/web/test_packed_func.js b/tests/web/test_packed_func.js deleted file mode 100644 index d239f7346e74..000000000000 --- a/tests/web/test_packed_func.js +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -// Load Emscripten Module, need to change path to root/build -const path = require("path"); -process.chdir(path.join(__dirname, "../../build")); -var Module = require("../../build/libtvm_web_runtime.js"); -// Bootstrap TVMruntime with emscripten module. -const tvm_runtime = require("../../web/tvm_runtime.js"); -const tvm = tvm_runtime.create(Module); - -function testGetGlobal() { - var targs = [10, 10.0, "hello"] - tvm.registerFunc("my_packed_func", function () { - tvm.assert(Array.from(arguments).toString() == targs, "assert fail"); - return 10 - }); - var f = tvm.getGlobalFunc("my_packed_func") - tvm.assert(tvm.isPackedFunc(f)); - y = f.apply(null, targs); - tvm.assert(y == 10); - f.release(); -} - - -function testReturnFunc() { - function addy(y) { - function add(x) { - return x + y; - } - return add; - } - var myf = tvm.convertFunc(addy); - var f = myf(10); - tvm.assert(tvm.isPackedFunc(f)); - tvm.assert(f(11) == 21); - myf.release(); - f.release(); -} - -function testByteArray() { - var a = new Uint8Array(3); - a[0] = 1; - a[1] = 2; - function myfunc(ss){ - tvm.assert(ss instanceof Uint8Array); - tvm.assert(ss.toString() == a); - } - f = tvm.convertFunc(myfunc); - f(a); - f.release(); -} - -testGetGlobal(); -testReturnFunc(); -testByteArray(); diff --git a/tests/webgl/test_local_gemm.py b/tests/webgl/test_local_gemm.py deleted file mode 100644 index 6bd22bf0057b..000000000000 --- a/tests/webgl/test_local_gemm.py +++ /dev/null @@ -1,58 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import tvm -from tvm import te -import numpy as np - -def test_local_gemm(): - if not tvm.runtime.enabled("opengl"): - return - if not tvm.runtime.enabled("llvm"): - return - - nn = 1024 - n = te.var('n') - n = tvm.runtime.convert(nn) - m = n - l = n - A = te.placeholder((n, l), name='A', dtype='int32') - B = te.placeholder((m, l), name='B', dtype='int32') - k = te.reduce_axis((0, l), name='k') - C = te.compute((n, m), lambda ii, jj: te.sum(A[ii, k] * B[jj, k], axis=k), - name='CC') - - s = te.create_schedule(C.op) - s[C].opengl() - print(tvm.lower(s, [A, B, C], simple_mode=True)) - - f = tvm.build(s, [A, B, C], "opengl", name="gemm") - print("------opengl code------") - print(f.imported_modules[0].get_source(fmt="gl")) - - ctx = tvm.opengl() - n, m, l = nn, nn, nn - a_np = np.random.uniform(low=0, high=10, size=(n, l)).astype(A.dtype) - b_np = np.random.uniform(low=0, high=10, size=(m, l)).astype(B.dtype) - a = tvm.nd.array(a_np, ctx) - b = tvm.nd.array(b_np, ctx) - c = tvm.nd.array(np.zeros((n, m), dtype=C.dtype), ctx) - f(a, b, c) - - tvm.testing.assert_allclose(c.asnumpy(), np.dot(a_np, b_np.T)) - -if __name__ == "__main__": - test_local_gemm() diff --git a/tests/webgl/test_local_multi_stage.py b/tests/webgl/test_local_multi_stage.py deleted file mode 100644 index 54a554b74ed9..000000000000 --- a/tests/webgl/test_local_multi_stage.py +++ /dev/null @@ -1,47 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import tvm -from tvm import te -import numpy as np - -def test_local_multi_stage(): - if not tvm.runtime.enabled("opengl"): - return - if not tvm.runtime.enabled("llvm"): - return - - n = te.var("n") - A = te.placeholder((n,), name='A', dtype="int32") - B = te.compute((n,), lambda i: A[i] + 1, name="B") - C = te.compute((n,), lambda i: B[i] * 2, name="C") - - s = te.create_schedule(C.op) - s[B].opengl() - s[C].opengl() - - f = tvm.build(s, [A, C], "opengl", name="multi_stage") - - ctx = tvm.opengl(0) - n = 10 - a = tvm.nd.array(np.random.uniform(size=(n,)).astype(A.dtype), ctx) - c = tvm.nd.array(np.random.uniform(size=(n,)).astype(B.dtype), ctx) - f(a, c) - - tvm.testing.assert_allclose(c.asnumpy(), (a.asnumpy() + 1) * 2) - -if __name__ == "__main__": - test_local_multi_stage() diff --git a/tests/webgl/test_local_save_load.py b/tests/webgl/test_local_save_load.py deleted file mode 100644 index cca68020c0c2..000000000000 --- a/tests/webgl/test_local_save_load.py +++ /dev/null @@ -1,53 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import numpy as np -import tvm -from tvm import te -from tvm import rpc -from tvm.contrib import util, emscripten - -def test_local_save_load(): - if not tvm.runtime.enabled("opengl"): - return - if not tvm.runtime.enabled("llvm"): - return - - n = te.var("n") - A = te.placeholder((n,), name='A', dtype='int32') - B = te.placeholder((n,), name='B', dtype='int32') - C = te.compute(A.shape, lambda i: A[i] + B[i], name="C") - s = te.create_schedule(C.op) - s[C].opengl() - - f = tvm.build(s, [A, B, C], "opengl", target_host="llvm", name="myadd") - - ctx = tvm.opengl(0) - n = 10 - a = tvm.nd.array(np.random.uniform(high=10, size=(n)).astype(A.dtype), ctx) - b = tvm.nd.array(np.random.uniform(high=10, size=(n)).astype(B.dtype), ctx) - c = tvm.nd.array(np.zeros((n), dtype=C.dtype), ctx) - f(a, b, c) - - temp = util.tempdir() - path_so = temp.relpath("myadd.so") - f.export_library(path_so) - f1 = tvm.runtime.load_module(path_so) - f1(a, b, c) - tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy()) - -if __name__ == "__main__": - test_local_save_load() diff --git a/tests/webgl/test_local_topi_conv2d_nchw.py b/tests/webgl/test_local_topi_conv2d_nchw.py deleted file mode 100644 index 0d9b7776096a..000000000000 --- a/tests/webgl/test_local_topi_conv2d_nchw.py +++ /dev/null @@ -1,99 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Example code to do convolution. -Copied from topi/tests/python/test_topi_conv2d_nchw.py. -Should be removed once we fix OpenGL testing on Jenkins.""" -import os -import numpy as np -import tvm -from tvm import te -import topi -from tvm.contrib.pickle_memoize import memoize -from topi.util import get_const_tuple - -def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, padding): - in_height = in_width = in_size - - A = te.placeholder((batch, in_channel, in_height, in_width), name='A') - W = te.placeholder((num_filter, in_channel, kernel, kernel), name='W') - B = topi.nn.conv2d_nchw(A, W, stride, padding) - C = topi.nn.relu(B) - - a_shape = get_const_tuple(A.shape) - w_shape = get_const_tuple(W.shape) - dtype = A.dtype - - @memoize("topi.tests.test_topi_conv2d.verify_con2d_nchw") - def get_ref_data(): - a_np = np.random.uniform(size=a_shape).astype(dtype) - w_np = np.random.uniform(size=w_shape).astype(dtype) - b_np = topi.testing.conv2d_nchw_python(a_np, w_np, stride, padding) - c_np = np.maximum(b_np, 0) - return a_np, w_np, b_np, c_np - - a_np, w_np, b_np, c_np = get_ref_data() - - def check_device(device): - ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - return - print("Running on target: %s" % device) - with tvm.target.create(device): - s1 = topi.generic.schedule_conv2d_nchw([B]) - s2 = topi.generic.schedule_conv2d_nchw([C]) - a = tvm.nd.array(a_np, ctx) - w = tvm.nd.array(w_np, ctx) - b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) - c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx) - with tvm.target.build_config(auto_unroll_max_step=1400, - unroll_explicit=(device != "cuda")): - func1 = tvm.build(s1, [A, W, B], device, name="conv2d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding)) - func2 = tvm.build(s2, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding)) - func1(a, w, b) - func2(a, w, c) - tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) - tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5) - - for device in ['opengl']: - check_device(device) - - -def test_conv2d_nchw(): - # ResNet18 worklaods - verify_conv2d_nchw(1, 3, 224, 64, 7, 2, 3) - verify_conv2d_nchw(1, 64, 56, 64, 3, 1, 1) - verify_conv2d_nchw(1, 64, 56, 64, 1, 1, 0) - verify_conv2d_nchw(1, 64, 56, 128, 3, 2, 1) - verify_conv2d_nchw(1, 64, 56, 128, 1, 2, 0) - verify_conv2d_nchw(1, 128, 28, 128, 3, 1, 1) - verify_conv2d_nchw(1, 128, 28, 256, 3, 2, 1) - verify_conv2d_nchw(1, 128, 28, 256, 1, 2, 0) - verify_conv2d_nchw(1, 256, 14, 256, 3, 1, 1) - verify_conv2d_nchw(1, 256, 14, 512, 3, 2, 1) - verify_conv2d_nchw(1, 256, 14, 512, 1, 2, 0) - verify_conv2d_nchw(1, 512, 7, 512, 3, 1, 1) - # Vgg16 workloads - verify_conv2d_nchw(1, 128, 122, 128, 3, 1, 1) - # Super resolution workloads - verify_conv2d_nchw(1, 1, 224, 64, 5, 1, 2) - verify_conv2d_nchw(1, 64, 224, 64, 3, 1, 1) - verify_conv2d_nchw(1, 64, 224, 32, 3, 1, 1) - verify_conv2d_nchw(1, 32, 224, 9, 3, 1, 1) - -if __name__ == "__main__": - test_conv2d_nchw() diff --git a/tests/webgl/test_local_topi_dense.py b/tests/webgl/test_local_topi_dense.py deleted file mode 100644 index 60dfe1ff690f..000000000000 --- a/tests/webgl/test_local_topi_dense.py +++ /dev/null @@ -1,76 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Test code for dense operator -Copied from topi/tests/python/test_topi_dense.py. -Should be removed once we fix OpenGL testing on Jenkins. -""" -import numpy as np -import tvm -from tvm import te -import topi -from topi.util import get_const_tuple -from tvm.contrib.pickle_memoize import memoize - - -def verify_dense(batch, in_dim, out_dim, use_bias=True): - A = te.placeholder((batch, in_dim), name='A') - B = te.placeholder((out_dim, in_dim), name='B') - C = te.placeholder((out_dim,), name='C') - D = topi.nn.dense(A, B, C if use_bias else None) - D = topi.nn.relu(D) - dtype = A.dtype - - # use memoize to pickle the test data for next time use - @memoize("topi.tests.test_topi_dense") - def get_ref_data(): - a_np = np.random.uniform(size=(batch, in_dim)).astype(dtype) - b_np = np.random.uniform(size=(out_dim, in_dim)).astype(dtype) - c_np = np.random.uniform(size=(out_dim,)).astype(dtype) - if use_bias: - d_np = np.maximum(np.dot(a_np, b_np.T) + c_np, 0.0) - else: - d_np = np.maximum(np.dot(a_np, b_np.T), 0.0) - return (a_np, b_np, c_np, d_np) - # get the test data - a_np, b_np, c_np, d_np = get_ref_data() - - def check_device(device): - if not tvm.runtime.enabled(device): - print("Skip because %s is not enabled" % device) - return - print("Running on target: %s" % device) - with tvm.target.create(device): - s = topi.generic.schedule_dense(D) - ctx = tvm.context(device, 0) - a = tvm.nd.array(a_np, ctx) - b = tvm.nd.array(b_np, ctx) - c = tvm.nd.array(c_np, ctx) - d = tvm.nd.array(np.zeros(get_const_tuple(D.shape), dtype=dtype), ctx) - f = tvm.build(s, [A, B, C, D], device, name="dense") - f(a, b, c, d) - tvm.testing.assert_allclose(d.asnumpy(), d_np, rtol=1e-5) - - for device in ['opengl']: - check_device(device) - -def test_dense(): - verify_dense(1, 1024, 1000, use_bias=True) - verify_dense(1, 1024, 1000, use_bias=False) - - -if __name__ == "__main__": - test_dense() diff --git a/tests/webgl/test_local_topi_pooling.py b/tests/webgl/test_local_topi_pooling.py deleted file mode 100644 index 3adae7bba51c..000000000000 --- a/tests/webgl/test_local_topi_pooling.py +++ /dev/null @@ -1,132 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Test code for pooling -Copied from topi/tests/python/test_topi_pooling.py. -Should be removed once we fix OpenGL testing on Jenkins. -""" -import numpy as np -import tvm -from tvm import te -import topi -import math -from topi.util import get_const_tuple - -def verify_pool(n, ic, ih, kh, sh, padding, pool_type, ceil_mode): - iw = ih - kw = kh - sw = sh - ph, pw = padding - A = te.placeholder((n, ic, ih, iw), name='A') - B = topi.nn.pool(A, kernel=[kh, kw], stride=[sh, sw], padding=padding, - pool_type=pool_type, ceil_mode=ceil_mode) - B = topi.nn.relu(B) - dtype = A.dtype - - bshape = get_const_tuple(B.shape) - ashape = get_const_tuple(A.shape) - if ceil_mode: - assert bshape[2] == int(math.ceil(float(ashape[2] - kh + ph * 2) / sh) + 1) - assert bshape[3] == int(math.ceil(float(ashape[3] - kw + pw * 2) / sw) + 1) - else: - assert bshape[2] == int(math.floor(float(ashape[2] - kh + ph * 2) / sh) + 1) - assert bshape[3] == int(math.floor(float(ashape[3] - kw + pw * 2) / sw) + 1) - - - a_np = np.random.uniform(size=(n, ic, ih, iw)).astype(dtype) - pad_np = np.zeros(shape=(n, ic, ih+2*ph, iw+2*pw)).astype(dtype) - no_zero = (range(n), range(ic), (range(ph, ih+ph)), (range(pw, iw+pw))) - pad_np[np.ix_(*no_zero)] = a_np - _, oc, oh, ow = get_const_tuple(B.shape) - b_np = np.zeros(shape=(n, oc, oh, ow)).astype(dtype) - - if pool_type == 'avg': - for i in range(oh): - for j in range(ow): - b_np[:,:,i,j] = np.mean(pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw], axis=(2,3)) - elif pool_type =='max': - for i in range(oh): - for j in range(ow): - b_np[:,:,i,j] = np.max(pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw], axis=(2,3)) - b_np = np.maximum(b_np, 0.0) - - def check_device(device): - if not tvm.runtime.enabled(device): - print("Skip because %s is not enabled" % device) - return - print("Running on target: %s" % device) - with tvm.target.create(device): - s = topi.generic.schedule_pool(B) - ctx = tvm.context(device, 0) - a = tvm.nd.array(a_np, ctx) - b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx) - print(tvm.lower(s, [A, B], simple_mode=True)) - - f = tvm.build(s, [A, B], device) - f(a, b) - tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) - - for device in ['opengl']: - check_device(device) - -def test_pool(): - verify_pool(1, 256, 32, 2, 2, [0, 0], 'avg', False) - verify_pool(1, 256, 31, 3, 3, [1, 2], 'avg', False) - verify_pool(1, 256, 32, 2, 2, [0, 0], 'max', False) - verify_pool(1, 256, 31, 3, 3, [2, 1], 'max', False) - verify_pool(1, 256, 31, 3, 3, [2, 1], 'max', True) - - - -def verify_global_pool(n, c, h, w, pool_type): - A = te.placeholder((n, c, h, w), name='A') - B = topi.nn.global_pool(A, pool_type=pool_type) - B = topi.nn.relu(B) - - a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype) - if pool_type == 'avg': - b_np = np.mean(a_np, axis=(2,3), keepdims=True) - elif pool_type =='max': - b_np = np.max(a_np, axis=(2,3), keepdims=True) - b_np = np.maximum(b_np, 0.0) - - def check_device(device): - if not tvm.runtime.enabled(device): - print("Skip because %s is not enabled" % device) - return - print("Running on target: %s" % device) - with tvm.target.create(device): - s = topi.generic.schedule_global_pool(B) - ctx = tvm.context(device, 0) - a = tvm.nd.array(a_np, ctx) - b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) - f = tvm.build(s, [A, B], device) - f(a, b) - tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) - - for device in ['opengl']: - check_device(device) - -def test_global_pool(): - verify_global_pool(1, 1024, 7, 7, 'avg') - verify_global_pool(4, 1024, 7, 7, 'avg') - verify_global_pool(1, 1024, 7, 7, 'max') - verify_global_pool(4, 1024, 7, 7, 'max') - - -if __name__ == "__main__": - test_pool() - test_global_pool() diff --git a/tests/webgl/test_local_topi_softmax.py b/tests/webgl/test_local_topi_softmax.py deleted file mode 100644 index c0ddbf21419a..000000000000 --- a/tests/webgl/test_local_topi_softmax.py +++ /dev/null @@ -1,96 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Test code for softmax -Copied from topi/tests/python/test_topi_softmax.py. -Should be removed once we fix OpenGL testing on Jenkins. -""" - -import os -import numpy as np -import tvm -from tvm import te -import topi -import logging -from topi.util import get_const_tuple - -def verify_softmax(m, n): - A = te.placeholder((m, n), name='A') - B = topi.nn.softmax(A) - # confirm lower works - s = te.create_schedule([B.op]) - tvm.lower(s, [A, B], simple_mode=True) - - a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype) - b_np = topi.testing.softmax_python(a_np) - - def check_device(device): - if not tvm.runtime.enabled(device): - print("Skip because %s is not enabled" % device) - return - print("Running on target: %s" % device) - with tvm.target.create(device): - s = topi.generic.schedule_softmax(B) - ctx = tvm.context(device, 0) - a = tvm.nd.array(a_np, ctx) - b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) - foo = tvm.build(s, [A, B], device, name="softmax") - foo(a, b) - tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) - - for device in ["opengl"]: - check_device(device) - -def test_softmax(): - verify_softmax(32, 10) - verify_softmax(3, 4) - - -def verify_log_softmax(m, n): - A = te.placeholder((m, n), name='A') - B = topi.nn.log_softmax(A) - # confirm lower works - s = te.create_schedule([B.op]) - tvm.lower(s, [A, B], simple_mode=True) - a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype) - b_np = topi.testing.log_softmax_python(a_np) - - def check_device(device): - if not tvm.runtime.enabled(device): - print("Skip because %s is not enabled" % device) - return - print("Running on target: %s" % device) - with tvm.target.create(device): - s = topi.generic.schedule_softmax(B) - ctx = tvm.context(device, 0) - a = tvm.nd.array(a_np, ctx) - b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) - foo = tvm.build(s, [A, B], device, name="log_softmax") - foo(a, b) - tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) - - for device in ["opengl"]: - check_device(device) - - -def test_log_softmax(): - verify_log_softmax(32, 10) - verify_log_softmax(3, 4) - -if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) - test_softmax() - test_log_softmax() diff --git a/tests/webgl/test_remote_save_load.py b/tests/webgl/test_remote_save_load.py deleted file mode 100644 index 34bbb3fa0f00..000000000000 --- a/tests/webgl/test_remote_save_load.py +++ /dev/null @@ -1,96 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -The following instruction is based on web/README.md. - -Setup an RPC server: -$ python -m tvm.exec.rpc_proxy --example-rpc=1 - -Go to http://localhost:9190 in browser. - -Click "Connect To Proxy". - -Run this test script: -$ python tests/webgl/test_remote_save_load.py -""" - -import numpy as np -import tvm -from tvm import te -from tvm import rpc -from tvm.contrib import util, emscripten - -proxy_host = "localhost" -proxy_port = 9090 - -def try_remote_save_load(): - if not tvm.runtime.enabled("rpc"): - return - if not tvm.runtime.enabled("opengl"): - return - if not tvm.runtime.enabled("llvm"): - return - - # Build the module. - n = te.var("n") - A = te.placeholder((n,), name='A') - B = te.placeholder((n,), name='B') - C = te.compute(A.shape, lambda i: A[i] + B[i], name="C") - s = te.create_schedule(C.op) - s[C].opengl() - target_host = "llvm -target=asmjs-unknown-emscripten -system-lib" - f = tvm.build(s, [A, B, C], "opengl", target_host=target_host, name="myadd") - - remote = rpc.connect(proxy_host, proxy_port, key="js") - - temp = util.tempdir() - ctx = remote.opengl(0) - path_obj = temp.relpath("myadd.bc") - path_dso = temp.relpath("myadd.js") - path_gl = temp.relpath("myadd.gl") - path_json = temp.relpath("myadd.tvm_meta.json") - - f.save(path_obj) - emscripten.create_js(path_dso, path_obj, side_module=True) - f.imported_modules[0].save(path_gl) - - remote.upload(path_dso, "myadd.dso") - remote.upload(path_gl) - remote.upload(path_json) - - remote.download("myadd.dso") - remote.download("myadd.gl") - remote.download("myadd.tvm_meta.json") - - print('Loading myadd.dso') - fhost = remote.load_module("myadd.dso") - - print('Loading myadd.gl') - fdev = remote.load_module("myadd.gl") - - print('import_module') - fhost.import_module(fdev) - - print('running...') - a = tvm.nd.array(np.random.uniform(size=16).astype(A.dtype), ctx) - b = tvm.nd.array(np.zeros(16, dtype=A.dtype), ctx) - c = tvm.nd.array(np.zeros(16, dtype=C.dtype), ctx) - fhost(a, b, c) - tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy()) - -if __name__ == "__main__": - try_remote_save_load() diff --git a/tests/webgl/test_static_webgl_library.html b/tests/webgl/test_static_webgl_library.html deleted file mode 100644 index f9268c65edf3..000000000000 --- a/tests/webgl/test_static_webgl_library.html +++ /dev/null @@ -1,72 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - TVM RPC Test Page - - - -

TVM Test Page

-
- - - - - - - - \ No newline at end of file diff --git a/tests/webgl/test_static_webgl_library.py b/tests/webgl/test_static_webgl_library.py deleted file mode 100644 index 929da4ca294c..000000000000 --- a/tests/webgl/test_static_webgl_library.py +++ /dev/null @@ -1,66 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Create a static WebGL library and run it in the browser.""" - -from __future__ import absolute_import, print_function - -import os, shutil, SimpleHTTPServer, SocketServer -import tvm -from tvm import te -from tvm.contrib import emscripten, util -import numpy as np - -def try_static_webgl_library(): - curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) - - # Change to lib/ which contains "libtvm_runtime.bc". - os.chdir(os.path.join(curr_path, "../../lib")) - - # Create OpenGL module. - n = te.var("n") - A = te.placeholder((n,), name='A', dtype="float") - B = te.compute((n,), lambda *i: A[i], name="B") - - s = te.create_schedule(B.op) - s[B].opengl() - - target_host = "llvm -target=asmjs-unknown-emscripten -system-lib" - f = tvm.build(s, [A, B], name="identity", target="opengl", - target_host=target_host) - - # Create a JS library that contains both the module and the tvm runtime. - path_dso = "identity_static.js" - f.export_library(path_dso, emscripten.create_js, options=[ - "-s", "USE_GLFW=3", - "-s", "USE_WEBGL2=1", - "-lglfw", - ]) - - # Create "tvm_runtime.js" and "identity_static.html" in lib/ - shutil.copyfile(os.path.join(curr_path, "../../web/tvm_runtime.js"), - "tvm_runtime.js") - shutil.copyfile(os.path.join(curr_path, "test_static_webgl_library.html"), - "identity_static.html") - - port = 8080 - handler = SimpleHTTPServer.SimpleHTTPRequestHandler - httpd = SocketServer.TCPServer(("", port), handler) - print("Please open http://localhost:" + str(port) + "/identity_static.html") - httpd.serve_forever() - -if __name__ == "__main__": - try_static_webgl_library() diff --git a/topi/include/topi/broadcast.h b/topi/include/topi/broadcast.h index 98614c3d4903..1b36ace4608f 100644 --- a/topi/include/topi/broadcast.h +++ b/topi/include/topi/broadcast.h @@ -28,8 +28,8 @@ #include #include -#include #include +#include namespace topi { @@ -49,8 +49,8 @@ inline tvm::te::Tensor broadcast_to(const tvm::te::Tensor& t, std::string name = "T_broadcast_to", std::string tag = kBroadcast) { CHECK_GE(output_shape.size(), t->shape.size()) - << "Not a broadcast, output dimensionality smaller than input.\noutput: " - << output_shape << "\nvs\ninput: " << t; + << "Not a broadcast, output dimensionality smaller than input.\noutput: " << output_shape + << "\nvs\ninput: " << t; auto bh = detail::BroadcastShape(output_shape, t->shape); CHECK_EQ(output_shape.size(), bh.common_shape.size()); for (size_t i = 0; i < output_shape.size(); ++i) { @@ -59,57 +59,39 @@ inline tvm::te::Tensor broadcast_to(const tvm::te::Tensor& t, auto l = [&](tvm::Array ovars) { return t(detail::InputIndexFromBroadcast(ovars, t, bh.vars2, bh.all_vars)); }; - return tvm::te::compute( - tvm::Array(bh.common_shape.begin(), bh.common_shape.end()), - l, - name, - tag); + return tvm::te::compute(tvm::Array(bh.common_shape.begin(), bh.common_shape.end()), + l, name, tag); } -#define TOPI_DEFINE_BCAST_OP(Name, ComputeRule) \ - inline tvm::PrimExpr Name(const tvm::PrimExpr& a, \ - const tvm::PrimExpr& b) { \ - ComputeRule; \ - } \ - inline tvm::te::Tensor Name(const tvm::te::Tensor& A, \ - const tvm::te::Tensor& B, \ - std::string name = "T_" #Name, \ - std::string tag = kBroadcast) { \ - auto l = [](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; }; \ - return detail::WithBroadcast(l, A, B, name, tag); \ - } \ - inline tvm::te::Tensor Name(const tvm::te::Tensor& A, \ - const tvm::PrimExpr& B, \ - std::string name = "T_" #Name, \ - std::string tag = kElementWise) { \ - auto l = [](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; }; \ - return tvm::te::compute(A->shape, [&](const ::tvm::Array<::tvm::tir::Var>& i) { \ - return l(A(i), B); \ - }, name, tag); \ - } \ - inline tvm::te::Tensor Name(const tvm::PrimExpr& A, \ - const tvm::te::Tensor& B, \ - std::string name = "T_" #Name, \ - std::string tag = kElementWise) { \ - auto l = [&](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; }; \ - return tvm::te::compute(B->shape, [&](const ::tvm::Array<::tvm::tir::Var>& i) { \ - return l(A, B(i)); \ - }, name, tag); \ +#define TOPI_DEFINE_BCAST_OP(Name, ComputeRule) \ + inline tvm::PrimExpr Name(const tvm::PrimExpr& a, const tvm::PrimExpr& b) { ComputeRule; } \ + inline tvm::te::Tensor Name(const tvm::te::Tensor& A, const tvm::te::Tensor& B, \ + std::string name = "T_" #Name, std::string tag = kBroadcast) { \ + auto l = [](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; }; \ + return detail::WithBroadcast(l, A, B, name, tag); \ + } \ + inline tvm::te::Tensor Name(const tvm::te::Tensor& A, const tvm::PrimExpr& B, \ + std::string name = "T_" #Name, std::string tag = kElementWise) { \ + auto l = [](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; }; \ + return tvm::te::compute( \ + A->shape, [&](const ::tvm::Array<::tvm::tir::Var>& i) { return l(A(i), B); }, name, tag); \ + } \ + inline tvm::te::Tensor Name(const tvm::PrimExpr& A, const tvm::te::Tensor& B, \ + std::string name = "T_" #Name, std::string tag = kElementWise) { \ + auto l = [&](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; }; \ + return tvm::te::compute( \ + B->shape, [&](const ::tvm::Array<::tvm::tir::Var>& i) { return l(A, B(i)); }, name, tag); \ } - -#define TOPI_DEFINE_OP_OVERLOAD(Name, OpName) \ - inline tvm::te::Tensor Name(const tvm::te::Tensor& A, \ - const tvm::te::Tensor& B) { \ - return topi::OpName(A, B); \ - } \ - inline tvm::te::Tensor Name(const tvm::PrimExpr& A, \ - const tvm::te::Tensor& B) { \ - return topi::OpName(A, B); \ - } \ - inline tvm::te::Tensor Name(const tvm::te::Tensor& A, \ - const tvm::PrimExpr& B) { \ - return topi::OpName(A, B); \ +#define TOPI_DEFINE_OP_OVERLOAD(Name, OpName) \ + inline tvm::te::Tensor Name(const tvm::te::Tensor& A, const tvm::te::Tensor& B) { \ + return topi::OpName(A, B); \ + } \ + inline tvm::te::Tensor Name(const tvm::PrimExpr& A, const tvm::te::Tensor& B) { \ + return topi::OpName(A, B); \ + } \ + inline tvm::te::Tensor Name(const tvm::te::Tensor& A, const tvm::PrimExpr& B) { \ + return topi::OpName(A, B); \ } /*! diff --git a/topi/include/topi/contrib/cublas.h b/topi/include/topi/contrib/cublas.h index f2ed029f5b33..30ad52510e6f 100644 --- a/topi/include/topi/contrib/cublas.h +++ b/topi/include/topi/contrib/cublas.h @@ -24,8 +24,8 @@ #ifndef TOPI_CONTRIB_CUBLAS_H_ #define TOPI_CONTRIB_CUBLAS_H_ -#include #include +#include namespace topi { namespace contrib { @@ -33,65 +33,51 @@ using namespace tvm; using namespace tvm::te; using namespace topi::detail; /*! -* \brief Create an op that multiplies lhs and rhs with cuBLAS -* -* \param lhs The left matrix operand -* \param rhs The right matrix operand -* \param transa Whether to transpose lhs -* \param transb Whether to transpose rhs -* -* \return The output tensor -*/ -inline Tensor cublas_matmul(const Tensor& lhs, - const Tensor& rhs, - bool transa, - bool transb) { + * \brief Create an op that multiplies lhs and rhs with cuBLAS + * + * \param lhs The left matrix operand + * \param rhs The right matrix operand + * \param transa Whether to transpose lhs + * \param transb Whether to transpose rhs + * + * \return The output tensor + */ +inline Tensor cublas_matmul(const Tensor& lhs, const Tensor& rhs, bool transa, bool transb) { auto n = transa ? lhs->shape[1] : lhs->shape[0]; auto m = transb ? rhs->shape[0] : rhs->shape[1]; return make_extern( - { { n, m } }, { lhs->dtype }, { lhs, rhs }, - [&](Array ins, Array outs) { - return call_packed({ - StringImmNode::make("tvm.contrib.cublas.matmul"), - pack_buffer(ins[0]), - pack_buffer(ins[1]), - pack_buffer(outs[0]), - transa, - transb }); - }, "C", "", {})[0]; + {{n, m}}, {lhs->dtype}, {lhs, rhs}, + [&](Array ins, Array outs) { + return call_packed({StringImm("tvm.contrib.cublas.matmul"), pack_buffer(ins[0]), + pack_buffer(ins[1]), pack_buffer(outs[0]), transa, transb}); + }, + "C", "", {})[0]; } /*! -* \brief Create an op that multiplies batch matrices -* lhs and rhs with cuBLAS -* -* \param lhs The left matrix operand -* \param rhs The right matrix operand -* \param transa Whether to transpose lhs -* \param transb Whether to transpose rhs -* -* \return The output tensor -*/ -inline Tensor cublas_batch_matmul(const Tensor& lhs, - const Tensor& rhs, - bool transa, - bool transb) { + * \brief Create an op that multiplies batch matrices + * lhs and rhs with cuBLAS + * + * \param lhs The left matrix operand + * \param rhs The right matrix operand + * \param transa Whether to transpose lhs + * \param transb Whether to transpose rhs + * + * \return The output tensor + */ +inline Tensor cublas_batch_matmul(const Tensor& lhs, const Tensor& rhs, bool transa, bool transb) { auto b = lhs->shape[0]; auto n = transa ? lhs->shape[2] : lhs->shape[1]; auto m = transb ? rhs->shape[1] : rhs->shape[2]; return make_extern( - { { b, n, m } }, { lhs->dtype }, { lhs, rhs }, - [&](Array ins, Array outs) { - return call_packed({ - StringImmNode::make("tvm.contrib.cublas.batch_matmul"), - pack_buffer(ins[0]), - pack_buffer(ins[1]), - pack_buffer(outs[0]), - transa, - transb }); - }, "C", "", {})[0]; + {{b, n, m}}, {lhs->dtype}, {lhs, rhs}, + [&](Array ins, Array outs) { + return call_packed({StringImm("tvm.contrib.cublas.batch_matmul"), pack_buffer(ins[0]), + pack_buffer(ins[1]), pack_buffer(outs[0]), transa, transb}); + }, + "C", "", {})[0]; } } // namespace contrib diff --git a/topi/include/topi/contrib/rocblas.h b/topi/include/topi/contrib/rocblas.h index f0bf92678f9a..988c37555b1c 100644 --- a/topi/include/topi/contrib/rocblas.h +++ b/topi/include/topi/contrib/rocblas.h @@ -25,6 +25,7 @@ #define TOPI_CONTRIB_ROCBLAS_H_ #include + #include "topi/detail/extern.h" namespace topi { @@ -32,33 +33,26 @@ namespace contrib { using namespace tvm; using namespace tvm::te; /*! -* \brief Create an op that multiplies lhs and rhs with rocBLAS -* -* \param lhs The left matrix operand -* \param rhs The right matrix operand -* \param transa Whether to transpose lhs -* \param transb Whether to transpose rhs -* -* \return The output tensor -*/ -inline Tensor rocblas_matmul(const Tensor& lhs, - const Tensor& rhs, - bool transa, - bool transb) { + * \brief Create an op that multiplies lhs and rhs with rocBLAS + * + * \param lhs The left matrix operand + * \param rhs The right matrix operand + * \param transa Whether to transpose lhs + * \param transb Whether to transpose rhs + * + * \return The output tensor + */ +inline Tensor rocblas_matmul(const Tensor& lhs, const Tensor& rhs, bool transa, bool transb) { auto n = transa ? lhs->shape[1] : lhs->shape[0]; auto m = transb ? rhs->shape[0] : rhs->shape[1]; return make_extern( - { { n, m } }, { lhs->dtype }, { lhs, rhs }, - [&](Array ins, Array outs) { - return call_packed({ - StringImmNode::make("tvm.contrib.rocblas.matmul"), - pack_buffer(ins[0]), - pack_buffer(ins[1]), - pack_buffer(outs[0]), - transa, - transb }); - }, "C", "", {})[0]; + {{n, m}}, {lhs->dtype}, {lhs, rhs}, + [&](Array ins, Array outs) { + return call_packed({StringImm("tvm.contrib.rocblas.matmul"), pack_buffer(ins[0]), + pack_buffer(ins[1]), pack_buffer(outs[0]), transa, transb}); + }, + "C", "", {})[0]; } } // namespace contrib diff --git a/topi/include/topi/cuda/dense.h b/topi/include/topi/cuda/dense.h index 1f0701e1aa38..145d249f90c6 100644 --- a/topi/include/topi/cuda/dense.h +++ b/topi/include/topi/cuda/dense.h @@ -24,14 +24,14 @@ #ifndef TOPI_CUDA_DENSE_H_ #define TOPI_CUDA_DENSE_H_ -#include -#include -#include -#include -#include -#include #include +#include #include +#include +#include +#include +#include +#include namespace topi { using namespace tvm; @@ -39,21 +39,19 @@ using namespace tvm::te; namespace cuda { /*! -* \brief Implementation of dense for CUDA backend -* -* \param target The target device -* \param data Tensor with shape [batch, in_dim] -* \param weight Tensor with shape [out_dim, in_dim] -* \param bias Tensor with shape [out_dim]. Optional; to omit bias, pass Tensor() -* \param out_dtype Output data type. Used for mixed precision. -* -* \return Tensor with shape [batch, out_dim] -*/ -inline tvm::te::Tensor dense_cuda(const Target& target, - const tvm::te::Tensor& data, - const tvm::te::Tensor& weight, - const tvm::te::Tensor& bias, - const DataType& out_dtype) { + * \brief Implementation of dense for CUDA backend + * + * \param target The target device + * \param data Tensor with shape [batch, in_dim] + * \param weight Tensor with shape [out_dim, in_dim] + * \param bias Tensor with shape [out_dim]. Optional; to omit bias, pass Tensor() + * \param out_dtype Output data type. Used for mixed precision. + * + * \return Tensor with shape [batch, out_dim] + */ +inline tvm::te::Tensor dense_cuda(const Target& target, const tvm::te::Tensor& data, + const tvm::te::Tensor& weight, const tvm::te::Tensor& bias, + const DataType& out_dtype) { CHECK_EQ(data->shape.size(), 2) << "dense requires 2-D data"; CHECK_EQ(weight->shape.size(), 2) << "dense requires 2-D weight"; if (bias.defined()) { @@ -68,10 +66,8 @@ inline tvm::te::Tensor dense_cuda(const Target& target, CHECK_EQ(data->dtype, out_dtype) << "Mixed precision not supported."; auto mm = topi::contrib::cublas_matmul(data, weight, false, true); if (bias.defined()) { - mm = tvm::te::compute({ batch, out_dim }, - [&](Var i, Var j) { - return mm(i, j) + bias(j); - }, "tensor", kBroadcast); + mm = tvm::te::compute( + {batch, out_dim}, [&](Var i, Var j) { return mm(i, j) + bias(j); }, "tensor", kBroadcast); } return mm; @@ -81,16 +77,15 @@ inline tvm::te::Tensor dense_cuda(const Target& target, } /*! -* \brief Create a CUDA schedule for dense -* -* \param target The target to generate a schedule for. -* \param outs The output tensors. -* -* \return A schedule for the given ops. -*/ -inline Schedule schedule_dense(const Target &target, const Array& outs) { - if (target->target_name == "cuda" && - target->libs().count("cublas")) { + * \brief Create a CUDA schedule for dense + * + * \param target The target to generate a schedule for. + * \param outs The output tensors. + * + * \return A schedule for the given ops. + */ +inline Schedule schedule_dense(const Target& target, const Array& outs) { + if (target->target_name == "cuda" && target->libs().count("cublas")) { return topi::generic::schedule_extern(target, outs); } diff --git a/topi/include/topi/cuda/injective.h b/topi/include/topi/cuda/injective.h index a7792a5f4f1b..5a5c5af37349 100644 --- a/topi/include/topi/cuda/injective.h +++ b/topi/include/topi/cuda/injective.h @@ -24,11 +24,11 @@ #ifndef TOPI_CUDA_INJECTIVE_H_ #define TOPI_CUDA_INJECTIVE_H_ +#include +#include +#include #include #include -#include -#include -#include namespace topi { using namespace tvm; @@ -63,7 +63,7 @@ inline Schedule schedule_injective_from_existing(Schedule sch, const Tensor& out * * \return A schedule for the given ops. */ -inline Schedule schedule_injective(const Target &target, const Array& outs) { +inline Schedule schedule_injective(const Target& target, const Array& outs) { Array out_ops; for (auto t : outs) { out_ops.push_back(t->op); diff --git a/topi/include/topi/cuda/normalization.h b/topi/include/topi/cuda/normalization.h index bfc209db213b..f8f498eaffcf 100644 --- a/topi/include/topi/cuda/normalization.h +++ b/topi/include/topi/cuda/normalization.h @@ -24,20 +24,20 @@ #ifndef TOPI_CUDA_NORMALIZATION_H_ #define TOPI_CUDA_NORMALIZATION_H_ +#include +#include #include #include -#include -#include namespace topi { using namespace tvm; using namespace tvm::te; namespace cuda { /*! -* \brief Create a CUDA schedule for LRN -* \param outs The output tensors. -* \return A schedule for the given ops. -*/ + * \brief Create a CUDA schedule for LRN + * \param outs The output tensors. + * \return A schedule for the given ops. + */ inline Schedule schedule_lrn(const Array& outs) { Array out_ops; for (auto t : outs) { diff --git a/topi/include/topi/cuda/pooling.h b/topi/include/topi/cuda/pooling.h index 75b66b3a7c9d..87866f2c6933 100644 --- a/topi/include/topi/cuda/pooling.h +++ b/topi/include/topi/cuda/pooling.h @@ -24,12 +24,12 @@ #ifndef TOPI_CUDA_POOLING_H_ #define TOPI_CUDA_POOLING_H_ +#include +#include +#include +#include #include #include -#include -#include -#include -#include namespace topi { using namespace tvm; @@ -38,14 +38,14 @@ using namespace tvm::te; namespace cuda { /*! -* \brief Create a CUDA schedule for pool -* -* \param target The target to generate a schedule for. -* \param outs The output tensors. -* -* \return A schedule for the given ops. -*/ -inline Schedule schedule_pool(const Target &target, const Array& outs) { + * \brief Create a CUDA schedule for pool + * + * \param target The target to generate a schedule for. + * \param outs The output tensors. + * + * \return A schedule for the given ops. + */ +inline Schedule schedule_pool(const Target& target, const Array& outs) { Array out_ops; for (auto t : outs) { out_ops.push_back(t->op); @@ -105,14 +105,14 @@ inline Schedule schedule_pool(const Target &target, const Array& outs) { } /*! -* \brief Create a CUDA schedule for global_pool -* -* \param target The target to generate a schedule for. -* \param outs The output tensors. -* -* \return A schedule for the given ops. -*/ -inline Schedule schedule_global_pool(const Target &target, const Array& outs) { + * \brief Create a CUDA schedule for global_pool + * + * \param target The target to generate a schedule for. + * \param outs The output tensors. + * + * \return A schedule for the given ops. + */ +inline Schedule schedule_global_pool(const Target& target, const Array& outs) { Array out_ops; for (auto t : outs) { out_ops.push_back(t->op); @@ -142,7 +142,7 @@ inline Schedule schedule_global_pool(const Target &target, const Array& s[out].split(i, num_thread, &by, &ty); IterVar bx, tx; s[out].split(c, num_thread, &bx, &tx); - s[out].reorder({ by, bx, ty, tx }); + s[out].reorder({by, bx, ty, tx}); s[out].bind(ty, thread_y); s[out].bind(tx, thread_x); s[out].bind(by, block_y); diff --git a/topi/include/topi/cuda/reduction.h b/topi/include/topi/cuda/reduction.h index add8d99aec91..35ce346eaaee 100644 --- a/topi/include/topi/cuda/reduction.h +++ b/topi/include/topi/cuda/reduction.h @@ -24,11 +24,11 @@ #ifndef TOPI_CUDA_REDUCTION_H_ #define TOPI_CUDA_REDUCTION_H_ +#include +#include +#include #include #include -#include -#include -#include namespace topi { using namespace tvm; @@ -45,10 +45,8 @@ namespace cuda { * an index, such as argmax or argmin. * * \return The schedule given by sch -*/ -Schedule ScheduleReduce(const Target& target, - Operation op, - Schedule sch, + */ +Schedule ScheduleReduce(const Target& target, Operation op, Schedule sch, bool is_idx_reduce = false) { Tensor data_out; Tensor data_in; @@ -61,8 +59,8 @@ Schedule ScheduleReduce(const Target& target, } auto out_stage = sch[data_out]; - CHECK_GT(out_stage->op.as()->reduce_axis.size(), 0) << - "reduce_axis must be greater than zero"; + CHECK_GT(out_stage->op.as()->reduce_axis.size(), 0) + << "reduce_axis must be greater than zero"; bool all_reduce; int num_thread; @@ -120,10 +118,8 @@ Schedule ScheduleReduce(const Target& target, } } else { if (is_idx_reduce) { - sch[temp_idx_input].compute_at(stage_real, - stage_real->op.as()->axis[0]); - sch[temp_val_input].compute_at(stage_real, - stage_real->op.as()->axis[0]); + sch[temp_idx_input].compute_at(stage_real, stage_real->op.as()->axis[0]); + sch[temp_val_input].compute_at(stage_real, stage_real->op.as()->axis[0]); } } @@ -152,13 +148,13 @@ void TraverseBeforeReduce(Schedule s, Operation op) { } /*! -* \brief Schedule a reduce op, then invoke TraverseBeforeReduce on each -* of the op's inputs. -* -* \param target The target to generate a schedule for. -* \param s The schedule we are building -* \param op The reduce op -*/ + * \brief Schedule a reduce op, then invoke TraverseBeforeReduce on each + * of the op's inputs. + * + * \param target The target to generate a schedule for. + * \param s The schedule we are building + * \param op The reduce op + */ void TraverseAfterReduce(const Target& target, Schedule s, Operation op) { if (is_broadcast(op->tag)) { LOG(ERROR) << "Elementwise op after reduce is not yet supported"; @@ -178,13 +174,13 @@ void TraverseAfterReduce(const Target& target, Schedule s, Operation op) { } /*! -* \brief Create a CUDA schedule for a reduce operation. -* -* \param target The target to generate a schedule for. -* \param outs The output tensors. -* -* \return A schedule for the given ops. -*/ + * \brief Create a CUDA schedule for a reduce operation. + * + * \param target The target to generate a schedule for. + * \param outs The output tensors. + * + * \return A schedule for the given ops. + */ Schedule schedule_reduce(const Target& target, Array outs) { CHECK_EQ(outs.size(), 1) << "outs must have size 1"; Array out_ops; diff --git a/topi/include/topi/cuda/softmax.h b/topi/include/topi/cuda/softmax.h index 4c88c3e9eddf..a3aa857d8c0c 100644 --- a/topi/include/topi/cuda/softmax.h +++ b/topi/include/topi/cuda/softmax.h @@ -24,11 +24,11 @@ #ifndef TOPI_CUDA_SOFTMAX_H_ #define TOPI_CUDA_SOFTMAX_H_ +#include +#include +#include #include #include -#include -#include -#include namespace topi { using namespace tvm; @@ -44,7 +44,7 @@ namespace cuda { * * \return A schedule for the given ops. */ -inline Schedule schedule_softmax(const Target &target, const Array& outs) { +inline Schedule schedule_softmax(const Target& target, const Array& outs) { Array out_ops; for (auto t : outs) { out_ops.push_back(t->op); diff --git a/topi/include/topi/detail/array_utils.h b/topi/include/topi/detail/array_utils.h index 3a3453a7baf4..d7204722c4f6 100644 --- a/topi/include/topi/detail/array_utils.h +++ b/topi/include/topi/detail/array_utils.h @@ -39,7 +39,7 @@ using namespace tvm::te; * * \return True iff the given array contains the given item. */ -template +template inline bool contains(Array array, T item) { for (auto& i : array) { if (i == item) { diff --git a/topi/include/topi/detail/broadcast.h b/topi/include/topi/detail/broadcast.h index 8622920dc374..ca3029327875 100644 --- a/topi/include/topi/detail/broadcast.h +++ b/topi/include/topi/detail/broadcast.h @@ -24,8 +24,8 @@ #ifndef TOPI_DETAIL_BROADCAST_H_ #define TOPI_DETAIL_BROADCAST_H_ -#include #include +#include #include #include @@ -77,10 +77,9 @@ inline BroadcastHelper BroadcastShape(const tvm::Array& shape1, bh.vars1.push_front(bh.all_vars[0]); bh.vars2.push_front(bh.all_vars[0]); } else { - CHECK(false) << "Incompatible broadcast dims: " << shape1[s1_size - i] - << " and " << shape2[s2_size - i] << " in: " - << tvm::Array(shape1.begin(), shape1.end()) - << " and " + CHECK(false) << "Incompatible broadcast dims: " << shape1[s1_size - i] << " and " + << shape2[s2_size - i] + << " in: " << tvm::Array(shape1.begin(), shape1.end()) << " and " << tvm::Array(shape2.begin(), shape2.end()); } } @@ -97,10 +96,8 @@ inline BroadcastHelper BroadcastShape(const tvm::Array& shape1, } inline tvm::Array InputIndexFromBroadcast( - const tvm::Array& ovars, - const tvm::te::Tensor& T, - const std::deque& my_vars, - const std::deque& all_vars) { + const tvm::Array& ovars, const tvm::te::Tensor& T, + const std::deque& my_vars, const std::deque& all_vars) { tvm::Array ivars; CHECK_EQ(ovars.size(), all_vars.size()); // N^2, could use a map but NBD. @@ -125,21 +122,16 @@ inline tvm::Array InputIndexFromBroadcast( } template -inline tvm::te::Tensor WithBroadcast(FBinaryExpr op, - const tvm::te::Tensor& A, - const tvm::te::Tensor& B, - const std::string& name = "tensor", - const std::string& tag = "") { +inline tvm::te::Tensor WithBroadcast(FBinaryExpr op, const tvm::te::Tensor& A, + const tvm::te::Tensor& B, const std::string& name = "tensor", + const std::string& tag = "") { auto bh = BroadcastShape(A->shape, B->shape); auto l = [&](tvm::Array ovars) { return op(A(InputIndexFromBroadcast(ovars, A, bh.vars1, bh.all_vars)), B(InputIndexFromBroadcast(ovars, B, bh.vars2, bh.all_vars))); }; - return tvm::te::compute( - tvm::Array(bh.common_shape.begin(), bh.common_shape.end()), - l, - name, - tag); + return tvm::te::compute(tvm::Array(bh.common_shape.begin(), bh.common_shape.end()), + l, name, tag); } } // namespace detail diff --git a/topi/include/topi/detail/constant_utils.h b/topi/include/topi/detail/constant_utils.h index 74be9453ae61..9bd125119987 100644 --- a/topi/include/topi/detail/constant_utils.h +++ b/topi/include/topi/detail/constant_utils.h @@ -24,9 +24,10 @@ #ifndef TOPI_DETAIL_CONSTANT_UTILS_H_ #define TOPI_DETAIL_CONSTANT_UTILS_H_ -#include -#include +#include +#include #include +#include #include #include @@ -43,10 +44,7 @@ using namespace tvm::te; * * \return true if the given expr is a constant int or uint, false otherwise. */ -inline bool IsConstInt(PrimExpr expr) { - return - expr->IsInstance(); -} +inline bool IsConstInt(PrimExpr expr) { return expr->IsInstance(); } /*! * \brief Get the value of the given constant integer expression. An error @@ -73,13 +71,11 @@ inline int64_t GetConstInt(PrimExpr expr) { * * \return A vector of the integer values */ -inline std::vector GetConstIntValues( - Array exprs, const std::string& var_name) { +inline std::vector GetConstIntValues(Array exprs, const std::string& var_name) { std::vector result; if (!exprs.defined()) return result; for (auto expr : exprs) { - CHECK(IsConstInt(expr)) << "All elements of " - << var_name << " must be constant integers"; + CHECK(IsConstInt(expr)) << "All elements of " << var_name << " must be constant integers"; result.push_back(GetConstInt(expr)); } return result; @@ -94,8 +90,8 @@ inline std::vector GetConstIntValues( * * \return A vector of the int64_t values */ -inline std::vector GetConstInt64Values( - Array exprs, const std::string& var_name) { +inline std::vector GetConstInt64Values(Array exprs, + const std::string& var_name) { std::vector result; if (!exprs.defined()) return result; for (auto expr : exprs) { @@ -106,8 +102,8 @@ inline std::vector GetConstInt64Values( } /*! - * \brief Check weather the two expressions are equal or not, if not simplify the expressions and check again - * \note This is stronger equality check than tvm::tir::Equal + * \brief Check weather the two expressions are equal or not, if not simplify the expressions and + * check again \note This is stronger equality check than tvm::tir::Equal * * \param lhs First expreesion * \param rhs Second expreesion @@ -119,7 +115,7 @@ inline bool EqualCheck(PrimExpr lhs, PrimExpr rhs) { bool result = expr_equal(lhs, rhs); if (!result) { PrimExpr zero(0); - result = expr_equal(tvm::tir::CanonicalSimplify(lhs-rhs), zero); + result = expr_equal(tvm::arith::Analyzer().Simplify(lhs - rhs), zero); } return result; } diff --git a/topi/include/topi/detail/extern.h b/topi/include/topi/detail/extern.h index ab83200ff387..b84fbc7722a1 100644 --- a/topi/include/topi/detail/extern.h +++ b/topi/include/topi/detail/extern.h @@ -25,9 +25,9 @@ #define TOPI_DETAIL_EXTERN_H_ #include -#include -#include +#include +#include namespace topi { namespace detail { @@ -43,13 +43,10 @@ using namespace tvm::te; * * \return The Buffer object */ -inline Buffer DeclExternBuffer(Array shape, - DataType dtype, - std::string name) { +inline Buffer DeclExternBuffer(Array shape, DataType dtype, std::string name) { auto data = var(name, DataType::Handle()); auto elem_offset = PrimExpr(); - return BufferNode::make(data, dtype, shape, Array(), elem_offset, name, "", - -1, 0, kDefault); + return Buffer(data, dtype, shape, Array(), elem_offset, name, "", -1, 0, kDefault); } /*! @@ -76,15 +73,12 @@ using FExtern = std::function, Array)>; * be one output Tensor for each element of out_shapes, with dtype equal to the corresponding * element of out_types. */ -inline Array make_extern(const Array< Array >& out_shapes, +inline Array make_extern(const Array >& out_shapes, const std::vector& out_types, - const Array& inputs, - FExtern fextern, - std::string name, - std::string tag, - ::tvm::Map attrs) { + const Array& inputs, FExtern fextern, std::string name, + std::string tag, ::tvm::Map attrs) { CHECK_EQ(out_shapes.size(), out_types.size()) - << "make_extern: out_shapes and out_types must have equal size"; + << "make_extern: out_shapes and out_types must have equal size"; Array input_placeholders; for (auto t : inputs) { @@ -96,11 +90,9 @@ inline Array make_extern(const Array< Array >& out_shapes, } auto body = fextern(input_placeholders, output_placeholders); - auto body_stmt = tvm::tir::EvaluateNode::make(body); + auto body_stmt = tvm::tir::Evaluate(body); - auto op = ExternOpNode::make( - name, tag, attrs, inputs, - input_placeholders, output_placeholders, body_stmt); + auto op = ExternOp(name, tag, attrs, inputs, input_placeholders, output_placeholders, body_stmt); Array outputs; for (size_t i = 0; i < output_placeholders.size(); ++i) { @@ -119,27 +111,23 @@ inline Array make_extern(const Array< Array >& out_shapes, */ inline PrimExpr pack_buffer(Buffer buf) { CHECK_GT(buf->shape.size(), 0) << "buf shape must have at least one element"; - auto shape = tvm::tir::CallNode::make( - DataType::Handle(), tvm::tir::intrinsic::tvm_stack_make_shape, - buf->shape, tvm::tir::CallNode::CallType::Intrinsic); + auto shape = tvm::tir::Call(DataType::Handle(), tvm::tir::intrinsic::tvm_stack_make_shape, + buf->shape, tvm::tir::CallNode::CallType::Intrinsic); PrimExpr strides; if (buf->strides.size() > 0) { - strides = tvm::tir::CallNode::make( - DataType::Handle(), tvm::tir::intrinsic::tvm_stack_make_shape, - buf->shape, tvm::tir::CallNode::CallType::Intrinsic); + strides = tvm::tir::Call(DataType::Handle(), tvm::tir::intrinsic::tvm_stack_make_shape, + buf->shape, tvm::tir::CallNode::CallType::Intrinsic); } else { strides = 0; } - Array pack_args{ - buf->data, - shape, - strides, - make_const(DataType::Int(32), static_cast(buf->shape.size())), - make_const(buf->dtype, 0), - buf->elem_offset - }; - return tvm::tir::CallNode::make(DataType::Handle(), tvm::tir::intrinsic::tvm_stack_make_array, - pack_args, tvm::tir::CallNode::CallType::Intrinsic); + Array pack_args{buf->data, + shape, + strides, + make_const(DataType::Int(32), static_cast(buf->shape.size())), + make_const(buf->dtype, 0), + buf->elem_offset}; + return tvm::tir::Call(DataType::Handle(), tvm::tir::intrinsic::tvm_stack_make_array, pack_args, + tvm::tir::CallNode::CallType::Intrinsic); } /*! @@ -152,8 +140,8 @@ inline PrimExpr pack_buffer(Buffer buf) { * \return An expression representing the invocation */ inline PrimExpr call_packed(Array args) { - return tvm::tir::CallNode::make(DataType::Int(32), tvm::tir::intrinsic::tvm_call_packed, - args, tvm::tir::CallNode::CallType::Intrinsic); + return tvm::tir::Call(DataType::Int(32), tvm::tir::intrinsic::tvm_call_packed, args, + tvm::tir::CallNode::CallType::Intrinsic); } } // namespace detail diff --git a/topi/include/topi/detail/pad_utils.h b/topi/include/topi/detail/pad_utils.h index 1f2a7c5d4185..7c416ecefb3c 100644 --- a/topi/include/topi/detail/pad_utils.h +++ b/topi/include/topi/detail/pad_utils.h @@ -18,16 +18,17 @@ */ /*! -* \file pad_utils.h -* \brief Padding helpers -*/ + * \file pad_utils.h + * \brief Padding helpers + */ #ifndef TOPI_DETAIL_PAD_UTILS_H_ #define TOPI_DETAIL_PAD_UTILS_H_ -#include +#include +#include +#include -#include "tvm/tir/expr.h" -#include "tvm/tir/op.h" +#include namespace topi { namespace detail { @@ -50,7 +51,7 @@ inline Array GetPadTuple(PrimExpr pad_h, PrimExpr pad_w) { auto pad_top = indexdiv(pad_h + 1, 2); auto pad_left = indexdiv(pad_w + 1, 2); - return { pad_top, pad_left, pad_h - pad_top, pad_w - pad_left }; + return {pad_top, pad_left, pad_h - pad_top, pad_w - pad_left}; } } // namespace detail diff --git a/topi/include/topi/detail/ravel_unravel.h b/topi/include/topi/detail/ravel_unravel.h index ca46da0a56f2..c87f2c997ca6 100644 --- a/topi/include/topi/detail/ravel_unravel.h +++ b/topi/include/topi/detail/ravel_unravel.h @@ -18,9 +18,9 @@ */ /*! -* \file ravel_unravel.h -* \brief Index ravel and unraval operations -*/ + * \file ravel_unravel.h + * \brief Index ravel and unraval operations + */ #ifndef TOPI_DETAIL_RAVEL_UNRAVEL_H_ #define TOPI_DETAIL_RAVEL_UNRAVEL_H_ @@ -34,13 +34,13 @@ using namespace tvm; using namespace tvm::te; /*! -* \brief Flatten the indices to 1D -* -* \param indices The input coordinates -* \param shape Shape of the tensor -* -* \return The index after flattening -*/ + * \brief Flatten the indices to 1D + * + * \param indices The input coordinates + * \param shape Shape of the tensor + * + * \return The index after flattening + */ inline PrimExpr RavelIndex(Array indices, Array shape) { CHECK_EQ(indices.size(), shape.size()) << "indices and shape must have equal size"; CHECK_GT(indices.size(), 0) << "indices must not be empty"; @@ -56,13 +56,13 @@ inline PrimExpr RavelIndex(Array indices, Array shape) { } /*! -* \brief Convert flattened index to coordinate array -* -* \param idx The 1D index -* \param shape Shape of the tensor -* -* \return The coordinate corresponding to the 1D index -*/ + * \brief Convert flattened index to coordinate array + * + * \param idx The 1D index + * \param shape Shape of the tensor + * + * \return The coordinate corresponding to the 1D index + */ inline Array UnravelIndex(PrimExpr idx, Array shape) { std::vector indices; diff --git a/topi/include/topi/detail/tensor_utils.h b/topi/include/topi/detail/tensor_utils.h index 6ac3982c3cf2..d144c75695ed 100644 --- a/topi/include/topi/detail/tensor_utils.h +++ b/topi/include/topi/detail/tensor_utils.h @@ -24,7 +24,6 @@ #ifndef TOPI_DETAIL_TENSOR_UTILS_H_ #define TOPI_DETAIL_TENSOR_UTILS_H_ - #include namespace topi { @@ -63,7 +62,7 @@ inline bool is_empty_shape(const Array& x) { * \return The interpolated value in the given index. */ inline PrimExpr bilinear_sample_nchw(const Tensor& input, const Array& indices, - const PrimExpr max_y, const PrimExpr max_x) { + const PrimExpr max_y, const PrimExpr max_x) { auto in_y = indices[2]; auto yf = tvm::floor(in_y); auto yc = tvm::cast(DataType::Int(32), tvm::ceil(in_y)); @@ -85,9 +84,7 @@ inline PrimExpr bilinear_sample_nchw(const Tensor& input, const Array& auto C = input(indices[0], indices[1], y1, x0); auto D = input(indices[0], indices[1], y1, x1); - return A * ( 1 - x_lerp) * ( 1 - y_lerp) + - B * x_lerp * (1 - y_lerp) + - C * (1 - x_lerp) * y_lerp + + return A * (1 - x_lerp) * (1 - y_lerp) + B * x_lerp * (1 - y_lerp) + C * (1 - x_lerp) * y_lerp + D * x_lerp * y_lerp; } diff --git a/topi/include/topi/elemwise.h b/topi/include/topi/elemwise.h index 49eb088feb7a..a92d21c27afe 100644 --- a/topi/include/topi/elemwise.h +++ b/topi/include/topi/elemwise.h @@ -24,11 +24,12 @@ #ifndef TOPI_ELEMWISE_H_ #define TOPI_ELEMWISE_H_ -#include -#include #include +#include + #include #include + #include "broadcast.h" namespace topi { @@ -36,13 +37,11 @@ using namespace tvm; using namespace tvm::te; // Unary intrinsic operators -#define TOPI_DECLARE_UNARY_OP(OpName) \ - inline Tensor OpName(const Tensor& x, \ - std::string name = "T_" #OpName, \ - std::string tag = kElementWise) { \ - return compute(x->shape, [&](const Array& i) { \ - return ::tvm::OpName(x(i)); \ - }, name, tag); \ +#define TOPI_DECLARE_UNARY_OP(OpName) \ + inline Tensor OpName(const Tensor& x, std::string name = "T_" #OpName, \ + std::string tag = kElementWise) { \ + return compute( \ + x->shape, [&](const Array& i) { return ::tvm::OpName(x(i)); }, name, tag); \ } TOPI_DECLARE_UNARY_OP(exp); @@ -50,15 +49,24 @@ TOPI_DECLARE_UNARY_OP(erf); TOPI_DECLARE_UNARY_OP(sigmoid); TOPI_DECLARE_UNARY_OP(sqrt); TOPI_DECLARE_UNARY_OP(log); +TOPI_DECLARE_UNARY_OP(log2); +TOPI_DECLARE_UNARY_OP(log10); TOPI_DECLARE_UNARY_OP(floor); TOPI_DECLARE_UNARY_OP(ceil); TOPI_DECLARE_UNARY_OP(round); TOPI_DECLARE_UNARY_OP(trunc); TOPI_DECLARE_UNARY_OP(abs); TOPI_DECLARE_UNARY_OP(cos); +TOPI_DECLARE_UNARY_OP(cosh); TOPI_DECLARE_UNARY_OP(tan); TOPI_DECLARE_UNARY_OP(sin); +TOPI_DECLARE_UNARY_OP(sinh); +TOPI_DECLARE_UNARY_OP(acos); +TOPI_DECLARE_UNARY_OP(acosh); +TOPI_DECLARE_UNARY_OP(asin); +TOPI_DECLARE_UNARY_OP(asinh); TOPI_DECLARE_UNARY_OP(atan); +TOPI_DECLARE_UNARY_OP(atanh); TOPI_DECLARE_UNARY_OP(isnan); TOPI_DECLARE_UNARY_OP(tanh); TOPI_DECLARE_UNARY_OP(isfinite); @@ -68,9 +76,7 @@ TOPI_DECLARE_UNARY_OP(isinf); * \brief Fast_tanh_float implementation from Eigen * https://github.com/eigenteam/eigen-git-mirror/blob/master/Eigen/src/Core/MathFunctionsImpl.h#L26 */ -inline Tensor fast_tanh_float(const Tensor& in, - std::string name, - std::string tag) { +inline Tensor fast_tanh_float(const Tensor& in, std::string name, std::string tag) { // Clamp the inputs to the range [-9, 9] since anything outside // this range is +/-1.0f in single-precision. auto x = maximum(minimum(in, make_const(in->dtype, 9.0)), make_const(in->dtype, -9.0)); @@ -90,178 +96,171 @@ inline Tensor fast_tanh_float(const Tensor& in, auto beta_4 = make_const(in->dtype, 1.18534705686654e-04); auto beta_6 = make_const(in->dtype, 1.19825839466702e-06); - return compute(x->shape, - [&](const Array& i) { - auto x2 = x(i) * x(i); - auto p = x2 * alpha_13 + alpha_11; - p = x2 * p + alpha_9; - p = x2 * p + alpha_7; - p = x2 * p + alpha_5; - p = x2 * p + alpha_3; - p = x2 * p + alpha_1; - p = x(i) * p; - - auto q = x2 * beta_6 + beta_4; - q = x2 * q + beta_2; - q = x2 * q + beta_0; - return p / q; - }, - name, tag); + return compute( + x->shape, + [&](const Array& i) { + auto x2 = x(i) * x(i); + auto p = x2 * alpha_13 + alpha_11; + p = x2 * p + alpha_9; + p = x2 * p + alpha_7; + p = x2 * p + alpha_5; + p = x2 * p + alpha_3; + p = x2 * p + alpha_1; + p = x(i) * p; + + auto q = x2 * beta_6 + beta_4; + q = x2 * q + beta_2; + q = x2 * q + beta_0; + return p / q; + }, + name, tag); } /*! -* \brief Creates an operation that returns hyperbolic tanh of a given tensor -* -* \param x The input tensor -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is tanh -*/ -inline Tensor fast_tanh(const Tensor& x, - std::string name = "T_fast_tanh", + * \brief Creates an operation that returns hyperbolic tanh of a given tensor + * + * \param x The input tensor + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is tanh + */ +inline Tensor fast_tanh(const Tensor& x, std::string name = "T_fast_tanh", std::string tag = kElementWise) { if (x->dtype == DataType::Float(32)) { // invoke fast_tanh_float implementation return fast_tanh_float(x, name, tag); } else { // fallback to default implementation - return compute(x->shape, [&](const Array& i) { - return ::tvm::tanh(x(i)); - }, name, tag); + return compute( + x->shape, [&](const Array& i) { return ::tvm::tanh(x(i)); }, name, tag); } } /*! -* \brief Creates an operation that returns identity of a given tensor -* -* \param x The input tensor -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the identity operation -*/ -inline Tensor identity(const Tensor& x, - std::string name = "T_identity", + * \brief Creates an operation that returns identity of a given tensor + * + * \param x The input tensor + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the identity operation + */ +inline Tensor identity(const Tensor& x, std::string name = "T_identity", std::string tag = kElementWise) { - return compute(x->shape, [&](const Array& i) { - return x(i); - }, name, tag); + return compute( + x->shape, [&](const Array& i) { return x(i); }, name, tag); } /*! -* \brief Creates an operation that returns the negation of a given tensor -* -* \param x The input tensor -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the negation operation -*/ -inline Tensor negative(const Tensor& x, - std::string name = "T_negative", + * \brief Creates an operation that returns the negation of a given tensor + * + * \param x The input tensor + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the negation operation + */ +inline Tensor negative(const Tensor& x, std::string name = "T_negative", std::string tag = kElementWise) { - return compute(x->shape, [&](const Array& i) { - return -x(i); - }, name, tag); + return compute( + x->shape, [&](const Array& i) { return -x(i); }, name, tag); } /*! -* \brief Creates an operation that returns the logical NOT of a given tensor -* -* \param x The input tensor -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the logical NOT operation -*/ -inline Tensor logical_not(const Tensor& x, - std::string name = "T_logical_not", + * \brief Creates an operation that returns the logical NOT of a given tensor + * + * \param x The input tensor + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the logical NOT operation + */ +inline Tensor logical_not(const Tensor& x, std::string name = "T_logical_not", std::string tag = kElementWise) { - return compute(x->shape, [&](const Array& i) { - return !x(i); - }, name, tag); + return compute( + x->shape, [&](const Array& i) { return !x(i); }, name, tag); } /*! -* \brief Creates an operation that returns the bitwise NOT of a given tensor -* -* \param x The input tensor -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the bitwise NOT operation -*/ -inline Tensor bitwise_not(const Tensor& x, - std::string name = "T_bitwise_not", + * \brief Creates an operation that returns the bitwise NOT of a given tensor + * + * \param x The input tensor + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the bitwise NOT operation + */ +inline Tensor bitwise_not(const Tensor& x, std::string name = "T_bitwise_not", std::string tag = kElementWise) { - return compute(x->shape, [&](const Array& i) { - return ~x(i); - }, name, tag); + return compute( + x->shape, [&](const Array& i) { return ~x(i); }, name, tag); } /*! -* \brief Returns the sign of the tensor -* -* \param x The input tensor -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the sign -*/ -inline Tensor sign(const Tensor& x, - std::string name = "T_sign", - std::string tag = kElementWise) { - return compute(x->shape, [&](const Array& i) { - PrimExpr zero = make_zero(x->dtype); - PrimExpr one = make_const(x->dtype, 1); - PrimExpr minus_one = make_const(x->dtype, -1); - auto s1 = tvm::tir::SelectNode::make((x(i) < zero), minus_one, zero); - auto s2 = tvm::tir::SelectNode::make((x(i) > zero), one, s1); - return s2; - }, name, tag); + * \brief Returns the sign of the tensor + * + * \param x The input tensor + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the sign + */ +inline Tensor sign(const Tensor& x, std::string name = "T_sign", std::string tag = kElementWise) { + return compute( + x->shape, + [&](const Array& i) { + PrimExpr zero = make_zero(x->dtype); + PrimExpr one = make_const(x->dtype, 1); + PrimExpr minus_one = make_const(x->dtype, -1); + auto s1 = tvm::tir::Select((x(i) < zero), minus_one, zero); + auto s2 = tvm::tir::Select((x(i) > zero), one, s1); + return s2; + }, + name, tag); } /*! -* \brief Creates an operation that returns rsqrt of a given tensor -* -* \param x The input tensor -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the rsqrt operation -*/ -inline Tensor rsqrt(const Tensor& x, - std::string name = "tensor", - std::string tag = kElementWise) { - return compute(x->shape, [&](const Array& i) { - PrimExpr one = make_const(x->dtype, 1); - return one/tvm::sqrt(x(i)); - }, name, tag); + * \brief Creates an operation that returns rsqrt of a given tensor + * + * \param x The input tensor + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the rsqrt operation + */ +inline Tensor rsqrt(const Tensor& x, std::string name = "tensor", std::string tag = kElementWise) { + return compute( + x->shape, + [&](const Array& i) { + PrimExpr one = make_const(x->dtype, 1); + return one / tvm::sqrt(x(i)); + }, + name, tag); } /*! -* \brief Creates an operation that clips each element of a tensor to -* the interval [a_min, a_max] -* -* \param x The input tensor -* \param a_min The inclusive lower bound of the interval -* \param a_max The inclusive upper bound of the interval -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the clip operation -*/ -inline Tensor clip(const Tensor& x, - const PrimExpr& a_min, - const PrimExpr& a_max, - std::string name = "T_clip", - std::string tag = kElementWise) { - return compute(x->shape, [&](const Array& i) { - auto min_val = tvm::cast(x->dtype, a_min); - auto max_val = tvm::cast(x->dtype, a_max); - return tvm::max(tvm::min(x(i), max_val), min_val); // NOLINT(*) - }, name, tag); + * \brief Creates an operation that clips each element of a tensor to + * the interval [a_min, a_max] + * + * \param x The input tensor + * \param a_min The inclusive lower bound of the interval + * \param a_max The inclusive upper bound of the interval + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the clip operation + */ +inline Tensor clip(const Tensor& x, const PrimExpr& a_min, const PrimExpr& a_max, + std::string name = "T_clip", std::string tag = kElementWise) { + return compute( + x->shape, + [&](const Array& i) { + auto min_val = tvm::cast(x->dtype, a_min); + auto max_val = tvm::cast(x->dtype, a_max); + return tvm::max(tvm::min(x(i), max_val), min_val); // NOLINT(*) + }, + name, tag); } /*! @@ -276,22 +275,23 @@ inline Tensor clip(const Tensor& x, * * \return A Tensor whose op member is the cast operation */ -inline Tensor cast(const Tensor& x, - DataType type, - std::string name = "T_cast", +inline Tensor cast(const Tensor& x, DataType type, std::string name = "T_cast", std::string tag = kElementWise) { - return compute(x->shape, [&](const Array& i) { - auto expr = x(i); - if (expr.dtype().code() == type.code() && expr.dtype().bits() == type.bits()) { - if (expr.dtype().lanes() == type.lanes()) { - return expr; - } else if (expr.dtype().lanes() == 1 && type.lanes() > 1) { - return tvm::tir::BroadcastNode::make(expr, type.lanes()); - } - } - - return tvm::cast(type, x(i)); - }, name, tag); + return compute( + x->shape, + [&](const Array& i) -> PrimExpr { + auto expr = x(i); + if (expr.dtype().code() == type.code() && expr.dtype().bits() == type.bits()) { + if (expr.dtype().lanes() == type.lanes()) { + return expr; + } else if (expr.dtype().lanes() == 1 && type.lanes() > 1) { + return tvm::tir::Broadcast(expr, type.lanes()); + } + } + + return tvm::cast(type, x(i)); + }, + name, tag); } /*! @@ -306,12 +306,12 @@ inline Tensor cast(const Tensor& x, */ inline Tensor reinterpret(const Tensor& x, DataType type, std::string name = "tensor", std::string tag = kElementWise) { - return compute(x->shape, - [&](const Array& i) { - return tvm::tir::CallNode::make(type, "reinterpret", {x(i)}, - tvm::tir::CallNode::PureIntrinsic); - }, - name, tag); + return compute( + x->shape, + [&](const Array& i) { + return tvm::tir::Call(type, "reinterpret", {x(i)}, tvm::tir::CallNode::PureIntrinsic); + }, + name, tag); } /*! @@ -323,63 +323,58 @@ inline Tensor reinterpret(const Tensor& x, DataType type, std::string name = "te * * \return A Tensor whose op member is the sum operation */ -inline Tensor elemwise_sum(const Array& xs, - std::string name = "T_elemwise_sum", +inline Tensor elemwise_sum(const Array& xs, std::string name = "T_elemwise_sum", std::string tag = kElementWise) { CHECK_GT(xs.size(), 0) << "elemwise sum must have at least one input tensor."; - return compute(xs[0]->shape, [&](const Array& i) { - auto sum_expr = xs[0](i); - for (size_t j = 1; j < xs.size(); j++) { - sum_expr = sum_expr + xs[j](i); - } - return sum_expr; - }, name, tag); + return compute( + xs[0]->shape, + [&](const Array& i) { + auto sum_expr = xs[0](i); + for (size_t j = 1; j < xs.size(); j++) { + sum_expr = sum_expr + xs[j](i); + } + return sum_expr; + }, + name, tag); } /*! -* \brief Creates an operation that fill a tensor with fill_value -* -* \param shape The shape of a tensor -* \param dtype The Type of fill_value -* \param fill_value The value to be filled -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the full operation -*/ -inline Tensor full(const Array& shape, - DataType dtype, - const PrimExpr fill_value, - std::string name = "T_full", - std::string tag = kElementWise) { + * \brief Creates an operation that fill a tensor with fill_value + * + * \param shape The shape of a tensor + * \param dtype The Type of fill_value + * \param fill_value The value to be filled + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the full operation + */ +inline Tensor full(const Array& shape, DataType dtype, const PrimExpr fill_value, + std::string name = "T_full", std::string tag = kElementWise) { PrimExpr ev = cast(dtype, fill_value); if (!ev.defined()) { LOG(ERROR) << "Can't cast fill_value to " << dtype; } - return compute(shape, [&](const Array& i) { - return ev; - }, name, tag); + return compute( + shape, [&](const Array& i) { return ev; }, name, tag); } /*! -* \brief Creates an operation that construct a tensor with same shape as input tensor, -* then fill a tensor with fill_value -* -* \param x The input tensor -* \param fill_value The value to be filled -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op memeber is the full_like operation -*/ -inline Tensor full_like(const Tensor& x, - const PrimExpr fill_value, - std::string name = "T_full_like", - std::string tag = kElementWise) { + * \brief Creates an operation that construct a tensor with same shape as input tensor, + * then fill a tensor with fill_value + * + * \param x The input tensor + * \param fill_value The value to be filled + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op memeber is the full_like operation + */ +inline Tensor full_like(const Tensor& x, const PrimExpr fill_value, + std::string name = "T_full_like", std::string tag = kElementWise) { PrimExpr ev = cast(x->dtype, fill_value); - return compute(x->shape, [&](const Array& i) { - return ev; - }, name, tag); + return compute( + x->shape, [&](const Array& i) { return ev; }, name, tag); } /*! @@ -403,9 +398,7 @@ inline Tensor full_like(const Tensor& x, * Approximation for fractional part: * y = exp(f) = 1 + 2 * P(x**2)/(Q(x**2) - P(x**2)) */ -inline Tensor fast_exp_float32(const Tensor& _x, - std::string name, - std::string tag) { +inline Tensor fast_exp_float32(const Tensor& _x, std::string name, std::string tag) { auto x_hi = make_const(DataType::Float(32), 88.3762626647950f); auto x_lo = make_const(DataType::Float(32), -88.3762626647949f); auto log2e = make_const(DataType::Float(32), 1.44269504088896341f); @@ -420,25 +413,25 @@ inline Tensor fast_exp_float32(const Tensor& _x, auto one_half = make_const(DataType::Float(32), 0.5f); auto b = make_const(DataType::Float(32), 127.0f); - return compute(_x->shape, - [&](const Array& i) { - // clamp x - auto x = ::tvm::max(::tvm::min(_x(i), x_hi), x_lo); - // integer part - auto n = ::tvm::floor(x * log2e + one_half); - // fractional part - auto f = x - n * ln2; - auto y = (((((p[0] * f + p[1]) * f + p[2]) * f + p[3])* f+ p[4]) * f - + p[5]) * f * f + f + one; - // Return 2^m * exp(r). - auto ef = tvm::reinterpret(DataType::Float(32), - ::tvm::cast(DataType::Int(32), n + b) << 23); - return ::tvm::max(ef * y, _x(i)); // NOLINT(*) - }, - name, tag); + return compute( + _x->shape, + [&](const Array& i) { + // clamp x + auto x = ::tvm::max(::tvm::min(_x(i), x_hi), x_lo); + // integer part + auto n = ::tvm::floor(x * log2e + one_half); + // fractional part + auto f = x - n * ln2; + auto y = + (((((p[0] * f + p[1]) * f + p[2]) * f + p[3]) * f + p[4]) * f + p[5]) * f * f + f + one; + // Return 2^m * exp(r). + auto ef = + tvm::reinterpret(DataType::Float(32), ::tvm::cast(DataType::Int(32), n + b) << 23); + return ::tvm::max(ef * y, _x(i)); // NOLINT(*) + }, + name, tag); } - /*! * \brief Fast exponential function implementation * @@ -449,16 +442,14 @@ inline Tensor fast_exp_float32(const Tensor& _x, * \return A Tensor whose op member is exponent operation * */ -inline Tensor fast_exp(const Tensor& x, - std::string name = "T_fast_exp", - std::string tag = kElementWise) { +inline Tensor fast_exp(const Tensor& x, std::string name = "T_fast_exp", + std::string tag = kElementWise) { if (x->dtype == DataType::Float(32)) { auto ret = fast_exp_float32(x, name, tag); return ret; } else { - return compute(x->shape, [&](const Array& i) { - return ::tvm::exp(x(i)); - }, name, tag); + return compute( + x->shape, [&](const Array& i) { return ::tvm::exp(x(i)); }, name, tag); } } @@ -466,9 +457,7 @@ inline Tensor fast_exp(const Tensor& x, * \brief Fast_tanh_float implementation from Eigen * https://github.com/eigenteam/eigen-git-mirror/blob/master/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h#L290 */ -inline Tensor fast_erf_float32(const Tensor& data, - std::string name, - std::string tag) { +inline Tensor fast_erf_float32(const Tensor& data, std::string name, std::string tag) { auto plus_4 = make_const(DataType::Float(32), 4.f); auto minus_4 = make_const(DataType::Float(32), -4.f); @@ -488,28 +477,31 @@ inline Tensor fast_erf_float32(const Tensor& data, auto beta_6 = make_const(DataType::Float(32), -2.13374055278905e-04f); auto beta_8 = make_const(DataType::Float(32), -1.45660718464996e-05f); - return compute(data->shape, [&](const Array &i) { - // clamp x - auto x = tvm::max(tvm::min(data(i), plus_4), minus_4); - auto x2 = x * x; - - // Evaluate the numerator polynomial p. - auto p = x2 * alpha_13 + alpha_11; - p = x2 * p + alpha_9; - p = x2 * p + alpha_7; - p = x2 * p + alpha_5; - p = x2 * p + alpha_3; - p = x2 * p + alpha_1; - p = x * p; - - // Evaluate the denominator polynomial p. - auto q = x2 * beta_8 + beta_6; - q = x2 * q + beta_4; - q = x2 * q + beta_2; - q = x2 * q + beta_0; - - return p / q; - }, name, tag); + return compute( + data->shape, + [&](const Array& i) { + // clamp x + auto x = tvm::max(tvm::min(data(i), plus_4), minus_4); + auto x2 = x * x; + + // Evaluate the numerator polynomial p. + auto p = x2 * alpha_13 + alpha_11; + p = x2 * p + alpha_9; + p = x2 * p + alpha_7; + p = x2 * p + alpha_5; + p = x2 * p + alpha_3; + p = x2 * p + alpha_1; + p = x * p; + + // Evaluate the denominator polynomial p. + auto q = x2 * beta_8 + beta_6; + q = x2 * q + beta_4; + q = x2 * q + beta_2; + q = x2 * q + beta_0; + + return p / q; + }, + name, tag); } /*! @@ -521,8 +513,7 @@ inline Tensor fast_erf_float32(const Tensor& data, * * \return A Tensor whose op member is erf operation */ -inline Tensor fast_erf(const Tensor& x, - std::string name = "T_fast_erf", +inline Tensor fast_erf(const Tensor& x, std::string name = "T_fast_erf", std::string tag = kElementWise) { if (x->dtype == DataType::Float(32)) { auto ret = fast_erf_float32(x, name, tag); diff --git a/topi/include/topi/generic/default.h b/topi/include/topi/generic/default.h index 640ab9545141..403b943a16e6 100644 --- a/topi/include/topi/generic/default.h +++ b/topi/include/topi/generic/default.h @@ -24,11 +24,11 @@ #ifndef TOPI_GENERIC_DEFAULT_H_ #define TOPI_GENERIC_DEFAULT_H_ +#include +#include +#include #include #include -#include -#include -#include namespace topi { using namespace tvm; @@ -36,14 +36,14 @@ using namespace tvm::te; namespace generic { /*! -* \brief Create a generic default schedule for the given output tensors. -* -* \param target The target to generate a schedule for. -* \param outs The output tensors. -* -* \return A schedule for the given ops. -*/ -inline Schedule default_schedule(const Target& target, Array outs) { + * \brief Create a generic default schedule for the given output tensors. + * + * \param target The target to generate a schedule for. + * \param outs The output tensors. + * + * \return A schedule for the given ops. + */ +inline Schedule default_schedule(const Target& target, const Array& outs) { Array out_ops; for (auto t : outs) { out_ops.push_back(t->op); @@ -53,15 +53,15 @@ inline Schedule default_schedule(const Target& target, Array outs) { } /*! -* \brief Create a generic default schedule for the given output tensors, and apply -* auto inline -* -* \param target The target to generate a schedule for. -* \param outs The output tensors. -* -* \return A schedule for the given ops. -*/ -inline Schedule default_schedule_auto_inline(const Target& target, Array outs) { + * \brief Create a generic default schedule for the given output tensors, and apply + * auto inline + * + * \param target The target to generate a schedule for. + * \param outs The output tensors. + * + * \return A schedule for the given ops. + */ +inline Schedule default_schedule_auto_inline(const Target& target, const Array& outs) { Array out_ops; for (auto t : outs) { out_ops.push_back(t->op); diff --git a/topi/include/topi/generic/extern.h b/topi/include/topi/generic/extern.h index e08158f297be..3954ac69ada8 100644 --- a/topi/include/topi/generic/extern.h +++ b/topi/include/topi/generic/extern.h @@ -24,12 +24,12 @@ #ifndef TOPI_GENERIC_EXTERN_H_ #define TOPI_GENERIC_EXTERN_H_ -#include -#include -#include -#include #include #include +#include +#include +#include +#include namespace topi { using namespace tvm; @@ -37,14 +37,14 @@ using namespace tvm::te; namespace generic { /*! -* \brief Schedule an extern op followed by injective operations -* -* \param target The target to generate a schedule for. -* \param outs The output tensors. -* -* \return A schedule for the op. -*/ -inline Schedule schedule_extern(const Target& target, Array outs) { + * \brief Schedule an extern op followed by injective operations + * + * \param target The target to generate a schedule for. + * \param outs The output tensors. + * + * \return A schedule for the op. + */ +inline Schedule schedule_extern(const Target& target, const Array& outs) { Array out_ops; for (auto t : outs) { out_ops.push_back(t->op); diff --git a/topi/include/topi/generic/injective.h b/topi/include/topi/generic/injective.h index 7a5aff7eaf80..69962dc645c0 100644 --- a/topi/include/topi/generic/injective.h +++ b/topi/include/topi/generic/injective.h @@ -24,11 +24,11 @@ #ifndef TOPI_GENERIC_INJECTIVE_H_ #define TOPI_GENERIC_INJECTIVE_H_ +#include +#include +#include #include #include -#include -#include -#include namespace topi { using namespace tvm; @@ -57,7 +57,7 @@ inline Schedule schedule_injective_from_existing(Schedule sch, const Tensor& out * * \return A schedule for the given ops. */ -inline Schedule schedule_injective(const Target &target, const Array& outs) { +inline Schedule schedule_injective(const Target& target, const Array& outs) { Array out_ops; for (auto t : outs) { out_ops.push_back(t->op); diff --git a/topi/include/topi/nn.h b/topi/include/topi/nn.h index a1ee8c1a5901..2a195b34fc4f 100644 --- a/topi/include/topi/nn.h +++ b/topi/include/topi/nn.h @@ -24,12 +24,12 @@ #ifndef TOPI_NN_H_ #define TOPI_NN_H_ -#include #include +#include +#include +#include #include -#include #include -#include #include #include @@ -37,19 +37,6 @@ namespace topi { using namespace tvm; using namespace tvm::te; -namespace detail { - -template -tvm::PrimExpr Map(const tvm::Array& exprs, T op) { - CHECK_GE(exprs.size(), 1); - tvm::PrimExpr res = exprs[0]; - for (size_t i = 1; i < exprs.size(); ++i) { - res = op(res, exprs[i]); - } - return res; -} - -} // namespace detail /*! * \brief Creates an operation that performs a rectified linear unit @@ -62,43 +49,38 @@ tvm::PrimExpr Map(const tvm::Array& exprs, T op) { * \return A Tensor whose op member is the relu operation */ template -inline tvm::te::Tensor relu(const tvm::te::Tensor& t, - T threshold = static_cast(0), - std::string name = "T_relu", - std::string tag = kElementWise) { +inline tvm::te::Tensor relu(const tvm::te::Tensor& t, T threshold = static_cast(0), + std::string name = "T_relu", std::string tag = kElementWise) { return tvm::te::compute( t->shape, [&](const tvm::Array& i) { auto threshold_const = tvm::tir::make_const(t->dtype, threshold); return tvm::max(t(i), threshold_const); }, - name, - tag); + name, tag); } /*! -* \brief Creates an operation that performs a leaky rectified linear unit -* -* \param t The input tensor -* \param alpha The slope for the small gradient when t < 0 -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the leaky relu operation -*/ -inline tvm::te::Tensor leaky_relu(const tvm::te::Tensor& t, - double alpha = 0.1, - std::string name = "T_leaky_relu", - std::string tag = kElementWise) { + * \brief Creates an operation that performs a leaky rectified linear unit + * + * \param t The input tensor + * \param alpha The slope for the small gradient when t < 0 + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the leaky relu operation + */ +inline tvm::te::Tensor leaky_relu(const tvm::te::Tensor& t, double alpha = 0.1, + std::string name = "T_leaky_relu", + std::string tag = kElementWise) { return tvm::te::compute( - t->shape, - [&](const tvm::Array& i) { - auto value = t(i); - auto calpha = tvm::tir::make_const(value.dtype(), alpha); - return tvm::tir::SelectNode::make(value > 0, value, value * calpha); - }, - name, - tag); + t->shape, + [&](const tvm::Array& i) { + auto value = t(i); + auto calpha = tvm::tir::make_const(value.dtype(), alpha); + return tvm::tir::Select(value > 0, value, value * calpha); + }, + name, tag); } /*! @@ -112,27 +94,20 @@ inline tvm::te::Tensor leaky_relu(const tvm::te::Tensor& t, * * \return A Tensor whose op member is the parametric relu operation */ -inline tvm::te::Tensor prelu(const tvm::te::Tensor &x, - const tvm::te::Tensor &slope, - const int axis = 1, - std::string name = "T_prelu", - std::string tag = kBroadcast) { - CHECK((size_t)axis < x->shape.size()) << - "Wrong axis (" << axis << ")value. "; - CHECK(topi::detail::GetConstInt(slope->shape[0]) == - topi::detail::GetConstInt(x->shape[axis])) - << "Wrong slope shape received."; +inline tvm::te::Tensor prelu(const tvm::te::Tensor& x, const tvm::te::Tensor& slope, + const int axis = 1, std::string name = "T_prelu", + std::string tag = kBroadcast) { + CHECK((size_t)axis < x->shape.size()) << "Wrong axis (" << axis << ")value. "; + CHECK(topi::detail::GetConstInt(slope->shape[0]) == topi::detail::GetConstInt(x->shape[axis])) + << "Wrong slope shape received."; - return tvm::te::compute(x->shape, - [&](const tvm::Array &indices) { - auto xval = x(indices); - return tvm::tir::SelectNode::make( - xval > 0, - xval, - xval * slope(indices[axis])); - }, - name, - tag); + return tvm::te::compute( + x->shape, + [&](const tvm::Array& indices) { + auto xval = x(indices); + return tvm::tir::Select(xval > 0, xval, xval * slope(indices[axis])); + }, + name, tag); } /*! @@ -172,27 +147,25 @@ inline tvm::te::Tensor prelu(const tvm::te::Tensor &x, * * */ -inline tvm::te::Tensor pad(const tvm::te::Tensor& t, - const tvm::Array& pad_before, - tvm::Array pad_after = tvm::Array(), - PrimExpr pad_value = PrimExpr(), - std::string name = "T_pad", - std::string tag = kElementWise, - std::string pad_mode = "constant") { +inline tvm::te::Tensor pad(const tvm::te::Tensor& t, const tvm::Array& pad_before, + tvm::Array pad_after = tvm::Array(), + PrimExpr pad_value = PrimExpr(), std::string name = "T_pad", + std::string tag = kElementWise, std::string pad_mode = "constant") { if (pad_after.size() < pad_before.size()) { for (size_t i = pad_after.size(); i < pad_before.size(); ++i) { pad_after.push_back(pad_before[i]); } } + arith::Analyzer analyzer; CHECK_GE(pad_before.size(), 1); CHECK_EQ(pad_before.size(), pad_after.size()); tvm::Array output_shape; tvm::Array pad_before_int32; tvm::Array pad_after_int32; - for (const auto &ele : pad_before) { + for (const auto& ele : pad_before) { pad_before_int32.push_back(tvm::cast(tvm::DataType::Int(32), ele)); } - for (const auto &ele : pad_after) { + for (const auto& ele : pad_after) { pad_after_int32.push_back(tvm::cast(tvm::DataType::Int(32), ele)); } for (size_t i = 0; i < t->shape.size(); ++i) { @@ -200,13 +173,14 @@ inline tvm::te::Tensor pad(const tvm::te::Tensor& t, output_shape.push_back(t->shape[i]); } else { output_shape.push_back( - tvm::tir::Simplify(t->shape[i] + pad_before_int32[i] + pad_after_int32[i])); + analyzer.Simplify(t->shape[i] + pad_before_int32[i] + pad_after_int32[i])); } } if (!pad_value.defined()) { pad_value = tvm::tir::make_const(t->dtype, 0); } + auto l = [&](tvm::Array ovars) { tvm::Array indices; tvm::Array sel; @@ -223,31 +197,27 @@ inline tvm::te::Tensor pad(const tvm::te::Tensor& t, indices.push_back(ovars[i]); } if (!topi::detail::EqualCheck(pad_after_int32[i], 0)) { - sel.push_back(tvm::tir::Simplify(ovars[i] < pad_before_int32[i] + t->shape[i])); + sel.push_back(analyzer.Simplify(ovars[i] < pad_before_int32[i] + t->shape[i])); } if (pad_mode == "edge") { - pad_idx.push_back(tvm::if_then_else( - ovars[i] < pad_before[i], - 0, - tvm::if_then_else(ovars[i] >= pad_before[i] + t->shape[i], - t->shape[i] - 1, - ovars[i] - pad_before[i]))); + pad_idx.push_back( + tvm::if_then_else(ovars[i] < pad_before[i], 0, + tvm::if_then_else(ovars[i] >= pad_before[i] + t->shape[i], + t->shape[i] - 1, ovars[i] - pad_before[i]))); } else if (pad_mode == "reflect") { - pad_idx.push_back(tvm::if_then_else( - ovars[i] < pad_before[i], - pad_before[i] - ovars[i], - tvm::if_then_else(ovars[i] >= pad_before[i] + t->shape[i], - t->shape[i] * 2 - ovars[i] + pad_before[i] - 2, - ovars[i] - pad_before[i]))); + pad_idx.push_back( + tvm::if_then_else(ovars[i] < pad_before[i], pad_before[i] - ovars[i], + tvm::if_then_else(ovars[i] >= pad_before[i] + t->shape[i], + t->shape[i] * 2 - ovars[i] + pad_before[i] - 2, + ovars[i] - pad_before[i]))); } } if (sel.size() != 0) { + auto fand = [](PrimExpr a, PrimExpr b) { return a && b; }; if (pad_mode == "constant") { - return tvm::if_then_else( - detail::Map(sel, tvm::tir::AndNode::make), t(indices), pad_value); + return tvm::if_then_else(foldl(fand, const_true(1), sel), t(indices), pad_value); } else if (pad_mode == "edge" || pad_mode == "reflect") { - return tvm::if_then_else( - detail::Map(sel, tvm::tir::AndNode::make), t(indices), t(pad_idx)); + return tvm::if_then_else(foldl(fand, const_true(1), sel), t(indices), t(pad_idx)); } } return t(indices); @@ -275,34 +245,27 @@ inline tvm::te::Tensor pad(const tvm::te::Tensor& t, * \return A Tensor whose op member is the 2-D convolution operation (NCHW * layout) */ -inline tvm::te::Tensor conv2d_nchw(const tvm::te::Tensor& I, - const tvm::te::Tensor& W, - int pad_h = 0, - int pad_w = 0, - int stride_h = 1, - int stride_w = 1, - std::string name = "T_conv2d_nchw", - std::string tag = kConv2dNCHW) { +inline tvm::te::Tensor conv2d_nchw(const tvm::te::Tensor& I, const tvm::te::Tensor& W, + int pad_h = 0, int pad_w = 0, int stride_h = 1, int stride_w = 1, + std::string name = "T_conv2d_nchw", + std::string tag = kConv2dNCHW) { CHECK_EQ(4, I->shape.size()); CHECK_EQ(4, W->shape.size()); auto pH = I->shape[2]; auto pW = I->shape[3]; tvm::Array output_shape{ - I->shape[0], // B - W->shape[0], // O - indexdiv(I->shape[2] - W->shape[2] + 2 * pad_h, stride_h) + 1, // H - indexdiv(I->shape[3] - W->shape[3] + 2 * pad_w, stride_w) + 1 // W + I->shape[0], // B + W->shape[0], // O + indexdiv(I->shape[2] - W->shape[2] + 2 * pad_h, stride_h) + 1, // H + indexdiv(I->shape[3] - W->shape[3] + 2 * pad_w, stride_w) + 1 // W }; auto i = tvm::te::reduce_axis(tvm::Range{0, I->shape[1]}, "i"); auto kh = tvm::te::reduce_axis(tvm::Range{0, W->shape[2]}, "kh"); auto kw = tvm::te::reduce_axis(tvm::Range{0, W->shape[3]}, "kw"); - auto T = (pad_h == 0 && pad_w == 0) - ? I - : pad(I, {tvm::PrimExpr(0), tvm::PrimExpr(0), pad_h, pad_w}); + auto T = + (pad_h == 0 && pad_w == 0) ? I : pad(I, {tvm::PrimExpr(0), tvm::PrimExpr(0), pad_h, pad_w}); auto l = [&](tvm::tir::Var b, tvm::tir::Var o, tvm::tir::Var h, tvm::tir::Var w) { - return tvm::sum( - T(b, i, stride_h * h + kh, stride_w * w + kw) * W(o, i, kh, kw), - {i, kh, kw}); + return tvm::sum(T(b, i, stride_h * h + kh, stride_w * w + kw) * W(o, i, kh, kw), {i, kh, kw}); }; return tvm::te::compute(output_shape, l, name, tag); } @@ -326,14 +289,10 @@ inline tvm::te::Tensor conv2d_nchw(const tvm::te::Tensor& I, * \return A Tensor whose op member is the 2-D convolution operation * (HWCN layout) */ -inline tvm::te::Tensor conv2d_hwcn(const tvm::te::Tensor& I, - const tvm::te::Tensor& W, - int pad_h = 0, - int pad_w = 0, - int stride_h = 1, - int stride_w = 1, - std::string name = "T_conv2d_hwcn", - std::string tag = kConv2dHWCN) { +inline tvm::te::Tensor conv2d_hwcn(const tvm::te::Tensor& I, const tvm::te::Tensor& W, + int pad_h = 0, int pad_w = 0, int stride_h = 1, int stride_w = 1, + std::string name = "T_conv2d_hwcn", + std::string tag = kConv2dHWCN) { CHECK_EQ(4, I->shape.size()); CHECK_EQ(4, W->shape.size()); auto pH = I->shape[2]; @@ -341,22 +300,19 @@ inline tvm::te::Tensor conv2d_hwcn(const tvm::te::Tensor& I, tvm::Array output_shape{ indexdiv(I->shape[2] - W->shape[2] + 2 * pad_h, stride_h) + 1, // H indexdiv(I->shape[3] - W->shape[3] + 2 * pad_w, stride_w) + 1, // W - I->shape[2], // B - W->shape[3] // O + I->shape[2], // B + W->shape[3] // O }; auto i = tvm::te::reduce_axis(tvm::Range{0, I->shape[3]}, "i"); auto kh = tvm::te::reduce_axis(tvm::Range{0, W->shape[0]}, "kh"); auto kw = tvm::te::reduce_axis(tvm::Range{0, W->shape[1]}, "kw"); auto T = (pad_h == 0 && pad_w == 0) ? I : pad(I, {pad_h, pad_w}); auto l = [&](tvm::tir::Var b, tvm::tir::Var o, tvm::tir::Var h, tvm::tir::Var w) { - return tvm::sum( - T(stride_h * h + kh, stride_w * w + kw, i, b) * W(kh, kw, i, o), - {i, kh, kw}); + return tvm::sum(T(stride_h * h + kh, stride_w * w + kw, i, b) * W(kh, kw, i, o), {i, kh, kw}); }; return tvm::te::compute(output_shape, l, name, tag); } - /*! * \brief Creates an operation that performs a 2-D depthwise convolution with * an NCHW-layout @@ -377,67 +333,59 @@ inline tvm::te::Tensor conv2d_hwcn(const tvm::te::Tensor& I, * \return A Tensor whose op member is the 2-D depthwise convolution operation * (NCHW layout) */ -inline tvm::te::Tensor depthwise_conv2d_nchw(const tvm::te::Tensor& I, - const tvm::te::Tensor& W, - int pad_h = 0, - int pad_w = 0, - int stride_h = 1, - int stride_w = 1, - std::string name = "T_depthwise_conv2d_nchw", - std::string tag = kDepthwiseConv2dNCHW) { +inline tvm::te::Tensor depthwise_conv2d_nchw(const tvm::te::Tensor& I, const tvm::te::Tensor& W, + int pad_h = 0, int pad_w = 0, int stride_h = 1, + int stride_w = 1, + std::string name = "T_depthwise_conv2d_nchw", + std::string tag = kDepthwiseConv2dNCHW) { CHECK_EQ(4, I->shape.size()); CHECK_EQ(4, W->shape.size()); auto pH = I->shape[2]; auto pW = I->shape[3]; auto pCM = W->shape[1]; // channel_multiplier tvm::Array output_shape{ - I->shape[0], // B - W->shape[1], // O + I->shape[0], // B + W->shape[1], // O indexdiv(I->shape[2] - W->shape[2] + 2 * pad_h, stride_h) + 1, // H indexdiv(I->shape[3] - W->shape[3] + 2 * pad_w, stride_w) + 1 // W }; auto i = tvm::te::reduce_axis(tvm::Range{0, I->shape[1]}, "i"); auto kh = tvm::te::reduce_axis(tvm::Range{0, W->shape[2]}, "kh"); auto kw = tvm::te::reduce_axis(tvm::Range{0, W->shape[3]}, "kw"); - auto T = (pad_h == 0 && pad_w == 0) - ? I - : pad(I, {tvm::PrimExpr(0), tvm::PrimExpr(0), pad_h, pad_w}); + auto T = + (pad_h == 0 && pad_w == 0) ? I : pad(I, {tvm::PrimExpr(0), tvm::PrimExpr(0), pad_h, pad_w}); auto l = [&](tvm::tir::Var b, tvm::tir::Var o, tvm::tir::Var h, tvm::tir::Var w) { return tvm::sum(T(b, indexdiv(i, pCM), stride_h * h + kh, stride_w * w + kw) * - W(indexdiv(i, pCM), indexmod(o, pCM), kh, kw), + W(indexdiv(i, pCM), indexmod(o, pCM), kh, kw), {i, kh, kw}); }; return tvm::te::compute(output_shape, l, name, tag); } -inline tvm::te::Tensor depthwise_conv2d_nhwc(const tvm::te::Tensor& I, - const tvm::te::Tensor& W, - int pad_h = 0, - int pad_w = 0, - int stride_h = 1, - int stride_w = 1, - std::string name = "T_depthwise_conv2d_nhwc", - std::string tag = kDepthwiseConv2dNHWC) { +inline tvm::te::Tensor depthwise_conv2d_nhwc(const tvm::te::Tensor& I, const tvm::te::Tensor& W, + int pad_h = 0, int pad_w = 0, int stride_h = 1, + int stride_w = 1, + std::string name = "T_depthwise_conv2d_nhwc", + std::string tag = kDepthwiseConv2dNHWC) { CHECK_EQ(4, I->shape.size()); CHECK_EQ(4, W->shape.size()); auto pH = I->shape[1]; auto pW = I->shape[2]; auto pCM = W->shape[1]; // channel_multiplier tvm::Array output_shape{ - I->shape[0], // B + I->shape[0], // B indexdiv(I->shape[1] - W->shape[1] + 2 * pad_h, stride_h) + 1, // H - indexdiv(I->shape[2] - W->shape[2] + 2 * pad_w, stride_w) + 1, // W - W->shape[3], // O + indexdiv(I->shape[2] - W->shape[2] + 2 * pad_w, stride_w) + 1, // W + W->shape[3], // O }; auto i = tvm::te::reduce_axis(tvm::Range{0, I->shape[3]}, "i"); auto kh = tvm::te::reduce_axis(tvm::Range{0, W->shape[0]}, "kh"); auto kw = tvm::te::reduce_axis(tvm::Range{0, W->shape[1]}, "kw"); - auto T = (pad_h == 0 && pad_w == 0) - ? I - : pad(I, {tvm::PrimExpr(0), pad_h, pad_w, tvm::PrimExpr(0)}); + auto T = + (pad_h == 0 && pad_w == 0) ? I : pad(I, {tvm::PrimExpr(0), pad_h, pad_w, tvm::PrimExpr(0)}); auto l = [&](tvm::tir::Var b, tvm::tir::Var h, tvm::tir::Var w, tvm::tir::Var o) { return tvm::sum(T(b, stride_h * h + kh, stride_w * w + kw, indexdiv(i, pCM)) * - W(kh, kw, indexdiv(i, pCM), indexmod(o, pCM)), + W(kh, kw, indexdiv(i, pCM), indexmod(o, pCM)), {kh, kw, i}); }; return tvm::te::compute(output_shape, l, name, tag); @@ -463,22 +411,19 @@ inline tvm::te::Tensor depthwise_conv2d_nhwc(const tvm::te::Tensor& I, * \return A Tensor whose op member is the 2-D groupconvolution operation * (NCHW layout) */ -inline tvm::te::Tensor group_conv2d_ngchw(const tvm::te::Tensor& I, - const tvm::te::Tensor& W, - int pad_h = 0, - int pad_w = 0, - int stride_h = 1, - int stride_w = 1, - std::string name = "T_group_conv2d_ngchw", - std::string tag = kGroupConv2d) { +inline tvm::te::Tensor group_conv2d_ngchw(const tvm::te::Tensor& I, const tvm::te::Tensor& W, + int pad_h = 0, int pad_w = 0, int stride_h = 1, + int stride_w = 1, + std::string name = "T_group_conv2d_ngchw", + std::string tag = kGroupConv2d) { CHECK_EQ(5, I->shape.size()); CHECK_EQ(5, W->shape.size()); auto pH = I->shape[2]; auto pW = I->shape[3]; tvm::Array output_shape{ - I->shape[0], // B - I->shape[1], // G - W->shape[2], // O + I->shape[0], // B + I->shape[1], // G + W->shape[2], // O indexdiv(I->shape[3] - W->shape[3] + 2 * pad_h, stride_h) + 1, // H indexdiv(I->shape[4] - W->shape[4] + 2 * pad_w, stride_w) + 1 // W }; @@ -495,9 +440,8 @@ inline tvm::te::Tensor group_conv2d_ngchw(const tvm::te::Tensor& I, tvm::tir::Var o = args[2]; tvm::tir::Var h = args[3]; tvm::tir::Var w = args[4]; - return tvm::sum( - I(b, g, i, stride_h * h + kh, stride_w * w + kw) * W(g, i, o, kh, kw), - {i, kh, kw}); + return tvm::sum(I(b, g, i, stride_h * h + kh, stride_w * w + kw) * W(g, i, o, kh, kw), + {i, kh, kw}); }; return tvm::te::compute(output_shape, l, name, tag); } diff --git a/topi/include/topi/nn/batch_matmul.h b/topi/include/topi/nn/batch_matmul.h index 12075e6d67ea..80525c427976 100644 --- a/topi/include/topi/nn/batch_matmul.h +++ b/topi/include/topi/nn/batch_matmul.h @@ -24,8 +24,8 @@ #ifndef TOPI_NN_BATCH_MATMUL_H_ #define TOPI_NN_BATCH_MATMUL_H_ -#include #include +#include #include @@ -35,15 +35,14 @@ using namespace tvm; using namespace tvm::te; /*! -* \brief Creates an operation that calculates matrix multiplication in batch. -* -* \param x Tensor with shape [batch, M, K] -* \param y Tensor with shape [batch, N, K] -* -* \return Tensor with shape [batch, M, N] -*/ -inline tvm::te::Tensor batch_matmul(const tvm::te::Tensor& x, - const tvm::te::Tensor& y) { + * \brief Creates an operation that calculates matrix multiplication in batch. + * + * \param x Tensor with shape [batch, M, K] + * \param y Tensor with shape [batch, N, K] + * + * \return Tensor with shape [batch, M, N] + */ +inline tvm::te::Tensor batch_matmul(const tvm::te::Tensor& x, const tvm::te::Tensor& y) { CHECK_EQ(x->shape.size(), 3) << "batch_matmul requires 3-D data"; CHECK_EQ(y->shape.size(), 3) << "batch_matmul requires 3-D data"; @@ -54,10 +53,8 @@ inline tvm::te::Tensor batch_matmul(const tvm::te::Tensor& x, auto k = tvm::te::reduce_axis(Range(0, K), "k"); auto result = tvm::te::compute( - { batch, M, N }, - [&](Var b, Var i, Var j) { - return tvm::sum(x(b, i, k) * y(b, j, k), { k }); - }, "tensor", "batch_matmul"); + {batch, M, N}, [&](Var b, Var i, Var j) { return tvm::sum(x(b, i, k) * y(b, j, k), {k}); }, + "tensor", "batch_matmul"); return result; } diff --git a/topi/include/topi/nn/bias_add.h b/topi/include/topi/nn/bias_add.h index 209c30ca875b..18e95deaccb1 100644 --- a/topi/include/topi/nn/bias_add.h +++ b/topi/include/topi/nn/bias_add.h @@ -24,10 +24,10 @@ #ifndef TOPI_NN_BIAS_ADD_H_ #define TOPI_NN_BIAS_ADD_H_ -#include -#include #include +#include #include +#include #include @@ -35,16 +35,15 @@ namespace topi { namespace nn { /*! -* \brief Creates an operation that calculates data + bias -* -* \param data Tensor with shape [batch, in_dim] -* \param bias Tensor with shape [batch]. -* \param axis The axis to add the bias to. -* \return Tensor with shape [batch, in_dim] -*/ -inline tvm::te::Tensor bias_add(const tvm::te::Tensor& data, - const tvm::te::Tensor& bias, - int axis) { + * \brief Creates an operation that calculates data + bias + * + * \param data Tensor with shape [batch, in_dim] + * \param bias Tensor with shape [batch]. + * \param axis The axis to add the bias to. + * \return Tensor with shape [batch, in_dim] + */ +inline tvm::te::Tensor bias_add(const tvm::te::Tensor& data, const tvm::te::Tensor& bias, + int axis) { int data_ndim = data->shape.size(); if (axis < 0) { axis += data_ndim; diff --git a/topi/include/topi/nn/bnn.h b/topi/include/topi/nn/bnn.h index 6bda65317706..c0626cd43c7f 100644 --- a/topi/include/topi/nn/bnn.h +++ b/topi/include/topi/nn/bnn.h @@ -24,10 +24,10 @@ #ifndef TOPI_NN_BNN_H_ #define TOPI_NN_BNN_H_ -#include -#include -#include #include +#include +#include +#include #include @@ -37,70 +37,67 @@ using namespace tvm; using namespace tvm::te; /*! -* \brief Binarization and bit-packing along a certain axis. -* -* \param data N-D tensor, can be any layout -* \param axis The axis along which to do binarization and bit-packing. This axis -* must have a size equal to an integer multiple of 32. -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return Output tensor with dtype uint32 -*/ -inline tvm::te::Tensor binarize_pack(const tvm::te::Tensor& data, - int axis, - std::string name = "PackedInput", - std::string tag = "binarize_pack") { + * \brief Binarization and bit-packing along a certain axis. + * + * \param data N-D tensor, can be any layout + * \param axis The axis along which to do binarization and bit-packing. This axis + * must have a size equal to an integer multiple of 32. + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return Output tensor with dtype uint32 + */ +inline tvm::te::Tensor binarize_pack(const tvm::te::Tensor& data, int axis, + std::string name = "PackedInput", + std::string tag = "binarize_pack") { auto ishape = data->shape; CHECK_EQ(GetConstInt(ishape[axis]) % 32, 0) - << "binarize_pack: axis size must be a multiple of 32"; + << "binarize_pack: axis size must be a multiple of 32"; + arith::Analyzer analyzer; auto n = ishape.size(); Array oshape; for (size_t i = 0; i < n; ++i) { - oshape.push_back(i == static_cast(axis) ? - tvm::tir::Simplify(indexdiv(ishape[i], 32)) : - ishape[i]); + oshape.push_back(i == static_cast(axis) ? analyzer.Simplify(indexdiv(ishape[i], 32)) + : ishape[i]); } return tvm::te::compute( - oshape, - [&](const Array& indices) { - Array start_idx; - for (size_t i = 0; i < n; ++i) { - start_idx.push_back(i == static_cast(axis) ? - indices[i] * 32 : - static_cast(indices[i])); - } - auto packed = make_const(DataType::UInt(32), 0); - for (size_t j = 0; j < 32; ++j) { - Array idx; + oshape, + [&](const Array& indices) { + Array start_idx; for (size_t i = 0; i < n; ++i) { - idx.push_back(i == static_cast(axis) ? - start_idx[i] + static_cast(j) : - start_idx[i]); + start_idx.push_back(i == static_cast(axis) ? indices[i] * 32 + : static_cast(indices[i])); } - auto sign = tvm::cast(DataType::UInt(32), data(idx) >= 0); - packed = (packed | sign); - if (j == 31) { - return packed; + auto packed = make_const(DataType::UInt(32), 0); + for (size_t j = 0; j < 32; ++j) { + Array idx; + for (size_t i = 0; i < n; ++i) { + idx.push_back(i == static_cast(axis) ? start_idx[i] + static_cast(j) + : start_idx[i]); + } + auto sign = tvm::cast(DataType::UInt(32), data(idx) >= 0); + packed = (packed | sign); + if (j == 31) { + return packed; + } + packed = packed << 1; } - packed = packed << 1; - } - return packed; // never reached, but suppress compiler warning - }, name, tag); + return packed; // never reached, but suppress compiler warning + }, + name, tag); } /*! -* \brief Binary matrix multiplication using xor and bit-count -* -* \param data Tensor with shape [batch, in_dim], dtype is uint32 -* \param weight Tensor with shape [out_dim, in_dim], dtype is uint32 -* -* \return Tensor with shape [batch, out_dim], dtype is float32 -*/ -inline tvm::te::Tensor binary_dense(const tvm::te::Tensor& data, - const tvm::te::Tensor& weight) { + * \brief Binary matrix multiplication using xor and bit-count + * + * \param data Tensor with shape [batch, in_dim], dtype is uint32 + * \param weight Tensor with shape [out_dim, in_dim], dtype is uint32 + * + * \return Tensor with shape [batch, out_dim], dtype is float32 + */ +inline tvm::te::Tensor binary_dense(const tvm::te::Tensor& data, const tvm::te::Tensor& weight) { CHECK_EQ(data->shape.size(), 2) << "binary_dense requires 2-D data"; CHECK_EQ(weight->shape.size(), 2) << "binary_dense requires 2-D weight"; CHECK_EQ(data->dtype, DataType::UInt(32)) << "binary_dense requires uint32 data"; @@ -112,16 +109,13 @@ inline tvm::te::Tensor binary_dense(const tvm::te::Tensor& data, auto k = tvm::te::reduce_axis(Range(0, in_dim), "k"); auto matmul = tvm::te::compute( - { batch, out_dim }, - [&](Var i, Var j) { - return tvm::sum(popcount(data(i, k) ^ weight(j, k)), { k }); - }, "tensor", "binary_dense"); + {batch, out_dim}, + [&](Var i, Var j) { return tvm::sum(popcount(data(i, k) ^ weight(j, k)), {k}); }, "tensor", + "binary_dense"); return tvm::te::compute( - { batch, out_dim }, - [&](Var i, Var j) { - return 32 * in_dim - 2.0f * matmul(i, j); - }, "tensor", kElementWise); + {batch, out_dim}, [&](Var i, Var j) { return 32 * in_dim - 2.0f * matmul(i, j); }, "tensor", + kElementWise); } } // namespace nn diff --git a/topi/include/topi/nn/dense.h b/topi/include/topi/nn/dense.h index 57f071a2ebfd..4ee36c275ef3 100644 --- a/topi/include/topi/nn/dense.h +++ b/topi/include/topi/nn/dense.h @@ -24,8 +24,8 @@ #ifndef TOPI_NN_DENSE_H_ #define TOPI_NN_DENSE_H_ -#include #include +#include #include @@ -35,19 +35,17 @@ using namespace tvm; using namespace tvm::te; /*! -* \brief Creates an operation that calculates data * weight^T + bias -* -* \param data Tensor with shape [batch, in_dim] -* \param weight Tensor with shape [out_dim, in_dim] -* \param bias Tensor with shape [out_dim]. Optional; to omit bias, pass Tensor() -* \param out_dtype Output data type. Used for mixed precision. -* -* \return Tensor with shape [batch, out_dim] -*/ -inline tvm::te::Tensor dense(const tvm::te::Tensor& data, - const tvm::te::Tensor& weight, - const tvm::te::Tensor& bias, - const DataType& out_dtype) { + * \brief Creates an operation that calculates data * weight^T + bias + * + * \param data Tensor with shape [batch, in_dim] + * \param weight Tensor with shape [out_dim, in_dim] + * \param bias Tensor with shape [out_dim]. Optional; to omit bias, pass Tensor() + * \param out_dtype Output data type. Used for mixed precision. + * + * \return Tensor with shape [batch, out_dim] + */ +inline tvm::te::Tensor dense(const tvm::te::Tensor& data, const tvm::te::Tensor& weight, + const tvm::te::Tensor& bias, const DataType& out_dtype) { CHECK_EQ(data->shape.size(), 2) << "dense requires 2-D data"; CHECK_EQ(weight->shape.size(), 2) << "dense requires 2-D weight"; if (bias.defined()) { @@ -60,18 +58,17 @@ inline tvm::te::Tensor dense(const tvm::te::Tensor& data, auto k = tvm::te::reduce_axis(Range(0, in_dim), "k"); auto matmul = tvm::te::compute( - { batch, out_dim }, - [&](Var i, Var j) { - return tvm::sum(tvm::cast(out_dtype, data(i, k)) * - tvm::cast(out_dtype, weight(j, k)), { k }); - }, "tensor", "dense"); + {batch, out_dim}, + [&](Var i, Var j) { + return tvm::sum(tvm::cast(out_dtype, data(i, k)) * tvm::cast(out_dtype, weight(j, k)), {k}); + }, + "tensor", "dense"); if (bias.defined()) { matmul = tvm::te::compute( - { batch, out_dim }, - [&](Var i, Var j) { - return matmul(i, j) + tvm::cast(out_dtype, bias(j)); - }, "tensor", kBroadcast); + {batch, out_dim}, + [&](Var i, Var j) { return matmul(i, j) + tvm::cast(out_dtype, bias(j)); }, "tensor", + kBroadcast); } return matmul; diff --git a/topi/include/topi/nn/dilate.h b/topi/include/topi/nn/dilate.h index a67bf3a300b2..0d3ab89bbae6 100644 --- a/topi/include/topi/nn/dilate.h +++ b/topi/include/topi/nn/dilate.h @@ -24,9 +24,9 @@ #ifndef TOPI_NN_DILATE_H_ #define TOPI_NN_DILATE_H_ -#include -#include #include +#include +#include #include @@ -36,13 +36,13 @@ using namespace tvm; using namespace tvm::te; /*! -* \brief Create a new expression of the logical and of all -* conditions in the arguments. -* -* \param args The arguments to find the logical conjunction of -* -* \return The logical conjunction expression -*/ + * \brief Create a new expression of the logical and of all + * conditions in the arguments. + * + * \param args The arguments to find the logical conjunction of + * + * \return The logical conjunction expression + */ PrimExpr all(Array args) { CHECK_GT(args.size(), 0) << "all requires at least one argument"; @@ -54,52 +54,50 @@ PrimExpr all(Array args) { } /*! -* \brief Dilate data with zeros -* -* \param x The input tensor, this can have any number of -* dimensions and any layout. -* \param strides Dilation stride for each dimension. Stride 1 -* means no dilation. -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return The output tensor. -*/ -inline Tensor dilate(const Tensor& x, - Array strides, - std::string name = "tensor", + * \brief Dilate data with zeros + * + * \param x The input tensor, this can have any number of + * dimensions and any layout. + * \param strides Dilation stride for each dimension. Stride 1 + * means no dilation. + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return The output tensor. + */ +inline Tensor dilate(const Tensor& x, Array strides, std::string name = "tensor", std::string tag = kInjective) { auto n = x->shape.size(); - CHECK_EQ(n, strides.size()) - << "strides size (" << strides.size() - << ") must match dimension of x (" << n << ")"; + CHECK_EQ(n, strides.size()) << "strides size (" << strides.size() + << ") must match dimension of x (" << n << ")"; Array out_shape; + arith::Analyzer analyzer; for (size_t i = 0; i < n; ++i) { - out_shape.push_back(tvm::tir::Simplify( - (x->shape[i] - 1) * cast(DataType::Int(32), strides[i] + 1))); + out_shape.push_back( + analyzer.Simplify((x->shape[i] - 1) * cast(DataType::Int(32), strides[i] + 1))); } return tvm::te::compute( - out_shape, - [&](const Array& indices) { - Array not_zero; - Array index_tuple; - for (size_t i = 0; i < n; ++i) { - if (IsConstInt(strides[i]) && GetConstInt(strides[i]) == 1) { - index_tuple.push_back(indices[i]); - } else { - index_tuple.push_back(indexdiv(indices[i], strides[i])); - not_zero.push_back((indexmod(indices[i], strides[i])) == 0); + out_shape, + [&](const Array& indices) { + Array not_zero; + Array index_tuple; + for (size_t i = 0; i < n; ++i) { + if (IsConstInt(strides[i]) && GetConstInt(strides[i]) == 1) { + index_tuple.push_back(indices[i]); + } else { + index_tuple.push_back(indexdiv(indices[i], strides[i])); + not_zero.push_back((indexmod(indices[i], strides[i])) == 0); + } + } + if (not_zero.size() > 0) { + auto all_not_zero = all(not_zero); + return tvm::if_then_else(all_not_zero, x(index_tuple), make_const(x->dtype, 0)); } - } - if (not_zero.size() > 0) { - auto all_not_zero = all(not_zero); - return tvm::if_then_else( - all_not_zero, x(index_tuple), make_const(x->dtype, 0)); - } - return x(index_tuple); - }, name, tag); + return x(index_tuple); + }, + name, tag); } } // namespace nn diff --git a/topi/include/topi/nn/flatten.h b/topi/include/topi/nn/flatten.h index 81cef2eda17b..1ac5de4a2ed1 100644 --- a/topi/include/topi/nn/flatten.h +++ b/topi/include/topi/nn/flatten.h @@ -24,9 +24,9 @@ #ifndef TOPI_NN_FLATTEN_H_ #define TOPI_NN_FLATTEN_H_ -#include -#include #include +#include +#include #include #include @@ -37,25 +37,23 @@ using namespace tvm; using namespace tvm::te; /*! -* \brief Flattens the input tensor into a 2-D tensor by collapsing higher dimensions. -* This requires the input tensor to have constant sized dimensions. -* -* \param x The input tensor. -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A 2-D tensor. -*/ -inline Tensor flatten(const Tensor& x, - std::string name = "tensor", - std::string tag = kInjective) { + * \brief Flattens the input tensor into a 2-D tensor by collapsing higher dimensions. + * This requires the input tensor to have constant sized dimensions. + * + * \param x The input tensor. + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A 2-D tensor. + */ +inline Tensor flatten(const Tensor& x, std::string name = "tensor", std::string tag = kInjective) { auto ishape = x->shape; PrimExpr dim = 1; for (size_t i = 1; i < ishape.size(); ++i) { dim = dim * ishape[i]; } - Array oshape({ ishape[0], dim }); + Array oshape({ishape[0], dim}); std::vector extra_shape; for (size_t i = 1; i < ishape.size(); ++i) { @@ -64,17 +62,19 @@ inline Tensor flatten(const Tensor& x, std::reverse(extra_shape.begin(), extra_shape.end()); return tvm::te::compute( - oshape, [&](Var i, Var j) { - PrimExpr idx = j; - std::vector index; - for (auto s : extra_shape) { - index.push_back(indexmod(idx, s)); - idx = indexdiv(idx, s); - } - index.push_back(i); - std::reverse(index.begin(), index.end()); - return x(index); - }, name, tag); + oshape, + [&](Var i, Var j) { + PrimExpr idx = j; + std::vector index; + for (auto s : extra_shape) { + index.push_back(indexmod(idx, s)); + idx = indexdiv(idx, s); + } + index.push_back(i); + std::reverse(index.begin(), index.end()); + return x(index); + }, + name, tag); } } // namespace nn diff --git a/topi/include/topi/nn/local_response_norm.h b/topi/include/topi/nn/local_response_norm.h index 14dec390e24a..4e8dfd99a517 100644 --- a/topi/include/topi/nn/local_response_norm.h +++ b/topi/include/topi/nn/local_response_norm.h @@ -24,8 +24,8 @@ #ifndef TOPI_NN_LOCAL_RESPONSE_NORM_H_ #define TOPI_NN_LOCAL_RESPONSE_NORM_H_ -#include #include +#include #include @@ -35,60 +35,45 @@ using namespace tvm; using namespace tvm::te; /*! -* \brief Local response normalization inference operator -* -* \param data The input tensor. 4-D shape NCHW or NHWC -* \param size Integer to define normalisation window size -* \param axis Input data layout channel axis -* \param alpha Float scaling factor -* \param beta Exponent value -* \param bias Offset to avoid dividing by zero -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the Local response normalization operation -*/ -inline Tensor lrn(const Tensor& data, - int size, - int axis = 1, - float alpha = 0.0001, - float beta = 0.75, - float bias = 2, - std::string name = "tensor", + * \brief Local response normalization inference operator + * + * \param data The input tensor. 4-D shape NCHW or NHWC + * \param size Integer to define normalisation window size + * \param axis Input data layout channel axis + * \param alpha Float scaling factor + * \param beta Exponent value + * \param bias Offset to avoid dividing by zero + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the Local response normalization operation + */ +inline Tensor lrn(const Tensor& data, int size, int axis = 1, float alpha = 0.0001, + float beta = 0.75, float bias = 2, std::string name = "tensor", std::string tag = kBroadcast) { CHECK_EQ(data->shape.size(), 4) << "LRN requires 4-D input"; CHECK_EQ(size % 2, 1) << "size should be odd number"; CHECK(axis == 1 || axis == 3) << "axis should be 1 or 3 for NCHW and NHWC"; auto input_shape = data->shape; - Array pad_before{ 0, 0, 0, 0}; - Array pad_after{ 0, 0, 0, 0}; - pad_before.Set(axis, static_cast(size/2)); - pad_after.Set(axis, static_cast(size/2)); + Array pad_before{0, 0, 0, 0}; + Array pad_after{0, 0, 0, 0}; + pad_before.Set(axis, static_cast(size / 2)); + pad_after.Set(axis, static_cast(size / 2)); auto pad_data = pad(data, pad_before, pad_after, 0, "pad_data"); auto rxs = tvm::te::reduce_axis(Range(0, size), "rxs"); Tensor sqr_sum; if (axis == 1) { - sqr_sum = tvm::te::compute(input_shape, - [&](Var i, Var l, Var j, Var k) { - return tvm::sum(pad_data(i, l + rxs, j, k) * - pad_data(i, l + rxs, j, k), - {rxs}); - }); + sqr_sum = tvm::te::compute(input_shape, [&](Var i, Var l, Var j, Var k) { + return tvm::sum(pad_data(i, l + rxs, j, k) * pad_data(i, l + rxs, j, k), {rxs}); + }); } else if (axis == 3) { - sqr_sum = tvm::te::compute(input_shape, - [&](Var i, Var l, Var j, Var k) { - return tvm::sum(pad_data(i, l, j, k + rxs) * - pad_data(i, l, j, k + rxs), - {rxs}); - }); + sqr_sum = tvm::te::compute(input_shape, [&](Var i, Var l, Var j, Var k) { + return tvm::sum(pad_data(i, l, j, k + rxs) * pad_data(i, l, j, k + rxs), {rxs}); + }); } - auto sqrt_sum_up = tvm::te::compute( - input_shape, - [&](Var i, Var j, Var k, Var l) { - return tvm::pow(bias + - (div(alpha * sqr_sum(i, j, k, l), size)), - beta); - }); + auto sqrt_sum_up = tvm::te::compute(input_shape, [&](Var i, Var j, Var k, Var l) { + return tvm::pow(bias + (div(alpha * sqr_sum(i, j, k, l), size)), beta); + }); return topi::divide(data, sqrt_sum_up); } } // namespace nn diff --git a/topi/include/topi/nn/mapping.h b/topi/include/topi/nn/mapping.h index 17d14045ac4b..2bf3314e7377 100644 --- a/topi/include/topi/nn/mapping.h +++ b/topi/include/topi/nn/mapping.h @@ -24,8 +24,8 @@ #ifndef TOPI_NN_MAPPING_H_ #define TOPI_NN_MAPPING_H_ -#include #include +#include #include @@ -35,49 +35,39 @@ using namespace tvm; using namespace tvm::te; /*! -* \brief Scale and shift with NCHW order -* -* \param x The input tensor. -* \param scale Scale tensor, 1-D of size channel -* \param shift Shift tensor, 1-D of size channel -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the scale shift operation -*/ -inline Tensor scale_shift_nchw(const Tensor& x, - const Tensor& scale, - const Tensor& shift, - std::string name = "ScaleShift", - std::string tag = kBroadcast) { + * \brief Scale and shift with NCHW order + * + * \param x The input tensor. + * \param scale Scale tensor, 1-D of size channel + * \param shift Shift tensor, 1-D of size channel + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the scale shift operation + */ +inline Tensor scale_shift_nchw(const Tensor& x, const Tensor& scale, const Tensor& shift, + std::string name = "ScaleShift", std::string tag = kBroadcast) { return tvm::te::compute( - x->shape, - [&](Var b, Var c, Var h, Var w) { - return x(b, c, h, w) * scale(c) + shift(w); - }, name, tag); + x->shape, [&](Var b, Var c, Var h, Var w) { return x(b, c, h, w) * scale(c) + shift(c); }, + name, tag); } /*! -* \brief Scale and shift with NHWC order -* -* \param x The input tensor. -* \param scale Scale tensor, 1-D of size channel -* \param shift Shift tensor, 1-D of size channel -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the scale shift operation -*/ -inline Tensor scale_shift_nhwc(const Tensor& x, - const Tensor& scale, - const Tensor& shift, - std::string name = "ScaleShift", - std::string tag = kBroadcast) { + * \brief Scale and shift with NHWC order + * + * \param x The input tensor. + * \param scale Scale tensor, 1-D of size channel + * \param shift Shift tensor, 1-D of size channel + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the scale shift operation + */ +inline Tensor scale_shift_nhwc(const Tensor& x, const Tensor& scale, const Tensor& shift, + std::string name = "ScaleShift", std::string tag = kBroadcast) { return tvm::te::compute( - x->shape, - [&](Var b, Var h, Var w, Var c) { - return x(b, h, w, c) * scale(c) + shift(w); - }, name, tag); + x->shape, [&](Var b, Var h, Var w, Var c) { return x(b, h, w, c) * scale(c) + shift(c); }, + name, tag); } } // namespace nn diff --git a/topi/include/topi/nn/pooling.h b/topi/include/topi/nn/pooling.h index 20b7b246317b..f6435cd2f42a 100644 --- a/topi/include/topi/nn/pooling.h +++ b/topi/include/topi/nn/pooling.h @@ -28,7 +28,7 @@ #include #include #include -#include +#include #include #include @@ -45,31 +45,25 @@ enum PoolType : int { kMaxPool, }; - /*! -* \brief Perform pooling on height and width dimension of data. -* -* \param x The input tensor -* \param kernel_size Vector of two ints: {kernel_height, kernel_width} -* \param stride_size Vector of two ints: {stride_height, stride_width} -* \param padding_size Vector of two ints: {padding_height, padding_width} -* \param pool_type The type of pooling operator -* \param ceil_mode Whether to use ceil when calculating the output size -* \param height_axis index of the height dimension -* \param width_axis index of the width dimension -* \param count_include_pad Whether include padding in the calculation -* -* \return The output tensor in same layout order -*/ -inline Tensor pool_impl(const Tensor& x, - const Array& kernel_size, - const Array& stride_size, - const Array& padding_size, - PoolType pool_type, - bool ceil_mode, - const size_t height_axis, - const size_t width_axis, - bool count_include_pad) { + * \brief Perform pooling on height and width dimension of data. + * + * \param x The input tensor + * \param kernel_size Vector of two ints: {kernel_height, kernel_width} + * \param stride_size Vector of two ints: {stride_height, stride_width} + * \param padding_size Vector of two ints: {padding_height, padding_width} + * \param pool_type The type of pooling operator + * \param ceil_mode Whether to use ceil when calculating the output size + * \param height_axis index of the height dimension + * \param width_axis index of the width dimension + * \param count_include_pad Whether include padding in the calculation + * + * \return The output tensor in same layout order + */ +inline Tensor pool_impl(const Tensor& x, const Array& kernel_size, + const Array& stride_size, const Array& padding_size, + PoolType pool_type, bool ceil_mode, const size_t height_axis, + const size_t width_axis, bool count_include_pad) { CHECK(x->shape.size() >= 2) << "Pooling input must >= 2-D (H, W)"; CHECK_EQ(kernel_size.size(), 2) << "Pooling kernel_size must have 2 elements"; CHECK_EQ(stride_size.size(), 2) << "Pooling stride_size must have 2 elements"; @@ -102,11 +96,11 @@ inline Tensor pool_impl(const Tensor& x, Array pad_after(std::vector(x->shape.size(), 0)); pad_after.Set(height_axis, pad_bottom); pad_after.Set(width_axis, pad_right); - - auto out_height = tvm::tir::Simplify( - indexdiv(height - kernel_height + pad_top + pad_bottom, stride_height) + 1); - auto out_width = tvm::tir::Simplify( - indexdiv(width - kernel_width + pad_left + pad_right, stride_width) + 1); + arith::Analyzer analyzer; + auto out_height = + analyzer.Simplify(indexdiv(height - kernel_height + pad_top + pad_bottom, stride_height) + 1); + auto out_width = + analyzer.Simplify(indexdiv(width - kernel_width + pad_left + pad_right, stride_width) + 1); auto dheight = tvm::te::reduce_axis(Range(0, kernel_height)); auto dwidth = tvm::te::reduce_axis(Range(0, kernel_width)); @@ -115,69 +109,73 @@ inline Tensor pool_impl(const Tensor& x, out_shape.Set(height_axis, out_height); out_shape.Set(width_axis, out_width); - const int64_t *padding_h0 = as_const_int(pad_top); - const int64_t *padding_w0 = as_const_int(pad_left); - const int64_t *padding_h1 = as_const_int(pad_bottom); - const int64_t *padding_w1 = as_const_int(pad_right); + const int64_t* padding_h0 = as_const_int(pad_top); + const int64_t* padding_w0 = as_const_int(pad_left); + const int64_t* padding_h1 = as_const_int(pad_bottom); + const int64_t* padding_w1 = as_const_int(pad_right); const bool do_pad = ((padding_h0 && *padding_h0) || (padding_w0 && *padding_w0)) || ((padding_h1 && *padding_h1) || (padding_w1 && *padding_w1)); if (pool_type == kMaxPool) { - auto temp = do_pad ? pad( - x, pad_before, pad_after, tvm::min_value(x->dtype), "pad_temp") : x; - return tvm::te::compute(out_shape, [&](const Array& output) { - Array indices; - for (const Var& var : output) indices.push_back(var); - indices.Set(height_axis, output[height_axis] * stride_height + dheight); - indices.Set(width_axis, output[width_axis] * stride_width + dwidth); - return tvm::max(temp(indices), { dheight, dwidth }); - }, "tensor", "pool_max"); + auto temp = do_pad ? pad(x, pad_before, pad_after, tvm::min_value(x->dtype), "pad_temp") : x; + return tvm::te::compute( + out_shape, + [&](const Array& output) { + Array indices; + for (const Var& var : output) indices.push_back(var); + indices.Set(height_axis, output[height_axis] * stride_height + dheight); + indices.Set(width_axis, output[width_axis] * stride_width + dwidth); + return tvm::max(temp(indices), {dheight, dwidth}); + }, + "tensor", "pool_max"); } else if (pool_type == kAvgPool) { // Pad the inputs auto temp = do_pad ? pad(x, pad_before, pad_after, 0, "pad_temp") : x; // TVM compute for summing the pooling window. - auto pool_sum = tvm::te::compute(out_shape, - [&](const Array& output) { - Array indices; - for (const Var& var : output) indices.push_back(var); - indices.Set(height_axis, output[height_axis] * stride_height + dheight); - indices.Set(width_axis, output[width_axis] * stride_width + dwidth); - return tvm::sum(temp(indices), { dheight, dwidth }); - }, "tensor", "pool_sum"); + auto pool_sum = tvm::te::compute( + out_shape, + [&](const Array& output) { + Array indices; + for (const Var& var : output) indices.push_back(var); + indices.Set(height_axis, output[height_axis] * stride_height + dheight); + indices.Set(width_axis, output[width_axis] * stride_width + dwidth); + return tvm::sum(temp(indices), {dheight, dwidth}); + }, + "tensor", "pool_sum"); // TVM compute for dividing the reduced window sum by kernel size. - return tvm::te::compute(out_shape, - [&](const Array& output) { - Array indices; - for (const Var& var : output) indices.push_back(var); - if (count_include_pad) { - return div(pool_sum(indices), (kernel_height * kernel_width)); - } else { - PrimExpr h_start = output[height_axis] * stride_height - pad_top; - PrimExpr w_start = output[width_axis] * stride_width - pad_left; - PrimExpr h_end = tir::MinNode::make(h_start + kernel_height, height); - PrimExpr w_end = tir::MinNode::make(w_start + kernel_width, width); - h_start = tir::MaxNode::make(h_start, make_const(DataType::DataType::Int(32), 0)); - w_start = tir::MaxNode::make(w_start, make_const(DataType::DataType::Int(32), 0)); - PrimExpr divide_factor = tir::MaxNode::make((h_end - h_start) * (w_end - w_start), - make_const(DataType::DataType::Int(32), 1)); - return div(pool_sum(indices), divide_factor); - } - }, "tensor", kElementWise); + return tvm::te::compute( + out_shape, + [&](const Array& output) { + Array indices; + for (const Var& var : output) indices.push_back(var); + if (count_include_pad) { + return div(pool_sum(indices), (kernel_height * kernel_width)); + } else { + PrimExpr h_start = output[height_axis] * stride_height - pad_top; + PrimExpr w_start = output[width_axis] * stride_width - pad_left; + + PrimExpr h_end = min(h_start + kernel_height, height); + PrimExpr w_end = min(w_start + kernel_width, width); + h_start = max(h_start, make_const(DataType::DataType::Int(32), 0)); + w_start = max(w_start, make_const(DataType::DataType::Int(32), 0)); + PrimExpr divide_factor = max((h_end - h_start) * (w_end - w_start), + make_const(DataType::DataType::Int(32), 1)); + return div(pool_sum(indices), divide_factor); + } + }, + "tensor", kElementWise); } else { LOG(ERROR) << "Unrecognized pool_type: " << pool_type; return x; } } -inline Tensor pool_grad_impl(const Tensor& out_grad, - const Tensor& x, - const Array& kernel_size, - const Array& stride_size, - const Array& padding_size, - PoolType pool_type, bool ceil_mode, - const size_t height_axis, const size_t width_axis, +inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x, + const Array& kernel_size, const Array& stride_size, + const Array& padding_size, PoolType pool_type, + bool ceil_mode, const size_t height_axis, const size_t width_axis, bool count_include_pad) { CHECK(out_grad->shape.size() >= 2) << "Pooling grad output must >= 2-D (H, W)"; CHECK(x->shape.size() >= 2) << "Pooling input must >= 2-D (H, W)"; @@ -212,11 +210,11 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, Array pad_after(std::vector(x->shape.size(), 0)); pad_after.Set(height_axis, pad_bottom); pad_after.Set(width_axis, pad_right); - + arith::Analyzer analyzer; auto out_height = - tvm::tir::Simplify((height - kernel_height + pad_top + pad_bottom) / stride_height + 1); + analyzer.Simplify((height - kernel_height + pad_top + pad_bottom) / stride_height + 1); auto out_width = - tvm::tir::Simplify((width - kernel_width + pad_left + pad_right) / stride_width + 1); + analyzer.Simplify((width - kernel_width + pad_left + pad_right) / stride_width + 1); auto dheight = tvm::te::reduce_axis(Range(0, kernel_height)); auto dwidth = tvm::te::reduce_axis(Range(0, kernel_width)); @@ -237,62 +235,57 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, ravel_shape.Set(height_axis, ravel_shape[height_axis] + pad_top + pad_bottom); ravel_shape.Set(width_axis, ravel_shape[width_axis] + pad_left + pad_right); - auto windowh = tvm::te::reduce_axis( - Range(0, (kernel_height + stride_height - 1) / stride_height)); - auto windoww = tvm::te::reduce_axis( - Range(0, (kernel_width + stride_width - 1) / stride_width)); + auto windowh = + tvm::te::reduce_axis(Range(0, (kernel_height + stride_height - 1) / stride_height)); + auto windoww = tvm::te::reduce_axis(Range(0, (kernel_width + stride_width - 1) / stride_width)); auto argmax = MakeArgmaxReducer(); - auto pad_x = do_pad ? pad( - x, pad_before, pad_after, tvm::min_value(x->dtype), "pad_temp") : x; - - auto mp_argmax = - tvm::te::compute( - out_shape, - [&](const Array& inds) { - Array window_inds{inds.begin(), inds.end()}; - window_inds.Set(height_axis, inds[height_axis] * stride_height + dheight); - window_inds.Set(width_axis, inds[width_axis] * stride_width + dwidth); - auto idx = detail::RavelIndex(window_inds, ravel_shape); - return argmax({idx, pad_x(window_inds)}, {dheight, dwidth}, nullptr); - }, - "maxpool_grad_argmax", kCommReduceIdx); + auto pad_x = do_pad ? pad(x, pad_before, pad_after, tvm::min_value(x->dtype), "pad_temp") : x; + + auto mp_argmax = tvm::te::compute( + out_shape, + [&](const Array& inds) { + Array window_inds{inds.begin(), inds.end()}; + window_inds.Set(height_axis, inds[height_axis] * stride_height + dheight); + window_inds.Set(width_axis, inds[width_axis] * stride_width + dwidth); + auto idx = detail::RavelIndex(window_inds, ravel_shape); + return argmax({idx, pad_x(window_inds)}, {dheight, dwidth}, nullptr); + }, + "maxpool_grad_argmax", kCommReduceIdx); auto mp_inds = mp_argmax[0]; return tvm::te::compute( x->shape, [&](const Array& inds) { - Array pad_inds {inds.begin(), inds.end()}; + Array pad_inds{inds.begin(), inds.end()}; pad_inds.Set(height_axis, pad_inds[height_axis] + pad_top); pad_inds.Set(width_axis, pad_inds[width_axis] + pad_left); auto idx = detail::RavelIndex(pad_inds, ravel_shape); - Array out_idx {inds.begin(), inds.end()}; + Array out_idx{inds.begin(), inds.end()}; out_idx.Set(height_axis, (inds[height_axis] + pad_top) / stride_height - windowh); out_idx.Set(width_axis, (inds[width_axis] + pad_left) / stride_width - windoww); - PrimExpr out_idx_lower_h = tir::SelectNode::make( + PrimExpr out_idx_lower_h = tir::Select( pad_inds[height_axis] < kernel_height, make_const(DataType::DataType::Int(32), 0), (pad_inds[height_axis] - kernel_height) / stride_height + 1); - PrimExpr out_idx_lower_w = tir::SelectNode::make( + PrimExpr out_idx_lower_w = tir::Select( pad_inds[width_axis] < kernel_width, make_const(DataType::DataType::Int(32), 0), (pad_inds[width_axis] - kernel_width) / stride_width + 1); return tvm::sum( - tvm::if_then_else(tir::AndNode::make( - tir::AndNode::make(out_idx[height_axis] >= out_idx_lower_h, - out_idx[width_axis] >= out_idx_lower_w), - mp_inds(out_idx) == idx), - out_grad(out_idx), make_const(x->dtype, 0)), + tvm::if_then_else(tir::And(tir::And(out_idx[height_axis] >= out_idx_lower_h, + out_idx[width_axis] >= out_idx_lower_w), + mp_inds(out_idx) == idx), + out_grad(out_idx), make_const(x->dtype, 0)), {windowh, windoww}); }, "T_pool_grad", "pool_grad_max"); } else if (pool_type == kAvgPool) { - auto windowh = tvm::te::reduce_axis( - Range(0, (kernel_height + stride_height - 1) / stride_height)); - auto windoww = tvm::te::reduce_axis( - Range(0, (kernel_width + stride_width - 1) / stride_width)); + auto windowh = + tvm::te::reduce_axis(Range(0, (kernel_height + stride_height - 1) / stride_height)); + auto windoww = tvm::te::reduce_axis(Range(0, (kernel_width + stride_width - 1) / stride_width)); return tvm::te::compute( x->shape, [&](const Array& inds) { @@ -304,12 +297,12 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, out_idx.Set(height_axis, (pad_h_idx / stride_height - windowh)); out_idx.Set(width_axis, (pad_w_idx / stride_width - windoww)); - PrimExpr out_idx_lower_h = tir::SelectNode::make( - pad_h_idx < kernel_height, make_const(DataType::Int(32), 0), - (pad_h_idx - kernel_height) / stride_height + 1); - PrimExpr out_idx_lower_w = tir::SelectNode::make( - pad_w_idx < kernel_width, make_const(DataType::Int(32), 0), - (pad_w_idx - kernel_width) / stride_width + 1); + PrimExpr out_idx_lower_h = + tir::Select(pad_h_idx < kernel_height, make_const(DataType::Int(32), 0), + (pad_h_idx - kernel_height) / stride_height + 1); + PrimExpr out_idx_lower_w = + tir::Select(pad_w_idx < kernel_width, make_const(DataType::Int(32), 0), + (pad_w_idx - kernel_width) / stride_width + 1); PrimExpr divide_factor; // number of pooled elements if (count_include_pad) { @@ -317,21 +310,20 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, } else { PrimExpr h_start = out_idx[height_axis] * stride_height - pad_top; PrimExpr w_start = out_idx[width_axis] * stride_width - pad_left; - PrimExpr h_end = tir::MinNode::make(h_start + kernel_height, height); - PrimExpr w_end = tir::MinNode::make(w_start + kernel_width, width); - h_start = tir::MaxNode::make(h_start, make_const(DataType::Int(32), 0)); - w_start = tir::MaxNode::make(w_start, make_const(DataType::Int(32), 0)); + + PrimExpr h_end = min(h_start + kernel_height, height); + PrimExpr w_end = min(w_start + kernel_width, width); + h_start = max(h_start, make_const(DataType::Int(32), 0)); + w_start = max(w_start, make_const(DataType::Int(32), 0)); divide_factor = - tir::MaxNode::make((h_end - h_start) * (w_end - w_start), - make_const(DataType::Int(32), 1)); + max((h_end - h_start) * (w_end - w_start), make_const(DataType::Int(32), 1)); } - return tvm::sum(tvm::if_then_else( - tir::AndNode::make( - tir::AndNode::make(out_idx[height_axis] >= out_idx_lower_h, - out_idx[height_axis] < out_height), - tir::AndNode::make(out_idx[width_axis] >= out_idx_lower_w, - out_idx[width_axis] < out_width)), - out_grad(out_idx) / divide_factor, make_const(out_grad->dtype, 0)), + return tvm::sum( + tvm::if_then_else(tir::And(tir::And(out_idx[height_axis] >= out_idx_lower_h, + out_idx[height_axis] < out_height), + tir::And(out_idx[width_axis] >= out_idx_lower_w, + out_idx[width_axis] < out_width)), + out_grad(out_idx) / divide_factor, make_const(out_grad->dtype, 0)), {windowh, windoww}); }, "T_pool_grad", "pool_grad_avg"); @@ -341,15 +333,12 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, } } -inline bool find_depth_height_width(const std::string& layout, - int* depth_axis, - int* height_axis, +inline bool find_depth_height_width(const std::string& layout, int* depth_axis, int* height_axis, int* width_axis) { *depth_axis = -1, *height_axis = -1, *width_axis = -1; int curr_idx = 0; for (size_t i = 0; i < layout.size(); ++i) { - if ((layout[i] >= 'A' && layout[i] <= 'Z') || - (layout[i] >= 'a' && layout[i] <= 'z')) { + if ((layout[i] >= 'A' && layout[i] <= 'Z') || (layout[i] >= 'a' && layout[i] <= 'z')) { if (layout[i] == 'D') { if (*depth_axis != -1) return false; *depth_axis = curr_idx; @@ -370,21 +359,18 @@ inline bool find_depth_height_width(const std::string& layout, return true; } -inline bool find_height_width(const std::string& layout, - int* height_axis, - int* width_axis) { +inline bool find_height_width(const std::string& layout, int* height_axis, int* width_axis) { int dummy; - CHECK_EQ(find_depth_height_width(layout, &dummy, height_axis, width_axis), false); + CHECK_EQ(find_depth_height_width(layout, &dummy, height_axis, width_axis), false); if (*height_axis != -1 && *width_axis != -1) { return true; } return false; } -inline bool find_width(const std::string& layout, - int* width_axis) { +inline bool find_width(const std::string& layout, int* width_axis) { int dummy; - CHECK_EQ(find_depth_height_width(layout, &dummy, &dummy, width_axis), false); + CHECK_EQ(find_depth_height_width(layout, &dummy, &dummy, width_axis), false); if (*width_axis != -1) { return true; } @@ -392,48 +378,42 @@ inline bool find_width(const std::string& layout, } /*! -* \brief Perform pooling on height and width dimension of data. -* It decides the height and width dimension according to the layout string, -* in which 'W' and 'H' means width and height respectively. -* Width and height dimension cannot be split. -* For example, NCHW, NCHW16c, etc. are valid for pool, -* while NCHW16w, NCHW16h are not. -* See \a layout for more information of the layout string convention. -* \param x The input tensor. -* \param kernel_size Vector of two ints: {kernel_height, kernel_width} -* \param stride_size Vector of two ints: {stride_height, stride_width} -* \param padding_size Vector of two ints: {padding_height, padding_width} -* \param pool_type The type of pooling operator -* \param ceil_mode Whether to use ceil when calculating the output size -* \param layout The input layout. Pooling supports any layout as long as 'H' and 'W' appear. -* The layout is supposed to be composed of upper cases, lower cases and (optional) numbers, -* where upper case indicates a dimension and -* the corresponding lower case (with factor size) indicates the split dimension. -* For example, NCHW16c can describe a 5-D tensor of -* [batch_size, channel, height, width, channel_block]. -* (in which factor size `16` will not be used in pooling but for other operators, -* it can be used to decide the output shape). -* Since pooling does not care about the factor size of dimensions -* other than `H` and `W`, one can pass `NCHWc` as well. -* \param count_include_pad Whether include padding in the calculation when pool_type is 'avg' -* -* -* \return The output tensor in the same layout -*/ -inline Tensor pool(const Tensor& x, - const Array& kernel_size, - const Array& stride_size, - const Array& padding_size, - PoolType pool_type, - bool ceil_mode, - const std::string& layout = "NCHW", + * \brief Perform pooling on height and width dimension of data. + * It decides the height and width dimension according to the layout string, + * in which 'W' and 'H' means width and height respectively. + * Width and height dimension cannot be split. + * For example, NCHW, NCHW16c, etc. are valid for pool, + * while NCHW16w, NCHW16h are not. + * See \a layout for more information of the layout string convention. + * \param x The input tensor. + * \param kernel_size Vector of two ints: {kernel_height, kernel_width} + * \param stride_size Vector of two ints: {stride_height, stride_width} + * \param padding_size Vector of two ints: {padding_height, padding_width} + * \param pool_type The type of pooling operator + * \param ceil_mode Whether to use ceil when calculating the output size + * \param layout The input layout. Pooling supports any layout as long as 'H' and 'W' appear. + * The layout is supposed to be composed of upper cases, lower cases and (optional) numbers, + * where upper case indicates a dimension and + * the corresponding lower case (with factor size) indicates the split dimension. + * For example, NCHW16c can describe a 5-D tensor of + * [batch_size, channel, height, width, channel_block]. + * (in which factor size `16` will not be used in pooling but for other operators, + * it can be used to decide the output shape). + * Since pooling does not care about the factor size of dimensions + * other than `H` and `W`, one can pass `NCHWc` as well. + * \param count_include_pad Whether include padding in the calculation when pool_type is 'avg' + * + * + * \return The output tensor in the same layout + */ +inline Tensor pool(const Tensor& x, const Array& kernel_size, + const Array& stride_size, const Array& padding_size, + PoolType pool_type, bool ceil_mode, const std::string& layout = "NCHW", bool count_include_pad = true) { int height_axis = -1, width_axis = -1; - CHECK(find_height_width(layout, &height_axis, &width_axis)) - << "Unsupported layout " << layout; - return pool_impl(x, kernel_size, stride_size, padding_size, - pool_type, ceil_mode, height_axis, width_axis, - count_include_pad); + CHECK(find_height_width(layout, &height_axis, &width_axis)) << "Unsupported layout " << layout; + return pool_impl(x, kernel_size, stride_size, padding_size, pool_type, ceil_mode, height_axis, + width_axis, count_include_pad); } /*! @@ -476,34 +456,27 @@ inline Tensor pool_grad(const Tensor& out_grad, const Tensor& x, const Array& output_size, - PoolType pool_type, - const std::vector& axes) { + * \brief Perform adaptive pooling on N dimensional data + * + * \param x The input tensor + * \param output_size int vector of size in each dimension + * \param pool_type The type of pooling operator + * \param axes indices of each dimension + * + * \return The output tensor in same layout order + */ +inline Tensor adaptive_pool_impl(const Tensor& x, const Array& output_size, + PoolType pool_type, const std::vector& axes) { const auto n_dim = output_size.size(); CHECK_EQ(axes.size(), n_dim) << "The number of axes not equal to the in/out dimension"; @@ -533,32 +506,41 @@ inline Tensor adaptive_pool_impl(const Tensor& x, }; if (pool_type == kMaxPool) { - return tvm::te::compute(out_shape, [&](const Array& output) { - Array indices; - Array reduce_axes; - std::tie(indices, reduce_axes) = get_iter_vars(output, true); - return tvm::max(x(indices), reduce_axes); // NOLINT(*) - }, "tensor", "adaptive_pool_max"); + return tvm::te::compute( + out_shape, + [&](const Array& output) { + Array indices; + Array reduce_axes; + std::tie(indices, reduce_axes) = get_iter_vars(output, true); + return tvm::max(x(indices), reduce_axes); // NOLINT(*) + }, + "tensor", "adaptive_pool_max"); } else if (pool_type == kAvgPool) { - auto pool_sum = tvm::te::compute(out_shape, [&](const Array& output) { - Array indices; - Array reduce_axes; - std::tie(indices, reduce_axes) = get_iter_vars(output, true); - return tvm::sum(x(indices), reduce_axes); - }, "tensor", "adaptive_pool_sum"); - - return tvm::te::compute(out_shape, [&](const Array& output) { - Array indices; - Array reduce_axes; - std::tie(indices, reduce_axes) = get_iter_vars(output, false); - - PrimExpr divide_factor = tvm::cast(x->dtype, 1); - for (size_t i = 0; i < n_dim; ++i) { - divide_factor *= tvm::cast(x->dtype, reduce_axes[i]->dom->extent); - } + auto pool_sum = tvm::te::compute( + out_shape, + [&](const Array& output) { + Array indices; + Array reduce_axes; + std::tie(indices, reduce_axes) = get_iter_vars(output, true); + return tvm::sum(x(indices), reduce_axes); + }, + "tensor", "adaptive_pool_sum"); - return div(pool_sum(indices), divide_factor); - }, "tensor", kElementWise); + return tvm::te::compute( + out_shape, + [&](const Array& output) { + Array indices; + Array reduce_axes; + std::tie(indices, reduce_axes) = get_iter_vars(output, false); + + PrimExpr divide_factor = tvm::cast(x->dtype, 1); + for (size_t i = 0; i < n_dim; ++i) { + divide_factor *= tvm::cast(x->dtype, reduce_axes[i]->dom->extent); + } + + return div(pool_sum(indices), divide_factor); + }, + "tensor", kElementWise); } else { LOG(ERROR) << "Unrecognized pool_type: " << pool_type; return x; @@ -566,118 +548,107 @@ inline Tensor adaptive_pool_impl(const Tensor& x, } /*! -* \brief Adaptively perform pooling on height and width dimension of data. -* The pooling kernel and stride sizes are automatically chosen for desired output sizes. -* It decides the height and width dimension according to the layout string, -* in which 'W' and 'H' means width and height respectively. -* Width and height dimension cannot be split. -* For example, NCHW, NCHW16c, etc. are valid for pool, -* while NCHW16w, NCHW16h are not. -* See \a layout for more information of the layout string convention. -* -* \param x The input tensor -* \param output_size Vector of two ints: {output_height, output_width} -* \param pool_type The type of pooling operator -* \param layout The input layout. Pooling supports any layout as long as 'H' and 'W' appear. -* The layout is supposed to be composed of upper cases, lower cases and (optional) numbers, -* where upper case indicates a dimension and -* the corresponding lower case (with factor size) indicates the split dimension. -* For example, NCHW16c can describe a 5-D tensor of -* [batch_size, channel, height, width, channel_block]. -* (in which factor size `16` will not be used in pooling but for other operators, -* it can be used to decide the output shape). -* Since pooling does not care about the factor size of dimensions -* other than `H` and `W`, one can pass `NCHWc` as well. -* -* \return The output tensor in same layout order -*/ -inline Tensor adaptive_pool(const Tensor& x, - const Array& output_size, - PoolType pool_type, + * \brief Adaptively perform pooling on height and width dimension of data. + * The pooling kernel and stride sizes are automatically chosen for desired output sizes. + * It decides the height and width dimension according to the layout string, + * in which 'W' and 'H' means width and height respectively. + * Width and height dimension cannot be split. + * For example, NCHW, NCHW16c, etc. are valid for pool, + * while NCHW16w, NCHW16h are not. + * See \a layout for more information of the layout string convention. + * + * \param x The input tensor + * \param output_size Vector of two ints: {output_height, output_width} + * \param pool_type The type of pooling operator + * \param layout The input layout. Pooling supports any layout as long as 'H' and 'W' appear. + * The layout is supposed to be composed of upper cases, lower cases and (optional) numbers, + * where upper case indicates a dimension and + * the corresponding lower case (with factor size) indicates the split dimension. + * For example, NCHW16c can describe a 5-D tensor of + * [batch_size, channel, height, width, channel_block]. + * (in which factor size `16` will not be used in pooling but for other operators, + * it can be used to decide the output shape). + * Since pooling does not care about the factor size of dimensions + * other than `H` and `W`, one can pass `NCHWc` as well. + * + * \return The output tensor in same layout order + */ +inline Tensor adaptive_pool(const Tensor& x, const Array& output_size, PoolType pool_type, const std::string& layout = "NCHW") { int height_axis = -1, width_axis = -1; - CHECK(find_height_width(layout, &height_axis, &width_axis)) - << "Unsupported layout " << layout; + CHECK(find_height_width(layout, &height_axis, &width_axis)) << "Unsupported layout " << layout; return adaptive_pool_impl(x, output_size, pool_type, {height_axis, width_axis}); } /*! -* \brief Adaptively perform pooling on three dimensional data. -* See the two dimensional version above for details. -* \param x The input tensor -* \param output_size Vector of three ints: {output_depth, output_height, output_width} -* \param pool_type The type of pooling operator -* \param layout The input layout. The default is "NCDHW". -*/ -inline Tensor adaptive_pool3d(const Tensor& x, - const Array& output_size, - PoolType pool_type, - const std::string& layout = "NCDHW") { + * \brief Adaptively perform pooling on three dimensional data. + * See the two dimensional version above for details. + * \param x The input tensor + * \param output_size Vector of three ints: {output_depth, output_height, output_width} + * \param pool_type The type of pooling operator + * \param layout The input layout. The default is "NCDHW". + */ +inline Tensor adaptive_pool3d(const Tensor& x, const Array& output_size, + PoolType pool_type, const std::string& layout = "NCDHW") { int depth_axis = -1, height_axis = -1, width_axis = -1; CHECK(find_depth_height_width(layout, &depth_axis, &height_axis, &width_axis)) - << "Unsupported layout " << layout; + << "Unsupported layout " << layout; return adaptive_pool_impl(x, output_size, pool_type, {depth_axis, height_axis, width_axis}); } /*! -* \brief Perform global pooling on height and width dimension of data. -* It decides the height and width dimension according to the layout string, -* in which 'W' and 'H' means width and height respectively. -* Width and height dimension cannot be split. -* For example, NCHW, NCHW16c, ... are valid for global_pool, -* while NCHW16w, NCHW16h are not. -* See \a layout for more information of the layout string convention. -* -* \param x The input tensor represent as layout -* \param pool_type The type of pooling operator -* \param layout The input layout. global-pooling supports any layout as long as 'H' and 'W' appear. -* The layout is supposed to be composed of upper cases, lower cases and (optional) numbers, -* where upper case indicates a dimension and -* the corresponding lower case (with factor size) indicates the sub-dimension. -* For example, `NCHW16c` can describe a 5-D tensor of -* [batch_size, channel, height, width, channel_block]. -* (in which factor size `16` will not be used in pooling but for other operators, -* it can be used to decide the output shape). -* Since pooling does not care about the factor size of -* dimensions other than `H` and `W`, one can pass `NCHWc` as well. -* -* \return The output tensor in same layout with height and width dimension size of 1. -* e.g., for NCHW, the output shape will be [batch, channel, 1, 1] -*/ -inline Tensor global_pool(const Tensor& x, - PoolType pool_type, - const std::string& layout = "NCHW") { + * \brief Perform global pooling on height and width dimension of data. + * It decides the height and width dimension according to the layout string, + * in which 'W' and 'H' means width and height respectively. + * Width and height dimension cannot be split. + * For example, NCHW, NCHW16c, ... are valid for global_pool, + * while NCHW16w, NCHW16h are not. + * See \a layout for more information of the layout string convention. + * + * \param x The input tensor represent as layout + * \param pool_type The type of pooling operator + * \param layout The input layout. global-pooling supports any layout as long as 'H' and 'W' appear. + * The layout is supposed to be composed of upper cases, lower cases and (optional) numbers, + * where upper case indicates a dimension and + * the corresponding lower case (with factor size) indicates the sub-dimension. + * For example, `NCHW16c` can describe a 5-D tensor of + * [batch_size, channel, height, width, channel_block]. + * (in which factor size `16` will not be used in pooling but for other operators, + * it can be used to decide the output shape). + * Since pooling does not care about the factor size of + * dimensions other than `H` and `W`, one can pass `NCHWc` as well. + * + * \return The output tensor in same layout with height and width dimension size of 1. + * e.g., for NCHW, the output shape will be [batch, channel, 1, 1] + */ +inline Tensor global_pool(const Tensor& x, PoolType pool_type, const std::string& layout = "NCHW") { return adaptive_pool(x, Array{1, 1}, pool_type, layout); } /*! -* \brief Perform pooling on N-dimension of data. -* -* \param x The input tensor -* \param kernel_size Vector of N ints -* \param stride_size Vector of N ints -* \param padding_size Vector of N*2 ints [head_pad_d1, head_pad_d2, ..., -* head_pad_dN, tail_pad_d1, tail_pad_d2, ..., tail_pad_dN] -* \param pool_type The type of pooling operator -* \param ceil_mode Whether to use ceil when calculating the output size -* \param axis Vector of indices for the N dimensions -* \param count_include_pad Whether include padding in the calculation -* -* \return The output tensor in same layout order -*/ -inline Tensor pool_impl_nd(const Tensor& x, - const Array& kernel_size, - const Array& stride_size, - const Array& padding_size, - PoolType pool_type, - bool ceil_mode, - const std::vector& axis, + * \brief Perform pooling on N-dimension of data. + * + * \param x The input tensor + * \param kernel_size Vector of N ints + * \param stride_size Vector of N ints + * \param padding_size Vector of N*2 ints [head_pad_d1, head_pad_d2, ..., + * head_pad_dN, tail_pad_d1, tail_pad_d2, ..., tail_pad_dN] + * \param pool_type The type of pooling operator + * \param ceil_mode Whether to use ceil when calculating the output size + * \param axis Vector of indices for the N dimensions + * \param count_include_pad Whether include padding in the calculation + * + * \return The output tensor in same layout order + */ +inline Tensor pool_impl_nd(const Tensor& x, const Array& kernel_size, + const Array& stride_size, const Array& padding_size, + PoolType pool_type, bool ceil_mode, const std::vector& axis, bool count_include_pad) { int k_size = kernel_size.size(); int x_size = x->shape.size(); CHECK_EQ(stride_size.size(), k_size) << "Pooling stride_size must have same elements as kernel"; CHECK_EQ(padding_size.size(), k_size * 2) << "Pooling padding_size must has double elements of" - " kernel"; + " kernel"; CHECK_EQ(axis.size(), k_size) << "axis must have same elements as kernel"; Array daxis; @@ -696,8 +667,8 @@ inline Tensor pool_impl_nd(const Tensor& x, stride[i] = cast(DataType::Int(32), stride_size[i]); pad_head[i] = cast(DataType::Int(32), padding_size[i]); pad_tail[i] = cast(DataType::Int(32), padding_size[i + k_size]); - const int64_t *padding0 = as_const_int(pad_head[i]); - const int64_t *padding1 = as_const_int(pad_tail[i]); + const int64_t* padding0 = as_const_int(pad_head[i]); + const int64_t* padding1 = as_const_int(pad_tail[i]); do_pad = (do_pad) ? do_pad : ((padding0 && *padding0) || (padding1 && *padding1)); if (ceil_mode) { @@ -711,70 +682,77 @@ inline Tensor pool_impl_nd(const Tensor& x, pad_before.Set(ii, pad_head[i]); pad_after.Set(ii, pad_tail[i]); - auto out_dim = tvm::tir::Simplify( - indexdiv(x->shape[ii] - kernel[i] + pad_head[i] + pad_tail[i], stride[i]) + 1); + arith::Analyzer analyzer; + auto out_dim = analyzer.Simplify( + indexdiv(x->shape[ii] - kernel[i] + pad_head[i] + pad_tail[i], stride[i]) + 1); out_shape.Set(ii, out_dim); } if (pool_type == kMaxPool) { - auto temp = do_pad ? pad( - x, pad_before, pad_after, tvm::min_value(x->dtype), "pad_temp") : x; - return tvm::te::compute(out_shape, [&](const Array& output) { - Array indices; - for (const Var& var : output) indices.push_back(var); - - for (int i = 0; i < k_size; i++) { - int ii = axis[i]; - indices.Set(ii, output[ii] * stride[i] + daxis[i]); - } + auto temp = do_pad ? pad(x, pad_before, pad_after, tvm::min_value(x->dtype), "pad_temp") : x; + return tvm::te::compute( + out_shape, + [&](const Array& output) { + Array indices; + for (const Var& var : output) indices.push_back(var); + + for (int i = 0; i < k_size; i++) { + int ii = axis[i]; + indices.Set(ii, output[ii] * stride[i] + daxis[i]); + } - return tvm::max(temp(indices), daxis); - }, "tensor", "pool_max"); + return tvm::max(temp(indices), daxis); + }, + "tensor", "pool_max"); } else if (pool_type == kAvgPool) { // Pad the inputs auto temp = do_pad ? pad(x, pad_before, pad_after, 0, "pad_temp") : x; // TVM compute for summing the pooling window. - auto pool_sum = tvm::te::compute(out_shape, - [&](const Array& output) { - Array indices; - for (const Var& var : output) indices.push_back(var); - - for (int i = 0; i < k_size; i++) { - int ii = axis[i]; - indices.Set(ii, output[ii] * stride[i] + daxis[i]); - } - return tvm::sum(temp(indices), daxis); - }, "tensor", "pool_sum"); + auto pool_sum = tvm::te::compute( + out_shape, + [&](const Array& output) { + Array indices; + for (const Var& var : output) indices.push_back(var); + + for (int i = 0; i < k_size; i++) { + int ii = axis[i]; + indices.Set(ii, output[ii] * stride[i] + daxis[i]); + } + return tvm::sum(temp(indices), daxis); + }, + "tensor", "pool_sum"); // TVM compute for dividing the reduced window sum by kernel size. - return tvm::te::compute(out_shape, - [&](const Array& output) { - Array indices; - for (const Var& var : output) indices.push_back(var); - if (count_include_pad) { - auto kernel_size = make_const(DataType::Int(32), 1); - for (int i = 0; i < k_size; i++) { - kernel_size *= kernel[i]; - } - return div(pool_sum(indices), kernel_size); - } else { - std::vector start(k_size); - std::vector end(k_size); - auto kernel_size = make_const(DataType::Int(32), 1); - for (int i = 0; i < k_size; i++) { - int ii = axis[i]; - start[i] = output[ii] * stride[i] - pad_head[i]; - end[i] = tir::MinNode::make(start[i] + kernel[i], x->shape[ii]); - start[i] = tir::MaxNode::make(start[i], make_const(DataType::Int(32), 0)); - kernel_size *= (end[i] - start[i]); - } - - PrimExpr divide_factor = tir::MaxNode::make(kernel_size, make_const(DataType::Int(32), 1)); - return div(pool_sum(indices), divide_factor); - } - }, "tensor", kElementWise); + return tvm::te::compute( + out_shape, + [&](const Array& output) { + Array indices; + for (const Var& var : output) indices.push_back(var); + if (count_include_pad) { + auto kernel_size = make_const(DataType::Int(32), 1); + for (int i = 0; i < k_size; i++) { + kernel_size *= kernel[i]; + } + return div(pool_sum(indices), kernel_size); + } else { + std::vector start(k_size); + std::vector end(k_size); + auto kernel_size = make_const(DataType::Int(32), 1); + for (int i = 0; i < k_size; i++) { + int ii = axis[i]; + start[i] = output[ii] * stride[i] - pad_head[i]; + end[i] = min(start[i] + kernel[i], x->shape[ii]); + start[i] = max(start[i], make_const(DataType::Int(32), 0)); + kernel_size *= (end[i] - start[i]); + } + + PrimExpr divide_factor = max(kernel_size, make_const(DataType::Int(32), 1)); + return div(pool_sum(indices), divide_factor); + } + }, + "tensor", kElementWise); } else { LOG(ERROR) << "Unrecognized pool_type: " << pool_type; return x; @@ -782,94 +760,85 @@ inline Tensor pool_impl_nd(const Tensor& x, } /*! -* \brief Perform pooling on the width dimension of data. -* Width axis is determined by the layout string -* in which 'W' means width. -* Width dimension cannot be split. -* For example, NCW, NCW16c, etc. are valid for pool, -* while NCW16w is not. -* See \a layout for more information of the layout string convention. -* \param x The input tensor. -* \param kernel_size Vector of three ints: {kernel_width} -* \param stride_size Vector of three ints: {stride_width} -* \param padding_size Vector of six ints: {head_pad_width, tail_pad_width} -* \param pool_type The type of pooling operator -* \param ceil_mode Whether to use ceil when calculating the output size -* \param layout The input layout. Pooling supports any layout as long as 'W' appears. -* The layout is supposed to be composed of upper cases, lower cases and (optional) numbers, -* where upper case indicates a dimension and -* the corresponding lower case (with factor size) indicates the split dimension. -* For example, NCW16c can describe a 4-D tensor of -* [batch_size, channel, width, channel_block]. -* (in which factor size `16` will not be used in pooling but for other operators, -* it can be used to decide the output shape). -* Since pooling does not care about the factor size of dimensions -* other than `W`, one can pass `NCWc` as well. -* \param count_include_pad Whether include padding in the calculation when pool_type is 'avg' -* -* -* \return The output tensor in the same layout -*/ -inline Tensor pool1d(const Tensor& x, - const Array& kernel_size, - const Array& stride_size, - const Array& padding_size, - PoolType pool_type, - bool ceil_mode, - const std::string& layout = "NCW", + * \brief Perform pooling on the width dimension of data. + * Width axis is determined by the layout string + * in which 'W' means width. + * Width dimension cannot be split. + * For example, NCW, NCW16c, etc. are valid for pool, + * while NCW16w is not. + * See \a layout for more information of the layout string convention. + * \param x The input tensor. + * \param kernel_size Vector of three ints: {kernel_width} + * \param stride_size Vector of three ints: {stride_width} + * \param padding_size Vector of six ints: {head_pad_width, tail_pad_width} + * \param pool_type The type of pooling operator + * \param ceil_mode Whether to use ceil when calculating the output size + * \param layout The input layout. Pooling supports any layout as long as 'W' appears. + * The layout is supposed to be composed of upper cases, lower cases and (optional) numbers, + * where upper case indicates a dimension and + * the corresponding lower case (with factor size) indicates the split dimension. + * For example, NCW16c can describe a 4-D tensor of + * [batch_size, channel, width, channel_block]. + * (in which factor size `16` will not be used in pooling but for other operators, + * it can be used to decide the output shape). + * Since pooling does not care about the factor size of dimensions + * other than `W`, one can pass `NCWc` as well. + * \param count_include_pad Whether include padding in the calculation when pool_type is 'avg' + * + * + * \return The output tensor in the same layout + */ +inline Tensor pool1d(const Tensor& x, const Array& kernel_size, + const Array& stride_size, const Array& padding_size, + PoolType pool_type, bool ceil_mode, const std::string& layout = "NCW", bool count_include_pad = true) { int width_axis = -1; - CHECK(find_width(layout, &width_axis)) - << "Unsupported layout " << layout; + CHECK(find_width(layout, &width_axis)) << "Unsupported layout " << layout; std::vector axis = {width_axis}; - return pool_impl_nd(x, kernel_size, stride_size, padding_size, - pool_type, ceil_mode, axis, count_include_pad); + return pool_impl_nd(x, kernel_size, stride_size, padding_size, pool_type, ceil_mode, axis, + count_include_pad); } /*! -* \brief Perform pooling on depth, height and width dimension of data. -* It decides the depth, height and width dimension according to the layout string, -* in which 'D', 'W' and 'H' means depth, width and height respectively. -* Depth, Width and height dimension cannot be split. -* For example, NCDHW, NCDHW16c, etc. are valid for pool, -* while NCDHW16d, NCDHW16w or NCDHW16h are not. -* See \a layout for more information of the layout string convention. -* \param x The input tensor. -* \param kernel_size Vector of three ints: {kernel_depth, kernel_height, kernel_width} -* \param stride_size Vector of three ints: {stride_depth, stride_height, stride_width} -* \param padding_size Vector of six ints: {head_pad_depth, head_pad_height, head_pad_width, -* tail_pad_depth, tail_pad_height, tail_pad_width} -* \param pool_type The type of pooling operator -* \param ceil_mode Whether to use ceil when calculating the output size -* \param layout The input layout. Pooling supports any layout as long as 'D', 'H' and 'W' appear. -* The layout is supposed to be composed of upper cases, lower cases and (optional) numbers, -* where upper case indicates a dimension and -* the corresponding lower case (with factor size) indicates the split dimension. -* For example, NCDHW16c can describe a 6-D tensor of -* [batch_size, channel, depth, height, width, channel_block]. -* (in which factor size `16` will not be used in pooling but for other operators, -* it can be used to decide the output shape). -* Since pooling does not care about the factor size of dimensions -* other than `D`, `H` and `W`, one can pass `NCDHWc` as well. -* \param count_include_pad Whether include padding in the calculation when pool_type is 'avg' -* -* -* \return The output tensor in the same layout -*/ -inline Tensor pool3d(const Tensor& x, - const Array& kernel_size, - const Array& stride_size, - const Array& padding_size, - PoolType pool_type, - bool ceil_mode, - const std::string& layout = "NCDHW", + * \brief Perform pooling on depth, height and width dimension of data. + * It decides the depth, height and width dimension according to the layout string, + * in which 'D', 'W' and 'H' means depth, width and height respectively. + * Depth, Width and height dimension cannot be split. + * For example, NCDHW, NCDHW16c, etc. are valid for pool, + * while NCDHW16d, NCDHW16w or NCDHW16h are not. + * See \a layout for more information of the layout string convention. + * \param x The input tensor. + * \param kernel_size Vector of three ints: {kernel_depth, kernel_height, kernel_width} + * \param stride_size Vector of three ints: {stride_depth, stride_height, stride_width} + * \param padding_size Vector of six ints: {head_pad_depth, head_pad_height, head_pad_width, + * tail_pad_depth, tail_pad_height, tail_pad_width} + * \param pool_type The type of pooling operator + * \param ceil_mode Whether to use ceil when calculating the output size + * \param layout The input layout. Pooling supports any layout as long as 'D', 'H' and 'W' appear. + * The layout is supposed to be composed of upper cases, lower cases and (optional) numbers, + * where upper case indicates a dimension and + * the corresponding lower case (with factor size) indicates the split dimension. + * For example, NCDHW16c can describe a 6-D tensor of + * [batch_size, channel, depth, height, width, channel_block]. + * (in which factor size `16` will not be used in pooling but for other operators, + * it can be used to decide the output shape). + * Since pooling does not care about the factor size of dimensions + * other than `D`, `H` and `W`, one can pass `NCDHWc` as well. + * \param count_include_pad Whether include padding in the calculation when pool_type is 'avg' + * + * + * \return The output tensor in the same layout + */ +inline Tensor pool3d(const Tensor& x, const Array& kernel_size, + const Array& stride_size, const Array& padding_size, + PoolType pool_type, bool ceil_mode, const std::string& layout = "NCDHW", bool count_include_pad = true) { int depth_axis = -1, height_axis = -1, width_axis = -1; CHECK(find_depth_height_width(layout, &depth_axis, &height_axis, &width_axis)) - << "Unsupported layout " << layout; + << "Unsupported layout " << layout; std::vector axis = {depth_axis, height_axis, width_axis}; - return pool_impl_nd(x, kernel_size, stride_size, padding_size, - pool_type, ceil_mode, axis, count_include_pad); + return pool_impl_nd(x, kernel_size, stride_size, padding_size, pool_type, ceil_mode, axis, + count_include_pad); } } // namespace nn diff --git a/topi/include/topi/nn/softmax.h b/topi/include/topi/nn/softmax.h index dc76a9e3e61a..5ebeb6b8a4bf 100644 --- a/topi/include/topi/nn/softmax.h +++ b/topi/include/topi/nn/softmax.h @@ -24,9 +24,9 @@ #ifndef TOPI_NN_SOFTMAX_H_ #define TOPI_NN_SOFTMAX_H_ -#include #include #include +#include #include #include @@ -37,18 +37,16 @@ using namespace tvm; using namespace tvm::te; /*! -* \brief Softmax activation -* -* \param x The input tensor. Can be any dimension -* \param axis The channel axis along which softmax is performed -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the softmax operation -*/ -inline Tensor softmax(const Tensor &x, - int axis = -1, - std::string name = "tensor", + * \brief Softmax activation + * + * \param x The input tensor. Can be any dimension + * \param axis The channel axis along which softmax is performed + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the softmax operation + */ +inline Tensor softmax(const Tensor& x, int axis = -1, std::string name = "tensor", std::string tag = "softmax_output") { auto input_shape = x->shape; auto ndim = input_shape.size(); @@ -61,11 +59,10 @@ inline Tensor softmax(const Tensor &x, auto k2 = tvm::te::reduce_axis(Range(0, input_shape[axis]), "k2"); auto reduced_shape = MakeReduceTargetShape({axis}, x, false, false); - tvm::Map attrs; + tvm::Map attrs; attrs.Set("axis", Integer(axis)); - auto insert_reduce_index = [axis, ndim](const Array &indices, - const IterVar &reduce_index) { + auto insert_reduce_index = [axis, ndim](const Array& indices, const IterVar& reduce_index) { Array eval_range; int arg_counter = 0; for (size_t i = 0; i < ndim; ++i) { @@ -77,61 +74,54 @@ inline Tensor softmax(const Tensor &x, return eval_range; }; - auto get_non_reduce_indices = [axis, ndim](const Array &indices) { + auto get_non_reduce_indices = [axis, ndim](const Array& indices) { Array non_reduce_indices; for (size_t i = 0; i < ndim; ++i) { - if (static_cast(i) != axis) - non_reduce_indices.push_back(indices[i]); + if (static_cast(i) != axis) non_reduce_indices.push_back(indices[i]); } return non_reduce_indices; }; - auto _compute_max = [&](const Array &indices) { + auto _compute_max = [&](const Array& indices) { auto eval_range = insert_reduce_index(indices, k1); return topi::MaxOp(x(eval_range), {k1}); }; - auto _compute_exp = [&](const Tensor &max_elem, - const Array &indices) { + auto _compute_exp = [&](const Tensor& max_elem, const Array& indices) { auto non_reduce_indices = get_non_reduce_indices(indices); return tvm::exp(x(indices) - max_elem(non_reduce_indices)); }; - auto _compute_expsum = [&](const Tensor &exp, - const Array &indices) { + auto _compute_expsum = [&](const Tensor& exp, const Array& indices) { auto eval_range = insert_reduce_index(indices, k2); return tvm::sum(exp(eval_range), {k2}); }; - auto _normalize = [&](const Tensor &exp, const Tensor &expsum, - const Array &indices) { + auto _normalize = [&](const Tensor& exp, const Tensor& expsum, const Array& indices) { auto non_reduce_indices = get_non_reduce_indices(indices); return exp(indices) / expsum(non_reduce_indices); }; auto max_elem = tvm::te::compute(reduced_shape, _compute_max); - auto exp = tvm::te::compute(input_shape, [&](const Array &indices) { - return _compute_exp(max_elem, indices); - }); - auto expsum = tvm::te::compute(reduced_shape, [&](const Array &indices) { - return _compute_expsum(exp, indices); - }); - return tvm::te::compute(input_shape, [&](const Array &indices) { - return _normalize(exp, expsum, indices); - }, name, tag, attrs); + auto exp = tvm::te::compute( + input_shape, [&](const Array& indices) { return _compute_exp(max_elem, indices); }); + auto expsum = tvm::te::compute( + reduced_shape, [&](const Array& indices) { return _compute_expsum(exp, indices); }); + return tvm::te::compute( + input_shape, [&](const Array& indices) { return _normalize(exp, expsum, indices); }, + name, tag, attrs); } /*! -* \brief Log softmax activation -* -* \param x The input tensor. 2-D where log softmax is performed along the second dimension -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the log softmax operation -*/ -inline Tensor log_softmax(const Tensor& x, - std::string name = "tensor", + * \brief Log softmax activation + * + * \param x The input tensor. 2-D where log softmax is performed along the second dimension + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the log softmax operation + */ +inline Tensor log_softmax(const Tensor& x, std::string name = "tensor", std::string tag = "log_softmax_output") { CHECK_EQ(x->shape.size(), 2) << "Log softmax requires 2-D input"; @@ -139,19 +129,16 @@ inline Tensor log_softmax(const Tensor& x, PrimExpr n = x->shape[1]; auto k = tvm::te::reduce_axis(Range(0, n), "k"); - auto max_elem = tvm::te::compute( - { m }, [&](Var i) { - return tvm::max(x(i, k), Array{ k }); }); + auto max_elem = + tvm::te::compute({m}, [&](Var i) { return tvm::max(x(i, k), Array{k}); }); k = tvm::te::reduce_axis(Range(0, n), "k"); - auto expsum = tvm::te::compute( - { m }, [&](Var i) { - return tvm::sum(tvm::exp(x(i, k) - max_elem(i)), { k }); }); + auto expsum = + tvm::te::compute({m}, [&](Var i) { return tvm::sum(tvm::exp(x(i, k) - max_elem(i)), {k}); }); return tvm::te::compute( - x->shape, [&](Var i, Var j) { - return x(i, j) - max_elem(i) - tvm::log(expsum(i)); - }, name, tag); + x->shape, [&](Var i, Var j) { return x(i, j) - max_elem(i) - tvm::log(expsum(i)); }, name, + tag); } } // namespace nn diff --git a/topi/include/topi/reduction.h b/topi/include/topi/reduction.h index 81c6963835e5..85555000dc1c 100644 --- a/topi/include/topi/reduction.h +++ b/topi/include/topi/reduction.h @@ -24,18 +24,18 @@ #ifndef TOPI_REDUCTION_H_ #define TOPI_REDUCTION_H_ -#include #include +#include +#include #include #include #include -#include -#include +#include #include +#include #include #include -#include namespace topi { using namespace tvm; @@ -45,21 +45,21 @@ using namespace tvm::te; using FReduce = std::function& axis)>; /*! \brief The operation to use for CommReduceIdx */ -using FCommReduce = std::function< - Array(Array exprs, const Array& axis, PrimExpr* condition)>; +using FCommReduce = std::function(Array exprs, const Array& axis, + PrimExpr* condition)>; /*! -* \brief Convert a reduction axis which could be empty or have negative -* elements into a real axis with valid dimension indices. -* -* \param ndim Number of dimensions in the target. -* \param axis The axis parameter. -* -* \return A non-empty sorted array of valid dimension indices, with no duplicates. -* If the input axis is empty, the result will be an axis including all dimensions. -* If any input element is negative, it will be treated as an offset from the -* last dimension (same as python indexing rules). -*/ + * \brief Convert a reduction axis which could be empty or have negative + * elements into a real axis with valid dimension indices. + * + * \param ndim Number of dimensions in the target. + * \param axis The axis parameter. + * + * \return A non-empty sorted array of valid dimension indices, with no duplicates. + * If the input axis is empty, the result will be an axis including all dimensions. + * If any input element is negative, it will be treated as an offset from the + * last dimension (same as python indexing rules). + */ inline std::vector GetRealAxis(int ndim, const Array& axis) { std::vector real_axis; if (!axis.defined() || axis.size() == 0) { @@ -78,8 +78,7 @@ inline std::vector GetRealAxis(int ndim, const Array& axis) { real_axis.push_back(static_cast(val)); } std::sort(real_axis.begin(), real_axis.end()); - real_axis.resize( - std::unique(real_axis.begin(), real_axis.end()) - real_axis.begin()); + real_axis.resize(std::unique(real_axis.begin(), real_axis.end()) - real_axis.begin()); } return real_axis; } @@ -89,17 +88,14 @@ inline Array MakeReduceAxes(const std::vector& real_axis, const Te Array reduce_axes; for (auto i : real_axis) { std::string name = "k" + std::to_string(i); - reduce_axes.push_back( - tvm::te::reduce_axis(Range(0, data->shape[i]), name)); + reduce_axes.push_back(tvm::te::reduce_axis(Range(0, data->shape[i]), name)); } return reduce_axes; } /*! \brief Calculate the target shape for a reduce op */ -inline Array MakeReduceTargetShape(const std::vector& real_axis, - const Tensor& data, - bool keepdims, - bool atleast1d) { +inline Array MakeReduceTargetShape(const std::vector& real_axis, const Tensor& data, + bool keepdims, bool atleast1d) { auto ndim = data->shape.size(); Array target_shape; if (keepdims) { @@ -137,9 +133,7 @@ inline Array MakeReduceTargetShape(const std::vector& real_axis, * * \return The result tensor. */ -inline Tensor DoCommReduce(const Tensor& data, - FReduce func, - const Array& target_shape, +inline Tensor DoCommReduce(const Tensor& data, FReduce func, const Array& target_shape, const std::vector& reduce_axes, const std::vector& squeeze_axes) { auto r_axes = MakeReduceAxes(reduce_axes, data); @@ -182,45 +176,39 @@ inline Tensor DoCommReduce(const Tensor& data, * * \return The result tensor. */ -inline Tensor CommReduce(const Tensor& data, - const Array& axis, - FReduce func, - bool keepdims, - bool atleast1d) { +inline Tensor CommReduce(const Tensor& data, const Array& axis, FReduce func, + bool keepdims, bool atleast1d) { auto ndim = data->shape.size(); CHECK_NE(ndim, 0) << "Cannot reduce a 0 dim Tensor"; auto real_axis = GetRealAxis(static_cast(ndim), axis); auto target_shape = MakeReduceTargetShape(real_axis, data, keepdims, atleast1d); return DoCommReduce(data, func, target_shape, real_axis, - keepdims ? std::vector() : real_axis); + keepdims ? std::vector() : real_axis); } /*! -* \brief Create an index reduction operation. -* -* \param data The input tensor. -* \param axis The axes along which the reduction is performed. -* \param func The reduction function -* \param keepdims If this is set to true, the axes which are reduced are -* left in the result as dimensions with size one. This enables the result -* to broadcast correctly against the input array. -* \param atleast1d Whether the output need to be atleast1d. -* -* \return The result tensor. -*/ -inline Tensor CommReduceIdx(const Tensor& data, - const Array& axis, - FCommReduce func, - bool keepdims, - bool atleast1d) { + * \brief Create an index reduction operation. + * + * \param data The input tensor. + * \param axis The axes along which the reduction is performed. + * \param func The reduction function + * \param keepdims If this is set to true, the axes which are reduced are + * left in the result as dimensions with size one. This enables the result + * to broadcast correctly against the input array. + * \param atleast1d Whether the output need to be atleast1d. + * + * \return The result tensor. + */ +inline Tensor CommReduceIdx(const Tensor& data, const Array& axis, FCommReduce func, + bool keepdims, bool atleast1d) { auto ndim = data->shape.size(); CHECK_NE(ndim, 0) << "Cannot reduce a 0 dim Tensor"; auto real_axis = GetRealAxis(static_cast(ndim), axis); auto reduce_axes = MakeReduceAxes(real_axis, data); auto target_shape = MakeReduceTargetShape(real_axis, data, keepdims, atleast1d); - auto compute = [ndim, keepdims, &real_axis, &reduce_axes, &func, &data] - (const Array& indices) { + auto compute = [ndim, keepdims, &real_axis, &reduce_axes, &func, + &data](const Array& indices) { Array eval_range; Array eval_indices; int arg_counter = 0; @@ -247,18 +235,16 @@ inline Tensor CommReduceIdx(const Tensor& data, ravel_shape.push_back(data->shape[i]); } auto idx = detail::RavelIndex(eval_indices, ravel_shape); - return func({ idx, data(eval_range) }, reduce_axes, nullptr); + return func({idx, data(eval_range)}, reduce_axes, nullptr); }; - auto temp_idx_val = tvm::te::compute(target_shape, compute, - data->op->name + "_red_temp", kCommReduceIdx); + auto temp_idx_val = + tvm::te::compute(target_shape, compute, data->op->name + "_red_temp", kCommReduceIdx); auto temp_idx = temp_idx_val[0]; auto temp_val = temp_idx_val[1]; return tvm::te::compute( - target_shape, - [&temp_idx](const Array& indices) { return temp_idx(indices); }, - data->op->name + "_red", - kCommReduceIdx); + target_shape, [&temp_idx](const Array& indices) { return temp_idx(indices); }, + data->op->name + "_red", kCommReduceIdx); } /*! \brief A combiner function for a reduction */ @@ -276,11 +262,10 @@ using FIdentity = std::function(std::vector types)>; * * \return A reducer function which creates a reduce expression over an axis. */ -inline FCommReduce MakeCommReducer(FCombine fcombine, - FIdentity fidentity, +inline FCommReduce MakeCommReducer(FCombine fcombine, FIdentity fidentity, std::string name = "reduce") { - return [fcombine, fidentity, name] - (Array exprs, const Array& axis, PrimExpr* condition) { + return [fcombine, fidentity, name](Array exprs, const Array& axis, + PrimExpr* condition) { Array lhs, rhs; std::vector dtypes; @@ -295,20 +280,17 @@ inline FCommReduce MakeCommReducer(FCombine fcombine, auto id_elem = fidentity(dtypes); auto cond = condition != nullptr ? *condition : tir::const_true(); - auto combiner = tvm::tir::CommReducerNode::make(lhs, rhs, result, id_elem); + auto combiner = tvm::tir::CommReducer(lhs, rhs, result, id_elem); Array outputs; for (size_t i = 0; i < exprs.size(); ++i) { - outputs.push_back( - tvm::tir::ReduceNode::make(combiner, exprs, axis, cond, static_cast(i))); + outputs.push_back(tvm::tir::Reduce(combiner, exprs, axis, cond, static_cast(i))); } return outputs; }; } /*! \brief Wrap tvm::min to ensure we get the correct overload */ -inline PrimExpr MinOp(PrimExpr source, Array axis) { - return tvm::min(source, axis); -} +inline PrimExpr MinOp(PrimExpr source, Array axis) { return tvm::min(source, axis); } /*! \brief Wrap tvm::max to ensure we get the correct overload */ inline PrimExpr MaxOp(PrimExpr source, Array axis) { @@ -321,21 +303,19 @@ inline PrimExpr ProdOp(PrimExpr source, Array axis) { } /*! -* \brief Creates an operation that sums array elements over a given axis -* -* \param data The input tensor -* \param axis The axis to sum over. If axis is empty, the operation will -* sum over all elements of the array. -* \param keepdims If this is set to true, the axes which are reduced are -* left in the result as dimensions with size one. This enables the result -* to broadcast correctly against the input array. -* \param atleast1d Whether the output need to be atleast1d. -* -* \return A Tensor whose op member is the sum operation -*/ -inline Tensor sum(const Tensor& data, - const Array& axis, - bool keepdims = false, + * \brief Creates an operation that sums array elements over a given axis + * + * \param data The input tensor + * \param axis The axis to sum over. If axis is empty, the operation will + * sum over all elements of the array. + * \param keepdims If this is set to true, the axes which are reduced are + * left in the result as dimensions with size one. This enables the result + * to broadcast correctly against the input array. + * \param atleast1d Whether the output need to be atleast1d. + * + * \return A Tensor whose op member is the sum operation + */ +inline Tensor sum(const Tensor& data, const Array& axis, bool keepdims = false, bool atleast1d = false) { return CommReduce(data, axis, tvm::sum, keepdims, atleast1d); } @@ -347,8 +327,7 @@ inline Tensor collapse_sum(const Tensor& data, Array target_shape) { std::vector reduce_axes; std::vector squeeze_axes; - for (int i_ax = ishape.size() - 1, - o_ax = oshape.size() - 1; i_ax >= 0; --i_ax) { + for (int i_ax = ishape.size() - 1, o_ax = oshape.size() - 1; i_ax >= 0; --i_ax) { if (o_ax >= 0 && ishape[i_ax] == oshape[o_ax]) { --o_ax; continue; @@ -369,117 +348,107 @@ inline Tensor collapse_sum(const Tensor& data, Array target_shape) { } /*! -* \brief Creates an operation that computes the logical AND of elements -* over a given axis -* -* \param data The input boolean tensor -* \param axis The axes to reduce. If axis is empty, the operation will -* perform logical AND over all elements of the array. -* \param keepdims If this is set to true, the axes which are reduced are -* left in the result as dimensions with size one. This enables the result -* to broadcast correctly against the input array. -* \param atleast1d Whether the output need to be atleast1d. -* -* \return A Tensor whose op member is the all operation -*/ -inline Tensor all(const Tensor& data, - const Array& axis, - bool keepdims = false, + * \brief Creates an operation that computes the logical AND of elements + * over a given axis + * + * \param data The input boolean tensor + * \param axis The axes to reduce. If axis is empty, the operation will + * perform logical AND over all elements of the array. + * \param keepdims If this is set to true, the axes which are reduced are + * left in the result as dimensions with size one. This enables the result + * to broadcast correctly against the input array. + * \param atleast1d Whether the output need to be atleast1d. + * + * \return A Tensor whose op member is the all operation + */ +inline Tensor all(const Tensor& data, const Array& axis, bool keepdims = false, bool atleast1d = false) { return CommReduce(data, axis, tvm::all, keepdims, atleast1d); } /*! -* \brief Creates an operation that computes the logical OR of elements -* over a given axis -* -* \param data The input boolean tensor -* \param axis The axes to reduce. If axis is empty, the operation will -* perform logical OR over all elements of the array. -* \param keepdims If this is set to true, the axes which are reduced are -* left in the result as dimensions with size one. This enables the result -* to broadcast correctly against the input array. -* \param atleast1d Whether the output need to be atleast1d. -* -* \return A Tensor whose op member is the all operation -*/ -inline Tensor any(const Tensor& data, - const Array& axis, - bool keepdims = false, + * \brief Creates an operation that computes the logical OR of elements + * over a given axis + * + * \param data The input boolean tensor + * \param axis The axes to reduce. If axis is empty, the operation will + * perform logical OR over all elements of the array. + * \param keepdims If this is set to true, the axes which are reduced are + * left in the result as dimensions with size one. This enables the result + * to broadcast correctly against the input array. + * \param atleast1d Whether the output need to be atleast1d. + * + * \return A Tensor whose op member is the all operation + */ +inline Tensor any(const Tensor& data, const Array& axis, bool keepdims = false, bool atleast1d = false) { return CommReduce(data, axis, tvm::any, keepdims, atleast1d); } /*! -* \brief Creates an operation that finds the minimum of elements over -* a given axis. -* -* \param data The input tensor -* \param axis The axis to find the minimum over. If axis is empty, the -* operation will find the minimum over all elements of the array. -* \param keepdims If this is set to true, the axes which are reduced are -* left in the result as dimensions with size one. This enables the result -* to broadcast correctly against the input array. -* \param atleast1d Whether the output need to be atleast1d. -* -* \return A Tensor whose op member is the min operation -*/ -inline Tensor min(const Tensor& data, - const Array& axis, - bool keepdims = false, + * \brief Creates an operation that finds the minimum of elements over + * a given axis. + * + * \param data The input tensor + * \param axis The axis to find the minimum over. If axis is empty, the + * operation will find the minimum over all elements of the array. + * \param keepdims If this is set to true, the axes which are reduced are + * left in the result as dimensions with size one. This enables the result + * to broadcast correctly against the input array. + * \param atleast1d Whether the output need to be atleast1d. + * + * \return A Tensor whose op member is the min operation + */ +inline Tensor min(const Tensor& data, const Array& axis, bool keepdims = false, bool atleast1d = false) { return CommReduce(data, axis, MinOp, keepdims, atleast1d); } /*! -* \brief Creates an operation that finds the maximum of elements over -* a given axis. -* -* \param data The input tensor -* \param axis The axis to find the maximum over. If axis is empty, the -* operation will find the maximum over all elements of the array. -* \param keepdims If this is set to true, the axes which are reduced are -* left in the result as dimensions with size one. This enables the result -* to broadcast correctly against the input array. -* \param atleast1d Whether the output need to be atleast1d. -* -* \return A Tensor whose op member is the max operation -*/ -inline Tensor max(const Tensor& data, - const Array& axis, - bool keepdims = false, + * \brief Creates an operation that finds the maximum of elements over + * a given axis. + * + * \param data The input tensor + * \param axis The axis to find the maximum over. If axis is empty, the + * operation will find the maximum over all elements of the array. + * \param keepdims If this is set to true, the axes which are reduced are + * left in the result as dimensions with size one. This enables the result + * to broadcast correctly against the input array. + * \param atleast1d Whether the output need to be atleast1d. + * + * \return A Tensor whose op member is the max operation + */ +inline Tensor max(const Tensor& data, const Array& axis, bool keepdims = false, bool atleast1d = false) { return CommReduce(data, axis, MaxOp, keepdims, atleast1d); } /*! -* \brief Creates an operation that finds the indices of the minimum -* values over a given axis. -* -* \param data The input tensor -* \param axis The axis along which the argmin is performed. If axis is empty, -* the operation will find the minimum index over all elements of the array. -* \param keepdims If this is set to true, the axes which are reduced are -* left in the result as dimensions with size one. This enables the result -* to broadcast correctly against the input array. -* \param atleast1d Whether the output need to be atleast1d. -* -* \return A Tensor whose op member is the argmin operation -*/ -inline Tensor argmin(const Tensor& data, - const Array& axis, - bool keepdims = false, + * \brief Creates an operation that finds the indices of the minimum + * values over a given axis. + * + * \param data The input tensor + * \param axis The axis along which the argmin is performed. If axis is empty, + * the operation will find the minimum index over all elements of the array. + * \param keepdims If this is set to true, the axes which are reduced are + * left in the result as dimensions with size one. This enables the result + * to broadcast correctly against the input array. + * \param atleast1d Whether the output need to be atleast1d. + * + * \return A Tensor whose op member is the argmin operation + */ +inline Tensor argmin(const Tensor& data, const Array& axis, bool keepdims = false, bool atleast1d = false) { auto fcombine = [](Array lhs, Array rhs) { Array result; - result.push_back(tvm::tir::SelectNode::make(lhs[1] <= rhs[1], lhs[0], rhs[0])); // idx - result.push_back(tvm::tir::SelectNode::make(lhs[1] <= rhs[1], lhs[1], rhs[1])); // val + result.push_back(tvm::tir::Select(lhs[1] <= rhs[1], lhs[0], rhs[0])); // idx + result.push_back(tvm::tir::Select(lhs[1] <= rhs[1], lhs[1], rhs[1])); // val return result; }; auto fidentity = [](std::vector types) { Array result; result.push_back(tvm::tir::make_const(types[0], -1)); // idx - result.push_back(tvm::max_value(types[1])); // val + result.push_back(tvm::max_value(types[1])); // val return result; }; auto func = MakeCommReducer(fcombine, fidentity, "argmin"); @@ -489,57 +458,53 @@ inline Tensor argmin(const Tensor& data, inline FCommReduce MakeArgmaxReducer() { auto fcombine = [](Array lhs, Array rhs) { Array result; - result.push_back(tvm::tir::SelectNode::make(lhs[1] >= rhs[1], lhs[0], rhs[0])); // idx - result.push_back(tvm::tir::SelectNode::make(lhs[1] >= rhs[1], lhs[1], rhs[1])); // val + result.push_back(tvm::tir::Select(lhs[1] >= rhs[1], lhs[0], rhs[0])); // idx + result.push_back(tvm::tir::Select(lhs[1] >= rhs[1], lhs[1], rhs[1])); // val return result; }; auto fidentity = [](std::vector types) { Array result; result.push_back(tvm::tir::make_const(types[0], -1)); // idx - result.push_back(tvm::min_value(types[1])); // val + result.push_back(tvm::min_value(types[1])); // val return result; }; return MakeCommReducer(fcombine, fidentity, "argmax"); } /*! -* \brief Creates an operation that finds the indices of the maximum -* values over a given axis. -* -* \param data The input tensor -* \param axis The axis along which the argmax is performed. If axis is empty, -* the operation will find the maximum index over all elements of the array. -* \param keepdims If this is set to true, the axes which are reduced are -* left in the result as dimensions with size one. This enables the result -* to broadcast correctly against the input array. -* \param atleast1d Whether the output need to be atleast1d. -* -* \return A Tensor whose op member is the argmax operation -*/ -inline Tensor argmax(const Tensor& data, - const Array& axis, - bool keepdims = false, + * \brief Creates an operation that finds the indices of the maximum + * values over a given axis. + * + * \param data The input tensor + * \param axis The axis along which the argmax is performed. If axis is empty, + * the operation will find the maximum index over all elements of the array. + * \param keepdims If this is set to true, the axes which are reduced are + * left in the result as dimensions with size one. This enables the result + * to broadcast correctly against the input array. + * \param atleast1d Whether the output need to be atleast1d. + * + * \return A Tensor whose op member is the argmax operation + */ +inline Tensor argmax(const Tensor& data, const Array& axis, bool keepdims = false, bool atleast1d = false) { auto reducer = MakeArgmaxReducer(); return CommReduceIdx(data, axis, reducer, keepdims, atleast1d); } /*! -* \brief Creates product operation over given axis. -* -* \param data The input tensor -* \param axis The axis to do product over. If axis is empty, the -* operation will do the product over all elements of the array. -* \param keepdims If this is set to true, the axes which are reduced are -* left in the result as dimensions with size one. This enables the result -* to broadcast correctly against the input array. -* \param atleast1d Whether the output need to be atleast1d. -* -* \return A Tensor whose op member is the prod operation -*/ -inline Tensor prod(const Tensor& data, - const Array& axis, - bool keepdims = false, + * \brief Creates product operation over given axis. + * + * \param data The input tensor + * \param axis The axis to do product over. If axis is empty, the + * operation will do the product over all elements of the array. + * \param keepdims If this is set to true, the axes which are reduced are + * left in the result as dimensions with size one. This enables the result + * to broadcast correctly against the input array. + * \param atleast1d Whether the output need to be atleast1d. + * + * \return A Tensor whose op member is the prod operation + */ +inline Tensor prod(const Tensor& data, const Array& axis, bool keepdims = false, bool atleast1d = false) { return CommReduce(data, axis, ProdOp, keepdims, atleast1d); } diff --git a/topi/include/topi/rocm/dense.h b/topi/include/topi/rocm/dense.h index 629b34e6ddaf..72f8ee62e155 100644 --- a/topi/include/topi/rocm/dense.h +++ b/topi/include/topi/rocm/dense.h @@ -24,14 +24,15 @@ #ifndef TOPI_ROCM_DENSE_H_ #define TOPI_ROCM_DENSE_H_ -#include -#include #include -#include "topi/detail/array_utils.h" -#include "topi/nn/dense.h" +#include +#include + #include "topi/contrib/rocblas.h" -#include "topi/generic/extern.h" #include "topi/cuda/dense.h" +#include "topi/detail/array_utils.h" +#include "topi/generic/extern.h" +#include "topi/nn/dense.h" namespace topi { using namespace tvm; @@ -39,21 +40,19 @@ using namespace tvm::te; namespace rocm { /*! -* \brief Implementation of dense for rocm backend -* -* \param target The target device -* \param data Tensor with shape [batch, in_dim] -* \param weight Tensor with shape [out_dim, in_dim] -* \param bias Tensor with shape [out_dim]. Optional; to omit bias, pass Tensor() -* \param out_dtype Output data type. Used for mixed precision. -* -* \return Tensor with shape [batch, out_dim] -*/ -inline tvm::te::Tensor dense_rocm(const Target& target, - const tvm::te::Tensor& data, - const tvm::te::Tensor& weight, - const tvm::te::Tensor& bias, - const DataType& out_dtype) { + * \brief Implementation of dense for rocm backend + * + * \param target The target device + * \param data Tensor with shape [batch, in_dim] + * \param weight Tensor with shape [out_dim, in_dim] + * \param bias Tensor with shape [out_dim]. Optional; to omit bias, pass Tensor() + * \param out_dtype Output data type. Used for mixed precision. + * + * \return Tensor with shape [batch, out_dim] + */ +inline tvm::te::Tensor dense_rocm(const Target& target, const tvm::te::Tensor& data, + const tvm::te::Tensor& weight, const tvm::te::Tensor& bias, + const DataType& out_dtype) { CHECK_EQ(data->shape.size(), 2) << "dense requires 2-D data"; CHECK_EQ(weight->shape.size(), 2) << "dense requires 2-D weight"; if (bias.defined()) { @@ -68,10 +67,8 @@ inline tvm::te::Tensor dense_rocm(const Target& target, CHECK_EQ(data->dtype, out_dtype) << "Mixed precision not supported."; auto mm = topi::contrib::rocblas_matmul(data, weight, false, true); if (bias.defined()) { - mm = tvm::te::compute({ batch, out_dim }, - [&](Var i, Var j) { - return mm(i, j) + bias(j); - }, "tensor", kBroadcast); + mm = tvm::te::compute( + {batch, out_dim}, [&](Var i, Var j) { return mm(i, j) + bias(j); }, "tensor", kBroadcast); } return mm; @@ -81,16 +78,15 @@ inline tvm::te::Tensor dense_rocm(const Target& target, } /*! -* \brief Create a rocm schedule for dense -* -* \param target The target to generate a schedule for. -* \param outs The output tensors. -* -* \return A schedule for the given ops. -*/ -inline Schedule schedule_dense(const Target &target, const Array& outs) { - if (target->target_name == "rocm" && - target->libs().count("rocblas")) { + * \brief Create a rocm schedule for dense + * + * \param target The target to generate a schedule for. + * \param outs The output tensors. + * + * \return A schedule for the given ops. + */ +inline Schedule schedule_dense(const Target& target, const Array& outs) { + if (target->target_name == "rocm" && target->libs().count("rocblas")) { return topi::generic::schedule_extern(target, outs); } diff --git a/topi/include/topi/rocm/injective.h b/topi/include/topi/rocm/injective.h index f3a3f3b0cbd2..e7415bfd0ff2 100644 --- a/topi/include/topi/rocm/injective.h +++ b/topi/include/topi/rocm/injective.h @@ -24,10 +24,10 @@ #ifndef TOPI_ROCM_INJECTIVE_H_ #define TOPI_ROCM_INJECTIVE_H_ -#include #include -#include +#include #include +#include #include "topi/cuda/injective.h" @@ -57,7 +57,7 @@ inline Schedule schedule_injective_from_existing(Schedule sch, const Tensor& out * * \return A schedule for the given ops. */ -inline Schedule schedule_injective(const Target &target, const Array& outs) { +inline Schedule schedule_injective(const Target& target, const Array& outs) { return topi::cuda::schedule_injective(target, outs); } diff --git a/topi/include/topi/rocm/normalization.h b/topi/include/topi/rocm/normalization.h index 303f4a8302c7..832868348b67 100644 --- a/topi/include/topi/rocm/normalization.h +++ b/topi/include/topi/rocm/normalization.h @@ -24,22 +24,20 @@ #ifndef TOPI_ROCM_NORMALIZATION_H_ #define TOPI_ROCM_NORMALIZATION_H_ -#include -#include #include +#include +#include namespace topi { using namespace tvm; using namespace tvm::te; namespace rocm { /*! -* \brief Create a rocm schedule for LRN -* \param outs The output tensors. -* \return A schedule for the given ops. -*/ -inline Schedule schedule_lrn(const Array& outs) { - return topi::cuda::schedule_lrn(outs); -} + * \brief Create a rocm schedule for LRN + * \param outs The output tensors. + * \return A schedule for the given ops. + */ +inline Schedule schedule_lrn(const Array& outs) { return topi::cuda::schedule_lrn(outs); } } // namespace rocm } // namespace topi diff --git a/topi/include/topi/rocm/pooling.h b/topi/include/topi/rocm/pooling.h index 7d1f36f2ee33..0b68a0ac5366 100644 --- a/topi/include/topi/rocm/pooling.h +++ b/topi/include/topi/rocm/pooling.h @@ -24,12 +24,12 @@ #ifndef TOPI_ROCM_POOLING_H_ #define TOPI_ROCM_POOLING_H_ -#include -#include -#include -#include -#include #include +#include +#include +#include +#include +#include namespace topi { using namespace tvm; @@ -38,26 +38,26 @@ using namespace tvm::te; namespace rocm { /*! -* \brief Create a rocm schedule for pool -* -* \param target The target to generate a schedule for. -* \param outs The output tensors. -* -* \return A schedule for the given ops. -*/ -inline Schedule schedule_pool(const Target &target, const Array& outs) { + * \brief Create a rocm schedule for pool + * + * \param target The target to generate a schedule for. + * \param outs The output tensors. + * + * \return A schedule for the given ops. + */ +inline Schedule schedule_pool(const Target& target, const Array& outs) { return topi::cuda::schedule_pool(target, outs); } /*! -* \brief Create a rocm schedule for global_pool -* -* \param target The target to generate a schedule for. -* \param outs The output tensors. -* -* \return A schedule for the given ops. -*/ -inline Schedule schedule_global_pool(const Target &target, const Array& outs) { + * \brief Create a rocm schedule for global_pool + * + * \param target The target to generate a schedule for. + * \param outs The output tensors. + * + * \return A schedule for the given ops. + */ +inline Schedule schedule_global_pool(const Target& target, const Array& outs) { return topi::cuda::schedule_global_pool(target, outs); } diff --git a/topi/include/topi/rocm/reduction.h b/topi/include/topi/rocm/reduction.h index ea4b65623928..512bf20b4bc1 100644 --- a/topi/include/topi/rocm/reduction.h +++ b/topi/include/topi/rocm/reduction.h @@ -24,10 +24,10 @@ #ifndef TOPI_ROCM_REDUCTION_H_ #define TOPI_ROCM_REDUCTION_H_ -#include #include -#include +#include #include +#include #include "topi/cuda/reduction.h" @@ -37,13 +37,13 @@ using namespace tvm::te; namespace rocm { /*! -* \brief Create a rocm schedule for a reduce operation. -* -* \param target The target to generate a schedule for. -* \param outs The output tensors. -* -* \return A schedule for the given ops. -*/ + * \brief Create a rocm schedule for a reduce operation. + * + * \param target The target to generate a schedule for. + * \param outs The output tensors. + * + * \return A schedule for the given ops. + */ Schedule schedule_reduce(const Target& target, Array outs) { return topi::cuda::schedule_reduce(target, outs); } diff --git a/topi/include/topi/rocm/softmax.h b/topi/include/topi/rocm/softmax.h index 63a0304d2818..de05c4cec9d3 100644 --- a/topi/include/topi/rocm/softmax.h +++ b/topi/include/topi/rocm/softmax.h @@ -24,10 +24,10 @@ #ifndef TOPI_ROCM_SOFTMAX_H_ #define TOPI_ROCM_SOFTMAX_H_ -#include #include -#include +#include #include +#include #include "topi/cuda/softmax.h" @@ -45,7 +45,7 @@ namespace rocm { * * \return A schedule for the given ops. */ -inline Schedule schedule_softmax(const Target &target, const Array& outs) { +inline Schedule schedule_softmax(const Target& target, const Array& outs) { return topi::cuda::schedule_softmax(target, outs); } diff --git a/topi/include/topi/tags.h b/topi/include/topi/tags.h index 8d353b949ab6..1e9ec446dfa3 100644 --- a/topi/include/topi/tags.h +++ b/topi/include/topi/tags.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -43,16 +43,12 @@ constexpr auto kDepthwiseConv2dBackWeightNHWC = "depthwise_conv2d_back_weight_nh constexpr auto kGroupConv2d = "group_conv2d"; inline bool is_broadcast(std::string tag) { - return - tag.rfind(kElementWise, 0) == 0 || - tag.rfind(kBroadcast, 0) == 0; + return tag.rfind(kElementWise, 0) == 0 || tag.rfind(kBroadcast, 0) == 0; } inline bool is_injective(std::string tag) { - return - tag.rfind(kElementWise, 0) == 0 || - tag.rfind(kBroadcast, 0) == 0 || - tag.rfind(kInjective, 0) == 0; + return tag.rfind(kElementWise, 0) == 0 || tag.rfind(kBroadcast, 0) == 0 || + tag.rfind(kInjective, 0) == 0; } } // namespace topi diff --git a/topi/include/topi/transform.h b/topi/include/topi/transform.h index 431ace5bc11e..794796702d00 100644 --- a/topi/include/topi/transform.h +++ b/topi/include/topi/transform.h @@ -24,19 +24,19 @@ #ifndef TOPI_TRANSFORM_H_ #define TOPI_TRANSFORM_H_ -#include -#include -#include -#include #include +#include #include +#include +#include +#include -#include -#include -#include #include +#include #include +#include #include +#include namespace topi { using namespace tvm; @@ -44,30 +44,25 @@ using namespace tvm::te; using namespace topi::detail; /*! -* \brief Creates an operation to insert new dimensions of length 1 -* -* \param x The input tensor -* \param axis The index of the first new dimension (allows negative -* indices as offsets from the last dimension) -* \param num_newaxis The number of new dimensions to insert -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the dim expansion operation -*/ -inline Tensor expand_dims(const Tensor& x, - int axis, - int num_newaxis = 1, - std::string name = "T_expand_dims", - std::string tag = kBroadcast) { + * \brief Creates an operation to insert new dimensions of length 1 + * + * \param x The input tensor + * \param axis The index of the first new dimension (allows negative + * indices as offsets from the last dimension) + * \param num_newaxis The number of new dimensions to insert + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the dim expansion operation + */ +inline Tensor expand_dims(const Tensor& x, int axis, int num_newaxis = 1, + std::string name = "T_expand_dims", std::string tag = kBroadcast) { int ndim = static_cast(x->shape.size()); CHECK(-ndim - 1 <= axis && axis <= ndim) - << "expand_dims only accepts `axis` in [-data.ndim - 1, data.ndim]" - << ", but got axis = " << axis - << ", and data.ndim = " << ndim; - CHECK(num_newaxis >= 0) - << "expand_dims only accepts `num_newaxis >= 0`" - << ", but got num_newaxis = " << num_newaxis; + << "expand_dims only accepts `axis` in [-data.ndim - 1, data.ndim]" + << ", but got axis = " << axis << ", and data.ndim = " << ndim; + CHECK(num_newaxis >= 0) << "expand_dims only accepts `num_newaxis >= 0`" + << ", but got num_newaxis = " << num_newaxis; if (axis < 0) { // Calculate offset from last dimension axis = ndim + axis + 1; @@ -84,32 +79,32 @@ inline Tensor expand_dims(const Tensor& x, } return compute( - new_shape, [&](const Array& indices) { - Array idx; - for (size_t i = 0; i < static_cast(axis); ++i) { - idx.push_back(indices[i]); - } - for (size_t i = axis + num_newaxis; i < indices.size(); ++i) { - idx.push_back(indices[i]); - } - return x(idx); - }, name, tag); + new_shape, + [&](const Array& indices) { + Array idx; + for (size_t i = 0; i < static_cast(axis); ++i) { + idx.push_back(indices[i]); + } + for (size_t i = axis + num_newaxis; i < indices.size(); ++i) { + idx.push_back(indices[i]); + } + return x(idx); + }, + name, tag); } /*! -* \brief Permute the dimensions of an array -* -* \param x The input tensor -* \param axes The indices of the permutation. If this is empty, -* the dimensions will be reversed. -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the transpose operation -*/ -inline Tensor transpose(const Tensor& x, - Array axes, - std::string name = "T_transpose", + * \brief Permute the dimensions of an array + * + * \param x The input tensor + * \param axes The indices of the permutation. If this is empty, + * the dimensions will be reversed. + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the transpose operation + */ +inline Tensor transpose(const Tensor& x, Array axes, std::string name = "T_transpose", std::string tag = kInjective) { if (!axes.defined() || axes.size() == 0) { axes = Array(); @@ -127,11 +122,11 @@ inline Tensor transpose(const Tensor& x, axes.Set(i, new_axis); } CHECK((new_axis >= 0) && (new_axis < static_cast(x->shape.size()))) - << "axis=" << axis << " is invalid for the " - << static_cast(x->shape.size()) << "-dimensional input tensor"; + << "axis=" << axis << " is invalid for the " << static_cast(x->shape.size()) + << "-dimensional input tensor"; for (size_t j = 0; j < axes.size(); ++j) { - if (i !=j) { + if (i != j) { CHECK(new_axis != static_cast(axes[j]->value)) << "repeated axis in transpose"; } } @@ -139,33 +134,33 @@ inline Tensor transpose(const Tensor& x, } return compute( - new_shape, [&](const Array& indices) { - std::vector idx; - for (size_t i = 0; i < axes.size(); ++i) { - idx.push_back(1); - } - for (size_t i = 0; i < axes.size(); ++i) { - int axis = static_cast(axes[i]->value); - idx[axis] = indices[i]; - } - return x(idx); - }, name, tag); + new_shape, + [&](const Array& indices) { + std::vector idx; + for (size_t i = 0; i < axes.size(); ++i) { + idx.push_back(1); + } + for (size_t i = 0; i < axes.size(); ++i) { + int axis = static_cast(axes[i]->value); + idx[axis] = indices[i]; + } + return x(idx); + }, + name, tag); } /*! -* \brief flip/reverse elements of an array in a particular axis -* -* \param x The input tensor -* \param axis The axis along which the tensors will be reveresed -* (allows negative indices) -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the reverse operation -*/ -inline Tensor flip(const Tensor& x, - int axis = 0, - std::string name = "T_flip", + * \brief flip/reverse elements of an array in a particular axis + * + * \param x The input tensor + * \param axis The axis along which the tensors will be reveresed + * (allows negative indices) + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the reverse operation + */ +inline Tensor flip(const Tensor& x, int axis = 0, std::string name = "T_flip", std::string tag = kInjective) { size_t src_tensor_dim = x->shape.size(); int axis_inp = axis; @@ -175,42 +170,42 @@ inline Tensor flip(const Tensor& x, } CHECK((0 <= axis) && (axis < static_cast(x->shape.size()))) - << "axis=" << axis_inp << " is invalid for the " - << static_cast(x->shape.size()) << "-dimensional input tensor"; + << "axis=" << axis_inp << " is invalid for the " << static_cast(x->shape.size()) + << "-dimensional input tensor"; // Reverse the Input Tensor in the axis specified return compute( - x->shape, [&](const Array& indices) { - Array real_indices; - for (size_t i = 0; i < src_tensor_dim; ++i) { - if (i == static_cast(axis)) { - real_indices.push_back(x->shape[i] - indices[i] - 1); - } else { - real_indices.push_back(indices[i]); + x->shape, + [&](const Array& indices) { + Array real_indices; + for (size_t i = 0; i < src_tensor_dim; ++i) { + if (i == static_cast(axis)) { + real_indices.push_back(x->shape[i] - indices[i] - 1); + } else { + real_indices.push_back(indices[i]); + } } - } - return x(real_indices); - }, name, tag); + return x(real_indices); + }, + name, tag); } /*! -* \brief Reshape a tensor -* -* \param x The input tensor -* \param newshape The new shape -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the reshape operation -*/ -inline Tensor reshape(const Tensor& x, - Array newshape, - std::string name = "T_reshape", + * \brief Reshape a tensor + * + * \param x The input tensor + * \param newshape The new shape + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the reshape operation + */ +inline Tensor reshape(const Tensor& x, Array newshape, std::string name = "T_reshape", std::string tag = kInjective) { auto x_shape = x->shape; Array target_shape; - for (const auto &ele : newshape) { + for (const auto& ele : newshape) { if (ele.as()) { target_shape.push_back(cast(DataType::Int(32), ele)); } else { @@ -219,16 +214,16 @@ inline Tensor reshape(const Tensor& x, } if (is_empty_shape(target_shape)) { - return compute(target_shape, - [&](const Array &indices) { return tvm::cast(x->dtype, 0); }, - name, tag); + return compute( + target_shape, [&](const Array& indices) { return tvm::cast(x->dtype, 0); }, name, tag); } else { return compute( - target_shape, [&](const Array& indices) { - return x(UnravelIndex( - RavelIndex(Array{indices.begin(), indices.end()}, target_shape), - x_shape)); - }, name, tag); + target_shape, + [&](const Array& indices) { + return x(UnravelIndex( + RavelIndex(Array{indices.begin(), indices.end()}, target_shape), x_shape)); + }, + name, tag); } } @@ -243,9 +238,7 @@ inline Tensor reshape(const Tensor& x, * \return A Tensor of coordinate arrays. */ -inline Tensor unravel_index(const Tensor& x, - const Tensor& shape, - std::string name = "T_unravel", +inline Tensor unravel_index(const Tensor& x, const Tensor& shape, std::string name = "T_unravel", std::string tag = kInjective) { auto x_shape = x->shape; auto shape_shape = shape->shape; @@ -281,23 +274,20 @@ inline Tensor unravel_index(const Tensor& x, } /*! -* \brief Remove size 1 dimensions from the shape of a tensor. -* The removed dimensions must have a constant size of 1. -* -* \param x The input tensor -* \param axis Indices of the dimensions to remove. If this is empty, -* all entries with a constant size of 1 will be removed. + * \brief Remove size 1 dimensions from the shape of a tensor. + * The removed dimensions must have a constant size of 1. + * + * \param x The input tensor + * \param axis Indices of the dimensions to remove. If this is empty, + * all entries with a constant size of 1 will be removed. * \param atleast1d Whether the output need to be atleast1d. -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the squeeze operation -*/ -inline Tensor squeeze(const Tensor& x, - Array axis, - bool atleast1d = false, - std::string name = "T_squeeze", - std::string tag = kInjective) { + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the squeeze operation + */ +inline Tensor squeeze(const Tensor& x, Array axis, bool atleast1d = false, + std::string name = "T_squeeze", std::string tag = kInjective) { auto ndim = x->shape.size(); std::vector axis_val; if (!axis.defined() || axis.size() == 0) { @@ -312,8 +302,7 @@ inline Tensor squeeze(const Tensor& x, if (val < 0) { val += static_cast(x->shape.size()); } - CHECK_EQ(GetConstInt(x->shape[val]), 1) << - "Dimension " << val << " must have size 1"; + CHECK_EQ(GetConstInt(x->shape[val]), 1) << "Dimension " << val << " must have size 1"; axis_val.push_back(val); } } @@ -331,152 +320,140 @@ inline Tensor squeeze(const Tensor& x, } return compute( - out_shape, [&](const Array& indices) { - Array real_indices; - int flag = 0; - for (size_t i = 0; i < ndim; ++i) { - if (axis_set.count(static_cast(i)) == 0) { - real_indices.push_back(indices[i - flag]); - } else { - real_indices.push_back(0); - flag += 1; + out_shape, + [&](const Array& indices) { + Array real_indices; + int flag = 0; + for (size_t i = 0; i < ndim; ++i) { + if (axis_set.count(static_cast(i)) == 0) { + real_indices.push_back(indices[i - flag]); + } else { + real_indices.push_back(0); + flag += 1; + } } - } - return x(real_indices); - }, name, tag); + return x(real_indices); + }, + name, tag); } /*! -* \brief Join a sequence of tensors along an existing axis -* -* \param inputs The input tensors -* \param axis The axis along which the tensors will be joined -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the concatenate operation -*/ -inline Tensor concatenate(const Array& inputs, - int axis = 0, - std::string name = "T_concat", + * \brief Join a sequence of tensors along an existing axis + * + * \param inputs The input tensors + * \param axis The axis along which the tensors will be joined + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the concatenate operation + */ +inline Tensor concatenate(const Array& inputs, int axis = 0, std::string name = "T_concat", std::string tag = kInjective) { int ndim = static_cast(inputs[0]->shape.size()); - CHECK(-ndim <= axis && axis < ndim) - << "concatenate only accepts `axis` in [-ndim, ndim)" - << ", but got axis = " << axis - << ", and ndim = " << ndim; + CHECK(-ndim <= axis && axis < ndim) << "concatenate only accepts `axis` in [-ndim, ndim)" + << ", but got axis = " << axis << ", and ndim = " << ndim; if (axis < 0) { axis += ndim; } - CHECK_LT(axis, inputs[0]->shape.size()) << - "axis out of bounds"; + CHECK_LT(axis, inputs[0]->shape.size()) << "axis out of bounds"; Array axis_sizes; for (auto t : inputs) { axis_sizes.push_back(t->shape[axis]); } - + arith::Analyzer analyzer; PrimExpr join_size = axis_sizes[0]; for (size_t i = 1; i < axis_sizes.size(); ++i) { join_size += axis_sizes[i]; } - join_size = tvm::tir::Simplify(join_size); + join_size = analyzer.Simplify(join_size); Array out_shape; for (size_t i = 0; i < inputs[0]->shape.size(); ++i) { out_shape.push_back(i == static_cast(axis) ? join_size : inputs[0]->shape[i]); } return compute( - out_shape, [&](const Array& indices) { - auto ret = inputs[0](indices); - auto ind = indices[axis]; - for (size_t i = 0; i < inputs.size() - 1; ++i) { - ind -= axis_sizes[i]; + out_shape, + [&](const Array& indices) { + auto ret = inputs[0](indices); + auto ind = indices[axis]; + for (size_t i = 0; i < inputs.size() - 1; ++i) { + ind -= axis_sizes[i]; + + Array idx; + for (size_t i = 0; i < static_cast(axis); ++i) { + idx.push_back(indices[i]); + } + idx.push_back(ind); + for (size_t i = axis + 1; i < indices.size(); ++i) { + idx.push_back(indices[i]); + } - Array idx; - for (size_t i = 0; i < static_cast(axis); ++i) { - idx.push_back(indices[i]); + ret = tvm::if_then_else(ind >= 0, inputs[i + 1](idx), ret); } - idx.push_back(ind); - for (size_t i = axis + 1; i < indices.size(); ++i) { - idx.push_back(indices[i]); - } - - ret = tvm::if_then_else(ind >= 0, - inputs[i + 1](idx), - ret); - } - return ret; - }, name, tag); + return ret; + }, + name, tag); } /*! -* \brief Join a sequence of tensors along a new axis. -* -* \param inputs The input tensors -* \param axis The axis along which the tensors will be stacked -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the stack operation -*/ -inline Tensor stack(const Array& inputs, - int axis = 0, - std::string name = "T_stack", + * \brief Join a sequence of tensors along a new axis. + * + * \param inputs The input tensors + * \param axis The axis along which the tensors will be stacked + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the stack operation + */ +inline Tensor stack(const Array& inputs, int axis = 0, std::string name = "T_stack", std::string tag = kInjective) { int ndim = static_cast(inputs[0]->shape.size()); CHECK(-ndim - 1 <= axis && axis <= ndim) - << "stack only accepts `axis` in [-ndim, ndim)" - << ", but got axis = " << axis - << ", and ndim = " << ndim; + << "stack only accepts `axis` in [-ndim, ndim)" + << ", but got axis = " << axis << ", and ndim = " << ndim; if (axis < 0) { axis += ndim + 1; } - CHECK_LT(axis, inputs[0]->shape.size() + 1) << - "axis out of bounds"; + CHECK_LT(axis, inputs[0]->shape.size() + 1) << "axis out of bounds"; const int stack_size = static_cast(inputs.size()); Array out_shape; - for (size_t i = 0; i < static_cast(axis); ++i) - out_shape.push_back(inputs[0]->shape[i]); + for (size_t i = 0; i < static_cast(axis); ++i) out_shape.push_back(inputs[0]->shape[i]); out_shape.push_back(stack_size); for (size_t i = static_cast(axis); i < static_cast(ndim); ++i) out_shape.push_back(inputs[0]->shape[i]); return compute( - out_shape, [&](const Array& indices) { - Array idx; - for (size_t i = 0; i < indices.size(); ++i) - if (i != static_cast(axis)) - idx.push_back(indices[i]); - auto ind = indices[axis]; - auto ret = inputs[0](idx); - for (int i = 0; i < static_cast(inputs.size() - 1); ++i) { - ret = tvm::if_then_else(ind == i + 1, - inputs[i + 1](idx), - ret); - } - return ret; - }, name, tag); + out_shape, + [&](const Array& indices) { + Array idx; + for (size_t i = 0; i < indices.size(); ++i) + if (i != static_cast(axis)) idx.push_back(indices[i]); + auto ind = indices[axis]; + auto ret = inputs[0](idx); + for (int i = 0; i < static_cast(inputs.size() - 1); ++i) { + ret = tvm::if_then_else(ind == i + 1, inputs[i + 1](idx), ret); + } + return ret; + }, + name, tag); } /*! -* \brief Split a tensor into multiple sub-tensors -* -* \param x The input tensor -* \param split_indices The indices to split the input at. This must be in ascending -* order. -* \param axis The axis to split along. -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the split operation -*/ -inline Array split(const Tensor& x, - Array split_indices, - int axis, - std::string name = "T_split", - std::string tag = kInjective) { + * \brief Split a tensor into multiple sub-tensors + * + * \param x The input tensor + * \param split_indices The indices to split the input at. This must be in ascending + * order. + * \param axis The axis to split along. + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the split operation + */ +inline Array split(const Tensor& x, Array split_indices, int axis, + std::string name = "T_split", std::string tag = kInjective) { if (axis < 0) { axis += static_cast(x->shape.size()); } @@ -488,12 +465,11 @@ inline Array split(const Tensor& x, for (Integer idx : split_indices) { int val = static_cast(idx->value); - CHECK_GT(val, begin_ids.back()) - << "split_indices must be sorted"; + CHECK_GT(val, begin_ids.back()) << "split_indices must be sorted"; begin_ids.push_back(val); } - Array< Array > out_shapes; + Array > out_shapes; for (size_t i = 0; i < begin_ids.size(); ++i) { int out_axis_size; if (i == begin_ids.size() - 1) { @@ -516,9 +492,9 @@ inline Array split(const Tensor& x, Array result; for (size_t i = 0; i < begin_ids.size(); ++i) { - result.push_back( - compute( - out_shapes[i], [&](const Array& indices) { + result.push_back(compute( + out_shapes[i], + [&](const Array& indices) { auto begin = begin_ids[i]; Array real_indices; for (size_t j = 0; j < static_cast(axis); ++j) { @@ -530,43 +506,40 @@ inline Array split(const Tensor& x, } return x(real_indices); - }, name, tag)); + }, + name, tag)); } return result; } /*! -* \brief strided_slice of a tensor -* -* \param x The input tensor -* \param begin The indices to begin with in the slicing -* \param end Indicies indicating end of the slice -* \param strides Specifies the stride values, it can be negative -* in that case, the input tensor will be reversed in that particular axis -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the split operation -*/ -inline Tensor strided_slice(const Tensor& x, - const Array& begin, - const Array& end, - const Array& strides, - std::string name = "T_strided_slice", - std::string tag = kInjective) { + * \brief strided_slice of a tensor + * + * \param x The input tensor + * \param begin The indices to begin with in the slicing + * \param end Indicies indicating end of the slice + * \param strides Specifies the stride values, it can be negative + * in that case, the input tensor will be reversed in that particular axis + * \param slice_mode Specifies the slice mode + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the split operation + */ +inline Tensor strided_slice(const Tensor& x, const Array& begin, const Array& end, + const Array& strides, std::string slice_mode = "end", + std::string name = "T_strided_slice", std::string tag = kInjective) { size_t src_tensor_dim = static_cast(x->shape.size()); // Setup the ranges. // NOTE: this code duplicates the shape inference logic relay.op // Consider to refactor in the future. - std::vector stride_vec; - for (Integer i : strides) { - CHECK(i.defined()); - stride_vec.push_back(i->value); - } - for (size_t i = stride_vec.size(); i < src_tensor_dim; ++i) { - stride_vec.push_back(1); + std::vector stride_vec(src_tensor_dim, 1); + for (size_t i = 0; i < strides.size(); ++i) { + CHECK(strides[i].defined()); + stride_vec[i] = strides[i]->value; } + const int64_t max_range = std::numeric_limits::max(); std::vector begin_vec; @@ -585,8 +558,15 @@ inline Tensor strided_slice(const Tensor& x, std::vector end_vec; for (size_t i = 0; i < end.size(); ++i) { // allow end to be None + if (!end[i].defined()) { end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range); + } else if (slice_mode == "size") { + if (end[i]->value < 0) { + end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range); + } else { + end_vec.push_back(begin_vec[i] + end[i]->value); + } } else { end_vec.push_back(end[i]->value); } @@ -615,43 +595,43 @@ inline Tensor strided_slice(const Tensor& x, int64_t end_i = index_canonicalization(end_vec[i]); int interval = std::abs(end_i - begin_i); - int slice_size = static_cast((interval - + std::abs(stride_vec[i]) - 1) / std::abs(stride_vec[i])); + int slice_size = + static_cast((interval + std::abs(stride_vec[i]) - 1) / std::abs(stride_vec[i])); CHECK(stride_vec[i] < 0 ? (end_i <= begin_i) : (begin_i <= end_i)) - << ": Input [Begin=" << begin_vec[i] << ", End=" << end_vec[i] - << "] is invalid for axis=" << i; + << ": Input [Begin=" << begin_vec[i] << ", End=" << end_vec[i] + << "] is invalid for axis=" << i; begin_expr.push_back(make_const(begin[0].dtype(), begin_i)); - strides_expr.push_back(make_const((strides.size() != 0 ? strides[0].dtype() : begin[0].dtype()), - stride_vec[i])); + strides_expr.push_back( + make_const((strides.size() != 0 ? strides[0].dtype() : begin[0].dtype()), stride_vec[i])); out_shape.push_back(slice_size); } return compute( - out_shape, [&](const Array& indices) { - Array real_indices; - for (size_t i = 0; i < src_tensor_dim; ++i) { - real_indices.push_back(indices[i] * strides_expr[i] + begin_expr[i]); - } - return x(real_indices); - }, name, tag); + out_shape, + [&](const Array& indices) { + Array real_indices; + for (size_t i = 0; i < src_tensor_dim; ++i) { + real_indices.push_back(indices[i] * strides_expr[i] + begin_expr[i]); + } + return x(real_indices); + }, + name, tag); } /*! -* \brief Split a tensor into a number of sub-tensors -* -* \param x The input tensor -* \param num_sections The number of sections to split the tensor into. -* this must be an integer factor of the size of the axis being split. -* \param axis The axis to split along. -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the split operation -*/ -inline Array split_sections(const Tensor& x, - int num_sections, - int axis, + * \brief Split a tensor into a number of sub-tensors + * + * \param x The input tensor + * \param num_sections The number of sections to split the tensor into. + * this must be an integer factor of the size of the axis being split. + * \param axis The axis to split along. + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the split operation + */ +inline Array split_sections(const Tensor& x, int num_sections, int axis, std::string name = "T_split_sections", std::string tag = kInjective) { if (axis < 0) { @@ -663,8 +643,8 @@ inline Array split_sections(const Tensor& x, CHECK_GT(num_sections, 0) << "Slice count must be > 0"; CHECK_EQ(src_axis_size % num_sections, 0) - << "num_sections must be an integer factor of the size of axis " << axis - << " (" << src_axis_size << ")"; + << "num_sections must be an integer factor of the size of axis " << axis << " (" + << src_axis_size << ")"; Array split_indices; auto seg_size = src_axis_size / num_sections; @@ -679,22 +659,19 @@ inline Array split_sections(const Tensor& x, } /*! -* \brief Take elements from an flattened input array when axis is None. -* -* \param a The source array. -* \param indices The indices of the values to extract. -* \param mode The mode of the operation. -* \param name The name of the operation. -* \param mode The mode of to handle out of bound indices. -* \param tag The tag to mark the operation. -* -* \return A Tensor whose op member is the take operation -*/ -inline Tensor take(const Tensor& a, - const Tensor& indices, - std::string mode = "clip", - std::string name = "T_take", - std::string tag = kInjective) { + * \brief Take elements from an flattened input array when axis is None. + * + * \param a The source array. + * \param indices The indices of the values to extract. + * \param mode The mode of the operation. + * \param name The name of the operation. + * \param mode The mode of to handle out of bound indices. + * \param tag The tag to mark the operation. + * + * \return A Tensor whose op member is the take operation + */ +inline Tensor take(const Tensor& a, const Tensor& indices, std::string mode = "clip", + std::string name = "T_take", std::string tag = kInjective) { Array a_shape = a->shape; Array out_shape = indices->shape; PrimExpr a_size = 1; @@ -704,44 +681,44 @@ inline Tensor take(const Tensor& a, if (mode == "clip") { return compute( - out_shape, [&](const Array& out_index) { + out_shape, + [&](const Array& out_index) { auto idx = tvm::min(tvm::max(0, indices(out_index)), a_size - 1); return a(UnravelIndex(idx, a_shape)); - }, name, tag); + }, + name, tag); } else if (mode == "fast") { LOG(WARNING) << "Fast mode segfaults when there are out-of-bounds indices. " "Make sure input indices are in bound"; return compute( - out_shape, [&](const Array& out_index) { - return a(UnravelIndex(indices(out_index), a_shape)); - }, name, tag); + out_shape, + [&](const Array& out_index) { return a(UnravelIndex(indices(out_index), a_shape)); }, + name, tag); } else { // mode == "wrap" return compute( - out_shape, [&](const Array& out_index) { + out_shape, + [&](const Array& out_index) { auto idx = truncmod(truncmod(indices(out_index), a_size) + a_size, a_size); return a(UnravelIndex(idx, a_shape)); - }, name, tag); + }, + name, tag); } } - /*! -* \brief Mask the out-of-boundary elements of each sequence. -* -* \param data The source array. -* \param valid_length The real length of each sequence. -* \param mask_value The masking value. -* \param axis The axis of the temporal dimension of the sequence -* \param name The name of the operation. -* \param tag The tag to mark the operation. -* -* \return A Tensor whose op member is the sequence_mask operation -*/ -inline Tensor sequence_mask(const Tensor& data, - const Tensor& valid_length, - double mask_value, - int axis, - std::string name = "T_sequence_mask", + * \brief Mask the out-of-boundary elements of each sequence. + * + * \param data The source array. + * \param valid_length The real length of each sequence. + * \param mask_value The masking value. + * \param axis The axis of the temporal dimension of the sequence + * \param name The name of the operation. + * \param tag The tag to mark the operation. + * + * \return A Tensor whose op member is the sequence_mask operation + */ +inline Tensor sequence_mask(const Tensor& data, const Tensor& valid_length, double mask_value, + int axis, std::string name = "T_sequence_mask", std::string tag = kInjective) { CHECK(axis == 0 || axis == 1) << "axis must be either 0 or 1"; CHECK_EQ(valid_length->shape.size(), 1) << "valid_length must have ndim=1, i.e., (batch_size,)."; @@ -749,38 +726,36 @@ inline Tensor sequence_mask(const Tensor& data, auto batch_dim = data->shape[1 - axis]; Array out_shape = data->shape; Tensor out = compute( - out_shape, [&](const Array& out_index) { + out_shape, + [&](const Array& out_index) { Array len_index; auto tid = out_index[axis]; auto bid = out_index[1 - axis]; len_index.push_back(bid); - PrimExpr ret = tvm::if_then_else( - tvm::cast(valid_length->dtype, tid) >= valid_length(len_index), - tvm::tir::make_const(data->dtype, mask_value), data(out_index)); + PrimExpr ret = + tvm::if_then_else(tvm::cast(valid_length->dtype, tid) >= valid_length(len_index), + tvm::tir::make_const(data->dtype, mask_value), data(out_index)); return ret; - }, name, tag); + }, + name, tag); return out; } /*! -* \brief Take elements from an array along an axis. -* -* \param a The source array. -* \param indices The indices of the values to extract. -* \param axis The axis over which to select values. By default, -* the flattened input array is used. -* \param mode The mode for handling out of bound indices. -* \param name The name of the operation. -* \param tag The tag to mark the operation. -* -* \return A Tensor whose op member is the take operation -*/ -inline Tensor take(const Tensor& a, - const Tensor& indices, - int axis, - std::string mode = "clip", - std::string name = "T_take", - std::string tag = kInjective) { + * \brief Take elements from an array along an axis. + * + * \param a The source array. + * \param indices The indices of the values to extract. + * \param axis The axis over which to select values. By default, + * the flattened input array is used. + * \param mode The mode for handling out of bound indices. + * \param name The name of the operation. + * \param tag The tag to mark the operation. + * + * \return A Tensor whose op member is the take operation + */ +inline Tensor take(const Tensor& a, const Tensor& indices, int axis, std::string mode = "clip", + std::string name = "T_take", std::string tag = kInjective) { if (axis < 0) { axis += static_cast(a->shape.size()); } @@ -801,30 +776,32 @@ inline Tensor take(const Tensor& a, } if (mode == "clip") { return compute( - out_shape, [&](const Array& out_index) { + out_shape, + [&](const Array& out_index) { Array indices_position; - for (size_t j = axis; j < static_cast(axis+indices_len); ++j) { + for (size_t j = axis; j < static_cast(axis + indices_len); ++j) { indices_position.push_back(out_index[j]); } Array real_indices; for (size_t j = 0; j < static_cast(axis); ++j) { real_indices.push_back(out_index[j]); } - auto idx = tvm::min(tvm::max(0, indices(indices_position)), - axis_dim - 1); + auto idx = tvm::min(tvm::max(0, indices(indices_position)), axis_dim - 1); real_indices.push_back(idx); for (size_t j = axis + indices_len; j < out_index.size(); ++j) { real_indices.push_back(out_index[j]); } return a(real_indices); - }, name, tag); + }, + name, tag); } else if (mode == "fast") { LOG(WARNING) << "Fast mode segfaults when there are out-of-bounds indices. " "Make sure input indices are in bound"; return compute( - out_shape, [&](const Array& out_index) { + out_shape, + [&](const Array& out_index) { Array indices_position; - for (size_t j = axis; j < static_cast(axis+indices_len); ++j) { + for (size_t j = axis; j < static_cast(axis + indices_len); ++j) { indices_position.push_back(out_index[j]); } Array real_indices; @@ -836,12 +813,14 @@ inline Tensor take(const Tensor& a, real_indices.push_back(out_index[j]); } return a(real_indices); - }, name, tag); + }, + name, tag); } else { // mode == "wrap" return compute( - out_shape, [&](const Array& out_index) { + out_shape, + [&](const Array& out_index) { Array indices_position; - for (size_t j = axis; j < static_cast(axis+indices_len); ++j) { + for (size_t j = axis; j < static_cast(axis + indices_len); ++j) { indices_position.push_back(out_index[j]); } Array real_indices; @@ -854,82 +833,78 @@ inline Tensor take(const Tensor& a, real_indices.push_back(out_index[j]); } return a(real_indices); - }, name, tag); + }, + name, tag); } } /*! -* \brief Return the elements, either from x or y, depending on the condition. -* -* \param condition The condition array. -* \param x First array to be selected. -* \param y Second array to be selected. -* \param name The name of the operation. -* \param tag The tag to mark the operation. -* -* \return A Tensor selected from x or y depending on condition. -*/ -inline Tensor where(const Tensor& condition, - const Tensor& x, - const Tensor& y, - std::string name = "T_where", - std::string tag = kBroadcast) { + * \brief Return the elements, either from x or y, depending on the condition. + * + * \param condition The condition array. + * \param x First array to be selected. + * \param y Second array to be selected. + * \param name The name of the operation. + * \param tag The tag to mark the operation. + * + * \return A Tensor selected from x or y depending on condition. + */ +inline Tensor where(const Tensor& condition, const Tensor& x, const Tensor& y, + std::string name = "T_where", std::string tag = kBroadcast) { CHECK_EQ(x->shape.size(), y->shape.size()) - << "x and y must have the same shape.Got different number of dimension: " - << x->shape.size() << " vs " << y->shape.size(); - CHECK_EQ(x->dtype, y->dtype) << "x and y must have the same dtype: " - << x->dtype << " vs " << y->dtype; + << "x and y must have the same shape.Got different number of dimension: " << x->shape.size() + << " vs " << y->shape.size(); + CHECK_EQ(x->dtype, y->dtype) << "x and y must have the same dtype: " << x->dtype << " vs " + << y->dtype; Array oshape = x->shape; Tensor out; if (condition->shape.size() != 1) { CHECK_EQ(condition->shape.size(), x->shape.size()) - << "condition array must be either have the same shape as x or to be a " - "1-D array.Got different number of dimension: " - << condition->shape.size() << " vs " << x->shape.size(); + << "condition array must be either have the same shape as x or to be a " + "1-D array.Got different number of dimension: " + << condition->shape.size() << " vs " << x->shape.size(); out = compute( - oshape, [&](const Array& indices) { - return tvm::tir::SelectNode::make(condition(indices) != 0, x(indices), y(indices)); - }, name, tag); + oshape, + [&](const Array& indices) { + return tvm::tir::Select(condition(indices) != 0, x(indices), y(indices)); + }, + name, tag); } else { CHECK_EQ(topi::GetConstInt(condition->shape[0]), topi::GetConstInt(x->shape[0])) - << "If condition is 1-D, the first dimension must be the same as x: " - << condition->shape[0] << " vs " << x->shape[0]; + << "If condition is 1-D, the first dimension must be the same as x: " << condition->shape[0] + << " vs " << x->shape[0]; out = compute( - oshape, [&](const Array& indices) { - Array condition_idx{indices[0]}; - return tvm::tir::SelectNode::make(condition(condition_idx) != 0, - x(indices), y(indices)); - }, name, tag); + oshape, + [&](const Array& indices) { + Array condition_idx{indices[0]}; + return tvm::tir::Select(condition(condition_idx) != 0, x(indices), y(indices)); + }, + name, tag); } return out; } /*! -* \brief Creates an operation to repeat elements of an array -* -* \param x The input tensor -* \param repeats The number of repetitions for each element -* \param axis The axis along which to repeat values (allows -* negative indices as offsets from the last dimension) -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the repeat operation -*/ -inline Tensor repeat(const Tensor& x, - int repeats, - int axis, - std::string name = "T_repeat", + * \brief Creates an operation to repeat elements of an array + * + * \param x The input tensor + * \param repeats The number of repetitions for each element + * \param axis The axis along which to repeat values (allows + * negative indices as offsets from the last dimension) + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the repeat operation + */ +inline Tensor repeat(const Tensor& x, int repeats, int axis, std::string name = "T_repeat", std::string tag = kBroadcast) { int ndim = static_cast(x->shape.size()); CHECK(-ndim - 1 <= axis && axis <= ndim) - << "repeat only accepts `axis` in [-data.ndim - 1, data.ndim]" - << ", but got axis = " << axis - << ", and data.ndim = " << ndim; - CHECK(repeats >= 1) - << "repeat only accepts `repeats >= 1`" - << ", but got repeats = " << repeats; + << "repeat only accepts `axis` in [-data.ndim - 1, data.ndim]" + << ", but got axis = " << axis << ", and data.ndim = " << ndim; + CHECK(repeats >= 1) << "repeat only accepts `repeats >= 1`" + << ", but got repeats = " << repeats; if (axis < 0) { // Calculate offset from last dimension axis += ndim; @@ -944,32 +919,32 @@ inline Tensor repeat(const Tensor& x, } return compute( - new_shape, [&](const Array& indices) { - Array idx; - for (size_t i = 0; i < static_cast(axis); ++i) { - idx.push_back(indices[i]); - } - idx.push_back(indexdiv(indices[axis], repeats)); - for (size_t i = axis + 1; i < indices.size(); ++i) { - idx.push_back(indices[i]); - } - return x(idx); - }, name, tag); + new_shape, + [&](const Array& indices) { + Array idx; + for (size_t i = 0; i < static_cast(axis); ++i) { + idx.push_back(indices[i]); + } + idx.push_back(indexdiv(indices[axis], repeats)); + for (size_t i = axis + 1; i < indices.size(); ++i) { + idx.push_back(indices[i]); + } + return x(idx); + }, + name, tag); } /*! -* \brief Creates an operation to tile elements of an array -* -* \param x The input tensor -* \param reps The number of times for repeating the tensor -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the tile operation -*/ -inline Tensor tile(const Tensor& x, - Array reps, - std::string name = "T_tile", + * \brief Creates an operation to tile elements of an array + * + * \param x The input tensor + * \param reps The number of times for repeating the tensor + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the tile operation + */ +inline Tensor tile(const Tensor& x, Array reps, std::string name = "T_tile", std::string tag = kBroadcast) { size_t ndim = x->shape.size(); size_t rdim = reps.size(); @@ -983,60 +958,99 @@ inline Tensor tile(const Tensor& x, reps_shape.push_back(reps[i]); } } else if (ndim > rdim) { - for (size_t i = 0; i < ndim; ++i) - data_shape.push_back(x->shape[i]); - for (size_t i = 0; i < (ndim - rdim); ++i) - reps_shape.push_back(1); - for (size_t i = 0; i < rdim; ++i) - reps_shape.push_back(reps[i]); + for (size_t i = 0; i < ndim; ++i) data_shape.push_back(x->shape[i]); + for (size_t i = 0; i < (ndim - rdim); ++i) reps_shape.push_back(1); + for (size_t i = 0; i < rdim; ++i) reps_shape.push_back(reps[i]); } else { - for (size_t i = 0; i < (rdim - ndim); ++i) - data_shape.push_back(1); - for (size_t i = 0; i < ndim; ++i) - data_shape.push_back(x->shape[i]); - for (size_t i = 0; i < rdim; ++i) - reps_shape.push_back(reps[i]); + for (size_t i = 0; i < (rdim - ndim); ++i) data_shape.push_back(1); + for (size_t i = 0; i < ndim; ++i) data_shape.push_back(x->shape[i]); + for (size_t i = 0; i < rdim; ++i) reps_shape.push_back(reps[i]); } - for (size_t i = 0; i < tdim; ++i) - new_shape.push_back(data_shape[i] * reps_shape[i]); + for (size_t i = 0; i < tdim; ++i) new_shape.push_back(data_shape[i] * reps_shape[i]); if (is_empty_shape(new_shape)) { - return compute(new_shape, - [&](const Array& indices) { return tvm::cast(x->dtype, 0);}, - name, tag); + return compute( + new_shape, [&](const Array& indices) { return tvm::cast(x->dtype, 0); }, name, tag); } else { return compute( - new_shape, [&](const Array& indices) { - Array idx; - if (ndim >= rdim) { - for (size_t i = 0; i < ndim; ++i) - idx.push_back(indexmod(indices[i], x->shape[i])); - } else { - for (size_t i = 0; i < ndim; ++i) - idx.push_back(indexmod(indices[rdim - ndim + i], x->shape[i])); - } - return x(idx); - }, name, tag); + new_shape, + [&](const Array& indices) { + Array idx; + if (ndim >= rdim) { + for (size_t i = 0; i < ndim; ++i) idx.push_back(indexmod(indices[i], x->shape[i])); + } else { + for (size_t i = 0; i < ndim; ++i) + idx.push_back(indexmod(indices[rdim - ndim + i], x->shape[i])); + } + return x(idx); + }, + name, tag); } } /*! -* \brief Gather elements from a n-dimension array. -* -* \param data The source array. -* \param indices The indices of the values to extract. -* \param name The name of the operation. -* \param tag The tag to mark the operation. -* -* \return A Tensor whose op member is the gather_nd operation -*/ -inline Tensor gather_nd(const Tensor& data, - const Tensor& indices, - std::string name = "T_gather_nd", + * \brief Gather values along given axis from given indices. + * + * \param data The input data to the operator. + * \param axis The axis along which to index. + * \param indices The indices of values to gather. + * \param name The name of the operation. + * \param tag The tag to mark the operation. + * + * \return A Tensor whose op member is the gather operation + */ +inline Tensor gather(const Tensor& data, int axis, const Tensor& indices, + std::string name = "T_gather", std::string tag = kInjective) { + size_t ndim_d = data->shape.size(); + size_t ndim_i = indices->shape.size(); + CHECK_GE(ndim_d, 1) << "Cannot gather from a scalar."; + CHECK_EQ(ndim_d, ndim_i); + CHECK_GE(axis, 0); + CHECK_LT(axis, ndim_d); + size_t indices_dim_i = static_cast(GetConstInt(indices->shape[axis])); + CHECK_GE(indices_dim_i, 1); + CHECK(indices->dtype.is_int()); + + Array out_shape; + for (size_t i = 0; i < ndim_i; ++i) { + out_shape.push_back(indices->shape[i]); + } + + return compute( + out_shape, + [&](const Array& out_index) { + Array indices_position; + for (size_t i = 0; i < ndim_i; ++i) { + indices_position.push_back(out_index[i]); + } + Array real_indices; + for (size_t i = 0; i < ndim_i; ++i) { + if (i == (size_t)axis) { + real_indices.push_back(indices(indices_position)); + } else { + real_indices.push_back(indices_position[i]); + } + } + return data(real_indices); + }, + name, tag); +} + +/*! + * \brief Gather elements from a n-dimension array. + * + * \param data The source array. + * \param indices The indices of the values to extract. + * \param name The name of the operation. + * \param tag The tag to mark the operation. + * + * \return A Tensor whose op member is the gather_nd operation + */ +inline Tensor gather_nd(const Tensor& data, const Tensor& indices, std::string name = "T_gather_nd", std::string tag = kInjective) { size_t ndim_d = data->shape.size(); size_t ndim_i = indices->shape.size(); - CHECK_GT(ndim_i, 1) << "indices tensor must have at least 2 dimensions"; + CHECK_GE(ndim_i, 1) << "indices tensor must have at least 1 dimensions"; size_t indices_dim0 = static_cast(GetConstInt(indices->shape[0])); CHECK_LE(indices_dim0, ndim_d) << "dim 0 of indices tensor must be no more " << "than dimensions of data tensor"; @@ -1051,27 +1065,31 @@ inline Tensor gather_nd(const Tensor& data, out_shape.push_back(make_const(DataType::Int(32), 1)); } return compute( - out_shape, [&](const Array& out_index) { - Array indices_position; - indices_position.push_back(0); - for (size_t i = 0; i < ndim_i - 1; ++i) { - indices_position.push_back(out_index[i]); - } - Array real_indices; - for (size_t i = 0; i < indices_dim0; ++i) { - indices_position.Set(0, make_const(DataType::Int(32), i)); - if (indices->dtype.is_int()) { - real_indices.push_back(indices(indices_position)); - } else { - real_indices.push_back( - tvm::cast(tvm::DataType::Int(32), indices(indices_position))); - } - } - for (size_t i = ndim_i - 1; i < out_index.size(); ++i) { - real_indices.push_back(out_index[i]); + out_shape, + [&](const Array& out_index) { + Array indices_position; + indices_position.push_back(0); + for (size_t i = 0; i < ndim_i - 1; ++i) { + indices_position.push_back(out_index[i]); + } + Array real_indices; + for (size_t i = 0; i < indices_dim0; ++i) { + indices_position.Set(0, make_const(DataType::Int(32), i)); + if (indices->dtype.is_int()) { + real_indices.push_back(indices(indices_position)); + } else { + real_indices.push_back(tvm::cast(tvm::DataType::Int(32), indices(indices_position))); } + } + if (real_indices.size() == ndim_d) { return data(real_indices); - }, name, tag); + } + for (size_t i = ndim_i - 1; i < out_index.size(); ++i) { + real_indices.push_back(out_index[i]); + } + return data(real_indices); + }, + name, tag); } /*! @@ -1089,18 +1107,13 @@ inline Tensor gather_nd(const Tensor& data, * * \return A Tensor whose op member is the matmul operation */ -inline tvm::te::Tensor matmul(const tvm::te::Tensor& A, - const tvm::te::Tensor& B, - bool trans_a = false, - bool trans_b = false, - std::string name = "T_matmul", - std::string tag = kMatMul) { - tvm::Array output_shape{A->shape[trans_a ? 1 : 0], - B->shape[trans_b ? 0 : 1]}; +inline tvm::te::Tensor matmul(const tvm::te::Tensor& A, const tvm::te::Tensor& B, + bool trans_a = false, bool trans_b = false, + std::string name = "T_matmul", std::string tag = kMatMul) { + tvm::Array output_shape{A->shape[trans_a ? 1 : 0], B->shape[trans_b ? 0 : 1]}; auto k = tvm::te::reduce_axis(tvm::Range{0, A->shape[trans_a ? 0 : 1]}, "k"); auto l = [&](tvm::tir::Var i, tvm::tir::Var j) { - return tvm::sum((trans_a ? A[k][i] : A[i][k]) * (trans_b ? B[j][k] : B[k][j]), - {k}); + return tvm::sum((trans_a ? A[k][i] : A[i][k]) * (trans_b ? B[j][k] : B[k][j]), {k}); }; return tvm::te::compute(output_shape, l, name, tag); } @@ -1116,45 +1129,35 @@ inline tvm::te::Tensor matmul(const tvm::te::Tensor& A, * * \return A Tensor computing the result */ -inline Tensor tensordot(const Tensor& A, - const tvm::te::Tensor& B, - int axes = 2, - std::string name = "T_tensordot", - std::string tag = kMatMul) { +inline Tensor tensordot(const Tensor& A, const tvm::te::Tensor& B, int axes = 2, + std::string name = "T_tensordot", std::string tag = kMatMul) { CHECK_GE(A->shape.size(), axes); CHECK_GE(B->shape.size(), axes); Array output_shape(A->shape.begin(), A->shape.end() + (-axes)); - for (auto it = B->shape.begin() + axes; it != B->shape.end(); ++it) - output_shape.push_back(*it); + for (auto it = B->shape.begin() + axes; it != B->shape.end(); ++it) output_shape.push_back(*it); Array iter_vars; for (int i = 0; i < axes; ++i) iter_vars.push_back(reduce_axis(Range(0, B->shape[i]), "k" + std::to_string(i))); - auto func = - [&A, &B, &iter_vars, axes] - (const Array& input_indices) { - Array A_indices( - input_indices.begin(), - input_indices.begin() + (A->shape.size() - axes)); - for (auto& v : iter_vars) - A_indices.push_back(v); - - Array B_indices; - for (auto& v : iter_vars) - B_indices.push_back(v); - - auto it = input_indices.begin() + (A->shape.size() - axes); - for (; it != input_indices.end(); ++it) - B_indices.push_back(*it); - - // Some passes don't like reductions with empty axis, so avoid it here - if (iter_vars.empty()) - return A(A_indices) * B(B_indices); - else - return sum(A(A_indices) * B(B_indices), iter_vars); - }; + auto func = [&A, &B, &iter_vars, axes](const Array& input_indices) { + Array A_indices(input_indices.begin(), + input_indices.begin() + (A->shape.size() - axes)); + for (auto& v : iter_vars) A_indices.push_back(v); + + Array B_indices; + for (auto& v : iter_vars) B_indices.push_back(v); + + auto it = input_indices.begin() + (A->shape.size() - axes); + for (; it != input_indices.end(); ++it) B_indices.push_back(*it); + + // Some passes don't like reductions with empty axis, so avoid it here + if (iter_vars.empty()) + return A(A_indices) * B(B_indices); + else + return sum(A(A_indices) * B(B_indices), iter_vars); + }; return compute(output_shape, func, name, tag); } @@ -1171,11 +1174,8 @@ inline Tensor tensordot(const Tensor& A, * * \return A Tensor computing the result */ -inline Tensor tensordot(const Tensor& A, - const tvm::te::Tensor& B, - Array A_axes, - Array B_axes, - std::string name = "T_tensordot", +inline Tensor tensordot(const Tensor& A, const tvm::te::Tensor& B, Array A_axes, + Array B_axes, std::string name = "T_tensordot", std::string tag = kMatMul) { CHECK_EQ(A_axes.size(), B_axes.size()); @@ -1191,47 +1191,42 @@ inline Tensor tensordot(const Tensor& A, output_shape.push_back(B->shape[i]); Array iter_vars; - for (unsigned i = 0; i < B_axes_val.size(); ++i) - iter_vars.push_back(reduce_axis(Range(0, B->shape[B_axes_val[i]]), "k" + std::to_string(i))); - - auto func = - [&A, &B, &iter_vars, A_axes_val, B_axes_val] - (const Array& input_indices) { - int idx_input = 0; - Array A_indices; - for (unsigned i = 0; i < A->shape.size(); ++i) { - auto axes_pos = std::find(A_axes_val.begin(), A_axes_val.end(), i); - if (axes_pos == A_axes_val.end()) - A_indices.push_back(input_indices[idx_input++]); - else - A_indices.push_back(iter_vars[axes_pos - A_axes_val.begin()]); - } + for (unsigned i = 0; i < B_axes_val.size(); ++i) + iter_vars.push_back(reduce_axis(Range(0, B->shape[B_axes_val[i]]), "k" + std::to_string(i))); + + auto func = [&A, &B, &iter_vars, A_axes_val, B_axes_val](const Array& input_indices) { + int idx_input = 0; + Array A_indices; + for (unsigned i = 0; i < A->shape.size(); ++i) { + auto axes_pos = std::find(A_axes_val.begin(), A_axes_val.end(), i); + if (axes_pos == A_axes_val.end()) + A_indices.push_back(input_indices[idx_input++]); + else + A_indices.push_back(iter_vars[axes_pos - A_axes_val.begin()]); + } - Array B_indices; - for (unsigned i = 0; i < B->shape.size(); ++i) { - auto axes_pos = std::find(B_axes_val.begin(), B_axes_val.end(), i); - if (axes_pos == B_axes_val.end()) - B_indices.push_back(input_indices[idx_input++]); - else - B_indices.push_back(iter_vars[axes_pos - B_axes_val.begin()]); - } - return sum(A(A_indices) * B(B_indices), iter_vars); - }; + Array B_indices; + for (unsigned i = 0; i < B->shape.size(); ++i) { + auto axes_pos = std::find(B_axes_val.begin(), B_axes_val.end(), i); + if (axes_pos == B_axes_val.end()) + B_indices.push_back(input_indices[idx_input++]); + else + B_indices.push_back(iter_vars[axes_pos - B_axes_val.begin()]); + } + return sum(A(A_indices) * B(B_indices), iter_vars); + }; return compute(output_shape, func, name, tag); } -inline Tensor arange(const PrimExpr& start, - const PrimExpr& stop, - const PrimExpr& step, - DataType dtype, - std::string name = "T_arange", - std::string tag = kInjective) { - PrimExpr num_elem = tvm::cast(tvm::DataType::Int(32), tvm::ceil( - tvm::cast(tvm::DataType::Float(32), stop - start) / step)); +inline Tensor arange(const PrimExpr& start, const PrimExpr& stop, const PrimExpr& step, + DataType dtype, std::string name = "T_arange", std::string tag = kInjective) { + PrimExpr num_elem = tvm::cast( + tvm::DataType::Int(32), tvm::ceil(tvm::cast(tvm::DataType::Float(32), stop - start) / step)); Array shape; - return compute({num_elem}, [&](const Array& indices) { - return tvm::cast(dtype, start + step * indices[0]); - }, name, tag); + return compute( + {num_elem}, + [&](const Array& indices) { return tvm::cast(dtype, start + step * indices[0]); }, name, + tag); } /*! @@ -1243,33 +1238,33 @@ inline Tensor arange(const PrimExpr& start, * \param tag output tensor tag. * \return A tensor with shape in \p dst_layout */ -inline Tensor layout_transform(const Tensor& src, - const std::string& src_layout, +inline Tensor layout_transform(const Tensor& src, const std::string& src_layout, const std::string& dst_layout, const std::string name = "T_layout_trans", const std::string tag = kInjective) { - Layout src_layout_struct = LayoutNode::make(src_layout); - Layout dst_layout_struct = LayoutNode::make(dst_layout); + Layout src_layout_struct(src_layout); + Layout dst_layout_struct(dst_layout); if (src_layout_struct.Equals(dst_layout_struct)) { return src; } CHECK(src_layout_struct.defined() && dst_layout_struct.defined()) - << "cannot convert from/to undefined layout"; + << "cannot convert from/to undefined layout"; auto layout_converter = tir::BijectiveLayout(src_layout_struct, dst_layout_struct); - CHECK(layout_converter.defined()) - << "cannot convert from " << src_layout << " to " << dst_layout; + CHECK(layout_converter.defined()) << "cannot convert from " << src_layout << " to " << dst_layout; Array dst_shape = layout_converter.ForwardShape(src->shape); return compute( - dst_shape, [&](const Array& dst_indices) { - Array dst_indices_expr(dst_indices.begin(), dst_indices.end()); - Array src_indices = layout_converter.BackwardIndex(dst_indices_expr); - return src(src_indices); - }, name, tag); + dst_shape, + [&](const Array& dst_indices) { + Array dst_indices_expr(dst_indices.begin(), dst_indices.end()); + Array src_indices = layout_converter.BackwardIndex(dst_indices_expr); + return src(src_indices); + }, + name, tag); } /*! @@ -1280,20 +1275,21 @@ inline Tensor layout_transform(const Tensor& src, * \param tag output tensor tag. * \return Tensor of input shape. */ -inline Tensor shape(const Tensor& src, - DataType dtype, - const std::string name = "T_shape", +inline Tensor shape(const Tensor& src, DataType dtype, const std::string name = "T_shape", const std::string tag = kInjective) { int ndim = static_cast(src->shape.size()); Array out_shape{ndim}; - return compute(out_shape, [&](const Array& indices) { - auto idx = indices[0]; - PrimExpr ret = 0; - for (int i = 0; i < ndim; ++i) { - ret = tvm::if_then_else(idx == i, src->shape[i], ret); - } - return tvm::cast(dtype, ret); - }, name, tag); + return compute( + out_shape, + [&](const Array& indices) { + auto idx = indices[0]; + PrimExpr ret = 0; + for (int i = 0; i < ndim; ++i) { + ret = tvm::if_then_else(idx == i, src->shape[i], ret); + } + return tvm::cast(dtype, ret); + }, + name, tag); } /*! @@ -1304,19 +1300,21 @@ inline Tensor shape(const Tensor& src, * \param tag output tensor tag. * \return Tensor of input shape. */ -inline Tensor ndarray_size(const Tensor& src, - const DataType& dtype, +inline Tensor ndarray_size(const Tensor& src, const DataType& dtype, const std::string& name = "ndarray_size", const std::string& tag = kInjective) { int ndim = static_cast(src->shape.size()); Array out_ndarray_size = {1}; - return compute(out_ndarray_size, [&](const Array& indices) { - PrimExpr ret = 1; - for (int i = 0; i < ndim; ++i) { - ret *= src->shape[i]; - } - return tvm::cast(dtype, ret); - }, name, tag); + return compute( + out_ndarray_size, + [&](const Array& indices) { + PrimExpr ret = 1; + for (int i = 0; i < ndim; ++i) { + ret *= src->shape[i]; + } + return tvm::cast(dtype, ret); + }, + name, tag); } /*! @@ -1332,14 +1330,9 @@ inline Tensor ndarray_size(const Tensor& src, * \param tag output tensor tag. * \return one-hot tensor. */ -inline Tensor one_hot(const Tensor& indices, - const PrimExpr on_value, - const PrimExpr off_value, - int depth, - int axis, - const DataType& dtype, - const std::string name = "T_one_hot", - const std::string tag = kInjective) { +inline Tensor one_hot(const Tensor& indices, const PrimExpr on_value, const PrimExpr off_value, + int depth, int axis, const DataType& dtype, + const std::string name = "T_one_hot", const std::string tag = kInjective) { Array oshape; int ndim = indices->shape.size() + 1; int indices_index = 0; @@ -1354,19 +1347,70 @@ inline Tensor one_hot(const Tensor& indices, PrimExpr on_value_cast = cast(dtype, on_value); PrimExpr off_value_cast = cast(dtype, off_value); - return compute(oshape, [&](const Array& iter_vars) { - Array indices_indices; - for (size_t i = 0; i < iter_vars.size(); i++) { - if (static_cast(i) == true_axis) { - continue; - } + return compute( + oshape, + [&](const Array& iter_vars) { + Array indices_indices; + for (size_t i = 0; i < iter_vars.size(); i++) { + if (static_cast(i) == true_axis) { + continue; + } - indices_indices.push_back(iter_vars[i]); - } + indices_indices.push_back(iter_vars[i]); + } + + auto idx = iter_vars[true_axis]; + return tir::Select(indices(indices_indices) == idx, on_value_cast, off_value_cast); + }, + name, tag); +} - auto idx = iter_vars[true_axis]; - return tir::SelectNode::make(indices(indices_indices) == idx, on_value_cast, off_value_cast); - }, name, tag); +/*! + * \brief Get a dense tensor. + * \param sparse_indices sparse_indices[i] contains sparse_values[i] will be placed. + * \param output_shape is the shape of the dense output tensor . + * \param sparse_values is a 0-D or 1-D tensor. Values for each row of sparse_indices. + * \param default_value is a 0-D tensor. Defaults to zero. + * \param name output tensor name. + * \param tag output tensor tag. + * \return Tensor of output_shape. + */ +inline Tensor sparse_to_dense(const Tensor& sparse_indices, const Array& output_shape, + const Tensor& sparse_values, const PrimExpr& default_value, + const std::string name = "T_sparse_to_dense", + const std::string tag = kInjective) { + CHECK(sparse_indices->dtype.is_int()) << "sparse_indices only accepts integer values"; + CHECK_LE(sparse_indices->shape.size(), 3) << "sparse_indices tensor should be 0D, 1D, or 2D only"; + CHECK_LE(sparse_values->shape.size(), 2) << "sparse_values tensor should be 0D or 1D only"; + + const auto rank_sparse_indices = static_cast(sparse_indices->shape.size()); + Array oshape; + for (auto l : output_shape) { + oshape.push_back(l); + } + return compute( + oshape, + [&](const Array& indices) { + PrimExpr ret = default_value; + if (0 == rank_sparse_indices) { + ret = if_then_else(indices[0] == sparse_indices[0], sparse_values[0], ret); + } else if (1 == rank_sparse_indices) { + for (int j = 0; j < GetConstInt(sparse_indices->shape[0]); j++) { + ret = if_then_else(indices[0] == sparse_indices[j], sparse_values[j], ret); + } + } else { + for (int j = 0; j < GetConstInt(sparse_indices->shape[0]); j++) { + PrimExpr aggregate_condition; + for (int k = 0; k < GetConstInt(sparse_indices->shape[1]); k++) { + PrimExpr comparision = indices[k] == sparse_indices[j][k]; + aggregate_condition = 0 == k ? comparision : aggregate_condition && comparision; + } + ret = if_then_else(aggregate_condition, sparse_values[j], ret); + } + } + return ret; + }, + name, tag); } } // namespace topi diff --git a/topi/include/topi/vision/reorg.h b/topi/include/topi/vision/reorg.h index 06931e424de3..5bd79f67f052 100644 --- a/topi/include/topi/vision/reorg.h +++ b/topi/include/topi/vision/reorg.h @@ -24,11 +24,11 @@ #ifndef TOPI_VISION_REORG_H_ #define TOPI_VISION_REORG_H_ -#include #include #include #include #include +#include #include #include @@ -39,18 +39,16 @@ using namespace tvm; using namespace tvm::te; /*! -* \brief Reorg operation -* -* \param data The input tensor. Can be any dimension -* \param stride The input integer used as stride in reorg operation -* \param name The name of the operation -* \param tag The tag to mark the operation -* -* \return A Tensor whose op member is the reorg operation -*/ -inline Tensor reorg(const Tensor &data, - int stride = 1, - std::string name = "tensor", + * \brief Reorg operation + * + * \param data The input tensor. Can be any dimension + * \param stride The input integer used as stride in reorg operation + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the reorg operation + */ +inline Tensor reorg(const Tensor& data, int stride = 1, std::string name = "tensor", std::string tag = "reorg_output") { auto input_shape = data->shape; @@ -60,15 +58,14 @@ inline Tensor reorg(const Tensor &data, int w_in = GetConstInt(input_shape[3]); int out_c = c_in / (stride * stride); - auto out = tvm::te::compute(input_shape, - [&](Var b, Var k, Var j, Var i) { - return data(b * stride * stride, - indexmod(k, out_c) * stride * stride, - (j*stride + indexdiv(indexdiv(k, out_c), stride)) * stride, - (i*stride + indexmod(indexdiv(k, out_c), stride))); - }, - name, - tag); + auto out = tvm::te::compute( + input_shape, + [&](Var b, Var k, Var j, Var i) { + return data(b * stride * stride, indexmod(k, out_c) * stride * stride, + (j * stride + indexdiv(indexdiv(k, out_c), stride)) * stride, + (i * stride + indexmod(indexdiv(k, out_c), stride))); + }, + name, tag); out_c = c_in * stride * stride; int out_h = h_in / stride; diff --git a/topi/include/topi/x86/bnn.h b/topi/include/topi/x86/bnn.h index 53b7a8e0739e..a59d30da3dce 100644 --- a/topi/include/topi/x86/bnn.h +++ b/topi/include/topi/x86/bnn.h @@ -24,10 +24,10 @@ #ifndef TOPI_X86_BNN_H_ #define TOPI_X86_BNN_H_ -#include #include -#include +#include #include +#include namespace topi { using namespace tvm; @@ -35,14 +35,14 @@ using namespace tvm::te; namespace x86 { /*! -* \brief Create a generic schedule for binarize_pack -* -* \param target The target to generate a schedule for. -* \param outs The output tensors. -* -* \return A schedule for the given ops. -*/ -inline Schedule schedule_binarize_pack(const Target &target, const Array& outs) { + * \brief Create a generic schedule for binarize_pack + * + * \param target The target to generate a schedule for. + * \param outs The output tensors. + * + * \return A schedule for the given ops. + */ +inline Schedule schedule_binarize_pack(const Target& target, const Array& outs) { Array out_ops; for (auto t : outs) { out_ops.push_back(t->op); @@ -67,14 +67,14 @@ inline Schedule schedule_binarize_pack(const Target &target, const Array } /*! -* \brief Create a generic schedule for binary_dense -* -* \param target The target to generate a schedule for. -* \param outs The output tensors. -* -* \return A schedule for the given ops. -*/ -inline Schedule schedule_binary_dense(const Target &target, const Array& outs) { + * \brief Create a generic schedule for binary_dense + * + * \param target The target to generate a schedule for. + * \param outs The output tensors. + * + * \return A schedule for the given ops. + */ +inline Schedule schedule_binary_dense(const Target& target, const Array& outs) { Array out_ops; for (auto t : outs) { out_ops.push_back(t->op); diff --git a/topi/include/topi/x86/default.h b/topi/include/topi/x86/default.h index 9b6efa511d8d..07337810a694 100644 --- a/topi/include/topi/x86/default.h +++ b/topi/include/topi/x86/default.h @@ -24,11 +24,11 @@ #ifndef TOPI_X86_DEFAULT_H_ #define TOPI_X86_DEFAULT_H_ -#include #include +#include +#include #include #include -#include namespace topi { using namespace tvm; @@ -36,16 +36,15 @@ using namespace tvm::te; namespace x86 { /*! -* \brief Helper to create a default x86 schedule for the given ops. -* -* \param target The target to generate a schedule for. -* \param outs The output tensors. -* \param auto_inline Whether to apply the auto inline step. -* -* \return A schedule for the given ops. -*/ -inline Schedule MakeDefaultSchedule(const Target &target, - const Array& outs, + * \brief Helper to create a default x86 schedule for the given ops. + * + * \param target The target to generate a schedule for. + * \param outs The output tensors. + * \param auto_inline Whether to apply the auto inline step. + * + * \return A schedule for the given ops. + */ +inline Schedule MakeDefaultSchedule(const Target& target, const Array& outs, bool auto_inline) { Array out_ops; for (auto t : outs) { @@ -66,7 +65,7 @@ inline Schedule MakeDefaultSchedule(const Target &target, if (axis.size() == 4) { auto n = axis[0]; auto c = axis[1]; - auto fused = detail::Fuse(s[x], { n, c }); // for nhwc layout, fuse n and h + auto fused = detail::Fuse(s[x], {n, c}); // for nhwc layout, fuse n and h s[x].parallel(fused); } else { s[x].parallel(axis[0]); @@ -76,26 +75,26 @@ inline Schedule MakeDefaultSchedule(const Target &target, } /*! -* \brief Create a default x86 schedule for the given ops. -* -* \param target The target to generate a schedule for. -* \param outs The output tensors. -* -* \return A schedule for the given ops. -*/ -inline Schedule default_schedule(const Target &target, const Array& outs) { + * \brief Create a default x86 schedule for the given ops. + * + * \param target The target to generate a schedule for. + * \param outs The output tensors. + * + * \return A schedule for the given ops. + */ +inline Schedule default_schedule(const Target& target, const Array& outs) { return MakeDefaultSchedule(target, outs, false); } /*! -* \brief Create a default x86 schedule for the given ops, with auto inline -* -* \param target The target to generate a schedule for. -* \param outs The output tensors. -* -* \return A schedule for the given ops. -*/ -inline Schedule default_schedule_auto_inline(const Target &target, const Array& outs) { + * \brief Create a default x86 schedule for the given ops, with auto inline + * + * \param target The target to generate a schedule for. + * \param outs The output tensors. + * + * \return A schedule for the given ops. + */ +inline Schedule default_schedule_auto_inline(const Target& target, const Array& outs) { return MakeDefaultSchedule(target, outs, true); } diff --git a/topi/include/topi/x86/injective.h b/topi/include/topi/x86/injective.h index 182140d68c5c..069a97170816 100644 --- a/topi/include/topi/x86/injective.h +++ b/topi/include/topi/x86/injective.h @@ -24,10 +24,10 @@ #ifndef TOPI_X86_INJECTIVE_H_ #define TOPI_X86_INJECTIVE_H_ -#include #include -#include +#include #include +#include namespace topi { using namespace tvm; @@ -48,7 +48,7 @@ inline Schedule schedule_injective_from_existing(Schedule sch, const Tensor& out if (axis.size() == 4) { auto n = axis[0]; auto c = axis[1]; - auto fused = detail::Fuse(sch[out], { n, c }); // for nhwc layout, fuse n and h + auto fused = detail::Fuse(sch[out], {n, c}); // for nhwc layout, fuse n and h sch[out].parallel(fused); } else { sch[out].parallel(axis[0]); @@ -57,14 +57,14 @@ inline Schedule schedule_injective_from_existing(Schedule sch, const Tensor& out } /*! -* \brief Create an x86 schedule for the given injective ops. -* -* \param target The target to generate a schedule for. -* \param outs The output tensors. -* -* \return A schedule for the given ops. -*/ -inline Schedule schedule_injective(const Target &target, const Array& outs) { + * \brief Create an x86 schedule for the given injective ops. + * + * \param target The target to generate a schedule for. + * \param outs The output tensors. + * + * \return A schedule for the given ops. + */ +inline Schedule schedule_injective(const Target& target, const Array& outs) { Array out_ops; for (auto t : outs) { out_ops.push_back(t->op); diff --git a/topi/python/topi/__init__.py b/topi/python/topi/__init__.py index f1019e667e81..56c3a740b843 100644 --- a/topi/python/topi/__init__.py +++ b/topi/python/topi/__init__.py @@ -39,6 +39,7 @@ from .transform import * from .broadcast import * from .sort import * +from .scatter import * from .argwhere import * from . import generic from . import nn @@ -48,7 +49,6 @@ from . import mali from . import bifrost from . import intel_graphics -from . import opengl from . import util from . import rocm from . import vision diff --git a/topi/python/topi/arm_cpu/__init__.py b/topi/python/topi/arm_cpu/__init__.py index eb05dd839e32..e121fbc7ec6d 100644 --- a/topi/python/topi/arm_cpu/__init__.py +++ b/topi/python/topi/arm_cpu/__init__.py @@ -25,3 +25,4 @@ from .bitserial_conv2d import * from .bitserial_dense import * from .injective import * +from . import cortex_m7 diff --git a/topi/python/topi/arm_cpu/bitserial_conv2d.py b/topi/python/topi/arm_cpu/bitserial_conv2d.py index b7da66f9168f..ac1ac45c1b38 100644 --- a/topi/python/topi/arm_cpu/bitserial_conv2d.py +++ b/topi/python/topi/arm_cpu/bitserial_conv2d.py @@ -270,8 +270,9 @@ def _instr(index): return irb.get() # body, reset, update return _instr(0), _instr(1), _instr(2) - with tvm.target.build_config(offset_factor=1, partition_const_loop=True): - return te.decl_tensor_intrin(z.op, _intrin_func, binds={w: Wb, x:Xb, z:Zb}) + buffer_params = {"offset_factor": 1} + return te.decl_tensor_intrin( + z.op, _intrin_func, binds={w: Wb, x:Xb, z:Zb}, default_buffer_params=buffer_params) # ARM specific schedule that using custom microkernel def _schedule_spatial_conv2d_nhwc(cfg, s, data_pad, data_vec, kernel_vec, diff --git a/topi/python/topi/arm_cpu/conv2d.py b/topi/python/topi/arm_cpu/conv2d.py index 25b338e06b5f..4faee42f75cc 100644 --- a/topi/python/topi/arm_cpu/conv2d.py +++ b/topi/python/topi/arm_cpu/conv2d.py @@ -31,6 +31,7 @@ conv2d_spatial_pack_nhwc, \ schedule_conv2d_spatial_pack_nchw, \ schedule_conv2d_spatial_pack_nhwc +from .cortex_m7.conv2d import direct_simd @autotvm.register_topi_compute("conv2d_nchw_spatial_pack.arm_cpu") @@ -166,15 +167,20 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, out_dtype, til idxm(b*VP + bb, nW) * m + nu], name='d') - # transform kernel - if pre_computed: - U = kernel + if autotvm.GLOBAL_SCOPE.in_tuning: + VC = cfg['tile_k'].size[-1] + kvshape = (KH + tile_size - 1, KW + tile_size - 1, idxd(CO, VC), CI, VC) + U = tvm.te.placeholder(kvshape, kernel.dtype, name="U") else: - r_kh = te.reduce_axis((0, KH), 'r_kh') - r_kw = te.reduce_axis((0, KW), 'r_kw') - U = te.compute((alpha, alpha, idxd(K, VK), C, VK), lambda eps, nu, k, c, kk: - te.sum(kernel[k * VK + kk][c][r_kh][r_kw].astype(out_dtype) * - G[eps][r_kh] * G[nu][r_kw], axis=[r_kh, r_kw]), name='U') + # transform kernel + if pre_computed: + U = kernel + else: + r_kh = te.reduce_axis((0, KH), 'r_kh') + r_kw = te.reduce_axis((0, KW), 'r_kw') + U = te.compute((alpha, alpha, idxd(K, VK), C, VK), lambda eps, nu, k, c, kk: + te.sum(kernel[k * VK + kk][c][r_kh][r_kw].astype(out_dtype) * + G[eps][r_kh] * G[nu][r_kw], axis=[r_kh, r_kw]), name='U') # transform image r_eps = te.reduce_axis((0, alpha), 'r_eps') @@ -425,3 +431,15 @@ def _callback(op): traverse_inline(s, outs[0].op, _callback) return s + +@autotvm.register_topi_compute("conv2d_direct_simd.arm_cpu") +def conv2d_direct_simd(cfg, data, kernel, strides, padding, dilation, out_dtype): + """Compute conv2d with SIMD (v7e-m).""" + return direct_simd.conv2d_direct_simd_compute( + cfg, data, kernel, strides, padding, dilation, out_dtype) + + +@autotvm.register_topi_schedule("conv2d_direct_simd.arm_cpu") +def schedule_conv2d_direct_simd(cfg, outs): + """Create schedule for conv2d_direct_simd""" + return direct_simd.conv2d_direct_simd_nhwc_schedule(cfg, outs) diff --git a/topi/python/topi/arm_cpu/conv2d_alter_op.py b/topi/python/topi/arm_cpu/conv2d_alter_op.py index 3d194cce6534..3206168d51bd 100644 --- a/topi/python/topi/arm_cpu/conv2d_alter_op.py +++ b/topi/python/topi/arm_cpu/conv2d_alter_op.py @@ -59,6 +59,10 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): data, kernel = tinfos out_dtype = out_type.dtype + # We only perform layout alteration for NCHW data layout. + if data_layout == "NHWC": + return None + # Extract data types data_tensor, kernel_tensor = tinfos data_dtype = data_tensor.dtype @@ -113,7 +117,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): weight_expr = relay.reshape(weight_expr, newshape=(KH + tile_size - 1, KW + tile_size - 1, - idxd(CO, VC), VC, CI)) + CO // VC, VC, CI)) weight_expr = relay.transpose(weight_expr, axes=[0, 1, 2, 4, 3]) new_attrs['tile_size'] = tile_size diff --git a/topi/python/topi/arm_cpu/conv2d_spatial_pack.py b/topi/python/topi/arm_cpu/conv2d_spatial_pack.py index 3bb9dc73e2db..8cf8401e7a07 100644 --- a/topi/python/topi/arm_cpu/conv2d_spatial_pack.py +++ b/topi/python/topi/arm_cpu/conv2d_spatial_pack.py @@ -109,12 +109,15 @@ def conv2d_spatial_pack_nchw(cfg, data, kernel, strides, padding, dilation, data_pad[n][ci][h*VH*HSTR+vh][w*VW*WSTR+vw], name='data_vec') - if pre_packed: - kernel_vec = kernel + if autotvm.GLOBAL_SCOPE.in_tuning: + kernel_vec = tvm.te.placeholder(kvshape, kernel.dtype, name="kernel") else: - kernel_vec = te.compute(kvshape, lambda co, ci, kh, kw, vc: - kernel[co*VC+vc][ci][kh][kw], - name='kernel_vec') + if pre_packed: + kernel_vec = kernel + else: + kernel_vec = te.compute(kvshape, lambda co, ci, kh, kw, vc: + kernel[co*VC+vc][ci][kh][kw], + name='kernel_vec') ci = te.reduce_axis((0, CI), name='ci') kh = te.reduce_axis((0, KH), name='kh') @@ -152,13 +155,13 @@ def schedule_conv2d_spatial_pack_nchw(cfg, s, data_vec, kernel_vec, cfg["ann_reduce"].apply(s, conv, [kh, kw], axis_lens=[get_const_int(kh.dom.extent), get_const_int(kw.dom.extent)], - max_unroll=16, + max_unroll=None, cfg=cfg) cfg["ann_spatial"].apply(s, conv, [vh, vw, vc], axis_lens=[cfg['tile_oh'].size[-1], cfg['tile_ow'].size[-1], cfg['tile_co'].size[-1]], - max_unroll=16, + max_unroll=None, cfg=cfg) # schedule fusion @@ -187,12 +190,8 @@ def schedule_conv2d_spatial_pack_nchw(cfg, s, data_vec, kernel_vec, s[data_vec].parallel(h) if kernel_vec.op.name == 'kernel_vec': - co, _, _, _, _ = s[kernel_vec].op.axis - if autotvm.GLOBAL_SCOPE.in_tuning: - # kernel packing will be pre-computed during compilation, so we skip - # this part to make tuning records correct - s[kernel_vec].pragma(co, 'debug_skip_region') - else: + if not autotvm.GLOBAL_SCOPE.in_tuning: + co, _, _, _, _ = s[kernel_vec].op.axis s[kernel_vec].parallel(co) elif kernel_vec.op.name == 'kernel_vec_conv2d_transpose': # for conv2d transpose co, _, _, _, _ = s[kernel_vec].op.axis @@ -267,9 +266,13 @@ def conv2d_spatial_pack_nhwc(cfg, data, kernel, strides, padding, dilation, out_ data_vec = te.compute(dvshape, lambda n, oho, owo, ohi, owi, ic: data_pad[n][oho*OHI*HSTR+ohi][owo*OWI*WSTR+owi][ic], name='data_vec') - kernel_vec = te.compute(kvshape, lambda oco, kh, kw, ic, oci: \ - kernel[kh][kw][ic][oco*OCI+oci], - name='kernel_vec') + + if autotvm.GLOBAL_SCOPE.in_tuning: + kernel_vec = tvm.te.placeholder(kvshape, kernel.dtype, name="kernel") + else: + kernel_vec = te.compute(kvshape, lambda oco, kh, kw, ic, oci: \ + kernel[kh][kw][ic][oco*OCI+oci], + name='kernel_vec') ic = te.reduce_axis((0, IC), name='ic') kh = te.reduce_axis((0, KH), name='kh') @@ -339,12 +342,13 @@ def schedule_conv2d_spatial_pack_nhwc(cfg, s, op, output): s[kernel_vec].compute_at(s[conv], compat_axis) s[data_vec].compute_at(s[conv], compat_axis) - # schedule kernel pack - oco, kh, kw, ic, oci = kernel_vec.op.axis - s[kernel_vec].vectorize(oci) - s[kernel_vec].unroll(ic) - if cfg['compat'].val == 2: - s[kernel_vec].parallel(oco) + if not autotvm.GLOBAL_SCOPE.in_tuning: + # schedule kernel pack + oco, kh, kw, ic, oci = kernel_vec.op.axis + s[kernel_vec].vectorize(oci) + s[kernel_vec].unroll(ic) + if cfg['compat'].val == 2: + s[kernel_vec].parallel(oco) # schedule data pack if data_vec.op.name == 'data_vec_undilated': diff --git a/topi/python/topi/opengl/__init__.py b/topi/python/topi/arm_cpu/cortex_m7/__init__.py similarity index 66% rename from topi/python/topi/opengl/__init__.py rename to topi/python/topi/arm_cpu/cortex_m7/__init__.py index 0ddbea0d9791..631c5f7ff447 100644 --- a/topi/python/topi/opengl/__init__.py +++ b/topi/python/topi/arm_cpu/cortex_m7/__init__.py @@ -14,13 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +"""Schedules specialized for cortex-m7.""" -# pylint: disable=redefined-builtin, wildcard-import -"""CUDA specific declaration and schedules.""" -from __future__ import absolute_import as _abs -from .conv2d_nchw import schedule_conv2d_nchw -from .injective import schedule_injective, schedule_elemwise, schedule_broadcast -from .softmax import schedule_softmax -from .dense import schedule_dense -from .pooling import schedule_pool, schedule_adaptive_pool +from . import conv2d diff --git a/topi/python/topi/arm_cpu/cortex_m7/conv2d/__init__.py b/topi/python/topi/arm_cpu/cortex_m7/conv2d/__init__.py new file mode 100644 index 000000000000..cc4faf97b126 --- /dev/null +++ b/topi/python/topi/arm_cpu/cortex_m7/conv2d/__init__.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Conv2d implementations for cortex-m7.""" + +from . import direct_simd diff --git a/topi/python/topi/arm_cpu/cortex_m7/conv2d/direct.py b/topi/python/topi/arm_cpu/cortex_m7/conv2d/direct.py new file mode 100644 index 000000000000..7d3e945fef14 --- /dev/null +++ b/topi/python/topi/arm_cpu/cortex_m7/conv2d/direct.py @@ -0,0 +1,175 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Direct implementation of conv2d.""" + +import tvm +from tvm import autotvm +from tvm.autotvm.task import deserialize_args +from topi.nn.conv2d import conv2d_nchw, conv2d_nhwc +from topi.util import get_const_tuple, get_const_int, traverse_inline + +def conv2d_direct(*args, **kwargs): + """Schedule function for directly-scheduled conv2d.""" + assert not kwargs, "Do not support kwargs in template function call" + args = deserialize_args(args) + data, kernel = args[:2] + layout = args[-2] + cfg = autotvm.get_config() + args = [cfg] + args + conv = conv2d_direct_compute(*args) + if layout == 'NHWC': + sched = conv2d_direct_nhwc_schedule(cfg, [data, kernel, conv]) + elif layout == 'NCHW': + sched = conv2d_direct_nchw_schedule(cfg, [data, kernel, conv]) + else: + raise RuntimeError(f'unsupported data layout "{layout}"') + return sched, [data, kernel, conv] + + +conv2d_direct.template_key = 'direct' +conv2d_direct.default_data_layout = 'NHWC' +conv2d_direct.default_kernel_layout = 'HWIO' + +@autotvm.register_topi_compute('conv2d_direct.micro_dev') +def conv2d_direct_compute(*args): + layout = args[-2] + if layout == 'NHWC': + return _conv2d_direct_nhwc_compute(*args) + if layout == 'NCHW': + return _conv2d_direct_nchw_compute(*args) + + raise RuntimeError(f'unsupported data layout "{layout}"') + + +def _conv2d_direct_nhwc_compute(cfg, data, kernel, strides, padding, dilation, layout, out_dtype): + assert layout == 'NHWC' + conv = conv2d_nhwc(data, kernel, strides, padding, dilation, out_dtype) + + # Config Space Definition + N, H, W, CI = get_const_tuple(data.shape) + KH, KW, _, CO = get_const_tuple(kernel.shape) + n, oh, ow, co = cfg.axis(N), cfg.axis(H), cfg.axis(W), cfg.axis(CO) + kh, kw, ci = cfg.reduce_axis(KH), cfg.reduce_axis(KW), cfg.reduce_axis(CI) + + # TODO should we add a max_factor attr to these splits? + co, vc = cfg.define_split('tile_co', co, num_outputs=2) + oh, vh = cfg.define_split('tile_oh', oh, num_outputs=2) + ow, vw = cfg.define_split('tile_ow', ow, num_outputs=2) + + cfg.define_reorder('reorder_0', + [n, co, oh, ow, ci, kh, kw, vh, vw, vc], + policy='candidate', candidate=[ + [n, co, oh, ow, ci, kh, kw, vh, vw, vc], + [n, co, oh, ow, ci, kh, kw, vc, vh, vw], + [n, co, oh, ow, ci, vh, vw, vc, kh, kw], + [n, co, oh, ow, ci, vc, vh, vw, kh, kw]]) + + cfg.define_annotate('ann_reduce', [kh, kw], policy='try_unroll') + cfg.define_annotate('ann_spatial', [vh, vw, vc], policy='try_unroll') + + cfg.define_knob('auto_unroll_max_step', [0, 2, 4, 8, 16, 32]) + cfg.define_knob('unroll_explicit', [0, 1]) + + return conv + + +def _conv2d_direct_nchw_compute(cfg, data, kernel, strides, padding, dilation, layout, out_dtype): + assert layout == 'NCHW' + conv = conv2d_nchw(data, kernel, strides, padding, dilation, out_dtype) + + ########################### + # Config Space Definition # + ########################### + cfg.define_knob('auto_unroll_max_step', [0, 2, 4, 8, 16, 32]) + cfg.define_knob('unroll_explicit', [0, 1]) + + return conv + + +@autotvm.register_topi_schedule('conv2d_direct_nhwc.micro_dev') +def conv2d_direct_nhwc_schedule(cfg, outs): + """Schedule function for directly-scheduled conv2d on NHWC layout.""" + sched = tvm.create_schedule([x.op for x in outs]) + + def _callback(op): + if 'conv2d_nhwc' not in op.tag: + return + + ### extract tensors ### + output = op.output(0) + conv = op + data_vec = conv.input_tensors[0] + kernel = conv.input_tensors[1] # pylint: disable=unused-variable + last = outs[0] # pylint: disable=unused-variable + + # tile reduction axes + n, oh, ow, co = sched[conv].op.axis + kh, kw, ci = sched[conv].op.reduce_axis + # NOTE we can't inline data padding in the SIMD path, because it + # introduces conditionals in the inner loop. + data_pad = data_vec.op + sched[data_pad].compute_inline() + + co, vc = cfg['tile_co'].apply(sched, conv, co) + oh, vh = cfg['tile_oh'].apply(sched, conv, oh) + ow, vw = cfg['tile_ow'].apply(sched, conv, ow) + cfg['reorder_0'].apply(sched, conv, [n, co, oh, ow, ci, kh, kw, vh, vw, vc]) + cfg['ann_reduce'].apply(sched, conv, [kh, kw], + axis_lens=[get_const_int(kh.dom.extent), + get_const_int(kw.dom.extent)], + max_unroll=8, + cfg=cfg) + cfg['ann_spatial'].apply(sched, conv, [vh, vw, vc], + axis_lens=[cfg['tile_oh'].size[-1], + cfg['tile_ow'].size[-1], + cfg['tile_co'].size[-1]], + max_unroll=8, + cfg=cfg) + + kernel_scope = n # this is the scope to attach global config inside this kernel + + # tune unroll + sched[output].pragma(kernel_scope, 'auto_unroll_max_step', cfg['auto_unroll_max_step'].val) + sched[output].pragma(kernel_scope, 'unroll_explicit', cfg['unroll_explicit'].val) + + traverse_inline(sched, outs[-1].op, _callback) + return sched + + +@autotvm.register_topi_schedule('conv2d_direct_nchw.micro_dev') +def conv2d_direct_nchw_schedule(cfg, outs): + """Schedule function for Cortex-M7 direct implementation of conv2d.""" + # use default schedule + sched = tvm.create_schedule([x.op for x in outs]) + + conv = outs[-1].op + output = conv.output(0) + data_vec = conv.input_tensors[0] + data_pad = data_vec.op + sched[data_pad].compute_inline() + + # TODO add more schedule opts (similar to the NHWC template) + + n, _, _, _ = sched[conv].op.axis + kernel_scope = n # this is the scope to attach global config inside this kernel + + # tune unroll + sched[output].pragma(kernel_scope, 'auto_unroll_max_step', cfg['auto_unroll_max_step'].val) + sched[output].pragma(kernel_scope, 'unroll_explicit', cfg['unroll_explicit'].val) + + return sched diff --git a/topi/python/topi/arm_cpu/cortex_m7/conv2d/direct_simd.py b/topi/python/topi/arm_cpu/cortex_m7/conv2d/direct_simd.py new file mode 100644 index 000000000000..fd411251272e --- /dev/null +++ b/topi/python/topi/arm_cpu/cortex_m7/conv2d/direct_simd.py @@ -0,0 +1,163 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, no-value-for-parameter +"""Direct implementation of conv2d.""" + +from tvm import autotvm +from tvm.autotvm.task import deserialize_args +from tvm import te +from topi.util import simplify, traverse_inline +from topi.nn.pad import pad +from topi.nn.util import get_pad_tuple + +from ..micro_kernel.gemm import ( + intrin_gemm_MxKxN, gemm_MxKxN_impl, +) + +def conv2d_direct_simd(*args, **kwargs): + """Defines the Cortex-M7 SIMD implementation of conv2d.""" + assert not kwargs, "Do not support kwargs in template function call" + args = deserialize_args(args) + data, kernel = args[:2] + layout = args[-2] + cfg = autotvm.get_config() + args = [cfg] + args + assert layout == 'NHWC' + conv = conv2d_direct_simd_compute(*args) + sched = conv2d_direct_simd_nhwc_schedule(cfg, [data, kernel, conv]) + return sched, [data, kernel, conv] + + +conv2d_direct_simd.template_key = 'direct_simd' +conv2d_direct_simd.default_data_layout = 'NHWC' +conv2d_direct_simd.default_kernel_layout = 'HWOI' + +def conv2d_direct_simd_compute(cfg, data, kernel, strides, padding, dilation, out_dtype): + """Compute function for Cortex-M7 SIMD implementation of conv2d.""" + assert isinstance(strides, int) or len(strides) == 2 + assert isinstance(dilation, int) or len(dilation) == 2 + + if isinstance(strides, int): + stride_h = stride_w = strides + else: + stride_h, stride_w = strides + + if isinstance(dilation, int): + dilation_h = dilation_w = dilation + else: + dilation_h, dilation_w = dilation + + batch_size, in_height, in_width, in_channels = data.shape + kernel_h, kernel_w, out_channels, _ = kernel.shape + + # compute the output shape + dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 + dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 + pad_top, pad_left, pad_down, pad_right = get_pad_tuple( + padding, (dilated_kernel_h, dilated_kernel_w)) + out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1) + out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1) + + pad_before = [0, pad_top, pad_left, 0] + pad_after = [0, pad_down, pad_right, 0] + padded_data = pad(data, pad_before, pad_after, name='padded_data') + + rc = te.reduce_axis((0, in_channels), name='rc') + ry = te.reduce_axis((0, kernel_h), name='ry') + rx = te.reduce_axis((0, kernel_w), name='rx') + + conv = te.compute( + (batch_size, out_height, out_width, out_channels), + lambda nn, yy, xx, ff: te.sum( + padded_data[nn, yy * stride_h + ry * dilation_h, + xx * stride_w + rx * dilation_w, rc].astype(out_dtype) * + kernel[ry, rx, ff, rc].astype(out_dtype), axis=[ry, rx, rc]), + name='conv2d', tag='conv2d_nhwc') + + ########################### + # Config Space Definition # + ########################### + n, oh, ow, co = (cfg.axis(batch_size.value), + cfg.axis(out_height.value), + cfg.axis(out_width.value), + cfg.axis(out_channels.value)) + kh, kw, ci = (cfg.reduce_axis(kernel_h.value), + cfg.reduce_axis(kernel_w.value), + cfg.reduce_axis(in_channels.value)) + + assert in_channels.value % 4 == 0 + owo, owi = cfg.define_split('tile_ow', ow, policy='factors', num_outputs=2) + cio, cii = cfg.define_split('tile_ci', ci, policy='factors', num_outputs=2, + filter=lambda x: x.size[-1] % 4 == 0) + coo, coi = cfg.define_split('tile_co', co, policy='factors', num_outputs=2) + + cfg.define_reorder('reorder_0_simd', + [n, oh, owo, owi, coo, coi, kh, kw, cio, cii], + policy='candidate', candidate=[ + [n, oh, kh, kw, owo, coo, cio, owi, coi, cii], + [n, oh, kh, kw, coo, owo, cio, owi, coi, cii], + [n, kh, kw, oh, owo, coo, cio, owi, coi, cii], + [n, kh, kw, oh, coo, owo, cio, owi, coi, cii]]) + + cfg.define_knob('auto_unroll_max_step', [0, 2, 4, 8, 16, 32]) + cfg.define_knob('unroll_explicit', [0, 1]) + + return conv + + +def conv2d_direct_simd_nhwc_schedule(cfg, outs): + """Schedule function for Cortex-M7 SIMD implementation of conv2d.""" + sched = te.create_schedule([x.op for x in outs]) + + def _callback(op): + if 'conv2d_nhwc' not in op.tag: + return + + # extract tensors + output = op.output(0) + conv = op + data_vec = conv.input_tensors[0] + kernel = conv.input_tensors[1] # pylint: disable=unused-variable + last = outs[0] # pylint: disable=unused-variable + + # tile reduction axes + n, oh, ow, co = sched[conv].op.axis + kh, kw, ci = sched[conv].op.reduce_axis + + M = cfg['tile_ow'].size[-1] + K = cfg['tile_ci'].size[-1] + N = cfg['tile_co'].size[-1] + + owo, owi = cfg['tile_ow'].apply(sched, conv, ow) + cio, cii = cfg['tile_ci'].apply(sched, conv, ci) + coo, coi = cfg['tile_co'].apply(sched, conv, co) + + cfg['reorder_0_simd'].apply(sched, conv, [n, oh, owo, owi, coo, coi, kh, kw, cio, cii]) + + gemm, uniq_id = intrin_gemm_MxKxN(M, K, N, data_vec.dtype, output.dtype) + sched[output].tensorize(owi, gemm) + sched[output].pragma(n, 'import_c', gemm_MxKxN_impl(M, K, N, uniq_id)) + + # this is the scope to attach global config inside this kernel + kernel_scope = n + + # tune unroll + sched[output].pragma(kernel_scope, 'auto_unroll_max_step', cfg['auto_unroll_max_step'].val) + sched[output].pragma(kernel_scope, 'unroll_explicit', cfg['unroll_explicit'].val) + + traverse_inline(sched, outs[-1].op, _callback) + return sched diff --git a/topi/python/topi/arm_cpu/cortex_m7/micro_kernel/__init__.py b/topi/python/topi/arm_cpu/cortex_m7/micro_kernel/__init__.py new file mode 100644 index 000000000000..13a83393a912 --- /dev/null +++ b/topi/python/topi/arm_cpu/cortex_m7/micro_kernel/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/topi/python/topi/arm_cpu/cortex_m7/micro_kernel/gemm.py b/topi/python/topi/arm_cpu/cortex_m7/micro_kernel/gemm.py new file mode 100644 index 000000000000..7bd9bdb0cb1b --- /dev/null +++ b/topi/python/topi/arm_cpu/cortex_m7/micro_kernel/gemm.py @@ -0,0 +1,207 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, no-value-for-parameter +"""Defines gemm intrinsics for SIMD matrix multiplication.""" + +import random +import string + +import tvm +from tvm import te + +########################## +# MxKxN MatMul Intrinsic # +########################## + +# NOTE this is transposed matmul (A * B^T) +def intrin_gemm_MxKxN(M, K, N, in_dtype, out_dtype): + """Defines a SIMD-accelerated transposed matmul.""" + # we generate a unique ID for every intrinsic definition, to prevent name + # collisions in the generated source (e.g., if there are multiple operators + # in the same module that use the same intrinsic) + # + # TODO(weberlo, areusch): to cut down on memory usage, we should cache each intrinsic + # instantiation and include it only once, eliminating the need for unique + # IDs + UNIQ_ID_LEN = 8 + uniq_id = ''.join(random.choices(string.ascii_uppercase, k=UNIQ_ID_LEN)) + + if isinstance(M, tvm.tir.IntImm): + M = M.value + if isinstance(K, tvm.tir.IntImm): + K = K.value + if isinstance(N, tvm.tir.IntImm): + N = N.value + assert K % 4 == 0 + # TODO(weberlo, areusch): support more dtypes? + assert in_dtype == 'int8' + assert out_dtype == 'int32' + A = te.placeholder((M, K), name='a', dtype=in_dtype) + B = te.placeholder((N, K), name='b', dtype=in_dtype) + k = te.reduce_axis((0, K), name='k') + C = te.compute( + (M, N), + lambda i, j: te.sum(A[i, k].astype(out_dtype) * B[j, k].astype(out_dtype), axis=k), + name='c') + A_buf = tvm.tir.decl_buffer( + A.shape, A.dtype, + name="A", + offset_factor=1, + strides=[te.var("A_s"), 1]) + B_buf = tvm.tir.decl_buffer( + B.shape, B.dtype, + name="B", + offset_factor=1, + strides=[te.var("B_s"), 1]) + C_buf = tvm.tir.decl_buffer( + C.shape, C.dtype, + name="C", + offset_factor=1, + strides=[te.var("C_s"), 1]) + def intrin_func(ins, outs): + aa, bb = ins + cc = outs[0] + def _reduce_update(): + ib = tvm.tir.ir_builder.create() + ib.emit(tvm.tir.call_extern("int32", f"gemm_{M}x{K}x{N}_update_{uniq_id}", + aa.access_ptr("r"), + bb.access_ptr("r"), + cc.access_ptr("w"), + aa.strides[0], + bb.strides[0], + cc.strides[0])) + return ib.get() + def _reduce_reset(): + ib = tvm.tir.ir_builder.create() + ib.emit(tvm.tir.call_extern("int32", f"gemm_{M}x{K}x{N}_reset_{uniq_id}", + cc.access_ptr("w"), + cc.strides[0])) + return ib.get() + def _body(): + ib = tvm.tir.ir_builder.create() + ib.emit(tvm.tir.call_extern("int32", f"gemm_{M}x{K}x{N}_body_{uniq_id}", + aa.access_ptr("r"), + bb.access_ptr("r"), + cc.access_ptr("w"), + aa.strides[0], + bb.strides[0], + cc.strides[0])) + return ib.get() + return _body(), _reduce_reset(), _reduce_update() + + intrin_decl = te.decl_tensor_intrin( + C.op, intrin_func, binds={A: A_buf, B: B_buf, C: C_buf}) + return intrin_decl, uniq_id + + +def gemm_MxKxN_impl(M, K, N, uniq_id): + """Emit C code for gemm impl.""" + # TODO(weberlo, areusch): are there any SIMD tricks to zero out arrays quickly? + aa_pad_size = M * K + bb_pad_size = N * K + # code reference: CMSIS-NN paper (https://arxiv.org/abs/1801.06601) + cc_code = f""" +#ifdef __cplusplus +extern "C" +#endif +__STATIC_FORCEINLINE int32_t gemm_{M}x{K}x{N}_body_{uniq_id}( + int8_t *aa, int8_t *bb, int32_t *cc, + int A_stride, int B_stride, int C_stride) {{ + int16_t aa_pad[{aa_pad_size}]; + int16_t bb_pad[{bb_pad_size}]; + + for (int i = 0; i < {M}; i++) {{ + for (int j = 0; j < {K} / 4; j++) {{ + read_and_pad(&aa[i*A_stride + j*4], (int32_t*) &aa_pad[i*{K} + j*4], (int32_t*) &aa_pad[i*{K} + j*4 + 2]); + }} + }} + + for (int i = 0; i < {N}; i++) {{ + for (int j = 0; j < {K} / 4; j++) {{ + read_and_pad(&bb[i*B_stride + j*4], (int32_t*) &bb_pad[i*{K} + j*4], (int32_t*) &bb_pad[i*{K} + j*4 + 2]); + }} + }} + + for (int i = 0; i < {M}; i++) {{ + for (int j = 0; j < {N}; j++) {{ + int32_t sum = 0; + for (int l = 0; l < {K} / 2; l++) {{ + sum = __SMLAD( + *((int32_t*) &aa_pad[i*{K} + l*2]), + *((int32_t*) &bb_pad[j*{K} + l*2]), + sum); + }} + // NOTE: this is the line where `*_body` differs from `*_update`. here + // we're *setting* the result, instead of accumulating, because we know + // the `i` and `j` itervars span their entire respective axes. + cc[i*C_stride + j] = sum; + }} + }} + + return 0; +}} + +#ifdef __cplusplus +extern "C" +#endif +__STATIC_FORCEINLINE int32_t gemm_{M}x{K}x{N}_update_{uniq_id}( + int8_t *aa, int8_t *bb, int32_t *cc, + int A_stride, int B_stride, int C_stride) {{ + int16_t aa_pad[{aa_pad_size}]; + int16_t bb_pad[{bb_pad_size}]; + + for (int i = 0; i < {M}; i++) {{ + for (int j = 0; j < {K} / 4; j++) {{ + read_and_pad(&aa[i*A_stride + j*4], (int32_t*) &aa_pad[i*{K} + j*4], (int32_t*) &aa_pad[i*{K} + j*4 + 2]); + }} + }} + + for (int i = 0; i < {N}; i++) {{ + for (int j = 0; j < {K} / 4; j++) {{ + read_and_pad(&bb[i*B_stride + j*4], (int32_t*) &bb_pad[i*{K} + j*4], (int32_t*) &bb_pad[i*{K} + j*4 + 2]); + }} + }} + + for (int i = 0; i < {M}; i++) {{ + for (int j = 0; j < {N}; j++) {{ + int32_t sum = 0; + for (int l = 0; l < {K} / 2; l++) {{ + sum = __SMLAD( + *((int32_t*) &aa_pad[i*{K} + l*2]), + *((int32_t*) &bb_pad[j*{K} + l*2]), + sum); + }} + cc[i*C_stride + j] += sum; + }} + }} + + return 0; +}} + +#ifdef __cplusplus +extern "C" +#endif +__STATIC_FORCEINLINE int32_t gemm_{M}x{K}x{N}_reset_{uniq_id}(int32_t *cc, int C_stride) {{ + for (int i = 0; i < {M}; i++) {{ + for (int j = 0; j < {N}; j++) {{ + cc[i*C_stride + j] = 0; + }} + }} + return 0; +}} + """ + return cc_code diff --git a/topi/python/topi/arm_cpu/tensor_intrin.py b/topi/python/topi/arm_cpu/tensor_intrin.py index 135c87d59511..bab91578e77e 100644 --- a/topi/python/topi/arm_cpu/tensor_intrin.py +++ b/topi/python/topi/arm_cpu/tensor_intrin.py @@ -107,5 +107,7 @@ def _instr(index): # body, reset, update return _instr(0), _instr(1), _instr(2) - with tvm.target.build_config(offset_factor=1, partition_const_loop=True): - return te.decl_tensor_intrin(C.op, _intrin_func, binds={data:a_buffer, kernel:b_buffer}) + buffer_params = {"offset_factor": 1} + return te.decl_tensor_intrin( + C.op, _intrin_func, binds={data:a_buffer, kernel:b_buffer}, + default_buffer_params=buffer_params) diff --git a/topi/python/topi/bifrost/conv2d.py b/topi/python/topi/bifrost/conv2d.py index 92e874afa2a5..ecc67c735a58 100644 --- a/topi/python/topi/bifrost/conv2d.py +++ b/topi/python/topi/bifrost/conv2d.py @@ -142,11 +142,7 @@ def _schedule_spatial_pack(cfg, s, output, conv, data_vec, kernel_vec): s[data_vec].unroll(vw) if isinstance(kernel_vec.op, tvm.te.ComputeOp) and kernel_vec.name == 'kernel_vec': - if autotvm.GLOBAL_SCOPE.in_tuning: - # kernel packing will be pre-computed during compilation, so we skip - # this part to make tuning records correct - s[kernel_vec].pragma(s[kernel_vec].op.axis[0], 'debug_skip_region') - else: + if not autotvm.GLOBAL_SCOPE.in_tuning: max_threads = tvm.target.Target.current(allow_none=False).max_num_threads co, ci, kh, kw, vc = s[kernel_vec].op.axis fused = s[kernel_vec].fuse(co, ci, kh, kw, vc) @@ -313,10 +309,15 @@ def upround(x, align): data_pad[n][c][h][w], name='d') - if pre_computed: - U = kernel + if autotvm.GLOBAL_SCOPE.in_tuning: + VC = cfg['tile_k'].size[-1] + kvshape = (KH + tile_size - 1, KW + tile_size - 1, tvm.tir.indexdiv(CO, VC), CI, VC) + U = tvm.te.placeholder(kvshape, kernel.dtype, name="U") else: - U = _decl_winograd_kernel_transform(kernel, tile_size, G) + if pre_computed: + U = kernel + else: + U = _decl_winograd_kernel_transform(kernel, tile_size, G) # V [alpha * alpha, C, P_round) # Perform the image transform @@ -370,12 +371,7 @@ def _schedule_winograd(cfg, s, op): s[G].compute_inline() eps, _, _, _ = s[U].op.axis y, _, _, _ = s[padded_kernel].op.axis - if autotvm.GLOBAL_SCOPE.in_tuning: - # Kernel transformation will be pre-computed during compilation, so we skip - # this part to make tuning records correct - s[U].pragma(eps, 'debug_skip_region') - s[padded_kernel].pragma(y, 'debug_skip_region') - else: + if not autotvm.GLOBAL_SCOPE.in_tuning: # Pad kernel y, x, ky, kx = s[padded_kernel].op.axis s[padded_kernel].unroll(ky) diff --git a/topi/python/topi/cuda/__init__.py b/topi/python/topi/cuda/__init__.py index 2b7a845cd9ec..90f4e6074ffc 100644 --- a/topi/python/topi/cuda/__init__.py +++ b/topi/python/topi/cuda/__init__.py @@ -25,10 +25,12 @@ from .conv2d_hwcn import * from .conv2d_int8 import * from .conv2d_winograd import * +from .conv2d_nhwc_winograd import * from .depthwise_conv2d import * from .group_conv2d_nchw import * from . import conv2d_alter_op from .conv2d_transpose_nchw import * +from .conv3d_transpose_ncdhw import * from .deformable_conv2d import * from .conv3d import * from .conv3d_winograd import * @@ -48,3 +50,5 @@ from .conv2d_nhwc_tensorcore import * from .conv3d_ndhwc_tensorcore import * from .dense_tensorcore import * +from .correlation import * +from .sparse import * diff --git a/topi/python/topi/cuda/batch_matmul.py b/topi/python/topi/cuda/batch_matmul.py index bf801820d25a..7d92edfb97b7 100644 --- a/topi/python/topi/cuda/batch_matmul.py +++ b/topi/python/topi/cuda/batch_matmul.py @@ -14,13 +14,24 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=invalid-name,too-many-locals,unused-variable +# pylint: disable=invalid-name,too-many-locals,unused-variable,unused-argument """cuda batch_matmul operators""" +import tvm +from tvm import autotvm from tvm import te from tvm.contrib import cublas +from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity +from .. import nn from ..util import traverse_inline, get_const_tuple, get_max_power2_factor -def schedule_batch_matmul(outs): +@autotvm.register_topi_compute("batch_matmul.cuda") +def batch_matmul(cfg, x, y): + """Compute conv2d with NCHW layout""" + return nn.batch_matmul(x, y) + + +@autotvm.register_topi_schedule("batch_matmul.cuda") +def schedule_batch_matmul(cfg, outs): """Schedule for batch_matmul Parameters @@ -37,7 +48,7 @@ def schedule_batch_matmul(outs): outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs s = te.create_schedule([x.op for x in outs]) - def _schedule(op): + def _schedule(cfg, op): C = op.output(0) A, B = s[C].op.input_tensors _, M, N = get_const_tuple(C.shape) @@ -51,16 +62,34 @@ def _schedule(op): C = s.outputs[0].output(0) b, y, x = s[C].op.axis - y_bn = get_max_power2_factor(M, 64) - x_bn = get_max_power2_factor(N, 64) - by, y = s[C].split(y, y_bn) - bx, x = s[C].split(x, x_bn) - y_nthreads = min(y_bn, 8) - x_nthreads = min(x_bn, 8) - ty, yi = s[C].split(y, nparts=y_nthreads) - tx, xi = s[C].split(x, nparts=x_nthreads) - thread_x = te.thread_axis((0, x_nthreads), "threadIdx.x") - thread_y = te.thread_axis((0, y_nthreads), "threadIdx.y") + k, = s[CC].op.reduce_axis + + cfg.define_split("tile_y", y, num_outputs=3) + cfg.define_split("tile_x", x, num_outputs=3) + cfg.define_split("tile_k", k, num_outputs=2) + cfg.define_knob("auto_unroll_max_step", [8, 16, 32, 64]) + target = tvm.target.Target.current() + if target.target_name in ['nvptx', 'rocm']: + # llvm-based backends cannot do non-explicit unrolling + cfg.define_knob("unroll_explicit", [1]) + else: + cfg.define_knob("unroll_explicit", [0, 1]) + + if cfg.is_fallback: + y_bn = get_max_power2_factor(M, 64) + x_bn = get_max_power2_factor(N, 64) + y_nthreads = min(y_bn, 8) + x_nthreads = min(x_bn, 8) + cfg['tile_x'] = SplitEntity([-1, x_nthreads, x_bn // x_nthreads]) + cfg['tile_y'] = SplitEntity([-1, y_nthreads, y_bn // y_nthreads]) + cfg['tile_k'] = SplitEntity([-1, 8]) + cfg['auto_unroll_max_step'] = OtherOptionEntity(16) + + by, ty, yi = cfg["tile_y"].apply(s, C, y) + bx, tx, xi = cfg["tile_x"].apply(s, C, x) + + thread_x = te.thread_axis("threadIdx.x") + thread_y = te.thread_axis("threadIdx.y") s[C].reorder(b, by, bx, ty, tx, yi, xi) s[C].bind(b, te.thread_axis("blockIdx.z")) @@ -68,38 +97,41 @@ def _schedule(op): s[C].bind(bx, te.thread_axis("blockIdx.x")) s[C].bind(ty, thread_y) s[C].bind(tx, thread_x) - s[C].pragma(yi, "auto_unroll_max_step", 16) + s[C].pragma(yi, "auto_unroll_max_step", cfg['auto_unroll_max_step'].val) + s[C].pragma(yi, 'unroll_explicit', cfg['unroll_explicit'].val) s[CC].compute_at(s[C], tx) _, yi, xi = s[CC].op.axis - k, = s[CC].op.reduce_axis - ko, ki = s[CC].split(k, 8) + ko, ki = cfg["tile_k"].apply(s, CC, k) s[CC].reorder(ko, ki, yi, xi) - s[CC].pragma(ki, "auto_unroll_max_step", 16) + s[CC].pragma(ki, "auto_unroll_max_step", cfg['auto_unroll_max_step'].val) + s[CC].pragma(ki, 'unroll_explicit', cfg['unroll_explicit'].val) s[AA].compute_at(s[CC], ko) s[AL].compute_at(s[CC], ki) s[BB].compute_at(s[CC], ko) s[BL].compute_at(s[CC], ki) _, y, k = s[AA].op.axis - ty, yi = s[AA].split(y, nparts=y_nthreads) - tx, ki = s[AA].split(k, nparts=x_nthreads) + ty, yi = s[AA].split(y, nparts=cfg["tile_y"].size[1]) + tx, ki = s[AA].split(k, nparts=cfg["tile_x"].size[1]) s[AA].reorder(ty, tx, yi, ki) s[AA].bind(ty, thread_y) s[AA].bind(tx, thread_x) - s[AA].pragma(yi, "auto_unroll_max_step", 16) + s[AA].pragma(yi, "auto_unroll_max_step", cfg['auto_unroll_max_step'].val) + s[AA].pragma(yi, 'unroll_explicit', cfg['unroll_explicit'].val) _, x, k = s[BB].op.axis - ty, xi = s[BB].split(x, nparts=y_nthreads) - tx, ki = s[BB].split(k, nparts=x_nthreads) + ty, xi = s[BB].split(x, nparts=cfg["tile_y"].size[1]) + tx, ki = s[BB].split(k, nparts=cfg["tile_x"].size[1]) s[BB].bind(ty, thread_y) s[BB].bind(tx, thread_x) s[BB].reorder(ty, tx, xi, ki) - s[BB].pragma(xi, "auto_unroll_max_step", 16) + s[BB].pragma(xi, "auto_unroll_max_step", cfg['auto_unroll_max_step'].val) + s[BB].pragma(xi, 'unroll_explicit', cfg['unroll_explicit'].val) def _callback(op): if "batch_matmul" in op.tag: - _schedule(op) + _schedule(cfg, op) traverse_inline(s, outs[0].op, _callback) return s diff --git a/topi/python/topi/cuda/conv2d.py b/topi/python/topi/cuda/conv2d.py index c7df3dc96a5e..d98d630d6415 100644 --- a/topi/python/topi/cuda/conv2d.py +++ b/topi/python/topi/cuda/conv2d.py @@ -66,8 +66,8 @@ def _callback(op): @autotvm.register_topi_compute("conv2d_cudnn.cuda") -def conv2d_cudnn(cfg, data, kernel, strides, padding, dilation, layout='NCHW', - out_dtype='float32'): +def conv2d_cudnn(cfg, data, kernel, strides, padding, dilation, groups=1, + layout='NCHW', out_dtype='float32'): """Compute conv2d using CuDNN library""" if layout == 'NCHW': tensor_format = 0 # CUDNN_TENSOR_NCHW @@ -89,7 +89,7 @@ def conv2d_cudnn(cfg, data, kernel, strides, padding, dilation, layout='NCHW', pt, pl, pb, pr = get_pad_tuple(padding, (KH, KW)) OH = (H + pt + pb - KH) // stride_h + 1 OW = (W + pl + pr - KW) // stride_w + 1 - cfg.add_flop(2 * N * OH * OW * CO * CI * ((KH - 1) * dilation_h + 1) * \ + cfg.add_flop(groups * 2 * N * OH * OW * CO * CI * ((KH - 1) * dilation_h + 1) * \ ((KW - 1) * dilation_w + 1)) if data.dtype == "int8" or kernel.dtype == "int8": @@ -107,7 +107,8 @@ def conv2d_cudnn(cfg, data, kernel, strides, padding, dilation, layout='NCHW', conv_mode=1, tensor_format=tensor_format, algo=-1, # let CUDNN choose the best algo - conv_dtype=dtype) + conv_dtype=dtype, + groups=groups) @autotvm.register_topi_schedule("conv2d_cudnn.cuda") diff --git a/topi/python/topi/cuda/conv2d_alter_op.py b/topi/python/topi/cuda/conv2d_alter_op.py index 8d9e86c192a0..c2a19054434e 100644 --- a/topi/python/topi/cuda/conv2d_alter_op.py +++ b/topi/python/topi/cuda/conv2d_alter_op.py @@ -111,6 +111,42 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): return relay.nn.contrib_conv2d_winograd_without_weight_transform( inputs[0], weight, **new_attrs) + if topi_tmpl in ('conv2d_nhwc_winograd_direct.cuda', 'conv2d_nhwc_winograd_tensorcore.cuda'): + if dilation != (1, 1): + logger.warning("Does not support weight pre-transform for dilated convolution.") + return None + + assert data_layout == "NHWC" and kernel_layout == "HWIO" + N, H, W, CI = get_const_tuple(data.shape) + KH, KW, _, CO = get_const_tuple(kernel.shape) + + # Pre-compute weight transformation in winograd + if H % 8 == 0: + tile_size = 4 + else: + tile_size = 2 + kernel_transform = relay.transpose(inputs[1], axes=[3, 2, 0, 1]) + weight = relay.nn.contrib_conv2d_winograd_weight_transform(kernel_transform, + tile_size=tile_size) + weight = relay.transpose(weight, axes=[0, 1, 3, 2]) + new_attrs['tile_size'] = tile_size + new_attrs['channels'] = CO + # Store the same config for the altered operator (workload) + new_data = data + new_weight = te.placeholder((KH + tile_size - 1, KW + tile_size - 1, CI, CO), + dtype=kernel.dtype) + if topi_tmpl == "conv2d_nhwc_winograd_direct.cuda": + new_workload = autotvm.task.args_to_workload( + [new_data, new_weight, strides, padding, dilation, out_dtype], + "conv2d_nhwc_winograd_direct_without_weight_transform.cuda") + elif topi_tmpl == "conv2d_nhwc_winograd_tensorcore.cuda": + new_workload = autotvm.task.args_to_workload( + [new_data, new_weight, strides, padding, dilation, out_dtype], + "conv2d_nhwc_winograd_tensorcore_without_weight_transform.cuda") + dispatch_ctx.update(target, new_workload, cfg) + return relay.nn.contrib_conv2d_winograd_without_weight_transform( + inputs[0], weight, **new_attrs) + if topi_tmpl == "group_conv2d_NCHWc_int8.cuda": assert data_layout == "NCHW" and kernel_layout == "OIHW" N, CI, H, W = get_const_tuple(data.shape) @@ -210,7 +246,8 @@ def _conv2d_legalize(attrs, inputs, arg_types): new_attrs['channels'] = new_out_channel out = tvm.relay.nn.conv2d(data, kernel, **new_attrs) original_out_shape = [x.value for x in output_tensor.shape] - out = relay.strided_slice(out, begin=(0, 0, 0, 0), end=original_out_shape) + out = relay.strided_slice(out, begin=relay.const([0, 0, 0, 0]), + end=relay.const(original_out_shape)) else: out = relay.nn.conv2d(data, kernel, **new_attrs) return out diff --git a/topi/python/topi/cuda/conv2d_nhwc_winograd.py b/topi/python/topi/cuda/conv2d_nhwc_winograd.py new file mode 100644 index 000000000000..2f5b85eed620 --- /dev/null +++ b/topi/python/topi/cuda/conv2d_nhwc_winograd.py @@ -0,0 +1,639 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,unused-variable,unused-argument +# pylint: disable=too-many-arguments,too-many-locals +# pylint: disable=too-many-statements +"""Winograd template for cuda backend""" + +import tvm +from tvm import te +from tvm import autotvm +from .. import nn +from ..util import get_const_int, get_const_tuple, traverse_inline +from ..nn.winograd_util import winograd_transform_matrices +from .tensor_intrin import intrin_wmma_load_matrix_A +from .tensor_intrin import intrin_wmma_load_matrix_W +from .tensor_intrin import intrin_wmma_store_matrix +from .tensor_intrin import intrin_wmma_gemm + +def _infer_tile_size(data, kernel): + """Compute the tile size""" + N, H, W, CI = get_const_tuple(data.shape) + if H % 8 == 0: + return 4 + return 2 + + +def schedule_bgemm_tensorcore(cfg, s, bgemm, data_pack, kernel_pack): + """Schedule for bgemm tensorcore""" + A = data_pack + B = kernel_pack + C = bgemm + _, _, P, out_dim = get_const_tuple(C.shape) + out_dtype = C.dtype + + # Explicit memory access + AS = s.cache_read(A, 'shared', [C]) + BS = s.cache_read(B, 'shared', [C]) + AF = s.cache_read(AS, 'wmma.matrix_a', [C]) + BF = s.cache_read(BS, 'wmma.matrix_b', [C]) + CF = s.cache_write(C, 'wmma.accumulator') + CS = s.cache_read(CF, 'shared', [C]) + + # Create tuning space + cfg.define_knob("block_row_warps", [1, 2, 4]) + cfg.define_knob("block_col_warps", [1, 2, 4]) + cfg.define_knob("warp_row_tiles", [1, 2, 4, 8]) + cfg.define_knob("warp_col_tiles", [1, 2, 4, 8]) + cfg.define_knob("chunk", [1, 2, 4, 8]) + cfg.define_knob("offset", [0, 1, 2, 4, 8]) + cfg.define_knob("offsetCS", [0, 1, 2, 4, 8]) + cfg.define_knob("vec", [1, 2, 4, 8]) + + # Ensure that the default parameters are applicable when autotvm is not in use + if (P % 16 == 0 and out_dim % 16 == 0): + cfg.define_knob("wmma_m", [16, 8, 32]) + elif (P % 32 == 0 and out_dim % 8 == 0): + cfg.define_knob("wmma_m", [32, 16, 8]) + elif (P % 8 == 0 and out_dim % 32 == 0): + cfg.define_knob("wmma_m", [8, 16, 32]) + + warp_size = 32 + wmma_k = 16 + block_row_warps = cfg["block_row_warps"].val + block_col_warps = cfg["block_col_warps"].val + warp_row_tiles = cfg["warp_row_tiles"].val + warp_col_tiles = cfg["warp_col_tiles"].val + chunk = cfg["chunk"].val + offsetAB = cfg["offset"].val + offsetCS = cfg["offsetCS"].val + wmma_m = cfg["wmma_m"].val + vec = cfg["vec"].val + + if wmma_m == 16: + wmma_n = 16 + elif wmma_m == 8: + wmma_n = 32 + elif wmma_m == 32: + wmma_n = 8 + + # Define the stride of intrin functions + AS_align = chunk * wmma_k + offsetAB + BS_align = warp_col_tiles * block_col_warps * wmma_n + offsetAB + CS_align = warp_col_tiles * block_col_warps * wmma_n + offsetCS + AS_stride = [AS_align, 1] + BS_stride = [BS_align, 1] + AF_stride = [wmma_k, 1] + BF_stride = [wmma_n * warp_col_tiles, 1] + CF_stride = [warp_col_tiles * wmma_n, 1] + CS_stride = [CS_align, 1] + block_x = te.thread_axis('blockIdx.x') + block_y = te.thread_axis('blockIdx.y') + block_z = te.thread_axis('blockIdx.z') + thread_x = te.thread_axis('threadIdx.x') + thread_y = te.thread_axis('threadIdx.y') + thread_z = te.thread_axis('threadIdx.z') + + # Schedule for computation + block_factor_b = wmma_m * warp_row_tiles * block_row_warps + block_factor_o = wmma_n * warp_col_tiles * block_col_warps + alpha_1, alpha_2, b, o = C.op.axis + block_k = s[C].fuse(alpha_1, alpha_2) + block_i, bc = s[C].split(b, factor=block_factor_b) + block_j, oc = s[C].split(o, factor=block_factor_o) + s[C].reorder(block_k, block_i, block_j, bc, oc) + t = s[C].fuse(bc, oc) + t, vi = s[C].split(t, factor=vec) + t, tx = s[C].split(t, factor=warp_size) + t, ty = s[C].split(t, factor=block_row_warps) + t, tz = s[C].split(t, factor=block_col_warps) + s[C].bind(block_k, block_z) + s[C].bind(block_i, block_x) + s[C].bind(block_j, block_y) + s[C].bind(tz, thread_z) + s[C].bind(ty, thread_y) + s[C].bind(tx, thread_x) + s[C].vectorize(vi) + + # Schedule for wmma store + s[CS].compute_at(s[C], block_j) + _, _, bb, oo = CS.op.axis + s[CS].storage_align(bb, CS_align - 1, CS_align) + bb, bbi = s[CS].split(bb, factor=wmma_m) + oo, ooi = s[CS].split(oo, factor=wmma_n) + bb, bbii = s[CS].split(bb, factor=warp_row_tiles) + oo, ooii = s[CS].split(oo, factor=warp_col_tiles) + s[CS].reorder(bb, oo, bbii, ooii, bbi, ooi) + + # Schedule for wmma computation + s[CF].compute_at(s[CS], oo) + _, _, warp_i, warp_j = CF.op.axis + warp_i, _ii = s[CF].split(warp_i, factor=wmma_m) + warp_j, _jj = s[CF].split(warp_j, factor=wmma_n) + k, = CF.op.reduce_axis + k, _k = s[CF].split(k, factor=wmma_k) + ko, ki = s[CF].split(k, factor=chunk) + s[CF].reorder(ko, ki, warp_i, warp_j, _ii, _jj, _k) + + # Schedule for wmma_matrix_a load + s[AF].compute_at(s[CF], ki) + _, _, b, i = AF.op.axis + b, b_ii = s[AF].split(b, factor=wmma_m) + i, i_jj = s[AF].split(i, factor=wmma_k) + s[AF].reorder(b, i, b_ii, i_jj) + + # Schedule for wmma_matrix_b load + s[BF].compute_at(s[CF], ki) + _, _, i, o = BF.op.axis + o, o_ii = s[BF].split(o, factor=wmma_n) + i, i_ii = s[BF].split(i, factor=wmma_k) + s[BF].reorder(i, o, i_ii, o_ii) + + # Schedule for A's(B's) shared memory load + def shared_shedule(stage, strides): + s[stage].compute_at(s[CF], ko) + _, _, xo, yo = stage.op.axis + s[stage].storage_align(xo, strides - 1, strides) + t = s[stage].fuse(xo, yo) + t, vi = s[stage].split(t, factor=vec) + t, tx = s[stage].split(t, factor=warp_size) + t, ty = s[stage].split(t, factor=block_row_warps) + _, tz = s[stage].split(t, factor=block_col_warps) + s[stage].bind(ty, thread_y) + s[stage].bind(tz, thread_z) + s[stage].bind(tx, thread_x) + s[stage].vectorize(vi) + + shared_shedule(AS, AS_align) + shared_shedule(BS, BS_align) + + shape = (wmma_m, wmma_n, wmma_k) + in_dtype = 'float16' + AL_gemm = te.placeholder((wmma_m, wmma_k), name='AL_gemm', dtype=in_dtype) + BL_gemm = te.placeholder((wmma_k, wmma_n), name='BL_gemm', dtype=in_dtype) + k_gemm = te.reduce_axis((0, wmma_k), name='k_gemm') + CL_compute = te.compute((wmma_m, wmma_n), lambda ii, jj: + te.sum(AL_gemm[ii, k_gemm].astype(out_dtype) * + BL_gemm[k_gemm, jj].astype(out_dtype), + axis=k_gemm), name='CL_compute') + + # Lower the computation loops down to TensorCore hardware intrinsics + # by mapping the tensorcore to tensor intrinsics + s[AF].tensorize(b_ii, intrin_wmma_load_matrix_A(AF_stride, AS_stride, shape, "row_major", + (wmma_m, wmma_k), (wmma_m, wmma_k), 'float16')) + s[BF].tensorize(i_ii, intrin_wmma_load_matrix_W(BF_stride, BS_stride, shape, "row_major", + (wmma_k, wmma_n), (wmma_k, wmma_n), 'float16')) + s[CF].tensorize(_ii, intrin_wmma_gemm(AL_gemm, BL_gemm, CL_compute, AF_stride, + BF_stride, CF_stride, shape)) + s[CS].tensorize(bbi, intrin_wmma_store_matrix(CS_stride, CF_stride, shape, out_dtype, + (wmma_m, wmma_n), (wmma_m, wmma_n))) + + +def schedule_bgemm_direct(cfg, s, bgemm, data_pack, kernel_pack): + """Schedule for bgemm direct""" + b1, b2, y, x = s[bgemm].op.axis + rc = s[bgemm].op.reduce_axis[0] + alpha = get_const_int(b1.dom.extent) + + # Create tuning space + cfg.define_split("tile_b", cfg.axis(alpha * alpha), num_outputs=4, + filter=lambda x: x.size[-3:] == [1, 1, 1]) + cfg.define_split("tile_y", y, num_outputs=4) + cfg.define_split("tile_x", x, num_outputs=4) + cfg.define_split("tile_rc", rc, num_outputs=2) + cfg.define_knob("offset_bgemm", [0, 1, 2, 4, 8]) + cfg.define_knob("vector_bgemm", [1, 2, 4, 8]) + offset_bgemm = cfg["offset_bgemm"].val + vector_bgemm = cfg["vector_bgemm"].val + + C = bgemm + A0, B0 = kernel_pack, data_pack + + # Designate the memory hierarchy + OL = s.cache_write(C, 'local') + AA = s.cache_read(A0, 'shared', [OL]) + BB = s.cache_read(B0, 'shared', [OL]) + + # Tile and bind spatial axes + b = s[bgemm].fuse(b1, b2) + bgemm_scope, b = s[bgemm].split(b, nparts=1) + bz, vz, tz, zi = cfg["tile_b"].apply(s, C, b) + by, vy, ty, yi = cfg["tile_y"].apply(s, C, y) + bx, vx, tx, xi = cfg["tile_x"].apply(s, C, x) + s[C].bind(bz, te.thread_axis("blockIdx.z")) + s[C].bind(by, te.thread_axis("blockIdx.y")) + s[C].bind(bx, te.thread_axis("blockIdx.x")) + s[C].bind(vz, te.thread_axis("vthread")) + s[C].bind(vy, te.thread_axis("vthread")) + s[C].bind(vx, te.thread_axis("vthread")) + s[C].bind(tz, te.thread_axis("threadIdx.z")) + s[C].bind(ty, te.thread_axis("threadIdx.y")) + s[C].bind(tx, te.thread_axis("threadIdx.x")) + s[C].reorder(bgemm_scope, bz, by, bx, vz, vy, vx, tz, ty, tx, zi, yi, xi) + + # Tile reduction axes + s[OL].compute_at(s[C], tx) + b1, b2, y, x = s[OL].op.axis + b = s[OL].fuse(b1, b2) + rc, = s[OL].op.reduce_axis + rco, rci = cfg['tile_rc'].apply(s, OL, rc) + s[OL].reorder(rco, b, y, x, rci) + + s[AA].compute_at(s[OL], rco) + _, _, k, n = s[AA].op.axis + AA_align = offset_bgemm + cfg["tile_x"].size[1] * cfg["tile_x"].size[2] * cfg["tile_x"].size[3] + s[AA].storage_align(k, AA_align - 1, AA_align) + + s[BB].compute_at(s[OL], rco) + _, _, m, k = s[BB].op.axis + BB_align = offset_bgemm + cfg["tile_rc"].size[1] + s[BB].storage_align(m, BB_align - 1, BB_align) + + # Schedule for A and B shared memory load + for load in [AA, BB]: + fused = s[load].fuse(*list(s[load].op.axis)) + fused, ti = s[load].split(fused, factor=vector_bgemm) + fused, tx = s[load].split(fused, cfg["tile_x"].size[2]) + fused, ty = s[load].split(fused, cfg["tile_y"].size[2]) + fused, tz = s[load].split(fused, cfg["tile_b"].size[2]) + s[load].bind(tz, te.thread_axis("threadIdx.z")) + s[load].bind(ty, te.thread_axis("threadIdx.y")) + s[load].bind(tx, te.thread_axis("threadIdx.x")) + s[load].vectorize(ti) + + +def nhwc_winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype, + use_tensorcore, pre_computed): + """Compute declaration for winograd""" + tile_size = _infer_tile_size(data, kernel) + N, H, W, CI = get_const_tuple(data.shape) + + if isinstance(dilation, int): + dilation_h = dilation_w = dilation + else: + dilation_h, dilation_w = dilation + + HSTR, WSTR = (strides, strides) if isinstance(strides, int) else strides + + if not pre_computed: # Kernel tensor is raw tensor, do strict check + if dilation_h != 1 or dilation_w != 1: + kernel = nn.dilate(kernel, (dilation_h, dilation_w, 1, 1)) + KH, KW, CI, CO = get_const_tuple(kernel.shape) + alpha = KW + tile_size - 1 + assert HSTR == 1 and WSTR == 1 and KH == KW + else: + # Kernel tensor is pre-transfomred. This op is created by conv2d_alter_op. + # Dilation is not supported + alpha, _, CI, CO = get_const_tuple(kernel.shape) + KH = KW = alpha + 1 - tile_size + assert HSTR == 1 and WSTR == 1 and dilation_h == 1 and dilation_w == 1 + + pt, pl, pb, pr = nn.get_pad_tuple(padding, (KH, KW)) + data_pad = nn.pad(data, (0, pt, pl, 0), (0, pb, pr, 0), name="data_pad") + + r = KW + m = tile_size + H = (H + pt + pb - KH) // HSTR + 1 + W = (W + pl + pr - KW) // WSTR + 1 + nH, nW = (H + m - 1) // m, (W + m - 1) // m + P = N * nH * nW + + # Determine whether the shape is available with tensorcore + shape_judge = (P % 16 == 0 and CI % 16 == 0 and CO % 16 == 0) or \ + (P % 8 == 0 and CI % 16 == 0 and CO % 32 == 0) or \ + (P % 32 == 0 and CI % 16 == 0 and CO % 8 == 0) + + if shape_judge and use_tensorcore: + trans_type = "float16" + else: + trans_type = data.dtype + + # Compute transform matrix + A, _, _ = winograd_transform_matrices(m, r, out_dtype) + _, B, G = winograd_transform_matrices(m, r, data.dtype) + + # Transform kernel + if not pre_computed: + # Check if we are currently tuning, if so we want to avoid counting + # prepacking in time costs. Just use a placeholder with the packed shape instead. + if autotvm.GLOBAL_SCOPE.in_tuning: + kernel_pack = te.placeholder((alpha, alpha, CI, CO), + dtype=kernel.dtype, + name='kernel_pack') + else: + r_kh = te.reduce_axis((0, KH), name='r_kh') + r_kw = te.reduce_axis((0, KW), name='r_kw') + kernel_pack = te.compute((alpha, alpha, CI, CO), lambda eps, nu, ci, co: + te.sum((kernel[r_kh][r_kw][ci][co]) * + G[eps][r_kh] * G[nu][r_kw], + axis=[r_kh, r_kw]), name='kernel_pack') + else: + kernel_pack = kernel + + idxdiv = tvm.tir.indexdiv + idxmod = tvm.tir.indexmod + + # Pack input tile + input_tile = te.compute((P, CI, alpha, alpha), lambda p, c, eps, nu: + data_pad[idxdiv(p, (nH * nW)), + idxmod(idxdiv(p, nW), nH) * m + eps, + idxmod(p, nW) * m + nu, + c], name='d') + + # Transform data + r_a = te.reduce_axis((0, alpha), 'r_a') + r_b = te.reduce_axis((0, alpha), 'r_b') + data_pack = te.compute((alpha, alpha, P, CI), lambda eps, nu, p, ci: + te.sum(input_tile[p][ci][r_a][r_b] * B[r_a][eps] * B[r_b][nu], + axis=[r_a, r_b]), name='data_pack') + + # Convert data type of input feature maps and weights for tensorcore + Transdata = te.compute( + data_pack.shape, lambda eps, nu, p, ci: data_pack[eps, nu, p, ci].astype(trans_type)) + TransFilter = te.compute( + kernel_pack.shape, lambda eps, nu, ci, co: kernel_pack[eps, nu, ci, co].astype(trans_type)) + + # Do batch gemm + ci = te.reduce_axis((0, CI), name='ci') + bgemm = te.compute((alpha, alpha, P, CO), lambda eps, nu, p, co: + te.sum((Transdata[eps][nu][p][ci]).astype(out_dtype) * + (TransFilter[eps][nu][ci][co]).astype(out_dtype), + axis=[ci]), name='bgemm') + + # Inverse transform + r_a = te.reduce_axis((0, alpha), 'r_a') + r_b = te.reduce_axis((0, alpha), 'r_a') + inverse = te.compute((P, CO, m, m), lambda p, co, vh, vw: + te.sum(bgemm[r_a][r_b][p][co] * A[r_a][vh] * A[r_b][vw], + axis=[r_a, r_b]), name='inverse') + + # Output + output = te.compute((N, H, W, CO), lambda n, h, w, co: + inverse[n * nH * nW + idxdiv(h, m) * nW + idxdiv(w, m), + co, + idxmod(h, m), + idxmod(w, m)], + name='output', tag='conv2d_nhwc_winograd') + cfg.add_flop(2 * N * CO * H * W * CI * KH * KW) + return output + + +def data_weight_transform(s, data_trans, input_tile, thread_num_trans, offset_trans, trans_tag): + """Schedule for data or kernel transform""" + kernel_align = thread_num_trans + offset_trans + indata_s = s.cache_read(input_tile, 'shared', [data_trans]) + data_l = s.cache_write(data_trans, 'local') + # Schedule for data or kernel transform + eps, nu, p, c = s[data_trans].op.axis + + block_x, thread_x = s[data_trans].split(c, thread_num_trans) + block_x = s[data_trans].fuse(p, block_x) + s[data_trans].reorder(block_x, thread_x, eps, nu) + s[data_trans].bind(thread_x, te.thread_axis("threadIdx.x")) + s[data_trans].bind(block_x, te.thread_axis("blockIdx.x")) + + s[data_l].compute_at(s[data_trans], thread_x) + eps_l, nu_l, p_l, c_l = s[data_l].op.axis + r_a, r_b = s[data_l].op.reduce_axis + block_x_l, thread_x_l = s[data_l].split(c_l, thread_num_trans) + block_x_l = s[data_l].fuse(p_l, block_x_l) + + s[data_l].reorder(block_x_l, thread_x_l, eps_l, nu_l, r_a, r_b) + + for axis in [eps_l, nu_l, r_a, r_b]: + s[data_l].unroll(axis) + + # Schedule for share memory load + s[indata_s].compute_at(s[data_l], block_x_l) + if trans_tag == "data": + p_is, c_is, eps_is, nu_is = s[indata_s].op.axis + data_align = get_const_int(eps_is.dom.extent) * \ + get_const_int(nu_is.dom.extent) + offset_trans + s[indata_s].storage_align(c_is, data_align - 1, data_align) + block_x_is, thread_x_is = s[indata_s].split(c_is, thread_num_trans) + s[indata_s].bind(thread_x_is, te.thread_axis("threadIdx.x")) + else: + eps_is, nu_is, ci_is, co_is = s[indata_s].op.axis + s[indata_s].storage_align(nu_is, kernel_align - 1, kernel_align) + block_x_is, thread_x_is = s[indata_s].split(co_is, thread_num_trans) + s[indata_s].reorder(ci_is, block_x_is, eps_is, nu_is, thread_x_is) + s[indata_s].bind(thread_x_is, te.thread_axis("threadIdx.x")) + + +def schedule_nhwc_winograd_cuda(cfg, s, output, use_tensorcore, pre_computed): + """Schedule winograd template""" + # Get stages + inverse = s[output].op.input_tensors[0] + bgemm, A = s[inverse].op.input_tensors + Transdata, TransFilter = s[bgemm].op.input_tensors + data_pack = s[Transdata].op.input_tensors[0] + kernel_pack = s[TransFilter].op.input_tensors[0] + s[Transdata].compute_inline() + s[TransFilter].compute_inline() + + input_tile, B = s[data_pack].op.input_tensors + pad_data = s[input_tile].op.input_tensors[0] + + # Define the stride of intrin functions + cfg.define_knob("thread_num_inverse", [1, 32, 64, 128, 256]) + cfg.define_knob("thread_num_data", [1, 32, 64, 128, 256]) + cfg.define_knob("thread_num_kernel", [1, 32, 64, 128, 256]) + cfg.define_knob("offset_inverse", [0, 2, 4]) + cfg.define_knob("offset_data", [0, 1, 2, 4]) + cfg.define_knob("offset_kernel", [0, 1, 2, 4]) + cfg.define_knob("inverse_in_vector", [1, 2, 4]) + + thread_num_data = cfg["thread_num_data"].val + thread_num_kernel = cfg["thread_num_kernel"].val + thread_num_inverse = cfg["thread_num_inverse"].val + offset_data = cfg["offset_data"].val + offset_kernel = cfg["offset_kernel"].val + offset_inverse = cfg["offset_inverse"].val + inverse_in_vector = cfg["inverse_in_vector"].val + + # Data transform + s[B].compute_inline() + data_weight_transform(s, data_pack, input_tile, thread_num_data, offset_data, trans_tag="data") + s[input_tile].compute_inline() + s[pad_data].compute_inline() + + # Kernel transform + if not pre_computed and not autotvm.GLOBAL_SCOPE.in_tuning: + kernel, G = s[kernel_pack].op.input_tensors + s[G].compute_inline() + data_weight_transform(s, kernel_pack, kernel, thread_num_kernel, + offset_kernel, trans_tag="kernel") + else: + kernel = kernel_pack + + if isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag: + s[kernel].compute_inline() + + b1, b2, y, x = s[bgemm].op.axis + alpha = get_const_int(b1.dom.extent) + _, _, P, CI = get_const_tuple(Transdata.shape) + _, _, _, CO = get_const_tuple(TransFilter.shape) + + # Determine whether the shape is available with tensorcore + shape_judge = (P % 16 == 0 and CI % 16 == 0 and CO % 16 == 0) or \ + (P % 8 == 0 and CI % 16 == 0 and CO % 32 == 0) or \ + (P % 32 == 0 and CI % 16 == 0 and CO % 8 == 0) + + if shape_judge and use_tensorcore: + schedule_bgemm_tensorcore(cfg, s, bgemm, Transdata, TransFilter) + else: + schedule_bgemm_direct(cfg, s, bgemm, Transdata, TransFilter) + + # Schedule inverse, output and fusion + if output.op in s.outputs: + OL = None + else: + OL = output + s[OL].set_scope('local') + output = s.outputs[0] + + s[A].compute_inline() + inverse_s = s.cache_read(bgemm, 'shared', [inverse]) + + m = alpha - 3 + 1 + offset_inverse_in = offset_inverse + vector_width_inverse_in = inverse_in_vector + + # Schedule for output + n, h, w, co = s[output].op.axis + ho, wo, hi, wi = s[output].tile(h, w, m, m) + s[output].reorder(n, ho, wo, co, hi, wi) + fused = s[output].fuse(n, ho, wo) + + block_x_s, thread_x_s = s[output].split(co, thread_num_inverse) + block_x_s = s[output].fuse(fused, block_x_s) + s[output].reorder(block_x_s, thread_x_s, hi, wi) + + if OL is not None: + s[OL].compute_inline() + + # Schedule for inverse + s[inverse].compute_at(s[output], thread_x_s) + p_inv, co_inv, eps_inv, nu_inv = s[inverse].op.axis + block_x_inv, thread_x_inv = s[inverse].split(co_inv, thread_num_inverse) + r_a, r_b = s[inverse].op.reduce_axis + for axis in [eps_inv, nu_inv, r_a, r_b]: + s[inverse].unroll(axis) + + # Schedule for share memory load + s[inverse_s].compute_at(s[output], block_x_s) + eps_inv_s, nu_inv_s, p_inv_s, co_inv_s = s[inverse_s].op.axis + inverse_in_align = offset_inverse_in + thread_num_inverse + s[inverse_s].storage_align(p_inv_s, inverse_in_align - 1, inverse_in_align) + block_x_inv_s, thread_x_inv_s = s[inverse_s].split(co_inv_s, thread_num_inverse) + block_x_inv_s = s[inverse_s].fuse(p_inv_s, block_x_inv_s) + s[inverse_s].reorder(block_x_inv_s, eps_inv_s, nu_inv_s, thread_x_inv_s) + t = s[inverse_s].fuse(eps_inv_s, nu_inv_s, thread_x_inv_s) + t, ti = s[inverse_s].split(t, factor=vector_width_inverse_in) + t, tx = s[inverse_s].split(t, factor=thread_num_inverse) + s[inverse_s].bind(tx, te.thread_axis("threadIdx.x")) + s[inverse_s].vectorize(ti) + + s[output].bind(thread_x_s, te.thread_axis("threadIdx.x")) + s[output].bind(block_x_s, te.thread_axis("blockIdx.x")) + return s + + +@autotvm.register_topi_compute("conv2d_nhwc_winograd_direct.cuda") +def conv2d_nhwc_winograd_direct(cfg, data, kernel, strides, padding, dilation, out_dtype): + """Compute conv2d with winograd for NHWC layout""" + return nhwc_winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype, + use_tensorcore=False, pre_computed=False) + + +@autotvm.register_topi_schedule("conv2d_nhwc_winograd_direct.cuda") +def schedule_conv2d_nhwc_winograd_direct(cfg, outs): + """TOPI schedule callback""" + s = te.create_schedule([x.op for x in outs]) + + def _callback(op): + if 'conv2d_nhwc_winograd' in op.tag: + schedule_nhwc_winograd_cuda(cfg, s, op.output(0), use_tensorcore=False, + pre_computed=False) + + traverse_inline(s, outs[0].op, _callback) + return s + + +@autotvm.register_topi_compute("conv2d_nhwc_winograd_tensorcore.cuda") +def conv2d_nhwc_winograd_tensorcore(cfg, data, kernel, strides, padding, dilation, out_dtype): + """Compute conv2d with winograd for NHWC layout""" + return nhwc_winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype, + use_tensorcore=True, pre_computed=False) + + +@autotvm.register_topi_schedule("conv2d_nhwc_winograd_tensorcore.cuda") +def schedule_conv2d_nhwc_winograd_tensorcore(cfg, outs): + """TOPI schedule callback""" + s = te.create_schedule([x.op for x in outs]) + + def _callback(op): + if 'conv2d_nhwc_winograd' in op.tag: + schedule_nhwc_winograd_cuda(cfg, s, op.output(0), use_tensorcore=True, + pre_computed=False) + + traverse_inline(s, outs[0].op, _callback) + return s + + +@autotvm.register_topi_compute("conv2d_nhwc_winograd_direct_without_weight_transform.cuda") +def conv2d_nhwc_winograd_direct_without_weight_transform(cfg, data, kernel, strides, + padding, dilation, out_dtype): + """Compute conv2d with winograd for NHWC layout""" + return nhwc_winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype, + use_tensorcore=False, pre_computed=True) + + +@autotvm.register_topi_schedule("conv2d_nhwc_winograd_direct_without_weight_transform.cuda") +def schedule_conv2d_nhwc_winograd_direct_without_weight_transform(cfg, outs): + """TOPI schedule callback""" + s = te.create_schedule([x.op for x in outs]) + + def _callback(op): + if 'conv2d_nhwc_winograd' in op.tag: + schedule_nhwc_winograd_cuda(cfg, s, op.output(0), use_tensorcore=False, + pre_computed=True) + + traverse_inline(s, outs[0].op, _callback) + return s + + +@autotvm.register_topi_compute("conv2d_nhwc_winograd_tensorcore_without_weight_transform.cuda") +def conv2d_nhwc_winograd_tensorcore_without_weight_transform(cfg, data, kernel, strides, + padding, dilation, out_dtype): + """Compute conv2d with winograd for NHWC layout""" + return nhwc_winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype, + use_tensorcore=True, pre_computed=True) + + +@autotvm.register_topi_schedule("conv2d_nhwc_winograd_tensorcore_without_weight_transform.cuda") +def schedule_conv2d_nhwc_winograd_tensorcore_without_weight_transform(cfg, outs): + """TOPI schedule callback""" + s = te.create_schedule([x.op for x in outs]) + + def _callback(op): + if 'conv2d_nhwc_winograd' in op.tag: + schedule_nhwc_winograd_cuda(cfg, s, op.output(0), use_tensorcore=True, + pre_computed=True) + + traverse_inline(s, outs[0].op, _callback) + return s diff --git a/topi/python/topi/cuda/conv3d.py b/topi/python/topi/cuda/conv3d.py index cc13aa511612..f244c65d0314 100644 --- a/topi/python/topi/cuda/conv3d.py +++ b/topi/python/topi/cuda/conv3d.py @@ -129,7 +129,7 @@ def schedule_conv3d_ndhwc(cfg, outs): The config for this template outs: Array of Tensor - The computation graph description of conv2d + The computation graph description of conv3d in the format of an array of tensors. Returns diff --git a/topi/python/topi/cuda/conv3d_transpose_ncdhw.py b/topi/python/topi/cuda/conv3d_transpose_ncdhw.py new file mode 100644 index 000000000000..bcad3e433e84 --- /dev/null +++ b/topi/python/topi/cuda/conv3d_transpose_ncdhw.py @@ -0,0 +1,145 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Conv3d transpose template for cuda backend""" + +import tvm +from tvm import te +from tvm import autotvm +from .. import nn +from ..util import get_const_tuple, traverse_inline +from .conv3d_direct import schedule_direct_conv3d_cuda + + +@autotvm.register_topi_compute("conv3d_transpose_ncdhw.cuda") +def conv3d_transpose_ncdhw(cfg, data, kernel, stride, padding, out_dtype): + """Transposed 3D convolution ncdhw forward operator. + + Parameters + ---------- + cfg: ConfigEntity + The config for this template + Input : tvm.te.Tensor + 5-D with shape [batch, in_channel, in_depth, in_height, in_width] + Filter : tvm.te.Tensor + 5-D with shape [in_channel, num_filter, filter_depth, filter_height, filter_width] + strides : int or a list/tuple of three ints + The spatial stride along height and width + padding : int or str + Padding size, or ['VALID', 'SAME'] + out_dtype: str + The output type. This is used in mixed precision + + Returns + ------- + Output : tvm.te.Tensor + 5-D with shape [batch, out_channel, out_depth, out_height, out_width] + """ + batch, inp_channels, inp_depth, inp_height, inp_width = get_const_tuple(data.shape) + _, out_channels, kernel_depth, kernel_height, kernel_width = get_const_tuple(kernel.shape) + stride_depth, stride_height, stride_width = stride + cfg.stride = stride + pad_front, pad_top, pad_left, pad_back, pad_bottom, pad_right = nn.get_pad_tuple3d( + padding, (kernel_depth, kernel_height, kernel_width)) + + out_depth = (inp_depth - 1) * stride_depth + \ + kernel_depth - pad_front - pad_back + pad_front = kernel_depth - 1 - pad_front + pad_back = kernel_depth - 1 - pad_back + dilated_depth = stride_depth * (inp_depth - 1) + 1 + + out_width = (inp_width - 1) * stride_width + \ + kernel_width - pad_left - pad_right + pad_left = kernel_width - 1 - pad_left + pad_right = kernel_width - 1 - pad_right + dilated_width = stride_width * (inp_width - 1) + 1 + + out_height = (inp_height - 1) * stride_height + \ + kernel_height - pad_top - pad_bottom + pad_top = kernel_height - 1 - pad_top + pad_bottom = kernel_height - 1 - pad_bottom + dilated_height = stride_height * (inp_height - 1) + 1 + + # compute pad + data = te.compute( + (batch, inp_channels, + pad_front + dilated_depth + pad_back, + pad_top + dilated_height + pad_bottom, + pad_left + dilated_width + pad_right), + lambda n, c, d, y, x: tvm.tir.if_then_else( + tvm.tir.all(x >= pad_left, + x < pad_left + dilated_width, + tvm.tir.indexmod(x - pad_left, stride_width).equal(0), + y >= pad_top, + y < pad_top + dilated_height, + tvm.tir.indexmod(y - pad_top, stride_height).equal(0), + d >= pad_front, + d < pad_front + dilated_depth, + tvm.tir.indexmod(d - pad_front, stride_depth).equal(0)), + data[n, c, + tvm.tir.indexdiv(d - pad_front, stride_depth), + tvm.tir.indexdiv(y - pad_top, stride_height), + tvm.tir.indexdiv(x - pad_left, stride_width)], + tvm.tir.const(0., "float32")), + name='data_pad') + + # compute transposed conv + dc = te.reduce_axis((0, inp_channels), name='dc') + dd = te.reduce_axis((0, kernel_depth), name='dd') + dh = te.reduce_axis((0, kernel_height), name='dh') + dw = te.reduce_axis((0, kernel_width), name='dw') + data_out = te.compute( + (batch, out_channels, out_depth, out_height, out_width), + lambda b, c, d, h, w: te.sum( + data[b, dc, d + dd, h + dh, w + dw].astype(out_dtype) * + kernel[dc, + c, + kernel_depth - 1 - dd, + kernel_height - 1 - dh, + kernel_width - 1 - dw].astype(out_dtype), + axis=[dc, dd, dh, dw]), tag="conv3d_transpose_ncdhw") + + return data_out + +@autotvm.register_topi_schedule("conv3d_transpose_ncdhw.cuda") +def schedule_conv3d_transpose_ncdhw(cfg, outs): + """TOPI Schedule callback for conv3d transpose operator. + + Parameters + ---------- + cfg: ConfigEntity + The parameters for this template + + outs: Array of Tensor + The computation graph description of conv3d transpose + in the format of an array of tensors. + + Returns + ------- + s: Schedule + The computation schedule for conv3d transpose. + """ + outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs + s = te.create_schedule([x.op for x in outs]) + + def _callback(op): + if op.tag == 'conv3d_transpose_ncdhw': + schedule_direct_conv3d_cuda(cfg, s, op.output(0), "NCDHW", + "conv3d_transpose_ncdhw.cuda") + + traverse_inline(s, outs[0].op, _callback) + return s diff --git a/topi/python/topi/cuda/correlation.py b/topi/python/topi/cuda/correlation.py new file mode 100644 index 000000000000..a383e4e7188e --- /dev/null +++ b/topi/python/topi/cuda/correlation.py @@ -0,0 +1,176 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Correlation operators on CUDA""" +import tvm +from tvm import te +from tvm import autotvm + +from .. import nn +from ..util import traverse_inline + + +@autotvm.register_topi_compute("correlation_nchw.cuda") +def correlation_nchw(cfg, data1, data2, kernel_size, max_displacement, stride1, stride2, padding, + is_multiply): + """Correlation operator in NCHW layout. + + Parameters + ---------- + data1 : tvm.te.Tensor + 4-D with shape [batch, channel, height, width] + + data2 : tvm.te.Tensor + 4-D with shape [batch, channel, height, width] + + kernel_size: int + Kernel size for correlation, must be an odd number + + max_displacement: int + Max displacement of Correlation + + stride1: int + Stride for data1 + + stride2: int + Stride for data2 within the neightborhood centered around data1 + + padding : int or a list/tuple of 2 or 4 ints + Padding size, or + [pad_height, pad_width] for 2 ints, or + [pad_top, pad_left, pad_bottom, pad_right] for 4 ints + + is_multiply: bocorrelation + operation type is either multiplication or substraction + + Returns + ------- + Output : tvm.te.Tensor + 4-D with shape [batch, out_channel, out_height, out_width] + """ + # pylint: disable=unused-argument + return nn.correlation_nchw(data1, data2, kernel_size, max_displacement, stride1, stride2, + padding, is_multiply) + + +def _schedule_correlation_nchw(cfg, s, correlation): + """Schedule correlation_nchw direct template""" + # pylint: disable=invalid-name + ##### space definition begin ##### + n, f, y, x = s[correlation].op.axis + rc, ry, rx = s[correlation].op.reduce_axis + cfg.define_split("tile_f", f, num_outputs=4) + cfg.define_split("tile_y", y, num_outputs=4) + cfg.define_split("tile_x", x, num_outputs=4) + cfg.define_split("tile_rc", rc, num_outputs=2) + cfg.define_split("tile_ry", ry, num_outputs=2) + cfg.define_split("tile_rx", rx, num_outputs=2) + cfg.define_knob("auto_unroll_max_step", [0, 512, 1500]) + + target = tvm.target.Target.current() + if target.target_name in ['nvptx', 'rocm']: + cfg.define_knob("unroll_explicit", [1]) + else: + cfg.define_knob("unroll_explicit", [0, 1]) + + ##### space definition end ##### + + padded_data1, padded_data2 = s[correlation].op.input_tensors + s[padded_data1].compute_inline() + s[padded_data2].compute_inline() + + # create cache stage + s[correlation].set_scope('local') + AA = s.cache_read(padded_data1, 'shared', [correlation]) + BB = s.cache_read(padded_data2, 'shared', [correlation]) + + output = s.outputs[0].output(0) + + # tile and bind spatial axes + n, f, y, x = s[output].op.axis + kernel_scope, n = s[output].split(n, nparts=1) + + bf, vf, tf, fi = cfg["tile_f"].apply(s, output, f) + by, vy, ty, yi = cfg["tile_y"].apply(s, output, y) + bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x) + + bf = s[output].fuse(n, bf) + s[output].bind(bf, te.thread_axis("blockIdx.z")) + s[output].bind(by, te.thread_axis("blockIdx.y")) + s[output].bind(bx, te.thread_axis("blockIdx.x")) + s[output].bind(vf, te.thread_axis("vthread")) + s[output].bind(vy, te.thread_axis("vthread")) + s[output].bind(vx, te.thread_axis("vthread")) + s[output].bind(tf, te.thread_axis("threadIdx.z")) + s[output].bind(ty, te.thread_axis("threadIdx.y")) + s[output].bind(tx, te.thread_axis("threadIdx.x")) + s[output].reorder(bf, by, bx, vf, vy, vx, tf, ty, tx, fi, yi, xi) + s[correlation].compute_at(s[output], tx) + + # tile reduction axes + n, f, y, x = s[correlation].op.axis + rc, ry, rx = s[correlation].op.reduce_axis + rco, rci = cfg['tile_rc'].apply(s, correlation, rc) + ryo, ryi = cfg['tile_ry'].apply(s, correlation, ry) + rxo, rxi = cfg['tile_rx'].apply(s, correlation, rx) + s[correlation].reorder(rco, ryo, rxo, rci, ryi, rxi, n, f, y, x) + + s[AA].compute_at(s[correlation], rxo) + s[BB].compute_at(s[correlation], rxo) + + # cooperative fetching + for load in [AA, BB]: + n, f, y, x = s[load].op.axis + fused = s[load].fuse(n, f, y, x) + tz, fused = s[load].split(fused, nparts=cfg["tile_f"].size[2]) + ty, fused = s[load].split(fused, nparts=cfg["tile_y"].size[2]) + tx, fused = s[load].split(fused, nparts=cfg["tile_x"].size[2]) + s[load].bind(tz, te.thread_axis("threadIdx.z")) + s[load].bind(ty, te.thread_axis("threadIdx.y")) + s[load].bind(tx, te.thread_axis("threadIdx.x")) + + # unroll + s[output].pragma(kernel_scope, 'auto_unroll_max_step', cfg['auto_unroll_max_step'].val) + s[output].pragma(kernel_scope, 'unroll_explicit', cfg['unroll_explicit'].val) + + +@autotvm.register_topi_schedule("correlation_nchw.cuda") +def schedule_correlation_nchw(cfg, outs): + """schedule of correlation_nchw for cuda gpu + + Parameters + ---------- + cfg: ConfigEntity + The config for this template + + outs: Array of Tensor + The computation graph description of correlation + in the format of an array of tensors. + + Returns + ------- + s: Schedule + The computation schedule for correlation. + """ + outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs + s = te.create_schedule([x.op for x in outs]) + + def _callback(op): + if op.tag == 'correlation_nchw': + _schedule_correlation_nchw(cfg, s, op.output(0)) + + traverse_inline(s, outs[0].op, _callback) + return s diff --git a/topi/python/topi/cuda/depthwise_conv2d.py b/topi/python/topi/cuda/depthwise_conv2d.py index db9da844e3af..b7cb32d58d01 100644 --- a/topi/python/topi/cuda/depthwise_conv2d.py +++ b/topi/python/topi/cuda/depthwise_conv2d.py @@ -167,7 +167,7 @@ def _schedule(temp, Filter, DepthwiseConv2d): b, h, w, c = s[Output].op.axis # num_thread here could be 728, it is larger than cuda.max_num_threads - num_thread = tvm.tir.ir_pass.Simplify(temp.shape[3]).value + num_thread = tvm.arith.Analyzer().simplify(temp.shape[3]).value target = tvm.target.Target.current() if target and (target.target_name not in ["cuda", "nvptx"]): num_thread = target.max_num_threads diff --git a/topi/python/topi/cuda/nms.py b/topi/python/topi/cuda/nms.py index d8be3bd1b886..f2c1143b5fb8 100644 --- a/topi/python/topi/cuda/nms.py +++ b/topi/python/topi/cuda/nms.py @@ -43,7 +43,8 @@ def atomic_add(x, y): return tvm.tir.call_pure_intrin(y.dtype, "atomic_add", x, y) -def get_valid_counts_ir(data, valid_count, out, score_threshold, id_index, score_index): +def get_valid_counts_ir(data, valid_count, out, out_indices, + score_threshold, id_index, score_index): """Low level IR to get valid count of bounding boxes given a score threshold. Also prepares to move valid boxes to the top of input data. @@ -83,6 +84,7 @@ def get_valid_counts_ir(data, valid_count, out, score_threshold, id_index, score valid_count = ib.buffer_ptr(valid_count) out = ib.buffer_ptr(out) + out_indices = ib.buffer_ptr(out_indices) atomic_add_return = ib.allocate( valid_count.dtype, (1,), name='atomic_add_return', scope='local') one_count = tvm.tir.const(1, dtype=valid_count.dtype) @@ -115,9 +117,11 @@ def get_valid_counts_ir(data, valid_count, out, score_threshold, id_index, score valid_count[i]), one_count) with ib.for_range(0, elem_length) as k: out[tid * elem_length + k] = data[tid * elem_length + k] + out_indices[tid + k] = tid + k with ib.else_scope(): with ib.for_range(0, elem_length) as k: out[tid * elem_length + k] = -one + out_indices[tid + k] = -one_count return ib.get() @@ -149,24 +153,27 @@ def get_valid_counts(data, score_threshold=0, id_index=0, score_index=1): Rearranged data tensor. """ batch_size = data.shape[0] + num_anchors = data.shape[1] data_buf = tvm.tir.decl_buffer( data.shape, data.dtype, "data_buf", data_alignment=8) valid_count_buf = tvm.tir.decl_buffer( (batch_size,), "int32", "valid_count_buf", data_alignment=8) out_buf = tvm.tir.decl_buffer( data.shape, data.dtype, "out_buf", data_alignment=8) + out_indices_buf = tvm.tir.decl_buffer( + (batch_size, num_anchors), "int32", "out_buf", data_alignment=8) - valid_count, out = \ - te.extern([(batch_size,), data.shape], [data], + valid_count, out, out_indices = \ + te.extern([(batch_size,), data.shape, (batch_size, num_anchors)], [data], lambda ins, outs: get_valid_counts_ir( - ins[0], outs[0], outs[1], score_threshold, id_index, score_index), + ins[0], outs[0], outs[1], outs[2], score_threshold, id_index, score_index), dtype=["int32", data.dtype], in_buffers=[data_buf], - out_buffers=[valid_count_buf, out_buf], + out_buffers=[valid_count_buf, out_buf, out_indices_buf], name="get_valid_counts", tag="get_valid_counts_gpu") - return [valid_count, out] + return [valid_count, out, out_indices] def nms_ir(data, sorted_index, valid_count, out, box_indices, @@ -335,7 +342,7 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): return ib.get() -def non_max_suppression(data, valid_count, max_output_size=-1, +def non_max_suppression(data, valid_count, indices, max_output_size=-1, iou_threshold=0.5, force_suppress=False, top_k=-1, coord_start=2, score_index=1, id_index=0, return_indices=True, invalid_to_bottom=False): @@ -347,9 +354,18 @@ def non_max_suppression(data, valid_count, max_output_size=-1, 3-D tensor with shape [batch_size, num_anchors, elem_length]. The last dimension should be in format of [class_id, score, box_left, box_top, box_right, box_bottom]. + It could be the second output out_tensor of get_valid_counts. valid_count : tvm.te.Tensor - 1-D tensor for valid number of boxes. + 1-D tensor for valid number of boxes. It could be the output + valid_count of get_valid_counts. + + indices : tvm.te.Tensor + 2-D tensor with shape [batch_size, num_anchors], represents + the index of box in original data. It could be the third + output out_indices of get_valid_counts. The values in the + second dimension are like the output of arange(num_anchors) + if get_valid_counts is not used before non_max_suppression. max_output_size : optional, int Max number of output valid boxes for each instance. diff --git a/topi/python/topi/cuda/pooling.py b/topi/python/topi/cuda/pooling.py index 26c18eeaa306..98399843e55c 100644 --- a/topi/python/topi/cuda/pooling.py +++ b/topi/python/topi/cuda/pooling.py @@ -22,7 +22,7 @@ from ..util import traverse_inline -def schedule_adaptive_pool(outs): +def schedule_adaptive_pool(outs, layout='NCHW'): """Schedule for adaptive_pool. Parameters @@ -51,8 +51,12 @@ def _schedule(Pool): else: Out = outs[0].op.output(0) s[Pool].set_scope("local") + by, ty = s[Out].split(s[Out].op.axis[0], factor=num_thread) - bx, tx = s[Out].split(s[Out].op.axis[1], factor=num_thread) + if layout == 'NHWC': + bx, tx = s[Out].split(s[Out].op.axis[3], factor=num_thread) + else: + bx, tx = s[Out].split(s[Out].op.axis[1], factor=num_thread) s[Out].reorder(by, bx, ty, tx) s[Out].bind(ty, thread_y) s[Out].bind(tx, thread_x) diff --git a/topi/python/topi/cuda/rcnn/proposal.py b/topi/python/topi/cuda/rcnn/proposal.py index 3546448cd306..f713bb216808 100644 --- a/topi/python/topi/cuda/rcnn/proposal.py +++ b/topi/python/topi/cuda/rcnn/proposal.py @@ -43,10 +43,10 @@ def predict_bbox_ir(cls_prob_buf, bbox_pred_buf, im_info_buf, out_buf, scales, r The last dimension is in format of [w_start, h_start, w_end, h_end, score] scales : list/tuple of float - Scales of anchor windoes. + Scales of anchor windows. ratios : list/tuple of float - Ratios of anchor windoes. + Ratios of anchor windows. feature_stride : int The size of the receptive field each unit in the convolution layer of the rpn, for example @@ -187,7 +187,7 @@ def argsort_ir(data_buf, out_index_buf): index_out[offset + 1] = temp_index[0] ib.emit(tvm.tir.Call(None, 'tvm_storage_sync', tvm.runtime.convert(['shared']), - tvm.tir.Call.Intrinsic, None, 0)) + tvm.tir.Call.Intrinsic)) return ib.get() @@ -248,7 +248,7 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): p_out[base_idx + i] = True ib.emit(tvm.tir.Call(None, 'tvm_storage_sync', tvm.runtime.convert(['shared']), - tvm.tir.Call.Intrinsic, None, 0)) + tvm.tir.Call.Intrinsic)) return ib.get() @@ -325,10 +325,10 @@ def proposal(cls_prob, bbox_pred, im_info, scales, ratios, feature_stride, thres 2-D with shape [batch, 3] scales : list/tuple of float - Scales of anchor windoes. + Scales of anchor windows. ratios : list/tuple of float - Ratios of anchor windoes. + Ratios of anchor windows. feature_stride : int The size of the receptive field each unit in the convolution layer of the rpn, for example diff --git a/topi/python/topi/cuda/softmax.py b/topi/python/topi/cuda/softmax.py index 62c437ae96ac..5f7402b4e7a0 100644 --- a/topi/python/topi/cuda/softmax.py +++ b/topi/python/topi/cuda/softmax.py @@ -16,12 +16,12 @@ # under the License. # pylint: disable=invalid-name, unused-variable, trailing-whitespace """Schedule for softmax operator""" +from tvm import target as target_ from tvm import te from tvm.contrib import cudnn from .. import generic from .injective import schedule_injective_from_existing - def schedule_softmax(outs): """Schedule for softmax op. @@ -39,6 +39,7 @@ def schedule_softmax(outs): outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs s = te.create_schedule([x.op for x in outs]) softmax = outs[0] + tgt = target_.Target.current(allow_none=False) op_tag = softmax.op.tag if op_tag == 'softmax_output': @@ -53,6 +54,18 @@ def schedule_softmax(outs): raise ValueError('Tag is expected to be softmax_output or log_softmax_output. \ Got {0}'.format(op_tag)) + # The nvptx and rocm backends only supports 32-bits warp shuffle + # instructions. + # + # TODO(tvm-team) Fix nvptx codegen or deprecate nvptx backend. + def sched_warp_softmax(): + if tgt.target_name == "nvptx" or tgt.target_name == "rocm": + return softmax.dtype == "float32" or softmax.dtype == "int32" + if tgt.target_name != "cuda": + # this is used as the gpu schedule for other arches which may not have warp reductions + return False + return True + if len(softmax.shape) > 2: ops = [max_elem.op, expsum.op, softmax.op] if exp is not None: @@ -60,6 +73,46 @@ def schedule_softmax(outs): for op in ops: s = schedule_injective_from_existing(s, op.output(0)) + + elif sched_warp_softmax(): + # A warp of 32 threads performs a row reduction. + num_thread = tgt.thread_warp_size + block_x = te.thread_axis("blockIdx.x") + thread_x = te.thread_axis((0, num_thread), "threadIdx.x") + + # (4) softmax + xo, xi = s[softmax].split(softmax.op.axis[1], nparts=num_thread) + _, xii = s[softmax].split(xi, factor=4) + s[softmax].vectorize(xii) + s[softmax].bind(xo, thread_x) + s[softmax].bind(softmax.op.axis[0], block_x) + + # (3) expsum + k = expsum.op.reduce_axis[0] + ko, _ = s[expsum].split(k, nparts=num_thread) + s[expsum].bind(ko, thread_x) + s[expsum].compute_at(s[softmax], xo) + + # (2) exp + if exp is not None: + xo, xi = s[exp].split(exp.op.axis[1], nparts=num_thread) + _, xii = s[exp].split(xi, factor=4) + s[exp].vectorize(xii) + s[exp].bind(xo, thread_x) + s[exp].compute_at(s[expsum], expsum.op.axis[0]) + s[exp].compute_at(s[softmax], softmax.op.axis[0]) + s[exp].set_scope("warp") + + # (1) max_elem + k = max_elem.op.reduce_axis[0] + ko, _ = s[max_elem].split(k, nparts=num_thread) + s[max_elem].bind(ko, thread_x) + if exp is not None: + s[max_elem].compute_at(s[exp], xo) + else: + s[max_elem].bind(ko, thread_x) + s[max_elem].bind(max_elem.op.axis[0], block_x) + else: num_thread = 64 block_x = te.thread_axis("blockIdx.x") diff --git a/topi/python/topi/cuda/sort.py b/topi/python/topi/cuda/sort.py index a1c70c44958d..ddae2bd96135 100644 --- a/topi/python/topi/cuda/sort.py +++ b/topi/python/topi/cuda/sort.py @@ -117,7 +117,7 @@ def sort_ir(data, values_out, axis, is_ascend, indices_out=None): tvm.tir.generic.cast(tid, indices_out.dtype) ib.emit(tvm.tir.Call(None, 'tvm_storage_sync', tvm.runtime.convert(['shared']), - tvm.tir.Call.Intrinsic, None, 0)) + tvm.tir.Call.Intrinsic)) idxd = tvm.tir.indexdiv idxm = tvm.tir.indexmod @@ -145,7 +145,7 @@ def sort_ir(data, values_out, axis, is_ascend, indices_out=None): indices_out[offset + axis_mul_after] = temp_index[0] ib.emit(tvm.tir.Call(None, 'tvm_storage_sync', tvm.runtime.convert(['shared']), - tvm.tir.Call.Intrinsic, None, 0)) + tvm.tir.Call.Intrinsic)) return ib.get() @@ -237,7 +237,7 @@ def sort_nms_ir(data, valid_count, output, axis, is_ascend): output[offset + axis_mul_after] = temp_index[0] ib.emit(tvm.tir.Call(None, 'tvm_storage_sync', tvm.runtime.convert(['shared']), - tvm.tir.Call.Intrinsic, None, 0)) + tvm.tir.Call.Intrinsic)) return ib.get() diff --git a/topi/python/topi/cuda/sparse.py b/topi/python/topi/cuda/sparse.py new file mode 100644 index 000000000000..5b57000f403a --- /dev/null +++ b/topi/python/topi/cuda/sparse.py @@ -0,0 +1,98 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Sparse operators""" +from tvm import te +from tvm import autotvm +from tvm.autotvm.task.space import SplitEntity +from ..util import traverse_inline +from .. import nn + + +@autotvm.register_topi_compute("sparse_dense.cuda") +def sparse_dense(cfg, data, weight_data, weight_indices, weight_indptr): + """ + Computes sparse-dense matrix multiplication of `data` and + `(weight_data, weight_indices, weight_indptr).T` + + Parameters + ---------- + cfg: ConfigEntity + The config for this template + + data : tvm.te.Tensor + 2-D with shape [M, K], float32 + + weight_data : tvm.te.Tensor + 1-D with shape [nnz] (CSR) or + 3-D with shape [num_blocks, bs_r, bs_c] (BSR) + + weight_indices : tvm.te.Tensor + 1-D with shape [nnz] (CSR) or + 1-D with shape [num_blocks] (BSR) + + weight_indptr : tvm.te.Tensor + 1-D with shape [N + 1] (CSR) or + 1-D with shape [(N + 1) // bs_r] (BSR) + + Returns + ------- + output : tvm.te.Tensor + 2-D with shape [M, N] + """ + # pylint:disable=unused-argument + return nn.sparse_dense(data, weight_data, weight_indices, weight_indptr) + + +@autotvm.register_topi_schedule("sparse_dense.cuda") +def schedule_sparse_dense(cfg, outs): + """Create schedule for sparse dense""" + # pylint:disable=invalid-name + s = te.create_schedule([x.op for x in outs]) + def _callback(op): + if op.tag == "sparse_dense_bsrmm": + y_bsrmm = op.input_tensors[0] + assert y_bsrmm.op.tag == "sparse_dense_bsrmm_block" + out = s.outputs[0].output(0) + + if op not in s.outputs: + y_reshape = op.output(0) + s[y_reshape].compute_at(s[out], s[out].op.axis[1]) + + (_, c) = s[y_bsrmm].op.reduce_axis + + (m_o, n_o) = s[out].op.axis + s[out].bind(m_o, te.thread_axis("blockIdx.x")) + s[out].bind(n_o, te.thread_axis("blockIdx.y")) + s[y_bsrmm].compute_at(s[out], n_o) + + thread_x = te.thread_axis("threadIdx.x") + + cfg.define_split("tile_c", c, num_outputs=2) + if cfg.is_fallback: + cfg["tile_c"] = SplitEntity([-1, 8]) + _, ci = cfg['tile_c'].apply(s, y_bsrmm, c) + + y_bsrmm_factored = s.rfactor(y_bsrmm, ci) + tx = s[y_bsrmm].op.reduce_axis[0] + s[y_bsrmm].bind(tx, thread_x) + s[y_bsrmm_factored].compute_at(s[y_bsrmm], tx) + s[y_bsrmm].set_store_predicate(thread_x.var.equal(0)) + s[out].set_store_predicate(thread_x.var.equal(0)) + + traverse_inline(s, outs[0].op, _callback) + return s diff --git a/topi/python/topi/cuda/ssd/multibox.py b/topi/python/topi/cuda/ssd/multibox.py index 30784f45a591..22d74438188c 100644 --- a/topi/python/topi/cuda/ssd/multibox.py +++ b/topi/python/topi/cuda/ssd/multibox.py @@ -459,7 +459,7 @@ def multibox_detection(cls_prob, loc_pred, anchor, clip=True, threshold=0.01, nm """ inter_out = multibox_transform_loc(cls_prob, loc_pred, anchor, clip, threshold, variances) - out = non_max_suppression(inter_out[0], inter_out[1], max_output_size=-1, + out = non_max_suppression(inter_out[0], inter_out[1], inter_out[1], max_output_size=-1, iou_threshold=nms_threshold, force_suppress=force_suppress, top_k=nms_topk, return_indices=False) return out diff --git a/topi/python/topi/cuda/tensor_intrin.py b/topi/python/topi/cuda/tensor_intrin.py index f8fce342e212..3941c00cc464 100644 --- a/topi/python/topi/cuda/tensor_intrin.py +++ b/topi/python/topi/cuda/tensor_intrin.py @@ -69,14 +69,15 @@ def _instr(index): return _instr(0), _instr(1), _instr(2) # body, reset, update - with tvm.target.build_config(data_alignment=4, offset_factor=1) as cfg: - scopes = {x: x_scope, y: y_scope, z: z_scope} - binds = {t: tvm.tir.decl_buffer(t.shape, t.dtype, t.op.name, - data_alignment=cfg.data_alignment, - offset_factor=cfg.offset_factor, - scope=scopes[t]) for t in [x, y, z]} - - return te.decl_tensor_intrin(z.op, _intrin_func, binds=binds) + default_buffer_params = { + "data_alignment": 4, "offset_factor": 1 + } + scopes = {x: x_scope, y: y_scope, z: z_scope} + binds = {t: tvm.tir.decl_buffer(t.shape, t.dtype, t.op.name, + scope=scopes[t], **default_buffer_params) for t in [x, y, z]} + + return te.decl_tensor_intrin( + z.op, _intrin_func, binds=binds, default_buffer_params=default_buffer_params) def intrin_wmma_load_matrix_A(strides_dst, strides_from, shape, layout, A_shape, C_shape, in_dtype): diff --git a/topi/python/topi/generic/default.py b/topi/python/topi/generic/default.py index d4c642ab8814..59e5a255c6e1 100644 --- a/topi/python/topi/generic/default.py +++ b/topi/python/topi/generic/default.py @@ -24,7 +24,7 @@ def default_schedule(outs, auto_inline): """Default schedule for llvm.""" target = tvm.target.Target.current(allow_none=False) outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs - if target.target_name != "llvm": + if target.target_name not in ("llvm", "c"): raise RuntimeError("schedule not registered for '%s'" % target) s = te.create_schedule([x.op for x in outs]) if auto_inline: diff --git a/topi/python/topi/generic/nn.py b/topi/python/topi/generic/nn.py index 2be4bbb456de..767087b0d4f0 100644 --- a/topi/python/topi/generic/nn.py +++ b/topi/python/topi/generic/nn.py @@ -290,6 +290,24 @@ def schedule_conv3d_ndhwc(outs): """ return _default_schedule(outs, False) + +def schedule_conv3d_transpose_ncdhw(outs): + """Schedule for conv3d_transpose_ncdhw + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of conv3d_transpose_ncdhw + in the format of an array of tensors. + + Returns + ------- + s: Schedule + The computation schedule for the op. + """ + return _default_schedule(outs, False) + + def schedule_conv2d_transpose_nchw(outs): """Schedule for conv2d_transpose_nchw @@ -672,3 +690,20 @@ def schedule_batch_matmul(outs): The computation schedule for the op. """ return _default_schedule(outs, False) + + +def schedule_correlation_nchw(outs): + """Schedule for correlation_nchw + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of correlation_nchw + in the format of an array of tensors. + + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + return _default_schedule(outs, False) diff --git a/topi/python/topi/generic/search.py b/topi/python/topi/generic/search.py index 91b7635108ff..895dadbd130c 100644 --- a/topi/python/topi/generic/search.py +++ b/topi/python/topi/generic/search.py @@ -34,3 +34,19 @@ def schedule_argwhere(outs): The computation schedule for the op. """ return _default_schedule(outs, False) + + +def schedule_scatter(outs): + """Schedule for scatter operator. + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of scatter. + + Returns + ------- + s: Schedule + The computation schedule for the op. + """ + return _default_schedule(outs, False) diff --git a/topi/python/topi/image/__init__.py b/topi/python/topi/image/__init__.py index 86b9825c8d92..914b02e1c219 100644 --- a/topi/python/topi/image/__init__.py +++ b/topi/python/topi/image/__init__.py @@ -21,3 +21,4 @@ from .resize import * from .dilation2d import * +from .grid_sample import * diff --git a/topi/python/topi/image/dilation2d.py b/topi/python/topi/image/dilation2d.py index a71866e60a98..074ca6c02d08 100644 --- a/topi/python/topi/image/dilation2d.py +++ b/topi/python/topi/image/dilation2d.py @@ -29,10 +29,10 @@ def dilation2d_nchw(input, filter, stride, padding, dilations, out_dtype=None): Parameters ---------- - input : tvm.Tensor + input : tvm.te.Tensor 4-D with shape [batch, in_channel, in_height, in_width] - filter : tvm.Tensor + filter : tvm.te.Tensor 3-D with shape [ in_channel, filter_height, filter_width] stride : int or a list/tuple of two ints @@ -49,7 +49,7 @@ def dilation2d_nchw(input, filter, stride, padding, dilations, out_dtype=None): Returns ------- - Output : tvm.Tensor + Output : tvm.te.Tensor 4-D with shape [batch, in_channel, out_height, out_width] """ if out_dtype is None: @@ -100,10 +100,10 @@ def dilation2d_nhwc(input, filter, stride, padding, dilations, out_dtype=None): Parameters ---------- - input : tvm.Tensor + input : tvm.te.Tensor 4-D with shape [batch, in_height, in_width, in_channel] - filter : tvm.Tensor + filter : tvm.te.Tensor 3-D with shape [filter_height, filter_width, in_channel] stride : int or a list/tuple of two ints @@ -120,7 +120,7 @@ def dilation2d_nhwc(input, filter, stride, padding, dilations, out_dtype=None): Returns ------- - Output : tvm.Tensor + Output : tvm.te.Tensor 4-D with shape [batch, out_height, out_width, in_channel] """ if out_dtype is None: diff --git a/topi/python/topi/image/grid_sample.py b/topi/python/topi/image/grid_sample.py new file mode 100644 index 000000000000..32b6112ddb18 --- /dev/null +++ b/topi/python/topi/image/grid_sample.py @@ -0,0 +1,124 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""affine_grid and grid_sample operator""" +from tvm import te, tir + + +def affine_grid(data, target_shape): + """affine_grid operator that generates 2D sampling grid. + + This operation is described in https://arxiv.org/pdf/1506.02025.pdf. It generates a uniform + sampling grid within the target shape and normalizes it to [-1, 1]. The provided affine + transformation is then applied on the sampling grid. + + Parameters + ---------- + data : tvm.Tensor + 3-D with shape [batch, 2, 3]. The affine matrix. + + target_shape: list/tuple of two int + Specifies the output shape (H, W). + + Returns + ------- + Output : tvm.Tensor + 4-D with shape [batch, 2, target_height, target_width] + """ + assert target_shape is not None + assert len(target_shape) == 2 + assert target_shape[0] > 1 and target_shape[1] > 1, \ + "target height/width should be greater than 1" + + dtype = data.dtype + y_step = tir.const((2.0 - 1e-7)/ (target_shape[0] - 1), dtype=dtype) + x_step = tir.const((2.0 - 1e-7)/ (target_shape[1] - 1), dtype=dtype) + start = tir.const(-1.0, dtype=dtype) + + def _compute(n, dim, i, j): + y = start + i * y_step + x = start + j * x_step + return data[n, dim, 0] * x + data[n, dim, 1] * y + data[n, dim, 2] + + oshape = (data.shape[0], len(target_shape), *target_shape) + return te.compute(oshape, _compute, tag='affine_grid') + + +def grid_sample(data, grid, method='bilinear', layout='NCHW'): + """Applies bilinear sampling to input feature map. + + Given :math:`data` and :math:`grid`, assuming NCHW layout, then the output is computed by + + .. math:: + + x_{src} = grid[batch, 0, y_{dst}, x_{dst}] \\ + y_{src} = grid[batch, 1, y_{dst}, x_{dst}] \\ + output[batch, channel, y_{dst}, x_{dst}] = G(data[batch, channel, y_{src}, x_{src}) + + :math:`x_{dst}`, :math:`y_{dst}` enumerate all spatial locations in :math:`output`, and + :math:`G()` denotes the interpolation method. + The out-boundary points will be padded with zeros. The shape of the output will be + (data.shape[0], data.shape[1], grid.shape[2], grid.shape[3]). + + The operator assumes that :math:`grid` has been normalized to [-1, 1]. + + grid_sample often cooperates with affine_grid which generates sampling grids for grid_sample. + + Parameters + ---------- + data : tvm.Tensor + 4-D with shape [batch, in_channel, in_height, in_width] + + grid : tvm.Tensor + 4-D with shape [batch, 2, out_height, out_width] + + method : str + The interpolation method. Only 'bilinear' is supported. + + layout : str + The layout of input data and the output. + + Returns + ------- + Output : tvm.Tensor + 4-D with shape [batch, 2, out_height, out_width] + """ + batch, in_channel, in_height, in_width = data.shape + out_height, out_width = grid.shape[2:] + assert method == 'bilinear', "Only bilinear is supported" + assert layout == "NCHW", "Only NCHW is supported" + + def _get_pixel_value(n, c, h, w): + return te.if_then_else(te.all(h >= 0, w >= 0, h < in_height, w < in_width), + data[n, c, h, w], tir.const(0.0, dtype=data.dtype)) + + def _bilinear_sample(n, c, h, w): + x = grid[n, 0, h, w] + y = grid[n, 1, h, w] + y = (y + 1) * (in_height - 1) / 2 + x = (x + 1) * (in_width - 1) / 2 + x0 = te.floor(x).astype('int32') + y0 = te.floor(y).astype('int32') + x1 = x0 + tir.const(1, 'int32') + y1 = y0 + tir.const(1, 'int32') + return _get_pixel_value(n, c, y0, x0) * (1.0 - (y - y0)) * (1.0 - (x - x0)) \ + + _get_pixel_value(n, c, y0, x1) * (1.0 - (y - y0)) * (x - x0) \ + + _get_pixel_value(n, c, y1, x0) * (y - y0) * (1.0 - (x - x0)) \ + + _get_pixel_value(n, c, y1, x1) * (y - y0) * (x - x0) + + return te.compute((batch, in_channel, out_height, out_width), _bilinear_sample, + tag='grid_sample') diff --git a/topi/python/topi/intel_graphics/depthwise_conv2d.py b/topi/python/topi/intel_graphics/depthwise_conv2d.py index a54941315a1a..650809985279 100644 --- a/topi/python/topi/intel_graphics/depthwise_conv2d.py +++ b/topi/python/topi/intel_graphics/depthwise_conv2d.py @@ -168,7 +168,7 @@ def _schedule(temp, Filter, DepthwiseConv2d): b, h, w, c = s[Output].op.axis # num_thread here could be 728, it is larger than cuda.max_num_threads - num_thread = tvm.tir.ir_pass.Simplify(temp.shape[3]).value + num_thread = tvm.arith.Analyzer().simplify(temp.shape[3]).value target = tvm.target.Target.current() if target and (target.target_name not in ["cuda", "nvptx"]): num_thread = target.max_num_threads diff --git a/topi/python/topi/mali/conv2d.py b/topi/python/topi/mali/conv2d.py index d19592857086..ed1932674964 100644 --- a/topi/python/topi/mali/conv2d.py +++ b/topi/python/topi/mali/conv2d.py @@ -138,11 +138,7 @@ def _schedule_spatial_pack(cfg, s, output, conv, data_vec, kernel_vec): s[data_vec].unroll(vw) if isinstance(kernel_vec.op, tvm.te.ComputeOp) and kernel_vec.name == 'kernel_vec': - if autotvm.GLOBAL_SCOPE.in_tuning: - # kernel packing will be pre-computed during compilation, so we skip - # this part to make tuning records correct - s[kernel_vec].pragma(s[kernel_vec].op.axis[0], 'debug_skip_region') - else: + if not autotvm.GLOBAL_SCOPE.in_tuning: max_threads = tvm.target.Target.current(allow_none=False).max_num_threads co, ci, kh, kw, vc = s[kernel_vec].op.axis fused = s[kernel_vec].fuse(co, ci, kh, kw, vc) @@ -279,15 +275,21 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, out_dtype, til data_pad[(b*bnb+bb) // (nH*nW)][ci][(b*bnb+bb) // nW % nH * m + eps] [(b*bnb+bb) % nW * m + nu], tvm.tir.const(0, data_pad.dtype)), name='d') - # transform kernel - if pre_computed: - U = kernel + if autotvm.GLOBAL_SCOPE.in_tuning: + VC = cfg['tile_k'].size[-1] + kvshape = (KH + tile_size - 1, KW + tile_size - 1, tvm.tir.indexdiv(CO, VC), CI, VC) + U = tvm.te.placeholder(kvshape, kernel.dtype, name="U") else: - r_kh = te.reduce_axis((0, KH), 'r_kh') - r_kw = te.reduce_axis((0, KW), 'r_kw') - U = te.compute((alpha, alpha, CO // bna, CI, bna), lambda eps, nu, co, ci, vco: - te.sum(kernel[co * bna + vco][ci][r_kh][r_kw] * G[eps][r_kh] * G[nu][r_kw], - axis=[r_kh, r_kw]), name='U') + # transform kernel + if pre_computed: + U = kernel + else: + r_kh = te.reduce_axis((0, KH), 'r_kh') + r_kw = te.reduce_axis((0, KW), 'r_kw') + U = te.compute((alpha, alpha, CO // bna, CI, bna), lambda eps, nu, co, ci, vco: + te.sum(kernel[co * bna + vco][ci][r_kh][r_kw] * + G[eps][r_kh] * G[nu][r_kw], + axis=[r_kh, r_kw]), name='U') # transform image r_a = te.reduce_axis((0, alpha), 'r_a') @@ -345,11 +347,7 @@ def _schedule_winograd(cfg, s, op): kernel, G = s[U].op.input_tensors s[G].compute_inline() eps, nu, co, ci, vco, = s[U].op.axis - if autotvm.GLOBAL_SCOPE.in_tuning: - # kernel transformation will be pre-computed during compilation, so we skip - # this part to make tuning records correct - s[U].pragma(eps, 'debug_skip_region') - else: + if not autotvm.GLOBAL_SCOPE.in_tuning: r_kh, r_kw = s[U].op.reduce_axis s[U].reorder(co, ci, eps, nu, r_kh, r_kw, vco) _ = [s[U].unroll(x) for x in [eps, nu, r_kh, r_kw]] diff --git a/topi/python/topi/math.py b/topi/python/topi/math.py index 6f31cca022d2..b4228a4a9178 100644 --- a/topi/python/topi/math.py +++ b/topi/python/topi/math.py @@ -143,6 +143,23 @@ def cos(x): return te.compute(x.shape, lambda *i: te.cos(x(*i))) +@tvm.te.tag_scope(tag=tag.ELEMWISE) +def cosh(x): + """Take cosh of input x. + + Parameters + ---------- + x : tvm.te.Tensor + Input argument. + + Returns + ------- + y : tvm.te.Tensor + The result. + """ + return te.compute(x.shape, lambda *i: te.cosh(x(*i))) + + @tvm.te.tag_scope(tag=tag.ELEMWISE) def sin(x): """Take sin of input x. @@ -160,6 +177,91 @@ def sin(x): return te.compute(x.shape, lambda *i: te.sin(x(*i))) +@tvm.te.tag_scope(tag=tag.ELEMWISE) +def sinh(x): + """Take sinh of input x. + + Parameters + ---------- + x : tvm.te.Tensor + Input argument. + + Returns + ------- + y : tvm.te.Tensor + The result. + """ + return te.compute(x.shape, lambda *i: te.sinh(x(*i))) + + +@tvm.te.tag_scope(tag=tag.ELEMWISE) +def acos(x): + """Take arc cos of input x. + + Parameters + ---------- + x : tvm.te.Tensor + Input argument. + + Returns + ------- + y : tvm.te.Tensor + The result. + """ + return te.compute(x.shape, lambda *i: te.acos(x(*i))) + + +@tvm.te.tag_scope(tag=tag.ELEMWISE) +def acosh(x): + """Take arc cosh of input x. + + Parameters + ---------- + x : tvm.te.Tensor + Input argument. + + Returns + ------- + y : tvm.te.Tensor + The result. + """ + return te.compute(x.shape, lambda *i: te.acosh(x(*i))) + + +@tvm.te.tag_scope(tag=tag.ELEMWISE) +def asin(x): + """Take arc sin of input x. + + Parameters + ---------- + x : tvm.te.Tensor + Input argument. + + Returns + ------- + y : tvm.te.Tensor + The result. + """ + return te.compute(x.shape, lambda *i: te.asin(x(*i))) + + +@tvm.te.tag_scope(tag=tag.ELEMWISE) +def asinh(x): + """Take arc sinh of input x. + + Parameters + ---------- + x : tvm.te.Tensor + Input argument. + + Returns + ------- + y : tvm.te.Tensor + The result. + """ + return te.compute(x.shape, lambda *i: te.asinh(x(*i))) + + @tvm.te.tag_scope(tag=tag.ELEMWISE) def atan(x): """Take atan of input x. @@ -176,6 +278,22 @@ def atan(x): """ return te.compute(x.shape, lambda *i: te.atan(x(*i))) +@tvm.te.tag_scope(tag=tag.ELEMWISE) +def atanh(x): + """Take atanh of input x. + + Parameters + ---------- + x : tvm.te.Tensor + Input argument. + + Returns + ------- + y : tvm.te.Tensor + The result. + """ + return te.compute(x.shape, lambda *i: te.atanh(x(*i))) + @tvm.te.tag_scope(tag=tag.ELEMWISE) def floor(x): """Take floor of input x. @@ -283,12 +401,12 @@ def isfinite(x): Parameters ---------- - x : tvm.Tensor + x : tvm.te.Tensor Input argument. Returns ------- - y : tvm.Tensor + y : tvm.te.Tensor The result. """ return te.compute(x.shape, lambda *i: te.isfinite(x(*i))) @@ -300,12 +418,12 @@ def isinf(x): Parameters ---------- - x : tvm.Tensor + x : tvm.te.Tensor Input argument. Returns ------- - y : tvm.Tensor + y : tvm.te.Tensor The result. """ return te.compute(x.shape, lambda *i: te.isinf(x(*i))) @@ -345,6 +463,40 @@ def log(x): return te.compute(x.shape, lambda *i: te.log(x(*i))) +@tvm.te.tag_scope(tag=tag.ELEMWISE) +def log2(x): + """Take logarithm to the base 2 of input x. + + Parameters + ---------- + x : tvm.te.Tensor + Input argument. + + Returns + ------- + y : tvm.te.Tensor + The result. + """ + return te.compute(x.shape, lambda *i: te.log2(x(*i))) + + +@tvm.te.tag_scope(tag=tag.ELEMWISE) +def log10(x): + """Take logarithm to the base 10 of input x. + + Parameters + ---------- + x : tvm.te.Tensor + Input argument. + + Returns + ------- + y : tvm.te.Tensor + The result. + """ + return te.compute(x.shape, lambda *i: te.log10(x(*i))) + + @tvm.te.tag_scope(tag=tag.ELEMWISE) def sqrt(x): """Take square root of input x. @@ -525,12 +677,12 @@ def fast_tanh(x): Parameters ---------- - x : tvm.Tensor + x : tvm.te.Tensor Input argument. Returns ------- - y : tvm.Tensor + y : tvm.te.Tensor The result. """ return cpp.fast_tanh(x, x.dtype, tag.ELEMWISE) diff --git a/topi/python/topi/nn/__init__.py b/topi/python/topi/nn/__init__.py index bd806b9d0e83..a035f6778c97 100644 --- a/topi/python/topi/nn/__init__.py +++ b/topi/python/topi/nn/__init__.py @@ -22,6 +22,7 @@ from .conv1d import * from .conv2d import * from .conv3d import * +from .correlation import * from .deformable_conv2d import * from .depthwise_conv2d import * from .elemwise import * @@ -31,6 +32,7 @@ from .mapping import * from .pooling import * from .softmax import * +from .conv3d_transpose import * from .conv2d_transpose import * from .conv1d_transpose import * from .bnn import * diff --git a/topi/python/topi/nn/conv3d_transpose.py b/topi/python/topi/nn/conv3d_transpose.py new file mode 100644 index 000000000000..29b9e53c22a0 --- /dev/null +++ b/topi/python/topi/nn/conv3d_transpose.py @@ -0,0 +1,169 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-variable, unused-argument +"""Transposed 3D convolution operators (sometimes called Deconvolution).""" +import tvm +from tvm import te +from tvm import relay +from .dilate import dilate +from .pad import pad +from .util import get_pad_tuple3d +from ..util import simplify + + +def conv3d_transpose_ncdhw(Input, Filter, strides, padding, out_dtype): + """Transposed 3D convolution ncdhw forward operator. + + Parameters + ---------- + Input : tvm.te.Tensor + 5-D with shape [batch, in_channel, in_depth, in_height, in_width] + + Filter : tvm.te.Tensor + 5-D with shape [in_channel, num_filter, filter_depth, filter_height, filter_width] + + strides : int or a list/tuple of three ints + The spatial stride along depth,height and width + + padding : int or str + Padding size, or ['VALID', 'SAME'] + + out_dtype : str + The output data type. This is used for mixed precision. + + Returns + ------- + Output : tvm.te.Tensor + 5-D with shape [batch, out_channel, out_depth, out_height, out_width] + """ + return declaration_conv3d_transpose_impl(Input, Filter, strides, padding, out_dtype) + + +def conv3d_transpose_ncdhw_preprocess(data, kernel, strides, padding, out_dtype): + """Preprocess data and kernel to make the compute pattern + of conv3d_transpose the same as conv3d""" + batch, in_c, in_d, in_h, in_w = data.shape + _, out_c, filter_d, filter_h, filter_w = kernel.shape + stride_d, stride_h, stride_w = strides + # dilate data + data_dilate = dilate(data, [1, 1, stride_d, stride_h, stride_w], name='data_dilate') + # pad data + fpad_front, fpad_top, fpad_left, fpad_back, fpad_bottom, fpad_right = get_pad_tuple3d( + padding, (filter_d, filter_h, filter_w)) + bpad_front = filter_d - 1 - fpad_front + bpad_back = filter_d - 1 - fpad_back + bpad_top = filter_h - 1 - fpad_top + bpad_bottom = filter_h - 1 - fpad_bottom + bpad_left = filter_w - 1 - fpad_left + bpad_right = filter_w - 1 - fpad_right + data_pad = pad(data_dilate, \ + [0, 0, bpad_front, bpad_top, bpad_left], \ + [0, 0, bpad_back, bpad_bottom, bpad_right], \ + name='data_pad') + # transform kernel layout from IODHW to OIDHW, and rotate kernel by 180 degrees + kernel_transform = te.compute((out_c, in_c, filter_d, filter_h, filter_w), \ + lambda o, i, d, h, w: kernel[i][o][filter_d-1-d] \ + [filter_h-1-h][filter_w-1-w], \ + name='kernel_transform') + return data_pad, kernel_transform + + +def declaration_conv3d_transpose_impl(data, kernel, strides, padding, out_dtype): + """Implementation of conv3d transpose""" + data_pad, kernel_transform = \ + conv3d_transpose_ncdhw_preprocess(data, kernel, strides, padding, out_dtype) + batch, in_c, in_d, in_h, in_w = data_pad.shape + out_c, _, filter_d, filter_h, filter_w = kernel_transform.shape + stride_d, stride_h, stride_w = strides + + # convolution stage + out_c = simplify(out_c) + out_d = simplify(in_d - filter_d + 1) + out_h = simplify(in_h - filter_h + 1) + out_w = simplify(in_w - filter_w + 1) + dc = te.reduce_axis((0, in_c), name='dc') + dd = te.reduce_axis((0, filter_d), name='dd') + dh = te.reduce_axis((0, filter_h), name='dh') + dw = te.reduce_axis((0, filter_w), name='dw') + + Output = te.compute( + (batch, out_c, out_d, out_h, out_w), + lambda b, c, d, h, w: te.sum( + data_pad[b, dc, d+dd, h+dh, w+dw].astype(out_dtype) * + kernel_transform[c, dc, dd, dh, dw].astype(out_dtype), + axis=[dc, dd, dh, dw]), tag="conv3d_transpose_ncdhw") + + return Output + + +@tvm.target.generic_func +def conv3d_transpose_legalize(attrs, inputs, types): + """Legalizes Transposed 3D convolution op. + + Parameters + ---------- + attrs : tvm.ir.Attrs + Attributes of current Transposed 3D convolution + inputs : list of tvm.relay.Expr + The args of the Relay expr to be legalized + types : list of types + List of input and output types + + Returns + ------- + result : tvm.relay.Expr + The legalized expr + """ + if attrs['data_layout'] == 'NDHWC': + data, kernel = inputs + kernel_layout = attrs['kernel_layout'] + # Convert Kernel layout to IODHW + # kernel_layout is different from input kernel layout - IO is swapped + if kernel_layout == 'DHWIO': + # input kernel layout is swapped to DHWOI + # output kernel layout will be IODHW + kernel = relay.transpose(kernel, axes=(4, 3, 0, 1, 2)) + elif kernel_layout == 'DHWOI': + # input kernel layout is swapped to DHWIO + # output kernel layout will be IODHW + kernel = relay.transpose(kernel, axes=(3, 4, 0, 1, 2)) + elif kernel_layout == 'IODHW': + # input kernel layout is swapped to OIDHW + # output kernel layout will be IODHW + kernel = relay.transpose(kernel, axes=(1, 0, 2, 3, 4)) + elif kernel_layout == 'OIDHW': + # input kernel layout is swapped to IODHW + # output kernel layout will be IODHW + pass + else: + # Skip legalize. Let relay.nn.conv2d_transpose to handle the case + return None + + # Set new attrs for conv3d_transpose. + new_attrs = {k: attrs[k] for k in attrs.keys()} + new_attrs['data_layout'] = 'NCDHW' + # layout of kernel should be IODHW, but kernel_layout should be swapped - OIDHW + new_attrs['kernel_layout'] = 'OIDHW' + + # Convert data to NCDHW. + data = relay.transpose(data, axes=(0, 4, 1, 2, 3)) + deconv = relay.nn.conv3d_transpose(data, kernel, **new_attrs) + # Convert back to original NDHWC layout. + out = relay.transpose(deconv, axes=(0, 2, 3, 4, 1)) + return out + + return None diff --git a/topi/python/topi/nn/correlation.py b/topi/python/topi/nn/correlation.py new file mode 100644 index 000000000000..94aea55d83b9 --- /dev/null +++ b/topi/python/topi/nn/correlation.py @@ -0,0 +1,116 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Correlation operators""" +from tvm import te + +from .pad import pad +from ..util import get_const_tuple + + +def correlation_nchw(data1, data2, kernel_size, max_displacement, stride1, stride2, padding, + is_multiply): + """Correlation operator in NCHW layout. + + Parameters + ---------- + data1 : tvm.te.Tensor + 4-D with shape [batch, channel, height, width] + + data2 : tvm.te.Tensor + 4-D with shape [batch, channel, height, width] + + kernel_size: int + Kernel size for correlation, must be an odd number + + max_displacement: int + Max displacement of Correlation + + stride1: int + Stride for data1 + + stride2: int + Stride for data2 within the neightborhood centered around data1 + + padding : int or a list/tuple of 2 or 4 ints + Padding size, or + [pad_height, pad_width] for 2 ints, or + [pad_top, pad_left, pad_bottom, pad_right] for 4 ints + + is_multiply: bool + operation type is either multiplication or substraction + + Returns + ------- + Output : tvm.te.Tensor + 4-D with shape [batch, out_channel, out_height, out_width] + """ + # pylint: disable=unnecessary-lambda, invalid-name + data_shape = get_const_tuple(data1.shape) + assert get_const_tuple(data2.shape) == data_shape, "data1 and data2 should have the same shape" + assert kernel_size > 0 and kernel_size % 2, "kernel_size should be non-negative odd number" + if isinstance(padding, (tuple, list)): + if len(padding) == 2: + pad_before_h = pad_after_h = padding[0] + pad_before_w = pad_after_w = padding[1] + elif len(padding) == 4: + pad_before_h, pad_before_w, pad_after_h, pad_after_w = padding + else: + raise ValueError("invalid padding") + elif isinstance(padding, int): + pad_before_h = pad_after_h = pad_before_w = pad_after_w = padding + else: + raise ValueError("invalid padding") + pad_before = [0, 0, pad_before_h, pad_before_w] + pad_after = [0, 0, pad_after_h, pad_after_w] + padded_data1 = pad(data1, pad_before, pad_after) + padded_data2 = pad(data2, pad_before, pad_after) + + batch, channel, height, width = data_shape + + kernel_radius = (kernel_size - 1) // 2 + border_size = max_displacement + kernel_radius + displacement_radius = max_displacement // stride2 + displacement_size = 2 * displacement_radius + 1 + + padded_width = width + pad_before_w + pad_after_w + padded_height = height + pad_before_h + pad_after_h + out_channel = displacement_size * displacement_size + out_height = (padded_height - 2 * border_size + stride1 - 1) // stride1 + out_width = (padded_width - 2 * border_size + stride1 - 1) // stride1 + + rc = te.reduce_axis((0, channel), name='rc') + ry = te.reduce_axis((0, kernel_size), name='ry') + rx = te.reduce_axis((0, kernel_size), name='rx') + + if is_multiply: + corr_func = lambda x, y: x * y + else: + corr_func = lambda x, y: te.abs(x - y) + + def _compute_correlation(n, q, i, j): + # location in data1 + y1 = i * stride1 + max_displacement + x1 = j * stride1 + max_displacement + # location in data2 + y2 = y1 + (te.indexdiv(q, displacement_size) - displacement_radius) * stride2 + x2 = x1 + (te.indexmod(q, displacement_size) - displacement_radius) * stride2 + return te.sum(corr_func(padded_data1[n, rc, y1 + ry, x1 + rx], + padded_data2[n, rc, y2 + ry, x2 + rx]), axis=[rc, ry, rx]) + + correlation = te.compute((batch, out_channel, out_height, out_width), lambda n, q, i, j: + _compute_correlation(n, q, i, j), tag="correlation_nchw") + return correlation / (kernel_size * kernel_size * channel) diff --git a/topi/python/topi/nn/dilate.py b/topi/python/topi/nn/dilate.py index f628fadee96e..ebcf478033fb 100644 --- a/topi/python/topi/nn/dilate.py +++ b/topi/python/topi/nn/dilate.py @@ -45,9 +45,9 @@ def dilate(data, strides, name="DilatedInput"): if len(strides) != n: raise ValueError("data dimension and strides size dismatch : %d vs %d" % ( n, len(strides))) - + ana = tvm.arith.Analyzer() out_shape = tuple( - tvm.tir.ir_pass.Simplify((data.shape[i] - 1) * strides[i] + 1) for i in range(n)) + ana.simplify((data.shape[i] - 1) * strides[i] + 1) for i in range(n)) def _dilate(*indices): not_zero = [] diff --git a/topi/python/topi/nn/pad.py b/topi/python/topi/nn/pad.py index 8fe53374f2b5..b298a0a2bb95 100644 --- a/topi/python/topi/nn/pad.py +++ b/topi/python/topi/nn/pad.py @@ -55,9 +55,9 @@ def pad(data, pad_before, pad_after=None, pad_value=0.0, name="PadInput"): if len(pad_after) != n: raise ValueError("Input dimension and pad_after dismatch : %d vs %d" % ( n, len(pad_before))) + ana = tvm.arith.Analyzer() out_shape = tuple( - tvm.tir.ir_pass.Simplify( - (data.shape[i] + pad_before[i] + pad_after[i])) for i in range(n)) + ana.simplify(data.shape[i] + pad_before[i] + pad_after[i]) for i in range(n)) pad_value = (pad_value if isinstance(pad_value, tvm.tir.PrimExpr) else tvm.tir.const(pad_value, data.dtype)) def _pad(*indices): @@ -115,8 +115,9 @@ def mirror_pad(data, if len(pad_after) != n: raise ValueError("Input dimension and pad_after dismatch : %d vs %d" % (n, len(pad_before))) + ana = tvm.arith.Analyzer() out_shape = tuple( - tvm.tir.ir_pass.Simplify((data.shape[i] + pad_before[i] + pad_after[i])) + ana.simplify(data.shape[i] + pad_before[i] + pad_after[i]) for i in range(n)) assert mode in ('SYMMETRIC', 'REFLECT') mode = int(mode == 'SYMMETRIC') diff --git a/topi/python/topi/nn/sparse.py b/topi/python/topi/nn/sparse.py index b37bac2a213a..b24121baf85a 100644 --- a/topi/python/topi/nn/sparse.py +++ b/topi/python/topi/nn/sparse.py @@ -30,7 +30,7 @@ def sparse_dense(data, weight_data, weight_indices, weight_indptr): Parameters ---------- - x : tvm.te.Tensor + data : tvm.te.Tensor 2-D with shape [M, K], float32 weight_data : tvm.te.Tensor diff --git a/topi/python/topi/opengl/conv2d_nchw.py b/topi/python/topi/opengl/conv2d_nchw.py deleted file mode 100644 index c93bcc25daef..000000000000 --- a/topi/python/topi/opengl/conv2d_nchw.py +++ /dev/null @@ -1,73 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -#pylint: disable=invalid-name, no-member, too-many-locals, too-many-statements, too-many-arguments, too-many-branches, line-too-long -"""Schedule for conv2d_nchw with auto fusion""" -import tvm -from tvm import te -from .. import tag - -def schedule_conv2d_nchw(outs): - """Schedule for conv2d_nchw. - - Parameters - ---------- - outs: Array of Tensor - The computation graph description of conv2d_nchw - in the format of an array of tensors. - - Returns - ------- - s: Schedule - The computation schedule for conv2d_nchw. - """ - outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs - s = te.create_schedule([x.op for x in outs]) - scheduled_ops = [] - - def _schedule(conv2d, data): - if conv2d.op in s.outputs: - Out = conv2d - else: - Out = outs[0].op.output(0) - s[conv2d].opengl() - s[Out].opengl() - s[data].opengl() - - def traverse(OP): - """Internal traverse function""" - # inline all one-to-one-mapping operators except the last stage (output) - if tag.is_broadcast(OP.tag): - if OP not in s.outputs: - s[OP].opengl() - for tensor in OP.input_tensors: - if isinstance(tensor.op, tvm.te.ComputeOp) and tensor.op not in scheduled_ops: - traverse(tensor.op) - # schedule conv2d_nchw - elif OP.tag.startswith('conv2d_nchw'): - conv2d = OP.output(0) - data = OP.input_tensors[0] - kernel = OP.input_tensors[1] - if isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag: - s[kernel].compute_inline() - _schedule(conv2d, data) - else: - raise RuntimeError("Unsupported operator: %s" % OP.tag) - - scheduled_ops.append(OP) - - traverse(outs[0].op) - return s diff --git a/topi/python/topi/opengl/dense.py b/topi/python/topi/opengl/dense.py deleted file mode 100644 index 715f713d56d6..000000000000 --- a/topi/python/topi/opengl/dense.py +++ /dev/null @@ -1,67 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=invalid-name, unused-variable -"""Schedule for dense operator""" -from tvm import te -from .. import tag - -def schedule_dense(outs): - """Schedule for dense operator. - - Parameters - ---------- - outs: Array of Tensor - The computation graph description of dense - in the format of an array of tensors. - - Returns - ------- - s: Schedule - The computation schedule for dense. - """ - outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs - s = te.create_schedule([x.op for x in outs]) - scheduled_ops = [] - - def _schedule(Dense): - if Dense.op in s.outputs: - Out = Dense - else: - Out = outs[0].op.output(0) - s[Dense].opengl() - s[Out].opengl() - - def traverse(OP): - """Internal traverse function""" - # inline all one-to-one-mapping operators except the last stage (output) - if tag.is_broadcast(OP.tag): - if OP not in s.outputs: - s[OP].compute_inline() - for tensor in OP.input_tensors: - if isinstance(tensor.op, te.tensor.ComputeOp) and tensor.op not in scheduled_ops: - traverse(tensor.op) - # schedule dense - elif OP.tag == 'dense': - Dense = OP.output(0) - _schedule(Dense) - else: - raise RuntimeError("Unsupported operator: %s" % OP.tag) - - scheduled_ops.append(OP) - - traverse(outs[0].op) - return s diff --git a/topi/python/topi/opengl/injective.py b/topi/python/topi/opengl/injective.py deleted file mode 100644 index a5944f7eedb2..000000000000 --- a/topi/python/topi/opengl/injective.py +++ /dev/null @@ -1,62 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=invalid-name, unused-variable, -"""Schedule for composition of injective operator""" -from tvm import te - -def schedule_injective_from_existing(sch, out): - """Schedule for injective op from existing schedule. - - Parameters - ---------- - sch: Schedule - The schedule to update. - out: Tensor - The tensor representing the injective op. - - Returns - ------- - sch: Schedule - The updated schedule. - """ - sch[out].opengl() - return sch - -def schedule_injective(outs): - """Schedule for injective op. - - Parameters - ---------- - outs: Array of Tensor - The computation graph description of injective in the format - of an array of tensors. - - Returns - ------- - sch: Schedule - The computation schedule for the op. - """ - outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs - s = te.create_schedule([x.op for x in outs]) - - te.schedule.AutoInlineInjective(s) - for out in outs: - schedule_injective_from_existing(s, out) - return s - -schedule_elemwise = schedule_injective -schedule_broadcast = schedule_injective diff --git a/topi/python/topi/opengl/pooling.py b/topi/python/topi/opengl/pooling.py deleted file mode 100644 index c30389c7b72c..000000000000 --- a/topi/python/topi/opengl/pooling.py +++ /dev/null @@ -1,121 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=invalid-name, unused-variable, unused-argument -"""Schedule for pooling operators""" -from tvm import te -from .. import tag - -def schedule_adaptive_pool(outs): - """Schedule for adaptive pool. - - Parameters - ---------- - outs: Array of Tensor - The computation graph description of global_pool - in the format of an array of tensors. - - Returns - ------- - s: Schedule - The computation schedule for adaptive pool. - """ - outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs - s = te.create_schedule([x.op for x in outs]) - scheduled_ops = [] - - def _schedule(Pool): - if Pool.op in s.outputs: - Out = Pool - else: - Out = outs[0].op.output(0) - s[Pool].opengl() - s[Out].opengl() - - def traverse(OP): - """Internal traverse function""" - # inline all one-to-one-mapping operators except the last stage (output) - if tag.is_broadcast(OP.tag): - if OP not in s.outputs: - s[OP].opengl() - for tensor in OP.input_tensors: - if isinstance(tensor.op, te.tensor.ComputeOp) and tensor.op not in scheduled_ops: - traverse(tensor.op) - # schedule global_pool - elif OP.tag.startswith('adaptive_pool'): - Pool = OP.output(0) - _schedule(Pool) - else: - raise RuntimeError("Unsupported operator: %s" % OP.tag) - - scheduled_ops.append(OP) - - traverse(outs[0].op) - return s - - -def schedule_pool(outs, layout): - """Schedule for pool. - - Parameters - ---------- - outs: Array of Tensor - The computation graph description of pool - in the format of an array of tensors. - - layout: str - Data layout. - - Returns - ------- - s: Schedule - The computation schedule for pool. - """ - outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs - s = te.create_schedule([x.op for x in outs]) - scheduled_ops = [] - - def _schedule(PaddedInput, Pool): - if isinstance(PaddedInput.op, te.tensor.ComputeOp): - s[PaddedInput].opengl() - if Pool.op in s.outputs: - Out = Pool - else: - Out = outs[0].op.output(0) - s[Pool].opengl() - s[Out].opengl() - - def traverse(OP): - """Internal traverse function""" - # inline all one-to-one-mapping operators except the last stage (output) - if tag.is_broadcast(OP.tag): - if OP not in s.outputs: - s[OP].compute_inline() - for tensor in OP.input_tensors: - if tensor.op not in scheduled_ops and isinstance(tensor.op, te.tensor.ComputeOp): - traverse(tensor.op) - # schedule pool - elif OP.tag.startswith('pool'): - PaddedInput = OP.input_tensors[0] - Pool = OP.output(0) - _schedule(PaddedInput, Pool) - else: - raise RuntimeError("Unsupported operator: %s" % OP.tag) - - scheduled_ops.append(OP) - - traverse(outs[0].op) - return s diff --git a/topi/python/topi/opengl/softmax.py b/topi/python/topi/opengl/softmax.py deleted file mode 100644 index 7b15a5373a3b..000000000000 --- a/topi/python/topi/opengl/softmax.py +++ /dev/null @@ -1,58 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=invalid-name, unused-variable, trailing-whitespace -"""Schedule for softmax operator""" -from tvm import te - -def schedule_softmax(outs): - """Schedule for softmax op. - - Parameters - ---------- - outs: Array of Tensor - The computation graph description of softmax in the format - of an array of tensors. - - Returns - ------- - sch: Schedule - The computation schedule for the op. - """ - outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs - s = te.create_schedule([x.op for x in outs]) - softmax = outs[0] - - op_tag = softmax.op.tag - if op_tag == 'softmax_output': - expsum = softmax.op.input_tensors[1] - exp = softmax.op.input_tensors[0] - max_elem = s[exp].op.input_tensors[1] - elif op_tag == 'log_softmax_output': - exp = None - max_elem = softmax.op.input_tensors[1] - expsum = softmax.op.input_tensors[2] - else: - raise ValueError('Tag is expected to be softmax_output or log_softmax_output. \ - Got {0}'.format(op_tag)) - - if exp is not None: - s[exp].opengl() - - s[max_elem].opengl() - s[expsum].opengl() - s[softmax].opengl() - return s diff --git a/topi/python/topi/rocm/conv2d.py b/topi/python/topi/rocm/conv2d.py index 4ee18775b938..bc5d5c3c0688 100644 --- a/topi/python/topi/rocm/conv2d.py +++ b/topi/python/topi/rocm/conv2d.py @@ -66,7 +66,7 @@ def conv2d_nchw_miopen(cfg, data, kernel, strides, padding, dilation, pt, pl, pb, pr = get_pad_tuple(padding, (KH, KW)) pad_h, pad_w = pt + pb, pl + pr dilation_h, dilation_w = (dilation, dilation) if isinstance(dilation, int) else dilation - + assert (pt == pb) and (pl == pr) OH = (H + 2 * pad_h - KH) // stride_h + 1 OW = (W + 2 * pad_w - KW) // stride_w + 1 cfg.add_flop(2 * N * OH * OW * CO * CI * ((KH - 1) * dilation_h + 1) *\ @@ -76,8 +76,8 @@ def conv2d_nchw_miopen(cfg, data, kernel, strides, padding, dilation, kernel, stride_h, stride_w, - pad_h, - pad_w, + pt, + pl, dilation_h, dilation_w, conv_mode=0, diff --git a/topi/python/topi/scatter.py b/topi/python/topi/scatter.py new file mode 100644 index 000000000000..e4e988612cc2 --- /dev/null +++ b/topi/python/topi/scatter.py @@ -0,0 +1,165 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, too-many-arguments, too-many-nested-blocks +"""Scatter operator""" +from tvm.te import hybrid + + +@hybrid.script +def _scatter_1d(data, indices, updates): + out = output_tensor(data.shape, data.dtype) + for i in range(data.shape[0]): + out[i] = data[i] + for i in range(indices.shape[0]): + out[indices[i] if indices[i] >= 0 else indices[i] + + data.shape[0]] = updates[i] + return out + + +@hybrid.script +def _scatter_2d(data, indices, updates, axis): + out = output_tensor(data.shape, data.dtype) + for i in const_range(data.shape[0]): + for j in const_range(data.shape[1]): + out[i, j] = data[i, j] + if axis == 0: + for i in range(indices.shape[0]): + for j in range(indices.shape[1]): + out[indices[i, j] if indices[i, j] >= + 0 else indices[i, j] + data.shape[axis], j] = updates[i, j] + else: + for i in range(indices.shape[0]): + for j in range(indices.shape[1]): + out[i, indices[i, j] if indices[i, j] >= + 0 else indices[i, j] + data.shape[axis]] = updates[i, j] + + return out + + +@hybrid.script +def _scatter_3d(data, indices, updates, axis): + out = output_tensor(data.shape, data.dtype) + for i in const_range(data.shape[0]): + for j in const_range(data.shape[1]): + for k in const_range(data.shape[2]): + out[i, j, k] = data[i, j, k] + if axis == 0: + for i in range(indices.shape[0]): + for j in range(indices.shape[1]): + for k in const_range(indices.shape[2]): + out[indices[i, j, k] if indices[i, j, k] >= + 0 else indices[i, j, k] + data.shape[axis], j, k] = updates[i, j, k] + elif axis == 1: + for i in range(indices.shape[0]): + for j in range(indices.shape[1]): + for k in const_range(indices.shape[2]): + out[i, indices[i, j, k] if indices[i, j, k] >= + 0 else indices[i, j, k] + data.shape[axis], k] = updates[i, j, k] + else: + for i in range(indices.shape[0]): + for j in range(indices.shape[1]): + for k in const_range(indices.shape[2]): + out[i, j, indices[i, j, k] if indices[i, j, k] >= + 0 else indices[i, j, k] + data.shape[axis]] = updates[i, j, k] + + return out + + +@hybrid.script +def _scatter_4d(data, indices, updates, axis): + out = output_tensor(data.shape, data.dtype) + for i in const_range(data.shape[0]): + for j in const_range(data.shape[1]): + for k in const_range(data.shape[2]): + for l in const_range(data.shape[3]): + out[i, j, k, l] = data[i, j, k, l] + + if axis == 0: + for i in range(indices.shape[0]): + for j in range(indices.shape[1]): + for k in const_range(indices.shape[2]): + for l in const_range(indices.shape[3]): + out[indices[i, j, k, l] if indices[i, j, k, l] >= + 0 else indices[i, j, k, l] + data.shape[axis], + j, k, l] = updates[i, j, k, l] + elif axis == 1: + for i in range(indices.shape[0]): + for j in range(indices.shape[1]): + for k in const_range(indices.shape[2]): + for l in const_range(indices.shape[3]): + out[i, + indices[i, j, k, l] if indices[i, j, k, l] >= + 0 else indices[i, j, k, l] + data.shape[axis], + k, l] = updates[i, j, k, l] + elif axis == 2: + for i in range(indices.shape[0]): + for j in range(indices.shape[1]): + for k in const_range(indices.shape[2]): + for l in const_range(indices.shape[3]): + out[i, j, + indices[i, j, k, l] if indices[i, j, k, l] >= + 0 else indices[i, j, k, l] + data.shape[axis], + l] = updates[i, j, k, l] + else: + for i in range(indices.shape[0]): + for j in range(indices.shape[1]): + for k in const_range(indices.shape[2]): + for l in const_range(indices.shape[3]): + out[i, j, k, + indices[i, j, k, l] if indices[i, j, k, l] >= + 0 else indices[i, j, k, l] + data.shape[axis] + ] = updates[i, j, k, l] + + return out + + +def scatter(data, indices, updates, axis=0): + """Update data at positions defined by indices with values in updates + + Parameters + ---------- + data : relay.Expr + The input data to the operator. + + indices : relay.Expr + The index locations to update. + + updates : relay.Expr + The values to update. + + axis : int + The axis to scatter on + + Returns + ------- + ret : relay.Expr + The computed result. + """ + if axis < 0: + axis += len(data.shape) + assert axis >= 0 + assert axis < len(data.shape) + + if len(data.shape) == 1: + return _scatter_1d(data, indices, updates) + if len(data.shape) == 2: + return _scatter_2d(data, indices, updates, axis) + if len(data.shape) == 3: + return _scatter_3d(data, indices, updates, axis) + if len(data.shape) == 4: + return _scatter_4d(data, indices, updates, axis) + raise ValueError("scatter only support for 1-4 dimensions") diff --git a/topi/python/topi/sort.py b/topi/python/topi/sort.py index 744da622adc2..f79eb52e9266 100644 --- a/topi/python/topi/sort.py +++ b/topi/python/topi/sort.py @@ -31,10 +31,10 @@ def argsort(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32"): The input tensor. valid_count : tvm.te.Tensor, optional - 1-D tensor for valid number of boxes only for ssd. + 1-D tensor for valid number of boxes. axis : int, optional - Axis along which to sort the input tensor. + Axis along which to sort the input tensor. By default the flattened array is used. is_ascend : boolean, optional @@ -107,7 +107,7 @@ def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int64"): data : tvm.te.Tensor The input tensor. - k : int, optional + k : int or tvm.te.Tensor, optional Number of top elements to select. Return all elements if k < 1. axis : int, optional @@ -133,7 +133,10 @@ def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int64"): assert ret_type in ["both", "values", "indices"] data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) out_shape = list(get_const_tuple(data.shape)) - if k >= 1: + kvar = tvm.te.size_var("k") + if not isinstance(k, int): + out_shape[axis] = kvar + elif k >= 1: out_shape[axis] = k out_bufs = [] if ret_type in ["both", "values"]: @@ -142,10 +145,11 @@ def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int64"): out_bufs.append(tvm.tir.decl_buffer(out_shape, dtype, "indices_buf", data_alignment=8)) out_shapes = [out_shape] * len(out_bufs) + kv = kvar if not isinstance(k, int) else k out = te.extern(out_shapes, [data], lambda ins, outs: tvm.tir.call_packed( - "tvm.contrib.sort.topk", ins[0], *outs, k, axis, ret_type, is_ascend), + "tvm.contrib.sort.topk", ins[0], *outs, kv, axis, ret_type, is_ascend), in_buffers=[data_buf], out_buffers=out_bufs, name="topk_cpu", diff --git a/topi/python/topi/testing/__init__.py b/topi/python/topi/testing/__init__.py index 36c460e671f5..70ee8e99047c 100644 --- a/topi/python/topi/testing/__init__.py +++ b/topi/python/topi/testing/__init__.py @@ -27,8 +27,10 @@ from .conv2d_nhwc_python import conv2d_nhwc_python from .conv3d_ncdhw_python import conv3d_ncdhw_python from .conv3d_ndhwc_python import conv3d_ndhwc_python +from .conv3d_transpose_ncdhw_python import conv3d_transpose_ncdhw_python from .conv2d_transpose_python import conv2d_transpose_nchw_python, conv2d_transpose_nhwc_python from .conv1d_transpose_ncw_python import conv1d_transpose_ncw_python +from .correlation_nchw_python import correlation_nchw_python from .deformable_conv2d_nchw_python import deformable_conv2d_nchw_python from .depthwise_conv2d_python import depthwise_conv2d_python_nchw, depthwise_conv2d_python_nhwc from .dilate_python import dilate_python @@ -41,6 +43,7 @@ from .roi_pool_python import roi_pool_nchw_python from .lrn_python import lrn_python from .l2_normalize_python import l2_normalize_python +from .gather_python import gather_python from .gather_nd_python import gather_nd_python from .strided_slice_python import strided_slice_python, strided_set_python from .batch_matmul import batch_matmul @@ -56,3 +59,4 @@ from .common import get_injective_schedule, get_reduce_schedule, get_broadcast_schedule, \ get_elemwise_schedule, get_conv2d_nchw_implement, dispatch from .adaptive_pool_python import adaptive_pool +from .grid_sample_python import affine_grid_python, grid_sample_nchw_python diff --git a/topi/python/topi/testing/common.py b/topi/python/topi/testing/common.py index 5817513f7f65..7bc5c5d8f60a 100644 --- a/topi/python/topi/testing/common.py +++ b/topi/python/topi/testing/common.py @@ -26,7 +26,6 @@ "arm_cpu": topi.arm_cpu.schedule_injective, "gpu": topi.cuda.schedule_injective, "hls": topi.hls.schedule_injective, - "opengl": topi.opengl.schedule_injective } _reduce_schedule = { @@ -64,7 +63,6 @@ def get_reduce_schedule(target): topi.mali.schedule_conv2d_nchw_spatial_pack), "bifrost": (topi.bifrost.conv2d_nchw_spatial_pack, topi.bifrost.schedule_conv2d_nchw_spatial_pack), - "opengl": (topi.nn.conv2d_nchw, topi.opengl.schedule_conv2d_nchw), "intel_graphics": (topi.intel_graphics.conv2d_nchw, topi.intel_graphics.schedule_conv2d_nchw), "hls": (topi.nn.conv2d_nchw, topi.hls.schedule_conv2d_nchw) diff --git a/topi/python/topi/testing/conv2d_nhwc_python.py b/topi/python/topi/testing/conv2d_nhwc_python.py index dc5f915daa22..7c021785544c 100644 --- a/topi/python/topi/testing/conv2d_nhwc_python.py +++ b/topi/python/topi/testing/conv2d_nhwc_python.py @@ -21,7 +21,7 @@ from topi.nn.util import get_pad_tuple -def conv2d_nhwc_python(a_np, w_np, stride, padding): +def _conv2d_nhwc_python(a_np, w_np, stride, padding): """Convolution operator in NHWC layout. Parameters @@ -35,10 +35,8 @@ def conv2d_nhwc_python(a_np, w_np, stride, padding): stride : int or a list/tuple of two ints Stride size, or [stride_height, stride_width] - padding : int or str or a list/tuple of 2 or 4 ints - Padding size, or ['VALID', 'SAME'], or - [pad_height, pad_width] for 2 ints, or - [pad_top, pad_left, pad_bottom, pad_right] for 2 ints + padding : int or str or a list/tuple of two ints + Padding size, or ['VALID', 'SAME'], or [pad_height, pad_width] Returns ------- @@ -77,3 +75,38 @@ def conv2d_nhwc_python(a_np, w_np, stride, padding): apad, np.rot90(np.rot90(wt[f, c])), mode='valid') bt[n, f] += out[::stride_h, ::stride_w] return bt.transpose((0, 2, 3, 1)) + +def conv2d_nhwc_python(a_np, w_np, stride, padding, groups=1): + """Convolution operator in NHWC layout. + + Parameters + ---------- + a_np : numpy.ndarray + 4-D with shape [batch, in_height, in_width, in_channel] + + w_np : numpy.ndarray + 4-D with shape [filter_height, filter_width, in_channel // groups, num_filter] + + stride : int or a list/tuple of two ints + Stride size, or [stride_height, stride_width] + + padding : int or str or a list/tuple of 2 or 4 ints + Padding size, or ['VALID', 'SAME'], or + [pad_height, pad_width] for 2 ints, or + [pad_top, pad_left, pad_bottom, pad_right] for 2 ints + + groups : int + Number of groups + + Returns + ------- + b_np : np.ndarray + 4-D with shape [batch, out_height, out_width, out_channel] + """ + + a_slices = np.array_split(a_np, groups, axis=3) + w_slices = np.array_split(w_np, groups, axis=3) + b_slices = [_conv2d_nhwc_python(a_slice, w_slice, stride, padding) + for a_slice, w_slice in zip(a_slices, w_slices)] + b_np = np.concatenate(b_slices, axis=3) + return b_np diff --git a/topi/python/topi/testing/conv3d_ncdhw_python.py b/topi/python/topi/testing/conv3d_ncdhw_python.py index 063c07d94133..0b2620fc290c 100644 --- a/topi/python/topi/testing/conv3d_ncdhw_python.py +++ b/topi/python/topi/testing/conv3d_ncdhw_python.py @@ -73,6 +73,7 @@ def conv3d_ncdhw_python(a_np, w_np, stride, padding, groups=1): padding : int or str or a list/tuple of three ints Padding size, or ['VALID', 'SAME'], or [pad_depth, pad_height, pad_width] + groups : int Number of groups diff --git a/topi/python/topi/testing/conv3d_transpose_ncdhw_python.py b/topi/python/topi/testing/conv3d_transpose_ncdhw_python.py new file mode 100644 index 000000000000..8140eb76d2db --- /dev/null +++ b/topi/python/topi/testing/conv3d_transpose_ncdhw_python.py @@ -0,0 +1,86 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals, too-many-branches +"""Convolution 3D transpose in python""" +import numpy as np +import topi +from topi.nn.util import get_pad_tuple3d + + +def conv3d_transpose_ncdhw_python(a_np, w_np, stride, padding): + """Transposed 3d convolution operator in NCDHW layout. + + Parameters + ---------- + a_np : numpy.ndarray + 5-D with shape [batch, in_channel, in_depth, in_height, in_width] + + w_np : numpy.ndarray + 5-D with shape [in_channel, num_filter, filter_depth, filter_height, filter_width] + + stride : int or a list/tuple of two ints + Stride size, or [stride_depth, stride_height, stride_width] + + padding : int or str + Padding size + + Returns + ------- + b_np : np.ndarray + 5-D with shape [batch, out_channel, out_depth, out_height, out_width] + """ + batch, in_c, in_d, in_h, in_w = a_np.shape + _, out_c, filter_d, filter_h, filter_w = w_np.shape + if isinstance(stride, int): + stride_d = stride_h = stride_w = stride + else: + stride_d, stride_h, stride_w = stride + + # dilate stage + dilated_a_np = topi.testing.dilate_python(a_np, [1, 1, stride_d, stride_h, stride_w]) + + # padding stage + fpad_front, fpad_top, fpad_left, fpad_back, fpad_bottom, fpad_right = get_pad_tuple3d( + padding, (filter_d, filter_h, filter_w)) + + bpad_front = filter_d - 1 - fpad_front + bpad_back = filter_d - 1 - fpad_back + bpad_top = filter_h - 1 - fpad_top + bpad_bottom = filter_h - 1 - fpad_bottom + bpad_left = filter_w - 1 - fpad_left + bpad_right = filter_w - 1 - fpad_right + + padded_a_np = np.zeros((batch, + in_c, + dilated_a_np.shape[2]+bpad_front+bpad_back, + dilated_a_np.shape[3]+bpad_top+bpad_bottom, + dilated_a_np.shape[4]+bpad_left+bpad_right)) + + padded_a_np[:, :, bpad_front:dilated_a_np.shape[2]+bpad_back, + bpad_top:dilated_a_np.shape[3]+bpad_top, + bpad_left:dilated_a_np.shape[4]+bpad_left] = dilated_a_np + + + # convolution stage + out_d = (in_d - 1) * stride_d - bpad_front - bpad_back + filter_d + out_h = (in_h - 1) * stride_h - fpad_top - fpad_bottom + filter_h + out_w = (in_w - 1) * stride_w - fpad_left - fpad_right + filter_w + + w_np = np.flip(w_np, axis=[2, 3, 4]).transpose((1, 0, 2, 3, 4)) + b_np = topi.testing.conv3d_ncdhw_python(padded_a_np, w_np, stride=(1, 1, 1), padding=(0, 0, 0)) + + return b_np diff --git a/topi/python/topi/testing/correlation_nchw_python.py b/topi/python/topi/testing/correlation_nchw_python.py new file mode 100644 index 000000000000..f0536560849b --- /dev/null +++ b/topi/python/topi/testing/correlation_nchw_python.py @@ -0,0 +1,103 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals +"""Convolution 3D in python""" +import numpy as np + + +def correlation_nchw_python(data1, data2, kernel_size, max_displacement, stride1, stride2, padding, is_multiply): + """Correlationn operator in NCHW layout. + + Parameters + ---------- + data1_np : numpy.ndarray + 4-D with shape [batch, in_channel, in_height, in_width] + + data2_np : numpy.ndarray + 4-D with shape [batch, in_channel, in_height, in_width] + + kernel_size: int + Kernel size for correlation, must be an odd number + + max_displacement: int + Max displacement of Correlation + + stride1: int + Stride for data1 + + stride2: int + Stride for data2 within the neightborhood centered around data1 + + padding: int + Padding for correlation + + is_multiply: bool + operation type is either multiplication or substraction + + Returns + ------- + c_np : np.ndarray + 4-D with shape [batch, out_channel, out_height, out_width] + """ + # compute output's dimension + pad_data_height = data1.shape[2] + 2 * padding + pad_data_width = data1.shape[3] + 2 * padding + kernel_radius = (kernel_size - 1) // 2 + border_size = max_displacement + kernel_radius + out_width = (pad_data_width - border_size * 2) // stride1 + out_height = (pad_data_height - border_size * 2) // stride1 + neighborhood_grid_radius = max_displacement // stride2 + neighborhood_grid_width = neighborhood_grid_radius * 2 + 1 + out_channel = neighborhood_grid_width * neighborhood_grid_width + + out = np.zeros((data1.shape[0], out_channel, out_height, out_width)) + pad_data1 = np.zeros((data1.shape[0], data1.shape[1], + pad_data_height, pad_data_width)) + pad_data2 = np.zeros((data1.shape[0], data1.shape[1], + pad_data_height, pad_data_width)) + + pad_data1[:, :, padding:padding + data1.shape[2], + padding:padding + data1.shape[3]] = data1[:, :, :, :] + pad_data2[:, :, padding:padding + data2.shape[2], + padding:padding + data2.shape[3]] = data2[:, :, :, :] + + if is_multiply: + corr_func = lambda x, y: x * y + else: + corr_func = lambda x, y: abs(x - y) + + # pylint: disable=too-many-nested-blocks + for i in range(out_height): + for j in range(out_width): + for nbatch in range(data1.shape[0]): + # x1,y1 is the location in data1 , i,j is the location in output + x1 = j * stride1 + max_displacement + y1 = i * stride1 + max_displacement + + for q in range(out_channel): + # location in data2 + x2 = x1 + (q % neighborhood_grid_width - neighborhood_grid_radius) * stride2 + y2 = y1 + (q // neighborhood_grid_width - neighborhood_grid_radius) * stride2 + + for h in range(kernel_size): + for w in range(kernel_size): + for channel in range(data1.shape[1]): + out[nbatch, q, i, j] += corr_func(pad_data1[nbatch, channel, y1 + h, x1 + w], + pad_data2[nbatch, channel, y2 + h, x2 + w]) + + out /= float(kernel_size** 2 *data1.shape[1]) + return out diff --git a/python/tvm/relay/frontend/util.py b/topi/python/topi/testing/gather_python.py similarity index 54% rename from python/tvm/relay/frontend/util.py rename to topi/python/topi/testing/gather_python.py index a7f89a30b996..0f3573cb1679 100644 --- a/python/tvm/relay/frontend/util.py +++ b/topi/python/topi/testing/gather_python.py @@ -14,20 +14,33 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=wildcard-import, redefined-builtin, invalid-name -""" Utility functions that are used across many directories. """ -from __future__ import absolute_import +# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals +"""gather in python""" import numpy as np -from .. import expr as _expr -def get_scalar_from_constant(expr): - """ Returns scalar value from Relay constant scalar. """ - assert isinstance(expr, _expr.Constant) and not expr.data.shape, \ - "Expr is not a constant scalar." - value = expr.data.asnumpy() - if value.dtype == np.dtype(np.int32): - return int(value) - if value.dtype == np.dtype(np.float32): - return float(value) - assert False, "Constant expr must be float32/int32" - return None # To suppress pylint +def gather_python(data, axis, indices): + """ Python version of Gather operator + + Parameters + ---------- + data : numpy.ndarray + Numpy array + + axis: int + integer + + indices : numpy.ndarray + Numpy array + + Returns + ------- + b_np : numpy.ndarray + Numpy array + """ + shape_indices = indices.shape + out = np.zeros(shape_indices, dtype=data.dtype) + for index in np.ndindex(*shape_indices): + new_index = list(index) + new_index[axis] = indices[index] + out[index] = data[tuple(new_index)] + return out diff --git a/topi/python/topi/testing/grid_sample_python.py b/topi/python/topi/testing/grid_sample_python.py new file mode 100644 index 000000000000..964d8a275745 --- /dev/null +++ b/topi/python/topi/testing/grid_sample_python.py @@ -0,0 +1,65 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals +"""affine_grid and grid_sample operators in python""" +import math +import numpy as np + + +def affine_grid_python(data, target_shape): + yv, xv = np.meshgrid( + np.arange(target_shape[0]), np.arange(target_shape[1])) + yv = yv.T * 2 / (target_shape[0] - 1) - 1 + xv = xv.T * 2 / (target_shape[1] - 1) - 1 + ones = np.ones_like(xv) + grid = np.stack([xv, yv, ones]).reshape(3, -1) + return data.reshape(-1, 3).dot(grid).reshape(data.shape[0], 2, *target_shape) + + +def _bilinear_sample_nchw_python(data, grid): + batch, in_channel, in_height, in_width = data.shape + _, _, out_height, out_width = grid.shape + out = np.zeros((batch, in_channel, out_height, out_width), dtype=data.dtype) + + def _within_bound(y, x): + return 0 <= y < in_height and 0 <= x < in_width + + for n in range(0, batch): + for h in range(0, out_height): + for w in range(0, out_width): + x, y = grid[n, :, h, w] + y = (y + 1) * (in_height - 1) / 2 + x = (x + 1) * (in_width - 1) / 2 + y0 = int(math.floor(y)) + x0 = int(math.floor(x)) + y1 = y0 + 1 + x1 = x0 + 1 + if _within_bound(y0, x0): + out[n, :, h, w] += data[n, :, y0, x0] * (1.0 - (y - y0)) * (1.0 - (x - x0)) + if _within_bound(y0, x1): + out[n, :, h, w] += data[n, :, y0, x1] * (1.0 - (y - y0)) * (x - x0) + if _within_bound(y1, x0): + out[n, :, h, w] += data[n, :, y1, x0] * (y - y0) * (1.0 - (x - x0)) + if _within_bound(y1, x1): + out[n, :, h, w] += data[n, :, y1, x1] * (y - y0) * (x - x0) + return out + + +def grid_sample_nchw_python(data, grid, method='bilinear'): + if method == 'bilinear': + return _bilinear_sample_nchw_python(data, grid) + raise ValueError("invalid method") diff --git a/topi/python/topi/testing/pool3d_python.py b/topi/python/topi/testing/pool3d_python.py index 2606650b33cf..457c4015a405 100644 --- a/topi/python/topi/testing/pool3d_python.py +++ b/topi/python/topi/testing/pool3d_python.py @@ -27,9 +27,18 @@ def pool3d_ncdhw_python(np_data, kernel, ceil_mode=False, dtype="float32"): """baseline for max_pool3d and avg_pool3d, default layout is "NCDHW""" in_n, in_c, in_d, in_h, in_w = in_shape = np_data.shape - k_d, k_h, k_w = kernel - s_d, s_h, s_w = strides - pf, pt, pl, pk, pb, pr = padding + if isinstance(kernel, int): + k_d = k_h = k_w = kernel + else: + k_d, k_h, k_w = kernel + if isinstance(strides, int): + s_d = s_h = s_w = strides + else: + s_d, s_h, s_w = strides + if isinstance(padding, int): + pf = pt = pl = pk = pb = pr = padding + else: + pf, pt, pl, pk, pb, pr = padding if ceil_mode: assert out_shape[2] == int(math.ceil(float(in_shape[2] - k_d + pf + pk) / s_d) + 1) diff --git a/topi/python/topi/testing/strided_slice_python.py b/topi/python/topi/testing/strided_slice_python.py index c1c899afe31f..970e1dedd8c9 100644 --- a/topi/python/topi/testing/strided_slice_python.py +++ b/topi/python/topi/testing/strided_slice_python.py @@ -17,7 +17,7 @@ """strided_slice/set in python""" -def strided_slice_python(data, begin, end, strides): +def strided_slice_python(data, begin, end, strides, slice_mode="end"): """Python version of strided slice operator. Parameters @@ -34,6 +34,14 @@ def strided_slice_python(data, begin, end, strides): strides : list The stride of each slice. + slice_mode : str, optional + The slice mode [end, size]. + end: The default slice mode, ending indices for the slice. + size: The input strides will be ignored, input end in this mode indicates + the sizeof a slice starting at the location specified by begin. If end[i] is -1, + all remaining elements in that dimension are included in the slice. + + Returns ------- result : numpy.ndarray @@ -42,10 +50,24 @@ def strided_slice_python(data, begin, end, strides): strides = [] if strides is None else strides slices = [] for i in range(len(data.shape)): - slices.append(slice( - begin[i] if i < len(begin) else None, - end[i] if i < len(end) else None, - strides[i] if i < len(strides) else None)) + new_stride = None + if slice_mode == "end" and i < len(strides): + new_stride = strides[i] + + new_begin = begin[i] if i < len(begin) else None + if i >= len(end): + new_end = None + elif slice_mode == "size": + if end[i] < 0: + new_end = None + else: + new_end = new_begin + end[i] + else: + new_end = end[i] + + slices.append(slice(new_begin, + new_end, + new_stride)) return data[tuple(slices)] diff --git a/topi/python/topi/transform.py b/topi/python/topi/transform.py index ef5456095899..f1bcccd9fde8 100644 --- a/topi/python/topi/transform.py +++ b/topi/python/topi/transform.py @@ -131,7 +131,7 @@ def flip(a, axis=0): """ return cpp.flip(a, axis) -def strided_slice(a, begin, end, strides=None): +def strided_slice(a, begin, end, strides=None, slice_mode="end"): """Slice of an array. Parameters @@ -139,24 +139,31 @@ def strided_slice(a, begin, end, strides=None): a : tvm.te.Tensor The tensor to be sliced. - begin: list of int + begin : list of int The indices to begin with in the slicing. - end: list of int + end : list of int Indicies indicating end of the slice. - strides: list of int, optional + strides : list of int, optional Specifies the stride values, it can be negative in that case, the input tensor will be reversed in that particular axis. + slice_mode : str, optional + The slice mode [end, size]. + end - The ending indices for the slice [default]. + size - The input strides will be ignored, input end in this mode indicates + the sizeof a slice starting at the location specified by begin. If end[i] + is -1, all remaining elements in that dimension are included in the slice. + Returns ------- ret : tvm.te.Tensor """ if strides is None: strides = [] - return cpp.strided_slice(a, begin, end, strides) + return cpp.strided_slice(a, begin, end, strides, slice_mode) @tvm.te.tag_scope(tag=tag.INJECTIVE+",strided_set") def strided_set(a, v, begin, end, strides=None): @@ -367,6 +374,38 @@ def take(a, indices, axis=None, mode="clip"): return cpp.take(a, indices, int(axis), mode) +def gather(data, axis, indices): + """Gather values along given axis from given indices. + + E.g. for a 3D tensor, output is computed as: + + .. code-block:: python + + out[i][j][k] = data[indices[i][j][k]][j][k] # if axis == 0 + out[i][j][k] = data[i][indices[i][j][k]][k] # if axis == 1 + out[i][j][k] = data[i][j][indices[i][j][k]] # if axis == 2 + + ``indices`` must have same shape as ``data``, except at dimension ``axis`` + which must just be not null. Output will have same shape as ``indices``. + + Parameters + ---------- + data : tvm.te.Tensor + The input data to the operator. + + axis: int + The axis along which to index. + + indices : tvm.te.Tensor + The indices of the values to extract. + + Returns + ------- + ret : tvm.te.Tensor + """ + return cpp.gather(data, axis, indices) + + def gather_nd(a, indices): """Gather elements from a n-dimension array.. @@ -676,3 +715,32 @@ def unravel_index(indices, shape): """ return cpp.unravel_index(indices, shape) + +def sparse_to_dense(sparse_indices, output_shape, sparse_values, default_value=0): + """Converts a sparse representation into a dense tensor. + + Example:: + - sparse_to_dense([[0, 0], [1, 1]], [2, 2], [3, 3], 0) = [[3, 0], [0, 3]] + + Parameters + ---------- + sparse_indices : tvm.te.Tensor + A 0-D, 1-D, or 2-D tensor of integers containing location of sparse values. + + output_shape : A list of integers + Shape of the dense output tensor. + + sparse_values : tvm.te.Tensor + A 0-D or 1-D tensor containing the sparse values for the sparse indices. + + default_value : tvm.te.Tensor + A 0-D tensor containing the default value for the remaining locations. + Defaults to 0. + + Returns + ------- + result : tvm.te.Tensor + Dense tensor of shape output_shape. Has the same type as sparse_values. + """ + + return cpp.sparse_to_dense(sparse_indices, output_shape, sparse_values, default_value) diff --git a/topi/python/topi/util.py b/topi/python/topi/util.py index 50a6a36edc46..cc437325e0d6 100644 --- a/topi/python/topi/util.py +++ b/topi/python/topi/util.py @@ -101,7 +101,8 @@ def get_const_int(expr): if isinstance(expr, Integral): return expr if not isinstance(expr, tvm.tir.IntImm): - expr = tvm.tir.ir_pass.Simplify(expr) + ana = tvm.arith.Analyzer() + expr = ana.simplify(expr) if not isinstance(expr, tvm.tir.IntImm): raise ValueError("Expect value to be constant int") return int(expr.value) @@ -123,7 +124,8 @@ def get_const_float(expr): if isinstance(expr, float): return float(expr) if not isinstance(expr, tvm.tir.FloatImm): - expr = tvm.tir.ir_pass.Simplify(expr) + ana = tvm.arith.Analyzer() + expr = ana.simplify(expr) if not isinstance(expr, tvm.tir.FloatImm): raise ValueError("Expect value to be constant float") return float(expr.value) @@ -145,7 +147,8 @@ def equal_const_int(expr, value): if isinstance(expr, Integral): return expr == value if not isinstance(expr, tvm.tir.IntImm): - expr = tvm.tir.ir_pass.Simplify(expr) + ana = tvm.arith.Analyzer() + expr = ana.simplify(expr) if not isinstance(expr, tvm.tir.IntImm): return False return expr.value == value @@ -165,11 +168,13 @@ def get_const_tuple(in_tuple): The output. """ ret = [] + ana = None for elem in in_tuple: if isinstance(elem, (tvm.tir.Var, tvm.tir.expr.Any)): ret.append(elem) elif not isinstance(elem, (tvm.tir.IntImm, int)): - elem = tvm.tir.ir_pass.Simplify(elem) + ana = tvm.arith.Analyzer() if ana is None else ana + elem = ana.simplify(elem) if not isinstance(elem, tvm.tir.IntImm): ret.append(elem) else: @@ -208,7 +213,7 @@ def simplify(expr): out : Expr or int The simplified output """ - return tvm.tir.ir_pass.Simplify(expr) if isinstance(expr, tvm.tir.PrimExpr) else expr + return tvm.arith.Analyzer().simplify(expr) if isinstance(expr, tvm.tir.PrimExpr) else expr def ravel_index(indices, shape): diff --git a/topi/python/topi/vision/nms.py b/topi/python/topi/vision/nms.py index 28598dedffbd..269c876d647e 100644 --- a/topi/python/topi/vision/nms.py +++ b/topi/python/topi/vision/nms.py @@ -23,7 +23,7 @@ from ..sort import argsort @hybrid.script -def hybrid_rearrange_out(data, one): +def hybrid_rearrange_box_out(data, one, batch_size, num_anchors): """Hybrid routine to rearrange nms output to move all valid entries to top. @@ -36,14 +36,19 @@ def hybrid_rearrange_out(data, one): one: tvm.tir.const Constant one with the same dtype as data. + batch_size: tvm.tir.IntImm or tvm.tir.Var + Batch size. We need to pass it in since hybrid script doesn't support + binding variable to symbolic dim. + + num_anchors: tvm.tir.IntImm or tvm.tir.Var + Number of anchors. + Returns ------- output : tvm.te.Tensor or numpy NDArray Transformed NMS output. 3-D tensor with shape [batch_size, num_anchors, 6]. """ - batch_size = data.shape[0] - num_anchors = data.shape[1] elem_length = data.shape[2] output = output_tensor((batch_size, num_anchors, @@ -64,7 +69,59 @@ def hybrid_rearrange_out(data, one): @hybrid.script -def hybrid_get_valid_counts(data, score_threshold, id_index, score_index, one): +def hybrid_rearrange_indices_out(data, one, batch_size, num_anchors): + """Hybrid routine to rearrange nms output to + move all valid entries to top. + + Parameters + ---------- + data : tvm.te.Tensor or numpy NDArray + NMS output. 3-D tensor with shape + [batch_size, num_anchors, 6] or + [batch_size, num_anchors, 5], or 2-D + tensor with shape [batch_size, num_anchors]. + + one: tvm.tir.const + Constant one with the same dtype as data. + + batch_size: tvm.tir.IntImm or tvm.tir.Var + Batch size. We need to pass it in since hybrid script doesn't support + binding variable to symbolic dim. + + num_anchors: tvm.tir.IntImm or tvm.tir.Var + Number of anchors. + + Returns + ------- + output : tvm.te.Tensor or numpy NDArray + 2-D tensor with shape [batch_size, num_anchors]. + + valid_box_count : tvm.te.Tensor or numpy NDArray + Tensor with shape [batch_size, 1], indicates + the valid number of boxes. + """ + valid_box_count = output_tensor((batch_size, 1), "int32") + output = output_tensor((batch_size, num_anchors), data.dtype) + + for i in parallel(batch_size): + valid_idx = 0 + for j in range(num_anchors): + if data[i, j] >= 0: + output[i, valid_idx] = data[i, j] + valid_idx += 1 + if data[i, j] > num_anchors or data[i, j] < -num_anchors: + output[i, valid_idx] = 0 + valid_idx += 1 + if j >= valid_idx: + output[i, j] = -one + valid_box_count[i, 0] = valid_idx + + return output, valid_box_count + + +@hybrid.script +def hybrid_get_valid_counts(data, score_threshold, id_index, score_index, + one, batch_size, num_anchors): """Hybrid routine to get valid count of bounding boxes given a score threshold. Also moves valid boxes to the top of input data. @@ -87,22 +144,31 @@ def hybrid_get_valid_counts(data, score_threshold, id_index, score_index, one): one: tvm.tir.const Constant one with the same dtype as data. + batch_size: tvm.tir.IntImm or tvm.tir.Var + Batch size. We need to pass it in since hybrid script doesn't support + binding variable to symbolic dim. + + num_anchors: tvm.tir.IntImm or tvm.tir.Var + Number of anchors. + Returns ------- + valid_count : tvm.te.Tensor or numpy NDArray + 1-D tensor for valid number of boxes. + out_tensor : tvm.te.Tensor or numpy NDArray Rearranged data tensor. - valid_count : tvm.te.Tensor or numpy NDArray - 1-D tensor for valid number of boxes. + out_indices: tvm.te.Tensor or numpy NDArray + Related index in input data. """ - batch_size = data.shape[0] - num_anchors = data.shape[1] box_data_length = data.shape[2] valid_count = output_tensor((batch_size,), "int32") out_tensor = output_tensor((batch_size, num_anchors, box_data_length), data.dtype) + out_indices = output_tensor((batch_size, num_anchors), "int32") for i in parallel(batch_size): valid_count[i] = 0 for j in range(num_anchors): @@ -111,11 +177,13 @@ def hybrid_get_valid_counts(data, score_threshold, id_index, score_index, one): (id_index < 0 or data[i, j, id_index] >= 0): for k in range(box_data_length): out_tensor[i, valid_count[i], k] = data[i, j, k] + out_indices[i, valid_count[i]] = j valid_count[i] += 1 if j >= valid_count[i]: for k in range(box_data_length): out_tensor[i, j, k] = -one - return valid_count, out_tensor + out_indices[i, j] = -1 + return valid_count, out_tensor, out_indices def get_valid_counts(data, score_threshold=0, id_index=0, score_index=1): @@ -139,38 +207,55 @@ def get_valid_counts(data, score_threshold=0, id_index=0, score_index=1): Returns ------- + valid_count : tvm.te.Tensor + 1-D tensor for valid number of boxes. + out_tensor : tvm.te.Tensor Rearranged data tensor. - valid_count : tvm.te.Tensor - 1-D tensor for valid number of boxes. + out_indices: tvm.te.Tensor or numpy NDArray + Related index in input data. """ score_threshold_const = tvm.tir.const(score_threshold, data.dtype) id_index_const = tvm.tir.const(id_index, "int32") score_index_const = tvm.tir.const(score_index, "int32") return hybrid_get_valid_counts(data, score_threshold_const, id_index_const, score_index_const, - tvm.tir.const(1, data.dtype)) + tvm.tir.const(1, data.dtype), + data.shape[0], data.shape[1]) @hybrid.script -def hybrid_nms(data, sorted_index, valid_count, - max_output_size, iou_threshold, force_suppress, - top_k, coord_start, id_index, score_index, zero, one): +def hybrid_nms(data, sorted_index, valid_count, indices, batch_size, num_anchors, + max_output_size, iou_threshold, force_suppress, top_k, coord_start, + score_index, id_index, return_indices, zero, one): """Hybrid routing for non-maximum suppression. Parameters ---------- data: tvm.te.Tensor or numpy NDArray Bounding boxes with class and score. 3-D tensor with shape - [batch_size, num_anchors, 6]. + [batch_size, num_anchors, 6]. It could be the second output + out_tensor of get_valid_counts. sorted_index : tvm.te.Tensor or numpy NDArray Bounding box indexes sorted by score, with shape [batch_size, num_anchors]. valid_count : tvm.te.Tensor or numpy NDArray - 1-D tensor for valid number of boxes. + 1-D tensor for valid number of boxes. It could be the output + valid_count of get_valid_counts. + + indices : tvm.te.Tensor or numpy.NDArray + indices in original tensor, with shape [batch_size, num_anchors], + represents the index of box in original data. It could be the third + output out_indices of get_valid_counts. The values in the second + dimension are like the output of arange(num_anchors) if get_valid_counts + is not used before non_max_suppression. + + batch_size: tvm.tir.IntImm or tvm.tir.Var + Batch size. We need to pass it in since hybrid script doesn't support + binding variable to symbolic dim. max_output_size : tvm.tir.const Max number of output valid boxes for each instance. @@ -188,11 +273,14 @@ def hybrid_nms(data, sorted_index, valid_count, coord_start : tvm.tir.const Start index of the consecutive 4 coordinates. + score_index: tvm.tir.const + Index of the scores/confidence of boxes. + id_index : tvm.tir.const index of the class categories, -1 to disable. - score_index: tvm.tir.const - Index of the scores/confidence of boxes. + return_indices : tvm.tir.const + Whether to return box indices in input data. zero: tvm.tir.const Constant zero with the same dtype as data. @@ -203,15 +291,17 @@ def hybrid_nms(data, sorted_index, valid_count, Returns ------- output : tvm.te.Tensor - 3-D tensor with shape [batch_size, num_anchors, 6]. + 3-D tensor with shape [batch_size, num_anchors, 6] + or [batch_size, num_anchors, 5]. box_indices: tvm.te.Tensor 2-D tensor with shape [batch_size, num_anchors]. """ - batch_size = data.shape[0] - num_anchors = data.shape[1] + box_data_length = data.shape[2] - box_indices = output_tensor((batch_size, num_anchors), "int32") + + # box_indices is the expected value, similar to TF & ONNX + box_indices = output_tensor((batch_size, num_anchors), sorted_index.dtype) output = output_tensor((batch_size, num_anchors, box_data_length,), data.dtype) @@ -232,9 +322,11 @@ def hybrid_nms(data, sorted_index, valid_count, for k in range(box_data_length): output[i, j + nkeep, k] = -one box_indices[i, j + nkeep] = -1 + # Apply nms box_start_idx = coord_start batch_idx = i + for j in range(valid_count[i]): if output[i, j, score_index] > 0 and (id_index < 0 or output[i, j, id_index] >= 0): box_a_idx = j @@ -246,36 +338,62 @@ def hybrid_nms(data, sorted_index, valid_count, check_iou = 1 elif id_index < 0 or output[i, j, id_index] == output[i, k, id_index]: check_iou = 1 + if check_iou > 0: - a_l = output[batch_idx, box_a_idx, box_start_idx] - a_t = output[batch_idx, box_a_idx, box_start_idx + 1] - a_r = output[batch_idx, box_a_idx, box_start_idx + 2] - a_b = output[batch_idx, box_a_idx, box_start_idx + 3] + # a_l: left, a_t: top, a_r: right, a_b: bottom + a_l = min(output[batch_idx, box_a_idx, box_start_idx], + output[batch_idx, box_a_idx, box_start_idx + 2]) + a_t = min(output[batch_idx, box_a_idx, box_start_idx + 1], + output[batch_idx, box_a_idx, box_start_idx + 3]) + a_r = max(output[batch_idx, box_a_idx, box_start_idx], + output[batch_idx, box_a_idx, box_start_idx + 2]) + a_b = max(output[batch_idx, box_a_idx, box_start_idx + 1], + output[batch_idx, box_a_idx, box_start_idx + 3]) + box_b_idx = k - b_t = output[batch_idx, box_b_idx, box_start_idx + 1] - b_b = output[batch_idx, box_b_idx, box_start_idx + 3] - b_l = output[batch_idx, box_b_idx, box_start_idx] - b_r = output[batch_idx, box_b_idx, box_start_idx + 2] + + # b_l: left, b_t: top, b_r: right, b_b: bottom + b_l = min(output[batch_idx, box_b_idx, box_start_idx], + output[batch_idx, box_b_idx, box_start_idx + 2]) + b_t = min(output[batch_idx, box_b_idx, box_start_idx + 1], + output[batch_idx, box_b_idx, box_start_idx + 3]) + b_r = max(output[batch_idx, box_b_idx, box_start_idx], + output[batch_idx, box_b_idx, box_start_idx + 2]) + b_b = max(output[batch_idx, box_b_idx, box_start_idx + 1], + output[batch_idx, box_b_idx, box_start_idx + 3]) + + # Overlapping width and height w = max(zero, min(a_r, b_r) - max(a_l, b_l)) h = max(zero, min(a_b, b_b) - max(a_t, b_t)) + + # Overlapping area area = h * w + + # total area of the figure formed by box a and box b + # except for overlapping area u = (a_r - a_l) * (a_b - a_t) + (b_r - b_l) * (b_b - b_t) - area + + # get the iou iou = zero if u <= zero else area / u + if iou >= iou_threshold: output[i, k, score_index] = -one if id_index >= 0: output[i, k, id_index] = -one box_indices[i, k] = -1 + else: for j in parallel(valid_count[i]): for k in range(box_data_length): output[i, j, k] = data[i, j, k] box_indices[i, j] = j + # Set invalid entry to be -1 for j in parallel(num_anchors - valid_count[i]): for k in range(box_data_length): output[i, j + valid_count[i], k] = -one box_indices[i, j + valid_count[i]] = -1 + # Only return max_output_size valid boxes num_valid_boxes = 0 if max_output_size > 0: @@ -287,10 +405,17 @@ def hybrid_nms(data, sorted_index, valid_count, box_indices[i, j] = -1 else: num_valid_boxes += 1 - return output, box_indices + if return_indices: + for j in range(valid_count[i]): + idx = box_indices[i, j] + if box_indices[i, j] >= 0: + box_indices[i, j] = indices[i, idx] + + return output, box_indices -def non_max_suppression(data, valid_count, max_output_size=-1, +@tvm.target.generic_func +def non_max_suppression(data, valid_count, indices, max_output_size=-1, iou_threshold=0.5, force_suppress=False, top_k=-1, coord_start=2, score_index=1, id_index=0, return_indices=True, invalid_to_bottom=False): @@ -304,6 +429,9 @@ def non_max_suppression(data, valid_count, max_output_size=-1, valid_count : tvm.te.Tensor 1-D tensor for valid number of boxes. + indices : tvm.te.Tensor + 2-D tensor with shape [batch_size, num_anchors]. + max_output_size : optional, int Max number of output valid boxes for each instance. By default all valid boxes are returned. @@ -334,8 +462,12 @@ def non_max_suppression(data, valid_count, max_output_size=-1, Returns ------- - out : tvm.te.Tensor - 3-D tensor with shape [batch_size, num_anchors, 6]. + out : tvm.te.Tensor or tuple of tvm.te.Tensor + 3-D tensor with shape [batch_size, num_anchors, 6] + or [batch_size, num_anchors, 5]. Out is a tuple of tvm.te.Tensor + if return_indices is True, the Tensor in the tuple is 2-D tensor + with shape [batch_size, num_anchors] and shape + [batch_size, num_valid_anchors] respectively. Example -------- @@ -348,7 +480,7 @@ def non_max_suppression(data, valid_count, max_output_size=-1, iou_threshold = 0.7 force_suppress = True top_k = -1 - out = non_max_suppression(data, valid_count, iou_threshold=iou_threshold, + out = non_max_suppression(data, valid_count, indices, iou_threshold=iou_threshold, force_suppress=force_suppress, top_k=top_k) np_data = np.random.uniform(dshape) np_valid_count = np.array([4]) @@ -366,17 +498,27 @@ def non_max_suppression(data, valid_count, max_output_size=-1, score_shape = (batch_size, num_anchors) score_tensor = te.compute(score_shape, lambda i, j: data[i, j, score_axis]) sort_tensor = argsort(score_tensor, valid_count=valid_count, axis=1, is_ascend=False) - out, box_indices = hybrid_nms(data, sort_tensor, valid_count, + out, box_indices = hybrid_nms(data, + sort_tensor, + valid_count, + indices, + batch_size, + num_anchors, tvm.tir.const(max_output_size, dtype="int32"), tvm.tir.const(iou_threshold, dtype=data.dtype), tvm.tir.const(force_suppress, dtype="bool"), tvm.tir.const(top_k, dtype="int32"), tvm.tir.const(coord_start, dtype="int32"), - tvm.tir.const(id_index, dtype="int32"), tvm.tir.const(score_index, dtype="int32"), + tvm.tir.const(id_index, dtype="int32"), + tvm.tir.const(return_indices, dtype="bool"), zero=tvm.tir.const(0, dtype=data.dtype), one=tvm.tir.const(1, dtype=data.dtype)) - if not return_indices and invalid_to_bottom: - out = hybrid_rearrange_out(out, one=tvm.tir.const(1, dtype=data.dtype)) - - return box_indices if return_indices else out + if return_indices: + return hybrid_rearrange_indices_out(box_indices, one=tvm.tir.const(1, dtype="int32"), + batch_size=batch_size, num_anchors=num_anchors) + + if invalid_to_bottom: + out = hybrid_rearrange_box_out(out, one=tvm.tir.const(1, dtype=data.dtype), + batch_size=batch_size, num_anchors=num_anchors) + return out diff --git a/topi/python/topi/vision/rcnn/proposal.py b/topi/python/topi/vision/rcnn/proposal.py index 23bd24d22fb3..e99ebe0da903 100644 --- a/topi/python/topi/vision/rcnn/proposal.py +++ b/topi/python/topi/vision/rcnn/proposal.py @@ -82,10 +82,10 @@ def predict_bbox_ir(cls_prob_buf, bbox_pred_buf, im_info_buf, out_buf, scales, r The last dimension is in format of [w_start, h_start, w_end, h_end, score] scales : list/tuple of float - Scales of anchor windoes. + Scales of anchor windows. ratios : list/tuple of float - Ratios of anchor windoes. + Ratios of anchor windows. feature_stride : int The size of the receptive field each unit in the convolution layer of the rpn, for example @@ -335,10 +335,10 @@ def proposal(cls_prob, bbox_pred, im_info, scales, ratios, feature_stride, thres 2-D with shape [batch, 3] scales : list/tuple of float - Scales of anchor windoes. + Scales of anchor windows. ratios : list/tuple of float - Ratios of anchor windoes. + Ratios of anchor windows. feature_stride : int The size of the receptive field each unit in the convolution layer of the rpn, for example diff --git a/topi/python/topi/vision/ssd/multibox.py b/topi/python/topi/vision/ssd/multibox.py index ba0cf5440c9a..e5b92156bdc3 100644 --- a/topi/python/topi/vision/ssd/multibox.py +++ b/topi/python/topi/vision/ssd/multibox.py @@ -304,7 +304,7 @@ def multibox_detection(cls_prob, loc_pred, anchor, clip=True, threshold=0.01, nm """ inter_out = multibox_transform_loc(cls_prob, loc_pred, anchor, clip, threshold, variances) - out = non_max_suppression(inter_out[0], inter_out[1], max_output_size=-1, + out = non_max_suppression(inter_out[0], inter_out[1], inter_out[1], max_output_size=-1, iou_threshold=nms_threshold, force_suppress=force_suppress, top_k=nms_topk, return_indices=False) return out diff --git a/topi/python/topi/x86/__init__.py b/topi/python/topi/x86/__init__.py index ce07c194268a..659668cbbe4c 100644 --- a/topi/python/topi/x86/__init__.py +++ b/topi/python/topi/x86/__init__.py @@ -36,5 +36,6 @@ from .batch_matmul import * from .roi_align import roi_align_nchw from .conv2d_transpose import * +from .conv3d_transpose import * from .sparse import * from .conv2d_alter_op import * diff --git a/topi/python/topi/x86/conv2d_alter_op.py b/topi/python/topi/x86/conv2d_alter_op.py index 5ee691b07362..e9fc4223a9ea 100644 --- a/topi/python/topi/x86/conv2d_alter_op.py +++ b/topi/python/topi/x86/conv2d_alter_op.py @@ -19,6 +19,7 @@ import logging +import re import tvm from tvm import te from tvm import relay @@ -31,6 +32,9 @@ logger = logging.getLogger('topi') +_NCHWc_matcher = re.compile("^NCHW[0-9]+c$") +_OIHWio_matcher = re.compile("^OIHW[0-9]+i[0-9]+o$") + @conv2d_alter_layout.register("cpu") def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): target = tvm.target.Target.current(allow_none=False) @@ -64,30 +68,33 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): if topi_tmpl == "conv2d_NCHWc.x86": # we only convert conv2d_NCHW to conv2d_NCHWc for x86 - assert data_layout == "NCHW" and kernel_layout == "OIHW" - if cfg.is_fallback: - _get_default_config(cfg, data_tensor, kernel_tensor, strides, padding, - out_dtype, False, data_layout) - batch_size, in_channel, height, width = get_const_tuple(data_tensor.shape) - out_channel, _, kh, kw = get_const_tuple(kernel_tensor.shape) - ic_bn, oc_bn = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1] - - # update new attrs - new_attrs['channels'] = out_channel - new_attrs['data_layout'] = 'NCHW%dc' % ic_bn - # (oc, ic, h, w) -> (OC, IC, h, w, ic, oc) - new_attrs['kernel_layout'] = 'OIHW%di%do' % (ic_bn, oc_bn) - new_attrs['out_layout'] = 'NCHW%dc' % oc_bn - - # Store altered operator's config - new_data = te.placeholder((batch_size, in_channel//ic_bn, height, width, ic_bn), - dtype=data_dtype) - new_kernel = te.placeholder((out_channel//oc_bn, in_channel//ic_bn, - kh, kw, ic_bn, oc_bn), dtype=kernel_tensor.dtype) - new_workload = autotvm.task.args_to_workload( - [new_data, new_kernel, strides, padding, dilation, new_attrs["data_layout"], - new_attrs["out_layout"], out_dtype], topi_tmpl) - dispatch_ctx.update(target, new_workload, cfg) + if data_layout == "NCHW" and kernel_layout == "OIHW": + if cfg.is_fallback: + _get_default_config(cfg, data_tensor, kernel_tensor, strides, padding, + out_dtype, False, data_layout) + batch_size, in_channel, height, width = get_const_tuple(data_tensor.shape) + out_channel, _, kh, kw = get_const_tuple(kernel_tensor.shape) + ic_bn, oc_bn = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1] + + # update new attrs + new_attrs['channels'] = out_channel + new_attrs['data_layout'] = 'NCHW%dc' % ic_bn + # (oc, ic, h, w) -> (OC, IC, h, w, ic, oc) + new_attrs['kernel_layout'] = 'OIHW%di%do' % (ic_bn, oc_bn) + new_attrs['out_layout'] = 'NCHW%dc' % oc_bn + + # Store altered operator's config + new_data = te.placeholder((batch_size, in_channel//ic_bn, height, width, ic_bn), + dtype=data_dtype) + new_kernel = te.placeholder((out_channel//oc_bn, in_channel//ic_bn, + kh, kw, ic_bn, oc_bn), dtype=kernel_tensor.dtype) + new_workload = autotvm.task.args_to_workload( + [new_data, new_kernel, strides, padding, dilation, new_attrs["data_layout"], + new_attrs["out_layout"], out_dtype], topi_tmpl) + dispatch_ctx.update(target, new_workload, cfg) + else: + assert _NCHWc_matcher.match(data_layout) + assert _OIHWio_matcher.match(kernel_layout) return relay.nn.contrib_conv2d_nchwc(*inputs, **new_attrs) if topi_tmpl == "conv2d_NCHWc_int8.x86": @@ -136,30 +143,34 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): return relay.nn.contrib_conv2d_nchwc(data_expr, kernel_OIHWioe, **new_attrs) if topi_tmpl == "depthwise_conv2d_NCHWc.x86": - assert data_layout == "NCHW" and kernel_layout == "OIHW" - if cfg.is_fallback: - _get_default_config(cfg, data_tensor, kernel_tensor, strides, padding, - out_dtype, True, data_layout) - - batch_size, in_channel, height, width = get_const_tuple(data_tensor.shape) - out_channel, channel_multiplier, kh, kw = get_const_tuple(kernel_tensor.shape) - ic_bn, oc_bn = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1] - assert channel_multiplier == 1 - - # update new attrs - new_attrs['channels'] = out_channel - new_attrs['data_layout'] = 'NCHW%dc' % ic_bn - new_attrs['kernel_layout'] = 'OIHW1i%do' % oc_bn - new_attrs['out_layout'] = 'NCHW%dc' % oc_bn - - # Store altered operator's config. - new_data = te.placeholder((batch_size, in_channel//ic_bn, height, width, ic_bn), - dtype=data_dtype) - new_kernel = te.placeholder((out_channel//oc_bn, 1, kh, kw, 1, oc_bn), dtype=kernel_dtype) - new_workload = autotvm.task.args_to_workload( - [new_data, new_kernel, strides, padding, dilation, new_attrs['data_layout'], - new_attrs['out_layout'], out_dtype], topi_tmpl) - dispatch_ctx.update(target, new_workload, cfg) + if data_layout == "NCHW" and kernel_layout == "OIHW": + if cfg.is_fallback: + _get_default_config(cfg, data_tensor, kernel_tensor, strides, padding, + out_dtype, True, data_layout) + + batch_size, in_channel, height, width = get_const_tuple(data_tensor.shape) + out_channel, channel_multiplier, kh, kw = get_const_tuple(kernel_tensor.shape) + ic_bn, oc_bn = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1] + assert channel_multiplier == 1 + + # update new attrs + new_attrs['channels'] = out_channel + new_attrs['data_layout'] = 'NCHW%dc' % ic_bn + new_attrs['kernel_layout'] = 'OIHW1i%do' % oc_bn + new_attrs['out_layout'] = 'NCHW%dc' % oc_bn + + # Store altered operator's config. + new_data = te.placeholder((batch_size, in_channel//ic_bn, height, width, ic_bn), + dtype=data_dtype) + new_kernel = te.placeholder((out_channel//oc_bn, 1, kh, kw, 1, oc_bn), + dtype=kernel_dtype) + new_workload = autotvm.task.args_to_workload( + [new_data, new_kernel, strides, padding, dilation, new_attrs['data_layout'], + new_attrs['out_layout'], out_dtype], topi_tmpl) + dispatch_ctx.update(target, new_workload, cfg) + else: + assert _NCHWc_matcher.match(data_layout) + assert _OIHWio_matcher.match(kernel_layout) return relay.nn.contrib_depthwise_conv2d_nchwc(*inputs, **new_attrs) return None @@ -301,7 +312,9 @@ def _conv2d_legalize(attrs, inputs, arg_types): new_attrs['channels'] = new_out_channel out = tvm.relay.nn.conv2d(data, kernel, **new_attrs) original_out_shape = [x.value for x in output_tensor.shape] - out = relay.strided_slice(out, begin=(0, 0, 0, 0), end=original_out_shape) + out = relay.strided_slice(out, + begin=relay.const([0, 0, 0, 0], "int32"), + end=relay.const(original_out_shape, "int32")) else: out = relay.nn.conv2d(data, kernel, **new_attrs) diff --git a/topi/python/topi/x86/conv3d.py b/topi/python/topi/x86/conv3d.py index 27f48f8dc69a..f0dee31a9992 100644 --- a/topi/python/topi/x86/conv3d.py +++ b/topi/python/topi/x86/conv3d.py @@ -78,11 +78,11 @@ def conv3d_ncdhw(cfg, data, kernel, strides, padding, dilation, out_dtype): Parameters ---------- - input : tvm.Tensor + input : tvm.te.Tensor 5-D input data with shapes: [batch, in_channel, in_depth, in_height, in_width] for NCDHW layout - filter : tvm.Tensor + filter : tvm.te.Tensor 5-D filter with shape [out_channels, in_channels, kernel_depth, kernel_height, kernel_width] strides : int or a list/tuple of three ints @@ -96,7 +96,7 @@ def conv3d_ncdhw(cfg, data, kernel, strides, padding, dilation, out_dtype): Returns ------- - output : tvm.Tensor + output : tvm.te.Tensor 5-D with shape [batch, out_channel, out_depth, out_height, out_width] for NCDHW layout """ layout = "NCDHW" diff --git a/topi/python/topi/x86/conv3d_transpose.py b/topi/python/topi/x86/conv3d_transpose.py new file mode 100644 index 000000000000..ad035d34c3a1 --- /dev/null +++ b/topi/python/topi/x86/conv3d_transpose.py @@ -0,0 +1,53 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,unused-variable,unused-argument,no-member +# pylint: disable=no-value-for-parameter + +"""Conv3D Transpose schedule on x86""" +from tvm import te +from ..util import traverse_inline +from .. import nn +from .conv3d import conv3d_ncdhw, schedule_conv3d_ncdhw + +def conv3d_transpose_ncdhw(data, kernel, strides, padding, out_dtype): + data_pad, kernel_transform = \ + nn.conv3d_transpose_ncdhw_preprocess(data, kernel, strides, padding, out_dtype) + + # reuse conv3d_ncdhw implementation + return conv3d_ncdhw(data_pad, kernel_transform, (1, 1, 1), + (0, 0, 0), (1, 1, 1), out_dtype) + +def schedule_conv3d_transpose_ncdhw(outs): + """Create schedule for tensors""" + outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs + s = schedule_conv3d_ncdhw(outs) + def _callback(op): + if 'unpack_ncdhwc' in op.tag: + conv_out = op.input_tensors[0] + # retrieve data + data_vec = conv_out.op.input_tensors[0] + data_pad = data_vec.op.input_tensors[0] + data_dilate = data_pad.op.input_tensors[0] + s[data_dilate].compute_inline() + s[data_pad].compute_inline() + # retrieve kernel + kernel_vec = conv_out.op.input_tensors[1] + kernel_transform = kernel_vec.op.input_tensors[0] + s[kernel_transform].compute_inline() + + traverse_inline(s, outs[0].op, _callback) + return s diff --git a/topi/python/topi/x86/sparse.py b/topi/python/topi/x86/sparse.py index 54a5af9ca9f0..02cbd2d76ed3 100644 --- a/topi/python/topi/x86/sparse.py +++ b/topi/python/topi/x86/sparse.py @@ -21,11 +21,9 @@ from ..util import traverse_inline, get_const_int from .util import get_fp32_len - def schedule_sparse_dense(outs): """Create schedule for sparse dense""" s = te.create_schedule([x.op for x in outs]) - def _callback(op): simd_width = get_fp32_len() if op.tag == "sparse_dense_csrmm" and op != outs[0].op: diff --git a/topi/python/topi/x86/tensor_intrin.py b/topi/python/topi/x86/tensor_intrin.py index 955b6b4ad280..ee8d83dbef07 100644 --- a/topi/python/topi/x86/tensor_intrin.py +++ b/topi/python/topi/x86/tensor_intrin.py @@ -110,8 +110,10 @@ def _instr(index): # body, reset, update return _instr(0), _instr(1), _instr(2) - with tvm.target.build_config(offset_factor=1, partition_const_loop=True): - return te.decl_tensor_intrin(C.op, _intrin_func, binds={data:a_buffer, kernel:b_buffer}) + buffer_params = {"offset_factor" : 1} + return te.decl_tensor_intrin( + C.op, _intrin_func, binds={data:a_buffer, kernel:b_buffer}, + default_buffer_params=buffer_params) def dot_16x1x16_uint8_int8_int16(): @@ -191,9 +193,10 @@ def _instr(index): # body, reset, update return _instr(0), _instr(1), _instr(2) - - with tvm.target.build_config(offset_factor=1, partition_const_loop=True): - return te.decl_tensor_intrin(C.op, _intrin_func, binds={data:a_buffer, kernel:b_buffer}) + buffer_params = {"offset_factor" : 1} + return te.decl_tensor_intrin( + C.op, _intrin_func, binds={data:a_buffer, kernel:b_buffer}, + default_buffer_params=buffer_params) def dot_16x1x16_uint8_int8_int32_cascadelake(): @@ -287,5 +290,7 @@ def _instr(index): # body, reset, update return _instr(0), _instr(1), _instr(2) - with tvm.target.build_config(offset_factor=1, partition_const_loop=True): - return te.decl_tensor_intrin(C.op, _intrin_func, binds={data:a_buffer, kernel:b_buffer}) + buffer_params = {"offset_factor" : 1} + return te.decl_tensor_intrin( + C.op, _intrin_func, binds={data:a_buffer, kernel:b_buffer}, + default_buffer_params=buffer_params) diff --git a/topi/recipe/conv/depthwise_conv2d_test.py b/topi/recipe/conv/depthwise_conv2d_test.py index a2b527356662..72e054e12b14 100644 --- a/topi/recipe/conv/depthwise_conv2d_test.py +++ b/topi/recipe/conv/depthwise_conv2d_test.py @@ -129,10 +129,10 @@ def check_device(device): print("success") for device in ['cuda', 'opencl', 'rocm']: - with tvm.target.build_config(auto_unroll_max_step=128, - unroll_explicit=device == 'rocm', - detect_global_barrier=False, - restricted_func=True): + with tvm.transform.PassContext(config={"tir.UnrollLoop": { + "auto_max_step": 128, + "explicit_unroll": device != "rocm" + }}): check_device(device) def test_depthwise_conv2d_nhwc(): @@ -218,9 +218,10 @@ def check_device(device): print("success") for device in ['cuda', 'opencl', 'rocm']: - with tvm.target.build_config(auto_unroll_max_step=128, - detect_global_barrier=False, - restricted_func=True): + with tvm.transform.PassContext(config={"tir.UnrollLoop": { + "auto_max_step": 128, + "explicit_unroll": device != "cuda" + }}): check_device(device) if __name__ == "__main__": diff --git a/topi/recipe/conv/test_conv2d_hwcn_map.py b/topi/recipe/conv/test_conv2d_hwcn_map.py index 69bda79555a9..35cd477e1f98 100644 --- a/topi/recipe/conv/test_conv2d_hwcn_map.py +++ b/topi/recipe/conv/test_conv2d_hwcn_map.py @@ -77,8 +77,11 @@ def check_device(device): w = tvm.nd.array(w_np, ctx) b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx) - with tvm.target.build_config(auto_unroll_max_step=128, - unroll_explicit=device == 'rocm'): + + with tvm.transform.PassContext(config={"tir.UrollLoop": { + "auto_unroll_max_step": 128, + "explicit_unroll": device == "rocm" + }}): func1 = tvm.build(s1, [A, W, B], device) func1(a, w, b) tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) diff --git a/topi/recipe/gemm/cuda_gemm_square.py b/topi/recipe/gemm/cuda_gemm_square.py index 196bf72e23a3..b35cd6073d1f 100644 --- a/topi/recipe/gemm/cuda_gemm_square.py +++ b/topi/recipe/gemm/cuda_gemm_square.py @@ -146,8 +146,10 @@ def check_device(device): print("average time cost of %d runs = %g ms, %g GFLOPS." % (num_runs, t * 1e3, GFLOPS)) for device in ["cuda", "opencl", "rocm", "nvptx", "vulkan"]: - with tvm.target.build_config(auto_unroll_max_step=128, - unroll_explicit=(device != "cuda")): + with tvm.transform.PassContext(config={"tir.UnrollLoop": { + "auto_max_step": 128, + "explicit_unroll": device != "cuda" + }}): check_device(device) if __name__ == "__main__": diff --git a/topi/recipe/reduce/test_reduce_map.py b/topi/recipe/reduce/test_reduce_map.py index 31f9bae7426c..5e5caec73bc3 100644 --- a/topi/recipe/reduce/test_reduce_map.py +++ b/topi/recipe/reduce/test_reduce_map.py @@ -27,12 +27,6 @@ USE_MANUAL_CODE = False -@tvm.register_func -def tvm_callback_cuda_compile(code): - ptx = nvcc.compile_cuda(code, target="ptx") - return ptx - - def write_code(code, fname): with open(fname, "w") as f: f.write(code) @@ -64,8 +58,9 @@ def test_reduce_map(in_shape, axis, keepdims, type="sum", test_id=0): else: raise NotImplementedError s = topi.cuda.schedule_reduce(B) - with tvm.target.build_config(auto_unroll_max_step=16, - auto_unroll_min_depth=0): + with tvm.transform.PassContext(config={"tir.UnrollLoop": { + "auto_max_step": 16, + }}): fcuda = tvm.build(s, [A, B], "cuda", name="sum") # Test diff --git a/topi/recipe/rnn/lstm.py b/topi/recipe/rnn/lstm.py index 4076eb6a4614..be46d895f444 100644 --- a/topi/recipe/rnn/lstm.py +++ b/topi/recipe/rnn/lstm.py @@ -188,10 +188,12 @@ def check_device(target): print("Time cost=%g" % eval_result.mean) # set unroll_explicit for more readable code. - with tvm.target.build_config( - detect_global_barrier=DETECT_GLOBAL_BARRIER, - auto_unroll_max_step=128, - unroll_explicit=False): + with tvm.transform.PassContext(config={ + "tir.UnrollLoop": { + "auto_max_step": 128, + }, + "tir.detect_global_barrier": DETECT_GLOBAL_BARRIER + }): check_device("cuda") if __name__ == "__main__": diff --git a/topi/recipe/rnn/matexp.py b/topi/recipe/rnn/matexp.py index 9991895ec8dc..444e27fed9a4 100644 --- a/topi/recipe/rnn/matexp.py +++ b/topi/recipe/rnn/matexp.py @@ -127,10 +127,12 @@ def rnn_matexp(): s[SS].bind(tx, thread_x) def check_device(target): - with tvm.target.build_config( - detect_global_barrier=detect_global_barrier, - auto_unroll_max_step=128, - unroll_explicit=False): + with tvm.transform.PassContext(config={ + "tir.UnrollLoop": { + "auto_max_step": 128, + }, + "tir.detect_global_barrier": detect_global_barrier + }): f = tvm.build(s, [s_scan, Whh], target) ctx = tvm.gpu(0) if target == "cuda" else tvm.cl(0) # launch the kernel. diff --git a/topi/src/broadcast.cc b/topi/src/broadcast.cc index b14754573c64..e13c09ebb922 100644 --- a/topi/src/broadcast.cc +++ b/topi/src/broadcast.cc @@ -18,39 +18,33 @@ */ /*! -* \brief Registration of broadcast operators -* \file broadcast.cc -*/ -#include -#include - + * \brief Registration of broadcast operators + * \file broadcast.cc + */ #include #include +#include +#include namespace topi { using namespace tvm; using namespace tvm::runtime; -#define TOPI_REGISTER_BCAST_OP(OpName, Op) \ - TVM_REGISTER_GLOBAL(OpName) \ - .set_body([](TVMArgs args, TVMRetValue *rv) { \ - bool lhs_is_tensor = args[0].IsObjectRef(); \ - bool rhs_is_tensor = args[1].IsObjectRef(); \ - if (lhs_is_tensor && rhs_is_tensor) { \ - *rv = Op(args[0].operator tvm::te::Tensor(), \ - args[1].operator tvm::te::Tensor()); \ - } else if (!lhs_is_tensor && rhs_is_tensor) { \ - *rv = Op(args[0].operator tvm::PrimExpr(), \ - args[1].operator tvm::te::Tensor()); \ - } else if (lhs_is_tensor && !rhs_is_tensor) { \ - *rv = Op(args[0].operator tvm::te::Tensor(), \ - args[1].operator tvm::PrimExpr()); \ - } else if (!lhs_is_tensor && !rhs_is_tensor) { \ - *rv = Op(args[0].operator tvm::PrimExpr(), \ - args[1].operator tvm::PrimExpr()); \ - } \ - }); \ +#define TOPI_REGISTER_BCAST_OP(OpName, Op) \ + TVM_REGISTER_GLOBAL(OpName).set_body([](TVMArgs args, TVMRetValue* rv) { \ + bool lhs_is_tensor = args[0].IsObjectRef(); \ + bool rhs_is_tensor = args[1].IsObjectRef(); \ + if (lhs_is_tensor && rhs_is_tensor) { \ + *rv = Op(args[0].operator tvm::te::Tensor(), args[1].operator tvm::te::Tensor()); \ + } else if (!lhs_is_tensor && rhs_is_tensor) { \ + *rv = Op(args[0].operator tvm::PrimExpr(), args[1].operator tvm::te::Tensor()); \ + } else if (lhs_is_tensor && !rhs_is_tensor) { \ + *rv = Op(args[0].operator tvm::te::Tensor(), args[1].operator tvm::PrimExpr()); \ + } else if (!lhs_is_tensor && !rhs_is_tensor) { \ + *rv = Op(args[0].operator tvm::PrimExpr(), args[1].operator tvm::PrimExpr()); \ + } \ + }); TOPI_REGISTER_BCAST_OP("topi.add", topi::add); TOPI_REGISTER_BCAST_OP("topi.subtract", topi::subtract); @@ -77,9 +71,8 @@ TOPI_REGISTER_BCAST_OP("topi.not_equal", topi::not_equal); TOPI_REGISTER_BCAST_OP("topi.greater_equal", topi::greater_equal); TOPI_REGISTER_BCAST_OP("topi.less_equal", topi::less_equal); -TVM_REGISTER_GLOBAL("topi.broadcast_to") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.broadcast_to").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = broadcast_to(args[0], args[1]); - }); +}); } // namespace topi diff --git a/topi/src/elemwise.cc b/topi/src/elemwise.cc index 71764cd52c0d..10ac8f8c4cee 100644 --- a/topi/src/elemwise.cc +++ b/topi/src/elemwise.cc @@ -18,142 +18,140 @@ */ /*! -* \brief Registration of elemwise operators -* \file elemwise.cc -*/ + * \brief Registration of elemwise operators + * \file elemwise.cc + */ +#include #include #include -#include - namespace topi { using namespace tvm; using namespace tvm::runtime; -TVM_REGISTER_GLOBAL("topi.exp") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = exp(args[0]); - }); +TVM_REGISTER_GLOBAL("topi.acos").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = acos(args[0]); +}); + +TVM_REGISTER_GLOBAL("topi.acosh").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = acosh(args[0]); +}); + +TVM_REGISTER_GLOBAL("topi.asin").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = asin(args[0]); +}); + +TVM_REGISTER_GLOBAL("topi.asinh").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = asinh(args[0]); +}); -TVM_REGISTER_GLOBAL("topi.fast_exp") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.atanh").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = atanh(args[0]); +}); + +TVM_REGISTER_GLOBAL("topi.exp").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = exp(args[0]); }); + +TVM_REGISTER_GLOBAL("topi.fast_exp").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = fast_exp(args[0]); - }); +}); -TVM_REGISTER_GLOBAL("topi.erf") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = erf(args[0]); - }); +TVM_REGISTER_GLOBAL("topi.erf").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = erf(args[0]); }); -TVM_REGISTER_GLOBAL("topi.fast_erf") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.fast_erf").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = fast_erf(args[0]); - }); +}); + +TVM_REGISTER_GLOBAL("topi.tan").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = tan(args[0]); }); -TVM_REGISTER_GLOBAL("topi.tan") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = tan(args[0]); - }); +TVM_REGISTER_GLOBAL("topi.cos").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = cos(args[0]); }); -TVM_REGISTER_GLOBAL("topi.cos") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = cos(args[0]); - }); +TVM_REGISTER_GLOBAL("topi.cosh").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = cosh(args[0]); +}); -TVM_REGISTER_GLOBAL("topi.sin") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = sin(args[0]); - }); +TVM_REGISTER_GLOBAL("topi.sin").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = sin(args[0]); }); -TVM_REGISTER_GLOBAL("topi.tanh") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.sinh").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = sinh(args[0]); +}); + +TVM_REGISTER_GLOBAL("topi.tanh").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = tanh(args[0]); - }); +}); -TVM_REGISTER_GLOBAL("topi.fast_tanh") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.fast_tanh").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = fast_tanh(args[0]); - }); +}); -TVM_REGISTER_GLOBAL("topi.atan") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.atan").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = atan(args[0]); - }); +}); -TVM_REGISTER_GLOBAL("topi.sigmoid") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.sigmoid").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = sigmoid(args[0]); - }); +}); -TVM_REGISTER_GLOBAL("topi.sqrt") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.sqrt").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = sqrt(args[0]); - }); +}); + +TVM_REGISTER_GLOBAL("topi.rsqrt").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = rsqrt(args[0]); +}); + +TVM_REGISTER_GLOBAL("topi.log").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = log(args[0]); }); -TVM_REGISTER_GLOBAL("topi.rsqrt") -.set_body([](TVMArgs args, TVMRetValue *rv) { -*rv = rsqrt(args[0]); - }); +TVM_REGISTER_GLOBAL("topi.log2").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = log2(args[0]); +}); -TVM_REGISTER_GLOBAL("topi.log") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = log(args[0]); - }); +TVM_REGISTER_GLOBAL("topi.log10").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = log10(args[0]); +}); -TVM_REGISTER_GLOBAL("topi.identity") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.identity").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = identity(args[0]); - }); +}); -TVM_REGISTER_GLOBAL("topi.negative") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.negative").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = negative(args[0]); - }); +}); -TVM_REGISTER_GLOBAL("topi.clip") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.clip").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = clip(args[0], args[1], args[2]); - }); +}); -TVM_REGISTER_GLOBAL("topi.cast") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.cast").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = cast(args[0], args[1]); - }); +}); -TVM_REGISTER_GLOBAL("topi.reinterpret") -.set_body([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("topi.reinterpret").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = reinterpret(args[0], args[1]); - }); +}); -TVM_REGISTER_GLOBAL("topi.elemwise_sum") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.elemwise_sum").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = elemwise_sum(args[0]); - }); +}); -TVM_REGISTER_GLOBAL("topi.sign") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.sign").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = sign(args[0]); - }); +}); -TVM_REGISTER_GLOBAL("topi.full") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.full").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = full(args[0], args[1], args[2]); - }); +}); -TVM_REGISTER_GLOBAL("topi.full_like") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.full_like").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = full_like(args[0], args[1]); - }); +}); -TVM_REGISTER_GLOBAL("topi.logical_not") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.logical_not").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = logical_not(args[0]); - }); +}); -TVM_REGISTER_GLOBAL("topi.bitwise_not") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.bitwise_not").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = bitwise_not(args[0]); - }); +}); } // namespace topi diff --git a/topi/src/nn.cc b/topi/src/nn.cc index 77b208db0dd0..3ec47787ec6e 100644 --- a/topi/src/nn.cc +++ b/topi/src/nn.cc @@ -18,23 +18,22 @@ */ /*! -* \brief Registration of NN operators -* \file nn.cc -*/ -#include -#include - + * \brief Registration of NN operators + * \file nn.cc + */ #include +#include #include #include #include #include #include +#include #include #include #include -#include -#include +#include +#include namespace topi { @@ -42,144 +41,113 @@ using namespace tvm; using namespace tvm::runtime; /* Ops from nn.h */ -TVM_REGISTER_GLOBAL("topi.nn.relu") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.nn.relu").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = relu(args[0]); - }); +}); -TVM_REGISTER_GLOBAL("topi.nn.leaky_relu") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.nn.leaky_relu").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = leaky_relu(args[0], args[1]); - }); +}); -TVM_REGISTER_GLOBAL("topi.nn.prelu") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.nn.prelu").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = prelu(args[0], args[1], args[2]); - }); +}); -TVM_REGISTER_GLOBAL("topi.nn.pad") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.nn.pad").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = pad(args[0], args[1], args[2], args[3]); - }); +}); /* Ops from nn/dense.h */ -TVM_REGISTER_GLOBAL("topi.nn.dense") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.nn.dense").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = nn::dense(args[0], args[1], args[2], args[3]); - }); +}); /* Ops from nn/bias_add.h */ -TVM_REGISTER_GLOBAL("topi.nn.bias_add") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.nn.bias_add").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = nn::bias_add(args[0], args[1], args[2]); - }); +}); /* Ops from nn/batch_matmul.h */ -TVM_REGISTER_GLOBAL("topi.nn.batch_matmul") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.nn.batch_matmul").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = nn::batch_matmul(args[0], args[1]); - }); +}); /* Ops from nn/dilate.h */ -TVM_REGISTER_GLOBAL("topi.nn.dilate") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.nn.dilate").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = nn::dilate(args[0], args[1]); - }); +}); /* Ops from nn/flatten.h */ -TVM_REGISTER_GLOBAL("topi.nn.flatten") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.nn.flatten").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = nn::flatten(args[0]); - }); +}); /* Ops from nn/mapping.h */ -TVM_REGISTER_GLOBAL("topi.nn.scale_shift_nchw") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.nn.scale_shift_nchw").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = nn::scale_shift_nchw(args[0], args[1], args[2]); - }); +}); -TVM_REGISTER_GLOBAL("topi.nn.scale_shift_nhwc") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.nn.scale_shift_nhwc").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = nn::scale_shift_nhwc(args[0], args[1], args[2]); - }); +}); /* Ops from nn/pooling.h */ -TVM_REGISTER_GLOBAL("topi.nn.pool") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.nn.pool").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = nn::pool(args[0], args[1], args[2], args[3], - static_cast(static_cast(args[4])), - args[5], args[6], args[7]); - }); + static_cast(static_cast(args[4])), args[5], args[6], args[7]); +}); -TVM_REGISTER_GLOBAL("topi.nn.pool_grad") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.nn.pool_grad").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = nn::pool_grad(args[0], args[1], args[2], args[3], args[4], - static_cast(static_cast(args[5])), - args[6], args[7], args[8]); - }); - -TVM_REGISTER_GLOBAL("topi.nn.global_pool") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = nn::global_pool(args[0], - static_cast(static_cast(args[1])), args[2]); - }); - -TVM_REGISTER_GLOBAL("topi.nn.adaptive_pool") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = nn::adaptive_pool(args[0], args[1], - static_cast(static_cast(args[2])), + static_cast(static_cast(args[5])), args[6], args[7], + args[8]); +}); + +TVM_REGISTER_GLOBAL("topi.nn.global_pool").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = nn::global_pool(args[0], static_cast(static_cast(args[1])), args[2]); +}); + +TVM_REGISTER_GLOBAL("topi.nn.adaptive_pool").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = nn::adaptive_pool(args[0], args[1], static_cast(static_cast(args[2])), args[3]); }); -TVM_REGISTER_GLOBAL("topi.nn.adaptive_pool3d") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = nn::adaptive_pool3d(args[0], args[1], - static_cast(static_cast(args[2])), +TVM_REGISTER_GLOBAL("topi.nn.adaptive_pool3d").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = nn::adaptive_pool3d(args[0], args[1], static_cast(static_cast(args[2])), args[3]); }); -TVM_REGISTER_GLOBAL("topi.nn.pool1d") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.nn.pool1d").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = nn::pool1d(args[0], args[1], args[2], args[3], - static_cast(static_cast(args[4])), - args[5], args[6], args[7]); - }); + static_cast(static_cast(args[4])), args[5], args[6], args[7]); +}); -TVM_REGISTER_GLOBAL("topi.nn.pool3d") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.nn.pool3d").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = nn::pool3d(args[0], args[1], args[2], args[3], - static_cast(static_cast(args[4])), - args[5], args[6], args[7]); - }); + static_cast(static_cast(args[4])), args[5], args[6], args[7]); +}); /* Ops from nn/softmax.h */ -TVM_REGISTER_GLOBAL("topi.nn.softmax") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.nn.softmax").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = nn::softmax(args[0], args[1]); - }); +}); -TVM_REGISTER_GLOBAL("topi.nn.log_softmax") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.nn.log_softmax").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = nn::log_softmax(args[0]); - }); +}); -TVM_REGISTER_GLOBAL("topi.nn.lrn") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = nn::lrn(args[0], args[1], args[2], - static_cast(args[3]), - static_cast(args[4]), - static_cast(args[5])); - }); +TVM_REGISTER_GLOBAL("topi.nn.lrn").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = nn::lrn(args[0], args[1], args[2], static_cast(args[3]), + static_cast(args[4]), static_cast(args[5])); +}); /* Ops from nn/bnn.h */ -TVM_REGISTER_GLOBAL("topi.nn.binarize_pack") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.nn.binarize_pack").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = nn::binarize_pack(args[0], args[1]); - }); +}); -TVM_REGISTER_GLOBAL("topi.nn.binary_dense") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.nn.binary_dense").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = nn::binary_dense(args[0], args[1]); - }); +}); } // namespace topi diff --git a/topi/src/reduction.cc b/topi/src/reduction.cc index e1fdada73eef..b981495411ba 100644 --- a/topi/src/reduction.cc +++ b/topi/src/reduction.cc @@ -18,58 +18,49 @@ */ /*! -* \brief Registration of reduction operators -* \file reduction.cc -*/ -#include -#include - + * \brief Registration of reduction operators + * \file reduction.cc + */ #include #include +#include +#include namespace topi { using namespace tvm; using namespace tvm::runtime; -TVM_REGISTER_GLOBAL("topi.sum") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.sum").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::sum(args[0], ArrayOrInt(args[1]), args[2]); - }); +}); -TVM_REGISTER_GLOBAL("topi.min") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.min").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::min(args[0], ArrayOrInt(args[1]), args[2]); - }); +}); -TVM_REGISTER_GLOBAL("topi.max") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.max").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::max(args[0], ArrayOrInt(args[1]), args[2]); - }); +}); -TVM_REGISTER_GLOBAL("topi.argmin") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.argmin").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::argmin(args[0], ArrayOrInt(args[1]), args[2]); - }); +}); -TVM_REGISTER_GLOBAL("topi.argmax") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.argmax").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::argmax(args[0], ArrayOrInt(args[1]), args[2]); - }); +}); -TVM_REGISTER_GLOBAL("topi.prod") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.prod").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::prod(args[0], ArrayOrInt(args[1]), args[2]); - }); +}); -TVM_REGISTER_GLOBAL("topi.all") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.all").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::all(args[0], ArrayOrInt(args[1]), args[2]); - }); +}); -TVM_REGISTER_GLOBAL("topi.any") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.any").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::any(args[0], ArrayOrInt(args[1]), args[2]); - }); +}); } // namespace topi diff --git a/topi/src/schedule.cc b/topi/src/schedule.cc index 936f39031e6a..b974acaf2dd5 100644 --- a/topi/src/schedule.cc +++ b/topi/src/schedule.cc @@ -18,212 +18,181 @@ */ /*! -* \brief Registration of TVM schedules -* \file schedule.cc -*/ + * \brief Registration of TVM schedules + * \file schedule.cc + */ #define TOPI_REDUCE_ATLEAST1D 0 -#include -#include -#include -#include -#include - -#include -#include -#include - #include #include +#include #include #include #include -#include - -#include -#include -#include - +#include +#include +#include +#include #include #include +#include #include #include #include -#include - -#include +#include +#include +#include +#include +#include +#include +#include +#include namespace topi { using namespace tvm; using namespace tvm::runtime; -TVM_REGISTER_GLOBAL("topi.TEST_create_target") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.TEST_create_target").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = tvm::Target::Create(args[0]); - }); +}); /* Generic schedules */ -TVM_REGISTER_GLOBAL("topi.generic.default_schedule") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.generic.default_schedule").set_body([](TVMArgs args, TVMRetValue* rv) { if (args[2]) { *rv = topi::generic::default_schedule_auto_inline(args[0], args[1]); } else { *rv = topi::generic::default_schedule(args[0], args[1]); } - }); +}); -TVM_REGISTER_GLOBAL("topi.generic.schedule_extern") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.generic.schedule_extern").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::generic::schedule_extern(args[0], args[1]); - }); +}); -TVM_REGISTER_GLOBAL("topi.generic.schedule_injective") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.generic.schedule_injective").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::generic::schedule_injective(args[0], args[1]); - }); +}); TVM_REGISTER_GLOBAL("topi.generic.schedule_injective_from_existing") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = topi::generic::schedule_injective_from_existing(args[0], args[1]); - }); + .set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = topi::generic::schedule_injective_from_existing(args[0], args[1]); + }); /* x86 schedules */ -TVM_REGISTER_GLOBAL("topi.x86.schedule_binarize_pack") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.x86.schedule_binarize_pack").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::x86::schedule_binarize_pack(args[0], args[1]); - }); +}); -TVM_REGISTER_GLOBAL("topi.x86.schedule_binary_dense") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.x86.schedule_binary_dense").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::x86::schedule_binary_dense(args[0], args[1]); - }); +}); -TVM_REGISTER_GLOBAL("topi.x86.default_schedule") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.x86.default_schedule").set_body([](TVMArgs args, TVMRetValue* rv) { if (args[2]) { *rv = topi::x86::default_schedule_auto_inline(args[0], args[1]); } else { *rv = topi::x86::default_schedule(args[0], args[1]); } - }); +}); -TVM_REGISTER_GLOBAL("topi.x86.schedule_injective") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.x86.schedule_injective").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::x86::schedule_injective(args[0], args[1]); - }); +}); TVM_REGISTER_GLOBAL("topi.x86.schedule_injective_from_existing") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = topi::x86::schedule_injective_from_existing(args[0], args[1]); - }); + .set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = topi::x86::schedule_injective_from_existing(args[0], args[1]); + }); /* ROCm schedules */ -TVM_REGISTER_GLOBAL("topi.rocm.dense_cuda") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.rocm.dense_cuda").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = rocm::dense_rocm(args[0], args[1], args[2], args[3], args[4]); - }); +}); -TVM_REGISTER_GLOBAL("topi.rocm.schedule_dense") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.rocm.schedule_dense").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::rocm::schedule_dense(args[0], args[1]); - }); +}); -TVM_REGISTER_GLOBAL("topi.rocm.schedule_injective") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.rocm.schedule_injective").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::rocm::schedule_injective(args[0], args[1]); - }); +}); TVM_REGISTER_GLOBAL("topi.rocm.schedule_injective_from_existing") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = topi::rocm::schedule_injective_from_existing(args[0], args[1]); - }); + .set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = topi::rocm::schedule_injective_from_existing(args[0], args[1]); + }); -TVM_REGISTER_GLOBAL("topi.rocm.schedule_pool") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.rocm.schedule_pool").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::rocm::schedule_pool(args[0], args[1]); - }); +}); -TVM_REGISTER_GLOBAL("topi.rocm.schedule_global_pool") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.rocm.schedule_global_pool").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::rocm::schedule_global_pool(args[0], args[1]); - }); +}); -TVM_REGISTER_GLOBAL("topi.rocm.schedule_reduce") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.rocm.schedule_reduce").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::rocm::schedule_reduce(args[0], args[1]); - }); +}); -TVM_REGISTER_GLOBAL("topi.rocm.schedule_softmax") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.rocm.schedule_softmax").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::rocm::schedule_softmax(args[0], args[1]); - }); +}); -TVM_REGISTER_GLOBAL("topi.rocm.schedule_lrn") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.rocm.schedule_lrn").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::rocm::schedule_lrn(args[0]); - }); +}); /* CUDA schedules */ -TVM_REGISTER_GLOBAL("topi.cuda.dense_cuda") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.cuda.dense_cuda").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = cuda::dense_cuda(args[0], args[1], args[2], args[3], args[4]); - }); +}); -TVM_REGISTER_GLOBAL("topi.cuda.schedule_dense") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.cuda.schedule_dense").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::cuda::schedule_dense(args[0], args[1]); - }); +}); -TVM_REGISTER_GLOBAL("topi.cuda.schedule_injective") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.cuda.schedule_injective").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::cuda::schedule_injective(args[0], args[1]); - }); +}); TVM_REGISTER_GLOBAL("topi.cuda.schedule_injective_from_existing") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = topi::cuda::schedule_injective_from_existing(args[0], args[1]); - }); + .set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = topi::cuda::schedule_injective_from_existing(args[0], args[1]); + }); -TVM_REGISTER_GLOBAL("topi.cuda.schedule_pool") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.cuda.schedule_pool").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::cuda::schedule_pool(args[0], args[1]); - }); +}); -TVM_REGISTER_GLOBAL("topi.cuda.schedule_global_pool") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.cuda.schedule_global_pool").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::cuda::schedule_global_pool(args[0], args[1]); - }); +}); -TVM_REGISTER_GLOBAL("topi.cuda.schedule_reduce") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.cuda.schedule_reduce").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::cuda::schedule_reduce(args[0], args[1]); - }); +}); -TVM_REGISTER_GLOBAL("topi.cuda.schedule_softmax") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.cuda.schedule_softmax").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::cuda::schedule_softmax(args[0], args[1]); - }); +}); -TVM_REGISTER_GLOBAL("topi.cuda.schedule_lrn") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.cuda.schedule_lrn").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::cuda::schedule_lrn(args[0]); - }); +}); /* Utility functions */ -TVM_REGISTER_GLOBAL("topi.util.is_empty_shape") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.util.is_empty_shape").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = topi::detail::is_empty_shape(args[0]); - }); +}); -TVM_REGISTER_GLOBAL("topi.util.bilinear_sample_nchw") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.util.bilinear_sample_nchw").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = detail::bilinear_sample_nchw(args[0], args[1], args[2], args[3]); - }); +}); /*! \brief Builder function for instantiating schedules. */ -using FTVMScheduleBuilder = std::function< - tvm::te::Schedule(const tvm::Target& target, const tvm::Array& outs)>; +using FTVMScheduleBuilder = std::function& outs)>; /*! * \brief Helper function for registering generic functions matching the @@ -242,7 +211,7 @@ inline PackedFunc WrapSchedule(FTVMScheduleBuilder builder) { if (argNodeRef->type_index() == outs->type_index()) { outs = args[0]; } else { - outs = Array { args[0] }; + outs = Array{args[0]}; } *ret = builder(target, outs); @@ -250,49 +219,49 @@ inline PackedFunc WrapSchedule(FTVMScheduleBuilder builder) { } TVM_REGISTER_GENERIC_FUNC(schedule_injective) -.set_default(WrapSchedule(topi::generic::schedule_injective)) -.register_func({ "cpu" }, WrapSchedule(topi::x86::schedule_injective)) -.register_func({ "cuda", "gpu" }, WrapSchedule(topi::cuda::schedule_injective)); + .set_default(WrapSchedule(topi::generic::schedule_injective)) + .register_func({"cpu"}, WrapSchedule(topi::x86::schedule_injective)) + .register_func({"cuda", "gpu"}, WrapSchedule(topi::cuda::schedule_injective)); TVM_REGISTER_GENERIC_FUNC(schedule_softmax) -.set_default(WrapSchedule(topi::generic::default_schedule)) -.register_func({ "cpu" }, WrapSchedule(topi::x86::default_schedule)) -.register_func({ "cuda", "gpu" }, WrapSchedule(topi::cuda::schedule_softmax)); + .set_default(WrapSchedule(topi::generic::default_schedule)) + .register_func({"cpu"}, WrapSchedule(topi::x86::default_schedule)) + .register_func({"cuda", "gpu"}, WrapSchedule(topi::cuda::schedule_softmax)); TVM_REGISTER_GENERIC_FUNC(schedule_dense) -.set_default(WrapSchedule(topi::generic::default_schedule)) -.register_func({ "cuda", "gpu" }, WrapSchedule(topi::cuda::schedule_dense)) -.register_func({ "rocm" }, WrapSchedule(topi::rocm::schedule_dense)); + .set_default(WrapSchedule(topi::generic::default_schedule)) + .register_func({"cuda", "gpu"}, WrapSchedule(topi::cuda::schedule_dense)) + .register_func({"rocm"}, WrapSchedule(topi::rocm::schedule_dense)); TVM_REGISTER_GENERIC_FUNC(schedule_batch_matmul) -.set_default(WrapSchedule(topi::generic::default_schedule)); + .set_default(WrapSchedule(topi::generic::default_schedule)); TVM_REGISTER_GENERIC_FUNC(schedule_pool) -.set_default(WrapSchedule(topi::generic::default_schedule)) -.register_func({ "cpu" }, WrapSchedule(topi::x86::default_schedule)) -.register_func({ "cuda", "gpu" }, WrapSchedule(topi::cuda::schedule_pool)); + .set_default(WrapSchedule(topi::generic::default_schedule)) + .register_func({"cpu"}, WrapSchedule(topi::x86::default_schedule)) + .register_func({"cuda", "gpu"}, WrapSchedule(topi::cuda::schedule_pool)); TVM_REGISTER_GENERIC_FUNC(schedule_global_pool) -.set_default(WrapSchedule(topi::generic::default_schedule)) -.register_func({ "cpu" }, WrapSchedule(topi::x86::default_schedule)) -.register_func({ "cuda", "gpu" }, WrapSchedule(topi::cuda::schedule_global_pool)); + .set_default(WrapSchedule(topi::generic::default_schedule)) + .register_func({"cpu"}, WrapSchedule(topi::x86::default_schedule)) + .register_func({"cuda", "gpu"}, WrapSchedule(topi::cuda::schedule_global_pool)); TVM_REGISTER_GENERIC_FUNC(schedule_reduce) -.set_default(WrapSchedule(topi::generic::default_schedule_auto_inline)) -.register_func({ "cpu" }, WrapSchedule(topi::x86::default_schedule_auto_inline)) -.register_func({ "cuda", "gpu" }, WrapSchedule(topi::cuda::schedule_reduce)); + .set_default(WrapSchedule(topi::generic::default_schedule_auto_inline)) + .register_func({"cpu"}, WrapSchedule(topi::x86::default_schedule_auto_inline)) + .register_func({"cuda", "gpu"}, WrapSchedule(topi::cuda::schedule_reduce)); TVM_REGISTER_GENERIC_FUNC(schedule_binarize_pack) -.set_default(WrapSchedule(topi::generic::default_schedule)) -.register_func({ "cpu" }, WrapSchedule(topi::x86::schedule_binarize_pack)); + .set_default(WrapSchedule(topi::generic::default_schedule)) + .register_func({"cpu"}, WrapSchedule(topi::x86::schedule_binarize_pack)); TVM_REGISTER_GENERIC_FUNC(schedule_binary_dense) -.set_default(WrapSchedule(topi::generic::default_schedule)) -.register_func({ "cpu" }, WrapSchedule(topi::x86::schedule_binary_dense)); + .set_default(WrapSchedule(topi::generic::default_schedule)) + .register_func({"cpu"}, WrapSchedule(topi::x86::schedule_binary_dense)); /*! \brief Builder function for instantiating schedules from existing schedules. */ -using FTVMScheduleFromExistingBuilder = std::function< - tvm::te::Schedule(tvm::te::Schedule sch, const tvm::te::Tensor& out)>; +using FTVMScheduleFromExistingBuilder = + std::function; /*! * \brief Helper function for registering generic functions matching the @@ -304,33 +273,30 @@ using FTVMScheduleFromExistingBuilder = std::function< * \return The wrapped schedule builder */ inline PackedFunc WrapScheduleFromExisting(FTVMScheduleFromExistingBuilder builder) { - return PackedFunc([builder](TVMArgs args, TVMRetValue* ret) { - *ret = builder(args[0], args[1]); - }); + return PackedFunc( + [builder](TVMArgs args, TVMRetValue* ret) { *ret = builder(args[0], args[1]); }); } TVM_REGISTER_GENERIC_FUNC(schedule_injective_from_existing) -.set_default(WrapScheduleFromExisting(topi::generic::schedule_injective_from_existing)) -.register_func({ "cpu" }, WrapScheduleFromExisting(topi::x86::schedule_injective_from_existing)) -.register_func({ "cuda", "gpu" }, WrapScheduleFromExisting( - topi::cuda::schedule_injective_from_existing)); + .set_default(WrapScheduleFromExisting(topi::generic::schedule_injective_from_existing)) + .register_func({"cpu"}, WrapScheduleFromExisting(topi::x86::schedule_injective_from_existing)) + .register_func({"cuda", "gpu"}, + WrapScheduleFromExisting(topi::cuda::schedule_injective_from_existing)); /*! \brief Builder function for instantiating dense ops. */ -using FTVMDenseOpBuilder = std::function; +using FTVMDenseOpBuilder = std::function; /*! -* \brief Helper function for registering dense ops matching the -* FTVMDenseOpBuilder signature. The op builder function is wrapped -* with a PackedFunc suitable for passing to a tvm::GenericFunc. -* -* \param builder The op builder to wrap. -* -* \return The wrapped op builder -*/ + * \brief Helper function for registering dense ops matching the + * FTVMDenseOpBuilder signature. The op builder function is wrapped + * with a PackedFunc suitable for passing to a tvm::GenericFunc. + * + * \param builder The op builder to wrap. + * + * \return The wrapped op builder + */ inline PackedFunc WrapDenseOp(FTVMDenseOpBuilder builder) { return PackedFunc([builder](TVMArgs args, TVMRetValue* ret) { auto target = Target::Current(false); @@ -344,14 +310,12 @@ inline PackedFunc WrapDenseOp(FTVMDenseOpBuilder builder) { } TVM_REGISTER_GENERIC_FUNC(dense) -.set_default(WrapDenseOp([](const Target& target, - const tvm::te::Tensor& data, - const tvm::te::Tensor& weight, - const tvm::te::Tensor& bias, - const DataType& out_dtype) { - return topi::nn::dense(data, weight, bias, out_dtype); -})) -.register_func({ "cuda", "gpu" }, WrapDenseOp(topi::cuda::dense_cuda)) -.register_func({ "rocm" }, WrapDenseOp(topi::rocm::dense_rocm)); + .set_default(WrapDenseOp([](const Target& target, const tvm::te::Tensor& data, + const tvm::te::Tensor& weight, const tvm::te::Tensor& bias, + const DataType& out_dtype) { + return topi::nn::dense(data, weight, bias, out_dtype); + })) + .register_func({"cuda", "gpu"}, WrapDenseOp(topi::cuda::dense_cuda)) + .register_func({"rocm"}, WrapDenseOp(topi::rocm::dense_rocm)); } // namespace topi diff --git a/topi/src/transform.cc b/topi/src/transform.cc index 4f0d4f8e6825..2791ff7dab1d 100644 --- a/topi/src/transform.cc +++ b/topi/src/transform.cc @@ -18,67 +18,56 @@ */ /*! -* \brief Registration of transform operators -* \file transform.cc -*/ -#include -#include - + * \brief Registration of transform operators + * \file transform.cc + */ #include #include +#include +#include namespace topi { using namespace tvm; using namespace tvm::runtime; -TVM_REGISTER_GLOBAL("topi.expand_dims") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.expand_dims").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = expand_dims(args[0], args[1], args[2]); - }); +}); -TVM_REGISTER_GLOBAL("topi.transpose") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.transpose").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = transpose(args[0], args[1]); - }); +}); -TVM_REGISTER_GLOBAL("topi.flip") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.flip").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = flip(args[0], args[1]); - }); +}); -TVM_REGISTER_GLOBAL("topi.reshape") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.reshape").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = reshape(args[0], args[1]); - }); +}); -TVM_REGISTER_GLOBAL("topi.squeeze") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.squeeze").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = squeeze(args[0], ArrayOrInt(args[1])); - }); +}); -TVM_REGISTER_GLOBAL("topi.concatenate") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.concatenate").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = concatenate(args[0], args[1]); - }); +}); -TVM_REGISTER_GLOBAL("topi.stack") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.stack").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = stack(args[0], args[1]); }); -TVM_REGISTER_GLOBAL("topi.shape") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.shape").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = shape(args[0], args[1]); }); -TVM_REGISTER_GLOBAL("topi.ndarray_size") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.ndarray_size").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = ndarray_size(args[0], args[1]); }); -TVM_REGISTER_GLOBAL("topi.split") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.split").set_body([](TVMArgs args, TVMRetValue* rv) { if (args[1].type_code() == kDLInt || args[1].type_code() == kDLUInt) { *rv = split_sections(args[0], args[1], args[2]); } else { @@ -86,13 +75,11 @@ TVM_REGISTER_GLOBAL("topi.split") } }); -TVM_REGISTER_GLOBAL("topi.layout_transform") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.layout_transform").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = layout_transform(args[0], args[1], args[2]); }); -TVM_REGISTER_GLOBAL("topi.take") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.take").set_body([](TVMArgs args, TVMRetValue* rv) { if (args.size() == 3) { std::string mode = args[2]; *rv = take(args[0], args[1], mode); @@ -101,56 +88,63 @@ TVM_REGISTER_GLOBAL("topi.take") std::string mode = args[3]; *rv = take(args[0], args[1], axis, mode); } - }); +}); -TVM_REGISTER_GLOBAL("topi.sequence_mask") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.sequence_mask").set_body([](TVMArgs args, TVMRetValue* rv) { double pad_val = args[2]; int axis = args[3]; *rv = sequence_mask(args[0], args[1], pad_val, axis); }); -TVM_REGISTER_GLOBAL("topi.where") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.where").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = where(args[0], args[1], args[2]); }); -TVM_REGISTER_GLOBAL("topi.arange") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.arange").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = arange(args[0], args[1], args[2], args[3]); }); -TVM_REGISTER_GLOBAL("topi.repeat") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.repeat").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = repeat(args[0], args[1], args[2]); }); -TVM_REGISTER_GLOBAL("topi.tile") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.tile").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = tile(args[0], args[1]); }); -TVM_REGISTER_GLOBAL("topi.gather_nd") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.gather").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = gather(args[0], args[1], args[2]); +}); + +TVM_REGISTER_GLOBAL("topi.gather_nd").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = gather_nd(args[0], args[1]); }); -TVM_REGISTER_GLOBAL("topi.unravel_index") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.unravel_index").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = unravel_index(args[0], args[1]); - }); - -TVM_REGISTER_GLOBAL("topi.matmul") -.set_body([](TVMArgs args, TVMRetValue *rv) { - switch ( args.size() ) { - case 2: *rv = matmul(args[0], args[1]); break; - case 3: *rv = matmul(args[0], args[1], args[2]); break; - case 4: *rv = matmul(args[0], args[1], args[2], args[3]); break; - default: CHECK(0) << "topi.matmul expects 2, 3 or 4 arguments"; - }}); - -TVM_REGISTER_GLOBAL("topi.tensordot") -.set_body([](TVMArgs args, TVMRetValue *rv) { +}); + +TVM_REGISTER_GLOBAL("topi.sparse_to_dense").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = sparse_to_dense(args[0], args[1], args[2], args[3]); +}); + +TVM_REGISTER_GLOBAL("topi.matmul").set_body([](TVMArgs args, TVMRetValue* rv) { + switch (args.size()) { + case 2: + *rv = matmul(args[0], args[1]); + break; + case 3: + *rv = matmul(args[0], args[1], args[2]); + break; + case 4: + *rv = matmul(args[0], args[1], args[2], args[3]); + break; + default: + CHECK(0) << "topi.matmul expects 2, 3 or 4 arguments"; + } +}); + +TVM_REGISTER_GLOBAL("topi.tensordot").set_body([](TVMArgs args, TVMRetValue* rv) { if (args.size() == 2) { *rv = tensordot(args[0], args[1]); } else if (args.size() == 3) { @@ -159,19 +153,17 @@ TVM_REGISTER_GLOBAL("topi.tensordot") Array axes = args[3]; *rv = tensordot(args[0], args[1], args[2], axes); } - }); +}); -TVM_REGISTER_GLOBAL("topi.strided_slice") -.set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = strided_slice(args[0], args[1], args[2], args[3]); - }); +TVM_REGISTER_GLOBAL("topi.strided_slice").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = strided_slice(args[0], args[1], args[2], args[3], args[4]); +}); -TVM_REGISTER_GLOBAL("topi.one_hot") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.one_hot").set_body([](TVMArgs args, TVMRetValue* rv) { int depth = args[3]; int axis = args[4]; DataType dtype = args[5]; *rv = one_hot(args[0], args[1], args[2], depth, axis, dtype); - }); +}); } // namespace topi diff --git a/topi/src/vision.cc b/topi/src/vision.cc index 1a4884e8d7c6..0485177cf9d5 100644 --- a/topi/src/vision.cc +++ b/topi/src/vision.cc @@ -18,22 +18,20 @@ */ /*! -* \brief Registration of vision operators -* \file vision.cc -*/ + * \brief Registration of vision operators + * \file vision.cc + */ +#include #include #include -#include - namespace topi { using namespace tvm; using namespace tvm::runtime; -TVM_REGISTER_GLOBAL("topi.vision.reorg") -.set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("topi.vision.reorg").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = vision::reorg(args[0], args[1]); - }); +}); } // namespace topi diff --git a/topi/tests/python/test_topi_batch_matmul.py b/topi/tests/python/test_topi_batch_matmul.py index b8c854746847..716f40700339 100644 --- a/topi/tests/python/test_topi_batch_matmul.py +++ b/topi/tests/python/test_topi_batch_matmul.py @@ -28,7 +28,7 @@ _batch_matmul_implement = { "generic": (topi.nn.batch_matmul, topi.generic.schedule_batch_matmul), "cpu": (topi.x86.batch_matmul, topi.x86.schedule_batch_matmul), - "gpu": (topi.nn.batch_matmul, topi.cuda.schedule_batch_matmul), + "gpu": (topi.cuda.batch_matmul, topi.cuda.schedule_batch_matmul), } def verify_batch_matmul(batch, M, N, K): diff --git a/topi/tests/python/test_topi_conv2d_nchw.py b/topi/tests/python/test_topi_conv2d_nchw.py index d42c8c7c24c0..11b799c712c0 100644 --- a/topi/tests/python/test_topi_conv2d_nchw.py +++ b/topi/tests/python/test_topi_conv2d_nchw.py @@ -75,7 +75,7 @@ def check_device(device): with tvm.target.create(device): if "cudnn" in device: - C = fcompute(A, W, (stride, stride), padding, (dilation, dilation), "NCHW", dtype) + C = fcompute(A, W, (stride, stride), padding, (dilation, dilation), 1, "NCHW", dtype) else: C = fcompute(A, W, (stride, stride), padding, (dilation, dilation), dtype) if add_bias: diff --git a/topi/tests/python/test_topi_conv2d_nhwc_winograd.py b/topi/tests/python/test_topi_conv2d_nhwc_winograd.py new file mode 100644 index 000000000000..7cb40417d2cc --- /dev/null +++ b/topi/tests/python/test_topi_conv2d_nhwc_winograd.py @@ -0,0 +1,155 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, too-many-locals, too-many-arguments +# pylint: disable=bad-whitespace +"""Example code to do convolution.""" + +import numpy as np +import tvm +import topi +import topi.testing +from tvm import te +from tvm.contrib.pickle_memoize import memoize +from tvm.contrib import nvcc +from topi.nn.util import get_pad_tuple +from topi.util import get_const_tuple + + +_conv2d_nhwc_winograd_tensorcore = { + "cuda": (topi.cuda.conv2d_nhwc_winograd_tensorcore, + topi.cuda.schedule_conv2d_nhwc_winograd_tensorcore) +} + +_conv2d_nhwc_winograd_direct = { + "cuda": (topi.cuda.conv2d_nhwc_winograd_direct, + topi.cuda.schedule_conv2d_nhwc_winograd_direct) +} + + +def verify_conv2d_nhwc(batch, in_channel, in_size, num_filter, kernel, stride, + padding, dilation=1, add_bias=False, add_relu=False, + devices='cuda', bgemm="direct"): + """Test the conv2d with winograd for nhwc layout""" + pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel)) + padding_sum = pad_top + pad_left + pad_bottom + pad_right + print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" % ( + batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)) + + in_height = in_width = in_size + + A = te.placeholder((batch, in_height, in_width, in_channel), name='A') + W = te.placeholder((kernel, kernel, in_channel, num_filter), name='W') + bias = te.placeholder((1, 1, 1, num_filter), name='bias') + + a_shape = get_const_tuple(A.shape) + w_shape = get_const_tuple(W.shape) + bias_shape = get_const_tuple(bias.shape) + dtype = A.dtype + + @memoize("topi.tests.test_topi_conv2d_nhwc.verify_conv2d_nhwc") + def get_ref_data(): + a_np = np.random.uniform(size=a_shape).astype(dtype) + w_np = np.random.uniform(size=w_shape).astype(dtype) + b_np = np.random.uniform(size=bias_shape).astype(dtype) + dw_np = topi.testing.dilate_python(w_np, (dilation, dilation, 1, 1)) + c_np = topi.testing.conv2d_nhwc_python(a_np, dw_np, stride, padding) + if add_bias: + b_np = np.random.uniform(size=bias_shape).astype(dtype) + c_np += b_np + if add_relu: + c_np = np.maximum(c_np, 0) + return a_np, w_np, b_np, c_np + + a_np, w_np, b_np, c_np = get_ref_data() + + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + return + print("Running on target: %s" % device) + with tvm.target.create(device): + if bgemm == "direct": + fcompute, fschedule = topi.testing.dispatch(device, + _conv2d_nhwc_winograd_direct) + elif bgemm == "tensorcore": + fcompute, fschedule = topi.testing.dispatch(device, + _conv2d_nhwc_winograd_tensorcore) + C = fcompute(A, W, stride, padding, dilation, 'float32') + if add_bias: + C = topi.add(C, bias) + if add_relu: + C = topi.nn.relu(C) + s = fschedule([C]) + + a = tvm.nd.array(a_np, ctx) + w = tvm.nd.array(w_np, ctx) + b = tvm.nd.array(b_np, ctx) + c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx) + if add_bias: + func = tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % ( + batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)) + func(a, w, b, c) + else: + func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % ( + batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation)) + func(a, w, c) + + tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=2e-3) + + check_device(devices) + + +def test_conv2d_nhwc_winograd_direct(): + """Test the conv2d with winograd for nhwc layout""" + # resnet 18 workloads + print("test_winograd_direct...") + verify_conv2d_nhwc(1, 64, 56, 64, 3, 1, 1, bgemm="direct") + verify_conv2d_nhwc(1, 128, 28, 128, 3, 1, 1) + verify_conv2d_nhwc(1, 256, 14, 256, 3, 1, 1) + verify_conv2d_nhwc(1, 512, 7, 512, 3, 1, 1) + verify_conv2d_nhwc(1, 48, 35, 64, 5, 1, 2) + + # weird workloads + verify_conv2d_nhwc(1, 1, 1, 1, 3, 1, 1) + verify_conv2d_nhwc(3, 3, 3, 3, 3, 1, 1) + verify_conv2d_nhwc(2, 13, 71, 59, 3, 1, 1) + + # Asymmetric padding + verify_conv2d_nhwc(1, 512, 7, 512, 3, 1, "SAME") + verify_conv2d_nhwc(2, 48, 56, 48, 3, 1, (1, 1), add_relu=True) + verify_conv2d_nhwc(2, 48, 56, 48, 3, 1, "SAME", add_relu=True, add_bias=True) + verify_conv2d_nhwc(1, 48, 35, 48, 5, 1, "VALID") + +def test_conv2d_nhwc_winograd_tensorcore(): + """Test the conv2d with winograd for nhwc layout""" + if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"): + print("skip because cuda is not enabled..") + return + if not nvcc.have_tensorcore(tvm.gpu(0).compute_version): + return + verify_conv2d_nhwc(8, 64, 56, 64, 3, 1, 1, bgemm="tensorcore") + verify_conv2d_nhwc(8, 128, 28, 128, 3, 1, 1, bgemm="tensorcore") + verify_conv2d_nhwc(8, 256, 14, 256, 3, 1, 1, bgemm="tensorcore") + + verify_conv2d_nhwc(2, 64, 56, 64, 3, 1, (1, 1), add_relu=True, bgemm="tensorcore") + verify_conv2d_nhwc(2, 64, 56, 64, 3, 1, "SAME", add_relu=True, bgemm="tensorcore") + + +if __name__ == "__main__": + test_conv2d_nhwc_winograd_direct() + test_conv2d_nhwc_winograd_tensorcore() diff --git a/topi/tests/python/test_topi_conv3d_transpose_ncdhw.py b/topi/tests/python/test_topi_conv3d_transpose_ncdhw.py new file mode 100644 index 000000000000..8b081987fd12 --- /dev/null +++ b/topi/tests/python/test_topi_conv3d_transpose_ncdhw.py @@ -0,0 +1,100 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test code for transposed convolution.""" +import numpy as np +import tvm +from tvm import te +import topi +import topi.testing +from tvm.contrib.pickle_memoize import memoize +from topi.util import get_const_tuple + +from common import get_all_backend + + +_conv3d_transpose_ncdhw_implement = { + "generic": (topi.nn.conv3d_transpose_ncdhw, topi.generic.schedule_conv3d_transpose_ncdhw), + "cpu": (topi.x86.conv3d_transpose_ncdhw, topi.x86.schedule_conv3d_transpose_ncdhw), + "gpu": (topi.cuda.conv3d_transpose_ncdhw, topi.cuda.schedule_conv3d_transpose_ncdhw), +} + +def verify_conv3d_transpose_ncdhw(batch, in_channel, in_size, num_filter, kernel, stride, padding): + in_depth, in_height, in_width = in_size + kernel_depth, kernel_height, kernel_width = kernel + stride_depth, stride_height, stride_width = stride + pad_front, pad_top, pad_left, pad_back, pad_bottom, pad_right = padding + + A = te.placeholder((batch, in_channel, in_depth, in_height, in_width), name='A') + W = te.placeholder((in_channel, num_filter, kernel_depth, kernel_height, kernel_width), name='W') + + a_shape = get_const_tuple(A.shape) + w_shape = get_const_tuple(W.shape) + dtype = A.dtype + + @memoize("topi.tests.test_topi_conv3d_transpose.verify_conv3d_transpose_ncdhw") + def get_ref_data(): + a_np = np.random.uniform(size=a_shape).astype(dtype) + w_np = np.random.uniform(size=w_shape).astype(dtype) + b_np = topi.testing.conv3d_transpose_ncdhw_python(a_np, w_np, stride, padding) + c_np = np.maximum(b_np, 0) + return a_np, w_np, b_np, c_np + + a_np, w_np, b_np, c_np = get_ref_data() + + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + return + print("Running on target: %s" % device) + with tvm.target.create(device): + fcompute, fschedule = topi.testing.dispatch(device, _conv3d_transpose_ncdhw_implement) + B = fcompute(A, W, + [stride_depth, stride_height, stride_width], + [pad_front, pad_top, pad_left, pad_back, pad_bottom, pad_right], + A.dtype) + C = topi.nn.relu(B) + s1 = fschedule([B]) + s2 = fschedule([C]) + a = tvm.nd.array(a_np, ctx) + w = tvm.nd.array(w_np, ctx) + b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) + c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx) + + func1 = tvm.build(s1, [A, W, B], device) + func2 = tvm.build(s2, [A, W, C], device) + func1(a, w, b) + func2(a, w, c) + tvm.testing.assert_allclose(b.asnumpy(), b_np, atol=1e-4, rtol=1e-4) + tvm.testing.assert_allclose(c.asnumpy(), c_np, atol=1e-4, rtol=1e-4) + for device in get_all_backend(): + check_device(device) + + +def test_conv3d_transpose_ncdhw(): + verify_conv3d_transpose_ncdhw(1, 3, (24, 24, 24), 1, (1, 1, 1), (1, 1, 1), (0, 0, 0, 0, 0, 0)) + verify_conv3d_transpose_ncdhw(1, 3, (24, 24, 24), 2, (3, 3, 3), (1, 1, 1), (0, 0, 0, 0, 0, 0)) + verify_conv3d_transpose_ncdhw(1, 3, (24, 24, 24), 16, (3, 3, 3), (1, 1, 1), (0, 0, 0, 0, 0, 0)) + verify_conv3d_transpose_ncdhw(1, 3, (24, 24, 24), 16, (3, 3, 3), (3, 3, 3), (0, 0, 0, 0, 0, 0)) + verify_conv3d_transpose_ncdhw(1, 3, (24, 24, 24), 16, (3, 3, 3), (1, 1, 1), (0, 0, 0, 0, 0, 0)) + verify_conv3d_transpose_ncdhw(1, 3, (24, 24, 24), 16, (3, 3, 3), (2, 2, 2), (1, 1, 1, 1, 1, 1)) + verify_conv3d_transpose_ncdhw(1, 3, (24, 24, 24), 16, (2, 2, 2), (2, 2, 2), (0, 0, 0, 0, 0, 0)) + verify_conv3d_transpose_ncdhw(1, 8, (32, 32, 32), 32, (5, 5, 5), (1, 1, 1), (0, 0, 0, 0, 0, 0)) + verify_conv3d_transpose_ncdhw(1, 8, (32, 32, 32), 64, (5, 5, 5), (2, 2, 2), (1, 1, 1, 1, 1, 1)) + +if __name__ == "__main__": + test_conv3d_transpose_ncdhw() diff --git a/topi/tests/python/test_topi_correlation.py b/topi/tests/python/test_topi_correlation.py new file mode 100644 index 000000000000..663564fab469 --- /dev/null +++ b/topi/tests/python/test_topi_correlation.py @@ -0,0 +1,93 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License +"""test of correlation operator in NCHW layout""" +import numpy as np +import tvm +from tvm import te +from tvm import autotvm +import topi +import topi.testing +from tvm.contrib.pickle_memoize import memoize +from topi.util import get_const_tuple + +from common import get_all_backend + + +_correlation_implement = { + "generic": (topi.nn.correlation_nchw, topi.generic.schedule_correlation_nchw), + "cuda": (topi.cuda.correlation_nchw, topi.cuda.schedule_correlation_nchw), +} + + +def verify_correlation_nchw(data_shape, kernel_size, max_displacement, stride1, stride2, pad_size, + is_multiply): + print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d, %d, %d)" % (data_shape[0], data_shape[1], data_shape[2], data_shape[3], + kernel_size, max_displacement, stride1, stride2, pad_size, + is_multiply)) + + A = te.placeholder(data_shape, name='data1') + B = te.placeholder(data_shape, name='data2') + dtype = A.dtype + + @memoize("topi.tests.test_topi_correlation_nchw.verify_correlation_nchw") + def get_ref_data(): + a_np = np.random.uniform(size=data_shape).astype(dtype) + b_np = np.random.uniform(size=data_shape).astype(dtype) + c_np = topi.testing.correlation_nchw_python(a_np, b_np, kernel_size, max_displacement, stride1, stride2, pad_size, is_multiply) + return a_np, b_np, c_np + + a_np, b_np, c_np = get_ref_data() + + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + return + print("Running on target: %s" % device) + fcompute, fschedule = topi.testing.dispatch( + device, _correlation_implement) + with tvm.target.create(device): + C = fcompute(A, B, kernel_size, max_displacement, stride1, stride2, pad_size, is_multiply) + s = fschedule([C]) + + a = tvm.nd.array(a_np, ctx) + b = tvm.nd.array(b_np, ctx) + c = tvm.nd.empty(c_np.shape, dtype=dtype, ctx=ctx) + + func = tvm.build(s, [A, B, C], device) + func(a, b, c) + tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5) + + for device in get_all_backend(): + check_device(device) + + +def test_correlation_nchw(): + verify_correlation_nchw((1, 3, 10, 10), kernel_size=1, max_displacement=4, + stride1=1, stride2=1, pad_size=4, is_multiply=True) + verify_correlation_nchw((1, 3, 10, 10), kernel_size=1, max_displacement=5, + stride1=1, stride2=1, pad_size=5, is_multiply=True) + verify_correlation_nchw((5, 1, 4, 4), kernel_size=3, max_displacement=1, + stride1=2, stride2=1, pad_size=2, is_multiply=True) + verify_correlation_nchw((5, 1, 6, 4), kernel_size=3, max_displacement=1, + stride1=2, stride2=2, pad_size=2, is_multiply=False) + verify_correlation_nchw((5, 1, 11, 11), kernel_size=5, max_displacement=1, + stride1=1, stride2=1, pad_size=2, is_multiply=False) + + +if __name__ == "__main__": + test_correlation_nchw() diff --git a/topi/tests/python/test_topi_dense.py b/topi/tests/python/test_topi_dense.py index 7498c004c8dd..6294c7d6818f 100644 --- a/topi/tests/python/test_topi_dense.py +++ b/topi/tests/python/test_topi_dense.py @@ -33,7 +33,6 @@ (topi.cuda.dense_large_batch, topi.cuda.schedule_dense_large_batch)], "mali": [(topi.mali.dense, topi.mali.schedule_dense)], "bifrost": [(topi.bifrost.dense, topi.bifrost.schedule_dense)], - "opengl": [(topi.nn.dense, topi.opengl.schedule_dense)], "rocm": [(topi.rocm.dense, topi.rocm.schedule_dense)], "hls": [(topi.nn.dense, topi.hls.schedule_dense)], } diff --git a/topi/tests/python/test_topi_image.py b/topi/tests/python/test_topi_image.py index 4eea75d68d28..012ed4207a1b 100644 --- a/topi/tests/python/test_topi_image.py +++ b/topi/tests/python/test_topi_image.py @@ -20,6 +20,7 @@ from tvm import te import topi import topi.testing +from tvm.contrib.pickle_memoize import memoize from common import get_all_backend @@ -204,7 +205,89 @@ def check_device(device): size_1, method='nearest_neighbor') verify_crop_and_resize((1, 3, 224, 224), boxes_1, indices_1, size_1, layout="NCHW") + +def test_affine_grid(): + def verify_affine_grid(num_batch, target_shape): + dtype = "float32" + data_shape = (num_batch, 2, 3) + data = te.placeholder(data_shape, dtype=dtype) + out = topi.image.affine_grid(data, target_shape) + + @memoize("topi.tests.test_affine_grid.verify_affine_grid") + def get_ref_data(): + data_np = np.random.uniform(size=data_shape).astype(dtype) + out_np = topi.testing.affine_grid_python(data_np, target_shape) + return data_np, out_np + + data_np, out_np = get_ref_data() + + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + return + print("Running on target: %s" % device) + with tvm.target.create(device): + s = topi.testing.get_injective_schedule(device)(out) + tvm_data = tvm.nd.array(data_np, ctx) + tvm_out = tvm.nd.empty(out_np.shape, dtype, ctx) + f = tvm.build(s, [data, out], device) + f(tvm_data, tvm_out) + + tvm.testing.assert_allclose( + tvm_out.asnumpy(), out_np, rtol=1e-5, atol=1e-5) + + for device in get_all_backend(): + check_device(device) + + verify_affine_grid(1, (16, 32)) + verify_affine_grid(4, (16, 32)) + + +def test_grid_sample(): + def verify_grid_sample(data_shape, grid_shape): + dtype = "float32" + data = te.placeholder(data_shape, dtype=dtype) + grid = te.placeholder(grid_shape, dtype=dtype) + out = topi.image.grid_sample(data, grid, 'bilinear') + + @memoize("topi.tests.test_grid_sample.verify_grid_sample") + def get_ref_data(): + data_np = np.random.uniform(size=data_shape).astype(dtype) + # allow grid values to be out-of-bound + grid_np = np.random.uniform(size=grid_shape, low=-1.5, high=1.5).astype(dtype) + out_np = topi.testing.grid_sample_nchw_python(data_np, grid_np, 'bilinear') + return data_np, grid_np, out_np + + data_np, grid_np, out_np = get_ref_data() + + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + return + print("Running on target: %s" % device) + with tvm.target.create(device): + s = topi.testing.get_injective_schedule(device)(out) + tvm_data = tvm.nd.array(data_np, ctx) + tvm_grid = tvm.nd.array(grid_np, ctx) + tvm_out = tvm.nd.empty(out_np.shape, dtype, ctx) + f = tvm.build(s, [data, grid, out], device) + f(tvm_data, tvm_grid, tvm_out) + + tvm.testing.assert_allclose( + tvm_out.asnumpy(), out_np, rtol=1e-5, atol=1e-5) + + for device in get_all_backend(): + check_device(device) + + verify_grid_sample((4, 4, 16, 32), (4, 2, 8, 8)) + verify_grid_sample((4, 4, 16, 32), (4, 2, 32, 32)) + + if __name__ == "__main__": test_resize() test_resize3d() test_crop_and_resize() + test_affine_grid() + test_grid_sample() diff --git a/topi/tests/python/test_topi_pooling.py b/topi/tests/python/test_topi_pooling.py index 9bdbb1073fd0..048de8168aa8 100644 --- a/topi/tests/python/test_topi_pooling.py +++ b/topi/tests/python/test_topi_pooling.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# pylint: disable=invalid-name, too-many-locals, too-many-statements, unused-argument """Test code for pooling""" import math import numpy as np @@ -44,6 +45,7 @@ } def verify_pool(n, ic, ih, kh, sh, padding, pool_type, ceil_mode, count_include_pad=True): + """verify function of pool""" iw = ih kw = kh sw = sh @@ -76,15 +78,17 @@ def verify_pool(n, ic, ih, kh, sh, padding, pool_type, ceil_mode, count_include_ for i in range(oh): for j in range(ow): if count_include_pad: - b_np[:,:,i,j] = np.mean(pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw], axis=(2,3)) + b_np[:, :, i, j] = \ + np.mean(pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw], axis=(2, 3)) else: - pad_count = np.sum(pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw] > 0, axis=(2,3)) - b_np[:,:,i,j] = np.sum(pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw], axis=(2,3)) / np.maximum(pad_count, 1) + pad_count = np.sum(pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw] > 0, axis=(2, 3)) + b_np[:, :, i, j] = np.sum(pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw], axis=(2, 3)) \ + / np.maximum(pad_count, 1) - elif pool_type =='max': + elif pool_type == 'max': for i in range(oh): for j in range(ow): - b_np[:,:,i,j] = np.max(pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw], axis=(2,3)) + b_np[:, :, i, j] = np.max(pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw], axis=(2, 3)) b_np = np.maximum(b_np, 0.0) def check_device(device): @@ -108,11 +112,11 @@ def check_device(device): def verify_pool_grad(n, ic, ih, kh, sh, padding, pool_type, ceil_mode, count_include_pad=True, add_relu=False): + """verify function of pool_grad""" iw = ih kw = kh sw = sh pt, pl, pb, pr = padding - layout = "NCHW" A = te.placeholder((n, ic, ih, iw), name='A') B = topi.nn.pool(A, kernel=[kh, kw], stride=[sh, sw], padding=padding, pool_type=pool_type, ceil_mode=ceil_mode, @@ -164,6 +168,7 @@ def check_device(device): check_device(device) def test_pool(): + """test cases of pool""" verify_pool(1, 256, 32, 2, 2, [0, 0, 0, 0], 'avg', False, True) verify_pool(1, 256, 31, 3, 3, [1, 2, 1, 2], 'avg', False, True) verify_pool(1, 256, 32, 2, 2, [1, 2, 1, 2], 'avg', False, False) @@ -179,6 +184,7 @@ def test_pool(): verify_pool(1, 256, 31, 3, 3, [3, 2, 1, 0], 'max', True) def test_pool_grad(): + """test cases of pool_grad""" verify_pool_grad(1, 256, 32, 3, 2, [1, 1, 1, 1], 'avg', False, False) verify_pool_grad(1, 256, 32, 2, 2, [0, 0, 0, 0], 'avg', False, True) verify_pool_grad(1, 256, 31, 3, 3, [1, 2, 1, 2], 'avg', False, True) @@ -200,10 +206,10 @@ def test_pool_grad(): verify_pool_grad(1, 256, 32, 2, 2, [0, 0, 0, 0], 'max', False, add_relu=True) -def verify_global_pool(n, c, h, w, pool_type, layout='NCHW'): - +def verify_global_pool(dshape, pool_type, layout='NCHW'): + """verify function of global_pool""" assert layout in ["NCHW", "NHWC"] - A = te.placeholder((n, c, h, w), name='A') + A = te.placeholder(shape=dshape, name='A') B = topi.nn.global_pool(A, pool_type=pool_type, layout=layout) B = topi.nn.relu(B) @@ -212,7 +218,7 @@ def verify_global_pool(n, c, h, w, pool_type, layout='NCHW'): axis = (layout.find('H'), layout.find('W')) if pool_type == 'avg': b_np = np.mean(a_np, axis=axis, keepdims=True) - elif pool_type =='max': + elif pool_type == 'max': b_np = np.max(a_np, axis=axis, keepdims=True) b_np = np.maximum(b_np, 0.0) @@ -224,7 +230,10 @@ def check_device(device): print("Running on target: %s" % device) with tvm.target.create(device): s_func = topi.testing.dispatch(device, _adaptive_pool_schedule) - s = s_func(B) + if device == "cuda": + s = s_func(B, layout) + else: + s = s_func(B) a = tvm.nd.array(a_np, ctx) b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) f = tvm.build(s, [A, B], device) @@ -235,17 +244,19 @@ def check_device(device): check_device(device) def test_global_pool(): - verify_global_pool(1, 1024, 7, 7, 'avg') - verify_global_pool(4, 1024, 7, 7, 'avg') - verify_global_pool(1, 1024, 7, 7, 'max') - verify_global_pool(4, 1024, 7, 7, 'max') - verify_global_pool(1, 1024, 7, 7, 'avg', 'NHWC') - verify_global_pool(4, 1024, 7, 7, 'avg', 'NHWC') - verify_global_pool(1, 1024, 7, 7, 'max', 'NHWC') - verify_global_pool(4, 1024, 7, 7, 'max', 'NHWC') + """test cases of global_pool""" + verify_global_pool((1, 1024, 7, 7), 'avg') + verify_global_pool((4, 1024, 7, 7), 'avg') + verify_global_pool((1, 1024, 7, 7), 'max') + verify_global_pool((4, 1024, 7, 7), 'max') + verify_global_pool((1, 7, 7, 1024), 'avg', 'NHWC') + verify_global_pool((4, 7, 7, 1024), 'avg', 'NHWC') + verify_global_pool((1, 7, 7, 1024), 'max', 'NHWC') + verify_global_pool((4, 7, 7, 1024), 'max', 'NHWC') def verify_adaptive_pool(dshape, out_size, pool_type, layout="NCHW", dtype="float32"): + """verify function of adaptive_pool""" np_data = np.random.uniform(low=0, high=255, size=dshape).astype(dtype) np_out = topi.testing.adaptive_pool(np_data, out_size, pool_type, layout) oshape = np_out.shape @@ -265,18 +276,22 @@ def check_device(device): print("Running on target: %s" % device) with tvm.target.create(device): s_func = topi.testing.dispatch(device, _adaptive_pool_schedule) - s = s_func(out) + if device == "cuda": + s = s_func(out, layout) + else: + s = s_func(out) a = tvm.nd.array(np_data, ctx) b = tvm.nd.array(np.zeros(get_const_tuple(oshape), dtype=out.dtype), ctx) f = tvm.build(s, [data, out], device) f(a, b) - tvm.testing.assert_allclose(b.asnumpy(), np_out, rtol=1e-5) + tvm.testing.assert_allclose(b.asnumpy(), np_out, rtol=4e-5, atol=1e-6) for device in get_all_backend(): check_device(device) def test_adaptive_pool(): + """test cases of adaptive_pool""" verify_adaptive_pool((1, 3, 224, 224), (1, 1), "max") verify_adaptive_pool((1, 3, 224, 224), (1, 1), "avg") verify_adaptive_pool((1, 14, 56, 78), (34, 13), "max") @@ -295,6 +310,7 @@ def test_adaptive_pool(): def verify_pool3d(n, ic, ih, kh, sh, padding, pool_type, ceil_mode, count_include_pad=True, layout='NCDHW'): + """verify function of pool3d""" id = iw = ih kd = kw = kh sd = sw = sh @@ -334,6 +350,7 @@ def check_device(device): def test_pool3d(): + """test cases of pool3d""" verify_pool3d(1, 256, 32, 2, 2, [0, 0, 0, 0, 0, 0], 'avg', False, True) verify_pool3d(1, 256, 31, 3, 3, [1, 1, 2, 2, 2, 1], 'avg', False, True) verify_pool3d(1, 256, 32, 2, 2, [1, 1, 2, 2, 2, 1], 'avg', False, False) @@ -351,6 +368,7 @@ def test_pool3d(): def verify_pool1d(n, ic, iw, kw, sw, padding, pool_type, ceil_mode, count_include_pad=True, layout='NCW'): + """verify function of pool1d""" input_shape = (n, ic, iw) kernel = [kw] stride = [sw] @@ -387,6 +405,7 @@ def check_device(device): def test_pool1d(): + """test cases of pool1d""" verify_pool1d(1, 256, 32, 2, 2, [0, 0], 'avg', False, True) verify_pool1d(1, 256, 31, 3, 3, [1, 2], 'avg', False, True) verify_pool1d(1, 256, 32, 2, 2, [1, 2], 'avg', False, False) diff --git a/topi/tests/python/test_topi_softmax.py b/topi/tests/python/test_topi_softmax.py index 485738700300..e21307405db7 100644 --- a/topi/tests/python/test_topi_softmax.py +++ b/topi/tests/python/test_topi_softmax.py @@ -31,7 +31,6 @@ "cpu": topi.x86.schedule_softmax, "gpu": topi.cuda.schedule_softmax, "hls": topi.hls.schedule_softmax, - "opengl": topi.opengl.schedule_softmax, } def check_device(A, B, a_np, b_np, device, name): diff --git a/topi/tests/python/test_topi_sparse.py b/topi/tests/python/test_topi_sparse.py index fc2d26b82842..748181dc650b 100644 --- a/topi/tests/python/test_topi_sparse.py +++ b/topi/tests/python/test_topi_sparse.py @@ -26,6 +26,12 @@ import time import scipy.sparse as sp +_sparse_dense_implement = { + "generic": (topi.nn.sparse_dense, topi.generic.schedule_sparse_dense), + "cuda": (topi.cuda.sparse_dense, topi.cuda.schedule_sparse_dense), + "x86": (topi.nn.sparse_dense, topi.x86.schedule_sparse_dense) +} + def verify_dynamic_csrmv(batch, in_dim, out_dim, use_bias=True): nr, nc, n = te.var("nr"), te.var("nc"), te.var("n") dtype = 'float32' @@ -282,27 +288,47 @@ def random_bsr_matrix(M, N, BS_R, BS_C, density, dtype): assert s.indptr.shape == (M // BS_R + 1, ) return s -def test_sparse_dense_bsr(): - M, N, K, BS_R, BS_C, density = 1, 64, 128, 8, 16, 0.9 +def verify_sparse_dense_bsr(M, N, K, BS_R, BS_C, density, use_relu): X_np = np.random.randn(M, K).astype("float32") W_sp_np = random_bsr_matrix(N, K, BS_R, BS_C, density=density, dtype="float32") W_np = W_sp_np.todense() Y_np = X_np.dot(W_np.T) + if use_relu: + Y_np = np.maximum(Y_np, 0.0) W_data = te.placeholder(shape=W_sp_np.data.shape, dtype=str(W_sp_np.data.dtype)) W_indices = te.placeholder(shape=W_sp_np.indices.shape, dtype=str(W_sp_np.indices.dtype)) W_indptr = te.placeholder(shape=W_sp_np.indptr.shape, dtype=str(W_sp_np.indptr.dtype)) X = te.placeholder(shape=X_np.shape, dtype=str(X_np.dtype)) - Y = topi.nn.sparse_dense(X, W_data, W_indices, W_indptr) - s = te.create_schedule(Y.op) - func = tvm.build(s, [X, W_data, W_indices, W_indptr, Y]) - Y_tvm = tvm.nd.array(np.zeros(Y_np.shape, dtype=Y_np.dtype)) - func(tvm.nd.array(X_np), - tvm.nd.array(W_sp_np.data), - tvm.nd.array(W_sp_np.indices), - tvm.nd.array(W_sp_np.indptr), - Y_tvm) - tvm.testing.assert_allclose(Y_tvm.asnumpy(), Y_np, atol=1e-4, rtol=1e-4) + + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + return + print("Running on target: %s" % device) + fcompute, fschedule = topi.testing.dispatch(device, _sparse_dense_implement) + with tvm.target.create(device): + Y = fcompute(X, W_data, W_indices, W_indptr) + if use_relu: + Y = topi.nn.relu(Y) + s = fschedule([Y]) + func = tvm.build(s, [X, W_data, W_indices, W_indptr, Y]) + Y_tvm = tvm.nd.array(np.zeros(Y_np.shape, dtype=Y_np.dtype), ctx=ctx) + func(tvm.nd.array(X_np, ctx=ctx), + tvm.nd.array(W_sp_np.data, ctx=ctx), + tvm.nd.array(W_sp_np.indices, ctx=ctx), + tvm.nd.array(W_sp_np.indptr, ctx=ctx), + Y_tvm) + tvm.testing.assert_allclose(Y_tvm.asnumpy(), Y_np, atol=1e-4, rtol=1e-4) + + for device in ['llvm', 'cuda']: + check_device(device) + +def test_sparse_dense_bsr(): + M, N, K, BS_R, BS_C, density = 1, 64, 128, 8, 16, 0.9 + verify_sparse_dense_bsr(M, N, K, BS_R, BS_C, density, use_relu=True) + verify_sparse_dense_bsr(M, N, K, BS_R, BS_C, density, use_relu=False) def test_sparse_dense_bsr_randomized(): for _ in range(20): @@ -322,16 +348,28 @@ def test_sparse_dense_bsr_randomized(): W_indices = te.placeholder(shape=W_sp_np.indices.shape, dtype=str(W_sp_np.indices.dtype)) W_indptr = te.placeholder(shape=W_sp_np.indptr.shape, dtype=str(W_sp_np.indptr.dtype)) X = te.placeholder(shape=X_np.shape, dtype=str(X_np.dtype)) - Y = topi.nn.sparse_dense(X, W_data, W_indices, W_indptr) - s = te.create_schedule(Y.op) - func = tvm.build(s, [X, W_data, W_indices, W_indptr, Y]) - Y_tvm = tvm.nd.array(np.zeros(Y_np.shape, dtype=Y_np.dtype)) - func(tvm.nd.array(X_np), - tvm.nd.array(W_sp_np.data), - tvm.nd.array(W_sp_np.indices), - tvm.nd.array(W_sp_np.indptr), - Y_tvm) - tvm.testing.assert_allclose(Y_tvm.asnumpy(), Y_np, atol=1e-5, rtol=1e-5) + + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + return + print("Running on target: %s" % device) + fcompute, fschedule = topi.testing.dispatch(device, _sparse_dense_implement) + with tvm.target.create(device): + Y = fcompute(X, W_data, W_indices, W_indptr) + s = fschedule([Y]) + func = tvm.build(s, [X, W_data, W_indices, W_indptr, Y]) + Y_tvm = tvm.nd.array(np.zeros(Y_np.shape, dtype=Y_np.dtype), ctx=ctx) + func(tvm.nd.array(X_np, ctx=ctx), + tvm.nd.array(W_sp_np.data, ctx=ctx), + tvm.nd.array(W_sp_np.indices, ctx=ctx), + tvm.nd.array(W_sp_np.indptr, ctx=ctx), + Y_tvm) + tvm.testing.assert_allclose(Y_tvm.asnumpy(), Y_np, atol=1e-5, rtol=1e-5) + + for device in ['llvm', 'cuda']: + check_device(device) def test_sparse_dense(): diff --git a/topi/tests/python/test_topi_transform.py b/topi/tests/python/test_topi_transform.py index b98ce09bfab8..96df101b092e 100644 --- a/topi/tests/python/test_topi_transform.py +++ b/topi/tests/python/test_topi_transform.py @@ -402,6 +402,35 @@ def check_device(device): for device in ["llvm", "opencl", "sdaccel", "aocl_sw_emu"]: check_device(device) +def verify_gather(data, axis, indices): + data = np.asarray(data) + indices = np.asarray(indices) + + var_data = te.placeholder(shape=data.shape, dtype=data.dtype.name, name="data") + var_indices = te.placeholder(shape=indices.shape, dtype=indices.dtype.name, name="indices") + out_tensor = topi.gather(var_data, axis, var_indices) + + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + return + print("Running on target: %s" % device) + with tvm.target.create(device): + s = topi.testing.get_injective_schedule(device)(out_tensor) + + func = tvm.build(s, [var_data, var_indices, out_tensor] , device, name="gather") + out_npys = topi.testing.gather_python(data, axis, indices) + + data_nd = tvm.nd.array(data, ctx) + indices_nd = tvm.nd.array(indices, ctx) + out_nd = tvm.nd.empty(out_npys.shape, ctx=ctx, dtype=data.dtype.name) + func(data_nd, indices_nd, out_nd) + tvm.testing.assert_allclose(out_nd.asnumpy(), out_npys) + + for device in get_all_backend(): + check_device(device) + def verify_gather_nd(src_shape, indices_src, indices_dtype): src_dtype = "float32" indices_src = np.array(indices_src, dtype=indices_dtype) @@ -595,6 +624,47 @@ def check_device(device): for device in get_all_backend(): check_device(device) +def verify_sparse_to_dense(sparse_indices, sparse_values, default_value, output_shape, xpected): + sparse_indices_data = np.array(sparse_indices) + sparse_values_data = np.array(sparse_values) + output_shape_data = np.array(output_shape) + default_value_data = np.array(default_value) + + A = te.placeholder(shape=sparse_indices_data.shape, name="sparse_indices", dtype=str(sparse_indices_data.dtype)) + B = te.placeholder(shape=sparse_values_data.shape, name="sparse_values", dtype=str(sparse_values_data.dtype)) + if default_value is None: + args = [A, B] + D = topi.sparse_to_dense(A, output_shape, B) + else: + C = te.placeholder(shape=(), name="default_value", dtype=str(default_value_data.dtype)) + args = [A, B, C] + D = topi.sparse_to_dense(A, output_shape, B, C) + + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + return + print("Running on target: %s" % device) + with tvm.target.create(device): + s = topi.testing.get_injective_schedule(device)(D) + + foo = tvm.build(s, args + [D], device, name="sparse_to_dense") + + sparse_indices_nd = tvm.nd.array(sparse_indices_data, ctx) + sparse_values_nd = tvm.nd.array(sparse_values_data, ctx) + out_nd = tvm.nd.empty(output_shape_data, ctx=ctx, dtype=B.dtype) + + if default_value is None: + foo(sparse_indices_nd, sparse_values_nd, out_nd) + else: + default_value_nd = tvm.nd.array(default_value_data, ctx) + foo(sparse_indices_nd, sparse_values_nd, default_value_nd, out_nd) + + tvm.testing.assert_allclose(out_nd.asnumpy(), np.array(xpected)) + + for device in get_all_backend(): + check_device(device) def test_strided_slice(): verify_strided_slice((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, -1, 2]) @@ -732,6 +802,15 @@ def test_take(): verify_take((3,4), [0, 2], axis=0, mode="fast") verify_take((3,4), [0, 2], axis=1, mode="fast") +def test_gather(): + verify_gather([[1, 2], [3, 4]], 1, [[0, 0], [1, 0]]) + verify_gather(np.random.randn(4, 7, 5), 0, np.random.randint(low=0, high=4, size=(1, 7, 5))) + verify_gather(np.random.randn(4, 7, 5), 0, np.random.randint(low=0, high=4, size=(4, 7, 5))) + verify_gather(np.random.randn(4, 7, 5), 1, np.random.randint(low=0, high=7, size=(4, 10, 5))) + verify_gather(np.random.randn(4, 7, 5), 1, np.random.randint(low=0, high=7, size=(4, 10, 5))) + verify_gather(np.random.randn(4, 7, 5), 2, np.random.randint(low=0, high=5, size=(4, 7, 2))) + verify_gather(np.random.randn(4, 7, 5), 2, np.random.randint(low=0, high=5, size=(4, 7, 10))) + def test_gather_nd(): for indices_dtype in ['int32', 'float32']: verify_gather_nd((4,), [[1.8]], indices_dtype) @@ -924,6 +1003,27 @@ def test_unravel_index(): verify_unravel_index(144, [5, 5, 5, 2], dtype) verify_unravel_index([100, 13, 5], [5, 5, 5, 2], dtype) +def test_sparse_to_dense(): + verify_sparse_to_dense(1, 3, 0, [5], [0, 3, 0, 0, 0]) #scalar + verify_sparse_to_dense([0, 1, 4], [3, 3, 3], 0, [5], [3, 3, 0, 0, 3]) #vector + verify_sparse_to_dense([[0, 0], [1, 2]], [1, 2], 0, [3, 4], [[1, 0, 0, 0],[0, 0, 2, 0],[0, 0, 0, 0]]) #nXd + verify_sparse_to_dense( + [[0, 0, 0], [1, 2, 3]], + [1, 2], + 4, + [2, 3, 4], + [[[1, 4, 4, 4], [4, 4, 4, 4], [4, 4, 4, 4]], [[4, 4, 4, 4], [4, 4, 4, 4], [4, 4, 4, 2]]] + ) #nXd + verify_sparse_to_dense([0, 1, 4], [3.1, 3.1, 3.1], 3.5, [5], [3.1, 3.1, 3.5, 3.5, 3.1]) #floats + verify_sparse_to_dense(1, 3, None, [5], [0, 3, 0, 0, 0]) # default value not specified + + #negative test cases + #sparse indices should be ints + #verify_sparse_to_dense([[0.1, 1.1, 4.1], [0,2,4]], [3.1, 3.1, 3.1], 3.5, [5], [3.1, 3.1, 3.5, 3.5, 3.1]) + #sparse_values should be 0d or 1d only + #verify_sparse_to_dense([[0, 1, 4], [0, 2, 4]], [[[3.1, 3.1, 3.1]]], 3.5, [5], [3.1, 3.1, 3.5, 3.5, 3.1]) + #sparse_indices should not be > 2d tensor + #verify_sparse_to_dense([[[[0, 1, 4], [0, 2, 4]]]], [[[3.1, 3.1, 3.1]]], 3.5, [5], [3.1, 3.1, 3.5, 3.5, 3.1]) if __name__ == "__main__": test_strided_slice() @@ -949,3 +1049,4 @@ def test_unravel_index(): test_where_fusion() test_one_hot() test_unravel_index() + test_sparse_to_dense() diff --git a/topi/tests/python/test_topi_vision.py b/topi/tests/python/test_topi_vision.py index 3ccb44d0f47c..d2331ee0c7f7 100644 --- a/topi/tests/python/test_topi_vision.py +++ b/topi/tests/python/test_topi_vision.py @@ -69,6 +69,7 @@ def verify_get_valid_counts(dshape, score_threshold, id_index, score_index): np_data = np.random.uniform(low=-2, high=2, size=dshape).astype(dtype) np_out1 = np.zeros(shape=(batch_size,)) np_out2 = np.zeros(shape=dshape).astype(dtype) + np_out3 = np.zeros(shape=(batch_size, num_anchor)) for i in range(batch_size): np_out1[i] = 0 inter_idx = 0 @@ -78,10 +79,12 @@ def verify_get_valid_counts(dshape, score_threshold, id_index, score_index): for k in range(elem_length): np_out2[i, inter_idx, k] = np_data[i, j, k] np_out1[i] += 1 + np_out3[i, inter_idx] = j inter_idx += 1 if j >= np_out1[i]: for k in range(elem_length): np_out2[i, j, k] = -1.0 + np_out3[i, j] = -1 def check_device(device): ctx = tvm.context(device, 0) @@ -98,10 +101,18 @@ def check_device(device): tvm_input_data = tvm.nd.array(np_data, ctx) tvm_out1 = tvm.nd.array(np.zeros(np_out1.shape, dtype="int32"), ctx) tvm_out2 = tvm.nd.array(np.zeros(np_out2.shape, dtype=dtype), ctx) - f = tvm.build(s, [data, outs[0], outs[1]], device) - f(tvm_input_data, tvm_out1, tvm_out2) - tvm.testing.assert_allclose(tvm_out1.asnumpy(), np_out1, rtol=1e-3) - tvm.testing.assert_allclose(tvm_out2.asnumpy(), np_out2, rtol=1e-3) + tvm_out3 = tvm.nd.array(np.zeros(np_out3.shape, dtype="int32"), ctx) + if device == "llvm": + f = tvm.build(s, [data, outs[0], outs[1], outs[2]], device) + f(tvm_input_data, tvm_out1, tvm_out2, tvm_out3) + tvm.testing.assert_allclose(tvm_out1.asnumpy(), np_out1, rtol=1e-3) + tvm.testing.assert_allclose(tvm_out2.asnumpy(), np_out2, rtol=1e-3) + tvm.testing.assert_allclose(tvm_out3.asnumpy(), np_out3, rtol=1e-3) + else: + f = tvm.build(s, [data, outs[0], outs[1]], device) + f(tvm_input_data, tvm_out1, tvm_out2) + tvm.testing.assert_allclose(tvm_out1.asnumpy(), np_out1, rtol=1e-3) + tvm.testing.assert_allclose(tvm_out2.asnumpy(), np_out2, rtol=1e-3) """ Skip this test as it is intermittent see https://github.com/apache/incubator-tvm/pull/4901#issuecomment-595040094 @@ -114,19 +125,21 @@ def check_device(device): def test_get_valid_counts(): + verify_get_valid_counts((1, 1000, 5), 0.5, -1, 0) verify_get_valid_counts((1, 2500, 6), 0, 0, 1) verify_get_valid_counts((1, 2500, 5), -1, -1, 0) verify_get_valid_counts((3, 1000, 6), 0.55, 1, 0) verify_get_valid_counts((16, 500, 5), 0.95, -1, 1) -def verify_non_max_suppression(np_data, np_valid_count, np_result, np_indices_result, iou_threshold, - force_suppress, top_k, coord_start, score_index, id_index): +def verify_non_max_suppression(np_data, np_valid_count, np_indices, np_result, np_indices_result, + iou_threshold, force_suppress, top_k, coord_start, score_index, id_index): dshape = np_data.shape batch, num_anchors, _ = dshape indices_dshape = (batch, num_anchors) data = te.placeholder(dshape, name="data") valid_count = te.placeholder((batch,), dtype="int32", name="valid_count") + indices = te.placeholder((batch, num_anchors), dtype="int32", name="indices") def check_device(device): ctx = tvm.context(device, 0) @@ -136,25 +149,31 @@ def check_device(device): print("Running on target: %s" % device) with tvm.target.create(device): fcompute, fschedule = topi.testing.dispatch(device, _nms_implement) - out = fcompute(data, valid_count, -1, iou_threshold, force_suppress, top_k, + out = fcompute(data, valid_count, indices, -1, iou_threshold, force_suppress, top_k, coord_start=coord_start, score_index=score_index, id_index=id_index, return_indices=False) - indices_out = fcompute(data, valid_count, -1, iou_threshold, force_suppress, top_k, - coord_start=coord_start, score_index=score_index, id_index=id_index) + indices_out = fcompute(data, valid_count, indices, -1, iou_threshold, force_suppress, top_k, + coord_start=coord_start, score_index=score_index, id_index=id_index, + return_indices=True) s = fschedule(out) indices_s = fschedule(indices_out) tvm_data = tvm.nd.array(np_data, ctx) tvm_valid_count = tvm.nd.array(np_valid_count, ctx) + tvm_indices = tvm.nd.array(np_indices, ctx) tvm_out = tvm.nd.array(np.zeros(dshape, dtype=data.dtype), ctx) - f = tvm.build(s, [data, valid_count, out], device) - f(tvm_data, tvm_valid_count, tvm_out) + f = tvm.build(s, [data, valid_count, indices, out], device) + f(tvm_data, tvm_valid_count, tvm_indices, tvm_out) tvm.testing.assert_allclose(tvm_out.asnumpy(), np_result, rtol=1e-4) tvm_indices_out = tvm.nd.array(np.zeros(indices_dshape, dtype="int32"), ctx) - f = tvm.build(indices_s, [data, valid_count, indices_out], device) - f(tvm_data, tvm_valid_count, tvm_indices_out) + if device == 'llvm': + f = tvm.build(indices_s, [data, valid_count, indices, indices_out[0]], device) + f(tvm_data, tvm_valid_count, tvm_indices, tvm_indices_out) + else: + f = tvm.build(indices_s, [data, valid_count, indices, indices_out], device) + f(tvm_data, tvm_valid_count, tvm_indices, tvm_indices_out) tvm.testing.assert_allclose(tvm_indices_out.asnumpy(), np_indices_result, rtol=1e-4) for device in ['llvm', 'cuda', 'opencl']: @@ -166,23 +185,24 @@ def test_non_max_suppression(): [0, 0.4, 4, 21, 19, 40], [2, 0.9, 35, 61, 52, 79], [1, 0.5, 100, 60, 70, 110]]]).astype("float32") np_valid_count = np.array([4]).astype("int32") + np_indices = np.array([[0, 1, 2, 3, 4]]).astype("int32") np_result = np.array([[[2, 0.9, 35, 61, 52, 79], [0, 0.8, 1, 20, 25, 45], [-1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1]]]) np_indices_result = np.array([[3, 0, -1, -1, -1]]) - verify_non_max_suppression(np_data, np_valid_count, np_result, np_indices_result, 0.7, True, 2, 2, 1, 0) + verify_non_max_suppression(np_data, np_valid_count, np_indices, np_result, np_indices_result, 0.7, True, 2, 2, 1, 0) np_data = np.array([[[0.8, 1, 20, 25, 45], [0.7, 30, 60, 50, 80], [0.4, 4, 21, 19, 40], [0.9, 35, 61, 52, 79], [0.5, 100, 60, 70, 110]]]).astype("float32") np_valid_count = np.array([4]).astype("int32") + np_indices = np.array([[0, 1, 2, 3, 4]]).astype("int32") np_result = np.array([[[0.9, 35, 61, 52, 79], [0.8, 1, 20, 25, 45], [-1, -1, -1, -1, -1], [-1, -1, -1, -1, -1], [-1, -1, -1, -1, -1]]]) np_indices_result = np.array([[3, 0, -1, -1, -1]]) - verify_non_max_suppression(np_data, np_valid_count, np_result, np_indices_result, 0.7, False, 2, 1, 0, -1) - + verify_non_max_suppression(np_data, np_valid_count, np_indices, np_result, np_indices_result, 0.7, False, 2, 1, 0, -1) def verify_multibox_prior(dshape, sizes=(1,), ratios=(1,), steps=(-1, -1), offsets=(0.5, 0.5), clip=False): @@ -459,9 +479,9 @@ def test_proposal(): if __name__ == "__main__": test_get_valid_counts() - test_non_max_suppression() test_multibox_prior() test_multibox_detection() test_roi_align() test_roi_pool() test_proposal() + test_non_max_suppression() diff --git a/tutorials/autotvm/tune_relay_arm.py b/tutorials/autotvm/tune_relay_arm.py index ffd3e8b9b5cb..3b07097ce696 100644 --- a/tutorials/autotvm/tune_relay_arm.py +++ b/tutorials/autotvm/tune_relay_arm.py @@ -311,7 +311,7 @@ def tune_and_evaluate(tuning_opt): # compile kernels with history best records with autotvm.apply_history_best(log_file): print("Compile...") - with relay.build_config(opt_level=3): + with tvm.transform.PassContext(opt_level=3): graph, lib, params = relay.build_module.build( mod, target=target, params=params) diff --git a/tutorials/autotvm/tune_relay_cuda.py b/tutorials/autotvm/tune_relay_cuda.py index 4195075ca66d..a6fe45b96263 100644 --- a/tutorials/autotvm/tune_relay_cuda.py +++ b/tutorials/autotvm/tune_relay_cuda.py @@ -222,7 +222,7 @@ def tune_and_evaluate(tuning_opt): # compile kernels with history best records with autotvm.apply_history_best(log_file): print("Compile...") - with relay.build_config(opt_level=3): + with tvm.transform.PassContext(opt_level=3): graph, lib, params = relay.build_module.build( mod, target=target, params=params) diff --git a/tutorials/autotvm/tune_relay_mobile_gpu.py b/tutorials/autotvm/tune_relay_mobile_gpu.py index ad7460829329..4748f41e96c3 100644 --- a/tutorials/autotvm/tune_relay_mobile_gpu.py +++ b/tutorials/autotvm/tune_relay_mobile_gpu.py @@ -308,7 +308,7 @@ def tune_and_evaluate(tuning_opt): # compile kernels with history best records with autotvm.apply_history_best(log_file): print("Compile...") - with relay.build_config(opt_level=3): + with tvm.transform.PassContext(opt_level=3): graph, lib, params = relay.build_module.build( mod, target=target, params=params, target_host=target_host) # export library diff --git a/tutorials/autotvm/tune_relay_x86.py b/tutorials/autotvm/tune_relay_x86.py index 15ce2de4b82f..dcc5b25c8288 100644 --- a/tutorials/autotvm/tune_relay_x86.py +++ b/tutorials/autotvm/tune_relay_x86.py @@ -189,7 +189,7 @@ def tune_and_evaluate(tuning_opt): # compile kernels with graph-level best records with autotvm.apply_graph_best(graph_opt_sch_file): print("Compile...") - with relay.build_config(opt_level=3): + with tvm.transform.PassContext(opt_level=3): graph, lib, params = relay.build_module.build( mod, target=target, params=params) diff --git a/tutorials/dev/low_level_custom_pass.py b/tutorials/dev/low_level_custom_pass.py index 25ca279bf339..17f864f4414e 100644 --- a/tutorials/dev/low_level_custom_pass.py +++ b/tutorials/dev/low_level_custom_pass.py @@ -40,8 +40,6 @@ take a look at ``python/tvm/build_module.py`` to get some basics. """ - -from __future__ import absolute_import, print_function import tvm from tvm import te import numpy as np @@ -57,7 +55,7 @@ c = te.compute((n, ), lambda i: a[i] + b[i], name='c') sch = te.create_schedule(c.op) -ir = tvm.lower(sch, [a, b, c], simple_mode=True) +ir = tvm.lower(sch, [a, b, c]) print(ir) ###################################################################### @@ -72,7 +70,7 @@ # # IR Visitor # ~~~~~~~~~~ -# We can use ``tvm.tir.ir_pass.PostOrderVisit(stmt, func)`` to gather information from the Halide IR. +# We can use ``tvm.tir.stmt_functor.post_order_visit(stmt, func)`` to gather information from the Halide IR. # ``func`` is a function callback. This function will be called before exiting the current IR node, # i.e. post-order visit. Then we leverage side effects to store the result of IR visit, because the # return value of ``func`` will be ignored. @@ -86,7 +84,7 @@ loops = [] def find_width8(op): - """ Find all the 'For' nodes whose extent can be divided by 8. """ + """ Find all the 'tir.For' nodes whose extent can be divided by 8. """ if isinstance(op, tvm.tir.For): if isinstance(op.extent, tvm.tir.IntImm): if op.extent.value % 8 == 0: @@ -113,39 +111,35 @@ def vectorize8(op): extent = op.extent.value name = op.loop_var.name lo, li = te.var(name + '.outer'), te.var(name + '.inner') - body = tvm.tir.ir_pass.Substitute(op.body, {op.loop_var: lo * 8 + li}) + body = tvm.tir.stmt_functor.substitute(op.body, {op.loop_var: lo * 8 + li}) body = tvm.tir.For(li, 0, 8, tvm.tir.For.Vectorized, 0, body) body = tvm.tir.For(lo, 0, extent // 8, tvm.tir.For.Serial, 0, body) return body return None -def vectorize(stmt): +@tvm.tir.transform.prim_func_pass(opt_level=0) +def vectorize(f, mod, ctx): global loops - tvm.tir.ir_pass.PostOrderVisit(stmt, find_width8) + tvm.tir.stmt_functor.post_order_visit(f.body, find_width8) if not loops: - return stmt + return sf # The last list arugment indicates what kinds of nodes will be transformed. # Thus, in this case only `For` nodes will call `vectorize8` - stmt = tvm.tir.ir_pass.IRTransform(stmt, None, vectorize8, ['For']) + return f.with_body( + tvm.tir.stmt_functor.ir_transform(f.body, None, vectorize8, ['tir.For'])) - return stmt ##################################################################### # Glue to Lowering # ---------------- # So far, we are done with writing this IR transformation pass. What we need to do next is to glue -# this pass to TVM's lower pass. We can first call this function directly as a sanity check. +# this pass to TVM's lower pass. # - -print(vectorize(ir)) - -##################################################################### -# In TVM, there is a property called ``BuildConfig``. You can use this property to customize your -# own lowering options. In this case, we inject the pass written above into the TVM standard lowering -# pass by feeding **a list of tuple** as argument to ``add_lower_pass``. "Tuple" indicates different +# In this case, we inject the pass written above into the TVM standard lowering +# pass by feeding **a list of tuple** as argument to ``tir.add_lower_pass``. "Tuple" indicates different # phases of lowering. In TVM, there are four phases of lowering and user-customized ones will be # called after each phase is done. # @@ -159,15 +153,15 @@ def vectorize(stmt): # Thus, a good place to put this transformation pass is just after Phase 1. # -with tvm.target.build_config(add_lower_pass=[(1, vectorize)]) as cfg: - print(tvm.lower(sch, [a, b, c], simple_mode=True)) +with tvm.transform.PassContext(config={"tir.add_lower_pass": [(1, vectorize)]}): + print(tvm.lower(sch, [a, b, c])) ##################################################################### # Quick View # ---------- # This tutorial gives a quick view of writing a customized IR transformation pass: -# - Use ``tvm.tir.ir_pass.PostOrderVisit`` to gather information on each IR nodes. -# - Use ``tvm.tir.ir_pass.IRTransform`` to transform IR nodes. +# - Use ``tvm.tir.stmt_functor.post_order_visit`` to gather information on each IR nodes. +# - Use ``tvm.tir.stmt_functor.ir_transform`` to transform IR nodes. # - Wrap up two above to write an IR-transformation function. -# - Use ``tvm.target.build_config`` to put this function to TVM lowering pass +# - Use ``tvm.transform.PassContext`` to put this function to TVM lowering pass # diff --git a/tutorials/dev/relay_pass_infra.py b/tutorials/dev/relay_pass_infra.py index 980d96ccc119..df40733164e8 100644 --- a/tutorials/dev/relay_pass_infra.py +++ b/tutorials/dev/relay_pass_infra.py @@ -160,7 +160,7 @@ def alter_conv2d(attrs, inputs, tinfos, out_type): # however, provides a configuration interface # for users to customize the optimization level that they want to execute. -with relay.build_config(opt_level=3): +with tvm.transform.PassContext(opt_level=3): mod2 = seq(mod) print(mod2) @@ -173,7 +173,7 @@ def alter_conv2d(attrs, inputs, tinfos, out_type): # EliminateCommonSubexpr as following. The printed module will again show two # identical addition operations. -with relay.build_config(opt_level=3, disabled_pass=["EliminateCommonSubexpr"]): +with tvm.transform.PassContext(opt_level=3, disabled_pass=["EliminateCommonSubexpr"]): mod3 = seq(mod) print(mod3) @@ -182,12 +182,12 @@ def alter_conv2d(attrs, inputs, tinfos, out_type): # provides a means to make pass target-aware. For example, the layout # alteration pass falls in such category. -with relay.build_config(opt_level=3): +with tvm.transform.PassContext(opt_level=3): mod4 = seq(mod) print(mod4) seq1 = tvm.transform.Sequential([relay.transform.AlterOpLayout()]) -with relay.build_config(opt_level=3): +with tvm.transform.PassContext(opt_level=3): with tvm.target.create("llvm"): mod5 = seq1(mod) print(mod5) @@ -242,7 +242,7 @@ def visit_constant(self, c): relay.transform.EliminateCommonSubexpr(), relay.transform.FuseOps(), tvm.transform.PrintIR()]) -with relay.build_config(opt_level=3): +with tvm.transform.PassContext(opt_level=3): mod = seq(mod) print("done") diff --git a/tutorials/frontend/build_gcn.py b/tutorials/frontend/build_gcn.py index 6ac518e42032..19719a5378eb 100644 --- a/tutorials/frontend/build_gcn.py +++ b/tutorials/frontend/build_gcn.py @@ -336,7 +336,7 @@ def prepare_params(g, data): mod = tvm.IRModule() mod["main"] = func # Build with Relay -with relay.build_config(opt_level=0): # Currently only support opt_level=0 +with tvm.transform.PassContext(opt_level=0): # Currently only support opt_level=0 graph, lib, params = relay.build(mod, target, params=params) # Generate graph runtime diff --git a/tutorials/frontend/deploy_model_on_android.py b/tutorials/frontend/deploy_model_on_android.py index 17ec9cb6baa1..bc5b5239a889 100644 --- a/tutorials/frontend/deploy_model_on_android.py +++ b/tutorials/frontend/deploy_model_on_android.py @@ -263,7 +263,7 @@ def transform_image(image): shape_dict = {input_name: x.shape} mod, params = relay.frontend.from_keras(keras_mobilenet_v2, shape_dict) -with relay.build_config(opt_level=3): +with tvm.transform.PassContext(opt_level=3): graph, lib, params = relay.build(mod, target=target, target_host=target_host, params=params) diff --git a/tutorials/frontend/deploy_model_on_rasp.py b/tutorials/frontend/deploy_model_on_rasp.py index ef707feedd2f..25df34128415 100644 --- a/tutorials/frontend/deploy_model_on_rasp.py +++ b/tutorials/frontend/deploy_model_on_rasp.py @@ -179,7 +179,7 @@ def transform_image(image): # The above line is a simple form of # target = tvm.target.create('llvm -device=arm_cpu -model=bcm2837 -target=armv7l-linux-gnueabihf -mattr=+neon') -with relay.build_config(opt_level=3): +with tvm.transform.PassContext(opt_level=3): graph, lib, params = relay.build(func, target, params=params) # After `relay.build`, you will get three return values: graph, diff --git a/tutorials/frontend/deploy_prequantized.py b/tutorials/frontend/deploy_prequantized.py index 40279778c045..d6183d68ad4a 100644 --- a/tutorials/frontend/deploy_prequantized.py +++ b/tutorials/frontend/deploy_prequantized.py @@ -81,7 +81,7 @@ def get_synset(): def run_tvm_model(mod, params, input_name, inp, target="llvm"): - with relay.build_config(opt_level=3): + with tvm.transform.PassContext(opt_level=3): json, lib, params = relay.build(mod, target=target, params=params) runtime = tvm.contrib.graph_runtime.create(json, lib, tvm.context(target, 0)) diff --git a/tutorials/frontend/deploy_prequantized_tflite.py b/tutorials/frontend/deploy_prequantized_tflite.py new file mode 100644 index 000000000000..ecd283ac46c8 --- /dev/null +++ b/tutorials/frontend/deploy_prequantized_tflite.py @@ -0,0 +1,257 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Deploy a Framework-prequantized Model with TVM - Part 3 (TFLite) +================================================================ +**Author**: `Siju Samuel `_ + +Welcome to part 3 of the Deploy Framework-Prequantized Model with TVM tutorial. +In this part, we will start with a Quantized TFLite graph and then compile and execute it via TVM. + + +For more details on quantizing the model using TFLite, readers are encouraged to +go through `Converting Quantized Models +`_. + +The TFLite models can be downloaded from this `link +`_. + +To get started, Tensorflow and TFLite package needs to be installed as prerequisite. + +.. code-block:: bash + + # install tensorflow and tflite + pip install tensorflow==2.1.0 + pip install tflite==2.1.0 + +Now please check if TFLite package is installed successfully, ``python -c "import tflite"`` + +""" + +############################################################################### +# Necessary imports +# ----------------- +import os + +import numpy as np +import tflite + +import tvm +from tvm import relay + + +###################################################################### +# Download pretrained Quantized TFLite model +# ------------------------------------------ + +# Download mobilenet V2 TFLite model provided by Google +from tvm.contrib.download import download_testdata + +model_url = "https://storage.googleapis.com/download.tensorflow.org/models/" \ + "tflite_11_05_08/mobilenet_v2_1.0_224_quant.tgz" + +# Download model tar file and extract it to get mobilenet_v2_1.0_224.tflite +model_path = download_testdata(model_url, "mobilenet_v2_1.0_224_quant.tgz", + module=['tf', 'official']) +model_dir = os.path.dirname(model_path) + + +###################################################################### +# Utils for downloading and extracting zip files +# ---------------------------------------------- +def extract(path): + import tarfile + if path.endswith("tgz") or path.endswith("gz"): + dir_path = os.path.dirname(path) + tar = tarfile.open(path) + tar.extractall(path=dir_path) + tar.close() + else: + raise RuntimeError('Could not decompress the file: ' + path) + +extract(model_path) + + +###################################################################### +# Load a test image +# ----------------- + +####################################################################### +# Get a real image for e2e testing +# -------------------------------- +def get_real_image(im_height, im_width): + from PIL import Image + repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV1/' + img_name = 'elephant-299.jpg' + image_url = os.path.join(repo_base, img_name) + img_path = download_testdata(image_url, img_name, module='data') + image = Image.open(img_path).resize((im_height, im_width)) + x = np.array(image).astype('uint8') + data = np.reshape(x, (1, im_height, im_width, 3)) + return data + +data = get_real_image(224, 224) + +###################################################################### +# Load a tflite model +# ------------------- + +###################################################################### +# Now we can open mobilenet_v2_1.0_224.tflite +tflite_model_file = os.path.join(model_dir, "mobilenet_v2_1.0_224_quant.tflite") +tflite_model_buf = open(tflite_model_file, "rb").read() + +# Get TFLite model from buffer +try: + import tflite + tflite_model = tflite.Model.GetRootAsModel(tflite_model_buf, 0) +except AttributeError: + import tflite.Model + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0) + +############################################################################### +# Lets run TFLite pre-quantized model inference and get the TFLite prediction. +def run_tflite_model(tflite_model_buf, input_data): + """ Generic function to execute TFLite """ + try: + from tensorflow import lite as interpreter_wrapper + except ImportError: + from tensorflow.contrib import lite as interpreter_wrapper + + input_data = input_data if isinstance(input_data, list) else [input_data] + + interpreter = interpreter_wrapper.Interpreter(model_content=tflite_model_buf) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + output_details = interpreter.get_output_details() + + # set input + assert len(input_data) == len(input_details) + for i in range(len(input_details)): + interpreter.set_tensor(input_details[i]['index'], input_data[i]) + + # Run + interpreter.invoke() + + # get output + tflite_output = list() + for i in range(len(output_details)): + tflite_output.append(interpreter.get_tensor(output_details[i]['index'])) + + return tflite_output + +############################################################################### +# Lets run TVM compiled pre-quantized model inference and get the TVM prediction. +def run_tvm(graph, lib, params): + from tvm.contrib import graph_runtime + rt_mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0)) + rt_mod.set_input(**params) + rt_mod.set_input('input', data) + rt_mod.run() + tvm_res = rt_mod.get_output(0).asnumpy() + tvm_pred = np.squeeze(tvm_res).argsort()[-5:][::-1] + return tvm_pred, rt_mod + + +############################################################################### +# TFLite inference +# ---------------- + +############################################################################### +# Run TFLite inference on the quantized model. +tflite_res = run_tflite_model(tflite_model_buf, data) +tflite_pred = np.squeeze(tflite_res).argsort()[-5:][::-1] + +############################################################################### +# TVM compilation and inference +# ----------------------------- + +############################################################################### +# We use the TFLite-Relay parser to convert the TFLite pre-quantized graph into Relay IR. Note that +# frontend parser call for a pre-quantized model is exactly same as frontend parser call for a FP32 +# model. We encourage you to remove the comment from print(mod) and inspect the Relay module. You +# will see many QNN operators, like, Requantize, Quantize and QNN Conv2D. +dtype_dict = {'input': data.dtype.name} +shape_dict = {'input': data.shape} + +mod, params = relay.frontend.from_tflite(tflite_model, + shape_dict=shape_dict, + dtype_dict=dtype_dict) +# print(mod) + +############################################################################### +# Lets now the compile the Relay module. We use the "llvm" target here. Please replace it with the +# target platform that you are interested in. +target = 'llvm' +with tvm.transform.PassContext(opt_level=3): + graph, lib, params = relay.build_module.build(mod, target=target, + params=params) + +############################################################################### +# Finally, lets call inference on the TVM compiled module. +tvm_pred, rt_mod = run_tvm(graph, lib, params) + +############################################################################### +# Accuracy comparison +# ------------------- + +############################################################################### +# Print the top-5 labels for MXNet and TVM inference. +# Checking the labels because the requantize implementation is different between +# TFLite and Relay. This cause final output numbers to mismatch. So, testing accuracy via labels. + +print("TVM Top-5 labels:", tvm_pred) +print("TFLite Top-5 labels:", tflite_pred) + + +########################################################################## +# Measure performance +# ------------------- +# Here we give an example of how to measure performance of TVM compiled models. +n_repeat = 100 # should be bigger to make the measurement more accurate +ctx = tvm.cpu(0) +ftimer = rt_mod.module.time_evaluator("run", ctx, number=1, repeat=n_repeat) +prof_res = np.array(ftimer().results) * 1e3 +print("Elapsed average ms:", np.mean(prof_res)) + +###################################################################### +# .. note:: +# +# Unless the hardware has special support for fast 8 bit instructions, quantized models are +# not expected to be any faster than FP32 models. Without fast 8 bit instructions, TVM does +# quantized convolution in 16 bit, even if the model itself is 8 bit. +# +# For x86, the best performance can be achieved on CPUs with AVX512 instructions set. +# In this case, TVM utilizes the fastest available 8 bit instructions for the given target. +# This includes support for the VNNI 8 bit dot product instruction (CascadeLake or newer). +# For EC2 C5.12x large instance, TVM latency for this tutorial is ~2 ms. +# +# Intel conv2d NCHWc schedule on ARM gives better end-to-end latency compared to ARM NCHW +# conv2d spatial pack schedule for many TFLite networks. ARM winograd performance is higher but +# it has a high memory footprint. +# +# Moreover, the following general tips for CPU performance equally applies: +# +# * Set the environment variable TVM_NUM_THREADS to the number of physical cores +# * Choose the best target for your hardware, such as "llvm -mcpu=skylake-avx512" or +# "llvm -mcpu=cascadelake" (more CPUs with AVX512 would come in the future) +# * Perform autotuning - `Auto-tuning a convolution network for x86 CPU +# `_. +# * To get best inference performance on ARM CPU, change target argument according to your +# device and follow `Auto-tuning a convolution network for ARM CPU +# `_. diff --git a/tutorials/frontend/deploy_ssd_gluoncv.py b/tutorials/frontend/deploy_ssd_gluoncv.py index 6126df0e73ab..e2fc3c59cb33 100644 --- a/tutorials/frontend/deploy_ssd_gluoncv.py +++ b/tutorials/frontend/deploy_ssd_gluoncv.py @@ -87,7 +87,7 @@ def build(target): mod, params = relay.frontend.from_mxnet(block, {"data": dshape}) - with relay.build_config(opt_level=3): + with tvm.transform.PassContext(opt_level=3): graph, lib, params = relay.build(mod, target, params=params) return graph, lib, params diff --git a/tutorials/frontend/from_caffe2.py b/tutorials/frontend/from_caffe2.py index 8fad80df1d1e..5988525f2da8 100644 --- a/tutorials/frontend/from_caffe2.py +++ b/tutorials/frontend/from_caffe2.py @@ -82,13 +82,13 @@ def transform_image(image): dtype_dict = {input_name: data.dtype} # parse Caffe2 model and convert into Relay computation graph -from tvm import relay +from tvm import relay, transform mod, params = relay.frontend.from_caffe2(resnet50.init_net, resnet50.predict_net, shape_dict, dtype_dict) # compile the model # target x86 CPU target = 'llvm' -with relay.build_config(opt_level=3): +with transform.PassContext(opt_level=3): graph, lib, params = relay.build(mod, target, params=params) ###################################################################### diff --git a/tutorials/frontend/from_coreml.py b/tutorials/frontend/from_coreml.py index 2a0c8dbc93f2..beac48325237 100644 --- a/tutorials/frontend/from_coreml.py +++ b/tutorials/frontend/from_coreml.py @@ -74,7 +74,7 @@ # Parse CoreML model and convert into Relay computation graph mod, params = relay.frontend.from_coreml(mlmodel, shape_dict) -with relay.build_config(opt_level=3): +with tvm.transform.PassContext(opt_level=3): graph, lib, params = relay.build(mod, target, params=params) diff --git a/tutorials/frontend/from_darknet.py b/tutorials/frontend/from_darknet.py index e2c1ea5aacbf..6d84463ca7f0 100644 --- a/tutorials/frontend/from_darknet.py +++ b/tutorials/frontend/from_darknet.py @@ -100,7 +100,7 @@ data = np.empty([batch_size, net.c, net.h, net.w], dtype) shape = {'data': data.shape} print("Compiling the model...") -with relay.build_config(opt_level=3): +with tvm.transform.PassContext(opt_level=3): graph, lib, params = relay.build(mod, target=target, target_host=target_host, diff --git a/tutorials/frontend/from_keras.py b/tutorials/frontend/from_keras.py index 928a8acbefa7..7ece790eb177 100644 --- a/tutorials/frontend/from_keras.py +++ b/tutorials/frontend/from_keras.py @@ -79,7 +79,7 @@ # compile the model target = 'cuda' ctx = tvm.gpu(0) -with relay.build_config(opt_level=3): +with tvm.transform.PassContext(opt_level=3): executor = relay.build_module.create_executor('graph', mod, ctx, target) ###################################################################### diff --git a/tutorials/frontend/from_mxnet.py b/tutorials/frontend/from_mxnet.py index d0e4c4ab0d18..6e6b2d79b209 100644 --- a/tutorials/frontend/from_mxnet.py +++ b/tutorials/frontend/from_mxnet.py @@ -90,7 +90,7 @@ def transform_image(image): ###################################################################### # now compile the graph target = 'cuda' -with relay.build_config(opt_level=3): +with tvm.transform.PassContext(opt_level=3): graph, lib, params = relay.build(func, target, params=params) ###################################################################### diff --git a/tutorials/frontend/from_onnx.py b/tutorials/frontend/from_onnx.py index 766451c2f8b1..9973a08153dd 100644 --- a/tutorials/frontend/from_onnx.py +++ b/tutorials/frontend/from_onnx.py @@ -74,7 +74,7 @@ shape_dict = {input_name: x.shape} mod, params = relay.frontend.from_onnx(onnx_model, shape_dict) -with relay.build_config(opt_level=1): +with tvm.transform.PassContext(opt_level=1): intrp = relay.build_module.create_executor('graph', mod, tvm.cpu(0), target) ###################################################################### diff --git a/tutorials/frontend/from_pytorch.py b/tutorials/frontend/from_pytorch.py index 8354b0eca193..53d29a9447be 100644 --- a/tutorials/frontend/from_pytorch.py +++ b/tutorials/frontend/from_pytorch.py @@ -101,7 +101,7 @@ target = 'llvm' target_host = 'llvm' ctx = tvm.cpu(0) -with relay.build_config(opt_level=3): +with tvm.transform.PassContext(opt_level=3): graph, lib, params = relay.build(mod, target=target, target_host=target_host, diff --git a/tutorials/frontend/from_tensorflow.py b/tutorials/frontend/from_tensorflow.py index 0ebd733ef9aa..b7b3d69c780b 100644 --- a/tutorials/frontend/from_tensorflow.py +++ b/tutorials/frontend/from_tensorflow.py @@ -144,7 +144,7 @@ # params: final params after compilation. # lib: target library which can be deployed on target with TVM runtime. -with relay.build_config(opt_level=3): +with tvm.transform.PassContext(opt_level=3): graph, lib, params = relay.build(mod, target=target, target_host=target_host, diff --git a/tutorials/frontend/from_tflite.py b/tutorials/frontend/from_tflite.py index 32738551f7c8..35a308c1d334 100644 --- a/tutorials/frontend/from_tflite.py +++ b/tutorials/frontend/from_tflite.py @@ -21,25 +21,12 @@ This article is an introductory tutorial to deploy TFLite models with Relay. -To get started, Flatbuffers and TFLite package needs to be installed as prerequisites. -A quick solution is to install Flatbuffers via pip +To get started, TFLite package needs to be installed as prerequisite. .. code-block:: bash - pip install flatbuffers --user - - -To install TFlite packages, you could use our prebuilt wheel: - -.. code-block:: bash - - # For python3: - wget https://github.com/FrozenGene/tflite/releases/download/v1.13.1/tflite-1.13.1-py3-none-any.whl - pip3 install -U tflite-1.13.1-py3-none-any.whl --user - - # For python2: - wget https://github.com/FrozenGene/tflite/releases/download/v1.13.1/tflite-1.13.1-py2-none-any.whl - pip install -U tflite-1.13.1-py2-none-any.whl --user + # install tflite + pip install tflite=2.1.0 --user or you could generate TFLite package yourself. The steps are the following: @@ -141,14 +128,14 @@ def extract(path): input_dtype = "float32" # Parse TFLite model and convert it to a Relay module -from tvm import relay +from tvm import relay, transform mod, params = relay.frontend.from_tflite(tflite_model, shape_dict={input_tensor: input_shape}, dtype_dict={input_tensor: input_dtype}) # Build the module against to x86 CPU target = "llvm" -with relay.build_config(opt_level=3): +with transform.PassContext(opt_level=3): graph, lib, params = relay.build(mod, target, params=params) ###################################################################### diff --git a/tutorials/language/tensorize.py b/tutorials/language/tensorize.py index 6224c10ed750..8a77c7764648 100644 --- a/tutorials/language/tensorize.py +++ b/tutorials/language/tensorize.py @@ -115,8 +115,7 @@ def intrin_func(ins, outs): bb.access_ptr("r"), m, l, bb.strides[0])) return ib.get() - with tvm.target.build_config(offset_factor=1): - return te.decl_tensor_intrin(c.op, intrin_func, binds={a: Ab, b: Bb, c: Cb}) + return te.decl_tensor_intrin(c.op, intrin_func, binds={a: Ab, b: Bb, c: Cb}) ###################################################################### # Here :code:`te.decl_tensor_intrin` declares how to execute the computation :code:`c.op`. @@ -269,8 +268,7 @@ def _reduce_reset(): def _reduce_update(): return _body() return _body(), _reduce_reset(), _reduce_update() - with tvm.target.build_config(offset_factor=1): - return te.decl_tensor_intrin(c.op, intrin_func, binds={a: Ab, b: Bb, c: Cb}) + return te.decl_tensor_intrin(c.op, intrin_func, binds={a: Ab, b: Bb, c: Cb}) ###################################################################### # Note that :code:`intrin_func` now returns a triplet: diff --git a/tutorials/optimize/opt_conv_tensorcore.py b/tutorials/optimize/opt_conv_tensorcore.py index 44b9de3b99ff..cd40a91ac6c8 100644 --- a/tutorials/optimize/opt_conv_tensorcore.py +++ b/tutorials/optimize/opt_conv_tensorcore.py @@ -331,7 +331,9 @@ def intrin_func(ins, outs): ctx = tvm.gpu(0) if nvcc.have_tensorcore(ctx.compute_version): - with tvm.target.build_config(auto_unroll_max_step=16): + with tvm.transform.PassContext(config={"tir.UnrollLoop": { + "auto_max_step": 16 + }}): func = tvm.build(s, [A, W, Conv], 'cuda') a_np = np.random.uniform(size=data_shape).astype(A.dtype) w_np = np.random.uniform(size=kernel_shape).astype(W.dtype) diff --git a/tutorials/optimize/opt_matmul_auto_tensorcore.py b/tutorials/optimize/opt_matmul_auto_tensorcore.py index 50cc1eb929a2..7dbd4757168f 100644 --- a/tutorials/optimize/opt_matmul_auto_tensorcore.py +++ b/tutorials/optimize/opt_matmul_auto_tensorcore.py @@ -287,10 +287,9 @@ def tune_and_evaluate(M, N, L, dtype, layout): print(best_config) with autotvm.apply_history_best('matmul.log'): with tvm.target.create("cuda"): - with tvm.target.build_config(): - s, arg_bufs = test_gemm(N, L, M, dtype, layout) - print(tvm.lower(s, arg_bufs, simple_mode=True)) - func = tvm.build(s, arg_bufs) + s, arg_bufs = test_gemm(N, L, M, dtype, layout) + print(tvm.lower(s, arg_bufs, simple_mode=True)) + func = tvm.build(s, arg_bufs) dev_module = func.imported_modules[0] print(dev_module.get_source()) diff --git a/tutorials/relay_quick_start.py b/tutorials/relay_quick_start.py index b2174a048035..e52a99aeccd4 100644 --- a/tutorials/relay_quick_start.py +++ b/tutorials/relay_quick_start.py @@ -96,7 +96,7 @@ opt_level = 3 target = tvm.target.cuda() -with relay.build_config(opt_level=opt_level): +with tvm.transform.PassContext(opt_level=opt_level): graph, lib, params = relay.build(mod, target, params=params) ##################################################################### diff --git a/vta/python/vta/build_module.py b/vta/python/vta/build_module.py index 4c33d36d69b5..2d67edb7051f 100644 --- a/vta/python/vta/build_module.py +++ b/vta/python/vta/build_module.py @@ -14,25 +14,22 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=unused-argument +# pylint: disable=unused-argument, invalid-name """VTA specific buildin for runtime.""" import tvm -from . import ir_pass +from . import transform from .environment import get_env -def lift_coproc_scope(x): - """Lift coprocessings cope to the """ - x = ir_pass.lift_alloc_to_scope_begin(x) - x = tvm.tir.ir_pass.LiftAttrScope(x, "coproc_scope", False) - return x - -def early_rewrite(stmt): +def EarlyRewrite(): """Try to do storage rewrite in early pass.""" - try: - return tvm.tir.ir_pass.StorageRewrite(stmt) - except tvm.error.TVMError: - return stmt + def _transform(mod, ctx): + try: + return tvm.tir.transform.StorageRewrite()(mod) + except tvm.error.TVMError: + return mod + return tvm.transform.module_pass( + _transform, opt_level=0, name="tir.vta.EarlyRewrite") def build_config(debug_flag=0, **kwargs): @@ -48,7 +45,7 @@ def build_config(debug_flag=0, **kwargs): Returns ------- - build_config: BuildConfig + build_config: tvm.transform.PassContext The build config that can be used in TVM. Example @@ -60,28 +57,40 @@ def build_config(debug_flag=0, **kwargs): vta_module = tvm.build(s, ...) """ env = get_env() - def add_debug(stmt): + + @tvm.tir.transform.prim_func_pass(opt_level=0) + def add_debug(f, *_): debug = tvm.tir.call_extern( "int32", "VTASetDebugMode", env.dev.command_handle, debug_flag) - return tvm.tir.stmt_seq(debug, stmt) - pass_list = [(0, ir_pass.inject_conv2d_transpose_skip), - (1, ir_pass.inject_dma_intrin), - (1, ir_pass.inject_skip_copy), - (1, ir_pass.annotate_alu_coproc_scope), - (1, lambda x: tvm.tir.ir_pass.LiftAttrScope(x, "coproc_uop_scope", True)), - (1, lift_coproc_scope), - (1, ir_pass.inject_coproc_sync), - (1, early_rewrite)] + return f.with_body(tvm.tir.stmt_seq(debug, f.body)) + + + pass_list = [(0, transform.InjectConv2DTransposeSkip()), + (1, transform.InjectDMAIntrin()), + (1, transform.InjectSkipCopy()), + (1, transform.AnnotateALUCoProcScope()), + (1, tvm.tir.transform.LiftAttrScope("coproc_uop_scope")), + (1, transform.LiftAllocToScopeBegin()), + (1, tvm.tir.transform.LiftAttrScope("coproc_scope")), + (1, transform.InjectCoProcSync()), + (1, EarlyRewrite())] if debug_flag: pass_list.append((1, add_debug)) - pass_list.append((2, ir_pass.inject_alu_intrin)) - pass_list.append((3, tvm.tir.ir_pass.LowerStorageAccessInfo)) - pass_list.append((3, ir_pass.fold_uop_loop)) - pass_list.append((3, ir_pass.cpu_access_rewrite)) - return tvm.target.build_config(add_lower_pass=pass_list, **kwargs) + pass_list.append((2, transform.InjectALUIntrin())) + pass_list.append((3, tvm.tir.transform.LowerDeviceStorageAccessInfo())) + pass_list.append((3, transform.FoldUopLoop())) + pass_list.append((3, transform.CPUAccessRewrite())) + config = { + "tir.add_lower_pass": pass_list + } + if kwargs.get("config"): + config.update(kwargs[config]) + del kwargs["config"] + + return tvm.transform.PassContext(config=config, **kwargs) def lower(*args, **kwargs): @@ -94,8 +103,8 @@ def lower(*args, **kwargs): -------- tvm.lower : The original TVM's lower function """ - cfg = tvm.target.BuildConfig.current() - if not cfg.add_lower_pass: + pass_ctx = tvm.transform.PassContext.current() + if not pass_ctx.config.get("add_lower_pass"): with build_config(): return tvm.lower(*args, **kwargs) return tvm.lower(*args, **kwargs) @@ -111,8 +120,8 @@ def build(*args, **kwargs): -------- tvm.build : The original TVM's build function """ - cfg = tvm.target.BuildConfig.current() - if not cfg.add_lower_pass: + pass_ctx = tvm.transform.PassContext.current() + if not pass_ctx.config.get("tir.add_lower_pass"): with build_config(): return tvm.build(*args, **kwargs) return tvm.build(*args, **kwargs) diff --git a/vta/python/vta/environment.py b/vta/python/vta/environment.py index bbaac2ce1797..e68f098ba53f 100644 --- a/vta/python/vta/environment.py +++ b/vta/python/vta/environment.py @@ -80,7 +80,7 @@ def __init__(self, env): ctx = tvm.tir.call_extern("handle", "VTATLSCommandHandle") self.command_handle = tvm.tir.Call( "handle", "tvm_thread_context", [ctx], - tvm.tir.Call.Intrinsic, None, 0) + tvm.tir.Call.Intrinsic) self.DEBUG_NO_SYNC = False env._dev_ctx = self self.gemm = intrin.gemm(env, env.mock_mode) diff --git a/vta/python/vta/exec/rpc_server.py b/vta/python/vta/exec/rpc_server.py index 9cfd50927041..220da4331dae 100644 --- a/vta/python/vta/exec/rpc_server.py +++ b/vta/python/vta/exec/rpc_server.py @@ -42,7 +42,7 @@ def server_start(): os.path.abspath(os.path.expanduser(__file__))) proj_root = os.path.abspath(os.path.join(curr_path, "../../../../")) dll_path = find_libvta("libvta")[0] - cfg_path = os.path.abspath(os.path.join(proj_root, "build/vta_config.json")) + cfg_path = os.path.abspath(os.path.join(proj_root, "3rdparty/vta-hw/config/vta_config.json")) runtime_dll = [] _load_module = tvm.get_global_func("tvm.rpc.server.load_module") diff --git a/vta/python/vta/ir_pass.py b/vta/python/vta/ir_pass.py deleted file mode 100644 index 8a7798ab7e6f..000000000000 --- a/vta/python/vta/ir_pass.py +++ /dev/null @@ -1,993 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Additional IR Pass for VTA""" -# pylint: disable=len-as-condition, no-else-return -import tvm -from tvm import te -from topi import util - -from .environment import get_env - - -def _match_pragma(stmt, key): - """Internal helper to match stmt to pragma stmt. - - Parameters - ---------- - stmt : Stmt - The AttrStmt - - key : str - The pragma key - """ - return ((stmt.attr_key == "pragma_" + key) or - (stmt.attr_key == "pragma_scope" and stmt.value.value == key)) - - -def fold_uop_loop(stmt_in): - """Detect and fold uop loop. - - VTA support uop programming model - that recognizes loop structure. - This pass detect the loop structure - and extract that into uop loop AST. - - Parameters - ---------- - stmt_in : Stmt - Input statement - - Returns - ------- - stmt_out : Stmt - Output statement. - """ - env = get_env() - - def _fold_outermost_loop(body): - stmt = body - if not isinstance(stmt, tvm.tir.For): - return None, body, None - - loop_var = stmt.loop_var - gemm_offsets = [None, None, None] - fail = [False] - - def _post_order(op): - assert isinstance(op, tvm.tir.Call) - base_args = 2 - if op.name == "VTAUopPush": - args = [] - args += op.args[:base_args] - for i in range(3): - m = tvm.arith.detect_linear_equation( - op.args[i + base_args], [loop_var]) - if not m: - fail[0] = True - return op - if gemm_offsets[i] is not None: - if not tvm.ir.structural_equal(m[0], gemm_offsets[i]): - fail[0] = True - return op - args.append(m[1]) - else: - gemm_offsets[i] = m[0] - args.append(m[1]) - args += op.args[base_args+3:] - return tvm.tir.call_extern("int32", "VTAUopPush", *args) - if op.name not in ("VTATLSCommandHandle", "tvm_thread_context"): - raise RuntimeError("unexpected op %s" % op) - return op - - ret = tvm.tir.ir_pass.IRTransform( - stmt.body, None, _post_order, ["Call"]) - - if not fail[0] and all(x is not None for x in gemm_offsets): - def _visit(op): - if op.same_as(loop_var): - fail[0] = True - tvm.tir.ir_pass.PostOrderVisit(ret, _visit) - if not fail[0]: - begin = tvm.tir.call_extern( - "int32", "VTAUopLoopBegin", stmt.extent, *gemm_offsets) - end = tvm.tir.call_extern("int32", "VTAUopLoopEnd") - return [begin, ret, end] - raise ValueError("Failed to fold the GEMM instructions..") - - def _do_fold(stmt): - if (stmt.attr_key == "coproc_uop_scope" and - isinstance(stmt.value, tvm.tir.StringImm) and - stmt.value.value == env.dev.vta_push_uop.value): - body = stmt.body - begins = [] - ends = [] - try: - begin, body, end = _fold_outermost_loop(body) - if begin is not None: - begins.append(begin) - if end is not None: - ends.append(end) - begin, body, end = _fold_outermost_loop(body) - if begin is not None: - begins.append(begin) - if end is not None: - ends.append(end) - except ValueError: - pass - if body == stmt.body: - return stmt - ends = list(reversed(ends)) - body = tvm.tir.stmt_seq(*(begins + [body] + ends)) - return tvm.tir.AttrStmt( - stmt.node, stmt.attr_key, stmt.value, body) - return None - out = tvm.tir.ir_pass.IRTransform( - stmt_in, _do_fold, None, ["AttrStmt"]) - return out - - -def cpu_access_rewrite(stmt_in): - """Detect CPU access to VTA buffer and get address correctly. - - VTA's buffer is an opaque handle that do not - correspond to address in CPU. - This pass detect CPU access and rewrite to use pointer - returned VTABufferCPUPtr for CPU access. - - Parameters - ---------- - stmt_in : Stmt - Input statement - - Returns - ------- - stmt_out : Stmt - Transformed statement - """ - env = get_env() - rw_info = {} - def _post_order(op): - if isinstance(op, tvm.tir.Allocate): - buffer_var = op.buffer_var - if not buffer_var in rw_info: - return None - new_var = rw_info[buffer_var] - let_stmt = tvm.tir.LetStmt( - new_var, tvm.tir.call_extern( - "handle", "VTABufferCPUPtr", - env.dev.command_handle, - buffer_var), op.body) - alloc = tvm.tir.Allocate( - buffer_var, op.dtype, op.extents, - op.condition, let_stmt) - del rw_info[buffer_var] - return alloc - if isinstance(op, tvm.tir.Load): - buffer_var = op.buffer_var - if not buffer_var in rw_info: - rw_info[buffer_var] = te.var( - buffer_var.name + "_ptr", "handle") - new_var = rw_info[buffer_var] - return tvm.tir.Load(op.dtype, new_var, op.index) - if isinstance(op, tvm.tir.Store): - buffer_var = op.buffer_var - if not buffer_var in rw_info: - rw_info[buffer_var] = te.var( - buffer_var.name + "_ptr", "handle") - new_var = rw_info[buffer_var] - return tvm.tir.Store(new_var, op.value, op.index) - raise RuntimeError("not reached") - stmt = tvm.tir.ir_pass.IRTransform( - stmt_in, None, _post_order, ["Allocate", "Load", "Store"]) - for buffer_var, new_var in rw_info.items(): - stmt = tvm.tir.LetStmt( - new_var, tvm.tir.call_extern( - "handle", "VTABufferCPUPtr", - env.dev.command_handle, - buffer_var), stmt) - return stmt - - -def lift_alloc_to_scope_begin(stmt_in): - """Lift allocate to beginning of the current scope. - - Parameters - ---------- - stmt_in : Stmt - Input statement - - Returns - ------- - stmt_out : Stmt - Transformed statement - """ - lift_stmt = [[]] - def _merge_block(slist, body): - for op in slist: - if op.body == body: - body = op - elif isinstance(op, tvm.tir.Allocate): - body = tvm.tir.Allocate( - op.buffer_var, op.dtype, - op.extents, op.condition, body) - elif isinstance(op, tvm.tir.AttrStmt): - body = tvm.tir.AttrStmt( - op.node, op.attr_key, op.value, body) - elif isinstance(op, tvm.tir.For): - body = tvm.tir.For( - op.loop_var, op.min, op.extent, op.for_type, - op.device_api, body) - else: - raise RuntimeError("unexpected op") - del slist[:] - return body - - def _pre_order(op): - if isinstance(op, tvm.tir.For): - lift_stmt.append([]) - elif isinstance(op, tvm.tir.AttrStmt): - if op.attr_key == "virtual_thread": - lift_stmt.append([]) - - def _post_order(op): - if isinstance(op, tvm.tir.Allocate): - lift_stmt[-1].append(op) - return op.body - if isinstance(op, tvm.tir.AttrStmt): - if op.attr_key == "storage_scope": - lift_stmt[-1].append(op) - return op.body - if op.attr_key == "virtual_thread": - return _merge_block(lift_stmt.pop() + [op], op.body) - return op - if isinstance(op, tvm.tir.For): - return _merge_block(lift_stmt.pop() + [op], op.body) - raise RuntimeError("not reached") - stmt = tvm.tir.ir_pass.IRTransform( - stmt_in, _pre_order, _post_order, ["Allocate", "AttrStmt", "For"]) - assert len(lift_stmt) == 1 - return _merge_block(lift_stmt[0], stmt) - - -def inject_skip_copy(stmt_in): - """Pass to inject skip copy stmt, used for debug purpose. - - Parameters - ---------- - stmt_in : Stmt - Input statement - - Returns - ------- - stmt_out : Stmt - Transformed statement - """ - def _do_fold(stmt): - if _match_pragma(stmt, "skip_dma_copy"): - return tvm.tir.Evaluate(0) - return None - return tvm.tir.ir_pass.IRTransform( - stmt_in, _do_fold, None, ["AttrStmt"]) - - -def inject_coproc_sync(stmt_in): - """Pass to inject skip copy stmt, used in debug. - - Parameters - ---------- - stmt_in : Stmt - Input statement - - Returns - ------- - stmt_out : Stmt - Transformed statement - """ - success = [False] - def _do_fold(stmt): - if _match_pragma(stmt, "coproc_sync"): - success[0] = True - sync = tvm.tir.Call( - "int32", "vta.coproc_sync", [], tvm.tir.Call.Intrinsic, None, 0) - return tvm.tir.SeqStmt([stmt.body, tvm.tir.Evaluate(sync)]) - if _match_pragma(stmt, "trim_loop"): - op = stmt.body - assert isinstance(op, tvm.tir.For) - return tvm.tir.For( - op.loop_var, op.min, 2, op.for_type, - op.device_api, op.body) - return None - stmt = tvm.tir.ir_pass.IRTransform( - stmt_in, None, _do_fold, ["AttrStmt"]) - stmt = tvm.tir.ir_pass.CoProcSync(stmt) - return stmt - - -def inject_dma_intrin(stmt_in): - """Pass to inject DMA copy intrinsics. - - Parameters - ---------- - stmt_in : Stmt - Input statement - - Returns - ------- - stmt_out : Stmt - Transformed statement - """ - env = get_env() - idxd = tvm.tir.indexdiv - idxm = tvm.tir.indexmod - - def _check_compact(buf): - ndim = len(buf.shape) - size = tvm.tir.const(1, buf.shape[0].dtype) - for i in reversed(range(ndim)): - if not util.equal_const_int(size - buf.strides[i], 0): - raise RuntimeError( - "Cannot prove compact: shape=%s, strides=%s" % (buf.shape, buf.strides)) - size = size * buf.shape[i] - - def _fold_buffer_dim(buf, scope, elem_block): - ndim = len(buf.shape) - x_size = 1 - base = 0 - for i in range(1, ndim + 1): - if not util.equal_const_int(buf.strides[ndim - i] - x_size, 0): - raise RuntimeError("scope %s needs to have block=%d" % (scope, elem_block)) - x_size = x_size * buf.shape[ndim - i] - if util.equal_const_int(x_size - elem_block, 0): - base = i + 1 - break - if base == 0: - raise RuntimeError("scope %s need to have block=%d, shape=%s" % ( - scope, elem_block, buf.shape)) - shape = [elem_block] - strides = [1] - - if base < ndim + 1 and not util.equal_const_int(buf.strides[ndim - base], elem_block): - shape.append(1) - strides.append(elem_block) - - while base < ndim + 1: - x_size = 1 - x_stride = buf.strides[ndim - base] - next_base = base - if not util.equal_const_int(idxm(x_stride, elem_block), 0): - raise RuntimeError( - "scope %s need to have block=%d, shape=%s, strides=%s" % ( - scope, elem_block, buf.shape, buf.strides)) - for i in range(base, ndim + 1): - k = ndim - i - if not util.equal_const_int(x_size * x_stride - buf.strides[k], 0): - break - x_size = x_size * buf.shape[k] - next_base = i + 1 - shape.append(tvm.tir.ir_pass.Simplify(x_size)) - strides.append(x_stride) - assert next_base != base - base = next_base - - strides = list(reversed(strides)) - shape = list(reversed(shape)) - return shape, strides - - def _get_2d_pattern(buf, elem_width, elem_bytes, dtype, scope, allow_fold): - elem_block = elem_bytes * 8 // elem_width - if buf.dtype != dtype: - raise RuntimeError("Expect buffer type to be %s instead of %s" % - (dtype, buf.dtype)) - shape, strides = buf.shape, buf.strides - if not util.equal_const_int(idxm(buf.elem_offset, elem_block), 0): - raise RuntimeError("scope %s need to have block=%d" % (scope, elem_block)) - if allow_fold: - shape, strides = _fold_buffer_dim(buf, scope, elem_block) - else: - shape = list(x for x in shape) - strides = list(x for x in strides) - - def raise_error(): - """Internal function to raise error """ - raise RuntimeError( - ("Scope[%s]: cannot detect 2d pattern with elem_block=%d:" + - " shape=%s, strides=%s") % (scope, elem_block, buf.shape, buf.strides)) - - ndim = len(shape) - - # Check if the inner-tensor is already flat - flat = util.equal_const_int(shape[-1], elem_block) - - if flat: - if not util.equal_const_int(strides[-1], 1): - raise_error() - - if ndim == 1: - x_size = 1 - x_stride = 1 - y_size = 1 - return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block) - if not util.equal_const_int(strides[-2] - elem_block, 0): - raise_error() - - if ndim == 2: - x_size = shape[-2] - x_stride = shape[-2] - y_size = 1 - return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block) - if not util.equal_const_int(idxm(strides[-3], elem_block), 0): - raise_error() - - if ndim == 3: - x_size = shape[-2] - x_stride = idxd(strides[-3], elem_block) - y_size = shape[-3] - return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block) - - else: - if not util.equal_const_int(strides[-1], 1): - raise_error() - if not util.equal_const_int(strides[-2] - shape[-1], 0): - raise_error() - if not util.equal_const_int(shape[-1] * shape[-2], elem_block): - raise_error() - - if ndim == 2: - x_size = 1 - x_stride = 1 - y_size = 1 - return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block) - if not util.equal_const_int(strides[-3], elem_block): - raise_error() - - if ndim == 3: - x_size = shape[-3] - x_stride = shape[-3] - y_size = 1 - return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block) - if not util.equal_const_int(idxm(strides[-4], elem_block), 0): - raise_error() - - if ndim == 4: - x_size = shape[-3] - x_stride = idxd(strides[-4], elem_block) - y_size = shape[-4] - return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block) - - raise_error() - - - def _inject_copy(src, dst, pad_before, pad_after, pad_value): - # FIXME: pad_value is ignored... - _ = pad_value - if dst.scope == "global": - # Store - if pad_before or pad_after: - raise RuntimeError("Do not support copy into DRAM with pad") - if src.scope == env.acc_scope: - elem_width = env.OUT_WIDTH - elem_bytes = env.OUT_ELEM_BYTES - mem_type = env.dev.MEM_ID_OUT - data_type = "int%d" % env.OUT_WIDTH - task_qid = env.dev.QID_STORE_OUT - else: - raise RuntimeError("Do not support copy %s->dram" % (src.scope)) - _check_compact(src) - x_size, y_size, x_stride, offset = _get_2d_pattern( - dst, elem_width, elem_bytes, data_type, src.scope, allow_fold=True) - irb = tvm.tir.ir_builder.create() - irb.scope_attr(env.dev.vta_axis, "coproc_scope", - env.dev.get_task_qid(task_qid)) - irb.emit(tvm.tir.call_extern( - "int32", "VTAStoreBuffer2D", - env.dev.command_handle, - src.access_ptr("r", "int32"), - mem_type, dst.data, offset, x_size, y_size, x_stride)) - return irb.get() - elif src.scope == "global": - if dst.scope == env.acc_scope: - elem_width = env.ACC_WIDTH - elem_bytes = env.ACC_ELEM_BYTES - mem_type = env.dev.MEM_ID_ACC - data_type = "int%d" % env.ACC_WIDTH - task_qid = env.dev.QID_LOAD_OUT - elif dst.scope == env.inp_scope: - elem_width = env.INP_WIDTH - elem_bytes = env.INP_ELEM_BYTES - mem_type = env.dev.MEM_ID_INP - data_type = "int%d" % env.INP_WIDTH - task_qid = env.dev.QID_LOAD_INP - elif dst.scope == env.wgt_scope: - elem_width = env.WGT_WIDTH - elem_bytes = env.WGT_ELEM_BYTES - mem_type = env.dev.MEM_ID_WGT - data_type = "int%d" % env.WGT_WIDTH - task_qid = env.dev.QID_LOAD_WGT - else: - raise RuntimeError("Do not support copy dram->%s" % (dst.scope)) - # collect pad statistics - if pad_before: - assert pad_after - ndim = len(pad_before) - if ndim <= 2 or ndim > 5: - raise ValueError("Limitation of 2D pad load forbid ndim=%d" % ndim) - if ndim == 5: - # This case occurs when batch size N > 1 - y_pad_before = pad_before[1] - x_pad_before = pad_before[2] - y_pad_after = pad_after[1] - x_pad_after = pad_after[2] - for dim in range(3, ndim): - if not util.equal_const_int(pad_before[dim], 0): - raise ValueError("Do not support pad on the innermost block") - if not util.equal_const_int(pad_after[dim], 0): - raise ValueError("Do not support pad on the innermost block") - else: - y_pad_before = pad_before[0] - x_pad_before = pad_before[1] - y_pad_after = pad_after[0] - x_pad_after = pad_after[1] - for dim in range(2, ndim): - if not util.equal_const_int(pad_before[dim], 0): - raise ValueError("Do not support pad on the innermost block") - if not util.equal_const_int(pad_after[dim], 0): - raise ValueError("Do not support pad on the innermost block") - allow_fold = False - else: - x_pad_before = 0 - y_pad_before = 0 - x_pad_after = 0 - y_pad_after = 0 - allow_fold = True - - _check_compact(dst) - x_size, y_size, x_stride, offset = _get_2d_pattern( - src, elem_width, elem_bytes, data_type, - dst.scope, allow_fold=allow_fold) - - irb = tvm.tir.ir_builder.create() - irb.scope_attr(env.dev.vta_axis, "coproc_scope", - env.dev.get_task_qid(task_qid)) - - irb.emit(tvm.tir.call_extern( - "int32", "VTALoadBuffer2D", - env.dev.command_handle, - src.data, offset, x_size, y_size, x_stride, - x_pad_before, y_pad_before, - x_pad_after, y_pad_after, - dst.access_ptr("r", "int32"), mem_type)) - return irb.get() - - else: - raise RuntimeError("Do not support copy %s->%s" % (src.scope, dst.scope)) - - return tvm.tir.ir_pass.InjectCopyIntrin(stmt_in, "dma_copy", _inject_copy) - - -def _get_gemm_intrin_buffer(): - env = get_env() - wgt_lanes = env.WGT_ELEM_BITS // env.WGT_WIDTH - assert wgt_lanes == env.BLOCK_OUT * env.BLOCK_IN - wgt_shape = (env.BLOCK_OUT, env.BLOCK_IN) - assert wgt_shape[0] * wgt_shape[1] == wgt_lanes - inp_lanes = env.INP_ELEM_BITS // env.INP_WIDTH - assert inp_lanes == env.BATCH * env.BLOCK_IN - inp_shape = (env.BATCH, env.BLOCK_IN) - assert inp_shape[0] * inp_shape[1] == inp_lanes - out_lanes = env.ACC_ELEM_BITS // env.ACC_WIDTH - assert out_lanes == env.BATCH * env.BLOCK_OUT - out_shape = (env.BATCH, env.BLOCK_OUT) - assert out_shape[0] * out_shape[1] == out_lanes - wgt = te.placeholder((wgt_shape[0], wgt_shape[1]), - dtype="int%d" % env.WGT_WIDTH, - name=env.wgt_scope) - inp = te.placeholder((inp_shape[0], inp_shape[1]), - dtype="int%d" % env.INP_WIDTH, - name=env.inp_scope) - k = te.reduce_axis((0, wgt_shape[1]), name="k") - out_dtype = "int%d" % env.ACC_WIDTH - out = te.compute((out_shape[0], out_shape[1]), - lambda i, j: te.sum(inp[i, k].astype(out_dtype) * - wgt[j, k].astype(out_dtype), - axis=[k]), - name="out") - wgt_layout = tvm.tir.decl_buffer( - wgt.shape, wgt.dtype, env.wgt_scope, - scope=env.wgt_scope, offset_factor=wgt_lanes, data_alignment=wgt_lanes) - inp_layout = tvm.tir.decl_buffer( - inp.shape, inp.dtype, env.inp_scope, - scope=env.inp_scope, offset_factor=inp_lanes, data_alignment=inp_lanes) - out_layout = tvm.tir.decl_buffer( - out.shape, out.dtype, env.acc_scope, - scope=env.acc_scope, offset_factor=out_lanes, data_alignment=out_lanes) - - return wgt_layout, inp_layout, out_layout - - -def inject_conv2d_transpose_skip(stmt_in): - """Pass to skip 0-weights in conv2d transpose with stride > 1. - - Parameters - ---------- - stmt_in : Stmt - Input statement - - Returns - ------- - stmt_out : Stmt - Transformed statement - """ - env = get_env() - dwgt, dinp, dout = _get_gemm_intrin_buffer() - - calls = [] - selects = [] - - def _find_basics(op): - if isinstance(op, tvm.tir.Call): - calls.append(op) - elif isinstance(op, tvm.tir.Select): - selects.append(op) - - def _do_fold(op): - if _match_pragma(op, "conv2d_transpose_gemm"): - is_init = ".init" in str(op) - tvm.tir.ir_pass.PostOrderVisit(op, _find_basics) - - if is_init: - # create inner most block - irb = tvm.tir.ir_builder.create() - dev = env.dev - irb.scope_attr(dev.vta_axis, "coproc_scope", dev.get_task_qid(dev.QID_COMPUTE)) - irb.scope_attr(dev.vta_axis, "coproc_uop_scope", dev.vta_push_uop) - irb.emit(tvm.tir.call_extern("int32", "VTAUopPush", - 0, 1, - dout.access_ptr("rw", "int32"), - 0, 0, - 0, 0, 0)) - inner = irb.get() - # TODO(@tmoreau89): This is only a temporary fix, please take a look. - body = op.body.body - while isinstance(body, tvm.tir.IfThenElse): - body = body.then_case - args = body.args - res_tensor = body.func.output(0) - tpl = (args[0], 1, args[1], 1, args[2], 1, args[3], 1, 0, 1, 0, env.BLOCK_OUT) - inner = tvm.tir.AttrStmt( - [dout, res_tensor], 'buffer_bind_scope', - tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner) - return inner - else: - conv_call, data_call, kernel_call = calls[-3:] - pad_data_tensor = data_call.func.output(0) - kernel_tensor = kernel_call.func.output(0) - res_tensor = conv_call.func.output(0) - - if selects: - condition = selects[0].condition - else: - condition = tvm.tir.const(1, 'int') - - # create inner most block - irb = tvm.tir.ir_builder.create() - with irb.if_scope(condition): - dev = env.dev - irb.scope_attr(dev.vta_axis, "coproc_scope", dev.get_task_qid(dev.QID_COMPUTE)) - irb.scope_attr(dev.vta_axis, "coproc_uop_scope", dev.vta_push_uop) - irb.emit(tvm.tir.call_extern("int32", "VTAUopPush", - 0, 0, - dout.access_ptr("rw", "int32"), - dinp.access_ptr("r", "int32"), - dwgt.access_ptr("r", "int32"), - 0, 0, 0)) - inner = irb.get() - - args = conv_call.args - tpl = (args[0], 1, args[1], 1, args[2], 1, args[3], - 1, 0, 1, 0, env.BLOCK_OUT) - inner = tvm.tir.AttrStmt( - [dout, res_tensor], 'buffer_bind_scope', - tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner) - args = kernel_call.args - tpl = (args[0], 1, args[1], 1, args[2], 1, args[3], - 1, 0, env.BLOCK_OUT, 0, env.BLOCK_IN) - inner = tvm.tir.AttrStmt( - [dwgt, kernel_tensor], 'buffer_bind_scope', - tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner) - args = data_call.args - tpl = (args[0], 1, args[1], 1, args[2], 1, args[3], - 1, 0, 1, 0, env.BLOCK_IN) - inner = tvm.tir.AttrStmt( - [dinp, pad_data_tensor], 'buffer_bind_scope', - tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner) - return inner - return None - ret = tvm.tir.ir_pass.IRTransform( - stmt_in, _do_fold, None, ["AttrStmt"]) - return ret - - -def annotate_alu_coproc_scope(stmt_in): - """Pass to insert ALU instruction. - - Parameters - ---------- - stmt_in : Stmt - Input statement - - Returns - ------- - stmt_out : Stmt - Transformed statement - """ - env = get_env() - def _do_fold(stmt): - if _match_pragma(stmt, "alu"): - irb = tvm.tir.ir_builder.create() - irb.scope_attr(env.dev.vta_axis, "coproc_scope", - env.dev.get_task_qid(env.dev.QID_COMPUTE)) - irb.scope_attr(env.dev.vta_axis, "coproc_uop_scope", - tvm.tir.StringImm("VTAPushALUOp")) - irb.emit(stmt) - return irb.get() - if _match_pragma(stmt, "skip_alu"): - return tvm.tir.Evaluate(0) - return stmt - - stmt_out = tvm.tir.ir_pass.IRTransform( - stmt_in, None, _do_fold, ["AttrStmt"]) - - return stmt_out - - -def inject_alu_intrin(stmt_in): - """Pass to inject ALU micro-ops. - - Parameters - ---------- - stmt_in : Stmt - Input statement - - Returns - ------- - stmt_out : Stmt - Transformed statement - """ - env = get_env() - idxm = tvm.tir.indexmod - - def _do_fold(stmt): - def _equal(x, y): - return tvm.ir.structural_equal(tvm.tir.ir_pass.Simplify(x - y), 0) - - def _flatten_loop(src_coeff, dst_coeff, extents): - src_coeff = list(src_coeff) - dst_coeff = list(dst_coeff) - extents = list(extents) - rev_src_coeff = [src_coeff.pop()] - rev_dst_coeff = [dst_coeff.pop()] - rev_extents = [] - assert src_coeff - vsrc = src_coeff.pop() - vdst = dst_coeff.pop() - vext = extents.pop() - while src_coeff: - next_src = src_coeff.pop() - next_dst = dst_coeff.pop() - next_ext = extents.pop() - - if _equal(next_src, vsrc * vext) and _equal(next_dst, vdst * vext): - vext = tvm.tir.ir_pass.Simplify(vext * next_ext) - else: - rev_src_coeff.append(vsrc) - rev_dst_coeff.append(vdst) - rev_extents.append(vext) - vsrc = next_src - vdst = next_dst - vext = next_ext - rev_src_coeff.append(vsrc) - rev_dst_coeff.append(vdst) - rev_extents.append(vext) - rev_src_coeff.reverse() - rev_dst_coeff.reverse() - rev_extents.reverse() - - return rev_src_coeff, rev_dst_coeff, rev_extents - - if _match_pragma(stmt, "alu"): - # Get to the innermost loop body - loop_body = stmt.body - nest_size = 0 - while isinstance(loop_body, tvm.tir.For): - loop_body = loop_body.body - nest_size += 1 - # Get the src/dst arguments - dst_var = loop_body.buffer_var - dst_idx = loop_body.index - # Derive loop variables and extents - tmp_body = stmt.body - indices = [] - extents = [] - for _ in range(nest_size): - indices.append(tmp_body.loop_var) - extents.append(tmp_body.extent) - tmp_body = tmp_body.body - # Derive opcode - if isinstance(loop_body.value, tvm.tir.Add): - alu_opcode = env.dev.ALU_OPCODE_ADD - lhs = loop_body.value.a - rhs = loop_body.value.b - elif isinstance(loop_body.value, tvm.tir.Sub): - alu_opcode = env.dev.ALU_OPCODE_SUB - lhs = loop_body.value.a - rhs = loop_body.value.b - elif isinstance(loop_body.value, tvm.tir.Mul): - alu_opcode = env.dev.ALU_OPCODE_MUL - lhs = loop_body.value.a - rhs = loop_body.value.b - elif isinstance(loop_body.value, tvm.tir.Min): - alu_opcode = env.dev.ALU_OPCODE_MIN - lhs = loop_body.value.a - rhs = loop_body.value.b - elif isinstance(loop_body.value, tvm.tir.Max): - alu_opcode = env.dev.ALU_OPCODE_MAX - lhs = loop_body.value.a - rhs = loop_body.value.b - elif isinstance(loop_body.value, tvm.tir.Call): - if loop_body.value.name == 'shift_left': - alu_opcode = env.dev.ALU_OPCODE_SHR - lhs = loop_body.value.args[0] - rhs = tvm.tir.ir_pass.Simplify(-loop_body.value.args[1]) - elif loop_body.value.name == 'shift_right': - alu_opcode = env.dev.ALU_OPCODE_SHR - lhs = loop_body.value.args[0] - rhs = loop_body.value.args[1] - else: - raise RuntimeError( - "Function call not recognized %s" % (loop_body.value.name)) - elif isinstance(loop_body.value, tvm.tir.Load): - alu_opcode = env.dev.ALU_OPCODE_SHR - lhs = loop_body.value - rhs = tvm.tir.const(0, "int32") - else: - raise RuntimeError( - "Expression not recognized %s, %s, %s" % ( - type(loop_body.value), str(loop_body.value), str(stmt))) - - # Derive array index coefficients - dst_coeff = tvm.arith.detect_linear_equation(dst_idx, indices) - # Check if lhs/rhs is immediate - use_imm = False - imm_val = None - if isinstance(rhs, tvm.tir.IntImm): - assert lhs.buffer_var.same_as(dst_var) - src_coeff = tvm.arith.detect_linear_equation(lhs.index, indices) - use_imm = True - imm_val = rhs - if isinstance(lhs, tvm.tir.IntImm): - assert rhs.buffer_var.same_as(dst_var) - src_coeff = tvm.arith.detect_linear_equation(rhs.index, indices) - use_imm = True - imm_val = lhs - if imm_val is None: - imm_val = 0 - assert lhs.buffer_var.same_as(dst_var) and rhs.buffer_var.same_as(dst_var) - src_lhs_coeff = tvm.arith.detect_linear_equation(lhs.index, indices) - src_rhs_coeff = tvm.arith.detect_linear_equation(rhs.index, indices) - # Determine which side has the same coefficients - lhs_equal = True - rhs_equal = True - for i, coef in enumerate(dst_coeff): - if not tvm.ir.structural_equal(coef, src_lhs_coeff[i]): - lhs_equal = False - if not tvm.ir.structural_equal(coef, src_rhs_coeff[i]): - rhs_equal = False - # Make sure at least one of the source is identical to the - # destination (in-place computation) - assert lhs_equal or rhs_equal - # Assign the source coefficients - if lhs_equal: - src_coeff = src_rhs_coeff - else: - src_coeff = src_lhs_coeff - - # Ensure that we have the proper tensor dimensions in the - # innermost loop (pattern match) - src_coeff = list(src_coeff) - dst_coeff = list(dst_coeff) - extents = list(extents) - assert len(src_coeff) > 1 - assert len(dst_coeff) > 1 - assert len(extents) != 0 - assert tvm.ir.structural_equal( - tvm.tir.ir_pass.Simplify( - idxm(src_coeff[-1], env.BATCH * env.BLOCK_OUT)), 0) - assert tvm.ir.structural_equal( - tvm.tir.ir_pass.Simplify( - idxm(dst_coeff[-1], env.BATCH * env.BLOCK_OUT)), 0) - assert tvm.ir.structural_equal(src_coeff[-2], 1) - assert tvm.ir.structural_equal(dst_coeff[-2], 1) - if env.BATCH > 1: - assert len(src_coeff) > 2 - assert len(dst_coeff) > 2 - assert len(extents) > 1 - assert tvm.ir.structural_equal(src_coeff[-3], env.BLOCK_OUT) - assert tvm.ir.structural_equal(dst_coeff[-3], env.BLOCK_OUT) - - # Apply tensorization of the loop coefficients - src_offset = src_coeff[-1] - dst_offset = dst_coeff[-1] - if env.BATCH == 1: - src_coeff = src_coeff[:-2] - dst_coeff = dst_coeff[:-2] - extents = extents[:-1] - else: - src_coeff = src_coeff[:-3] - dst_coeff = dst_coeff[:-3] - extents = extents[:-2] - src_coeff.append(src_offset) - dst_coeff.append(dst_offset) - src_coeff = [ - tvm.tir.ir_pass.Simplify(c // (env.BATCH * env.BLOCK_OUT)) for c in src_coeff] - dst_coeff = [ - tvm.tir.ir_pass.Simplify(c // (env.BATCH * env.BLOCK_OUT)) for c in dst_coeff] - - # Flatten the outer loops - if extents: - src_coeff, dst_coeff, extents = _flatten_loop(src_coeff, dst_coeff, extents) - - # Insert ALU micro-ops - irb = tvm.tir.ir_builder.create() - for idx, extent in enumerate(extents): - irb.emit(tvm.tir.call_extern( - "int32", "VTAUopLoopBegin", - extent, dst_coeff[idx], src_coeff[idx], 0)) - use_imm = int(use_imm) - irb.emit(tvm.tir.call_extern( - "int32", "VTAUopPush", - 1, 0, - dst_coeff[len(dst_coeff)-1], - src_coeff[len(src_coeff)-1], - 0, - alu_opcode, use_imm, imm_val)) - for extent in extents: - irb.emit(tvm.tir.call_extern( - "int32", "VTAUopLoopEnd")) - return irb.get() - return stmt - - stmt_out = tvm.tir.ir_pass.IRTransform( - stmt_in, None, _do_fold, ["AttrStmt"]) - return stmt_out - - -def debug_print(stmt): - """A debug pass that print the stmt - - Parameters - ---------- - stmt : Stmt - The input statement - - Returns - ------- - stmt : Stmt - The - """ - # pylint: disable=superfluous-parens - print(stmt) - return stmt diff --git a/vta/python/vta/top/graphpack.py b/vta/python/vta/top/graphpack.py index 2334de7e6905..231d40033350 100644 --- a/vta/python/vta/top/graphpack.py +++ b/vta/python/vta/top/graphpack.py @@ -345,9 +345,9 @@ def visit_call(self, call): method, align_corners) elif call.op == self.reshape and len(input_types[0].shape) == 4: - data, = args + data, _ = args data = op.transpose(data, axes=(0, 4, 1, 5, 2, 3)) - return op.reshape(data, input_types[0].shape) + return op.reshape(data, [int(x) for x in input_types[0].shape]) return relay.Call( self.visit(call.op), @@ -376,7 +376,7 @@ def _recursion(anf, start_found, stop_found, operator_current_idx): if isinstance(anf, relay.expr.Let): value = anf.value if isinstance(value, relay.expr.Call): - if isinstance(value.op, relay.op.Op): + if isinstance(value.op, tvm.ir.Op): if value.op.name == start_name and not start_found: if operator_current_idx == start_name_idx or start_name_idx is None: value = relay.expr.Call(bitpack_start, [value]) diff --git a/vta/python/vta/transform.py b/vta/python/vta/transform.py new file mode 100644 index 000000000000..207f784b5885 --- /dev/null +++ b/vta/python/vta/transform.py @@ -0,0 +1,962 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Additional Transformation Passes. for VTA""" +# pylint: disable=len-as-condition, no-else-return, unused-argument, invalid-name +import tvm +from tvm import te +from topi import util + +from .environment import get_env + + +def _match_pragma(stmt, key): + """Internal helper to match stmt to pragma stmt. + + Parameters + ---------- + stmt : Stmt + The AttrStmt + + key : str + The pragma key + """ + return ((stmt.attr_key == "pragma_" + key) or + (stmt.attr_key == "pragma_scope" and stmt.value.value == key)) + + +def FoldUopLoop(): + """Detect and fold uop loop. + + VTA support uop programming model + that recognizes loop structure. + This pass detect the loop structure + and extract that into uop loop AST. + + Returns + ------- + fpass : tvm.transform.Pass + The pass + """ + def _fold_outermost_loop(body): + stmt = body + if not isinstance(stmt, tvm.tir.For): + return None, body, None + + loop_var = stmt.loop_var + gemm_offsets = [None, None, None] + fail = [False] + + def _post_order(op): + assert isinstance(op, tvm.tir.Call) + base_args = 2 + if op.name == "VTAUopPush": + args = [] + args += op.args[:base_args] + for i in range(3): + m = tvm.arith.detect_linear_equation( + op.args[i + base_args], [loop_var]) + if not m: + fail[0] = True + return op + if gemm_offsets[i] is not None: + if not tvm.ir.structural_equal(m[0], gemm_offsets[i]): + fail[0] = True + return op + args.append(m[1]) + else: + gemm_offsets[i] = m[0] + args.append(m[1]) + args += op.args[base_args+3:] + return tvm.tir.call_extern("int32", "VTAUopPush", *args) + if op.name not in ("VTATLSCommandHandle", "tvm_thread_context"): + raise RuntimeError("unexpected op %s" % op) + return op + + ret = tvm.tir.stmt_functor.ir_transform( + stmt.body, None, _post_order, ["tir.Call"]) + + if not fail[0] and all(x is not None for x in gemm_offsets): + def _visit(op): + if op.same_as(loop_var): + fail[0] = True + tvm.tir.stmt_functor.post_order_visit(ret, _visit) + if not fail[0]: + begin = tvm.tir.call_extern( + "int32", "VTAUopLoopBegin", stmt.extent, *gemm_offsets) + end = tvm.tir.call_extern("int32", "VTAUopLoopEnd") + return [begin, ret, end] + raise ValueError("Failed to fold the GEMM instructions..") + + def _do_fold(stmt): + env = get_env() + if (stmt.attr_key == "coproc_uop_scope" and + isinstance(stmt.value, tvm.tir.StringImm) and + stmt.value.value == env.dev.vta_push_uop.value): + body = stmt.body + begins = [] + ends = [] + try: + begin, body, end = _fold_outermost_loop(body) + if begin is not None: + begins.append(begin) + if end is not None: + ends.append(end) + begin, body, end = _fold_outermost_loop(body) + if begin is not None: + begins.append(begin) + if end is not None: + ends.append(end) + except ValueError: + pass + if body == stmt.body: + return stmt + ends = list(reversed(ends)) + body = tvm.tir.stmt_seq(*(begins + [body] + ends)) + return tvm.tir.AttrStmt( + stmt.node, stmt.attr_key, stmt.value, body) + return None + + def _ftransform(f, mod, ctx): + return f.with_body(tvm.tir.stmt_functor.ir_transform( + f.body, _do_fold, None, ["tir.AttrStmt"])) + + return tvm.tir.transform.prim_func_pass( + _ftransform, opt_level=0, name="tir.vta.FoldUopLoop") + + +def CPUAccessRewrite(): + """Detect CPU access to VTA buffer and get address correctly. + + VTA's buffer is an opaque handle that do not + correspond to address in CPU. + This pass detect CPU access and rewrite to use pointer + returned VTABufferCPUPtr for CPU access. + + Returns + ------- + fpass : tvm.transform.Pass + The pass + """ + def _ftransform(f, mod, ctx): + rw_info = {} + env = get_env() + def _post_order(op): + if isinstance(op, tvm.tir.Allocate): + buffer_var = op.buffer_var + if not buffer_var in rw_info: + return None + new_var = rw_info[buffer_var] + let_stmt = tvm.tir.LetStmt( + new_var, tvm.tir.call_extern( + "handle", "VTABufferCPUPtr", + env.dev.command_handle, + buffer_var), op.body) + alloc = tvm.tir.Allocate( + buffer_var, op.dtype, op.extents, + op.condition, let_stmt) + del rw_info[buffer_var] + return alloc + if isinstance(op, tvm.tir.Load): + buffer_var = op.buffer_var + if not buffer_var in rw_info: + rw_info[buffer_var] = te.var( + buffer_var.name + "_ptr", "handle") + new_var = rw_info[buffer_var] + return tvm.tir.Load(op.dtype, new_var, op.index) + if isinstance(op, tvm.tir.Store): + buffer_var = op.buffer_var + if not buffer_var in rw_info: + rw_info[buffer_var] = te.var( + buffer_var.name + "_ptr", "handle") + new_var = rw_info[buffer_var] + return tvm.tir.Store(new_var, op.value, op.index) + raise RuntimeError("not reached") + + stmt_in = f.body + stmt = tvm.tir.stmt_functor.ir_transform( + stmt_in, None, _post_order, ["tir.Allocate", "tir.Load", "tir.Store"]) + + for buffer_var, new_var in rw_info.items(): + stmt = tvm.tir.LetStmt( + new_var, tvm.tir.call_extern( + "handle", "VTABufferCPUPtr", + env.dev.command_handle, + buffer_var), stmt) + return f.with_body(stmt) + return tvm.tir.transform.prim_func_pass( + _ftransform, opt_level=0, name="tir.vta.CPUAccessRewrite") + + +def LiftAllocToScopeBegin(): + """Lift allocate to beginning of the current scope. + + Returns + ------- + fpass : tvm.transform.Pass + The pass + """ + def _ftransform(f, mod, ctx): + lift_stmt = [[]] + def _merge_block(slist, body): + for op in slist: + if op.body == body: + body = op + elif isinstance(op, tvm.tir.Allocate): + body = tvm.tir.Allocate( + op.buffer_var, op.dtype, + op.extents, op.condition, body) + elif isinstance(op, tvm.tir.AttrStmt): + body = tvm.tir.AttrStmt( + op.node, op.attr_key, op.value, body) + elif isinstance(op, tvm.tir.For): + body = tvm.tir.For( + op.loop_var, op.min, op.extent, op.for_type, + op.device_api, body) + else: + raise RuntimeError("unexpected op") + del slist[:] + return body + + def _pre_order(op): + if isinstance(op, tvm.tir.For): + lift_stmt.append([]) + elif isinstance(op, tvm.tir.AttrStmt): + if op.attr_key == "virtual_thread": + lift_stmt.append([]) + + def _post_order(op): + if isinstance(op, tvm.tir.Allocate): + lift_stmt[-1].append(op) + return op.body + if isinstance(op, tvm.tir.AttrStmt): + if op.attr_key == "storage_scope": + lift_stmt[-1].append(op) + return op.body + if op.attr_key == "virtual_thread": + return _merge_block(lift_stmt.pop() + [op], op.body) + return op + if isinstance(op, tvm.tir.For): + return _merge_block(lift_stmt.pop() + [op], op.body) + raise RuntimeError("not reached") + stmt_in = f.body + stmt = tvm.tir.stmt_functor.ir_transform( + stmt_in, _pre_order, _post_order, ["tir.Allocate", "tir.AttrStmt", "tir.For"]) + assert len(lift_stmt) == 1 + return f.with_body(_merge_block(lift_stmt[0], stmt)) + + return tvm.tir.transform.prim_func_pass( + _ftransform, opt_level=0, name="tir.vta.LiftAllocToScopeBegin") + + +def InjectSkipCopy(): + """Pass to inject skip copy stmt, used for debug purpose. + + Returns + ------- + fpass : tvm.transform.Pass + The pass + """ + def _do_fold(stmt): + if _match_pragma(stmt, "skip_dma_copy"): + return tvm.tir.Evaluate(0) + return None + + def _ftransform(f, mod, ctx): + return f.with_body(tvm.tir.stmt_functor.ir_transform( + f.body, _do_fold, None, ["tir.AttrStmt"])) + + return tvm.tir.transform.prim_func_pass( + _ftransform, opt_level=0, name="tir.vta.InjectSkipCopy") + + +def InjectCoProcSync(): + """Pass inject coproc sync + + Returns + ------- + fpass : tvm.transform.Pass + The pass + """ + def _ftransform(f, *_): + success = [False] + def _do_fold(stmt): + if _match_pragma(stmt, "coproc_sync"): + success[0] = True + sync = tvm.tir.Call( + "int32", "vta.coproc_sync", [], tvm.tir.Call.Intrinsic) + return tvm.tir.SeqStmt([stmt.body, tvm.tir.Evaluate(sync)]) + if _match_pragma(stmt, "trim_loop"): + op = stmt.body + assert isinstance(op, tvm.tir.For) + return tvm.tir.For( + op.loop_var, op.min, 2, op.for_type, + op.device_api, op.body) + return None + return f.with_body(tvm.tir.stmt_functor.ir_transform( + f.body, None, _do_fold, ["tir.AttrStmt"])) + return tvm.transform.Sequential( + [tvm.tir.transform.prim_func_pass(_ftransform, 0, "tir.vta.InjectCoProcSync"), + tvm.tir.transform.CoProcSync()], + opt_level=0, name="tir.vta.InjectCoProcSync") + + +def InjectDMAIntrin(): + """Pass to inject DMA copy intrinsics. + + Returns + ------- + fpass : tvm.transform.Pass + The pass + """ + idxd = tvm.tir.indexdiv + idxm = tvm.tir.indexmod + + def _check_compact(buf): + ndim = len(buf.shape) + size = tvm.tir.const(1, buf.shape[0].dtype) + for i in reversed(range(ndim)): + if not util.equal_const_int(size - buf.strides[i], 0): + raise RuntimeError( + "Cannot prove compact: shape=%s, strides=%s" % (buf.shape, buf.strides)) + size = size * buf.shape[i] + + def _fold_buffer_dim(buf, scope, elem_block): + ndim = len(buf.shape) + x_size = 1 + base = 0 + for i in range(1, ndim + 1): + if not util.equal_const_int(buf.strides[ndim - i] - x_size, 0): + raise RuntimeError("scope %s needs to have block=%d" % (scope, elem_block)) + x_size = x_size * buf.shape[ndim - i] + if util.equal_const_int(x_size - elem_block, 0): + base = i + 1 + break + if base == 0: + raise RuntimeError("scope %s need to have block=%d, shape=%s" % ( + scope, elem_block, buf.shape)) + shape = [elem_block] + strides = [1] + + if base < ndim + 1 and not util.equal_const_int(buf.strides[ndim - base], elem_block): + shape.append(1) + strides.append(elem_block) + + analyzer = tvm.arith.Analyzer() + while base < ndim + 1: + x_size = 1 + x_stride = buf.strides[ndim - base] + next_base = base + if not util.equal_const_int(idxm(x_stride, elem_block), 0): + raise RuntimeError( + "scope %s need to have block=%d, shape=%s, strides=%s" % ( + scope, elem_block, buf.shape, buf.strides)) + for i in range(base, ndim + 1): + k = ndim - i + if not util.equal_const_int(x_size * x_stride - buf.strides[k], 0): + break + x_size = x_size * buf.shape[k] + next_base = i + 1 + shape.append(analyzer.simplify(x_size)) + strides.append(x_stride) + assert next_base != base + base = next_base + + strides = list(reversed(strides)) + shape = list(reversed(shape)) + return shape, strides + + def _get_2d_pattern(buf, elem_width, elem_bytes, dtype, scope, allow_fold): + elem_block = elem_bytes * 8 // elem_width + if buf.dtype != dtype: + raise RuntimeError("Expect buffer type to be %s instead of %s" % + (dtype, buf.dtype)) + shape, strides = buf.shape, buf.strides + if not util.equal_const_int(idxm(buf.elem_offset, elem_block), 0): + raise RuntimeError("scope %s need to have block=%d" % (scope, elem_block)) + if allow_fold: + shape, strides = _fold_buffer_dim(buf, scope, elem_block) + else: + shape = list(x for x in shape) + strides = list(x for x in strides) + + def raise_error(): + """Internal function to raise error """ + raise RuntimeError( + ("Scope[%s]: cannot detect 2d pattern with elem_block=%d:" + + " shape=%s, strides=%s") % (scope, elem_block, buf.shape, buf.strides)) + + ndim = len(shape) + + # Check if the inner-tensor is already flat + flat = util.equal_const_int(shape[-1], elem_block) + + if flat: + if not util.equal_const_int(strides[-1], 1): + raise_error() + + if ndim == 1: + x_size = 1 + x_stride = 1 + y_size = 1 + return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block) + if not util.equal_const_int(strides[-2] - elem_block, 0): + raise_error() + + if ndim == 2: + x_size = shape[-2] + x_stride = shape[-2] + y_size = 1 + return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block) + if not util.equal_const_int(idxm(strides[-3], elem_block), 0): + raise_error() + + if ndim == 3: + x_size = shape[-2] + x_stride = idxd(strides[-3], elem_block) + y_size = shape[-3] + return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block) + + else: + if not util.equal_const_int(strides[-1], 1): + raise_error() + if not util.equal_const_int(strides[-2] - shape[-1], 0): + raise_error() + if not util.equal_const_int(shape[-1] * shape[-2], elem_block): + raise_error() + + if ndim == 2: + x_size = 1 + x_stride = 1 + y_size = 1 + return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block) + if not util.equal_const_int(strides[-3], elem_block): + raise_error() + + if ndim == 3: + x_size = shape[-3] + x_stride = shape[-3] + y_size = 1 + return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block) + if not util.equal_const_int(idxm(strides[-4], elem_block), 0): + raise_error() + + if ndim == 4: + x_size = shape[-3] + x_stride = idxd(strides[-4], elem_block) + y_size = shape[-4] + return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block) + + raise_error() + + + def _inject_copy(src, dst, pad_before, pad_after, pad_value): + # FIXME: pad_value is ignored... + env = get_env() + _ = pad_value + if dst.scope == "global": + # Store + if pad_before or pad_after: + raise RuntimeError("Do not support copy into DRAM with pad") + if src.scope == env.acc_scope: + elem_width = env.OUT_WIDTH + elem_bytes = env.OUT_ELEM_BYTES + mem_type = env.dev.MEM_ID_OUT + data_type = "int%d" % env.OUT_WIDTH + task_qid = env.dev.QID_STORE_OUT + else: + raise RuntimeError("Do not support copy %s->dram" % (src.scope)) + _check_compact(src) + x_size, y_size, x_stride, offset = _get_2d_pattern( + dst, elem_width, elem_bytes, data_type, src.scope, allow_fold=True) + irb = tvm.tir.ir_builder.create() + irb.scope_attr(env.dev.vta_axis, "coproc_scope", + env.dev.get_task_qid(task_qid)) + irb.emit(tvm.tir.call_extern( + "int32", "VTAStoreBuffer2D", + env.dev.command_handle, + src.access_ptr("r", "int32"), + mem_type, dst.data, offset, x_size, y_size, x_stride)) + return irb.get() + elif src.scope == "global": + if dst.scope == env.acc_scope: + elem_width = env.ACC_WIDTH + elem_bytes = env.ACC_ELEM_BYTES + mem_type = env.dev.MEM_ID_ACC + data_type = "int%d" % env.ACC_WIDTH + task_qid = env.dev.QID_LOAD_OUT + elif dst.scope == env.inp_scope: + elem_width = env.INP_WIDTH + elem_bytes = env.INP_ELEM_BYTES + mem_type = env.dev.MEM_ID_INP + data_type = "int%d" % env.INP_WIDTH + task_qid = env.dev.QID_LOAD_INP + elif dst.scope == env.wgt_scope: + elem_width = env.WGT_WIDTH + elem_bytes = env.WGT_ELEM_BYTES + mem_type = env.dev.MEM_ID_WGT + data_type = "int%d" % env.WGT_WIDTH + task_qid = env.dev.QID_LOAD_WGT + else: + raise RuntimeError("Do not support copy dram->%s" % (dst.scope)) + # collect pad statistics + if pad_before: + assert pad_after + ndim = len(pad_before) + if ndim <= 2 or ndim > 5: + raise ValueError("Limitation of 2D pad load forbid ndim=%d" % ndim) + if ndim == 5: + # This case occurs when batch size N > 1 + y_pad_before = pad_before[1] + x_pad_before = pad_before[2] + y_pad_after = pad_after[1] + x_pad_after = pad_after[2] + for dim in range(3, ndim): + if not util.equal_const_int(pad_before[dim], 0): + raise ValueError("Do not support pad on the innermost block") + if not util.equal_const_int(pad_after[dim], 0): + raise ValueError("Do not support pad on the innermost block") + else: + y_pad_before = pad_before[0] + x_pad_before = pad_before[1] + y_pad_after = pad_after[0] + x_pad_after = pad_after[1] + for dim in range(2, ndim): + if not util.equal_const_int(pad_before[dim], 0): + raise ValueError("Do not support pad on the innermost block") + if not util.equal_const_int(pad_after[dim], 0): + raise ValueError("Do not support pad on the innermost block") + allow_fold = False + else: + x_pad_before = 0 + y_pad_before = 0 + x_pad_after = 0 + y_pad_after = 0 + allow_fold = True + + _check_compact(dst) + x_size, y_size, x_stride, offset = _get_2d_pattern( + src, elem_width, elem_bytes, data_type, + dst.scope, allow_fold=allow_fold) + + irb = tvm.tir.ir_builder.create() + irb.scope_attr(env.dev.vta_axis, "coproc_scope", + env.dev.get_task_qid(task_qid)) + + irb.emit(tvm.tir.call_extern( + "int32", "VTALoadBuffer2D", + env.dev.command_handle, + src.data, offset, x_size, y_size, x_stride, + x_pad_before, y_pad_before, + x_pad_after, y_pad_after, + dst.access_ptr("r", "int32"), mem_type)) + return irb.get() + + else: + raise RuntimeError("Do not support copy %s->%s" % (src.scope, dst.scope)) + + return tvm.tir.transform.InjectCopyIntrin("dma_copy", _inject_copy) + + +def _get_gemm_intrin_buffer(): + env = get_env() + wgt_lanes = env.WGT_ELEM_BITS // env.WGT_WIDTH + assert wgt_lanes == env.BLOCK_OUT * env.BLOCK_IN + wgt_shape = (env.BLOCK_OUT, env.BLOCK_IN) + assert wgt_shape[0] * wgt_shape[1] == wgt_lanes + inp_lanes = env.INP_ELEM_BITS // env.INP_WIDTH + assert inp_lanes == env.BATCH * env.BLOCK_IN + inp_shape = (env.BATCH, env.BLOCK_IN) + assert inp_shape[0] * inp_shape[1] == inp_lanes + out_lanes = env.ACC_ELEM_BITS // env.ACC_WIDTH + assert out_lanes == env.BATCH * env.BLOCK_OUT + out_shape = (env.BATCH, env.BLOCK_OUT) + assert out_shape[0] * out_shape[1] == out_lanes + wgt = te.placeholder((wgt_shape[0], wgt_shape[1]), + dtype="int%d" % env.WGT_WIDTH, + name=env.wgt_scope) + inp = te.placeholder((inp_shape[0], inp_shape[1]), + dtype="int%d" % env.INP_WIDTH, + name=env.inp_scope) + k = te.reduce_axis((0, wgt_shape[1]), name="k") + out_dtype = "int%d" % env.ACC_WIDTH + out = te.compute((out_shape[0], out_shape[1]), + lambda i, j: te.sum(inp[i, k].astype(out_dtype) * + wgt[j, k].astype(out_dtype), + axis=[k]), + name="out") + wgt_layout = tvm.tir.decl_buffer( + wgt.shape, wgt.dtype, env.wgt_scope, + scope=env.wgt_scope, offset_factor=wgt_lanes, data_alignment=wgt_lanes) + inp_layout = tvm.tir.decl_buffer( + inp.shape, inp.dtype, env.inp_scope, + scope=env.inp_scope, offset_factor=inp_lanes, data_alignment=inp_lanes) + out_layout = tvm.tir.decl_buffer( + out.shape, out.dtype, env.acc_scope, + scope=env.acc_scope, offset_factor=out_lanes, data_alignment=out_lanes) + + return wgt_layout, inp_layout, out_layout + + +def InjectConv2DTransposeSkip(): + """Pass to skip 0-weights in conv2d transpose with stride > 1. + + Returns + ------- + fpass : tvm.transform.Pass + The pass + """ + def _ftransform(func, mod, ctx): + env = get_env() + dwgt, dinp, dout = _get_gemm_intrin_buffer() + + calls = [] + selects = [] + + def _find_basics(op): + if isinstance(op, tvm.tir.BufferLoad): + calls.append(op) + elif isinstance(op, tvm.tir.Select): + selects.append(op) + + def _do_fold(op): + if _match_pragma(op, "conv2d_transpose_gemm"): + is_init = ".init" in str(op) + tvm.tir.stmt_functor.post_order_visit(op, _find_basics) + + if is_init: + # create inner most block + irb = tvm.tir.ir_builder.create() + dev = env.dev + irb.scope_attr(dev.vta_axis, "coproc_scope", dev.get_task_qid(dev.QID_COMPUTE)) + irb.scope_attr(dev.vta_axis, "coproc_uop_scope", dev.vta_push_uop) + irb.emit(tvm.tir.call_extern("int32", "VTAUopPush", + 0, 1, + dout.access_ptr("rw", "int32"), + 0, 0, + 0, 0, 0)) + inner = irb.get() + # TODO(@tmoreau89): This is only a temporary fix, please take a look. + body = op.body.body + while isinstance(body, tvm.tir.IfThenElse): + body = body.then_case + args = body.indices + res_buffer = body.buffer + tpl = (args[0], 1, args[1], 1, args[2], 1, args[3], 1, 0, 1, 0, env.BLOCK_OUT) + inner = tvm.tir.AttrStmt( + [dout, res_buffer], 'buffer_bind_scope', + tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner) + return inner + else: + conv_call, data_call, kernel_call = calls[-3:] + pad_data_tensor = data_call.buffer + kernel_tensor = kernel_call.buffer + res_tensor = conv_call.buffer + + if selects: + condition = selects[0].condition + else: + condition = tvm.tir.const(1, 'int') + + # create inner most block + irb = tvm.tir.ir_builder.create() + with irb.if_scope(condition): + dev = env.dev + irb.scope_attr( + dev.vta_axis, "coproc_scope", dev.get_task_qid(dev.QID_COMPUTE)) + irb.scope_attr(dev.vta_axis, "coproc_uop_scope", dev.vta_push_uop) + irb.emit(tvm.tir.call_extern("int32", "VTAUopPush", + 0, 0, + dout.access_ptr("rw", "int32"), + dinp.access_ptr("r", "int32"), + dwgt.access_ptr("r", "int32"), + 0, 0, 0)) + inner = irb.get() + + args = conv_call.indices + tpl = (args[0], 1, args[1], 1, args[2], 1, args[3], + 1, 0, 1, 0, env.BLOCK_OUT) + inner = tvm.tir.AttrStmt( + [dout, res_tensor], 'buffer_bind_scope', + tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner) + args = kernel_call.indices + tpl = (args[0], 1, args[1], 1, args[2], 1, args[3], + 1, 0, env.BLOCK_OUT, 0, env.BLOCK_IN) + inner = tvm.tir.AttrStmt( + [dwgt, kernel_tensor], 'buffer_bind_scope', + tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner) + args = data_call.indices + tpl = (args[0], 1, args[1], 1, args[2], 1, args[3], + 1, 0, 1, 0, env.BLOCK_IN) + inner = tvm.tir.AttrStmt( + [dinp, pad_data_tensor], 'buffer_bind_scope', + tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner) + return inner + return None + + return func.with_body(tvm.tir.stmt_functor.ir_transform( + func.body, _do_fold, None, ["tir.AttrStmt"])) + return tvm.tir.transform.prim_func_pass( + _ftransform, opt_level=0, name="tir.vta.InjectConv2DTrasnposeSkip") + + +def AnnotateALUCoProcScope(): + """Pass to insert ALU instruction. + + Returns + ------- + fpass : tvm.transform.Pass + The pass + """ + def _ftransform(func, mod, ctx): + env = get_env() + def _do_fold(stmt): + if _match_pragma(stmt, "alu"): + irb = tvm.tir.ir_builder.create() + irb.scope_attr(env.dev.vta_axis, "coproc_scope", + env.dev.get_task_qid(env.dev.QID_COMPUTE)) + irb.scope_attr(env.dev.vta_axis, "coproc_uop_scope", + tvm.tir.StringImm("VTAPushALUOp")) + irb.emit(stmt) + return irb.get() + if _match_pragma(stmt, "skip_alu"): + return tvm.tir.Evaluate(0) + return stmt + + return func.with_body(tvm.tir.stmt_functor.ir_transform( + func.body, None, _do_fold, ["tir.AttrStmt"])) + return tvm.tir.transform.prim_func_pass( + _ftransform, opt_level=0, name="tir.vta.AnnotateALUCoProcScope") + + +def InjectALUIntrin(): + """Pass to inject ALU micro-ops. + + Returns + ------- + fpass : tvm.transform.Pass + The pass + """ + def _ftransform(func, mod, ctx): + env = get_env() + idxm = tvm.tir.indexmod + analyzer = tvm.arith.Analyzer() + + def _do_fold(stmt): + def _equal(x, y): + return tvm.ir.structural_equal(analyzer.simplify(x - y), 0) + + def _flatten_loop(src_coeff, dst_coeff, extents): + src_coeff = list(src_coeff) + dst_coeff = list(dst_coeff) + extents = list(extents) + rev_src_coeff = [src_coeff.pop()] + rev_dst_coeff = [dst_coeff.pop()] + rev_extents = [] + assert src_coeff + vsrc = src_coeff.pop() + vdst = dst_coeff.pop() + vext = extents.pop() + while src_coeff: + next_src = src_coeff.pop() + next_dst = dst_coeff.pop() + next_ext = extents.pop() + + if _equal(next_src, vsrc * vext) and _equal(next_dst, vdst * vext): + vext = analyzer.simplify(vext * next_ext) + else: + rev_src_coeff.append(vsrc) + rev_dst_coeff.append(vdst) + rev_extents.append(vext) + vsrc = next_src + vdst = next_dst + vext = next_ext + rev_src_coeff.append(vsrc) + rev_dst_coeff.append(vdst) + rev_extents.append(vext) + rev_src_coeff.reverse() + rev_dst_coeff.reverse() + rev_extents.reverse() + + return rev_src_coeff, rev_dst_coeff, rev_extents + + if _match_pragma(stmt, "alu"): + # Get to the innermost loop body + loop_body = stmt.body + nest_size = 0 + while isinstance(loop_body, tvm.tir.For): + loop_body = loop_body.body + nest_size += 1 + # Get the src/dst arguments + dst_var = loop_body.buffer_var + dst_idx = loop_body.index + # Derive loop variables and extents + tmp_body = stmt.body + indices = [] + extents = [] + for _ in range(nest_size): + indices.append(tmp_body.loop_var) + extents.append(tmp_body.extent) + tmp_body = tmp_body.body + # Derive opcode + if isinstance(loop_body.value, tvm.tir.Add): + alu_opcode = env.dev.ALU_OPCODE_ADD + lhs = loop_body.value.a + rhs = loop_body.value.b + elif isinstance(loop_body.value, tvm.tir.Sub): + alu_opcode = env.dev.ALU_OPCODE_SUB + lhs = loop_body.value.a + rhs = loop_body.value.b + elif isinstance(loop_body.value, tvm.tir.Mul): + alu_opcode = env.dev.ALU_OPCODE_MUL + lhs = loop_body.value.a + rhs = loop_body.value.b + elif isinstance(loop_body.value, tvm.tir.Min): + alu_opcode = env.dev.ALU_OPCODE_MIN + lhs = loop_body.value.a + rhs = loop_body.value.b + elif isinstance(loop_body.value, tvm.tir.Max): + alu_opcode = env.dev.ALU_OPCODE_MAX + lhs = loop_body.value.a + rhs = loop_body.value.b + elif isinstance(loop_body.value, tvm.tir.Call): + if loop_body.value.name == 'shift_left': + alu_opcode = env.dev.ALU_OPCODE_SHR + lhs = loop_body.value.args[0] + rhs = analyzer.simplify(-loop_body.value.args[1]) + elif loop_body.value.name == 'shift_right': + alu_opcode = env.dev.ALU_OPCODE_SHR + lhs = loop_body.value.args[0] + rhs = loop_body.value.args[1] + else: + raise RuntimeError( + "Function call not recognized %s" % (loop_body.value.name)) + elif isinstance(loop_body.value, tvm.tir.Load): + alu_opcode = env.dev.ALU_OPCODE_SHR + lhs = loop_body.value + rhs = tvm.tir.const(0, "int32") + else: + raise RuntimeError( + "Expression not recognized %s, %s, %s" % ( + type(loop_body.value), str(loop_body.value), str(stmt))) + + # Derive array index coefficients + dst_coeff = tvm.arith.detect_linear_equation(dst_idx, indices) + # Check if lhs/rhs is immediate + use_imm = False + imm_val = None + if isinstance(rhs, tvm.tir.IntImm): + assert lhs.buffer_var.same_as(dst_var) + src_coeff = tvm.arith.detect_linear_equation(lhs.index, indices) + use_imm = True + imm_val = rhs + if isinstance(lhs, tvm.tir.IntImm): + assert rhs.buffer_var.same_as(dst_var) + src_coeff = tvm.arith.detect_linear_equation(rhs.index, indices) + use_imm = True + imm_val = lhs + if imm_val is None: + imm_val = 0 + assert lhs.buffer_var.same_as(dst_var) and rhs.buffer_var.same_as(dst_var) + src_lhs_coeff = tvm.arith.detect_linear_equation(lhs.index, indices) + src_rhs_coeff = tvm.arith.detect_linear_equation(rhs.index, indices) + # Determine which side has the same coefficients + lhs_equal = True + rhs_equal = True + for i, coef in enumerate(dst_coeff): + if not tvm.ir.structural_equal(coef, src_lhs_coeff[i]): + lhs_equal = False + if not tvm.ir.structural_equal(coef, src_rhs_coeff[i]): + rhs_equal = False + # Make sure at least one of the source is identical to the + # destination (in-place computation) + assert lhs_equal or rhs_equal + # Assign the source coefficients + if lhs_equal: + src_coeff = src_rhs_coeff + else: + src_coeff = src_lhs_coeff + + # Ensure that we have the proper tensor dimensions in the + # innermost loop (pattern match) + src_coeff = list(src_coeff) + dst_coeff = list(dst_coeff) + extents = list(extents) + assert len(src_coeff) > 1 + assert len(dst_coeff) > 1 + assert len(extents) != 0 + assert tvm.ir.structural_equal( + analyzer.simplify( + idxm(src_coeff[-1], env.BATCH * env.BLOCK_OUT)), 0) + assert tvm.ir.structural_equal( + analyzer.simplify( + idxm(dst_coeff[-1], env.BATCH * env.BLOCK_OUT)), 0) + assert tvm.ir.structural_equal(src_coeff[-2], 1) + assert tvm.ir.structural_equal(dst_coeff[-2], 1) + if env.BATCH > 1: + assert len(src_coeff) > 2 + assert len(dst_coeff) > 2 + assert len(extents) > 1 + assert tvm.ir.structural_equal(src_coeff[-3], env.BLOCK_OUT) + assert tvm.ir.structural_equal(dst_coeff[-3], env.BLOCK_OUT) + + # Apply tensorization of the loop coefficients + src_offset = src_coeff[-1] + dst_offset = dst_coeff[-1] + if env.BATCH == 1: + src_coeff = src_coeff[:-2] + dst_coeff = dst_coeff[:-2] + extents = extents[:-1] + else: + src_coeff = src_coeff[:-3] + dst_coeff = dst_coeff[:-3] + extents = extents[:-2] + src_coeff.append(src_offset) + dst_coeff.append(dst_offset) + src_coeff = [ + analyzer.simplify(c // (env.BATCH * env.BLOCK_OUT)) for c in src_coeff] + dst_coeff = [ + analyzer.simplify(c // (env.BATCH * env.BLOCK_OUT)) for c in dst_coeff] + + # Flatten the outer loops + if extents: + src_coeff, dst_coeff, extents = _flatten_loop(src_coeff, dst_coeff, extents) + + # Insert ALU micro-ops + irb = tvm.tir.ir_builder.create() + for idx, extent in enumerate(extents): + irb.emit(tvm.tir.call_extern( + "int32", "VTAUopLoopBegin", + extent, dst_coeff[idx], src_coeff[idx], 0)) + use_imm = int(use_imm) + irb.emit(tvm.tir.call_extern( + "int32", "VTAUopPush", + 1, 0, + dst_coeff[len(dst_coeff)-1], + src_coeff[len(src_coeff)-1], + 0, + alu_opcode, use_imm, imm_val)) + for extent in extents: + irb.emit(tvm.tir.call_extern( + "int32", "VTAUopLoopEnd")) + return irb.get() + return stmt + + return func.with_body(tvm.tir.stmt_functor.ir_transform( + func.body, None, _do_fold, ["tir.AttrStmt"])) + + return tvm.tir.transform.prim_func_pass( + _ftransform, opt_level=0, name="tir.vta.InjectALUIntrin") diff --git a/vta/runtime/device_api.cc b/vta/runtime/device_api.cc index 047a6fdbd50d..298403ca840d 100644 --- a/vta/runtime/device_api.cc +++ b/vta/runtime/device_api.cc @@ -22,12 +22,11 @@ * \brief TVM device API for VTA */ -#include #include +#include -#include "runtime.h" #include "../../src/runtime/workspace_pool.h" - +#include "runtime.h" namespace tvm { namespace runtime { @@ -42,25 +41,14 @@ class VTADeviceAPI final : public DeviceAPI { } } - void* AllocDataSpace(TVMContext ctx, - size_t size, - size_t alignment, - DLDataType type_hint) final { + void* AllocDataSpace(TVMContext ctx, size_t size, size_t alignment, DLDataType type_hint) final { return VTABufferAlloc(size); } - void FreeDataSpace(TVMContext ctx, void* ptr) final { - VTABufferFree(ptr); - } + void FreeDataSpace(TVMContext ctx, void* ptr) final { VTABufferFree(ptr); } - void CopyDataFromTo(const void* from, - size_t from_offset, - void* to, - size_t to_offset, - size_t size, - TVMContext ctx_from, - TVMContext ctx_to, - DLDataType type_hint, + void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size, + TVMContext ctx_from, TVMContext ctx_to, DLDataType type_hint, TVMStreamHandle stream) final { int kind_mask = 0; if (ctx_from.device_type != kDLCPU) { @@ -69,33 +57,27 @@ class VTADeviceAPI final : public DeviceAPI { if (ctx_to.device_type != kDLCPU) { kind_mask |= 1; } - VTABufferCopy(from, from_offset, - to, to_offset, - size, kind_mask); + VTABufferCopy(from, from_offset, to, to_offset, size, kind_mask); } - void StreamSync(TVMContext ctx, TVMStreamHandle stream) final { - } + void StreamSync(TVMContext ctx, TVMStreamHandle stream) final {} void* AllocWorkspace(TVMContext ctx, size_t size, DLDataType type_hint) final; void FreeWorkspace(TVMContext ctx, void* data) final; static const std::shared_ptr& Global() { - static std::shared_ptr inst = - std::make_shared(); + static std::shared_ptr inst = std::make_shared(); return inst; } }; struct VTAWorkspacePool : public WorkspacePool { - VTAWorkspacePool() : - WorkspacePool(kDLExtDev, VTADeviceAPI::Global()) {} + VTAWorkspacePool() : WorkspacePool(kDLExtDev, VTADeviceAPI::Global()) {} }; void* VTADeviceAPI::AllocWorkspace(TVMContext ctx, size_t size, DLDataType type_hint) { - return dmlc::ThreadLocalStore::Get() - ->AllocWorkspace(ctx, size); + return dmlc::ThreadLocalStore::Get()->AllocWorkspace(ctx, size); } void VTADeviceAPI::FreeWorkspace(TVMContext ctx, void* data) { @@ -104,10 +86,10 @@ void VTADeviceAPI::FreeWorkspace(TVMContext ctx, void* data) { // Register device api with override. static TVM_ATTRIBUTE_UNUSED auto& __register_dev__ = -::tvm::runtime::Registry::Register("device_api.ext_dev", true) -.set_body([](TVMArgs args, TVMRetValue* rv) { - DeviceAPI* ptr = VTADeviceAPI::Global().get(); - *rv = static_cast(ptr); - }); + ::tvm::runtime::Registry::Register("device_api.ext_dev", true) + .set_body([](TVMArgs args, TVMRetValue* rv) { + DeviceAPI* ptr = VTADeviceAPI::Global().get(); + *rv = static_cast(ptr); + }); } // namespace runtime } // namespace tvm diff --git a/vta/runtime/runtime.cc b/vta/runtime/runtime.cc index 038d5cfa398c..49fe9c557336 100644 --- a/vta/runtime/runtime.cc +++ b/vta/runtime/runtime.cc @@ -24,24 +24,23 @@ * The runtime depends on specific instruction * stream spec as specified in hw_spec.h */ -#include -#include +#include "runtime.h" + #include #include +#include +#include #include #include #include -#include #include - -#include "runtime.h" +#include namespace vta { // Avoid bad configurations. -static_assert(VTA_UOP_WIDTH == sizeof(VTAUop) * 8, - "VTA_UOP_WIDTH do not match VTAUop size"); +static_assert(VTA_UOP_WIDTH == sizeof(VTAUop) * 8, "VTA_UOP_WIDTH do not match VTAUop size"); /*! \brief Enable coherent access of data buffers between VTA and CPU */ static const bool kBufferCoherent = VTA_COHERENT_ACCESSES; @@ -53,13 +52,9 @@ static const bool kAlwaysCache = true; */ struct DataBuffer { /*! \return Virtual address of the data. */ - void* virt_addr() const { - return data_; - } + void* virt_addr() const { return data_; } /*! \return Physical address of the data. */ - vta_phy_addr_t phy_addr() const { - return phy_addr_; - } + vta_phy_addr_t phy_addr() const { return phy_addr_; } /*! * \brief Invalidate the cache of given location in data buffer. * \param offset The offset to the data. @@ -67,9 +62,7 @@ struct DataBuffer { */ void InvalidateCache(size_t offset, size_t size) { if (!kBufferCoherent && kAlwaysCache) { - VTAInvalidateCache(reinterpret_cast(data_) + offset, - phy_addr_ + offset, - size); + VTAInvalidateCache(reinterpret_cast(data_) + offset, phy_addr_ + offset, size); } } /*! @@ -79,16 +72,14 @@ struct DataBuffer { */ void FlushCache(size_t offset, size_t size) { if (!kBufferCoherent && kAlwaysCache) { - VTAFlushCache(reinterpret_cast(data_) + offset, - phy_addr_ + offset, - size); + VTAFlushCache(reinterpret_cast(data_) + offset, phy_addr_ + offset, size); } } /*! * \brief Performs a copy operation from host memory to buffer allocated with VTAMemAlloc. - * \param dst The desination buffer in FPGA-accessible memory. Has to be allocated with VTAMemAlloc(). - * \param src The source buffer in host memory. - * \param size Size of the region in Bytes. + * \param dst The desination buffer in FPGA-accessible memory. Has to be allocated with + * VTAMemAlloc(). \param src The source buffer in host memory. \param size Size of the region in + * Bytes. */ void MemCopyFromHost(void* dst, const void* src, size_t size) { VTAMemCopyFromHost(dst, src, size); @@ -99,9 +90,7 @@ struct DataBuffer { * \param src The source buffer in FPGA-accessible memory. Has to be allocated with VTAMemAlloc(). * \param size Size of the region in Bytes. */ - void MemCopyToHost(void* dst, const void* src, size_t size) { - VTAMemCopyToHost(dst, src, size); - } + void MemCopyToHost(void* dst, const void* src, size_t size) { VTAMemCopyToHost(dst, src, size); } /*! * \brief Allocate a buffer of a given size. * \param size The size of the buffer. @@ -128,8 +117,7 @@ struct DataBuffer { * \return The corresponding data buffer header. */ static DataBuffer* FromHandle(const void* buffer) { - return const_cast( - reinterpret_cast(buffer)); + return const_cast(reinterpret_cast(buffer)); } private: @@ -157,9 +145,7 @@ class UopKernel { * \param signature The pointer to signature. * \param nbytes Number of bytes. */ - UopKernel(const char* signature, int nbytes) - : signature_(signature, signature + nbytes) { - } + UopKernel(const char* signature, int nbytes) : signature_(signature, signature + nbytes) {} /*! * \brief Verify if the signature is correct. * \param signature Signature ptr. @@ -170,21 +156,13 @@ class UopKernel { return memcmp(signature, signature_.data(), nbytes) == 0; } /*! \return Whether the kernel is cached in SRAM. */ - bool cached() const { - return sram_begin_ != sram_end_; - } + bool cached() const { return sram_begin_ != sram_end_; } /*! \return The length of the micro op sequence. */ - size_t size() const { - return seq_.size(); - } + size_t size() const { return seq_.size(); } /*! \return The micro-op data. */ - const VTAUop* data() const { - return seq_.data(); - } + const VTAUop* data() const { return seq_.data(); } /*! \return The loop structure. */ - const std::vector& loop() const { - return loop_; - } + const std::vector& loop() const { return loop_; } /*! * \brief Declare loop start. * \param extent The loop extent. @@ -192,9 +170,7 @@ class UopKernel { * \param src_factor Loop factor of input index * \param wgt_factor Loop factor of weight index. */ - void PushLoopBegin(uint32_t extent, - uint32_t dst_factor, - uint32_t src_factor, + void PushLoopBegin(uint32_t extent, uint32_t dst_factor, uint32_t src_factor, uint32_t wgt_factor) { LoopEntry le; le.extent = extent; @@ -209,9 +185,7 @@ class UopKernel { /*! * \brief Declare loop end. */ - void PushLoopEnd() { - --loop_ptr_; - } + void PushLoopEnd() { --loop_ptr_; } /*! * \brief Push micro op into kernel. * \param mode Set to GEMM mode if set to 0, ALU mode is set to 1. @@ -223,14 +197,8 @@ class UopKernel { * \param use_imm Use immediate in ALU mode if set to true. * \param imm_val Immediate value in ALU mode. */ - void Push(uint32_t mode, - uint32_t reset_out, - uint32_t dst_index, - uint32_t src_index, - uint32_t wgt_index, - uint32_t opcode, - uint32_t use_imm, - int32_t imm_val) { + void Push(uint32_t mode, uint32_t reset_out, uint32_t dst_index, uint32_t src_index, + uint32_t wgt_index, uint32_t opcode, uint32_t use_imm, int32_t imm_val) { // The loop nest structure VerifyDep(dst_index); VTAUop op; @@ -268,10 +236,7 @@ class UopKernel { uint32_t size = seq_.size(); printf("There are %u uops\n", size); for (uint32_t i = 0; i < size; ++i) { - printf("[%04u]\t acc=%u, inp=%u, wgt=%u\n", - i, - seq_[i].dst_idx, - seq_[i].src_idx, + printf("[%04u]\t acc=%u, inp=%u, wgt=%u\n", i, seq_[i].dst_idx, seq_[i].src_idx, seq_[i].wgt_idx); } printf("\n"); @@ -294,7 +259,7 @@ class UopKernel { } } // The uop buffer - template + template friend class UopQueue; friend class CommandQueue; // SRAM location if begin != end @@ -322,26 +287,21 @@ class BaseQueue { } } /*! \return Content of DRAM buffer. */ - char* dram_buffer() const { - return dram_buffer_; - } + char* dram_buffer() const { return dram_buffer_; } /*! \return Physical address of DRAM. */ vta_phy_addr_t dram_phy_addr() const { CHECK(fpga_buff_phy_); return fpga_buff_phy_; } /*! \return Whether there is pending information. */ - bool pending() const { - return sram_begin_ != sram_end_; - } + bool pending() const { return sram_begin_ != sram_end_; } /*! \brief Initialize the space of the buffer. */ void InitSpace(uint32_t elem_bytes, uint32_t max_bytes, bool coherent, bool always_cache) { coherent_ = coherent; always_cache_ = always_cache; elem_bytes_ = elem_bytes; // Allocate buffer ahead of time - fpga_buff_ = static_cast(VTAMemAlloc( - max_bytes, coherent_ || always_cache_)); + fpga_buff_ = static_cast(VTAMemAlloc(max_bytes, coherent_ || always_cache_)); CHECK(fpga_buff_ != nullptr); fpga_buff_phy_ = VTAMemGetPhyAddr(fpga_buff_); } @@ -351,6 +311,9 @@ class BaseQueue { */ virtual void Reset() { dram_buffer_.clear(); + // reset to 0 as we always copy data to area starting from fpga_buff base + // we do mem copy for every DeviceRun + sram_end_ = 0; sram_begin_ = sram_end_; } @@ -376,14 +339,12 @@ class BaseQueue { /*! * \brief Micro op buffer that manages the micro op cache. */ -template +template class UopQueue : public BaseQueue { public: - void InitSpace() { - BaseQueue::InitSpace(kElemBytes, kMaxBytes, kCoherent, kAlwaysCache); - } + void InitSpace() { BaseQueue::InitSpace(kElemBytes, kMaxBytes, kCoherent, kAlwaysCache); } // Push data to the queue - template + template void Push(UopKernel* kernel, FAutoSync fautosync) { // if the micro-op is cached in VTA SRAM, skip if (kernel->cached()) return; @@ -446,13 +407,18 @@ class UopQueue : public BaseQueue { } /*! \brief clear cache and reset base queue buffer.*/ void Reset() { + // unmark "cached" status + // as we cannot assume it is still in SRAM across DeviceRun + for (UopKernel* kernel : cache_) { + kernel->sram_begin_ = 0; + kernel->sram_end_ = 0; + } + cache_.clear(); cache_idx_ = 0; BaseQueue::Reset(); } - void AutoReadBarrier() { - ReadBarrier(); - } + void AutoReadBarrier() { ReadBarrier(); } /*! \brief Writer barrier to make sure that data written by CPU is visible to VTA. */ void ReadBarrier() { CHECK(fpga_buff_ != nullptr); @@ -467,18 +433,14 @@ class UopQueue : public BaseQueue { uint32_t offset = 0; for (uint32_t i = 0; i < cache_.size(); ++i) { uint32_t ksize = cache_[i]->size() * kElemBytes; - VTAMemCopyFromHost(static_cast(fpga_buff_) + offset, - cache_[i]->data(), - ksize); + VTAMemCopyFromHost(static_cast(fpga_buff_) + offset, cache_[i]->data(), ksize); // Update offset offset += ksize; } // Flush if we're using a shared memory system // and if interface is non-coherent if (!coherent_ && always_cache_) { - VTAFlushCache(fpga_buff_, - fpga_buff_phy_, - offset); + VTAFlushCache(fpga_buff_, fpga_buff_phy_, offset); } } @@ -497,8 +459,7 @@ class UopQueue : public BaseQueue { class UopKernelMap { public: // Simple hash map - UopKernel** Get(void* signature, - int nbytes) { + UopKernel** Get(void* signature, int nbytes) { uint32_t key = 0; CHECK(nbytes == 0 || nbytes == sizeof(int)); if (nbytes == sizeof(int)) { @@ -516,15 +477,10 @@ class UopKernelMap { std::vector kmap_; }; -enum PipelineStage : int { - kNoneStage = 0, - kLoadStage = 1, - kComputeStage = 2, - kStoreStage = 3 -}; +enum PipelineStage : int { kNoneStage = 0, kLoadStage = 1, kComputeStage = 2, kStoreStage = 3 }; // Instruction Queue -template +template class InsnQueue : public BaseQueue { public: /*! \brief Initialize the space. */ @@ -535,13 +491,9 @@ class InsnQueue : public BaseQueue { std::fill(pending_pop_next_, pending_pop_next_ + 4, 0); } /*! \return The data pointer. */ - VTAGenericInsn* data() { - return dram_buffer_.data(); - } + VTAGenericInsn* data() { return dram_buffer_.data(); } /*! \return Number of instructions. */ - uint32_t count() { - return dram_buffer_.size(); - } + uint32_t count() { return dram_buffer_.size(); } // Insert dependency push of load void DepPop(int from, int to) { // NOTE: This instruction executes on queue[to] @@ -569,10 +521,12 @@ class InsnQueue : public BaseQueue { if (GetPipelineStage(mptr) == from) { if (from < to && !mptr->push_next_dep) { // push(LD->C) or push(C->ST) - mptr->push_next_dep = true; return; + mptr->push_next_dep = true; + return; } else if (from > to && !mptr->push_prev_dep) { // push(C->LD) or push(ST->C) - mptr->push_prev_dep = true; return; + mptr->push_prev_dep = true; + return; } } } @@ -585,25 +539,15 @@ class InsnQueue : public BaseQueue { } } // Create a new instruction for a GEMM stage - VTAGemInsn* CreateGemInsn() { - return reinterpret_cast( - Create(kComputeStage)); - } + VTAGemInsn* CreateGemInsn() { return reinterpret_cast(Create(kComputeStage)); } // Create a new instruction for a ALU stage - VTAAluInsn* CreateAluInsn() { - return reinterpret_cast( - Create(kComputeStage)); - } + VTAAluInsn* CreateAluInsn() { return reinterpret_cast(Create(kComputeStage)); } // Create a new instruction for a memory stage VTAMemInsn* CreateMemInsn(int memory_type) { - return reinterpret_cast( - Create(GetMemPipelineStage(memory_type))); + return reinterpret_cast(Create(GetMemPipelineStage(memory_type))); } // create a new instruction for a store stage - VTAMemInsn* CreateStoreInsn() { - return reinterpret_cast( - Create(kStoreStage)); - } + VTAMemInsn* CreateStoreInsn() { return reinterpret_cast(Create(kStoreStage)); } // Rewrite instruction stream to force serial execution void RewriteForceSerial() { int insn_count = count(); @@ -653,7 +597,7 @@ class InsnQueue : public BaseQueue { } CommitPendingPop(kComputeStage); } else { - pending_pop_next_[kComputeStage] = 0; + pending_pop_next_[kComputeStage] = 0; } DepPush(kComputeStage, kLoadStage); DepPop(kLoadStage, kComputeStage); @@ -666,30 +610,30 @@ class InsnQueue : public BaseQueue { } // Helper function: Get Opcode string const char* getOpcodeString(int opcode, bool use_imm) { - // The string name - if (opcode == VTA_ALU_OPCODE_MIN) { - if (use_imm) { - return "min imm"; - } else { - return "min"; - } - } else if (opcode == VTA_ALU_OPCODE_MAX) { - if (use_imm) { - return "max imm"; - } else { - return "max"; - } - } else if (opcode == VTA_ALU_OPCODE_ADD) { - if (use_imm) { - return "add imm"; - } else { - return "add"; - } - } else if (opcode == VTA_ALU_OPCODE_SHR) { - return "shr"; + // The string name + if (opcode == VTA_ALU_OPCODE_MIN) { + if (use_imm) { + return "min imm"; + } else { + return "min"; + } + } else if (opcode == VTA_ALU_OPCODE_MAX) { + if (use_imm) { + return "max imm"; + } else { + return "max"; + } + } else if (opcode == VTA_ALU_OPCODE_ADD) { + if (use_imm) { + return "add imm"; + } else { + return "add"; } + } else if (opcode == VTA_ALU_OPCODE_SHR) { + return "shr"; + } - return "unknown op"; + return "unknown op"; } // Dump instructions in the queue void DumpInsn() { @@ -718,10 +662,8 @@ class InsnQueue : public BaseQueue { printf("NOP-MEMORY-STAGE\n"); } printf("\tdep - pop prev: %d, pop next: %d, push prev: %d, push next: %d\n", - static_cast(c.mem.pop_prev_dep), - static_cast(c.mem.pop_next_dep), - static_cast(c.mem.push_prev_dep), - static_cast(c.mem.push_next_dep)); + static_cast(c.mem.pop_prev_dep), static_cast(c.mem.pop_next_dep), + static_cast(c.mem.push_prev_dep), static_cast(c.mem.push_next_dep)); // Count status in queues if (c.mem.opcode == VTA_OPCODE_STORE) { CHECK(c.mem.pop_next_dep == false); @@ -729,8 +671,7 @@ class InsnQueue : public BaseQueue { if (c.mem.pop_prev_dep) g2s_queue--; if (c.mem.push_prev_dep) s2g_queue++; } else if (c.mem.opcode == VTA_OPCODE_LOAD && - (c.mem.memory_type == VTA_MEM_ID_INP || - c.mem.memory_type == VTA_MEM_ID_WGT) ) { + (c.mem.memory_type == VTA_MEM_ID_INP || c.mem.memory_type == VTA_MEM_ID_WGT)) { CHECK(c.mem.pop_prev_dep == false); CHECK(c.mem.push_prev_dep == false); if (c.mem.pop_next_dep) g2l_queue--; @@ -757,65 +698,44 @@ class InsnQueue : public BaseQueue { printf("STORE:\n"); } printf("\tdep - pop prev: %d, pop next: %d, push prev: %d, push next: %d\n", - static_cast(c.mem.pop_prev_dep), - static_cast(c.mem.pop_next_dep), - static_cast(c.mem.push_prev_dep), - static_cast(c.mem.push_next_dep)); - printf("\tDRAM: 0x%08x, SRAM:0x%04x\n", - static_cast(c.mem.dram_base), + static_cast(c.mem.pop_prev_dep), static_cast(c.mem.pop_next_dep), + static_cast(c.mem.push_prev_dep), static_cast(c.mem.push_next_dep)); + printf("\tDRAM: 0x%08x, SRAM:0x%04x\n", static_cast(c.mem.dram_base), static_cast(c.mem.sram_base)); - printf("\ty: size=%d, pad=[%d, %d]\n", - static_cast(c.mem.y_size), - static_cast(c.mem.y_pad_0), - static_cast(c.mem.y_pad_1)); - printf("\tx: size=%d, stride=%d, pad=[%d, %d]\n", - static_cast(c.mem.x_size), - static_cast(c.mem.x_stride), - static_cast(c.mem.x_pad_0), + printf("\ty: size=%d, pad=[%d, %d]\n", static_cast(c.mem.y_size), + static_cast(c.mem.y_pad_0), static_cast(c.mem.y_pad_1)); + printf("\tx: size=%d, stride=%d, pad=[%d, %d]\n", static_cast(c.mem.x_size), + static_cast(c.mem.x_stride), static_cast(c.mem.x_pad_0), static_cast(c.mem.x_pad_1)); } else if (c.mem.opcode == VTA_OPCODE_GEMM) { // Print instruction field information printf("GEMM\n"); printf("\tdep - pop prev: %d, pop next: %d, push prev: %d, push next: %d\n", - static_cast(c.mem.pop_prev_dep), - static_cast(c.mem.pop_next_dep), - static_cast(c.mem.push_prev_dep), - static_cast(c.mem.push_next_dep)); + static_cast(c.mem.pop_prev_dep), static_cast(c.mem.pop_next_dep), + static_cast(c.mem.push_prev_dep), static_cast(c.mem.push_next_dep)); printf("\treset_out: %d\n", static_cast(c.gemm.reset_reg)); - printf("\trange (%d, %d)\n", - static_cast(c.gemm.uop_bgn), + printf("\trange (%d, %d)\n", static_cast(c.gemm.uop_bgn), static_cast(c.gemm.uop_end)); printf("\touter loop - iter: %d, wgt: %d, inp: %d, acc: %d\n", - static_cast(c.gemm.iter_out), - static_cast(c.gemm.wgt_factor_out), - static_cast(c.gemm.src_factor_out), - static_cast(c.gemm.dst_factor_out)); + static_cast(c.gemm.iter_out), static_cast(c.gemm.wgt_factor_out), + static_cast(c.gemm.src_factor_out), static_cast(c.gemm.dst_factor_out)); printf("\tinner loop - iter: %d, wgt: %d, inp: %d, acc: %d\n", - static_cast(c.gemm.iter_in), - static_cast(c.gemm.wgt_factor_in), - static_cast(c.gemm.src_factor_in), - static_cast(c.gemm.dst_factor_in)); + static_cast(c.gemm.iter_in), static_cast(c.gemm.wgt_factor_in), + static_cast(c.gemm.src_factor_in), static_cast(c.gemm.dst_factor_in)); } else if (c.mem.opcode == VTA_OPCODE_ALU) { // Print instruction field information printf("ALU - %s\n", getOpcodeString(c.alu.alu_opcode, c.alu.use_imm)); printf("\tdep - pop prev: %d, pop next: %d, push prev: %d, push next: %d\n", - static_cast(c.mem.pop_prev_dep), - static_cast(c.mem.pop_next_dep), - static_cast(c.mem.push_prev_dep), - static_cast(c.mem.push_next_dep)); + static_cast(c.mem.pop_prev_dep), static_cast(c.mem.pop_next_dep), + static_cast(c.mem.push_prev_dep), static_cast(c.mem.push_next_dep)); printf("\treset_out: %d\n", static_cast(c.alu.reset_reg)); - printf("\trange (%d, %d)\n", - static_cast(c.alu.uop_bgn), + printf("\trange (%d, %d)\n", static_cast(c.alu.uop_bgn), static_cast(c.alu.uop_end)); - printf("\touter loop - iter: %d, dst: %d, src: %d\n", - static_cast(c.alu.iter_out), - static_cast(c.alu.dst_factor_out), - static_cast(c.alu.src_factor_out)); - printf("\tinner loop - iter: %d, dst: %d, src: %d\n", - static_cast(c.alu.iter_in), - static_cast(c.alu.dst_factor_in), - static_cast(c.alu.src_factor_in)); + printf("\touter loop - iter: %d, dst: %d, src: %d\n", static_cast(c.alu.iter_out), + static_cast(c.alu.dst_factor_out), static_cast(c.alu.src_factor_out)); + printf("\tinner loop - iter: %d, dst: %d, src: %d\n", static_cast(c.alu.iter_in), + static_cast(c.alu.dst_factor_in), static_cast(c.alu.src_factor_in)); } else if (c.mem.opcode == VTA_OPCODE_FINISH) { printf("FINISH\n"); } @@ -823,25 +743,23 @@ class InsnQueue : public BaseQueue { // Count status in queues if (c.mem.opcode == VTA_OPCODE_LOAD || c.mem.opcode == VTA_OPCODE_STORE) { if (c.mem.opcode == VTA_OPCODE_STORE) { - CHECK(c.mem.pop_next_dep == false); - CHECK(c.mem.push_next_dep == false); - if (c.mem.pop_prev_dep) g2s_queue--; - if (c.mem.push_prev_dep) s2g_queue++; + CHECK(c.mem.pop_next_dep == false); + CHECK(c.mem.push_next_dep == false); + if (c.mem.pop_prev_dep) g2s_queue--; + if (c.mem.push_prev_dep) s2g_queue++; } else if (c.mem.opcode == VTA_OPCODE_LOAD && - (c.mem.memory_type == VTA_MEM_ID_INP || - c.mem.memory_type == VTA_MEM_ID_WGT) ) { - CHECK(c.mem.pop_prev_dep == false); - CHECK(c.mem.push_prev_dep == false); - if (c.mem.pop_next_dep) g2l_queue--; - if (c.mem.push_next_dep) l2g_queue++; + (c.mem.memory_type == VTA_MEM_ID_INP || c.mem.memory_type == VTA_MEM_ID_WGT)) { + CHECK(c.mem.pop_prev_dep == false); + CHECK(c.mem.push_prev_dep == false); + if (c.mem.pop_next_dep) g2l_queue--; + if (c.mem.push_next_dep) l2g_queue++; } else { - if (c.mem.pop_prev_dep) l2g_queue--; - if (c.mem.push_prev_dep) g2l_queue++; - if (c.mem.pop_next_dep) s2g_queue--; - if (c.mem.push_next_dep) g2s_queue++; + if (c.mem.pop_prev_dep) l2g_queue--; + if (c.mem.push_prev_dep) g2l_queue++; + if (c.mem.pop_next_dep) s2g_queue--; + if (c.mem.push_next_dep) g2s_queue++; } - } else if (c.mem.opcode == VTA_OPCODE_GEMM || - c.mem.opcode == VTA_OPCODE_ALU) { + } else if (c.mem.opcode == VTA_OPCODE_GEMM || c.mem.opcode == VTA_OPCODE_ALU) { // Print instruction field information if (c.gemm.pop_prev_dep) l2g_queue--; if (c.gemm.push_prev_dep) g2l_queue++; @@ -857,11 +775,8 @@ class InsnQueue : public BaseQueue { // Handle the LD<->compute queue // NOTE: pop executes on target(stage) CHECK(stage > 0 && stage < 4); - if (pending_pop_prev_[stage] || - pending_pop_next_[stage]) { - PushNoop(stage, false, false, - pending_pop_prev_[stage], - pending_pop_next_[stage]); + if (pending_pop_prev_[stage] || pending_pop_next_[stage]) { + PushNoop(stage, false, false, pending_pop_prev_[stage], pending_pop_next_[stage]); pending_pop_prev_[stage] = 0; pending_pop_next_[stage] = 0; } @@ -878,9 +793,7 @@ class InsnQueue : public BaseQueue { } return false; } - void AutoReadBarrier() { - ReadBarrier(); - } + void AutoReadBarrier() { ReadBarrier(); } /*! \brief Writer barrier to make sure that data written by CPU is visible to VTA. */ void ReadBarrier() { CHECK(fpga_buff_ != nullptr); @@ -888,15 +801,11 @@ class InsnQueue : public BaseQueue { uint32_t buff_size = dram_buffer_.size() * elem_bytes_; CHECK(buff_size <= kMaxBytes); // Copy contents of DRAM buffer to FPGA buff - VTAMemCopyFromHost(fpga_buff_, - dram_buffer_.data(), - buff_size); + VTAMemCopyFromHost(fpga_buff_, dram_buffer_.data(), buff_size); // Flush if we're using a shared memory system // and if interface is non-coherent if (!coherent_ && always_cache_) { - VTAFlushCache(fpga_buff_, - fpga_buff_phy_, - buff_size); + VTAFlushCache(fpga_buff_, fpga_buff_phy_, buff_size); } } @@ -947,15 +856,14 @@ class InsnQueue : public BaseQueue { // Get stage of memory and computation static PipelineStage GetPipelineStageAll(VTAMemInsn* insn) { - PipelineStage stage = GetPipelineStage(insn); - if (stage != kNoneStage) return stage; - return GetMemPipelineStage(insn->memory_type); + PipelineStage stage = GetPipelineStage(insn); + if (stage != kNoneStage) return stage; + return GetMemPipelineStage(insn->memory_type); } // Push no-op - void PushNoop(int stage, - bool push_prev_dep, bool push_next_dep, - bool pop_prev_dep, bool pop_next_dep) { + void PushNoop(int stage, bool push_prev_dep, bool push_next_dep, bool pop_prev_dep, + bool pop_next_dep) { VTAMemInsn* insn = reinterpret_cast(NextInsn()); insn->opcode = (stage == kStoreStage ? VTA_OPCODE_STORE : VTA_OPCODE_LOAD); insn->push_prev_dep = push_prev_dep; @@ -987,9 +895,7 @@ class InsnQueue : public BaseQueue { */ class CommandQueue { public: - CommandQueue() { - this->InitSpace(); - } + CommandQueue() { this->InitSpace(); } void InitSpace() { uop_queue_.InitSpace(); insn_queue_.InitSpace(); @@ -997,31 +903,29 @@ class CommandQueue { CHECK(device_ != nullptr); } - ~CommandQueue() { - VTADeviceFree(device_); - } + ~CommandQueue() { VTADeviceFree(device_); } uint32_t GetElemBytes(uint32_t memory_id) { uint32_t elem_bytes = 0; switch (memory_id) { case VTA_MEM_ID_UOP: - elem_bytes = VTA_UOP_ELEM_BYTES; - break; + elem_bytes = VTA_UOP_ELEM_BYTES; + break; case VTA_MEM_ID_INP: - elem_bytes = VTA_INP_ELEM_BYTES; - break; + elem_bytes = VTA_INP_ELEM_BYTES; + break; case VTA_MEM_ID_WGT: - elem_bytes = VTA_WGT_ELEM_BYTES; - break; + elem_bytes = VTA_WGT_ELEM_BYTES; + break; case VTA_MEM_ID_ACC: - elem_bytes = VTA_ACC_ELEM_BYTES; - break; + elem_bytes = VTA_ACC_ELEM_BYTES; + break; case VTA_MEM_ID_OUT: - elem_bytes = VTA_OUT_ELEM_BYTES; - break; + elem_bytes = VTA_OUT_ELEM_BYTES; + break; default: - LOG(FATAL) << "Memory id not recognized:" << memory_id; - break; + LOG(FATAL) << "Memory id not recognized:" << memory_id; + break; } /* * elements size should not larger than VTA_PAGE_BYTES. @@ -1031,16 +935,9 @@ class CommandQueue { return elem_bytes; } - void LoadBuffer2D(void* src_dram_addr, - uint32_t src_elem_offset, - uint32_t x_size, - uint32_t y_size, - uint32_t x_stride, - uint32_t x_pad_before, - uint32_t y_pad_before, - uint32_t x_pad_after, - uint32_t y_pad_after, - uint32_t dst_sram_index, + void LoadBuffer2D(void* src_dram_addr, uint32_t src_elem_offset, uint32_t x_size, uint32_t y_size, + uint32_t x_stride, uint32_t x_pad_before, uint32_t y_pad_before, + uint32_t x_pad_after, uint32_t y_pad_after, uint32_t dst_sram_index, uint32_t dst_memory_type) { VTAMemInsn* insn = insn_queue_.CreateMemInsn(dst_memory_type); insn->opcode = VTA_OPCODE_LOAD; @@ -1058,12 +955,8 @@ class CommandQueue { this->CheckInsnOverFlow(); } - void StoreBuffer2D(uint32_t src_sram_index, - uint32_t src_memory_type, - void* dst_dram_addr, - uint32_t dst_elem_offset, - uint32_t x_size, - uint32_t y_size, + void StoreBuffer2D(uint32_t src_sram_index, uint32_t src_memory_type, void* dst_dram_addr, + uint32_t dst_elem_offset, uint32_t x_size, uint32_t y_size, uint32_t x_stride) { VTAMemInsn* insn = insn_queue_.CreateStoreInsn(); insn->opcode = VTA_OPCODE_STORE; @@ -1081,27 +974,21 @@ class CommandQueue { this->CheckInsnOverFlow(); } - void DepPush(int from_qid, int to_qid) { - insn_queue_.DepPush(from_qid, to_qid); - } + void DepPush(int from_qid, int to_qid) { insn_queue_.DepPush(from_qid, to_qid); } - void DepPop(int from_qid, int to_qid) { - insn_queue_.DepPop(from_qid, to_qid); - } + void DepPop(int from_qid, int to_qid) { insn_queue_.DepPop(from_qid, to_qid); } void ReadBarrier(void* buffer, uint32_t elem_bits, uint32_t start, uint32_t extent) { if (!(debug_flag_ & VTA_DEBUG_SKIP_READ_BARRIER)) { uint32_t elem_bytes = (elem_bits + 8 - 1) / 8; - DataBuffer::FromHandle(buffer)->FlushCache( - elem_bytes * start, elem_bytes * extent); + DataBuffer::FromHandle(buffer)->FlushCache(elem_bytes * start, elem_bytes * extent); } } void WriteBarrier(void* buffer, uint32_t elem_bits, uint32_t start, uint32_t extent) { if (!(debug_flag_ & VTA_DEBUG_SKIP_WRITE_BARRIER)) { uint32_t elem_bytes = (elem_bits + 8 - 1) / 8; - DataBuffer::FromHandle(buffer)->InvalidateCache( - elem_bytes * start, elem_bytes * extent); + DataBuffer::FromHandle(buffer)->InvalidateCache(elem_bytes * start, elem_bytes * extent); } } @@ -1131,16 +1018,13 @@ class CommandQueue { insn_queue_.DumpInsn(); } // Make sure that the last instruction is a finish instruction - CHECK(reinterpret_cast( - insn_queue_.data())[insn_queue_.count()-1].opcode == VTA_OPCODE_FINISH); + CHECK(reinterpret_cast(insn_queue_.data())[insn_queue_.count() - 1].opcode == + VTA_OPCODE_FINISH); // Make sure that we don't exceed contiguous physical memory limits CHECK(insn_queue_.count() * sizeof(VTAGenericInsn) < VTA_MAX_XFER); - int timeout = VTADeviceRun( - device_, - insn_queue_.dram_phy_addr(), - insn_queue_.count(), - wait_cycles); + int timeout = + VTADeviceRun(device_, insn_queue_.dram_phy_addr(), insn_queue_.count(), wait_cycles); CHECK_EQ(timeout, 0); // Reset buffers uop_queue_.Reset(); @@ -1154,14 +1038,9 @@ class CommandQueue { } // Set debug flag - void SetDebugFlag(int debug_flag) { - debug_flag_ = debug_flag; - } + void SetDebugFlag(int debug_flag) { debug_flag_ = debug_flag; } - void PushGEMMOp(void** uop_handle, - int (*finit)(void*), - void* signature, - int nbytes) { + void PushGEMMOp(void** uop_handle, int (*finit)(void*), void* signature, int nbytes) { UopKernelMap** uptr = reinterpret_cast(uop_handle); if (uptr[0] == nullptr) { uptr[0] = new UopKernelMap(); @@ -1180,10 +1059,7 @@ class CommandQueue { this->CheckInsnOverFlow(); } - void PushALUUop(void** uop_handle, - int (*finit)(void*), - void* signature, - int nbytes) { + void PushALUUop(void** uop_handle, int (*finit)(void*), void* signature, int nbytes) { UopKernelMap** uptr = reinterpret_cast(uop_handle); if (uptr[0] == nullptr) { uptr[0] = new UopKernelMap(); @@ -1203,23 +1079,19 @@ class CommandQueue { } static std::shared_ptr& ThreadLocal() { - static std::shared_ptr inst = - std::make_shared(); + static std::shared_ptr inst = std::make_shared(); if (inst == nullptr) { inst = std::make_shared(); } return inst; } - static void Shutdown() { - ThreadLocal().reset(); - } + static void Shutdown() { ThreadLocal().reset(); } private: // Push GEMM uop to the command buffer void PushGEMMOp(UopKernel* kernel) { - uop_queue_.Push(kernel, - [this]() { this->AutoSync(); }); + uop_queue_.Push(kernel, [this]() { this->AutoSync(); }); if (uop_queue_.pending()) { VTAMemInsn* insn = insn_queue_.CreateMemInsn(VTA_MEM_ID_UOP); insn->opcode = VTA_OPCODE_LOAD; @@ -1230,7 +1102,7 @@ class CommandQueue { insn->reset_reg = kernel->reset_out_; insn->uop_bgn = kernel->sram_begin_; insn->uop_end = kernel->sram_end_; - const std::vector &loop = kernel->loop(); + const std::vector& loop = kernel->loop(); if (loop.size() > 0) { insn->iter_out = loop[0].extent; insn->wgt_factor_out = loop[0].wgt_factor; @@ -1257,8 +1129,7 @@ class CommandQueue { // Push ALU uop to the command buffer void PushALUUop(UopKernel* kernel) { - uop_queue_.Push(kernel, - [this]() { this->AutoSync(); }); + uop_queue_.Push(kernel, [this]() { this->AutoSync(); }); if (uop_queue_.pending()) { VTAMemInsn* insn = insn_queue_.CreateMemInsn(VTA_MEM_ID_UOP); insn->opcode = VTA_OPCODE_LOAD; @@ -1272,7 +1143,7 @@ class CommandQueue { insn->alu_opcode = kernel->opcode_; insn->use_imm = kernel->use_imm_; insn->imm = kernel->imm_val_; - const std::vector &loop = kernel->loop(); + const std::vector& loop = kernel->loop(); if (loop.size() == 0) { insn->iter_out = 1; insn->dst_factor_out = 0; @@ -1305,9 +1176,7 @@ class CommandQueue { } } // Auto sync when instruction overflow - void AutoSync() { - this->Synchronize(1 << 31); - } + void AutoSync() { this->Synchronize(1 << 31); } // Internal debug flag int debug_flag_{0}; @@ -1323,19 +1192,11 @@ class CommandQueue { } // namespace vta -void* VTABufferAlloc(size_t size) { - return vta::DataBuffer::Alloc(size); -} +void* VTABufferAlloc(size_t size) { return vta::DataBuffer::Alloc(size); } -void VTABufferFree(void* buffer) { - vta::DataBuffer::Free(vta::DataBuffer::FromHandle(buffer)); -} +void VTABufferFree(void* buffer) { vta::DataBuffer::Free(vta::DataBuffer::FromHandle(buffer)); } -void VTABufferCopy(const void* from, - size_t from_offset, - void* to, - size_t to_offset, - size_t size, +void VTABufferCopy(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size, int kind_mask) { vta::DataBuffer* from_buffer = nullptr; vta::DataBuffer* to_buffer = nullptr; @@ -1353,143 +1214,87 @@ void VTABufferCopy(const void* from, // This is an FPGA to host mem transfer from_buffer->InvalidateCache(from_offset, size); from_buffer->MemCopyToHost(static_cast(to) + to_offset, - static_cast(from) + from_offset, - size); + static_cast(from) + from_offset, size); } else if (to_buffer) { // This is a host to FPGA mem transfer to_buffer->MemCopyFromHost(static_cast(to) + to_offset, - static_cast(from) + from_offset, - size); + static_cast(from) + from_offset, size); to_buffer->FlushCache(to_offset, size); } } -VTACommandHandle VTATLSCommandHandle() { - return vta::CommandQueue::ThreadLocal().get(); -} +VTACommandHandle VTATLSCommandHandle() { return vta::CommandQueue::ThreadLocal().get(); } -void VTARuntimeShutdown() { - vta::CommandQueue::Shutdown(); -} +void VTARuntimeShutdown() { vta::CommandQueue::Shutdown(); } void VTASetDebugMode(VTACommandHandle cmd, int debug_flag) { - static_cast(cmd)-> - SetDebugFlag(debug_flag); + static_cast(cmd)->SetDebugFlag(debug_flag); } void* VTABufferCPUPtr(VTACommandHandle cmd, void* buffer) { return vta::DataBuffer::FromHandle(buffer)->virt_addr(); } -void VTAWriteBarrier(VTACommandHandle cmd, - void* buffer, - uint32_t elem_bits, - uint32_t start, +void VTAWriteBarrier(VTACommandHandle cmd, void* buffer, uint32_t elem_bits, uint32_t start, uint32_t extent) { - static_cast(cmd)-> - WriteBarrier(buffer, elem_bits, start, extent); + static_cast(cmd)->WriteBarrier(buffer, elem_bits, start, extent); } -void VTAReadBarrier(VTACommandHandle cmd, - void* buffer, - uint32_t elem_bits, - uint32_t start, +void VTAReadBarrier(VTACommandHandle cmd, void* buffer, uint32_t elem_bits, uint32_t start, uint32_t extent) { - static_cast(cmd)-> - ReadBarrier(buffer, elem_bits, start, extent); + static_cast(cmd)->ReadBarrier(buffer, elem_bits, start, extent); } -void VTALoadBuffer2D(VTACommandHandle cmd, - void* src_dram_addr, - uint32_t src_elem_offset, - uint32_t x_size, - uint32_t y_size, - uint32_t x_stride, - uint32_t x_pad_before, - uint32_t y_pad_before, - uint32_t x_pad_after, - uint32_t y_pad_after, - uint32_t dst_sram_index, - uint32_t dst_memory_type) { - static_cast(cmd)-> - LoadBuffer2D(src_dram_addr, src_elem_offset, - x_size, y_size, x_stride, - x_pad_before, y_pad_before, - x_pad_after, y_pad_after, - dst_sram_index, dst_memory_type); +void VTALoadBuffer2D(VTACommandHandle cmd, void* src_dram_addr, uint32_t src_elem_offset, + uint32_t x_size, uint32_t y_size, uint32_t x_stride, uint32_t x_pad_before, + uint32_t y_pad_before, uint32_t x_pad_after, uint32_t y_pad_after, + uint32_t dst_sram_index, uint32_t dst_memory_type) { + static_cast(cmd)->LoadBuffer2D( + src_dram_addr, src_elem_offset, x_size, y_size, x_stride, x_pad_before, y_pad_before, + x_pad_after, y_pad_after, dst_sram_index, dst_memory_type); } -void VTAStoreBuffer2D(VTACommandHandle cmd, - uint32_t src_sram_index, - uint32_t src_memory_type, - void* dst_dram_addr, - uint32_t dst_elem_offset, - uint32_t x_size, - uint32_t y_size, - uint32_t x_stride) { - static_cast(cmd)-> - StoreBuffer2D(src_sram_index, src_memory_type, - dst_dram_addr, dst_elem_offset, - x_size, y_size, x_stride); +void VTAStoreBuffer2D(VTACommandHandle cmd, uint32_t src_sram_index, uint32_t src_memory_type, + void* dst_dram_addr, uint32_t dst_elem_offset, uint32_t x_size, + uint32_t y_size, uint32_t x_stride) { + static_cast(cmd)->StoreBuffer2D( + src_sram_index, src_memory_type, dst_dram_addr, dst_elem_offset, x_size, y_size, x_stride); } -void VTAUopPush(uint32_t mode, - uint32_t reset_out, - uint32_t dst_index, - uint32_t src_index, - uint32_t wgt_index, - uint32_t opcode, - uint32_t use_imm, - int32_t imm_val) { - vta::CommandQueue::ThreadLocal()->record_kernel() - ->Push(mode, reset_out, dst_index, src_index, - wgt_index, opcode, use_imm, imm_val); +void VTAUopPush(uint32_t mode, uint32_t reset_out, uint32_t dst_index, uint32_t src_index, + uint32_t wgt_index, uint32_t opcode, uint32_t use_imm, int32_t imm_val) { + vta::CommandQueue::ThreadLocal()->record_kernel()->Push(mode, reset_out, dst_index, src_index, + wgt_index, opcode, use_imm, imm_val); } -void VTAUopLoopBegin(uint32_t extent, - uint32_t dst_factor, - uint32_t src_factor, +void VTAUopLoopBegin(uint32_t extent, uint32_t dst_factor, uint32_t src_factor, uint32_t wgt_factor) { - vta::CommandQueue::ThreadLocal()->record_kernel() - ->PushLoopBegin(extent, dst_factor, src_factor, wgt_factor); + vta::CommandQueue::ThreadLocal()->record_kernel()->PushLoopBegin(extent, dst_factor, src_factor, + wgt_factor); } -void VTAUopLoopEnd() { - vta::CommandQueue::ThreadLocal()->record_kernel() - ->PushLoopEnd(); -} +void VTAUopLoopEnd() { vta::CommandQueue::ThreadLocal()->record_kernel()->PushLoopEnd(); } -int VTAPushGEMMOp(void** uop_handle, - int (*finit)(void*), - void* signature, - int nbytes) { - vta::CommandQueue::ThreadLocal()-> - PushGEMMOp(uop_handle, finit, signature, nbytes); +int VTAPushGEMMOp(void** uop_handle, int (*finit)(void*), void* signature, int nbytes) { + vta::CommandQueue::ThreadLocal()->PushGEMMOp(uop_handle, finit, signature, nbytes); return 0; } -int VTAPushALUOp(void** uop_handle, - int (*finit)(void*), - void* signature, - int nbytes) { - vta::CommandQueue::ThreadLocal()-> - PushALUUop(uop_handle, finit, signature, nbytes); +int VTAPushALUOp(void** uop_handle, int (*finit)(void*), void* signature, int nbytes) { + vta::CommandQueue::ThreadLocal()->PushALUUop(uop_handle, finit, signature, nbytes); return 0; } int VTADepPush(VTACommandHandle cmd, int from_qid, int to_qid) { - static_cast(cmd)-> - DepPush(from_qid, to_qid); + static_cast(cmd)->DepPush(from_qid, to_qid); return 0; } int VTADepPop(VTACommandHandle cmd, int from_qid, int to_qid) { - static_cast(cmd)-> - DepPop(from_qid, to_qid); + static_cast(cmd)->DepPop(from_qid, to_qid); return 0; } void VTASynchronize(VTACommandHandle cmd, uint32_t wait_cycles) { - static_cast(cmd)-> - Synchronize(wait_cycles); + static_cast(cmd)->Synchronize(wait_cycles); } diff --git a/vta/runtime/runtime.h b/vta/runtime/runtime.h index bb16d3a3bfc2..24ebb8e1247b 100644 --- a/vta/runtime/runtime.h +++ b/vta/runtime/runtime.h @@ -64,12 +64,8 @@ TVM_DLL void VTABufferFree(void* buffer); * \param size Size of copy. * \param kind_mask The memory copy kind. */ -TVM_DLL void VTABufferCopy(const void* from, - size_t from_offset, - void* to, - size_t to_offset, - size_t size, - int kind_mask); +TVM_DLL void VTABufferCopy(const void* from, size_t from_offset, void* to, size_t to_offset, + size_t size, int kind_mask); /*! \brief VTA command handle */ typedef void* VTACommandHandle; @@ -99,10 +95,7 @@ TVM_DLL void* VTABufferCPUPtr(VTACommandHandle cmd, void* buffer); * \param start The start of the region (in elements). * \param extent The end of the region (in elements). */ -TVM_DLL void VTAWriteBarrier(VTACommandHandle cmd, - void* buffer, - uint32_t elem_bits, - uint32_t start, +TVM_DLL void VTAWriteBarrier(VTACommandHandle cmd, void* buffer, uint32_t elem_bits, uint32_t start, uint32_t extent); /*! @@ -113,10 +106,7 @@ TVM_DLL void VTAWriteBarrier(VTACommandHandle cmd, * \param start The start of the region (in elements). * \param extent The end of the region (in elements). */ -TVM_DLL void VTAReadBarrier(VTACommandHandle cmd, - void* buffer, - uint32_t elem_bits, - uint32_t start, +TVM_DLL void VTAReadBarrier(VTACommandHandle cmd, void* buffer, uint32_t elem_bits, uint32_t start, uint32_t extent); /*! @@ -142,17 +132,10 @@ TVM_DLL void VTASetDebugMode(VTACommandHandle cmd, int debug_flag); * \param dst_sram_index Destination SRAM index. * \param dst_memory_type Destination memory type. */ -TVM_DLL void VTALoadBuffer2D(VTACommandHandle cmd, - void* src_dram_addr, - uint32_t src_elem_offset, - uint32_t x_size, - uint32_t y_size, - uint32_t x_stride, - uint32_t x_pad_before, - uint32_t y_pad_before, - uint32_t x_pad_after, - uint32_t y_pad_after, - uint32_t dst_sram_index, +TVM_DLL void VTALoadBuffer2D(VTACommandHandle cmd, void* src_dram_addr, uint32_t src_elem_offset, + uint32_t x_size, uint32_t y_size, uint32_t x_stride, + uint32_t x_pad_before, uint32_t y_pad_before, uint32_t x_pad_after, + uint32_t y_pad_after, uint32_t dst_sram_index, uint32_t dst_memory_type); /*! @@ -167,13 +150,9 @@ TVM_DLL void VTALoadBuffer2D(VTACommandHandle cmd, * \param y_size The number of rows. * \param x_stride The x axis stride. */ -TVM_DLL void VTAStoreBuffer2D(VTACommandHandle cmd, - uint32_t src_sram_index, - uint32_t src_memory_type, - void* dst_dram_addr, - uint32_t dst_elem_offset, - uint32_t x_size, - uint32_t y_size, +TVM_DLL void VTAStoreBuffer2D(VTACommandHandle cmd, uint32_t src_sram_index, + uint32_t src_memory_type, void* dst_dram_addr, + uint32_t dst_elem_offset, uint32_t x_size, uint32_t y_size, uint32_t x_stride); /*! @@ -207,14 +186,8 @@ TVM_DLL void VTAStoreBuffer2D(VTACommandHandle cmd, * \param use_imm Use immediate in ALU mode if set to true. * \param imm_val Immediate value in ALU mode. */ -TVM_DLL void VTAUopPush(uint32_t mode, - uint32_t reset_out, - uint32_t dst_index, - uint32_t src_index, - uint32_t wgt_index, - uint32_t opcode, - uint32_t use_imm, - int32_t imm_val); +TVM_DLL void VTAUopPush(uint32_t mode, uint32_t reset_out, uint32_t dst_index, uint32_t src_index, + uint32_t wgt_index, uint32_t opcode, uint32_t use_imm, int32_t imm_val); /*! * \brief Mark start of a micro op loop. @@ -223,9 +196,7 @@ TVM_DLL void VTAUopPush(uint32_t mode, * \param src_factor The input factor. * \param wgt_factor The weight factor. */ -TVM_DLL void VTAUopLoopBegin(uint32_t extent, - uint32_t dst_factor, - uint32_t src_factor, +TVM_DLL void VTAUopLoopBegin(uint32_t extent, uint32_t dst_factor, uint32_t src_factor, uint32_t wgt_factor); /*! @@ -241,10 +212,7 @@ TVM_DLL void VTAUopLoopEnd(); * \param nbytes Number of bytes to in the closure arguments. * \return 0 if success. */ -TVM_DLL int VTAPushGEMMOp(void** uop_handle, - int (*finit)(void*), - void* signature, - int nbytes); +TVM_DLL int VTAPushGEMMOp(void** uop_handle, int (*finit)(void*), void* signature, int nbytes); /*! * \brief Push ALU uop kernel into the command handle. @@ -254,10 +222,7 @@ TVM_DLL int VTAPushGEMMOp(void** uop_handle, * \param nbytes Number of bytes to in the closure arguments. * \return 0 if success. */ -TVM_DLL int VTAPushALUOp(void** uop_handle, - int (*finit)(void*), - void* signature, - int nbytes); +TVM_DLL int VTAPushALUOp(void** uop_handle, int (*finit)(void*), void* signature, int nbytes); /*! * \brief Push dependence token. diff --git a/vta/scripts/tune_resnet.py b/vta/scripts/tune_resnet.py index 1de35c024203..2d358d335389 100644 --- a/vta/scripts/tune_resnet.py +++ b/vta/scripts/tune_resnet.py @@ -127,7 +127,7 @@ def compile_network(opt, env, target): # Perform quantization in Relay # Note: We set opt_level to 3 in order to fold batch norm - with relay.build_config(opt_level=3): + with tvm.transform.PassContext(opt_level=3): with relay.quantize.qconfig(global_scale=8.0, skip_conv_layers=[0]): relay_prog = relay.quantize.quantize(mod["main"], params=params) @@ -271,16 +271,16 @@ def tune_tasks(tasks, # Compile network print("Compiling network with best tuning parameters...") - with relay.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}): - if target.device_name != "vta": + if target.device_name != "vta": + with tvm.transform.PassContext(opt_level=3, disabled_pass={"AlterOpLayout"}): + graph, lib, params = relay.build( + relay_prog, target=target, + params=params, target_host=env.target_host) + else: + with vta.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}): graph, lib, params = relay.build( relay_prog, target=target, params=params, target_host=env.target_host) - else: - with vta.build_config(): - graph, lib, params = relay.build( - relay_prog, target=target, - params=params, target_host=env.target_host) # Export library temp = util.tempdir() diff --git a/vta/tutorials/autotvm/tune_relay_vta.py b/vta/tutorials/autotvm/tune_relay_vta.py index 571dde669d2a..a92b1ee5d90b 100644 --- a/vta/tutorials/autotvm/tune_relay_vta.py +++ b/vta/tutorials/autotvm/tune_relay_vta.py @@ -92,7 +92,7 @@ def compile_network(env, target, model, start_pack, stop_pack): # Perform quantization in Relay # Note: We set opt_level to 3 in order to fold batch norm - with relay.build_config(opt_level=3): + with tvm.transform.PassContext(opt_level=3): with relay.quantize.qconfig(global_scale=8.0, skip_conv_layers=[0]): mod = relay.quantize.quantize(mod, params=params) @@ -392,19 +392,19 @@ def tune_and_evaluate(tuning_opt): with autotvm.tophub.context(target, extra_files=[log_file]): # Compile network print("Compile...") - with relay.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}): - if target.device_name != "vta": + if target.device_name != "vta": + with tvm.transform.PassContext(opt_level=3, disabled_pass={"AlterOpLayout"}): graph, lib, params = relay.build(relay_prog, - target=target, - params=params, - target_host=env.target_host) - else: - with vta.build_config(): - graph, lib, params = relay.build( - relay_prog, - target=target, - params=params, - target_host=env.target_host) + target=target, + params=params, + target_host=env.target_host) + else: + with vta.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}): + graph, lib, params = relay.build( + relay_prog, + target=target, + params=params, + target_host=env.target_host) # Export library print("Upload...") diff --git a/vta/tutorials/frontend/deploy_classification.py b/vta/tutorials/frontend/deploy_classification.py index 62fb32165a18..3a367851ed25 100644 --- a/vta/tutorials/frontend/deploy_classification.py +++ b/vta/tutorials/frontend/deploy_classification.py @@ -171,7 +171,7 @@ if target.device_name == "vta": # Perform quantization in Relay # Note: We set opt_level to 3 in order to fold batch norm - with relay.build_config(opt_level=3): + with tvm.transform.PassContext(opt_level=3): with relay.quantize.qconfig(global_scale=8.0, skip_conv_layers=[0]): mod = relay.quantize.quantize(mod, params=params) @@ -188,16 +188,16 @@ relay_prog = mod["main"] # Compile Relay program with AlterOpLayout disabled - with relay.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}): - if target.device_name != "vta": + if target.device_name != "vta": + with tvm.transform.PassContext(opt_level=3, disabled_pass={"AlterOpLayout"}): + graph, lib, params = relay.build( + relay_prog, target=target, + params=params, target_host=env.target_host) + else: + with vta.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}): graph, lib, params = relay.build( relay_prog, target=target, params=params, target_host=env.target_host) - else: - with vta.build_config(): - graph, lib, params = relay.build( - relay_prog, target=target, - params=params, target_host=env.target_host) # Measure Relay build time build_time = time.time() - build_start diff --git a/vta/tutorials/frontend/deploy_detection.py b/vta/tutorials/frontend/deploy_detection.py index efcd2c43591d..5039488149d5 100644 --- a/vta/tutorials/frontend/deploy_detection.py +++ b/vta/tutorials/frontend/deploy_detection.py @@ -207,7 +207,7 @@ if target.device_name == "vta": # Perform quantization in Relay # Note: We set opt_level to 3 in order to fold batch norm - with relay.build_config(opt_level=3): + with tvm.transform.PassContext(opt_level=3): with relay.quantize.qconfig(global_scale=33.0, skip_conv_layers=[0], store_lowbit_output=True, diff --git a/web/.eslintignore b/web/.eslintignore new file mode 100644 index 000000000000..1521c8b7652b --- /dev/null +++ b/web/.eslintignore @@ -0,0 +1 @@ +dist diff --git a/web/.eslintrc.json b/web/.eslintrc.json new file mode 100644 index 000000000000..0724c440041b --- /dev/null +++ b/web/.eslintrc.json @@ -0,0 +1,34 @@ +{ + "env": { + "browser": true, + "es6": true + }, + "extends": ["eslint:recommended"], + "root": true, + "parser": "@typescript-eslint/parser", + "parserOptions": { + "ecmaVersion": 2018, + "sourceType": "module" + }, + "overrides": [ + { + "files": ["src/**.ts", "src/**.tsx"], + "plugins": ["@typescript-eslint"], + "extends": [ + "plugin:@typescript-eslint/eslint-recommended", + "plugin:@typescript-eslint/recommended" + ], + "rules": { + "require-jsdoc": 0, + "@typescript-eslint/no-explicit-any": 0, + "@typescript-eslint/no-empty-function": 0 + } + }, + { + "files": ["tests/node/*.js", "apps/node/*.js"], + "env": { + "node": true + } + } + ] +} diff --git a/web/.gitignore b/web/.gitignore new file mode 100644 index 000000000000..a3135cf24b9d --- /dev/null +++ b/web/.gitignore @@ -0,0 +1,6 @@ +.vscode +*~ +out +node_modules +package-lock.json +build diff --git a/web/.jsdoc_conf.json b/web/.jsdoc_conf.json deleted file mode 100644 index 33783b3bbb21..000000000000 --- a/web/.jsdoc_conf.json +++ /dev/null @@ -1,7 +0,0 @@ -{ - "templates": { - "default": { - "includeDate": false - } - } -} diff --git a/web/Makefile b/web/Makefile new file mode 100644 index 000000000000..eaf5a954accb --- /dev/null +++ b/web/Makefile @@ -0,0 +1,56 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +TVM_ROOT=$(shell cd ..; pwd) + +INCLUDE_FLAGS = -I$(TVM_ROOT) -I$(TVM_ROOT)/include\ + -I$(TVM_ROOT)/3rdparty/dlpack/include -I$(TVM_ROOT)/3rdparty/dmlc-core/include + +.PHONY: clean all rmtypedep preparetest + +all: dist/wasm/tvmjs_runtime.wasm dist/wasm/tvmjs_runtime.wasi.js + +EMCC = emcc + +EMCC_CFLAGS = $(INCLUDE_FLAGS) -O3 -std=c++14 -Wno-ignored-attributes \ + -s ALLOW_MEMORY_GROWTH=1 -s STANDALONE_WASM=1 -s ERROR_ON_UNDEFINED_SYMBOLS=0 + +EMCC_LDFLAGS = --pre-js emcc/preload.js + +dist/wasm/%.bc: emcc/%.cc + @mkdir -p $(@D) + $(EMCC) $(EMCC_CFLAGS) -c -MM -MT dist/wasm/$*.bc $< >dist/wasm/$*.d + $(EMCC) $(EMCC_CFLAGS) -c -o dist/wasm/$*.bc $< + + +dist/wasm/tvmjs_runtime.wasm: dist/wasm/wasm_runtime.bc dist/wasm/tvmjs_support.bc dist/wasm/webgpu_runtime.bc + @mkdir -p $(@D) + $(EMCC) $(EMCC_CFLAGS) -o dist/wasm/tvmjs_runtime.js $+ $(EMCC_LDFLAGS) + + +dist/wasm/tvmjs_runtime.wasi.js: dist/wasm/tvmjs_runtime.wasm emcc/decorate_as_wasi.py + python3 emcc/decorate_as_wasi.py dist/wasm/tvmjs_runtime.js $@ + +clean: + @rm -rf dist/wasm + +# Patch to remove require("@webgpu/types") +rmtypedep: + grep -v webgpu/types dist/webgpu.js > dist/webgpu.temp.js + mv dist/webgpu.temp.js dist/webgpu.js + +-include dist/wasm/*.d diff --git a/web/README.md b/web/README.md index 5dfd6917934b..358884ca26b1 100644 --- a/web/README.md +++ b/web/README.md @@ -15,163 +15,83 @@ -# TVM WebAssembly and Javascript Backend +# TVM WebAssembly Runtime -This folder contains TVM WebAssembly and Javascript backend through Emscripten. +This folder contains TVM WebAssembly Runtime. ## Installation -While the LLVM main branch support webassembly as a target. We still need a good runtime with libc and other -system library support. Emscripten toolchain offers that nicely. The general idea is to build TVM against -the fastcomp LLVM backend in the Emscripten project and allow us to generate ```asmjs-unknown-emscripten``` -as a backend target. + +The LLVM main branch support webassembly as a target, we can directly +build TVM with LLVM mainline to generate wasm modules. +Note that, however, we still need emscripten to compile the runtime and provide system library support. + +Note that so far we requires everything to be in the source and setup PYTHONPATH(instead of use setup.py install). ### Setup Emscripten -Checkout [Emscripten Portable SDK Downloads](https://kripken.github.io/emscripten-site/docs/getting_started/downloads.html) -to download emsdk-portable and unzip it on a local folder. Follow the installation guide from emscripten document. -```bash -./emsdk update -./emsdk install latest -./emsdk activate latest -``` +We use emscripten to compile our runtime wasm library as well as a WASI variant that we can deploy +to the browser environment. -Because we need to compile against the LLVM backend of emscripten, we will need the source and llvm library. -Which can be installed via following command. +Follow [Emscripten](https://emscripten.org/) to download emsdk and install emcc on your local environment. -```bash -./emsdk install clang-incoming-64bit -./emsdk activate clang-incoming-64bit -``` +### Build TVM Wasm Runtime -### Setup Environment Variable +After the emcc is setup correctly. We can build tvm's wasm runtime by typing `make` in the web folder. -In normal setting, we can setup the necessary environment variable with the following command. ```bash -source /path-to-emsdk-portable/emsdk_env.sh +make ``` -However, this will put emscripten's clang and llvm path ahead of the current system path. -What you can do is to set the path manually, by putting emscripten's path after the PATH like the following ones. -You can get the detailed path by type ```./emsdk activate``` -```bash -export PATH=${PATH}:/emsdk-related-path-here +This command will create the follow files: +- `dist/wasm/libtvm_runtime.bc` bitcode library `tvm.contrib.emcc` will link into. +- `dist/wasm/tvmjs_runtime.wasm` a standalone wasm runtime for testing purposes. +- `dist/wasm/tvmjs_runtime.wasi.js` a WASI compatible library generated by emscripten that can be fed into runtime. -``` -### Build TVM with Fastcomp LLVM +### Build TVM Wasm JS Frontend -To build TVM with Emscripten's Fastcomp LLVM, we can modify the LLVM_CONFIG in ```config.mk``` -to point to fastcomp's llvm-config and build TVM normally. +Type the following command in the web folder. ```bash -LLVM_CONFIG = /path/to/emsdk-portable/clang/fastcomp/build_incoming_64/bin/llvm-config +npm run bundle ``` -### Build TVM Web Runtime +This command will create the tvmjs library that we can use to interface with the wasm runtime. -The above command gives us the TVM compiling environment. Now we need to build runtime, -to do so, make sure we set the environment correctly as in previous section and type -```bash -make web -``` +## Use TVM to Generate Wasm Library and Run it -This will create ```build/libtvm_web_runtime.bc``` and ```build/libtvm_web_runtime.js```. - -## Use TVM to Generate Javascript Library - -The general idea is to use TVM as normally and set target to be ```llvm -target=asmjs-unknown-emscripten -system-lib```. - -The following code snippet from [tests/web/prepare_test_libs.py](https://github.com/apache/incubator-tvm/tree/master/tests/web/prepare_test_libs.py) demonstrate -the compilation process. - -```python -import tvm -from tvm import te -from tvm.contrib import emscripten -import os -def prepare_test_libs(base_path): - target = "llvm -target=asmjs-unknown-emscripten -system-lib" - if not tvm.runtime.enabled(target): - raise RuntimeError("Target %s is not enbaled" % target) - n = te.var("n") - A = te.placeholder((n,), name='A') - B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name='B') - s = te.create_schedule(B.op) - fadd1 = tvm.build(s, [A, B], target, name="add_one") - obj_path = os.path.join(base_path, "test_add_one.bc") - fadd1.save(obj_path) - emscripten.create_js(os.path.join(base_path, "test_module.js"), obj_path) - -if __name__ == "__main__": - curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) - prepare_test_libs(os.path.join(curr_path, "../../build")) -``` +Check code snippet in -In this workflow, we use TVM to generate a ```.bc``` file and statically link -that with the ```build/libtvm_web_runtime.bc```(emscripten.create_js will help you do that). -The result js library is a library that contains both TVM runtime and the compiled function. - - -## Run the Generated Library - -The following code snippet from [tests/web/test_module_load.js](https://github.com/apache/incubator-tvm/tree/master/tests/web/test_module_load.js) demonstrate -how to run the compiled library. - -```js -// Load Emscripten Module, need to change path to root/build -const path = require("path"); -process.chdir(path.join(__dirname, "../../build")); -var Module = require("../../build/test_module.js"); -// Bootstrap TVMruntime with emscripten module. -const tvm_runtime = require("../../web/tvm_runtime.js"); -const tvm = tvm_runtime.create(Module); - -// Load system library, the compiled function is registered in sysLib. -var sysLib = tvm.systemLib(); - -function randomArray(length, max) { - return Array.apply(null, Array(length)).map(function() { - return Math.random() * max; - }); -} - -function testAddOne() { - // grab pre-loaded function - var faddOne = sysLib.getFunction("add_one"); - var assert = require('assert'); - tvm.assert(tvm.isPackedFunc(faddOne)); - var n = 124; - var A = tvm.empty(n).copyFrom(randomArray(n, 1)); - var B = tvm.empty(n); - // call the function. - faddOne(A, B); - AA = A.asArray(); // retrieve values in js array - BB = B.asArray(); // retrieve values in js array - // verify - for (var i = 0; i < BB.length; ++i) { - assert(Math.abs(BB[i] - (AA[i] + 1)) < 1e-5); - } - faddOne.release(); -} - -testAddOne(); -sysLib.release(); -console.log("Finish verifying test_module_load"); -``` +- [tests/python/prepare_test_libs.py](https://github.com/apache/incubator-tvm/tree/master/web/tests/pythob/prepare_test_libs.py) + shows how to create a wasm library that links with tvm runtime. + - Note that all wasm libraries have to created using the `--system-lib` option + - emcc.create_wasm will automatically link the runtime library `dist/wasm/libtvm_runtime.bc` +- [tests/web/test_module_load.js](https://github.com/apache/incubator-tvm/tree/master/web/tests/node/test_module_load.js) demonstrate + how to run the generated library through tvmjs API. -Current example supports static linking, which is the preferred way to get more efficiency -in javascript backend. -## Proxy based RPC +## Run Wasm Remotely through WebSocket RPC. -We can now use javascript end to start an RPC server and connect to it from python side, +We can now use js side to start an RPC server and connect to it from python side, making the testing flow easier. -The following is an example to reproduce this. This requires everything to be in the git source and setup PYTHONPATH(instead of use setup.py install) -- run "python -m tvm.exec.rpc_proxy --example-rpc=1" to start proxy. -- Open broswer, goto the server webpage click Connect to proxy. - - Alternatively run "node web/example_rpc_node.js" -- run "python tests/web/websock_rpc_test.py" to run the rpc client. +The following is an example to reproduce this. +- run `python -m tvm.exec.rpc_proxy --example-rpc=1` to start proxy. +- Start the WebSocket RPC + - Browswer version: open https://localhost:8888, click connect to proxy + - NodeJS version: `npm run rpc` +- run `python tests/node/websock_rpc_test.py` to run the rpc client. + + +## WebGPU Experiments + +Web gpu is still experimental, so apis can change. +Right now we use the SPIRV to generate shaders that can be accepted by Chrome and Firefox. -The general idea is to use Emscripten's dynamic linking to dynamically load modules. +- Obtain a browser that support webgpu. + - So far only Chrome Canary on MacOS works + - Firefox should be close pending the support of Fence. +- Download vulkan SDK (1.1 or higher) that supports SPIRV 1.3 +- Start the WebSocket RPC +- run `python tests/node/webgpu_rpc_test.py` diff --git a/web/apps/browser/rpc_server.html b/web/apps/browser/rpc_server.html new file mode 100644 index 000000000000..6d353e29b08d --- /dev/null +++ b/web/apps/browser/rpc_server.html @@ -0,0 +1,79 @@ + + + + + + + + + + + + + + + + + + + + TVM RPC Test Page + + + + + +

TVM WebSocket RPC Server

+ To use this page +
    +
  • Run "make" and "npm run bundle" to create the libraries.
  • +
  • + run "python -m tvm.exec.rpc_proxy --example-rpc=1" to start proxy. +
  • +
  • Click Connect to proxy.
  • +
  • run "python tests/python/websock_rpc_test.py" to run the rpc client.
  • +
+ +

Options

+ Proxy URL
+ RPC Server Key
+ + +
+ + + diff --git a/web/apps/node/example.js b/web/apps/node/example.js new file mode 100644 index 000000000000..f81a9c903e5d --- /dev/null +++ b/web/apps/node/example.js @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/** + * Example code to start the runtime. + */ +const path = require("path"); +const fs = require("fs"); +const tvmjs = require("../../dist"); + +const wasmPath = tvmjs.wasmPath(); +const EmccWASI = require(path.join(wasmPath, "tvmjs_runtime.wasi.js")); +const wasmSource = fs.readFileSync(path.join(wasmPath, "tvmjs_runtime.wasm")); +// Here we pass the javascript module generated by emscripten as the +// LibraryProvider to provide WASI related libraries. +// the async version of the API. +tvmjs.instantiate(wasmSource, new EmccWASI()) +.then((tvm) => { + // List all the global functions from the runtime. + console.log("Runtime functions using EmccWASI\n", tvm.listGlobalFuncNames()); +}); + diff --git a/include/tvm/arith/util.h b/web/apps/node/wasi_example.js similarity index 54% rename from include/tvm/arith/util.h rename to web/apps/node/wasi_example.js index adfcefcd2e21..95ec2e0b1d07 100644 --- a/include/tvm/arith/util.h +++ b/web/apps/node/wasi_example.js @@ -16,30 +16,21 @@ * specific language governing permissions and limitations * under the License. */ - -/*! - * \file tvm/arith/util.h - * \brief Utils for arithmetic analysis. +/** + * Example code to start the runtime. */ -#ifndef TVM_ARITH_UTIL_H_ -#define TVM_ARITH_UTIL_H_ - -#include -#include +const { WASI } = require('wasi'); +const path = require("path"); +const fs = require("fs"); +const tvmjs = require("../../dist"); -namespace tvm { -/*! \brief namespace of arithmetic analysis. */ -namespace arith { +const wasmPath = tvmjs.wasmPath(); +const wasmSource = fs.readFileSync(path.join(wasmPath, "tvmjs_runtime.wasm")); -/*! - * \brief Calculate the extended greatest common divisor for two values. - * See https://en.wikipedia.org/wiki/Extended_Euclidean_algorithm. - * \param a an integer number - * \param b an integer number - * \return 3 integers (div, m, n) where div = gcd(a, b) and a*m + b*n = div - */ -std::tuple xgcd(int64_t a, int64_t b); +const wasi = new WASI({ args: process.argv, env: process.env }); +// Here we pass the javascript module generated by emscripten as the +// LibraryProvider to provide WASI related libraries. +const tvm = new tvmjs.Instance(new WebAssembly.Module(wasmSource), wasi); -} // namespace arith -} // namespace tvm -#endif // TVM_ARITH_UTIL_H_ +// List all the global functions from the runtime. +console.log("Runtime using WASI\n", tvm.listGlobalFuncNames()); diff --git a/web/apps/node/wasi_rpc_server.js b/web/apps/node/wasi_rpc_server.js new file mode 100644 index 000000000000..eb4c6ed52be9 --- /dev/null +++ b/web/apps/node/wasi_rpc_server.js @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/** + * Example code to start the RPC server on nodejs using WASI + */ +const { WASI } = require("wasi"); +const tvmjs = require("../../dist"); + +// Get import returns a fresh library in each call. +const getImports = () => { + return new WASI({ + args: process.argv, + env: process.env + }); +}; + +const proxyUrl = "ws://localhost:8888/ws"; + +new tvmjs.RPCServer(proxyUrl, "wasm", getImports, console.log); diff --git a/cmake/modules/OpenGL.cmake b/web/emcc/decorate_as_wasi.py similarity index 54% rename from cmake/modules/OpenGL.cmake rename to web/emcc/decorate_as_wasi.py index 38054f195650..741e33bb22ea 100644 --- a/cmake/modules/OpenGL.cmake +++ b/web/emcc/decorate_as_wasi.py @@ -14,22 +14,29 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +"""Decorate emcc generated js to a WASI compatible API.""" -find_package(OpenGL QUIET) +import sys -if(OpenGL_FOUND) - # always set the includedir when dir is available - # avoid global retrigger of cmake - include_directories(${OPENGL_INCLUDE_DIRS}) -endif(OpenGL_FOUND) +template_head = """ +function EmccWASI() { +""" -if(USE_OPENGL) - find_package(OpenGL REQUIRED) - find_package(glfw3 QUIET REQUIRED) - message(STATUS "Build with OpenGL support") - file(GLOB RUNTIME_OPENGL_SRCS src/runtime/opengl/*.cc) - list(APPEND TVM_RUNTIME_LINKER_LIBS ${OpenGL_LIBRARIES} glfw) - list(APPEND RUNTIME_SRCS ${RUNTIME_OPENGL_SRCS}) -else(USE_OPENGL) - list(APPEND COMPILER_SRCS src/target/opt/build_opengl_off.cc) -endif(USE_OPENGL) +template_tail = """ + this.Module = Module; + this.start = Module.wasmLibraryProvider.start; + this.imports = Module.wasmLibraryProvider.imports; + this.wasiImport = this.imports["wasi_snapshot_preview1"]; +} + +if (typeof module !== "undefined" && module.exports) { + module.exports = EmccWASI; +} +""" + +if __name__ == "__main__": + if len(sys.argv) != 3: + print("Usage ") + result = template_head + open(sys.argv[1]).read() + template_tail + with open(sys.argv[2], "w") as fo: + fo.write(result) diff --git a/web/emcc/preload.js b/web/emcc/preload.js new file mode 100644 index 000000000000..882280f9cac0 --- /dev/null +++ b/web/emcc/preload.js @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/* eslint-disable no-unused-vars */ +/** + * JS config used by --pre-js in emcc. + * Wrap module as a LibraryProvider. + */ + +var __wasmLib = {}; + +function __wasmLibInstantiateWasm(imports, successCallback) { + __wasmLib.imports = imports; + __wasmLib.successCallback = successCallback; +} + +function __wasmLibStart(wasmInstance) { + __wasmLib.successCallback(wasmInstance); +} + +__wasmLib.start = __wasmLibStart; + +var Module = { + "instantiateWasm": __wasmLibInstantiateWasm, + "wasmLibraryProvider": __wasmLib +}; diff --git a/web/emcc/tvmjs_support.cc b/web/emcc/tvmjs_support.cc new file mode 100644 index 000000000000..6abd12252d1d --- /dev/null +++ b/web/emcc/tvmjs_support.cc @@ -0,0 +1,290 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* + * \file tvmjs_support.cc + * \brief Support functions to be linked with wasm_runtime to provide + * PackedFunc callbacks in tvmjs. + * We do not need to link this file in standalone wasm. + */ + +// configurations for the dmlc log. +#define DMLC_LOG_CUSTOMIZE 0 +#define DMLC_LOG_STACK_TRACE 0 +#define DMLC_LOG_DEBUG 0 +#define DMLC_LOG_NODATE 1 +#define DMLC_LOG_FATAL_THROW 0 + +#include +#include +#include +#include +#include + +#include "../../src/runtime/rpc/rpc_local_session.h" + +extern "C" { +// --- Additional C API for the Wasm runtime --- +/*! + * \brief Allocate space aligned to 64 bit. + * \param size The size of the space. + * \return The allocated space. + */ +TVM_DLL void* TVMWasmAllocSpace(int size); + +/*! + * \brief Free the space allocated by TVMWasmAllocSpace. + * \param data The data pointer. + */ +TVM_DLL void TVMWasmFreeSpace(void* data); + +/*! + * \brief Create PackedFunc from a resource handle. + * \param resource_handle The handle to the resource. + * \param out The output PackedFunc. + * \sa TVMWasmPackedCFunc, TVMWasmPackedCFuncFinalizer +3A * \return 0 if success. + */ +TVM_DLL int TVMWasmFuncCreateFromCFunc(void* resource_handle, TVMFunctionHandle* out); + +// --- APIs to be implemented by the frontend. --- +/*! + * \brief Wasm frontend packed function caller. + * + * \param args The arguments + * \param type_codes The type codes of the arguments + * \param num_args Number of arguments. + * \param ret The return value handle. + * \param resource_handle The handle additional resouce handle from fron-end. + * \return 0 if success, -1 if failure happens, set error via TVMAPISetLastError. + */ +extern int TVMWasmPackedCFunc(TVMValue* args, int* type_codes, int num_args, TVMRetValueHandle ret, + void* resource_handle); + +/*! + * \brief Wasm frontend resource finalizer. + * \param resource_handle The pointer to the external resource. + */ +extern void TVMWasmPackedCFuncFinalizer(void* resource_handle); +} // extern "C" + +void* TVMWasmAllocSpace(int size) { + int num_count = (size + 7) / 8; + return new int64_t[num_count]; +} + +void TVMWasmFreeSpace(void* arr) { delete[] static_cast(arr); } + +int TVMWasmFuncCreateFromCFunc(void* resource_handle, TVMFunctionHandle* out) { + return TVMFuncCreateFromCFunc(TVMWasmPackedCFunc, resource_handle, TVMWasmPackedCFuncFinalizer, + out); +} + +namespace tvm { +namespace runtime { + +// A special local session that can interact with async +// functions in the JS runtime. +class AsyncLocalSession : public LocalSession { + public: + AsyncLocalSession() {} + + PackedFuncHandle GetFunction(const std::string& name) final { + if (name == "runtime.RPCTimeEvaluator") { + return get_time_eval_placeholder_.get(); + } else if (auto* fp = tvm::runtime::Registry::Get(name)) { + // return raw handle because the remote need to explicitly manage it. + return new PackedFunc(*fp); + } else if (auto* fp = tvm::runtime::Registry::Get("__async." + name)) { + auto* rptr = new PackedFunc(*fp); + async_func_set_.insert(rptr); + return rptr; + } else { + return nullptr; + } + } + + void FreeHandle(void* handle, int type_code) final { + if (type_code == kTVMPackedFuncHandle) { + auto it = async_func_set_.find(handle); + if (it != async_func_set_.end()) { + async_func_set_.erase(it); + } + } + if (handle != get_time_eval_placeholder_.get()) { + LocalSession::FreeHandle(handle, type_code); + } + } + + void AsyncCallFunc(PackedFuncHandle func, const TVMValue* arg_values, const int* arg_type_codes, + int num_args, FAsyncCallback callback) final { + auto it = async_func_set_.find(func); + if (it != async_func_set_.end()) { + PackedFunc packed_callback([callback, this](TVMArgs args, TVMRetValue*) { + int code = args[0]; + TVMRetValue rv; + rv = args[1]; + this->EncodeReturn(std::move(rv), + [&](TVMArgs encoded_args) { callback(RPCCode::kReturn, encoded_args); }); + }); + + TVMRetValue temp; + std::vector values(arg_values, arg_values + num_args); + std::vector type_codes(arg_type_codes, arg_type_codes + num_args); + values.emplace_back(TVMValue()); + type_codes.emplace_back(0); + + TVMArgsSetter setter(&values[0], &type_codes[0]); + // pass the callback as the last argument. + setter(num_args, packed_callback); + + auto* pf = static_cast(func); + pf->CallPacked(TVMArgs(values.data(), type_codes.data(), num_args + 1), &temp); + } else if (func == get_time_eval_placeholder_.get()) { + // special handle time evaluator. + try { + TVMArgs args(arg_values, arg_type_codes, num_args); + PackedFunc retfunc = + this->GetTimeEvaluator(args[0], args[1], args[2], args[3], args[4], args[5], args[6]); + TVMRetValue rv; + rv = retfunc; + this->EncodeReturn(std::move(rv), [&](TVMArgs encoded_args) { + // mark as async. + async_func_set_.insert(encoded_args.values[1].v_handle); + callback(RPCCode::kReturn, encoded_args); + }); + } catch (const std::runtime_error& e) { + this->SendException(callback, e.what()); + } + } else { + LocalSession::AsyncCallFunc(func, arg_values, arg_type_codes, num_args, callback); + } + } + + void AsyncCopyToRemote(void* local_from, size_t local_from_offset, void* remote_to, + size_t remote_to_offset, size_t nbytes, TVMContext remote_ctx_to, + DLDataType type_hint, FAsyncCallback on_complete) final { + TVMContext cpu_ctx; + cpu_ctx.device_type = kDLCPU; + cpu_ctx.device_id = 0; + try { + this->GetDeviceAPI(remote_ctx_to) + ->CopyDataFromTo(local_from, local_from_offset, remote_to, remote_to_offset, nbytes, + cpu_ctx, remote_ctx_to, type_hint, nullptr); + this->AsyncStreamWait(remote_ctx_to, nullptr, on_complete); + } catch (const std::runtime_error& e) { + this->SendException(on_complete, e.what()); + } + } + + void AsyncCopyFromRemote(void* remote_from, size_t remote_from_offset, void* local_to, + size_t local_to_offset, size_t nbytes, TVMContext remote_ctx_from, + DLDataType type_hint, FAsyncCallback on_complete) final { + TVMContext cpu_ctx; + cpu_ctx.device_type = kDLCPU; + cpu_ctx.device_id = 0; + try { + this->GetDeviceAPI(remote_ctx_from) + ->CopyDataFromTo(remote_from, remote_from_offset, local_to, local_to_offset, nbytes, + remote_ctx_from, cpu_ctx, type_hint, nullptr); + this->AsyncStreamWait(remote_ctx_from, nullptr, on_complete); + } catch (const std::runtime_error& e) { + this->SendException(on_complete, e.what()); + } + } + + void AsyncStreamWait(TVMContext ctx, TVMStreamHandle stream, FAsyncCallback on_complete) final { + if (ctx.device_type == kDLCPU) { + TVMValue value; + int32_t tcode = kTVMNullptr; + value.v_handle = nullptr; + on_complete(RPCCode::kReturn, TVMArgs(&value, &tcode, 1)); + } else { + CHECK(ctx.device_type == static_cast(kDLWebGPU)); + if (async_wait_ == nullptr) { + async_wait_ = tvm::runtime::Registry::Get("__async.wasm.WebGPUWaitForTasks"); + } + CHECK(async_wait_ != nullptr); + PackedFunc packed_callback([on_complete](TVMArgs args, TVMRetValue*) { + int code = args[0]; + on_complete(static_cast(code), + TVMArgs(args.values + 1, args.type_codes + 1, args.size() - 1)); + }); + (*async_wait_)(packed_callback); + } + } + + bool IsAsync() const final { return true; } + + private: + std::unordered_set async_func_set_; + std::unique_ptr get_time_eval_placeholder_ = std::make_unique(); + const PackedFunc* async_wait_{nullptr}; + + // time evaluator + PackedFunc GetTimeEvaluator(Optional opt_mod, std::string name, int device_type, + int device_id, int number, int repeat, int min_repeat_ms) { + TVMContext ctx; + ctx.device_type = static_cast(device_type); + ctx.device_id = device_id; + + if (opt_mod.defined()) { + Module m = opt_mod.value(); + std::string tkey = m->type_key(); + return WrapWasmTimeEvaluator(m.GetFunction(name, false), ctx, number, repeat, min_repeat_ms); + } else { + auto* pf = runtime::Registry::Get(name); + CHECK(pf != nullptr) << "Cannot find " << name << " in the global function"; + return WrapWasmTimeEvaluator(*pf, ctx, number, repeat, min_repeat_ms); + } + } + + // time evaluator + PackedFunc WrapWasmTimeEvaluator(PackedFunc pf, TVMContext ctx, int number, int repeat, + int min_repeat_ms) { + auto ftimer = [pf, ctx, number, repeat, min_repeat_ms](TVMArgs args, TVMRetValue* rv) { + // the function is a async function. + PackedFunc on_complete = args[args.size() - 1]; + // keep argument alive in finvoke so that they + // can be used throughout the async benchmark + std::vector values(args.values, args.values + args.size() - 1); + std::vector type_codes(args.type_codes, args.type_codes + args.size() - 1); + + auto finvoke = [pf, values, type_codes](int n) { + TVMRetValue temp; + TVMArgs invoke_args(values.data(), type_codes.data(), values.size()); + for (int i = 0; i < n; ++i) { + pf.CallPacked(invoke_args, &temp); + } + }; + auto* time_exec = runtime::Registry::Get("__async.wasm.TimeExecution"); + CHECK(time_exec != nullptr) << "Cannot find wasm.GetTimer in the global function"; + (*time_exec)(TypedPackedFunc(finvoke), ctx, number, repeat, min_repeat_ms, + on_complete); + }; + return PackedFunc(ftimer); + } +}; + +TVM_REGISTER_GLOBAL("wasm.LocalSession").set_body_typed([]() { + return CreateRPCSessionModule(std::make_shared()); +}); + +} // namespace runtime +} // namespace tvm diff --git a/web/emcc/wasm_runtime.cc b/web/emcc/wasm_runtime.cc new file mode 100644 index 000000000000..a67b4c3dcd14 --- /dev/null +++ b/web/emcc/wasm_runtime.cc @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* + * \file wasm_runtime.cc + * \brief TVM wasm runtime library pack. + */ + +// configurations for the dmlc log. +#define DMLC_LOG_CUSTOMIZE 0 +#define DMLC_LOG_STACK_TRACE 0 +#define DMLC_LOG_DEBUG 0 +#define DMLC_LOG_NODATE 1 +#define DMLC_LOG_FATAL_THROW 0 + +#include +#include + +#include "src/runtime/c_runtime_api.cc" +#include "src/runtime/cpu_device_api.cc" +#include "src/runtime/file_util.cc" +#include "src/runtime/graph/graph_runtime.cc" +#include "src/runtime/library_module.cc" +#include "src/runtime/module.cc" +#include "src/runtime/ndarray.cc" +#include "src/runtime/object.cc" +#include "src/runtime/registry.cc" +#include "src/runtime/rpc/rpc_channel.cc" +#include "src/runtime/rpc/rpc_endpoint.cc" +#include "src/runtime/rpc/rpc_event_impl.cc" +#include "src/runtime/rpc/rpc_local_session.cc" +#include "src/runtime/rpc/rpc_module.cc" +#include "src/runtime/rpc/rpc_session.cc" +#include "src/runtime/system_library.cc" +#include "src/runtime/workspace_pool.cc" + +// --- Implementations of backend and wasm runtime API. --- + +int TVMBackendParallelLaunch(FTVMParallelLambda flambda, void* cdata, int num_task) { + TVMParallelGroupEnv env; + env.num_task = 1; + flambda(0, &env, cdata); + return 0; +} + +int TVMBackendParallelBarrier(int task_id, TVMParallelGroupEnv* penv) { return 0; } + +// --- Environment PackedFuncs for testing --- +namespace tvm { +namespace runtime { + +TVM_REGISTER_GLOBAL("testing.echo").set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = args[0]; +}); + +TVM_REGISTER_GLOBAL("testing.add_one").set_body_typed([](int x) { return x + 1; }); + +TVM_REGISTER_GLOBAL("testing.wrap_callback").set_body([](TVMArgs args, TVMRetValue* ret) { + PackedFunc pf = args[0]; + *ret = runtime::TypedPackedFunc([pf]() { pf(); }); +}); +} // namespace runtime +} // namespace tvm diff --git a/web/emcc/webgpu_runtime.cc b/web/emcc/webgpu_runtime.cc new file mode 100644 index 000000000000..7f0b0d9f72cb --- /dev/null +++ b/web/emcc/webgpu_runtime.cc @@ -0,0 +1,230 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* + * \file webgpu_runtime.cc + * \brief WebGPU runtime based on the TVM JS. + */ + +// configurations for the dmlc log. +#define DMLC_LOG_CUSTOMIZE 0 +#define DMLC_LOG_STACK_TRACE 0 +#define DMLC_LOG_DEBUG 0 +#define DMLC_LOG_NODATE 1 +#define DMLC_LOG_FATAL_THROW 0 + +#include +#include +#include +#include +#include + +#include "../../src/runtime/meta_data.h" +#include "../../src/runtime/vulkan/vulkan_shader.h" +#include "../../src/runtime/workspace_pool.h" + +namespace tvm { +namespace runtime { + +/*! \brief Thread local workspace */ +class WebGPUThreadEntry { + public: + /*! \brief thread local pool*/ + WorkspacePool pool; + /*! \brief constructor */ + WebGPUThreadEntry(); + // get the threadlocal workspace + static WebGPUThreadEntry* ThreadLocal(); +}; + +// All the implementations are redirectly to the JS side. +class WebGPUDeviceAPI : public DeviceAPI { + public: + WebGPUDeviceAPI() { + auto* fp = tvm::runtime::Registry::Get("wasm.WebGPUDeviceAPI"); + CHECK(fp != nullptr) << "Cannot find wasm.WebGPUContext in the env"; + auto getter = TypedPackedFunc(*fp); + alloc_space_ = getter("deviceAllocDataSpace"); + free_space_ = getter("deviceFreeDataSpace"); + copy_to_gpu_ = getter("deviceCopyToGPU"); + copy_from_gpu_ = getter("deviceCopyFromGPU"); + copy_within_gpu_ = getter("deviceCopyWithinGPU"); + } + + void SetDevice(TVMContext ctx) final {} + void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final { + if (kind == kExist) { + *rv = 1; + } + } + + void* AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment, + DLDataType type_hint) final { + double ptr_number = alloc_space_(nbytes); + return reinterpret_cast(static_cast(ptr_number)); + } + + void FreeDataSpace(TVMContext ctx, void* ptr) final { return free_space_(ptr); } + + void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size, + TVMContext ctx_from, TVMContext ctx_to, DLDataType type_hint, + TVMStreamHandle stream) final { + if (static_cast(ctx_from.device_type) == kDLWebGPU && + static_cast(ctx_to.device_type) == kDLWebGPU) { + CHECK_EQ(ctx_from.device_id, ctx_to.device_id); + copy_within_gpu_(const_cast(from), from_offset, to, to_offset, size); + } else if (static_cast(ctx_from.device_type) == kDLWebGPU && + ctx_to.device_type == kDLCPU) { + void* to_ptr = static_cast(to) + to_offset; + copy_from_gpu_(const_cast(from), from_offset, to_ptr, size); + } else if (ctx_from.device_type == kDLCPU && + static_cast(ctx_to.device_type) == kDLWebGPU) { + void* from_ptr = static_cast(const_cast(from)) + from_offset; + copy_to_gpu_(from_ptr, to, to_offset, size); + } else { + LOG(FATAL) << "expect copy from/to WebGPU or between WebGPU"; + } + } + + TVMStreamHandle CreateStream(TVMContext ctx) final { + LOG(FATAL) << "Not implemented"; + return nullptr; + } + + void FreeStream(TVMContext ctx, TVMStreamHandle stream) final { + LOG(FATAL) << "Not implemented"; + return; + } + + void SyncStreamFromTo(TVMContext ctx, TVMStreamHandle event_src, TVMStreamHandle event_dst) { + LOG(FATAL) << "Not implemented"; + return; + } + + void StreamSync(TVMContext ctx, TVMStreamHandle stream) final { LOG(FATAL) << "Not implemented"; } + + void SetStream(TVMContext ctx, TVMStreamHandle stream) final { + LOG(FATAL) << "Not implemented"; + return; + } + + void* AllocWorkspace(TVMContext ctx, size_t size, DLDataType type_hint) final { + return WebGPUThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size); + } + + void FreeWorkspace(TVMContext ctx, void* data) final { + WebGPUThreadEntry::ThreadLocal()->pool.FreeWorkspace(ctx, data); + } + + static const std::shared_ptr& Global() { + static std::shared_ptr inst = std::make_shared(); + return inst; + } + + private: + // NOTE: js return number as double. + TypedPackedFunc alloc_space_; + TypedPackedFunc free_space_; + TypedPackedFunc copy_to_gpu_; + TypedPackedFunc copy_from_gpu_; + TypedPackedFunc + copy_within_gpu_; +}; + +typedef dmlc::ThreadLocalStore WebGPUThreadStore; + +WebGPUThreadEntry::WebGPUThreadEntry() + : pool(static_cast(kDLWebGPU), WebGPUDeviceAPI::Global()) {} + +WebGPUThreadEntry* WebGPUThreadEntry::ThreadLocal() { return WebGPUThreadStore::Get(); } + +class WebGPUModuleNode final : public runtime::ModuleNode { + public: + explicit WebGPUModuleNode(std::unordered_map smap, + std::unordered_map fmap, std::string source) + : smap_(smap), fmap_(fmap), source_(source) { + auto* fp = tvm::runtime::Registry::Get("wasm.WebGPUCreateShader"); + CHECK(fp != nullptr); + create_shader_ = *fp; + } + + const char* type_key() const final { return "webgpu"; } + + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { + auto it = smap_.find(name); + if (it != smap_.end()) { + FunctionInfo info = fmap_.at(name); + info.name = name; + std::ostringstream os; + dmlc::JSONWriter writer(&os); + info.Save(&writer); + TVMByteArray arr; + arr.data = reinterpret_cast(it->second.data.data()); + arr.size = it->second.data.size() * sizeof(it->second.data[0]); + return create_shader_(os.str(), arr); + } else { + return PackedFunc(nullptr); + } + } + + void SaveToFile(const std::string& file_name, const std::string& format) final { + LOG(FATAL) << "Not implemented"; + } + + void SaveToBinary(dmlc::Stream* stream) final { LOG(FATAL) << "Not implemented"; } + + std::string GetSource(const std::string& format) final { + // can only return source code. + return source_; + } + + private: + // function information table. + std::unordered_map smap_; + // function information table. + std::unordered_map fmap_; + // The source + std::string source_; + // Callback to get the GPU function. + TypedPackedFunc create_shader_; +}; + +Module WebGPUModuleLoadBinary(void* strm) { + dmlc::Stream* stream = static_cast(strm); + std::unordered_map smap; + std::unordered_map fmap; + + std::string fmt; + stream->Read(&fmt); + stream->Read(&fmap); + stream->Read(&smap); + return Module(make_object(smap, fmap, "")); +} + +// for now webgpu is hosted via a vulkan module. +TVM_REGISTER_GLOBAL("runtime.module.loadbinary_vulkan").set_body_typed(WebGPUModuleLoadBinary); + +TVM_REGISTER_GLOBAL("device_api.webgpu").set_body([](TVMArgs args, TVMRetValue* rv) { + DeviceAPI* ptr = WebGPUDeviceAPI::Global().get(); + *rv = static_cast(ptr); +}); + +} // namespace runtime +} // namespace tvm diff --git a/web/example_rpc.html b/web/example_rpc.html deleted file mode 100644 index ae2b1dd9c44b..000000000000 --- a/web/example_rpc.html +++ /dev/null @@ -1,61 +0,0 @@ - - - - - - - - - - - - - - - - - - - TVM RPC Test Page - - - - -

TVM Test Page

- To use this page, the easiest way is to do -
    -
  • run "python -m tvm.exec.rpc_proxy --example-rpc=1" to start proxy. -
  • Click Connect to proxy. -
  • run "python tests/web/websock_rpc_test.py" to run the rpc client. -
-

Options

- Proxy URL
- RPC Server Key
- - -
- - - - diff --git a/web/.eslintrc.js b/web/jest.config.js similarity index 70% rename from web/.eslintrc.js rename to web/jest.config.js index 2e82ba50e3c4..23cbd8b4ab04 100644 --- a/web/.eslintrc.js +++ b/web/jest.config.js @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -17,29 +17,11 @@ * under the License. */ +/* eslint-disable no-undef */ module.exports = { - "env": { - "browser": true, - "node": true, - "es6": true - }, - "extends": "eslint:recommended", - "rules": { - "indent": [ - "error", - 2 - ], - "linebreak-style": [ - "error", - "unix" - ], - "quotes": [ - "error", - "double" - ], - "semi": [ - "error", - "always" - ] - } + testEnvironment: "node", + + testMatch: [ + "**/tests/node/*.js" + ], }; diff --git a/web/package.json b/web/package.json new file mode 100644 index 000000000000..25fca5088e78 --- /dev/null +++ b/web/package.json @@ -0,0 +1,32 @@ +{ + "name": "tvmjs", + "displayName": "TVM Wasm JS runtime", + "license": "Apache-2.0", + "version": "0.7.0", + "scripts": { + "prepwasm": "make && python3 tests/python/prepare_test_libs.py", + "build": "tsc -b && make rmtypedep", + "lint": "eslint -c .eslintrc.json .", + "typedoc": "typedoc .", + "test": "jest", + "bundle": "npm run build && rollup -c rollup.config.js", + "example": "npm run bundle && node apps/node/example.js", + "example:wasi": "npm run bundle && node --experimental-wasi-unstable-preview1 --experimental-wasm-bigint apps/node/wasi_example.js", + "rpc": "npm run bundle && node --experimental-wasi-unstable-preview1 --experimental-wasm-bigint apps/node/wasi_rpc_server.js" + }, + "devDependencies": { + "@rollup/plugin-commonjs": "^11.1.0", + "@rollup/plugin-node-resolve": "^7.1.3", + "@types/node": "^12.12.37", + "@typescript-eslint/eslint-plugin": "^2.29.0", + "@typescript-eslint/parser": "^2.29.0", + "@webgpu/types": "^0.0.24", + "eslint": "^6.8.0", + "jest": "^26.0.1", + "rollup": "^2.7.6", + "rollup-plugin-typescript2": "^0.27.0", + "typedoc": "^0.17.6", + "typescript": "^3.8.3", + "ws": "^7.2.5" + } +} diff --git a/web/rollup.config.js b/web/rollup.config.js new file mode 100644 index 000000000000..9090c77868fe --- /dev/null +++ b/web/rollup.config.js @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import commonjs from '@rollup/plugin-commonjs'; +import resolve from '@rollup/plugin-node-resolve'; + +export default { + input: 'dist/index.js', + output: { + file: 'dist/tvmjs.bundle.js', + format: 'umd', + name: 'tvmjs', + exports: 'named', + globals: {'ws': 'ws', + 'perf_hooks': 'perf_hooks', + '@webgpu/types': 'webgputypes'} + }, + plugins: [commonjs(), resolve()], + external: ['ws', 'perf_hooks', '@webgpu/types'] +}; diff --git a/web/src/compact.ts b/web/src/compact.ts new file mode 100644 index 000000000000..29569b5d005d --- /dev/null +++ b/web/src/compact.ts @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/** NodeJS and Web compact layer */ + +/** + * Get performance masurement. + */ +export function getPeformance(): Performance { + if (typeof performance == "undefined") { + // eslint-disable-next-line @typescript-eslint/no-var-requires + const performanceNode = require("perf_hooks"); + return performanceNode.performance as Performance; + } else { + return performance as Performance; + } +} + +/** + * Create a new websocket for a given URL + * @param url The url. + */ +export function createWebSocket(url: string): WebSocket { + if (typeof WebSocket == "undefined") { + // eslint-disable-next-line @typescript-eslint/no-var-requires + const WebSocket = require("ws"); + return new WebSocket(url); + } else { + return new (WebSocket as any)(url); + } + +} \ No newline at end of file diff --git a/web/src/ctypes.ts b/web/src/ctypes.ts new file mode 100644 index 000000000000..66c46fe7ed91 --- /dev/null +++ b/web/src/ctypes.ts @@ -0,0 +1,229 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/** + * Types for C API. + */ + +/** A pointer to points to the raw address space. */ +export type Pointer = number; + +/** A pointer offset, need to add a base address to get a valid ptr. */ +export type PtrOffset = number; + +// -- TVM runtime C API -- +/** + * const char *TVMGetLastError(); + */ +export type FTVMGetLastError = () => Pointer; + +/** + * int TVMModGetFunction(TVMModuleHandle mod, + * const char* func_name, + * int query_imports, + * TVMFunctionHandle *out); + */ +export type FTVMModGetFunction = ( + mod: Pointer, funcName: Pointer, queryImports: number, out: Pointer) => number; +/** + * int TVMModImport(TVMModuleHandle mod, + * TVMModuleHandle dep); + */ +export type FTVMModImport = (mod: Pointer, dep: Pointer) => number; +/** + * int TVMModFree(TVMModuleHandle mod); + */ +export type FTVMModFree = (mod: Pointer) => number; + +/** + * int TVMFuncFree(TVMFunctionHandle func); + */ +export type FTVMFuncFree = (func: Pointer) => number; + +/** + * int TVMFuncCall(TVMFunctionHandle func, + * TVMValue* arg_values, + * int* type_codes, + * int num_args, + * TVMValue* ret_val, + * int* ret_type_code); + */ +export type FTVMFuncCall = ( + func: Pointer, argValues: Pointer, typeCode: Pointer, + nargs: number, retValue: Pointer, retCode: Pointer) => number; + +/** + * int TVMCFuncSetReturn(TVMRetValueHandle ret, + * TVMValue* value, + * int* type_code, + * int num_ret); + */ +export type FTVMCFuncSetReturn = ( + ret: Pointer, value: Pointer, typeCode: Pointer, numRet: number) => number; + +/** + * int TVMCbArgToReturn(TVMValue* value, int* code); + */ +export type FTVMCbArgToReturn = (value: Pointer, code: Pointer) => number; + +/** + * int TVMFuncListGlobalNames(int* outSize, const char*** outArray); + */ +export type FTVMFuncListGlobalNames = (outSize: Pointer, outArray: Pointer) => number; + +/** + * int TVMFuncRegisterGlobal( + * const char* name, TVMFunctionHandle f, int override); + */ +export type FTVMFuncRegisterGlobal = ( + name: Pointer, f: Pointer, override: number) => number; + +/** + *int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out); + */ +export type FTVMFuncGetGlobal = (name: Pointer, out: Pointer) => number; + +/** + * int TVMArrayAlloc(const tvm_index_t* shape, + * int ndim, + * int dtype_code, + * int dtype_bits, + * int dtype_lanes, + * int device_type, + * int device_id, + * TVMArrayHandle* out); + */ +export type FTVMArrayAlloc = ( + shape: Pointer, ndim: number, + dtypeCode: number, dtypeBits: number, + dtypeLanes: number, deviceType: number, deviceId: number, + out: Pointer) => number; + +/** + * int TVMArrayFree(TVMArrayHandle handle); + */ +export type FTVMArrayFree = (handle: Pointer) => number; + +/** + * int TVMArrayCopyFromBytes(TVMArrayHandle handle, + * void* data, + * size_t nbytes); + */ +export type FTVMArrayCopyFromBytes = ( + handle: Pointer, data: Pointer, nbytes: number) => number; + +/** + * int TVMArrayCopyToBytes(TVMArrayHandle handle, + * void* data, + * size_t nbytes); + */ +export type FTVMArrayCopyToBytes = ( + handle: Pointer, data: Pointer, nbytes: number) => number; + +/** + * int TVMArrayCopyFromTo(TVMArrayHandle from, + * TVMArrayHandle to, + * TVMStreamHandle stream); + */ +export type FTVMArrayCopyFromTo = ( + from: Pointer, to: Pointer, stream: Pointer) => number; + +/** + * int TVMSynchronize(int device_type, int device_id, TVMStreamHandle stream); + */ +export type FTVMSynchronize = ( + deviceType: number, deviceId: number, stream: Pointer) => number; + +/** + * typedef int (*TVMBackendPackedCFunc)(TVMValue* args, + * int* type_codes, + * int num_args, + * TVMValue* out_ret_value, + * int* out_ret_tcode); + */ +export type FTVMBackendPackedCFunc = ( + argValues: Pointer, argCodes: Pointer, nargs: number, + outValue: Pointer, outCode: Pointer) => number; + +// -- TVM Wasm Auxiliary C API -- + +/** void* TVMWasmAllocSpace(int size); */ +export type FTVMWasmAllocSpace = (size: number) => Pointer; + +/** void TVMWasmFreeSpace(void* data); */ +export type FTVMWasmFreeSpace = (ptr: Pointer) => void; + +/** + * int TVMWasmPackedCFunc(TVMValue* args, + * int* type_codes, + * int num_args, + * TVMRetValueHandle ret, + * void* resource_handle); + */ +export type FTVMWasmPackedCFunc = ( + args: Pointer, typeCodes: Pointer, nargs: number, + ret: Pointer, resourceHandle: Pointer) => number; + +/** + * int TVMWasmFuncCreateFromCFunc(void* resource_handle, + * TVMFunctionHandle *out); + */ +export type FTVMWasmFuncCreateFromCFunc = ( + resource: Pointer, out: Pointer) => number; + +/** + * void TVMWasmPackedCFuncFinalizer(void* resource_handle); + */ +export type FTVMWasmPackedCFuncFinalizer = (resourceHandle: Pointer) => void; + +/** + * Size of common data types. + */ +export const enum SizeOf { + U8 = 1, + U16 = 2, + I32 = 4, + I64 = 8, + F32 = 4, + F64 = 8, + TVMValue = 8, + DLDataType = I32, + DLContext = I32 + I32, +} + +/** + * Argument Type code in TVM FFI. + */ +export const enum ArgTypeCode { + Int = 0, + UInt = 1, + Float = 2, + TVMOpaqueHandle = 3, + Null = 4, + TVMDataType = 5, + TVMContext = 6, + TVMDLTensorHandle = 7, + TVMObjectHandle = 8, + TVMModuleHandle = 9, + TVMPackedFuncHandle = 10, + TVMStr = 11, + TVMBytes = 12, + TVMNDArrayHandle = 13, + TVMObjectRValueRefArg = 14 +} diff --git a/web/src/environment.ts b/web/src/environment.ts new file mode 100644 index 000000000000..df0fe68c81e0 --- /dev/null +++ b/web/src/environment.ts @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/** + * Runtime environment that provide js libaries calls. + */ +import { Pointer } from "./ctypes"; +import { LibraryProvider } from "./types"; +import { assert } from "./support"; +import * as ctypes from "./ctypes"; + +/** + * Detect library provider from the importObject. + * + * @param importObject The import object. + */ +function detectLibraryProvider( + importObject: Record +): LibraryProvider | undefined { + if ( + importObject["wasmLibraryProvider"] && + importObject["wasmLibraryProvider"]["start"] && + importObject["wasmLibraryProvider"]["imports"] !== undefined + ) { + const item = importObject as { wasmLibraryProvider: LibraryProvider }; + // create provider so that we capture imports in the provider. + return { + imports: item.wasmLibraryProvider.imports, + start: (inst: WebAssembly.Instance): void => { + item.wasmLibraryProvider.start(inst); + }, + }; + } else if (importObject["imports"] && importObject["start"] !== undefined) { + return importObject as LibraryProvider; + } else if (importObject["wasiImport"] && importObject["start"] !== undefined) { + // WASI + return { + imports: { + "wasi_snapshot_preview1": importObject["wasiImport"], + }, + start: (inst: WebAssembly.Instance): void => { + importObject["start"](inst); + } + }; + } else { + return undefined; + } +} + +/** + * Environment to impelement most of the JS library functions. + */ +export class Environment implements LibraryProvider { + logger: (msg: string) => void; + imports: Record; + /** + * Maintains a table of FTVMWasmPackedCFunc that the C part + * can call via TVMWasmPackedCFunc. + * + * We maintain a separate table so that we can have un-limited amount + * of functions that do not maps to the address space. + */ + packedCFuncTable: Array = [ + undefined, + ]; + /** + * Free table index that can be recycled. + */ + packedCFuncTableFreeId: Array = []; + + private libProvider?: LibraryProvider; + + constructor( + importObject: Record = {}, + logger: (msg: string) => void = console.log + ) { + this.logger = logger; + this.libProvider = detectLibraryProvider(importObject); + // get imports from the provider + if (this.libProvider !== undefined) { + this.imports = this.libProvider.imports; + } else { + this.imports = importObject; + } + // update with more functions + this.imports.env = this.environment(this.imports.env); + } + + /** Mark the start of the instance. */ + start(inst: WebAssembly.Instance): void { + if (this.libProvider !== undefined) { + this.libProvider.start(inst); + } + } + + private environment(initEnv: Record): Record { + // default env can be be overriden by libraries. + const defaultEnv = { + "__cxa_thread_atexit": (): void => {}, + // eslint-disable-next-line @typescript-eslint/no-unused-vars + "emscripten_notify_memory_growth": (index: number): void => {} + }; + const wasmPackedCFunc: ctypes.FTVMWasmPackedCFunc = ( + args: Pointer, + typeCodes: Pointer, + nargs: number, + ret: Pointer, + resourceHandle: Pointer + ): number => { + const cfunc = this.packedCFuncTable[resourceHandle]; + assert(cfunc !== undefined); + return cfunc(args, typeCodes, nargs, ret, resourceHandle); + }; + + const wasmPackedCFuncFinalizer: ctypes.FTVMWasmPackedCFuncFinalizer = ( + resourceHandle: Pointer + ): void => { + this.packedCFuncTable[resourceHandle] = undefined; + this.packedCFuncTableFreeId.push(resourceHandle); + }; + + const newEnv = { + TVMWasmPackedCFunc: wasmPackedCFunc, + TVMWasmPackedCFuncFinalizer: wasmPackedCFuncFinalizer, + "__console_log": (msg: string): void => { + this.logger(msg); + } + }; + return Object.assign(defaultEnv, initEnv, newEnv); + } +} \ No newline at end of file diff --git a/web/src/index.ts b/web/src/index.ts new file mode 100644 index 000000000000..2d99fc9106cc --- /dev/null +++ b/web/src/index.ts @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +export { + Scalar, DLContext, DLDataType, + PackedFunc, Module, NDArray, Instance, + instantiate +} from "./runtime"; +export { Disposable, LibraryProvider } from "./types"; +export { RPCServer } from "./rpc_server"; +export { wasmPath } from "./support"; +export { detectGPUDevice } from "./webgpu"; +export { assert } from "./support"; \ No newline at end of file diff --git a/web/src/memory.ts b/web/src/memory.ts new file mode 100644 index 000000000000..ac737b7c297d --- /dev/null +++ b/web/src/memory.ts @@ -0,0 +1,408 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/** + * Classes to manipulate Wasm memories. + */ +import { Pointer, PtrOffset, SizeOf } from "./ctypes"; +import { Disposable } from "./types"; +import { assert, StringToUint8Array } from "./support"; + +import * as ctypes from "./ctypes"; + +/** + * Wasm Memory wrapper to perform JS side raw memory access. + */ +export class Memory { + memory: WebAssembly.Memory; + wasm32 = true; + private buffer: ArrayBuffer | SharedArrayBuffer; + private viewU8: Uint8Array; + private viewU16: Uint16Array; + private viewI32: Int32Array; + private viewU32: Uint32Array; + private viewF32: Float32Array; + private viewF64: Float64Array; + + constructor(memory: WebAssembly.Memory) { + this.memory = memory; + this.buffer = this.memory.buffer; + this.viewU8 = new Uint8Array(this.buffer); + this.viewU16 = new Uint16Array(this.buffer); + this.viewI32 = new Int32Array(this.buffer); + this.viewU32 = new Uint32Array(this.buffer); + this.viewF32 = new Float32Array(this.buffer); + this.viewF64 = new Float64Array(this.buffer); + } + + loadU8(ptr: Pointer): number { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + return this.viewU8[ptr >> 0]; + } + + loadU16(ptr: Pointer): number { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + return this.viewU16[ptr >> 1]; + } + + loadU32(ptr: Pointer): number { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + return this.viewU32[ptr >> 2]; + } + + loadI32(ptr: Pointer): number { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + return this.viewI32[ptr >> 2]; + } + + loadI64(ptr: Pointer): number { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + const base = ptr >> 2; + // assumes little endian, for now truncate high. + return this.viewI32[base]; + } + + loadF32(ptr: Pointer): number { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + return this.viewF32[ptr >> 2]; + } + + loadF64(ptr: Pointer): number { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + return this.viewF64[ptr >> 3]; + } + + loadPointer(ptr: Pointer): Pointer { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + if (this.wasm32) { + return this.loadU32(ptr); + } else { + return this.loadI64(ptr); + } + } + loadUSize(ptr: Pointer): Pointer { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + if (this.wasm32) { + return this.loadU32(ptr); + } else { + return this.loadI64(ptr); + } + } + sizeofPtr(): number { + return this.wasm32 ? SizeOf.I32 : SizeOf.I64; + } + /** + * Load raw bytes from ptr. + * @param ptr The head address + * @param numBytes The number + */ + loadRawBytes(ptr: Pointer, numBytes: number): Uint8Array { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + const result = new Uint8Array(numBytes); + result.set(this.viewU8.slice(ptr, ptr + numBytes)); + return result; + } + /** + * Load TVMByteArray from ptr. + * + * @param ptr The address of the header. + */ + loadTVMBytes(ptr: Pointer): Uint8Array { + const data = this.loadPointer(ptr); + const length = this.loadUSize(ptr + this.sizeofPtr()); + return this.loadRawBytes(data, length); + } + /** + * Load null-terminated C-string from ptr. + * @param ptr The head address + */ + loadCString(ptr: Pointer): string { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + // NOTE: the views are still valid for read. + const ret = []; + let ch = 1; + while (ch != 0) { + ch = this.viewU8[ptr]; + if (ch != 0) { + ret.push(String.fromCharCode(ch)); + } + ++ptr; + } + return ret.join(""); + } + /** + * Store raw bytes to the ptr. + * @param ptr The head address. + * @param bytes The bytes content. + */ + storeRawBytes(ptr: Pointer, bytes: Uint8Array): void { + if (this.buffer != this.memory.buffer) { + this.updateViews(); + } + this.viewU8.set(bytes, ptr); + } + + /** + * Update memory view after the memory growth. + */ + private updateViews(): void { + this.buffer = this.memory.buffer; + this.viewU8 = new Uint8Array(this.buffer); + this.viewU16 = new Uint16Array(this.buffer); + this.viewI32 = new Int32Array(this.buffer); + this.viewU32 = new Uint32Array(this.buffer); + this.viewF32 = new Float32Array(this.buffer); + this.viewF64 = new Float64Array(this.buffer); + } +} + +/** + * Auxiliary call stack for the FFI calls. + * + * Lifecyle of a call stack. + * - Calls into allocXX to allocate space, mixed with storeXXX to store data. + * - Calls into ptrFromOffset, no further allocation(as ptrFromOffset can change), + * can still call into storeXX + * - Calls into commitToWasmMemory once. + * - reset. + */ +export class CachedCallStack implements Disposable { + /** List of temporay arguments that can be disposed during reset. */ + tempArgs: Array = []; + + private memory: Memory; + private cAllocSpace: ctypes.FTVMWasmAllocSpace; + private cFreeSpace: ctypes.FTVMWasmFreeSpace; + + private buffer: ArrayBuffer; + private viewU8: Uint8Array; + private viewI32: Int32Array; + private viewU32: Uint32Array; + private viewF64: Float64Array; + + private stackTop: PtrOffset = 0; + private basePtr: Pointer = 0; + + private addressToSetTargetValue: Array<[PtrOffset, PtrOffset]> = []; + + constructor( + memory: Memory, + allocSpace: ctypes.FTVMWasmAllocSpace, + freeSpace: ctypes.FTVMWasmFreeSpace + ) { + const initCallStackSize = 128; + this.memory = memory; + this.cAllocSpace = allocSpace; + this.cFreeSpace = freeSpace; + this.buffer = new ArrayBuffer(initCallStackSize); + this.basePtr = this.cAllocSpace(initCallStackSize); + this.viewU8 = new Uint8Array(this.buffer); + this.viewI32 = new Int32Array(this.buffer); + this.viewU32 = new Uint32Array(this.buffer); + this.viewF64 = new Float64Array(this.buffer); + this.updateViews(); + } + + dispose(): void { + if (this.basePtr != 0) { + this.cFreeSpace(this.basePtr); + this.basePtr = 0; + } + } + /** + * Rest the call stack so that it can be reused again. + */ + reset(): void { + this.stackTop = 0; + assert(this.addressToSetTargetValue.length == 0); + while (this.tempArgs.length != 0) { + (this.tempArgs.pop() as Disposable).dispose(); + } + } + + /** + * Commit all the cached data to WasmMemory. + * This function can only be called once. + * No further store function should be called. + * + * @param nbytes Number of bytes to be stored. + */ + commitToWasmMemory(nbytes: number = this.stackTop): void { + // commit all pointer values. + while (this.addressToSetTargetValue.length != 0) { + const [targetOffset, valueOffset] = this.addressToSetTargetValue.pop() as [ + number, + number + ]; + this.storePtr(targetOffset, this.ptrFromOffset(valueOffset)); + } + this.memory.storeRawBytes(this.basePtr, this.viewU8.slice(0, nbytes)); + } + + /** + * Allocate space by number of bytes + * @param nbytes Number of bytes. + * @note This function always allocate space that aligns to 64bit. + */ + allocRawBytes(nbytes: number): PtrOffset { + // always aligns to 64bit + nbytes = ((nbytes + 7) >> 3) << 3; + + if (this.stackTop + nbytes > this.buffer.byteLength) { + const newSize = Math.max( + this.buffer.byteLength * 2, + this.stackTop + nbytes + ); + const oldU8 = this.viewU8; + this.buffer = new ArrayBuffer(newSize); + this.updateViews(); + this.viewU8.set(oldU8); + if (this.basePtr != 0) { + this.cFreeSpace(this.basePtr); + } + this.basePtr = this.cAllocSpace(newSize); + } + const retOffset = this.stackTop; + this.stackTop += nbytes; + return retOffset; + } + + /** + * Allocate space for pointers. + * @param count Number of pointers. + * @returns The allocated pointer array. + */ + allocPtrArray(count: number): PtrOffset { + return this.allocRawBytes(this.memory.sizeofPtr() * count); + } + + /** + * Get the real pointer from offset values. + * Note that the returned value becomes obsolete if alloc is called on the stack. + * @param offset The allocated offset. + */ + ptrFromOffset(offset: PtrOffset): Pointer { + return this.basePtr + offset; + } + + // Store APIs + storePtr(offset: PtrOffset, value: Pointer): void { + if (this.memory.wasm32) { + this.storeU32(offset, value); + } else { + this.storeI64(offset, value); + } + } + + storeUSize(offset: PtrOffset, value: Pointer): void { + if (this.memory.wasm32) { + this.storeU32(offset, value); + } else { + this.storeI64(offset, value); + } + } + + storeI32(offset: PtrOffset, value: number): void { + this.viewI32[offset >> 2] = value; + } + + storeU32(offset: PtrOffset, value: number): void { + this.viewU32[offset >> 2] = value; + } + + storeI64(offset: PtrOffset, value: number): void { + // For now, just store as 32bit + // NOTE: wasm always uses little endian. + const low = value & 0xffffffff; + const base = offset >> 2; + this.viewI32[base] = low; + this.viewI32[base + 1] = 0; + } + + storeF64(offset: PtrOffset, value: number): void { + this.viewF64[offset >> 3] = value; + } + + storeRawBytes(offset: PtrOffset, bytes: Uint8Array): void { + this.viewU8.set(bytes, offset); + } + + /** + * Allocate then set C-String pointer to the offset. + * This function will call into allocBytes to allocate necessary data. + * The address won't be set immediately(because the possible change of basePtr) + * and will be filled when we commit the data. + * + * @param offset The offset to set ot data pointer. + * @param data The string content. + */ + allocThenSetArgString(offset: PtrOffset, data: string): void { + const strOffset = this.allocRawBytes(data.length + 1); + this.storeRawBytes(strOffset, StringToUint8Array(data)); + this.addressToSetTargetValue.push([offset, strOffset]); + } + /** + * Allocate then set the argument location with a TVMByteArray. + * Allocate new temporary space for bytes. + * + * @param offset The offset to set ot data pointer. + * @param data The string content. + */ + allocThenSetArgBytes(offset: PtrOffset, data: Uint8Array): void { + // Note: size of size_t equals sizeof ptr. + const headerOffset = this.allocRawBytes(this.memory.sizeofPtr() * 2); + const dataOffset = this.allocRawBytes(data.length); + this.storeRawBytes(dataOffset, data); + this.storeUSize(headerOffset + this.memory.sizeofPtr(), data.length); + + this.addressToSetTargetValue.push([offset, headerOffset]); + this.addressToSetTargetValue.push([headerOffset, dataOffset]); + } + + /** + * Update internal cache views. + */ + private updateViews(): void { + this.viewU8 = new Uint8Array(this.buffer); + this.viewI32 = new Int32Array(this.buffer); + this.viewU32 = new Uint32Array(this.buffer); + this.viewF64 = new Float64Array(this.buffer); + } +} diff --git a/web/src/rpc_server.ts b/web/src/rpc_server.ts new file mode 100644 index 000000000000..542558aa157f --- /dev/null +++ b/web/src/rpc_server.ts @@ -0,0 +1,414 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import { SizeOf, ArgTypeCode } from "./ctypes"; +import { assert, StringToUint8Array, Uint8ArrayToString } from "./support"; +import { detectGPUDevice } from "./webgpu"; +import * as compact from "./compact"; +import * as runtime from "./runtime"; + +enum RPCServerState { + InitHeader, + InitHeaderKey, + InitServer, + WaitForCallback, + ReceivePacketHeader, + ReceivePacketBody, +} + +/** RPC magic header */ +const RPC_MAGIC = 0xff271; + +/** + * An utility class to read from binary bytes. + */ +class ByteStreamReader { + offset = 0; + bytes: Uint8Array; + + constructor(bytes: Uint8Array) { + this.bytes = bytes; + } + + readU32(): number { + const i = this.offset; + const b = this.bytes; + const val = b[i] | (b[i + 1] << 8) | (b[i + 2] << 16) | (b[i + 3] << 24); + this.offset += 4; + return val; + } + + readU64(): number { + const val = this.readU32(); + this.offset += 4; + return val; + } + + readByteArray(): Uint8Array { + const len = this.readU64(); + assert(this.offset + len <= this.bytes.byteLength); + const ret = new Uint8Array(len); + ret.set(this.bytes.slice(this.offset, this.offset + len)); + this.offset += len; + return ret; + } +} + +/** + * A websocket based RPC + */ +export class RPCServer { + url: string; + key: string; + socket: WebSocket; + state: RPCServerState = RPCServerState.InitHeader; + logger: (msg: string) => void; + getImports: () => Record; + private pendingSend: Promise = Promise.resolve(); + private name: string; + private inst?: runtime.Instance = undefined; + private serverRecvData?: (header: Uint8Array, body: Uint8Array) => void; + private currPacketHeader?: Uint8Array; + private currPacketLength = 0; + private remoteKeyLength = 0; + private pendingBytes = 0; + private buffredBytes = 0; + private messageQueue: Array = []; + + constructor( + url: string, + key: string, + getImports: () => Record, + logger: (msg: string) => void = console.log + ) { + this.url = url; + this.key = key; + this.name = "WebSocketRPCServer[" + this.key + "]: "; + this.getImports = getImports; + this.logger = logger; + + this.checkLittleEndian(); + this.socket = compact.createWebSocket(url); + this.socket.binaryType = "arraybuffer"; + + this.socket.addEventListener("open", (event: Event) => { + return this.onOpen(event); + }); + this.socket.addEventListener("message", (event: MessageEvent) => { + return this.onMessage(event); + }); + this.socket.addEventListener("close", (event: CloseEvent) => { + return this.onClose(event); + }); + } + + // eslint-disable-next-line @typescript-eslint/no-unused-vars + private onClose(_event: CloseEvent): void { + if (this.inst !== undefined) { + this.inst.dispose(); + } + if (this.state == RPCServerState.ReceivePacketHeader) { + this.log("Closing the server in clean state"); + this.log("Automatic reconnecting.."); + new RPCServer(this.url, this.key, this.getImports, this.logger); + } else { + this.log("Closing the server, final state=" + this.state); + } + } + + // eslint-disable-next-line @typescript-eslint/no-unused-vars + private onOpen(_event: Event): void { + // Send the headers + let bkey = StringToUint8Array("server:" + this.key); + bkey = bkey.slice(0, bkey.length - 1); + const intbuf = new Int32Array(1); + intbuf[0] = RPC_MAGIC; + this.socket.send(intbuf); + intbuf[0] = bkey.length; + this.socket.send(intbuf); + this.socket.send(bkey); + this.log("connected..."); + // request bytes: magic + keylen + this.requestBytes(SizeOf.I32 + SizeOf.I32); + this.state = RPCServerState.InitHeader; + } + + /** Handler for raw message. */ + private onMessage(event: MessageEvent): void { + const buffer = event.data; + this.buffredBytes += buffer.byteLength; + this.messageQueue.push(new Uint8Array(buffer)); + this.processEvents(); + } + /** Process ready events. */ + private processEvents(): void { + while (this.buffredBytes >= this.pendingBytes && this.pendingBytes != 0) { + this.onDataReady(); + } + } + /** State machine to handle each request */ + private onDataReady(): void { + switch (this.state) { + case RPCServerState.InitHeader: { + this.handleInitHeader(); + break; + } + case RPCServerState.InitHeaderKey: { + this.handleInitHeaderKey(); + break; + } + case RPCServerState.ReceivePacketHeader: { + this.currPacketHeader = this.readFromBuffer(SizeOf.I64); + const reader = new ByteStreamReader(this.currPacketHeader); + this.currPacketLength = reader.readU64(); + assert(this.pendingBytes == 0); + this.requestBytes(this.currPacketLength); + this.state = RPCServerState.ReceivePacketBody; + break; + } + case RPCServerState.ReceivePacketBody: { + const body = this.readFromBuffer(this.currPacketLength); + assert(this.pendingBytes == 0); + assert(this.currPacketHeader !== undefined); + this.onPacketReady(this.currPacketHeader, body); + break; + } + case RPCServerState.WaitForCallback: { + assert(this.pendingBytes == 0); + break; + } + default: { + throw new Error("Cannot handle state " + this.state); + } + } + } + + private onPacketReady(header: Uint8Array, body: Uint8Array): void { + if (this.inst === undefined) { + // initialize server. + const reader = new ByteStreamReader(body); + // eslint-disable-next-line @typescript-eslint/no-unused-vars + const code = reader.readU32(); + // eslint-disable-next-line @typescript-eslint/no-unused-vars + const ver = Uint8ArrayToString(reader.readByteArray()); + const nargs = reader.readU32(); + const tcodes = []; + const args = []; + for (let i = 0; i < nargs; ++i) { + tcodes.push(reader.readU32()); + } + + for (let i = 0; i < nargs; ++i) { + const tcode = tcodes[i]; + if (tcode == ArgTypeCode.TVMStr) { + const str = Uint8ArrayToString(reader.readByteArray()); + args.push(str); + } else if (tcode == ArgTypeCode.TVMBytes) { + args.push(reader.readByteArray()); + } else { + throw new Error("cannot support type code " + tcode); + } + } + this.onInitServer(args, header, body); + } else { + assert(this.serverRecvData !== undefined); + this.serverRecvData(header, body); + this.requestBytes(SizeOf.I64); + this.state = RPCServerState.ReceivePacketHeader; + } + } + + /** Event handler during server initialization. */ + private onInitServer( + args: Array, + header: Uint8Array, + body: Uint8Array + ): void { + // start the server + assert(args[0] == "rpc.WasmSession"); + assert(this.pendingBytes == 0); + + const asyncInitServer = async (): Promise => { + assert(args[1] instanceof Uint8Array); + const inst = await runtime.instantiate( + args[1].buffer, + this.getImports(), + this.logger + ); + try { + const gpuDevice: GPUDevice | undefined = await detectGPUDevice(); + if (gpuDevice !== undefined) { + const label = gpuDevice.label?.toString() || "WebGPU"; + this.log("Initialize GPU device: " + label); + inst.initWebGPU(gpuDevice); + } + } catch (err) { + this.log("Cannnot initialize WebGPU, " + err.toString()); + } + + this.inst = inst; + const fcreate = this.inst.getGlobalFunc("rpc.CreateEventDrivenServer"); + + const messageHandler = fcreate( + (cbytes: Uint8Array): runtime.Scalar => { + assert(this.inst !== undefined); + if (this.socket.readyState == 1) { + // WebSocket will automatically close the socket + // if we burst send data that exceeds its internal buffer + // wait a bit before we send next one. + const sendDataWithCongestionControl = async (): Promise => { + const packetSize = 4 << 10; + const maxBufferAmount = 4 * packetSize; + const waitTimeMs = 20; + for ( + let offset = 0; + offset < cbytes.length; + offset += packetSize + ) { + const end = Math.min(offset + packetSize, cbytes.length); + while (this.socket.bufferedAmount >= maxBufferAmount) { + await new Promise((r) => setTimeout(r, waitTimeMs)); + } + this.socket.send(cbytes.slice(offset, end)); + } + }; + // Chain up the pending send so that the async send is always in-order. + this.pendingSend = this.pendingSend.then( + sendDataWithCongestionControl + ); + // Directly return since the data are "sent" from the caller's pov. + return this.inst.scalar(cbytes.length, "int32"); + } else { + return this.inst.scalar(0, "int32"); + } + }, + this.name, + this.key + ); + + fcreate.dispose(); + const writeFlag = this.inst.scalar(3, "int32"); + + this.serverRecvData = (header: Uint8Array, body: Uint8Array): void => { + if (messageHandler(header, writeFlag) == 0) { + this.socket.close(); + } + if (messageHandler(body, writeFlag) == 0) { + this.socket.close(); + } + }; + + // Forward the same init sequence to the wasm RPC. + // The RPC will look for "rpc.wasmSession" + // and we will redirect it to the correct local session. + // register the callback to redirect the session to local. + const flocal = this.inst.getGlobalFunc("wasm.LocalSession"); + const localSession = flocal(); + flocal.dispose(); + assert(localSession instanceof runtime.Module); + + // eslint-disable-next-line @typescript-eslint/no-unused-vars + this.inst.registerFunc( + "rpc.WasmSession", + // eslint-disable-next-line @typescript-eslint/no-unused-vars + (_args: unknown): runtime.Module => { + return localSession; + } + ); + messageHandler(header, writeFlag); + messageHandler(body, writeFlag); + localSession.dispose(); + + this.log("Finish initializing the Wasm Server.."); + this.requestBytes(SizeOf.I64); + this.state = RPCServerState.ReceivePacketHeader; + // call process events in case there are bufferred data. + this.processEvents(); + }; + + this.state = RPCServerState.WaitForCallback; + asyncInitServer(); + } + + private log(msg: string): void { + this.logger(this.name + msg); + } + + private handleInitHeader(): void { + const reader = new ByteStreamReader(this.readFromBuffer(SizeOf.I32 * 2)); + const magic = reader.readU32(); + if (magic == RPC_MAGIC + 1) { + throw new Error("key: " + this.key + " has already been used in proxy"); + } else if (magic == RPC_MAGIC + 2) { + throw new Error("RPCProxy do not have matching client key " + this.key); + } + assert(magic == RPC_MAGIC, this.url + " is not an RPC Proxy"); + this.remoteKeyLength = reader.readU32(); + assert(this.pendingBytes == 0); + this.requestBytes(this.remoteKeyLength); + this.state = RPCServerState.InitHeaderKey; + } + + private handleInitHeaderKey(): void { + // eslint-disable-next-line @typescript-eslint/no-unused-vars + const remoteKey = Uint8ArrayToString( + this.readFromBuffer(this.remoteKeyLength) + ); + assert(this.pendingBytes == 0); + this.requestBytes(SizeOf.I64); + this.state = RPCServerState.ReceivePacketHeader; + } + + private checkLittleEndian(): void { + const a = new ArrayBuffer(4); + const b = new Uint8Array(a); + const c = new Uint32Array(a); + b[0] = 0x11; + b[1] = 0x22; + b[2] = 0x33; + b[3] = 0x44; + assert(c[0] === 0x44332211, "RPCServer little endian to work"); + } + + private requestBytes(nbytes: number): void { + this.pendingBytes += nbytes; + } + + private readFromBuffer(nbytes: number): Uint8Array { + const ret = new Uint8Array(nbytes); + let ptr = 0; + while (ptr < nbytes) { + assert(this.messageQueue.length != 0); + const nleft = nbytes - ptr; + if (this.messageQueue[0].byteLength <= nleft) { + const buffer = this.messageQueue.shift() as Uint8Array; + ret.set(buffer, ptr); + ptr += buffer.byteLength; + } else { + const buffer = this.messageQueue[0]; + ret.set(buffer.slice(0, nleft), ptr); + this.messageQueue[0] = buffer.slice(nleft, buffer.byteLength); + ptr += nleft; + } + } + this.buffredBytes -= nbytes; + this.pendingBytes -= nbytes; + return ret; + } +} diff --git a/web/src/runtime.ts b/web/src/runtime.ts new file mode 100644 index 000000000000..5c9b9d8181d7 --- /dev/null +++ b/web/src/runtime.ts @@ -0,0 +1,1372 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/** + * TVM JS Wasm Runtime library. + */ +import { Pointer, PtrOffset, SizeOf, ArgTypeCode } from "./ctypes"; +import { Disposable } from "./types"; +import { Memory, CachedCallStack } from "./memory"; +import { assert, StringToUint8Array } from "./support"; +import { Environment } from "./environment"; +import { WebGPUContext } from "./webgpu"; + +import * as compact from "./compact"; +import * as ctypes from "./ctypes"; + +/** + * Type for PackedFunc inthe TVMRuntime. + */ +export type PackedFunc = ((...args: any) => any) & + Disposable & { _tvmPackedCell: PackedFuncCell }; + +/** + * @internal + * FFI Library wrapper, maintains most runtime states. + */ +class FFILibrary implements Disposable { + wasm32: boolean; + memory: Memory; + exports: Record; + webGPUContext?: WebGPUContext; + private wasmInstance: WebAssembly.Instance; + private recycledCallStacks: Array = []; + + constructor( + wasmInstance: WebAssembly.Instance, + imports: Record + ) { + this.wasmInstance = wasmInstance; + this.memory = new Memory(this.detectWasmMemory(this.wasmInstance, imports)); + assert( + this.wasmInstance.exports !== undefined, + "Expect the library module contains exports" + ); + this.exports = this.wasmInstance.exports as Record; + this.wasm32 = this.memory.wasm32; + this.validateInstance(); + } + + dispose(): void { + while (this.recycledCallStacks.length != 0) { + (this.recycledCallStacks.pop() as Disposable).dispose(); + } + } + + sizeofPtr(): number { + return this.memory.sizeofPtr(); + } + + checkCall(code: number): void { + if (code != 0) { + const msgPtr = (this.exports + .TVMGetLastError as ctypes.FTVMGetLastError)(); + throw new Error("TVMError: " + this.memory.loadCString(msgPtr)); + } + } + + getOrAllocCallStack(): CachedCallStack { + if (this.recycledCallStacks.length != 0) { + return this.recycledCallStacks.pop() as CachedCallStack; + } + return new CachedCallStack( + this.memory, + this.exports.TVMWasmAllocSpace as ctypes.FTVMWasmAllocSpace, + this.exports.TVMWasmFreeSpace as ctypes.FTVMWasmFreeSpace + ); + } + + recycleCallStack(callstack: CachedCallStack): void { + callstack.reset(); + this.recycledCallStacks.push(callstack); + } + + private validateInstance(): void { + this.checkExports(["TVMWasmAllocSpace", "TVMWasmFreeSpace", "TVMFuncFree"]); + } + + private checkExports(funcNames: Array): void { + const missList = []; + for (const name of funcNames) { + const f = this.exports[name]; + if (!(f instanceof Function)) { + missList.push(name); + } + } + if (missList.length != 0) { + throw new Error("Cannot find " + missList + " in exports"); + } + } + + private detectWasmMemory( + instance: WebAssembly.Instance, + imports: Record + ): WebAssembly.Memory { + if (instance.exports.memory instanceof WebAssembly.Memory) { + return instance.exports.memory; + } + if (imports.env && imports.env.memory instanceof WebAssembly.Memory) { + return imports.env.memory; + } + + throw new Error( + "Cannt detect wasm memory from imports " + + imports + + " or exports" + + instance.exports + ); + } +} + +/** + * A typed scalar constant used to represent a typed number + * argument to PackedFunc calls. + */ +export class Scalar { + /** The value. */ + value: number; + /** The data type of the scalar. */ + dtype: string; + + constructor(value: number, dtype: string) { + this.value = value; + this.dtype = dtype; + } +} + +/** + * Cell holds the PackedFunc object. + */ +class PackedFuncCell implements Disposable { + handle: Pointer; + private lib: FFILibrary; + + constructor(handle: Pointer, lib: FFILibrary) { + this.handle = handle; + this.lib = lib; + } + + dispose(): void { + if (this.handle != 0) { + this.lib.checkCall( + (this.lib.exports.TVMFuncFree as ctypes.FTVMFuncFree)(this.handle) + ); + this.handle = 0; + } + } +} + +const DeviceEnumToStr: Record = { + 1: "cpu", + 2: "gpu", + 4: "opencl", + 8: "metal", + 15: "webgpu" +}; + +const DeviceStrToEnum: Record = { + cpu: 1, + gpu: 2, + cuda: 2, + cl: 4, + opencl: 4, + vulkan: 7, + metal: 8, + webgpu: 15 +}; + +/** + * Represent a runtime context where a NDArray can reside. + */ +export class DLContext { + /** The device type code of the context. */ + deviceType: number; + /** The device index. */ + deviceId: number; + + private lib: FFILibrary; + + constructor(deviceType: number | string, deviceId: number, lib: FFILibrary) { + const tp = typeof deviceType; + if (tp == "string") { + this.deviceType = DeviceStrToEnum[deviceType]; + if (this.deviceType == undefined) { + throw new Error("Cannot recogonize deviceType " + deviceType); + } + } else if (tp == "number") { + this.deviceType = deviceType as number; + } else { + throw new Error("Cannot take type " + tp + " as deviceType"); + } + this.deviceId = deviceId; + this.lib = lib; + } + + /** + * Synchronize the context + */ + async sync(): Promise { + if (this.deviceType == DeviceStrToEnum.webgpu) { + assert(this.lib.webGPUContext !== undefined); + await this.lib.webGPUContext.sync(); + } + } + + toString(): string { + return ( + DeviceEnumToStr[this.deviceType] + "(" + this.deviceId.toString() + ")" + ); + } +} +/** + * The data type code in DLDataType + */ +export const enum DLDataTypeCode { + Int = 0, + UInt = 1, + Float = 2, + OpaqueHandle = 3 +} + +const DLDataTypeCodeToStr: Record = { + 0: "int", + 1: "uint", + 2: "float", + 3: "handle", +}; + +/** + * Runtime data type of NDArray. + */ +export class DLDataType { + /** The type code */ + code: number; + /** Number of bits in the data type. */ + bits: number; + /** Number of vector lanes. */ + lanes: number; + + constructor(code: number, bits: number, lanes: number) { + this.code = code; + this.bits = bits; + this.lanes = lanes; + } + + toString(): string { + const ret = DLDataTypeCodeToStr[this.code] + this.bits.toString(); + if (this.lanes != 1) { + return ret + "x" + this.lanes.toString(); + } else { + return ret; + } + } + + numStorageBytes(): number { + return (this.bits * this.lanes + 7) >> 3; + } +} + +/** + * n-dimnesional array. + */ +export class NDArray implements Disposable { + /** Internal array handle. */ + handle: Pointer; + /** Number of dimensions. */ + ndim: number; + /** Data type of the array. */ + dtype: string; + /** Shape of the array. */ + shape: Array; + /** Context of the array. */ + context: DLContext; + /** Whether it is a temporary view that can become invalid after the call. */ + private isView: boolean; + private byteOffset: number; + private dltensor: Pointer; + private dataPtr: Pointer; + private lib: FFILibrary; + private dlDataType: DLDataType; + + constructor(handle: Pointer, isView: boolean, lib: FFILibrary) { + this.handle = handle; + this.isView = isView; + this.lib = lib; + + if (this.isView) { + this.dltensor = handle; + } else { + this.dltensor = this.getDLTensorFromArrayHandle(this.handle); + } + // constant offsets. + const arrayOffsetData = 0; + const arrayOffsetContext = arrayOffsetData + this.lib.sizeofPtr(); + const arrayOffsetDevType = arrayOffsetContext; + const arrayOffsetDevId = arrayOffsetContext + SizeOf.I32; + const arrayOffsetNdim = arrayOffsetContext + SizeOf.DLContext; + const arrayOffsetDtype = arrayOffsetNdim + SizeOf.I32; + const arrayOffsetDtypeCode = arrayOffsetDtype; + const arrayOffsetDtypeBits = arrayOffsetDtype + SizeOf.U8; + const arrayOffsetDtypeLanes = arrayOffsetDtypeBits + SizeOf.U8; + const arrayOffsetShape = arrayOffsetDtype + SizeOf.DLDataType; + const arrayOffsetStrides = arrayOffsetShape + this.lib.sizeofPtr(); + const arrayOffsetByteOffset = arrayOffsetStrides + this.lib.sizeofPtr(); + // dataPtr + this.dataPtr = lib.memory.loadPointer(this.dltensor); + // ndim + this.ndim = lib.memory.loadI32(this.dltensor + arrayOffsetNdim); + // shape + const cshapePtr = lib.memory.loadPointer(this.dltensor + arrayOffsetShape); + this.shape = []; + for (let i = 0; i < this.ndim; ++i) { + this.shape.push(lib.memory.loadI64(cshapePtr + i * SizeOf.I64)); + } + // dtype + const code = lib.memory.loadU8(this.dltensor + arrayOffsetDtypeCode); + const bits = lib.memory.loadU8(this.dltensor + arrayOffsetDtypeBits); + const lanes = lib.memory.loadU16(this.dltensor + arrayOffsetDtypeLanes); + this.dlDataType = new DLDataType(code, bits, lanes); + this.dtype = this.dlDataType.toString(); + + // ctx + const deviceType = lib.memory.loadI32(this.dltensor + arrayOffsetDevType); + const deviceId = lib.memory.loadI32(this.dltensor + arrayOffsetDevId); + this.context = new DLContext(deviceType, deviceId, lib); + + // byte_offset + this.byteOffset = lib.memory.loadI64(this.dltensor + arrayOffsetByteOffset); + } + + dispose(): void { + if (this.handle != 0 && !this.isView) { + this.lib.checkCall( + (this.lib.exports.TVMArrayFree as ctypes.FTVMArrayFree)(this.handle) + ); + this.handle = 0; + } + } + /** + * Copy data from another NDArray or javascript array. + * The number of elements must match. + * + * @param data The source data array. + * @returns this + */ + copyFrom(data: NDArray | Array | Float32Array): this { + if (data instanceof NDArray) { + this.lib.checkCall( + (this.lib.exports.TVMArrayCopyFromTo as ctypes.FTVMArrayCopyFromTo)( + data.handle, + this.handle, + 0 + ) + ); + return this; + } else { + const size = this.shape.reduce((a, b) => { + return a * b; + }, 1); + if (data.length != size) { + throw new Error( + "data size and shape mismatch data.length" + + data.length + + " vs " + + size + ); + } + let buffer: ArrayBuffer; + if (this.dtype == "float32") { + buffer = Float32Array.from(data).buffer; + } else if (this.dtype == "float64") { + buffer = Float64Array.from(data).buffer; + } else if (this.dtype == "int32") { + buffer = Int32Array.from(data).buffer; + } else if (this.dtype == "int8") { + buffer = Int8Array.from(data).buffer; + } else if (this.dtype == "uint8") { + buffer = Uint8Array.from(data).buffer; + } else { + throw new Error("Unsupported data type " + this.dtype); + } + return this.copyFromRawBytes(new Uint8Array(buffer)); + } + } + /** + * Copy data from raw bytes. + * @param data Uint8Array of bytes. + * @returns this + */ + copyFromRawBytes(data: Uint8Array): this { + const size = this.shape.reduce((a, b) => { + return a * b; + }, 1); + const nbytes = this.dlDataType.numStorageBytes() * size; + if (nbytes != data.length) { + throw new Error("Expect the data's length equals nbytes=" + nbytes); + } + + const stack = this.lib.getOrAllocCallStack(); + + const tempOffset = stack.allocRawBytes(nbytes); + const tempPtr = stack.ptrFromOffset(tempOffset); + this.lib.memory.storeRawBytes(tempPtr, data); + this.lib.checkCall( + (this.lib.exports.TVMArrayCopyFromBytes as ctypes.FTVMArrayCopyFromBytes)( + this.handle, + tempPtr, + nbytes + ) + ); + + this.lib.recycleCallStack(stack); + return this; + } + /** + * Return a copied Uint8Array of the raw bytes in the NDArray. + * @returns The result array. + */ + toRawBytes(): Uint8Array { + if (this.context.deviceType != DeviceStrToEnum.cpu) { + throw new Error("Can only synchronize copy for GPU array, use copyfrom instead."); + } + const size = this.shape.reduce((a, b) => { + return a * b; + }, 1); + + const nbytes = this.dlDataType.numStorageBytes() * size; + const stack = this.lib.getOrAllocCallStack(); + + const tempOffset = stack.allocRawBytes(nbytes); + const tempPtr = stack.ptrFromOffset(tempOffset); + this.lib.checkCall( + (this.lib.exports.TVMArrayCopyToBytes as ctypes.FTVMArrayCopyToBytes)( + this.handle, + tempPtr, + nbytes + ) + ); + const ret = this.lib.memory.loadRawBytes(tempPtr, nbytes); + + this.lib.recycleCallStack(stack); + return ret; + } + + /** + * Return a TypedArray copy of the NDArray, the specific type depends on + * the dtype of the NDArray. + * @returns The result array. + */ + toArray(): Float32Array | Float64Array | Int32Array | Int8Array | Uint8Array { + const stype = this.dtype; + if (stype == "float32") { + return new Float32Array(this.toRawBytes().buffer); + } else if (stype == "float64") { + return new Float64Array(this.toRawBytes().buffer); + } else if (stype == "int32") { + return new Int32Array(this.toRawBytes().buffer); + } else if (stype == "int8") { + return new Int8Array(this.toRawBytes().buffer); + } else if (stype == "uint8") { + return new Uint8Array(this.toRawBytes().buffer); + } else { + throw new Error("Unsupported data type " + this.dtype); + } + } + + private getDLTensorFromArrayHandle(handle: Pointer): Pointer { + // Note: this depends on the NDArray C ABI. + // keep this function in case of ABI change. + return handle; + } +} + +/** + * Runtime Module. + */ +export class Module implements Disposable { + handle: Pointer; + private lib: FFILibrary; + private makePackedFunc: (ptr: Pointer) => PackedFunc; + + constructor( + handle: Pointer, + lib: FFILibrary, + makePackedFunc: (ptr: Pointer) => PackedFunc + ) { + this.handle = handle; + this.lib = lib; + this.makePackedFunc = makePackedFunc; + } + + dispose(): void { + if (this.handle != 0) { + this.lib.checkCall( + (this.lib.exports.TVMModFree as ctypes.FTVMModFree)(this.handle) + ); + this.handle = 0; + } + } + + /** + * Get a function in the module. + * @param name The name of the function. + * @returns The result function. + */ + getFunction(name: string): PackedFunc { + const stack = this.lib.getOrAllocCallStack(); + const nameOffset = stack.allocRawBytes(name.length + 1); + stack.storeRawBytes(nameOffset, StringToUint8Array(name)); + + const outOffset = stack.allocPtrArray(1); + const outPtr = stack.ptrFromOffset(outOffset); + + stack.commitToWasmMemory(outOffset); + + this.lib.checkCall( + (this.lib.exports.TVMModGetFunction as ctypes.FTVMModGetFunction)( + this.handle, + stack.ptrFromOffset(nameOffset), + 1, + outPtr + ) + ); + const handle = this.lib.memory.loadPointer(outPtr); + this.lib.recycleCallStack(stack); + if (handle == 0) { + throw Error("Cannot find function " + name); + } + const ret = this.makePackedFunc(handle); + return ret; + } + + /** + * Import another module into the current runtime module. + * @param mod The module to be imported. + */ + importModule(mod: Module): void { + this.lib.checkCall( + (this.lib.exports.TVMModImport as ctypes.FTVMModImport)( + this.handle, + mod.handle + ) + ); + } +} + +/** + * Graph runtime. + * + * This is a thin wrapper of the underlying TVM module. + * you can also directly call set_input, run, and get_output + * of underlying module functions + */ +class GraphRuntime implements Disposable { + module: Module; + private packedSetInput: PackedFunc; + private packedRun: PackedFunc; + private packedGetOutput: PackedFunc; + private packedLoadParams: PackedFunc; + + /** + * COnstructor + * @param module The underlying module. + */ + constructor(module: Module) { + this.module = module; + this.packedSetInput = module.getFunction("set_input"); + this.packedRun = module.getFunction("run"); + this.packedGetOutput = module.getFunction("get_output"); + this.packedLoadParams = module.getFunction("load_params"); + } + + dispose(): void { + this.packedSetInput.dispose(); + this.packedRun.dispose(); + this.packedGetOutput.dispose(); + } + + /** + * Set input to the executor. + * + * @param key The input key. + * @param value The value to get set. + */ + setInput(key: number | string, value: NDArray): void { + if (typeof key == "number") { + this.packedSetInput(new Scalar(key, "int32"), value); + } else { + this.packedSetInput(key, value); + + } + } + + /** + * Execute the underlying graph. + */ + run(): void { + this.packedRun(); + } + + /** + * Get index-th output. + * @param index The index number. + * @param out The optional output storage parameters. + * @returns The output array. + */ + getOutput(index: number, out: NDArray | undefined = undefined): NDArray { + if (out !== undefined) { + this.packedGetOutput(new Scalar(index, "int32"), out) + return out; + } else { + return this.packedGetOutput(new Scalar(index, "int32")); + } + } + + /** + * Load parameters from parameter binary. + * @param paramBinary The parameter binary. + */ + loadParams(paramBinary: Uint8Array): void { + this.packedLoadParams(paramBinary); + } + + /** + * Benchmark stable execution of the graph(without data copy). + * @params ctx The context to sync during each run. + * @number The number of times to compute the average. + * @repeat The number of times to repeat the run. + */ + async benchmarkRuns(ctx: DLContext, number=10, repeat=4): Promise { + // Skip first run as it can involve GPU warmup and module loading time. + const perf = compact.getPeformance(); + const results = []; + this.run(); + await ctx.sync(); + for (let k = 0; k < repeat; ++k) { + const tstart = perf.now(); + for (let i = 0; i < number; ++i) { + this.run(); + } + await ctx.sync(); + const tend = perf.now(); + results.push((tend - tstart) / number); + } + return results; + } +} + +/** Code used as the first argument of the async callback. */ +const enum AyncCallbackCode { + kReturn = 4, + kException = 5, +} + +/** + * TVM runtime instance. + */ +export class Instance implements Disposable { + memory: Memory; + exports: Record; + private lib: FFILibrary; + private env: Environment; + + /** + * Internal function(registered by the runtime) + */ + private wasmCreateLibraryModule?: PackedFunc & + ((getFunc: PackedFunc, getGlobal: PackedFunc) => PackedFunc); + + /** + * Constructor + * + * importObject can also be a {@link LibraryProvider} object, + * a WASI object, or an object containing wasmLibraryProvider field. + * + * @param wasmModule The input module or instance. + * @param importObject The imports to initialize the wasmInstance if it is not provided. + * @param wasmInstance Additional wasm instance argument for deferred construction. + * @param env Directly specified environment module. + * + * @see Please use the async version {@link instantiate} when targeting browsers. + */ + constructor( + wasmModule: WebAssembly.Module, + importObject: Record = {}, + wasmInstance?: WebAssembly.Instance, + env?: Environment + ) { + if (wasmInstance instanceof WebAssembly.Instance) { + assert( + env instanceof Environment, + "env must be provided when passing in instance" + ); + } else { + assert(env === undefined); + env = new Environment(importObject); + wasmInstance = new WebAssembly.Instance(wasmModule, env.imports); + } + + env.start(wasmInstance); + this.env = env; + this.lib = new FFILibrary(wasmInstance, env.imports); + this.memory = this.lib.memory; + this.exports = this.lib.exports; + this.registerEnvGlobalPackedFuncs(); + } + + dispose(): void { + this.lib.dispose(); + } + /** + * Get system-wide library module in the wasm. + * System lib is a global module that contains self register functions in startup. + * @returns The system library module. + */ + systemLib(): Module { + const getSysLib = this.getGlobalFunc("runtime.SystemLib"); + const mod = getSysLib() as Module; + getSysLib.dispose(); + return mod; + } + /** + * List all the global function names registered in the runtime. + * @returns The name list. + */ + listGlobalFuncNames(): Array { + const stack = this.lib.getOrAllocCallStack(); + + const outSizeOffset = stack.allocPtrArray(2); + + const outSizePtr = stack.ptrFromOffset(outSizeOffset); + const outArrayPtr = stack.ptrFromOffset( + outSizeOffset + this.lib.sizeofPtr() + ); + + this.lib.checkCall( + (this.exports.TVMFuncListGlobalNames as ctypes.FTVMFuncListGlobalNames)( + outSizePtr, + outArrayPtr + ) + ); + + const size = this.memory.loadI32(outSizePtr); + const array = this.memory.loadPointer(outArrayPtr); + const names: Array = []; + + for (let i = 0; i < size; ++i) { + names.push( + this.memory.loadCString( + this.memory.loadPointer(array + this.lib.sizeofPtr() * i) + ) + ); + } + + this.lib.recycleCallStack(stack); + return names; + } + + /** + * Register function to be global function in tvm runtime. + * @param name The name of the function. + * @param f function to be registered. + * @param override Whether overwrite function in existing registry. + */ + registerFunc( + name: string, + func: PackedFunc | Function, + override = false + ): void { + const packedFunc = this.toPackedFunc(func); + const ioverride = override ? 1 : 0; + + const stack = this.lib.getOrAllocCallStack(); + const nameOffset = stack.allocRawBytes(name.length + 1); + stack.storeRawBytes(nameOffset, StringToUint8Array(name)); + stack.commitToWasmMemory(); + + this.lib.checkCall( + (this.lib.exports.TVMFuncRegisterGlobal as ctypes.FTVMFuncRegisterGlobal)( + stack.ptrFromOffset(nameOffset), + packedFunc._tvmPackedCell.handle, + ioverride + ) + ); + } + + /** + * Get global PackedFunc from the runtime. + * @param name The name of the function. + * @returns The result function. + */ + getGlobalFunc(name: string): PackedFunc { + const stack = this.lib.getOrAllocCallStack(); + const nameOffset = stack.allocRawBytes(name.length + 1); + stack.storeRawBytes(nameOffset, StringToUint8Array(name)); + const outOffset = stack.allocPtrArray(1); + const outPtr = stack.ptrFromOffset(outOffset); + + stack.commitToWasmMemory(outOffset); + + this.lib.checkCall( + (this.exports.TVMFuncGetGlobal as ctypes.FTVMFuncGetGlobal)( + stack.ptrFromOffset(nameOffset), + outPtr + ) + ); + const handle = this.memory.loadPointer(outPtr); + this.lib.recycleCallStack(stack); + if (handle == 0) { + throw Error("Cannot find global function " + name); + } + const ret = this.makePackedFunc(handle); + return ret; + } + + /** + * Check if func is PackedFunc. + * + * @param func The input. + * @returns The check result. + */ + isPackedFunc(func: unknown): boolean { + // eslint-disable-next-line no-prototype-builtins + return typeof func == "function" && func.hasOwnProperty("_tvmPackedCell"); + } + + /** + * Convert func to PackedFunc + * + * @param func Input function. + * @returns The converted function. + */ + toPackedFunc(func: Function): PackedFunc { + if (this.isPackedFunc(func)) return func as PackedFunc; + return this.createPackedFuncFromCFunc(this.wrapJSFuncAsPackedCFunc(func)); + } + + /** + * Convert dtype to {@link DLDataType} + * + * @param dtype The input dtype string or DLDataType. + * @returns The converted result. + */ + toDLDataType(dtype: string | DLDataType): DLDataType { + if (dtype instanceof DLDataType) return dtype; + if (typeof dtype == "string") { + let pattern = dtype; + let code, + bits = 32, + lanes = 1; + if (pattern.substring(0, 5) == "float") { + pattern = pattern.substring(5, pattern.length); + code = DLDataTypeCode.Float; + } else if (pattern.substring(0, 3) == "int") { + pattern = pattern.substring(3, pattern.length); + code = DLDataTypeCode.Int; + } else if (pattern.substring(0, 4) == "uint") { + pattern = pattern.substring(4, pattern.length); + code = DLDataTypeCode.UInt; + } else if (pattern.substring(0, 6) == "handle") { + pattern = pattern.substring(5, pattern.length); + code = DLDataTypeCode.OpaqueHandle; + bits = 64; + } else { + throw new Error("Unknown dtype " + dtype); + } + + const arr = pattern.split("x"); + if (arr.length >= 1) { + const parsed = parseInt(arr[0]); + if (parsed + "" == arr[0]) { + bits = parsed; + } + } + if (arr.length >= 2) { + lanes = parseInt(arr[1]); + } + return new DLDataType(code, bits, lanes); + } else { + throw new Error("Unknown dtype " + dtype); + } + } + + /** + * Create a new {@link Scalar} that can be passed to a PackedFunc. + * @param value The number value. + * @param dtype The dtype string. + * @returns The created scalar. + */ + scalar(value: number, dtype: string): Scalar { + return new Scalar(value, dtype); + } + + /** + * Create a new {@link DLContext} + * @param deviceType The device type. + * @param deviceId The device index. + * @returns The created context. + */ + context(deviceType: number | string, deviceId = 0): DLContext { + return new DLContext(deviceType, deviceId, this.lib); + } + + /** + * Create a new cpu {@link DLContext} + * @param deviceId The device index. + */ + cpu(deviceId = 0): DLContext { + return this.context("cpu", deviceId); + } + + /** + * Create a new webgpu {@link DLContext} + * @param deviceId The device index. + */ + webgpu(deviceId = 0): DLContext { + return this.context("webgpu", deviceId); + } + + /** + * Create an empty {@link NDArray} with given shape and dtype. + * + * @param shape The shape of the array. + * @param dtype The data type of the array. + * @param ctx The context of the ndarray. + * @returns The created ndarray. + */ + empty( + shape: Array | number, + dtype: string | DLDataType = "float32", + ctx: DLContext = this.context("cpu", 0) + ): NDArray { + dtype = this.toDLDataType(dtype); + shape = typeof shape == "number" ? [shape] : shape; + + const stack = this.lib.getOrAllocCallStack(); + const shapeOffset = stack.allocRawBytes(shape.length * SizeOf.I64); + for (let i = 0; i < shape.length; ++i) { + stack.storeI64(shapeOffset + i * SizeOf.I64, shape[i]); + } + + const outOffset = stack.allocPtrArray(1); + const outPtr = stack.ptrFromOffset(outOffset); + stack.commitToWasmMemory(outOffset); + + this.lib.checkCall( + (this.exports.TVMArrayAlloc as ctypes.FTVMArrayAlloc)( + stack.ptrFromOffset(shapeOffset), + shape.length, + dtype.code, + dtype.bits, + dtype.lanes, + ctx.deviceType, + ctx.deviceId, + outPtr + ) + ); + const ret = new NDArray(this.memory.loadPointer(outPtr), false, this.lib); + this.lib.recycleCallStack(stack); + return ret; + } + + /** + * Create a new graph runtime. + * + * @param graphJson The graph runtime json file. + * @param lib The underlying library. + * @param ctx The execution context of the graph. + */ + createGraphRuntime( + graphJson: string, + lib: Module, + ctx: DLContext + ): GraphRuntime { + const fcreate = this.getGlobalFunc("tvm.graph_runtime.create"); + const module = fcreate( + graphJson, + lib, + this.scalar(ctx.deviceType, "int32"), + this.scalar(ctx.deviceId, "int32")) as Module; + return new GraphRuntime(module); + } + + + /** + * Register an asyncfunction to be global function in the server. + * @param name The name of the function. + * @param func function to be registered. + * @param override Whether overwrite function in existing registry. + * + * @note The async function will only be used for serving remote calls in the rpc. + */ + registerAsyncServerFunc( + name: string, + func: Function, + override = false + ): void { + const asyncVariant = (...args: Array): void => { + const fargs = args.slice(0, args.length - 1); + const callback = args[args.length - 1] as PackedFunc; + const promise: Promise = func(...fargs); + promise.then((rv: any) => { + callback(this.scalar(AyncCallbackCode.kReturn, "int32"), rv); + }); + }; + this.registerFunc("__async." + name, asyncVariant, override); + } + + /** + * Initialize webgpu in the runtime. + * @param device The given GPU device. + */ + initWebGPU(device: GPUDevice): void { + const webGPUContext = new WebGPUContext( + this.memory, device + ); + this.registerFunc("wasm.WebGPUDeviceAPI", (name: string) => { + return webGPUContext.getDeviceAPI(name); + }); + this.registerFunc("wasm.WebGPUCreateShader", (info: string, data: Uint8Array) => { + return webGPUContext.createShader(info, data); + }); + this.registerAsyncServerFunc("wasm.WebGPUWaitForTasks", async () => { + await webGPUContext.sync(); + }); + this.lib.webGPUContext = webGPUContext; + } + + /** Register global packed functions needed by the backend to the env. */ + private registerEnvGlobalPackedFuncs(): void { + // Register the timer function to enable the time_evaluator. + const perf = compact.getPeformance(); + + // Helper function to time the finvoke + const timeExecution = async ( + finvoke: PackedFunc, + ctx: DLContext, + nstep: number, + repeat: number, + minRepeatMs: number + ): Promise => { + finvoke(this.scalar(1, "int32")); + await ctx.sync(); + const result = []; + let setupNumber: number = nstep; + + for (let i = 0; i < repeat; ++i) { + let durationMs = 0.0; + do { + if (durationMs > 0.0) { + setupNumber = Math.floor( + Math.max(minRepeatMs / (durationMs / nstep) + 1, nstep * 1.618) + ); + } + const tstart: number = perf.now(); + finvoke(this.scalar(setupNumber, "int32")); + await ctx.sync(); + const tend: number = perf.now(); + + durationMs = tend - tstart; + } while (durationMs < minRepeatMs); + const speed = durationMs / setupNumber / 1000; + result.push(speed); + } + const ret = new Float64Array(result.length); + ret.set(result); + return new Uint8Array(ret.buffer); + }; + + const addOne = async (x: number): Promise => { + await new Promise(resolve => setTimeout(resolve, 100)); + return x + 1; + }; + + this.registerAsyncServerFunc("wasm.TimeExecution", timeExecution); + this.registerAsyncServerFunc("testing.asyncAddOne", addOne); + } + + private createPackedFuncFromCFunc( + func: ctypes.FTVMWasmPackedCFunc + ): PackedFunc { + let findex = this.env.packedCFuncTable.length; + if (this.env.packedCFuncTableFreeId.length != 0) { + findex = this.env.packedCFuncTableFreeId.pop() as number; + } else { + this.env.packedCFuncTable.push(undefined); + } + this.env.packedCFuncTable[findex] = func; + + const stack = this.lib.getOrAllocCallStack(); + const outOffset = stack.allocPtrArray(1); + const outPtr = stack.ptrFromOffset(outOffset); + this.lib.checkCall( + (this.exports + .TVMWasmFuncCreateFromCFunc as ctypes.FTVMWasmFuncCreateFromCFunc)( + findex, + outPtr + ) + ); + const ret = this.makePackedFunc(this.memory.loadPointer(outPtr)); + this.lib.recycleCallStack(stack); + return ret; + } + + /** + * Set packed function arguments into the location indicated by argsValue and argsCode. + * Allocate new temporary space from the stack if necessary. + * + * @parma stack The call stack + * @param args The input arguments. + * @param argsValue The offset of argsValue. + * @param argsCode The offset of argsCode. + */ + setPackedArguments( + stack: CachedCallStack, + args: Array, + argsValue: PtrOffset, + argsCode: PtrOffset + ): void { + for (let i = 0; i < args.length; ++i) { + let val = args[i]; + const tp = typeof val; + const valueOffset = argsValue + i * SizeOf.TVMValue; + const codeOffset = argsCode + i * SizeOf.I32; + if (val instanceof NDArray) { + stack.storePtr(valueOffset, val.handle); + stack.storeI32(codeOffset, ArgTypeCode.TVMNDArrayHandle); + } else if (val instanceof Scalar) { + if (val.dtype.startsWith("int") || val.dtype.startsWith("uint")) { + stack.storeI64(valueOffset, val.value); + stack.storeI32(codeOffset, ArgTypeCode.Int); + } else if (val.dtype.startsWith("float")) { + stack.storeF64(valueOffset, val.value); + stack.storeI32(codeOffset, ArgTypeCode.Float); + } else { + assert(val.dtype == "handle", "Expect handle"); + stack.storePtr(valueOffset, val.value); + stack.storeI32(codeOffset, ArgTypeCode.TVMOpaqueHandle); + } + } else if (val instanceof DLContext) { + stack.storeI32(valueOffset, val.deviceType); + stack.storeI32(valueOffset + SizeOf.I32, val.deviceType); + stack.storeI32(codeOffset, ArgTypeCode.TVMContext); + } else if (tp == "number") { + stack.storeF64(valueOffset, val); + stack.storeI32(codeOffset, ArgTypeCode.Float); + // eslint-disable-next-line no-prototype-builtins + } else if (tp == "function" && val.hasOwnProperty("_tvmPackedCell")) { + stack.storePtr(valueOffset, val._tvmPackedCell.handle); + stack.storeI32(codeOffset, ArgTypeCode.TVMPackedFuncHandle); + } else if (val === null || val == undefined) { + stack.storePtr(valueOffset, 0); + stack.storeI32(codeOffset, ArgTypeCode.Null); + } else if (tp == "string") { + stack.allocThenSetArgString(valueOffset, val); + stack.storeI32(codeOffset, ArgTypeCode.TVMStr); + } else if (val instanceof Uint8Array) { + stack.allocThenSetArgBytes(valueOffset, val); + stack.storeI32(codeOffset, ArgTypeCode.TVMBytes); + } else if (val instanceof Function) { + val = this.toPackedFunc(val); + stack.tempArgs.push(val); + stack.storePtr(valueOffset, val._tvmPackedCell.handle); + stack.storeI32(codeOffset, ArgTypeCode.TVMPackedFuncHandle); + } else if (val instanceof Module) { + stack.storePtr(valueOffset, val.handle); + stack.storeI32(codeOffset, ArgTypeCode.TVMModuleHandle); + } else { + throw new Error("Unsupported argument type " + tp); + } + } + } + + private wrapJSFuncAsPackedCFunc(func: Function): ctypes.FTVMWasmPackedCFunc { + const lib = this.lib; + return ( + argValues: Pointer, + argCodes: Pointer, + nargs: number, + ret: Pointer, + // eslint-disable-next-line @typescript-eslint/no-unused-vars + _handle: Pointer + ): number => { + const jsArgs = []; + for (let i = 0; i < nargs; ++i) { + const valuePtr = argValues + i * SizeOf.TVMValue; + const codePtr = argCodes + i * SizeOf.I32; + let tcode = lib.memory.loadI32(codePtr); + + if ( + tcode == ArgTypeCode.TVMObjectHandle || + tcode == ArgTypeCode.TVMObjectRValueRefArg || + tcode == ArgTypeCode.TVMPackedFuncHandle || + tcode == ArgTypeCode.TVMModuleHandle + ) { + lib.checkCall( + (lib.exports.TVMCbArgToReturn as ctypes.FTVMCbArgToReturn)( + valuePtr, + codePtr + ) + ); + } + tcode = lib.memory.loadI32(codePtr); + jsArgs.push(this.retValueToJS(valuePtr, tcode, true)); + } + + const rv = func(...jsArgs); + + if (rv !== undefined && rv !== null) { + const stack = lib.getOrAllocCallStack(); + const valueOffset = stack.allocRawBytes(SizeOf.TVMValue); + const codeOffset = stack.allocRawBytes(SizeOf.I32); + this.setPackedArguments(stack, [rv], valueOffset, codeOffset); + const valuePtr = stack.ptrFromOffset(valueOffset); + const codePtr = stack.ptrFromOffset(codeOffset); + stack.commitToWasmMemory(); + lib.checkCall( + (lib.exports.TVMCFuncSetReturn as ctypes.FTVMCFuncSetReturn)( + ret, + valuePtr, + codePtr, + 1 + ) + ); + lib.recycleCallStack(stack); + } + return 0; + }; + } + + private makePackedFunc(handle: Pointer): PackedFunc { + const cell = new PackedFuncCell(handle, this.lib); + + const packedFunc = (...args: any): any => { + const stack = this.lib.getOrAllocCallStack(); + + const valueOffset = stack.allocRawBytes(SizeOf.TVMValue * args.length); + const tcodeOffset = stack.allocRawBytes(SizeOf.I32 * args.length); + + this.setPackedArguments(stack, args, valueOffset, tcodeOffset); + + const rvalueOffset = stack.allocRawBytes(SizeOf.TVMValue); + const rcodeOffset = stack.allocRawBytes(SizeOf.I32); + const rvaluePtr = stack.ptrFromOffset(rvalueOffset); + const rcodePtr = stack.ptrFromOffset(rcodeOffset); + + // commit to wasm memory, till rvalueOffset (the return value don't need to be committed) + stack.commitToWasmMemory(rvalueOffset); + + this.lib.checkCall( + (this.exports.TVMFuncCall as ctypes.FTVMFuncCall)( + handle, + stack.ptrFromOffset(valueOffset), + stack.ptrFromOffset(tcodeOffset), + args.length, + rvaluePtr, + rcodePtr + ) + ); + + const ret = this.retValueToJS(rvaluePtr, this.memory.loadI32(rcodePtr), false); + this.lib.recycleCallStack(stack); + return ret; + }; + // Attach attributes to the function type. + // This is because javascript do not allow us to overload call. + const ret: any = packedFunc; + ret.dispose = (): void => { + cell.dispose(); + }; + ret._tvmPackedCell = cell; + return ret as PackedFunc; + } + + private retValueToJS(rvaluePtr: Pointer, tcode: number, callbackArg: boolean): any { + switch (tcode) { + case ArgTypeCode.Int: + case ArgTypeCode.UInt: + return this.memory.loadI64(rvaluePtr); + case ArgTypeCode.Float: + return this.memory.loadF64(rvaluePtr); + case ArgTypeCode.TVMOpaqueHandle: { + return this.memory.loadPointer(rvaluePtr); + } + case ArgTypeCode.TVMNDArrayHandle: { + return new NDArray(this.memory.loadPointer(rvaluePtr), false, this.lib); + } + case ArgTypeCode.TVMDLTensorHandle: { + assert(callbackArg); + return new NDArray(this.memory.loadPointer(rvaluePtr), true, this.lib); + } + case ArgTypeCode.TVMPackedFuncHandle: { + return this.makePackedFunc(this.memory.loadPointer(rvaluePtr)); + } + case ArgTypeCode.TVMModuleHandle: { + return new Module( + this.memory.loadPointer(rvaluePtr), + this.lib, + (ptr: Pointer) => { + return this.makePackedFunc(ptr); + } + ); + } + case ArgTypeCode.Null: return undefined; + case ArgTypeCode.TVMContext: { + const deviceType = this.memory.loadI32(rvaluePtr); + const deviceId = this.memory.loadI32(rvaluePtr + SizeOf.I32); + return this.context(deviceType, deviceId); + } + case ArgTypeCode.TVMStr: { + const ret = this.memory.loadCString(this.memory.loadPointer(rvaluePtr)); + return ret; + } + case ArgTypeCode.TVMBytes: { + return this.memory.loadTVMBytes(this.memory.loadPointer(rvaluePtr)); + } + default: + throw new Error("Unsupported return type code=" + tcode); + } + } +} + +/** + * Asynchrously instantiate a new {@link Instance}. + * + * importObject can also be a {@link LibraryProvider} object, + * a WASI object, or an object containing wasmLibraryProvider field. + * We can take benefit of syslib implementations from the Emscripten + * by passing its generated js Module as the imports. + * + * @param bufferSource The source to be compiled. + * @param importObject The import objects. + * @param logger The system logger. + */ +export function instantiate( + bufferSource: ArrayBuffer, + importObject: Record = {}, + logger: (msg: string) => void = console.log +): Promise { + const env = new Environment(importObject, logger); + + return WebAssembly.instantiate(bufferSource, env.imports).then( + (result: WebAssembly.WebAssemblyInstantiatedSource): Instance => { + return new Instance(result.module, {}, result.instance, env); + } + ); +} diff --git a/web/src/support.ts b/web/src/support.ts new file mode 100644 index 000000000000..7a2667a2299f --- /dev/null +++ b/web/src/support.ts @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/** + * Convert string to Uint8array. + * @param str The string. + * @returns The corresponding Uint8Array. + */ +export function StringToUint8Array(str: string): Uint8Array { + const arr = new Uint8Array(str.length + 1); + for (let i = 0; i < str.length; ++i) { + arr[i] = str.charCodeAt(i); + } + arr[str.length] = 0; + return arr; +} + +/** + * Convert Uint8array to string. + * @param array The array. + * @returns The corresponding string. + */ +export function Uint8ArrayToString(arr: Uint8Array): string { + const ret = []; + for (const ch of arr) { + ret.push(String.fromCharCode(ch)); + } + return ret.join(""); +} + +/** + * Internal assert helper + * @param condition condition The condition to fail. + * @param msg msg The message. + */ +export function assert(condition: boolean, msg?: string): asserts condition { + if (!condition) { + throw new Error("AssertError:" + (msg || "")); + } +} + +/** + * Get the path to the wasm library in nodejs. + * @return The wasm path. + */ +export function wasmPath(): string { + return __dirname + "/wasm"; +} \ No newline at end of file diff --git a/web/src/types.ts b/web/src/types.ts new file mode 100644 index 000000000000..621375a23f5f --- /dev/null +++ b/web/src/types.ts @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/** Common type definitions. */ + +/** + * Library interface provider that can provide + * syslibs(e.g. libs provided by WASI and beyond) for the Wasm runtime. + * + * It can be viewed as a generalization of imports used in WebAssembly instance creation. + * + * The {@link LibraryProvider.start} callback will be called + * to allow the library provider to initialize related resources during startup time. + * + * We can use Emscripten generated js Module as a { wasmLibraryProvider: LibraryProvider }. + */ +export interface LibraryProvider { + /** The imports that can be passed to WebAssembly instance creation. */ + imports: Record; + /** + * Callback function to notify the provider the created instance. + * @param inst The created instance. + */ + start: (inst: WebAssembly.Instance) => void; +} + +/** + * Disposable classes that contains resources (WasmMemory, GPU buffer) + * which needs to be explicitly disposed. + */ +export interface Disposable { + /** + * Dispose the internal resource + * This function can be called multiple times, + * only the first call will take effect. + */ + dispose: () => void; +} diff --git a/web/src/webgpu.ts b/web/src/webgpu.ts new file mode 100644 index 000000000000..640f7b4a7163 --- /dev/null +++ b/web/src/webgpu.ts @@ -0,0 +1,337 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +import "@webgpu/types"; +import { assert } from "./support"; +import { Pointer } from "./ctypes"; +import { Memory } from "./memory"; + +/** A pointer to points to the raw address space. */ +export type GPUPointer = number; + +/** + * DetectGPU device in the environment. + */ +export async function detectGPUDevice(): Promise { + if (typeof navigator !== "undefined" && navigator.gpu !== undefined) { + const adapter = await navigator.gpu.requestAdapter(); + return await adapter.requestDevice(); + } else { + return undefined; + } +} + +interface FunctionInfo { + name: string; + arg_types: Array; + thread_axis_tags: Array; +} + +/** + * WebGPU context + * Manages all the webgpu resources here. + */ +export class WebGPUContext { + device: GPUDevice; + memory: Memory; + + //private readBuffer:; + private bufferTable: Array = [undefined]; + private bufferTableFreeId: Array = []; + private pendingRead: Promise = Promise.resolve(); + private numPendingReads = 0; + + constructor(memory: Memory, device: GPUDevice) { + this.memory = memory; + this.device = device; + } + + /** + * Wait for all pending GPU tasks to complete + */ + async sync(): Promise { + const fence = this.device.defaultQueue.createFence(); + this.device.defaultQueue.signal(fence, 1); + if (this.numPendingReads != 0) { + // eslint-disable-next-line @typescript-eslint/no-empty-function + await Promise.all([fence.onCompletion(1), this.pendingRead]); + } else { + await fence.onCompletion(1); + } + } + + /** + * Create a PackedFunc that runs the given shader + * + * @param info The function information in json. + * @param data The shader data(in SPIRV) + */ + createShader(info: string, data: Uint8Array): Function { + const finfo = JSON.parse(info); + const layoutEntries: Array = []; + for (let i = 0; i < finfo.arg_types.length; ++i) { + const dtype = finfo.arg_types[i]; + if (dtype == "handle") { + layoutEntries.push({ + binding: i, + visibility: GPUShaderStage.COMPUTE, + type: "storage-buffer" + }); + } else { + throw new Error("Cannot handle argument type " + dtype + " in WebGPU shader"); + } + } + const bindGroupLayout = this.device.createBindGroupLayout({ + entries: layoutEntries + }); + + const pipeline = this.device.createComputePipeline({ + layout: this.device.createPipelineLayout({ + bindGroupLayouts: [ bindGroupLayout ] + }), + computeStage: { + module: this.device.createShaderModule({ + code: new Uint32Array(data.buffer) + }), + entryPoint: "main" + } + }); + + const dispatchToDim: Array = []; + + for (let i = 0; i < finfo.thread_axis_tags.length; ++i) { + const tag: string = finfo.thread_axis_tags[i]; + if (tag.startsWith("blockIdx.")) { + const target: number = tag.charCodeAt(tag.length - 1) - ("x".charCodeAt(0)); + assert(target >= 0 && target < 3); + dispatchToDim.push(target); + } else if (tag.startsWith("threadIdx.")) { + const target: number = tag.charCodeAt(tag.length - 1) - ("x".charCodeAt(0)); + assert(target >= 0 && target < 3); + dispatchToDim.push(target + 3); + } else { + throw new Error("Cannot handle thread_axis " + tag); + } + } + + const submitShader = (...args: Array): void => { + const commandEncoder = this.device.createCommandEncoder(); + const compute = commandEncoder.beginComputePass(); + compute.setPipeline(pipeline); + const bindGroupEntries: Array = []; + assert(args.length == layoutEntries.length + dispatchToDim.length); + + for (let i = 0; i < layoutEntries.length; ++i) { + bindGroupEntries.push({ + binding: i, + resource: { + buffer: this.gpuBufferFromPtr(args[i]) + } + }); + } + + compute.setBindGroup(0, this.device.createBindGroup({ + layout: bindGroupLayout, + entries: bindGroupEntries + })); + const wl: Array = [1, 1, 1, 1, 1, 1]; + for (let i = 0; i < dispatchToDim.length; ++i) { + wl[dispatchToDim[i]] = args[layoutEntries.length + i]; + } + compute.dispatch(wl[0], wl[1], wl[2]); + compute.endPass(); + const command = commandEncoder.finish(); + this.device.defaultQueue.submit([command]); + }; + + return submitShader; + } + + /** + * Get the device API according to its name + * @param The name of the API. + * @returns The corresponding device api. + */ + getDeviceAPI(name: string): Function { + if (name == "deviceAllocDataSpace") { + return (nbytes: number): GPUPointer => { + return this.deviceAllocDataSpace(nbytes); + }; + } else if (name == "deviceFreeDataSpace") { + return (ptr: GPUPointer): void => { + return this.deviceFreeDataSpace(ptr); + }; + } else if (name == "deviceCopyToGPU") { + return ( + from: Pointer, + to: GPUPointer, + toOffset: number, + nbytes: number + ): void => { + this.deviceCopyToGPU(from, to, toOffset, nbytes); + }; + } else if (name == "deviceCopyFromGPU") { + return ( + from: GPUPointer, + fromOffset: number, + to: Pointer, + nbytes: number + ): void => { + this.deviceCopyFromGPU(from, fromOffset, to, nbytes); + }; + } else if (name == "deviceCopyWithinGPU") { + return ( + from: GPUPointer, + fromOffset: number, + to: Pointer, + toOffset: number, + nbytes: number + ): void => { + this.deviceCopyWithinGPU(from, fromOffset, to, toOffset, nbytes); + }; + } else { + throw new Error("Unknown DeviceAPI function " + name); + } + + } + + // DeviceAPI + private deviceAllocDataSpace(nbytes: number): GPUPointer { + const buffer = this.device.createBuffer({ + size: nbytes, + usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST, + }); + return this.attachToBufferTable(buffer); + } + + private deviceFreeDataSpace(ptr: GPUPointer): void { + const idx = ptr; + const buffer = this.bufferTable[idx]; + this.bufferTable[idx] = undefined; + assert(buffer !== undefined); + this.bufferTableFreeId.push(idx); + buffer.destroy(); + } + + private deviceCopyToGPU( + from: Pointer, + to: GPUPointer, + toOffset: number, + nbytes: number + ): void { + // Perhaps it would be more useful to use a staging buffer? + const [gpuTemp, cpuTemp] = this.device.createBufferMapped({ + size: nbytes, + usage: GPUBufferUsage.MAP_WRITE | GPUBufferUsage.COPY_SRC, + }); + + const viewU8 = new Uint8Array(cpuTemp); + viewU8.set(this.memory.loadRawBytes(from, nbytes)); + gpuTemp.unmap(); + + const copyEncoder = this.device.createCommandEncoder(); + copyEncoder.copyBufferToBuffer( + gpuTemp, + 0, + this.gpuBufferFromPtr(to), + toOffset, + nbytes + ); + const copyCommands = copyEncoder.finish(); + this.device.defaultQueue.submit([copyCommands]); + gpuTemp.destroy(); + } + + private deviceCopyFromGPU( + from: GPUPointer, + fromOffset: number, + to: Pointer, + nbytes: number + ): void { + // Perhaps it would be more useful to resuse a staging buffer? + const gpuTemp = this.device.createBuffer({ + size: nbytes, + usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST, + }); + + const copyEncoder = this.device.createCommandEncoder(); + copyEncoder.copyBufferToBuffer( + this.gpuBufferFromPtr(from), + fromOffset, + gpuTemp, + 0, + nbytes + ); + const copyCommands = copyEncoder.finish(); + this.device.defaultQueue.submit([copyCommands]); + + this.numPendingReads += 1; + const readEvent = gpuTemp.mapReadAsync().then((data: ArrayBuffer) => { + this.memory.storeRawBytes(to, new Uint8Array(data)); + this.numPendingReads -= 1; + gpuTemp.destroy(); + }); + + if (this.numPendingReads == 1) { + this.pendingRead = readEvent; + } else { + this.pendingRead = Promise.all([ + this.pendingRead, + readEvent, + // eslint-disable-next-line @typescript-eslint/no-empty-function + ]).then(() => {}); + } + } + + private deviceCopyWithinGPU( + from: GPUPointer, + fromOffset: number, + to: Pointer, + toOffset: number, + nbytes: number + ): void { + const copyEncoder = this.device.createCommandEncoder(); + copyEncoder.copyBufferToBuffer( + this.gpuBufferFromPtr(from), + fromOffset, + this.gpuBufferFromPtr(to), + toOffset, + nbytes + ); + const copyCommands = copyEncoder.finish(); + this.device.defaultQueue.submit([copyCommands]); + } + + private gpuBufferFromPtr(ptr: GPUPointer): GPUBuffer { + const buffer = this.bufferTable[ptr]; + assert(buffer !== undefined); + return buffer; + } + + private attachToBufferTable(buffer: GPUBuffer): GPUPointer { + if (this.bufferTableFreeId.length != 0) { + const idx = this.bufferTableFreeId.pop() as number; + this.bufferTable[idx] = buffer; + return idx; + } else { + const idx = this.bufferTable.length; + this.bufferTable.push(buffer); + return idx; + } + } +} diff --git a/tests/web/test_module_load.js b/web/tests/node/test_module_load.js similarity index 58% rename from tests/web/test_module_load.js rename to web/tests/node/test_module_load.js index f4c809536bb5..561de8aa5786 100644 --- a/tests/web/test_module_load.js +++ b/web/tests/node/test_module_load.js @@ -16,43 +16,45 @@ * specific language governing permissions and limitations * under the License. */ - +/* eslint-disable no-undef */ // Load Emscripten Module, need to change path to root/lib const path = require("path"); -process.chdir(path.join(__dirname, "../../build")); -var Module = require("../../build/test_module.js"); -// Bootstrap TVMruntime with emscripten module. -const tvm_runtime = require("../../web/tvm_runtime.js"); -const tvm = tvm_runtime.create(Module); +const fs = require("fs"); +const assert = require("assert"); +const tvmjs = require("../../dist"); + +const wasmPath = tvmjs.wasmPath(); +const EmccWASI = require(path.join(wasmPath, "tvmjs_runtime.wasi.js")); +const wasmSource = fs.readFileSync(path.join(wasmPath, "test_addone.wasm")); + +const tvm = new tvmjs.Instance( + new WebAssembly.Module(wasmSource), + new EmccWASI() +); // Load system library -var sysLib = tvm.systemLib(); +const sysLib = tvm.systemLib(); function randomArray(length, max) { - return Array.apply(null, Array(length)).map(function() { + return Array.apply(null, Array(length)).map(function () { return Math.random() * max; }); } -function testAddOne() { +test("add one", () => { // grab pre-loaded function - var faddOne = sysLib.getFunction("add_one"); - var assert = require('assert'); - tvm.assert(tvm.isPackedFunc(faddOne)); - var n = 124; - var A = tvm.empty(n).copyFrom(randomArray(n, 1)); - var B = tvm.empty(n); + const faddOne = sysLib.getFunction("add_one"); + assert(tvm.isPackedFunc(faddOne)); + const n = 124; + const A = tvm.empty(n).copyFrom(randomArray(n, 1)); + const B = tvm.empty(n); // call the function. faddOne(A, B); - AA = A.asArray(); // retrieve values in js array - BB = B.asArray(); // retrieve values in js array + const AA = A.toArray(); // retrieve values in js array + const BB = B.toArray(); // retrieve values in js array // verify for (var i = 0; i < BB.length; ++i) { assert(Math.abs(BB[i] - (AA[i] + 1)) < 1e-5); } - faddOne.release(); -} - -testAddOne(); -sysLib.release(); -console.log("Finish verifying test_module_load"); + faddOne.dispose(); +}); diff --git a/web/tests/node/test_ndarray.js b/web/tests/node/test_ndarray.js new file mode 100644 index 000000000000..eb0a8f446d4c --- /dev/null +++ b/web/tests/node/test_ndarray.js @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/* eslint-disable no-undef */ +const path = require("path"); +const fs = require("fs"); +const assert = require("assert"); +const tvmjs = require("../../dist/tvmjs.bundle") + +const wasmPath = tvmjs.wasmPath(); +const EmccWASI = require(path.join(wasmPath, "tvmjs_runtime.wasi.js")); +const wasmSource = fs.readFileSync(path.join(wasmPath, "tvmjs_runtime.wasm")); + +let tvm = new tvmjs.Instance(new WebAssembly.Module(wasmSource), new EmccWASI()); + +// Basic fields. +assert(tvm.listGlobalFuncNames() !== undefined); + +// Test ndarray +function testArrayCopy(dtype, arrayType) { + let data = [1, 2, 3, 4, 5, 6]; + let a = tvm.empty([2, 3], dtype).copyFrom(data); + + assert(a.context.toString() == "cpu(0)"); + assert(a.shape[0] == 2 && a.shape[1] == 3); + + let ret = a.toArray(); + assert(ret instanceof arrayType); + assert(ret.toString() == arrayType.from(data).toString()); + // test multiple dispose. + a.dispose(); + a.dispose(); +} + +test("array copy", () => { + testArrayCopy("float32", Float32Array); + testArrayCopy("int", Int32Array); + testArrayCopy("int8", Int8Array); + testArrayCopy("uint8", Uint8Array); + testArrayCopy("float64", Float64Array); +}); + diff --git a/web/tests/node/test_packed_func.js b/web/tests/node/test_packed_func.js new file mode 100644 index 000000000000..e18c0aecfdc0 --- /dev/null +++ b/web/tests/node/test_packed_func.js @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/* eslint-disable no-undef */ +const path = require("path"); +const fs = require("fs"); +const assert = require("assert"); +const tvmjs = require("../../dist"); + +const wasmPath = tvmjs.wasmPath(); +const EmccWASI = require(path.join(wasmPath, "tvmjs_runtime.wasi.js")); +const wasmSource = fs.readFileSync(path.join(wasmPath, "tvmjs_runtime.wasm")); + +let tvm = new tvmjs.Instance( + new WebAssembly.Module(wasmSource), + new EmccWASI() +); + +test("GetGlobal", () => { + let flist = tvm.listGlobalFuncNames(); + let faddOne = tvm.getGlobalFunc("testing.add_one"); + let fecho = tvm.getGlobalFunc("testing.echo"); + + assert(faddOne(tvm.scalar(1, "int")) == 2); + // check function argument with different types. + assert(fecho(1123) == 1123); + assert(fecho("xyz") == "xyz"); + + let bytes = new Uint8Array([1, 2, 3]); + let rbytes = fecho(bytes); + assert(rbytes.length == bytes.length); + + for (let i = 0; i < bytes.length; ++i) { + assert(rbytes[i] == bytes[i]); + } + + assert(fecho(undefined) == undefined); + + let arr = tvm.empty([2, 2]).copyFrom([1, 2, 3, 4]); + let arr2 = fecho(arr); + assert(arr.handle == arr2.handle); + assert(arr2.toArray().toString() == arr.toArray().toString()); + + let mod = tvm.systemLib(); + let ret = fecho(mod); + assert(ret.handle == mod.handle); + assert(flist.length != 0); + + mod.dispose(); + ret.dispose(); + arr.dispose(); + arr2.dispose(); + fecho.dispose(); + faddOne.dispose(); +}); + +test("ReturnFunc", () => { + function addy(y) { + function add(x, z) { + return x + y + z; + } + return add; + } + + let fecho = tvm.getGlobalFunc("testing.echo"); + let myf = tvm.toPackedFunc(addy); + assert(tvm.isPackedFunc(myf)); + let myf2 = tvm.toPackedFunc(myf); + assert(myf2._tvmPackedCell.handle === myf._tvmPackedCell.handle); + let f = myf(10); + + assert(tvm.isPackedFunc(f)); + assert(f(11, 0) == 21); + assert(f("x", 1) == "x101"); + assert(f("x", "yz") == "x10yz"); + + fecho.dispose(); + myf.dispose(); + myf2.dispose(); + // test multiple dispose. + f.dispose(); + f.dispose(); +}); + +test("RegisterGlobal", () => { + tvm.registerFunc("xyz", function (x, y) { + return x + y; + }); + + let f = tvm.getGlobalFunc("xyz"); + assert(f(1, 2) == 3); + f.dispose(); + + let syslib = tvm.systemLib(); + syslib.dispose(); +}); diff --git a/tests/web/prepare_test_libs.py b/web/tests/python/prepare_test_libs.py similarity index 69% rename from tests/web/prepare_test_libs.py rename to web/tests/python/prepare_test_libs.py index a0e2c13eab82..ec4eb5be1536 100644 --- a/tests/web/prepare_test_libs.py +++ b/web/tests/python/prepare_test_libs.py @@ -14,27 +14,28 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# Prepare test library for js. +# Prepare test library for standalone wasm runtime test. + import tvm from tvm import te -from tvm.contrib import emscripten +from tvm.contrib import emcc import os + def prepare_test_libs(base_path): - target = "llvm -target=asmjs-unknown-emscripten -system-lib" + target = "llvm -target=wasm32-unknown-unknown-wasm -system-lib" if not tvm.runtime.enabled(target): raise RuntimeError("Target %s is not enbaled" % target) n = te.var("n") A = te.placeholder((n,), name='A') B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name='B') s = te.create_schedule(B.op) - fadd1 = tvm.build(s, [A, B], target, name="add_one") - obj_path = os.path.join(base_path, "test_add_one.bc") - fadd1.save(obj_path) - emscripten.create_js(os.path.join(base_path, "test_module.js"), obj_path, - options=["-s", "WASM=0", "-s", "USE_GLFW=3", "-s", - "USE_WEBGL2=1", "-lglfw"]) + fadd = tvm.build(s, [A, B], target, name="add_one") + + wasm_path = os.path.join(base_path, "test_addone.wasm") + fadd.export_library(wasm_path, emcc.create_tvmjs_wasm) + if __name__ == "__main__": curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) - prepare_test_libs(os.path.join(curr_path, "../../build")) + prepare_test_libs(os.path.join(curr_path, "../../dist/wasm")) diff --git a/web/tests/python/webgpu_rpc_test.py b/web/tests/python/webgpu_rpc_test.py new file mode 100644 index 000000000000..d16ba3f3304e --- /dev/null +++ b/web/tests/python/webgpu_rpc_test.py @@ -0,0 +1,79 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Simple testcode to test Javascript RPC + +To use it, start a rpc proxy with "python -m tvm.exec.rpc_proxy". +Connect javascript end to the websocket port and connect to the RPC. +""" + +import tvm +from tvm import te +from tvm import rpc +from tvm.contrib import util, emcc +import numpy as np + +proxy_host = "localhost" +proxy_port = 9090 + + +def test_rpc(): + if not tvm.runtime.enabled("rpc"): + return + # generate the wasm library + target_device = "webgpu" + target_host = "llvm -target=wasm32-unknown-unknown-wasm -system-lib" + if not tvm.runtime.enabled(target_host): + raise RuntimeError("Target %s is not enbaled" % target_host) + + n = 2048 + A = te.placeholder((n,), name='A') + B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name='B') + s = te.create_schedule(B.op) + + num_thread = 2 + xo, xi = s[B].split(B.op.axis[0], factor=num_thread) + s[B].bind(xi, te.thread_axis("threadIdx.x")) + s[B].bind(xo, te.thread_axis("blockIdx.x")) + + + fadd = tvm.build(s, [A, B], target_device, target_host=target_host, name="addone") + temp = util.tempdir() + + wasm_path = temp.relpath("addone_gpu.wasm") + fadd.export_library(wasm_path, emcc.create_tvmjs_wasm) + + wasm_binary = open(wasm_path, "rb").read() + remote = rpc.connect(proxy_host, proxy_port, key="wasm", + session_constructor_args=["rpc.WasmSession", wasm_binary]) + + def check(remote): + # basic function checks. + ctx = remote.webgpu(0) + adata = np.random.uniform(size=n).astype(A.dtype) + a = tvm.nd.array(adata, ctx) + b = tvm.nd.array(np.zeros(n, dtype=A.dtype), ctx) + + np.testing.assert_equal(a.asnumpy(), adata) + f1 = remote.system_lib() + addone = f1.get_function("addone") + addone(a, b) + np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1) + print("Test pass..") + + check(remote) + +test_rpc() diff --git a/tests/web/websock_rpc_test.py b/web/tests/python/websock_rpc_test.py similarity index 53% rename from tests/web/websock_rpc_test.py rename to web/tests/python/websock_rpc_test.py index 8be8ce04cb75..f7c07924a210 100644 --- a/tests/web/websock_rpc_test.py +++ b/web/tests/python/websock_rpc_test.py @@ -22,45 +22,63 @@ import tvm from tvm import te -import os from tvm import rpc -from tvm.contrib import util, emscripten +from tvm.contrib import util, emcc import numpy as np proxy_host = "localhost" proxy_port = 9090 -def test_rpc_array(): +def test_rpc(): if not tvm.runtime.enabled("rpc"): return - # graph - n = tvm.runtime.convert(1024) + # generate the wasm library + target = "llvm -target=wasm32-unknown-unknown-wasm -system-lib" + if not tvm.runtime.enabled(target): + raise RuntimeError("Target %s is not enbaled" % target) + n = te.var("n") A = te.placeholder((n,), name='A') B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name='B') s = te.create_schedule(B.op) - remote = rpc.connect(proxy_host, proxy_port, key="js") - target = "llvm -target=asmjs-unknown-emscripten -system-lib" - def check_remote(): - if not tvm.runtime.enabled(target): - print("Skip because %s is not enabled" % target) - return - temp = util.tempdir() + + fadd = tvm.build(s, [A, B], target, name="addone") + temp = util.tempdir() + + wasm_path = temp.relpath("addone.wasm") + fadd.export_library(wasm_path, emcc.create_tvmjs_wasm) + + wasm_binary = open(wasm_path, "rb").read() + + remote = rpc.connect(proxy_host, proxy_port, key="wasm", + session_constructor_args=["rpc.WasmSession", wasm_binary]) + + def check(remote): + # basic function checks. + faddone = remote.get_function("testing.asyncAddOne") + fecho = remote.get_function("testing.echo") + assert(faddone(100) == 101) + assert(fecho(1, 2, 3) == 1) + assert(fecho(1, 2, 3) == 1) + assert(fecho(100, 2, 3) == 100) + assert(fecho("xyz") == "xyz") + assert(bytes(fecho(bytearray(b"123"))) == b"123") + + # run the generated library. + f1 = remote.system_lib() ctx = remote.cpu(0) - f = tvm.build(s, [A, B], target, name="myadd") - path_obj = temp.relpath("dev_lib.bc") - path_dso = temp.relpath("dev_lib.js") - f.save(path_obj) - emscripten.create_js(path_dso, path_obj, side_module=True) - # Upload to suffix as dso so it can be loaded remotely - remote.upload(path_dso, "dev_lib.dso") - data = remote.download("dev_lib.dso") - f1 = remote.load_module("dev_lib.dso") a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx) b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx) - time_f = f1.time_evaluator(f1.entry_name, remote.cpu(0), number=10) + # invoke the function + addone = f1.get_function("addone") + addone(a, b) + + # time evaluator + time_f = f1.time_evaluator("addone", ctx, number=100, repeat=10) + time_f(a, b) cost = time_f(a, b).mean print('%g secs/op' % cost) np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1) - check_remote() -test_rpc_array() + check(remote) + +test_rpc() diff --git a/web/tsconfig.json b/web/tsconfig.json new file mode 100644 index 000000000000..6aec44858a7a --- /dev/null +++ b/web/tsconfig.json @@ -0,0 +1,13 @@ +{ + "compilerOptions": { + "module": "commonjs", + "target": "es6", + "outDir": "dist", + "rootDir": "src", + "declaration": true, + "sourceMap": true, + "strict": true + }, + "include": ["src"], + "exclude": ["node_modules"] +} diff --git a/web/tvm_runtime.js b/web/tvm_runtime.js deleted file mode 100644 index b62b298d969e..000000000000 --- a/web/tvm_runtime.js +++ /dev/null @@ -1,1274 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/** - * TVM Javascript web runtime library. - * - * @projectname tvm - * @version 0.7.dev1 - */ -/* eslint no-unused-vars: "off" */ -/* eslint no-unexpected-multiline: "off" */ -/* eslint indent: "off" */ -/* eslint no-console: "off" */ -/** - * TVM Runtime namespace. - * Provide tvm_runtime.create to create a {@link tvm.TVMRuntime}. - * - * @namespace tvm_runtime - */ -var tvm_runtime = tvm_runtime || {}; - -/** - * TVM root namespace. - * The classes inside this namespace need to be constructed by factory functions. - * Use {@link tvm_runtime}.create to get started. - * - * @namespace tvm - */ -(function() { - /** - * TVMRuntime object for interacting with TVM runtime. - * This object can be constructed using {@link tvm_runtime}.create - * - * @class - * @memberof tvm - */ - function TVMRuntime() { - "use strict"; - var runtime_ref = this; - // Utility function to throw error - function throwError(message) { - if (typeof runtime_ref.logger !== "undefined") { - runtime_ref.logger(message); - } - if (typeof Error !== "undefined") { - throw new Error(message); - } - throw message; - } - var Module = this.Module; - var Runtime = this.Runtime; - if (typeof Module === "undefined") { - throwError("Emscripten Module is not available"); - } - // constants - var SIZEOF_POINTER = 4; - var SIZEOF_SIZE_T = 4; - var SIZEOF_FLOAT = 4; - var SIZEOF_INT = 4; - var SIZEOF_INT8 = 1; - var SIZEOF_INT64 = 8; - var SIZEOF_DOUBLE = 8; - var SIZEOF_TYPE = 4; - var SIZEOF_CTX = SIZEOF_INT + SIZEOF_INT; - var SIZEOF_TVMVALUE = SIZEOF_DOUBLE; - var ARRAY_OFFSET_DATA = 0; - var ARRAY_OFFSET_CTX = ARRAY_OFFSET_DATA + SIZEOF_POINTER; - var ARRAY_OFFSET_DEV_TYPE = ARRAY_OFFSET_CTX; - var ARRAY_OFFSET_DEV_ID = ARRAY_OFFSET_CTX + SIZEOF_INT; - var ARRAY_OFFSET_NDIM = ARRAY_OFFSET_CTX + SIZEOF_CTX; - var ARRAY_OFFSET_DTYPE = ARRAY_OFFSET_NDIM + SIZEOF_INT; - var ARRAY_OFFSET_DTYPE_CODE = ARRAY_OFFSET_DTYPE; - var ARRAY_OFFSET_DTYPE_BITS = ARRAY_OFFSET_DTYPE_CODE + SIZEOF_INT8; - var ARRAY_OFFSET_DTYPE_LANES = ARRAY_OFFSET_DTYPE_BITS + SIZEOF_INT8; - var ARRAY_OFFSET_SHAPE = ARRAY_OFFSET_DTYPE + SIZEOF_TYPE; - var ARRAY_OFFSET_STRIDES = ARRAY_OFFSET_STRIDES + SIZEOF_POINTER; - var ARRAY_OFFSET_BYTE_OFFSET = ARRAY_OFFSET_STRIDES + SIZEOF_POINTER; - // Type codes - var kInt = 0; - var kUInt = 1; - var kFloat = 2; - var kTVMOpaqueHandle = 3; - var kNull = 4; - var kTVMDataType = 5; - var kTVMContext = 6; - var kTVMDLTensorHandle = 7; - var kTVMObjectHandle = 8; - var kTVMModuleHandle = 9; - var kTVMPackedFuncHandle = 10; - var kTVMStr = 11; - var kTVMBytes = 12; - var kTVMObjectRValueRefArg = 14; - //----------------------------------------- - // TVM CWrap library - // ---------------------------------------- - var TVMGetLastError = Module.cwrap( - "TVMGetLastError", - "string", // const char* - []); - - var TVMAPISetLastError = Module.cwrap - ("TVMAPISetLastError", - null, - ["string" // const char* - ]); - - var TVMModImport = Module.cwrap - ("TVMModImport", - "number", - ["number", // TVMModuleHandle mod - "number" // TVMModuleHandle dep - ]); - - var TVMModGetFunction = Module.cwrap - ("TVMModGetFunction", - "number", - ["number", // TVMModuleHandle mod - "string", // const char* func_name - "number", // int query_imports - "number" // TVMFunctionHandle *out - ]); - - var TVMModFree = Module.cwrap - ("TVMModFree", - "number", - ["number" // TVMModeHandle mod - ]); - - var TVMFuncFree = Module.cwrap - ("TVMFuncFree", - "number", - ["number" // TVMFunctionHandle func - ]); - - var TVMFuncCall = Module.cwrap - ("TVMFuncCall", - "number", - ["number", // TVMFunctionHandle func - "number", // TVMValue* arg_values - "number", // int* arg_tcodes - "number", // int num_args - "number", // int ret_val - "number" // int ret_type_code - ]); - - var TVMCFuncSetReturn = Module.cwrap - ("TVMCFuncSetReturn", - "number", - ["number", // TVMRetValueHandle ret - "number", // TVMValue* value - "number", // int* type_code - "number" // int num_ret - ]); - - var TVMCbArgToReturn = Module.cwrap - ("TVMCbArgToReturn", - "number", - ["number", // TVMValue* value - "number" // int* code - ]); - - var TVMFuncCreateFromCFunc = Module.cwrap - ("TVMFuncCreateFromCFunc", - "number", - ["number", // TVMPackedCFunc func, - "number", // void* resource_handle - "number", // TVMPackedCFuncFinalizer fin - "number" // TVMFunctionHandle *out - ]); - - var TVMFuncRegisterGlobal = Module.cwrap - ("TVMFuncRegisterGlobal", - "number", - ["string", // name - "number", // TVMFunctionHandle f - "number" // int override - ]); - - var TVMFuncGetGlobal = Module.cwrap - ("TVMFuncGetGlobal", - "number", - ["string", // const char* name - "number" // TVMFunctionHandle* out - ]); - - var TVMFuncListGlobalNames = Module.cwrap - ("TVMFuncListGlobalNames", - "number", - ["number", // int* out_size - "number" // const char*** out_array - ]); - - - var TVMArrayAlloc = Module.cwrap - ("TVMArrayAlloc", - "number", - ["number", // const tvm_index_t* shape - "number", // int ndim - "number", // int dtype_code - "number", // int dtype_bits - "number", // int dtype_lanes - "number", // int device_type - "number", // int device_id - "number" // int TVMArrayHandle* out - ]); - - var TVMArrayFree = Module.cwrap - ("TVMArrayFree", - "number", - ["number" // TVMArrayHandle handle - ]); - - var TVMArrayCopyFromTo = Module.cwrap - ("TVMArrayCopyFromTo", - "number", - ["number", // TVMArrayHandle from - "number" // TVMArrayHandle to - ]); - - var TVMArrayCopyFromBytes = Module.cwrap - ("TVMArrayCopyFromBytes", - "number", - ["number", // TVMArrayHandle handle - "number", // int data - "number" // size_t nbytes - ]); - - var TVMArrayCopyToBytes = Module.cwrap - ("TVMArrayCopyToBytes", - "number", - ["number", // TVMArrayHandle handle - "number", // int data - "number" // size_t nbytes - ]); - - var TVMModLoadFromFile = Module.cwrap - ("TVMModLoadFromFile", - "number", - ["string", // const char* file_name - "string", // const char* format - "number" // TVMModuleHandle* out - ]) - - //----------------------------------------- - // Static utility functions - // ---------------------------------------- - this.assert = function(condition, message) { - if (!condition) { - message = message || "assert failed"; - throwError(message); - } - }; - /** - * Logging function. - * Override this to change logger behavior. - * - * @param {string} message - */ - this.logger = function(message) { - console.log(message); - }; - - function logging(message) { - runtime_ref.logger(message); - } - // Override print error to logging - Module.printErr = logging; - var CHECK = this.assert; - - function TVM_CALL(ret) { - if (ret != 0) { - throwError(TVMGetLastError()); - } - } - - function CInt64ArrayToJS(ptr, size) { - var ret = []; - for (var i = 0; i < size; ++i) { - ret.push(Module.getValue(ptr + i * SIZEOF_INT64, "i64")); - } - return ret; - } - - function CStringToJS(ptr) { - var ret = []; - var ch = 1; - while (ch != 0) { - ch = Module.getValue(ptr, "i8"); - if (ch != 0) { - ret.push(String.fromCharCode(ch)); - } - ++ptr; - } - return ret.join(""); - } - - function CBytesToJS(ptr) { - var data = Module.getValue(ptr, "*"); - var size = Module.getValue(ptr + SIZEOF_POINTER, "i32"); - var ret = new Uint8Array(new ArrayBuffer(size)); - ret.set(new Uint8Array(Module.HEAPU8.buffer, data, size)); - return ret; - } - - function StringToUint8Array(str) { - var arr = new Uint8Array(str.length + 1); - for(var i = 0; i < str.length; ++i) { - arr[i] = str.charCodeAt(i); - } - arr[str.length] = 0; - return arr; - } - //----------------------------------------- - // Class declarations - // ---------------------------------------- - function CBuffer(nbytes) { - this.data = Module._malloc(nbytes); - } - - function RefTVMValue() { - this.data = Module._malloc(SIZEOF_TVMVALUE); - } - - function TVMArgs(nargs) { - this.nargs = nargs; - this.value = Module._malloc(SIZEOF_TVMVALUE * nargs); - this.tcode = Module._malloc(SIZEOF_INT * nargs); - this.temp = []; - } - - function TVMType(code, bits, lanes) { - this.code = code; - this.bits = bits; - this.lanes = lanes; - } - /** - * TVM device context. - * @class - * @memberof tvm - */ - function TVMContext(device_type, device_id) { - this.device_type = device_type; - this.device_id = device_id; - } - /** - * TVM n-dimensional array. - * - * Use {@link tvm.TVMRuntime}.empty to create an instance. - * @class - * @memberof tvm - */ - function NDArray(handle) { - this.handle = handle; - this.ndim = Module.getValue(this.handle + ARRAY_OFFSET_NDIM, "i32"); - // shape - var cshape = Module.getValue(this.handle + ARRAY_OFFSET_SHAPE, "*"); - this.shape = CInt64ArrayToJS(cshape, this.ndim); - // dtype - var code = Module.getValue(this.handle + ARRAY_OFFSET_DTYPE_CODE, "i8"); - var bits = Module.getValue(this.handle + ARRAY_OFFSET_DTYPE_BITS, "i8"); - var lanes = Module.getValue(this.handle + ARRAY_OFFSET_DTYPE_LANES, "i16"); - var dtype = new TVMType(code, bits, lanes); - this.dtype = dtype; - this.BYTES_PER_ELEMENT = (dtype.bits * dtype.lanes / 8); - // ctx - var device_type = Module.getValue(this.handle + ARRAY_OFFSET_DEV_TYPE, "i32"); - var device_id = Module.getValue(this.handle + ARRAY_OFFSET_DEV_ID, "i32"); - this.context = new TVMContext(device_type, device_id); - // byte_offset - this.byteOffset = Module.getValue(this.handle + ARRAY_OFFSET_BYTE_OFFSET, "i64"); - } - - function TVMFunction(handle) { - this.handle = handle; - } - /** - * Module container of TVM generated functions. - * - * @class - * @memberof tvm - */ - function TVMModule(handle) { - this.handle = handle; - } - /** - * A typed scalar constant. - * This can be used to pass number as integer types to tvm function. - * Use {@link tvm.TVMRuntime}.constant to create an instance. - * @class - * @memberof tvm - */ - function TVMConstant(value, dtype) { - this.value = value; - this.dtype = dtype; - } - //----------------------------------------- - // Private Functions - // ---------------------------------------- - function getTVMType(dtype) { - if (dtype instanceof TVMType) return dtype; - if (typeof dtype == "string") { - var pattern = dtype; - var code, bits = 32, lanes = 1; - if (pattern.substring(0, 5) == "float") { - pattern = pattern.substring(5, pattern.length); - code = kFloat; - } else if (pattern.substring(0, 3) == "int") { - pattern = pattern.substring(3, pattern.length); - code = kInt; - } else if (pattern.substring(0, 4) == "uint") { - pattern = pattern.substring(4, pattern.length); - code = kUInt; - } else if (pattern.substring(0, 6) == "handle") { - pattern = pattern.substring(5, pattern.length); - code = kTVMOpaqueHandle; - bits = 64; - } else { - throw throwError("Unknown dtype " + dtype); - } - var arr = pattern.split("x"); - if (arr.length >= 1) { - var parsed = parseInt(arr[0]); - if (parsed == arr[0]) { - bits = parsed; - } - } - if (arr.length >= 2) { - lanes = parseInt(arr[1]); - } - return new TVMType(code, bits, lanes); - } else { - throw throwError("Unknown dtype " + dtype); - } - } - - function TVMRetValueToJS(vptr, tcode) { - switch (tcode) { - case kInt: - case kUInt: return Module.getValue(vptr, "i64"); - case kFloat: return Module.getValue(vptr, "double"); - case kTVMPackedFuncHandle: return makeTVMFunction(Module.getValue(vptr, "*")); - case kTVMModuleHandle: return new TVMModule(Module.getValue(vptr, "*")); - case kNull: return null; - case kTVMStr: return CStringToJS(Module.getValue(vptr, "*")); - case kTVMBytes: return CBytesToJS(Module.getValue(vptr, "*")); - default: throwError("Unsupported return type code=" + tcode); - } - } - - function makeTVMFunction(handle) { - var func = new TVMFunction(handle); - var ret = function () { - // alloc - var args = new TVMArgs(arguments.length); - var rvalue = new RefTVMValue(); - var rtcode = new RefTVMValue(); - args.setArguments(arguments); - TVM_CALL(TVMFuncCall(handle, args.value, args.tcode, - args.nargs, rvalue.data, rtcode.data)); - var rv = TVMRetValueToJS(rvalue.data, rtcode.asInt()); - // release - args.release(); - rvalue.release(); - rtcode.release(); - return rv; - }; - var release = function() { - func.release(); - }; - ret._tvm_function = func; - ret.release = release; - return ret; - } - //----------------------------------------- - // Javascript PackedCallback System - // ---------------------------------------- - var funcTable = [0]; - var freeFuncId = []; - - function invokeCallback(arg_value, arg_tcode, nargs, ret, handle) { - var args = []; - for (var i = 0; i < nargs; ++i) { - var vptr = arg_value + i * SIZEOF_TVMVALUE; - var tcodeptr = arg_tcode + i * SIZEOF_INT; - var tcode = Module.getValue(tcodeptr, "i32"); - if (tcode == kTVMObjectHandle || - tcode == kTVMObjectRValueRefArg || - tcode == kTVMPackedFuncHandle || - tcode == kTVMModuleHandle) { - TVM_CALL(TVMCbArgToReturn(vptr, tcodeptr)); - } - tcode = Module.getValue(tcodeptr, "i32"); - args.push(TVMRetValueToJS(vptr, tcode)); - } - var rv = funcTable[handle].apply(null, args); - if (typeof rv !== "undefined") { - // alloc - var rarg = new TVMArgs(1); - rarg.setArguments([rv]); - TVM_CALL(TVMCFuncSetReturn(ret, rarg.value, rarg.tcode, 1)); - // release - rarg.release(); - } - return 0; - } - function freeCallback(handle) { - funcTable[handle] = 0; - freeFuncId.push(handle); - } - var fptrInvokeCallback = null; - var fptrFreeCallback = null; - if (typeof Runtime !== "undefined" && - typeof Runtime.addFunction !== "undefined") { - fptrInvokeCallback = Runtime.addFunction(invokeCallback); - fptrFreeCallback = Runtime.addFunction(freeCallback); - } - /** - * Check if a function is TVM PackedFunc - * @param {Function} f function to be checked. - * @return {boolean} Whether f is PackedFunc - */ - this.isPackedFunc = function(f) { - return (typeof f == "function") && f.hasOwnProperty("_tvm_function"); - }; - var isPackedFunc = this.isPackedFunc; - /** - * Convert a javascript function to TVM function. - * @param {Function} f javascript function. - * @return {Function} The created TVMFunction. - */ - this.convertFunc = function(f) { - if (isPackedFunc(f)) return f; - CHECK(fptrInvokeCallback !== null, - "Emscripten Runtime addFunction is not available"); - var fid; - if (freeFuncId.length != 0) { - fid = freeFuncId.pop(); - } else { - fid = funcTable.length; - funcTable.push(0); - } - funcTable[fid] = f; - // alloc - var out = new RefTVMValue(); - TVM_CALL(TVMFuncCreateFromCFunc( - fptrInvokeCallback, fid, fptrFreeCallback, out.data)); - var out_handle = out.asHandle(); - // release - out.release(); - return makeTVMFunction(out_handle); - }; - var convertFunc = this.convertFunc; - //----------------------------------------- - // Private Class declarations - // ---------------------------------------- - CBuffer.prototype = { - /** - * Finalizer: resources from the object. - */ - release : function() { - if (this.data != 0) { - Module._free(this.data); - this.data = 0; - } - }, - }; - // RefTVMValue - RefTVMValue.prototype = { - /** - * Finalizer: resources from the object. - */ - release : function() { - if (this.data != 0) { - Module._free(this.data); - this.data = 0; - } - }, - asInt : function() { - return Module.getValue(this.data, "i32"); - }, - asInt64 : function() { - return Module.getValue(this.data, "i64"); - }, - asDouble : function() { - return Module.getValue(this.data, "double"); - }, - asHandle : function() { - return Module.getValue(this.data, "*"); - } - }; - // TVMArgs - TVMArgs.prototype = { - release : function() { - if (this.value != 0) { - Module._free(this.value); - Module._free(this.tcode); - this.value = 0; - for (var i = 0; i< this.temp.length; ++i) { - if (this.temp[i].release instanceof Function) { - this.temp[i].release(); - } - } - } - }, - setInt : function(index, value) { - Module.setValue(this.tcode + index * SIZEOF_INT, kInt, "i32"); - Module.setValue(this.value + index * SIZEOF_TVMVALUE, value, "i64"); - }, - setDouble : function(index, value) { - Module.setValue(this.tcode + index * SIZEOF_INT, kFloat, "i32"); - Module.setValue(this.value + index * SIZEOF_TVMVALUE, value, "double"); - }, - setHandle : function(index, value, tcode) { - Module.setValue(this.tcode + index * SIZEOF_INT, tcode, "i32"); - Module.setValue(this.value + index * SIZEOF_TVMVALUE, value, "*"); - }, - setString : function(index, value) { - var sdata = new CBuffer(value.length + 1); - Module.HEAPU8.set(StringToUint8Array(value), sdata.data); - this.temp.push(sdata); - Module.setValue(this.tcode + index * SIZEOF_INT, kTVMStr, "i32"); - Module.setValue(this.value + index * SIZEOF_TVMVALUE, sdata.data, "*"); - }, - setBytes : function(index, value) { - CHECK(value instanceof Uint8Array); - var sdata = new CBuffer(value.length); - var sheader = new CBuffer(SIZEOF_POINTER + SIZEOF_SIZE_T); - Module.HEAPU8.set(new Uint8Array(value), sdata.data); - Module.setValue(sheader.data, sdata.data, "*"); - Module.setValue(sheader.data + SIZEOF_POINTER, value.length, "i32"); - this.temp.push(sdata); - this.temp.push(sheader); - Module.setValue(this.tcode + index * SIZEOF_INT, kTVMBytes, "i32"); - Module.setValue(this.value + index * SIZEOF_TVMVALUE, sheader.data, "*"); - }, - setArguments : function(args) { - for (var i = 0; i < args.length; ++i) { - var v = args[i]; - var tp = typeof v; - if (v instanceof NDArray) { - this.setHandle(i, v.handle, kTVMDLTensorHandle); - } else if (v instanceof TVMConstant) { - var code = getTVMType(v.dtype).code; - if (code == kInt || code == kUInt) { - this.setInt(i, v.value); - } else if (code == kFloat) { - this.setDouble(i, v.value); - } else { - CHECK(code == kTVMOpaqueHandle); - this.setHandle(i, v.value, kTVMOpaqueHandle); - } - } else if (tp == "number") { - this.setDouble(i, v); - } else if (tp == "function" && v.hasOwnProperty("_tvm_function")) { - this.setString(i, v._tvm_function.handle, kTVMPackedFuncHandle); - } else if (v === null) { - this.setHandle(i, 0, kNull); - } else if (tp == "string") { - this.setString(i, v); - } else if (v instanceof Uint8Array) { - this.setBytes(i, v); - } else if (v instanceof Function) { - v = convertFunc(v); - this.temp.push(v); - this.setHandle(i, v._tvm_function.handle, kTVMPackedFuncHandle); - } else if (v instanceof TVMModule) { - this.setHandle(i, v.handle, kTVMModuleHandle); - } else { - throwError("Unsupported argument type " + tp); - } - } - } - }; - // TVMType - var TYPE_CODE2STR = { - 0 : "int", - 1 : "uint", - 2 : "float", - 4 : "handle" - }; - - TVMType.prototype = { - toString : function() { - var ret = TYPE_CODE2STR[this.code] + this.bits.toString(); - if (this.lanes != 1) { - return ret + "x" + this.lanes.toString(); - } else { - return ret; - } - } - }; - // TVMFunction - TVMFunction.prototype = { - release : function() { - if (this.handle != 0) { - TVM_CALL(TVMFuncFree(this.handle)); - this.handle = 0; - } - } - }; - // TVMContext - var CTX_MASK2STR = { - 1 : "cpu", - 2 : "gpu", - 4 : "opencl", - 7 : "vulkan", - 8 : "metal", - 9 : "vpi", - 11 : "opengl", - }; - var CTX_STR2MASK = { - "cpu": 1, - "gpu": 2, - "cuda": 2, - "cl": 4, - "opencl": 4, - "vulkan": 7, - "metal": 8, - "vpi": 9, - "opengl": 11, - }; - TVMContext.prototype = { - toString : function() { - return CTX_MASK2STR[this.device_type] + "(" + this.device_id.toString() + ")"; - } - }; - //----------------------------------------- - // Public Functions - // ---------------------------------------- - /** - * Construct a TVMContext given device type and id. - * - * @param {number} device_type, string or int, The device type. - * @param {number} device_id, the device id. - * @return {tvm.TVMContext} The created TVMContext - */ - this.context = function(device_type, device_id) { - if (typeof device_type == "string") { - device_type = CTX_STR2MASK[device_type]; - } - return new TVMContext(device_type, device_id); - }; - var context = this.context; - /** - * Create empty ndarray with given shape. - * - * @param {Array.} shape The shape of the array. - * @param {string} dtype The data type of the array, optional, default="float32" - * @param {tvm.TVMContext} ctx The context of the array, optional, default=cpu(0). - * @return {tvm.NDArray} The created ndarray. - */ - this.empty = function(shape, dtype, ctx) { - dtype = (typeof dtype !== "undefined") ? dtype: "float32"; - ctx = (typeof ctx !== "undefined") ? ctx : context("cpu", 0); - shape = (typeof shape == "number") ? [shape] : shape; - // alloc - var cshape = Module._malloc(SIZEOF_INT64 * shape.length); - var out = new RefTVMValue(); - for (var i = 0; i < shape.length; ++i) { - Module.setValue(cshape + i * SIZEOF_INT64, shape[i], "i64"); - } - dtype = getTVMType(dtype); - TVM_CALL(TVMArrayAlloc(cshape, shape.length, - dtype.code, dtype.bits, dtype.lanes, - ctx.device_type, ctx.device_id, - out.data)); - var out_handle = out.asHandle(); - // release - Module._free(cshape); - out.release(); - return new NDArray(out_handle); - }; - /** - * List all global function names in the TVM runtime. - * @return {Array.} List of global function names. - */ - this.listGlobalFuncNames = function() { - // alloc - var out_size = new RefTVMValue(); - var out_array = new RefTVMValue(); - TVM_CALL(TVMFuncListGlobalNames(out_size.data, out_array.data)); - var length = out_size.asInt(); - var base = out_array.asHandle(); - var names = []; - for (var i = 0 ; i < length; ++i) { - names.push( - CStringToJS(Module.getValue(base + i * SIZEOF_POINTER, "*"))); - } - // release - out_size.release(); - out_array.release(); - return names; - }; - var listGlobalFuncNames = this.listGlobalFuncNames; - /** - * Get a global function from TVM runtime. - * - * @param {string} The name of the function. - * @return {Function} The corresponding function, null if function do not exist - */ - this.getGlobalFunc = function (name) { - // alloc - var out = new RefTVMValue(); - TVM_CALL(TVMFuncGetGlobal(name, out.data)); - var out_handle = out.asHandle(); - // release - out.release(); - if (out_handle != 0) { - return makeTVMFunction(out_handle); - } else { - return null; - } - }; - var getGlobalFunc = this.getGlobalFunc; - /** - * Register function to be global function in tvm runtime. - * @param {string} name The name of the function. - * @param {Function} f function to be registered. - * @param {boolean} override Whether overwrite function in existing registry. - */ - this.registerFunc = function(name, f, override) { - f = convertFunc(f); - override = (typeof override !== "undefined") ? override: false; - var ioverride = override ? 1 : 0; - TVM_CALL(TVMFuncRegisterGlobal(name, f._tvm_function.handle, ioverride)); - }; - /** - * Create a typed scalar constant. - * This can be used to pass number as integer types to tvm function. - * - * @param {number} value The value of the data. - * @param {string} dtype The data type. - * @param {tvm.TVMConstant} The created typed scalar. - */ - this.constant = function(value, dtype) { - return new TVMConstant(value, dtype); - }; - //----------------------------------------- - // Wrap of TVM Functions. - // ---------------------------------------- - var systemFunc = {}; - /** - * Get system-wide library module singleton.5A - * System lib is a global module that contains self register functions in startup. - * @return {tvm.TVMModule} The system module singleton. - */ - this.systemLib = function() { - if (typeof systemFunc.fGetSystemLib === "undefined") { - systemFunc.fGetSystemLib = getGlobalFunc("runtime.SystemLib"); - } - return systemFunc.fGetSystemLib(); - }; - - this.startRPCServer = function(url, key, counter) { - if (typeof key === "undefined") { - key = ""; - } - if (typeof counter === "undefined") { - counter = 1; - } - // Node js, import websocket - var bkey = StringToUint8Array("server:" + key); - bkey = bkey.slice(0, bkey.length - 1); - var server_name = "WebSocketRPCServer[" + key + "]"; - var RPC_MAGIC = 0xff271; - function checkEndian() { - var a = new ArrayBuffer(4); - var b = new Uint8Array(a); - var c = new Uint32Array(a); - b[0] = 0x11; - b[1] = 0x22; - b[2] = 0x33; - b[3] = 0x44; - CHECK(c[0] === 0x44332211, "Need little endian to work"); - } - checkEndian(); - // start rpc - function RPCServer(counter) { - var socket; - if (typeof module !== "undefined" && module.exports) { - // WebSocket for nodejs - const WebSocket = require("ws"); - socket = new WebSocket(url); - } else { - socket = new WebSocket(url); - } - var self = this; - socket.binaryType = "arraybuffer"; - this.init = true; - this.counter = counter; - - if (typeof systemFunc.fcreateServer === "undefined") { - systemFunc.fcreateServer = - getGlobalFunc("rpc._CreateEventDrivenServer"); - } - if (systemFunc.fcreateServer == null) { - throwError("RPCServer is not included in runtime"); - } - - var message_handler = systemFunc.fcreateServer( - function(cbytes) { - if (socket.readyState == 1) { - socket.send(cbytes); - return new TVMConstant(cbytes.length, "int32"); - } else { - return new TVMConstant(0, "int32"); - } - } , server_name, "%toinit"); - - function on_open(event) { - var intbuf = new Int32Array(1); - intbuf[0] = RPC_MAGIC; - socket.send(intbuf); - intbuf[0] = bkey.length; - socket.send(intbuf); - socket.send(bkey); - logging(server_name + " connected..."); - } - - function on_message(event) { - if (self.init) { - var msg = new Uint8Array(event.data); - CHECK(msg.length >= 4, "Need message header to be bigger than 4"); - var magic = new Int32Array(event.data)[0]; - - if (magic == RPC_MAGIC + 1) { - throwError("key: " + key + " has already been used in proxy"); - } else if (magic == RPC_MAGIC + 2) { - logging(server_name + ": RPCProxy do not have matching client key " + key); - } else { - CHECK(magic == RPC_MAGIC, url + "is not RPC Proxy"); - self.init = false; - } - logging(server_name + "init end..."); - if (msg.length > 4) { - if (message_handler( - new Uint8Array(event.data, 4, msg.length -4), - new TVMConstant(3, "int32")) == 0) { - socket.close(); - } - } - } else { - if (message_handler(new Uint8Array(event.data), - new TVMConstant(3, "int32")) == 0) { - socket.close(); - } - } - } - function on_close(event) { - message_handler.release(); - logging(server_name + ": closed finish..."); - if (!self.init && self.counter != 0) { - logging(server_name + ":reconnect to serve another request, session left=" + counter); - // start a new server. - new RPCServer(counter - 1); - } - } - socket.addEventListener("open", on_open); - socket.addEventListener("message", on_message); - socket.addEventListener("close", on_close); - } - return new RPCServer(counter); - }; - - /** - * Load a TVM module from a library file. - * The file must be present in the Emscripten virtual file system. - * For example, you can pass "--preload-file file" or "--preload-file dir/" - * to "emcc" when compiling the TVM library, in order to populate files into - * the file system. - * For more detail, see: - * https://kripken.github.io/emscripten-site/docs/porting/files/packaging_files - * @param {string} file_name Path of the file to be loaded. The path refers - * to the Emscripten virtual file system. - * @param {string} format The format of the file. - * @return {tvm.TVMModule} The loaded module. - */ - this.loadModuleFromFile = function (file_name, format) { - // alloc - var out = new RefTVMValue(); - TVM_CALL(TVMModLoadFromFile(file_name, format, out.data)); - var out_handle = out.asHandle(); - // release - out.release(); - if (out_handle != 0) { - return new TVMModule(out_handle); - } else { - return null; - } - }; - var loadModuleFromFile = this.loadModuleFromFile; - - /** - * Wrapper runtime module. - * Wraps around set_input, load_params, run, and get_output. - * - * @class - * @memberof tvm - */ - function GraphModule(tvm_graph_module, ctx) { - CHECK(tvm_graph_module instanceof TVMModule, - "tvm_graph_module must be TVMModule"); - CHECK(ctx instanceof TVMContext, "ctx must be TVMContext"); - - this.tvm_graph_module = tvm_graph_module; - this.ctx = ctx; - this._set_input = tvm_graph_module.getFunction("set_input"); - this._load_params = tvm_graph_module.getFunction("load_params"); - this._run = tvm_graph_module.getFunction("run"); - this._get_output = tvm_graph_module.getFunction("get_output"); - }; - - GraphModule.prototype = { - /** - * Set input to graph module. - * - * @param {string} key The name of the input. - * @param {NDArray} value The input value. - */ - "set_input" : function(key, value) { - CHECK(typeof key == "string", "key must be string"); - CHECK(value instanceof NDArray, "value must be NDArray"); - this._set_input(key, value); - }, - - /** - * Load parameters from serialized byte array of parameter dict. - * - * @param {Uint8Array} params The serialized parameter dict. - */ - "load_params" : function(params) { - CHECK(params instanceof Uint8Array, "params must be Uint8Array"); - this._load_params(params); - }, - - /** - * Load parameters from serialized base64 string of parameter dict. - * - * @param {string} base64_params The serialized parameter dict. - */ - "load_base64_params" : function(base64_params) { - CHECK(typeof base64_params == "string", "base64_params must be string"); - var decoded_string = atob(base64_params); - var decoded_u8 = new Uint8Array(decoded_string.length); - for (var i = 0; i < decoded_string.length; i++) { - decoded_u8[i] = decoded_string[i].charCodeAt(0); - } - this.load_params(decoded_u8); - }, - - /** - * Run forward execution of the graph. - */ - "run" : function() { - this._run(); - }, - - /** - * Get index-th output to out. - * - * @param {NDArray} out The output array container. - * @return {NDArray} The output array container. - */ - "get_output" : function(index, out) { - CHECK(typeof index == "number", "index must be number"); - CHECK(out instanceof NDArray, "out must be NDArray"); - this._get_output(new TVMConstant(index, "int32"), out); - return out; - } - }; - - /** - * Create a runtime executor module given a graph and a module. - * @param {string} graph_json_str The Json string of the graph. - * @param {TVMModule} libmod The TVM module. - * @param {TVMContext} ctx The context to deploy the module. - * @return {GraphModule} Runtime graph module for executing the graph. - */ - this.createGraphRuntime = function(graph_json_str, libmod, ctx) { - CHECK(typeof graph_json_str == "string", "graph_json_str must be string"); - CHECK(libmod instanceof TVMModule, "libmod must be TVMModule"); - CHECK(ctx instanceof TVMContext, "ctx must be TVMContext"); - - var fcreate = getGlobalFunc("tvm.graph_runtime.create"); - CHECK(fcreate != null, "Cannot find tvm.graph_runtime.create"); - - var tvm_graph_module = fcreate(graph_json_str, libmod, - new TVMConstant(ctx.device_type, "int32"), - new TVMConstant(ctx.device_id, "int32")); - - return new GraphModule(tvm_graph_module, ctx); - }; - - //----------------------------------------- - // Class defintions - // ---------------------------------------- - // NDArray. - NDArray.prototype = { - /** - * Finalizer: resources from the object. - */ - release : function() { - if (this.handle != 0) { - TVM_CALL(TVMArrayFree(this.handle)); - this.handle = 0; - } - }, - /** - * Copy data from another NDArray or javascript array. - * The number of elements must match. - * - * @param {Array} data The source data array. - */ - copyFrom : function(data) { - if (data instanceof NDArray) { - TVM_CALL(TVMArrayCopyFromTo(data.handle, this.handle)); - } else { - var size = this.shape.reduce(function(a, b) { return a * b; }, 1); - if (data.length != size) { - throwError("data size and shape mismatch data.length" + data.length + " vs " + size); - } - if (this.dtype == "float32") { - data = Float32Array.from(data); - } else if (this.dtype == "float64") { - data = Float64Array.from(data); - } else if (this.dtype == "int32") { - data = Int32Array.from(data); - } else if (this.dtype == "int8") { - data = Int8Array.from(data); - } else if (this.dtype == "uint8") { - data = Uint8Array.from(data); - } else { - throwError("Unsupported data type " + this.dtype); - } - return this.copyFromRawBytes(new Uint8Array(data.buffer)); - } - }, - /** - * Copy data from raw bytes. - * @param {Uint8Array} data Uint8Array of bytes. - */ - copyFromRawBytes : function(data) { - var size = this.shape.reduce(function(a, b) { return a * b; }, 1); - var dtype = getTVMType(this.dtype); - var nbytes = this.BYTES_PER_ELEMENT * size; - CHECK(data instanceof Uint8Array); - CHECK(data.length == nbytes, - "Data length and bytes do not match " + data.length + - " vs " + nbytes); - var temp = Module._malloc(nbytes); - Module.HEAPU8.set(data, temp); - TVM_CALL(TVMArrayCopyFromBytes(this.handle, temp, nbytes)); - Module._free(temp); - return this; - }, - /** - * Return a copied Uint8Array of the raw bytes in the NDArray. - * @return {Uint8Array} The created array. - */ - asRawBytes : function() { - var size = this.shape.reduce(function(a, b) { return a * b; }, 1); - var nbytes = this.BYTES_PER_ELEMENT * size; - var temp = Module._malloc(nbytes); - TVM_CALL(TVMArrayCopyToBytes(this.handle, temp, nbytes)); - var ret = new Uint8Array(new ArrayBuffer(nbytes)); - ret.set(new Uint8Array(Module.HEAPU8.buffer, temp, nbytes)); - Module._free(temp); - return ret; - }, - /** - * Return Array data content as javascript typed array. - * @return {TypedArray} The created array. - */ - asArray : function() { - if (this.dtype == "float32") { - return new Float32Array(this.asRawBytes().buffer); - } else if (this.dtype == "float64") { - return new Float64Array(this.asRawBytes().buffer); - } else if (this.dtype == "int32") { - return new Int32Array(this.asRawBytes().buffer); - } else if (this.dtype == "int8") { - return new Int8Array(this.asRawBytes().buffer); - } else if (this.dtype == "uint8") { - return new Uint8Array(this.asRawBytes().buffer); - } else { - throwError("Unsupported data type " + this.dtype); - } - } - }; - - TVMModule.prototype = { - /** - * Finalizer: resources from the object. - */ - release : function() { - if (this.handle != 0) { - TVM_CALL(TVMModFree(this.handle)); - this.handle = 0; - } - }, - /** - * Get function from the module. - * @param {string} name The name of the function. - * @return {Function} The correspondin function. - */ - getFunction : function(name) { - // alloc - var out = new RefTVMValue(); - TVM_CALL(TVMModGetFunction(this.handle, name, 0, out.data)); - var out_handle = out.asHandle(); - // release - out.release(); - if (out_handle == 0) { - throwError("Module has no function " + name); - } - return makeTVMFunction(out_handle); - }, - /** - * Add module to the import list of current one. - * @param {tvm.TVMModule} mod The other module to be imported. - */ - import_module : function(mod) { - CHECK(mod instanceof TVMModule, "mod must be instance of TVMModule"); - TVM_CALL(TVMModImport(this.handle, mod.handle)); - } - }; - //----------------------------------------- - // Static variables. - // ---------------------------------------- - /** Float32 type */ - this.float32 = "float32"; - /** Int32 type */ - this.int32 = "int32"; - } - /** - * Create a TVM runtime given emscripten module. - * @property {string} create - * @memberof tvm_runtime - * @param Module The emscripten module. - * @return {tvm.TVMRuntime} The created TVM runtime. - */ - this.create = function(Module) { - var tvm = {}; - tvm.Module = Module; - if (typeof Module.addFunction !== "undefined") { - tvm.Runtime = Module; - } else { - tvm.Runtime = Module.Runtime; - } - TVMRuntime.apply(tvm); - return tvm; - }; -}).apply(tvm_runtime); - -// export things in node -if (typeof module !== "undefined" && module.exports) { - module.exports = tvm_runtime; -} diff --git a/web/typedoc.json b/web/typedoc.json new file mode 100644 index 000000000000..65631ea5efa8 --- /dev/null +++ b/web/typedoc.json @@ -0,0 +1,11 @@ +{ + "out": "dist/docs", + "readme": "none", + "mode": "file", + "excludeNotExported": true, + "excludePrivate": true, + "listInvalidSymbolLinks": true, + "module": "umd", + "includes": ["src"], + "exclude": ["node_modules"] +} diff --git a/web/web_runtime.cc b/web/web_runtime.cc deleted file mode 100644 index 701ded76288e..000000000000 --- a/web/web_runtime.cc +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file web_runtime.cc - */ -#include -#include - -#include "../src/runtime/c_runtime_api.cc" -#include "../src/runtime/cpu_device_api.cc" -#include "../src/runtime/workspace_pool.cc" -#include "../src/runtime/library_module.cc" -#include "../src/runtime/system_library.cc" -#include "../src/runtime/module.cc" -#include "../src/runtime/ndarray.cc" -#include "../src/runtime/object.cc" -#include "../src/runtime/registry.cc" -#include "../src/runtime/file_util.cc" -#include "../src/runtime/dso_library.cc" -#include "../src/runtime/rpc/rpc_session.cc" -#include "../src/runtime/rpc/rpc_event_impl.cc" -#include "../src/runtime/rpc/rpc_server_env.cc" -#include "../src/runtime/graph/graph_runtime.cc" -#include "../src/runtime/opengl/opengl_device_api.cc" -#include "../src/runtime/opengl/opengl_module.cc" - -namespace tvm { -namespace contrib { - -struct RPCEnv { - public: - RPCEnv() { - base_ = "/rpc"; - mkdir(&base_[0], 0777); - } - // Get Path. - std::string GetPath(const std::string& file_name) { - return base_ + "/" + file_name; - } - - private: - std::string base_; -}; - -TVM_REGISTER_GLOBAL("tvm.rpc.server.workpath") -.set_body_typed([](std::string path) { - static RPCEnv env; - return env.GetPath(path); - }); - -TVM_REGISTER_GLOBAL("tvm.rpc.server.load_module") -.set_body_typed([](std::string path) { - std::string file_name = "/rpc/" + path; - LOG(INFO) << "Load module from " << file_name << " ..."; - return Module::LoadFromFile(file_name, ""); - }); -} // namespace contrib -} // namespace tvm - -// dummy parallel runtime -int TVMBackendParallelLaunch( - FTVMParallelLambda flambda, - void* cdata, - int num_task) { - TVMAPISetLastError("Parallel is not supported in Web runtime"); - return -1; -} - -int TVMBackendParallelBarrier(int task_id, TVMParallelGroupEnv* penv) { - return 0; -}